import os import numpy as np from matplotlib import cm import matplotlib.patches as mpatches import matplotlib.pyplot as plt import torch from utils import myroll2d def create_outer_edge_mask_torch(mask, edge_thickness = 20): mask_down = myroll2d(mask, edge_thickness, 0 ) mask_edge_down = (mask_down.to(torch.float) -mask.to(torch.float))>0 mask_up = myroll2d(mask, -edge_thickness, 0) mask_edge_up = (mask_up.to(torch.float) -mask.to(torch.float))>0 mask_left = myroll2d(mask, 0, -edge_thickness) mask_edge_left = (mask_left.to(torch.float) -mask.to(torch.float))>0 mask_right = myroll2d(mask, 0, edge_thickness) mask_edge_right = (mask_right.to(torch.float) -mask.to(torch.float))>0 mask_ur = myroll2d(mask, -edge_thickness,edge_thickness) mask_edge_ur = (mask_ur.to(torch.float) -mask.to(torch.float))>0 mask_ul = myroll2d(mask, -edge_thickness,-edge_thickness) mask_edge_ul = (mask_ul.to(torch.float) -mask.to(torch.float))>0 mask_dr = myroll2d(mask, edge_thickness,edge_thickness ) mask_edge_dr = (mask_dr.to(torch.float) -mask.to(torch.float))>0 mask_dl = myroll2d(mask, edge_thickness,-edge_thickness) mask_edge_ul = (mask_dl.to(torch.float) -mask.to(torch.float))>0 mask_edge = mask_union_torch(mask_edge_down, mask_edge_up, mask_edge_left, mask_edge_right, mask_edge_ur, mask_edge_ul, mask_edge_dr, mask_edge_ul) return mask_edge def mask_substract_torch(mask1, mask2): return ((mask1.cpu().to(torch.float)-mask2.cpu().to(torch.float))>0).to(torch.uint8) def check_mask_overlap_torch(*masks): assert torch.any(sum([m.float() for m in masks])<=1 ) def check_mask_overlap_numpy(*masks): assert np.all(sum([m.astype(float) for m in masks])<=1 ) def check_cover_all_torch (*masks): assert torch.all(sum([m.cpu().float() for m in masks])==1) def process_mask_to_follow_priority(mask_list, priority_list): for idx1, (m1 , p1) in enumerate(zip(mask_list, priority_list)): for idx2, (m2 , p2) in enumerate(zip(mask_list, priority_list)): if p2 > p1: mask_list[idx1] = ((m1.astype(float)-m2.astype(float))>0).astype(np.uint8) return mask_list def mask_union(*masks): masks = [m.astype(float) for m in masks] res = sum(masks)>0 return res.astype(np.uint8) def mask_intersection(mask1, mask2): mask_uni = mask_union(mask1, mask2) mask_intersec = ((mask1.astype(float)-mask2.astype(float))==0) * mask_uni return mask_intersec def mask_union_torch(*masks): masks = [m.float() for m in masks] res = sum(masks)>0 return res.to(torch.uint8) def mask_intersection_torch(mask1, mask2): mask_uni = mask_union_torch(mask1, mask2) mask_intersec = ((mask1.float()-mask2.float())==0) * mask_uni return mask_intersec.cpu().to(torch.uint8) def visualize_mask_list(mask_list, savepath): mask = 0 for midx, m in enumerate(mask_list): try: mask += m.astype(float)* midx except: mask += m.float()*midx viridis = cm.get_cmap('viridis', len(mask_list)) fig, ax = plt.subplots() ax.imshow( mask) handles = [] label_list = [] for idx , _ in enumerate(mask_list): color = viridis(idx) label = f"{idx}" handles.append(mpatches.Patch(color=color, label=label)) label_list.append(label) ax.legend(handles=handles) plt.savefig(savepath) def visualize_mask_list_clean(mask_list, savepath): mask = 0 for midx, m in enumerate(mask_list): try: mask += m.astype(float)* midx except: mask += m.float()*midx viridis = cm.get_cmap('viridis', len(mask_list)) fig, ax = plt.subplots() ax.imshow( mask) handles = [] label_list = [] for idx , _ in enumerate(mask_list): color = viridis(idx) label = f"{idx}" handles.append(mpatches.Patch(color=color, label=label)) label_list.append(label) # ax.legend(handles=handles) plt.savefig(savepath, dpi=500) def move_mask(mask_select, delta_x, delta_y): mask_edit = myroll2d(mask_select, delta_y, delta_x) return mask_edit def stack_mask_with_priority (mask_list_np, priority_list, edit_idx_list): mask_sel = mask_union(*[mask_list_np[eid] for eid in edit_idx_list]) for midx, mask in enumerate(mask_list_np): if midx not in edit_idx_list: if priority_list[edit_idx_list[0]] >= priority_list[midx]: mask = mask.astype(float) - np.logical_and(mask.astype(bool) , mask_sel.astype(bool)).astype(float) mask_list_np[midx] = mask.astype("uint8") for midx in edit_idx_list: for midx_1 in edit_idx_list: if midx != midx_1: if priority_list[midx] <= priority_list[midx_1]: mask = mask_list_np[midx].astype(float) - np.logical_and(mask_list_np[midx].astype(bool), mask_list_np[midx_1].astype(bool)).astype(float) mask_list_np[midx] = mask.astype("uint8") return mask_list_np def process_remain_mask(mask_list, edit_idx_list = None, force_mask_remain = None): print("Start to process remaining mask using nearest neighbor") width = mask_list[0].shape[0] height = mask_list[0].shape[1] pixel_ind = np.arange( width* height) y_axis = np.arange(width) ymesh = np.repeat(y_axis[:,np.newaxis], height, axis = 1) #N, N ymesh_vec = ymesh.reshape(-1) #N *N x_axis = np.arange(height) xmesh = np.repeat(x_axis[np.newaxis, : ], width, axis = 0) xmesh_vec = xmesh.reshape(-1) mask_remain = (1 - sum([m.astype(float) for m in mask_list])).astype(np.uint8) if force_mask_remain is not None: mask_list[force_mask_remain] = (mask_list[force_mask_remain].astype(float) + mask_remain.astype(float)).astype(np.uint8) else: if edit_idx_list is not None: a = [mask_list[eidx] for eidx in edit_idx_list] mask_edit = mask_union(*a) else: mask_edit = np.zeros_like(mask_remain).astype(np.uint8) mask_feasible = (1 - mask_remain.astype(float) - mask_edit.astype(float)).astype(np.uint8) edge_width = 2 mask_feasible_down = myroll2d(mask_feasible, edge_width, 0) mask_edge_down = (mask_feasible_down.astype(float) -mask_feasible.astype(float))<0 mask_feasible_up = myroll2d(mask_feasible, -edge_width, 0) mask_edge_up = (mask_feasible_up.astype(float) -mask_feasible.astype(float))<0 mask_feasible_left = myroll2d(mask_feasible, 0, -edge_width) mask_edge_left = (mask_feasible_left.astype(float) -mask_feasible.astype(float))<0 mask_feasible_right = myroll2d(mask_feasible, 0, edge_width) mask_edge_right = (mask_feasible_right.astype(float) -mask_feasible.astype(float))<0 mask_feasible_ur = myroll2d(mask_feasible, -edge_width,edge_width) mask_edge_ur = (mask_feasible_ur.astype(float) -mask_feasible.astype(float))<0 mask_feasible_ul = myroll2d(mask_feasible, -edge_width,-edge_width ) mask_edge_ul = (mask_feasible_ul.astype(float) -mask_feasible.astype(float))<0 mask_feasible_dr = myroll2d(mask_feasible, edge_width,edge_width ) mask_edge_dr = (mask_feasible_dr.astype(float) -mask_feasible.astype(float))<0 mask_feasible_dl = myroll2d(mask_feasible, edge_width,-edge_width) mask_edge_ul = (mask_feasible_dl.astype(float) -mask_feasible.astype(float))<0 mask_edge = mask_union( mask_edge_down, mask_edge_up, mask_edge_left, mask_edge_right, mask_edge_ur, mask_edge_ul, mask_edge_dr, mask_edge_ul ) mask_feasible_edge = mask_intersection(mask_edge, mask_feasible) vec_mask_feasible_edge = mask_feasible_edge.reshape(-1) vec_mask_remain = mask_remain.reshape(-1) indvec_all = np.arange(width*height) vec_region_partition= 0 for mask_idx, mask in enumerate(mask_list): vec_region_partition += mask.reshape(-1) * mask_idx vec_region_partition += mask_remain.reshape(-1) * mask_idx # assert 0 in vec_region_partition vec_ind_remain = np.nonzero(vec_mask_remain)[0] vec_ind_feasible_edge = np.nonzero(vec_mask_feasible_edge)[0] vec_x_remain = xmesh_vec[vec_ind_remain] vec_y_remain = ymesh_vec[vec_ind_remain] vec_x_feasible_edge = xmesh_vec[vec_ind_feasible_edge] vec_y_feasible_edge = ymesh_vec[vec_ind_feasible_edge] x_dis = vec_x_remain[:,np.newaxis] - vec_x_feasible_edge[np.newaxis,:] y_dis = vec_y_remain[:,np.newaxis] - vec_y_feasible_edge[np.newaxis,:] dis = x_dis **2 + y_dis **2 pos = np.argmin(dis, axis = 1) nearest_point = vec_ind_feasible_edge[pos] # closest point to target point nearest_region = vec_region_partition[nearest_point] nearest_region_set = set(nearest_region) if edit_idx_list is not None: for edit_idx in edit_idx_list: assert edit_idx not in nearest_region for midx, m in enumerate(mask_list): if midx in nearest_region_set: vec_newmask = np.zeros_like(indvec_all) add_ind = vec_ind_remain [np.argwhere(nearest_region==midx)] vec_newmask[add_ind] = 1 mask_list[midx] = mask_list[midx].astype(float)+ vec_newmask.reshape( mask_list[midx].shape).astype(float) mask_list[midx] = mask_list[midx] > 0 print("Finish processing remaining mask, if you want to edit, launch the ui") return mask_list, mask_remain def resize_mask(mask_np, resize_ratio = 1): w, h = mask_np.shape[0], mask_np.shape[1] resized_w, resized_h = int(w*resize_ratio),int(h*resize_ratio) mask_resized = torch.nn.functional.interpolate(torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0), (resized_w, resized_h)).squeeze() mask = torch.zeros(w, h) if w > resized_w: mask[:resized_w, :resized_h] = mask_resized else: assert h <= resized_h mask = mask_resized[resized_w//2-w//2: resized_w//2-w//2+w, resized_h//2-h//2: resized_h//2-h//2+h] return mask.cpu().numpy().astype(np.uint8) def process_mask_move_torch( mask_list, move_index_list, delta_x_list = None, delta_y_list = None, edit_priority_list = None, force_mask_remain = None, resize_list = None ): mask_list_np = [m.cpu().numpy() for m in mask_list] priority_list = [0 for _ in range(len(mask_list_np))] for idx, (move_index, delta_x, delta_y, priority) in enumerate(zip(move_index_list, delta_x_list, delta_y_list, edit_priority_list)): priority_list[move_index] = priority if resize_list is not None: mask = resize_mask (mask_list_np[move_index], resize_list[idx]) else: mask = mask_list_np[move_index] mask_list_np[move_index] = move_mask(mask, delta_x = delta_x, delta_y = delta_y) mask_list_np = stack_mask_with_priority (mask_list_np, priority_list, move_index_list) # exists blank check_mask_overlap_numpy(*mask_list_np) mask_list_np, mask_remain = process_remain_mask(mask_list_np, move_index_list,force_mask_remain) mask_list = [torch.from_numpy(m).to( dtype=torch.uint8) for m in mask_list_np] mask_remain = torch.from_numpy(mask_remain).to(dtype=torch.uint8) return mask_list, mask_remain def process_mask_remove_torch(mask_list, remove_idx): mask_list_np = [m.cpu().numpy() for m in mask_list] mask_list_np[remove_idx] = np.zeros_like(mask_list_np[0]) mask_list_np, mask_remain = process_remain_mask(mask_list_np) mask_list = [torch.from_numpy(m).to(dtype=torch.uint8) for m in mask_list_np] mask_remain = torch.from_numpy(mask_remain).to(dtype=torch.uint8) return mask_list, mask_remain def get_mask_difference_torch(mask_list1, mask_list2): assert len(mask_list1) == len(mask_list2) mask_diff = torch.zeros_like(mask_list1[0]) for mask1 , mask2 in zip(mask_list1, mask_list2): diff = ((mask1.float() - mask2.float())!=0).to(torch.uint8) mask_diff = mask_union_torch(mask_diff, diff) return mask_diff def save_mask_list_to_npys(folder, mask_list, mask_label_list, name = "mask"): for midx, (mask, mask_label) in enumerate(zip(mask_list, mask_label_list)): np.save(os.path.join(folder, "{}{}_{}.npy".format(name, midx, mask_label)), mask)