Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import numpy as np | |
from matplotlib import cm | |
import matplotlib.patches as mpatches | |
import matplotlib.pyplot as plt | |
import torch | |
from utils import myroll2d | |
def create_outer_edge_mask_torch(mask, edge_thickness = 20): | |
mask_down = myroll2d(mask, edge_thickness, 0 ) | |
mask_edge_down = (mask_down.to(torch.float) -mask.to(torch.float))>0 | |
mask_up = myroll2d(mask, -edge_thickness, 0) | |
mask_edge_up = (mask_up.to(torch.float) -mask.to(torch.float))>0 | |
mask_left = myroll2d(mask, 0, -edge_thickness) | |
mask_edge_left = (mask_left.to(torch.float) -mask.to(torch.float))>0 | |
mask_right = myroll2d(mask, 0, edge_thickness) | |
mask_edge_right = (mask_right.to(torch.float) -mask.to(torch.float))>0 | |
mask_ur = myroll2d(mask, -edge_thickness,edge_thickness) | |
mask_edge_ur = (mask_ur.to(torch.float) -mask.to(torch.float))>0 | |
mask_ul = myroll2d(mask, -edge_thickness,-edge_thickness) | |
mask_edge_ul = (mask_ul.to(torch.float) -mask.to(torch.float))>0 | |
mask_dr = myroll2d(mask, edge_thickness,edge_thickness ) | |
mask_edge_dr = (mask_dr.to(torch.float) -mask.to(torch.float))>0 | |
mask_dl = myroll2d(mask, edge_thickness,-edge_thickness) | |
mask_edge_ul = (mask_dl.to(torch.float) -mask.to(torch.float))>0 | |
mask_edge = mask_union_torch(mask_edge_down, mask_edge_up, mask_edge_left, mask_edge_right, | |
mask_edge_ur, mask_edge_ul, mask_edge_dr, mask_edge_ul) | |
return mask_edge | |
def mask_substract_torch(mask1, mask2): | |
return ((mask1.cpu().to(torch.float)-mask2.cpu().to(torch.float))>0).to(torch.uint8) | |
def check_mask_overlap_torch(*masks): | |
assert torch.any(sum([m.float() for m in masks])<=1 ) | |
def check_mask_overlap_numpy(*masks): | |
assert np.all(sum([m.astype(float) for m in masks])<=1 ) | |
def check_cover_all_torch (*masks): | |
assert torch.all(sum([m.cpu().float() for m in masks])==1) | |
def process_mask_to_follow_priority(mask_list, priority_list): | |
for idx1, (m1 , p1) in enumerate(zip(mask_list, priority_list)): | |
for idx2, (m2 , p2) in enumerate(zip(mask_list, priority_list)): | |
if p2 > p1: | |
mask_list[idx1] = ((m1.astype(float)-m2.astype(float))>0).astype(np.uint8) | |
return mask_list | |
def mask_union(*masks): | |
masks = [m.astype(float) for m in masks] | |
res = sum(masks)>0 | |
return res.astype(np.uint8) | |
def mask_intersection(mask1, mask2): | |
mask_uni = mask_union(mask1, mask2) | |
mask_intersec = ((mask1.astype(float)-mask2.astype(float))==0) * mask_uni | |
return mask_intersec | |
def mask_union_torch(*masks): | |
masks = [m.float() for m in masks] | |
res = sum(masks)>0 | |
return res.to(torch.uint8) | |
def mask_intersection_torch(mask1, mask2): | |
mask_uni = mask_union_torch(mask1, mask2) | |
mask_intersec = ((mask1.float()-mask2.float())==0) * mask_uni | |
return mask_intersec.cpu().to(torch.uint8) | |
def visualize_mask_list(mask_list, savepath): | |
mask = 0 | |
for midx, m in enumerate(mask_list): | |
try: | |
mask += m.astype(float)* midx | |
except: | |
mask += m.float()*midx | |
viridis = cm.get_cmap('viridis', len(mask_list)) | |
fig, ax = plt.subplots() | |
ax.imshow( mask) | |
handles = [] | |
label_list = [] | |
for idx , _ in enumerate(mask_list): | |
color = viridis(idx) | |
label = f"{idx}" | |
handles.append(mpatches.Patch(color=color, label=label)) | |
label_list.append(label) | |
ax.legend(handles=handles) | |
plt.savefig(savepath) | |
def visualize_mask_list_clean(mask_list, savepath): | |
mask = 0 | |
for midx, m in enumerate(mask_list): | |
try: | |
mask += m.astype(float)* midx | |
except: | |
mask += m.float()*midx | |
viridis = cm.get_cmap('viridis', len(mask_list)) | |
fig, ax = plt.subplots() | |
ax.imshow( mask) | |
handles = [] | |
label_list = [] | |
for idx , _ in enumerate(mask_list): | |
color = viridis(idx) | |
label = f"{idx}" | |
handles.append(mpatches.Patch(color=color, label=label)) | |
label_list.append(label) | |
# ax.legend(handles=handles) | |
plt.savefig(savepath, dpi=500) | |
def move_mask(mask_select, delta_x, delta_y): | |
mask_edit = myroll2d(mask_select, delta_y, delta_x) | |
return mask_edit | |
def stack_mask_with_priority (mask_list_np, priority_list, edit_idx_list): | |
mask_sel = mask_union(*[mask_list_np[eid] for eid in edit_idx_list]) | |
for midx, mask in enumerate(mask_list_np): | |
if midx not in edit_idx_list: | |
if priority_list[edit_idx_list[0]] >= priority_list[midx]: | |
mask = mask.astype(float) - np.logical_and(mask.astype(bool) , mask_sel.astype(bool)).astype(float) | |
mask_list_np[midx] = mask.astype("uint8") | |
for midx in edit_idx_list: | |
for midx_1 in edit_idx_list: | |
if midx != midx_1: | |
if priority_list[midx] <= priority_list[midx_1]: | |
mask = mask_list_np[midx].astype(float) - np.logical_and(mask_list_np[midx].astype(bool), mask_list_np[midx_1].astype(bool)).astype(float) | |
mask_list_np[midx] = mask.astype("uint8") | |
return mask_list_np | |
def process_remain_mask(mask_list, edit_idx_list = None, force_mask_remain = None): | |
print("Start to process remaining mask using nearest neighbor") | |
width = mask_list[0].shape[0] | |
height = mask_list[0].shape[1] | |
pixel_ind = np.arange( width* height) | |
y_axis = np.arange(width) | |
ymesh = np.repeat(y_axis[:,np.newaxis], height, axis = 1) #N, N | |
ymesh_vec = ymesh.reshape(-1) #N *N | |
x_axis = np.arange(height) | |
xmesh = np.repeat(x_axis[np.newaxis, : ], width, axis = 0) | |
xmesh_vec = xmesh.reshape(-1) | |
mask_remain = (1 - sum([m.astype(float) for m in mask_list])).astype(np.uint8) | |
if force_mask_remain is not None: | |
mask_list[force_mask_remain] = (mask_list[force_mask_remain].astype(float) + mask_remain.astype(float)).astype(np.uint8) | |
else: | |
if edit_idx_list is not None: | |
a = [mask_list[eidx] for eidx in edit_idx_list] | |
mask_edit = mask_union(*a) | |
else: | |
mask_edit = np.zeros_like(mask_remain).astype(np.uint8) | |
mask_feasible = (1 - mask_remain.astype(float) - mask_edit.astype(float)).astype(np.uint8) | |
edge_width = 2 | |
mask_feasible_down = myroll2d(mask_feasible, edge_width, 0) | |
mask_edge_down = (mask_feasible_down.astype(float) -mask_feasible.astype(float))<0 | |
mask_feasible_up = myroll2d(mask_feasible, -edge_width, 0) | |
mask_edge_up = (mask_feasible_up.astype(float) -mask_feasible.astype(float))<0 | |
mask_feasible_left = myroll2d(mask_feasible, 0, -edge_width) | |
mask_edge_left = (mask_feasible_left.astype(float) -mask_feasible.astype(float))<0 | |
mask_feasible_right = myroll2d(mask_feasible, 0, edge_width) | |
mask_edge_right = (mask_feasible_right.astype(float) -mask_feasible.astype(float))<0 | |
mask_feasible_ur = myroll2d(mask_feasible, -edge_width,edge_width) | |
mask_edge_ur = (mask_feasible_ur.astype(float) -mask_feasible.astype(float))<0 | |
mask_feasible_ul = myroll2d(mask_feasible, -edge_width,-edge_width ) | |
mask_edge_ul = (mask_feasible_ul.astype(float) -mask_feasible.astype(float))<0 | |
mask_feasible_dr = myroll2d(mask_feasible, edge_width,edge_width ) | |
mask_edge_dr = (mask_feasible_dr.astype(float) -mask_feasible.astype(float))<0 | |
mask_feasible_dl = myroll2d(mask_feasible, edge_width,-edge_width) | |
mask_edge_ul = (mask_feasible_dl.astype(float) -mask_feasible.astype(float))<0 | |
mask_edge = mask_union( | |
mask_edge_down, mask_edge_up, mask_edge_left, mask_edge_right, mask_edge_ur, mask_edge_ul, mask_edge_dr, mask_edge_ul | |
) | |
mask_feasible_edge = mask_intersection(mask_edge, mask_feasible) | |
vec_mask_feasible_edge = mask_feasible_edge.reshape(-1) | |
vec_mask_remain = mask_remain.reshape(-1) | |
indvec_all = np.arange(width*height) | |
vec_region_partition= 0 | |
for mask_idx, mask in enumerate(mask_list): | |
vec_region_partition += mask.reshape(-1) * mask_idx | |
vec_region_partition += mask_remain.reshape(-1) * mask_idx | |
# assert 0 in vec_region_partition | |
vec_ind_remain = np.nonzero(vec_mask_remain)[0] | |
vec_ind_feasible_edge = np.nonzero(vec_mask_feasible_edge)[0] | |
vec_x_remain = xmesh_vec[vec_ind_remain] | |
vec_y_remain = ymesh_vec[vec_ind_remain] | |
vec_x_feasible_edge = xmesh_vec[vec_ind_feasible_edge] | |
vec_y_feasible_edge = ymesh_vec[vec_ind_feasible_edge] | |
x_dis = vec_x_remain[:,np.newaxis] - vec_x_feasible_edge[np.newaxis,:] | |
y_dis = vec_y_remain[:,np.newaxis] - vec_y_feasible_edge[np.newaxis,:] | |
dis = x_dis **2 + y_dis **2 | |
pos = np.argmin(dis, axis = 1) | |
nearest_point = vec_ind_feasible_edge[pos] # closest point to target point | |
nearest_region = vec_region_partition[nearest_point] | |
nearest_region_set = set(nearest_region) | |
if edit_idx_list is not None: | |
for edit_idx in edit_idx_list: | |
assert edit_idx not in nearest_region | |
for midx, m in enumerate(mask_list): | |
if midx in nearest_region_set: | |
vec_newmask = np.zeros_like(indvec_all) | |
add_ind = vec_ind_remain [np.argwhere(nearest_region==midx)] | |
vec_newmask[add_ind] = 1 | |
mask_list[midx] = mask_list[midx].astype(float)+ vec_newmask.reshape( mask_list[midx].shape).astype(float) | |
mask_list[midx] = mask_list[midx] > 0 | |
print("Finish processing remaining mask, if you want to edit, launch the ui") | |
return mask_list, mask_remain | |
def resize_mask(mask_np, resize_ratio = 1): | |
w, h = mask_np.shape[0], mask_np.shape[1] | |
resized_w, resized_h = int(w*resize_ratio),int(h*resize_ratio) | |
mask_resized = torch.nn.functional.interpolate(torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0), (resized_w, resized_h)).squeeze() | |
mask = torch.zeros(w, h) | |
if w > resized_w: | |
mask[:resized_w, :resized_h] = mask_resized | |
else: | |
assert h <= resized_h | |
mask = mask_resized[resized_w//2-w//2: resized_w//2-w//2+w, resized_h//2-h//2: resized_h//2-h//2+h] | |
return mask.cpu().numpy().astype(np.uint8) | |
def process_mask_move_torch( | |
mask_list, | |
move_index_list, | |
delta_x_list = None, | |
delta_y_list = None, | |
edit_priority_list = None, | |
force_mask_remain = None, | |
resize_list = None | |
): | |
mask_list_np = [m.cpu().numpy() for m in mask_list] | |
priority_list = [0 for _ in range(len(mask_list_np))] | |
for idx, (move_index, delta_x, delta_y, priority) in enumerate(zip(move_index_list, delta_x_list, delta_y_list, edit_priority_list)): | |
priority_list[move_index] = priority | |
if resize_list is not None: | |
mask = resize_mask (mask_list_np[move_index], resize_list[idx]) | |
else: | |
mask = mask_list_np[move_index] | |
mask_list_np[move_index] = move_mask(mask, delta_x = delta_x, delta_y = delta_y) | |
mask_list_np = stack_mask_with_priority (mask_list_np, priority_list, move_index_list) # exists blank | |
check_mask_overlap_numpy(*mask_list_np) | |
mask_list_np, mask_remain = process_remain_mask(mask_list_np, move_index_list,force_mask_remain) | |
mask_list = [torch.from_numpy(m).to( dtype=torch.uint8) for m in mask_list_np] | |
mask_remain = torch.from_numpy(mask_remain).to(dtype=torch.uint8) | |
return mask_list, mask_remain | |
def process_mask_remove_torch(mask_list, remove_idx): | |
mask_list_np = [m.cpu().numpy() for m in mask_list] | |
mask_list_np[remove_idx] = np.zeros_like(mask_list_np[0]) | |
mask_list_np, mask_remain = process_remain_mask(mask_list_np) | |
mask_list = [torch.from_numpy(m).to(dtype=torch.uint8) for m in mask_list_np] | |
mask_remain = torch.from_numpy(mask_remain).to(dtype=torch.uint8) | |
return mask_list, mask_remain | |
def get_mask_difference_torch(mask_list1, mask_list2): | |
assert len(mask_list1) == len(mask_list2) | |
mask_diff = torch.zeros_like(mask_list1[0]) | |
for mask1 , mask2 in zip(mask_list1, mask_list2): | |
diff = ((mask1.float() - mask2.float())!=0).to(torch.uint8) | |
mask_diff = mask_union_torch(mask_diff, diff) | |
return mask_diff | |
def save_mask_list_to_npys(folder, mask_list, mask_label_list, name = "mask"): | |
for midx, (mask, mask_label) in enumerate(zip(mask_list, mask_label_list)): | |
np.save(os.path.join(folder, "{}{}_{}.npy".format(name, midx, mask_label)), mask) | |