Spaces:
Running
Running
import numpy as np | |
import gradio as gr | |
import cv2 | |
from copy import deepcopy | |
import torch | |
from torchvision import transforms | |
from PIL import Image, ImageDraw, ImageFont | |
from sam.efficient_sam.build_efficient_sam import build_efficient_sam_vits | |
from src.utils.utils import resize_numpy_image | |
sam = build_efficient_sam_vits() | |
def show_point_or_box(image, global_points): | |
# for point | |
if len(global_points) == 1: | |
image = cv2.circle(image, global_points[0], 10, (0, 0, 255), -1) | |
# for box | |
if len(global_points) == 2: | |
p1 = global_points[0] | |
p2 = global_points[1] | |
image = cv2.rectangle(image,(int(p1[0]),int(p1[1])),(int(p2[0]),int(p2[1])),(0,0,255),2) | |
return image | |
def segment_with_points( | |
image, | |
original_image, | |
global_points, | |
global_point_label, | |
evt: gr.SelectData, | |
img_direction, | |
save_dir = "./tmp" | |
): | |
if original_image is None: | |
original_image = image | |
else: | |
image = original_image | |
if img_direction is None: | |
img_direction = original_image | |
x, y = evt.index[0], evt.index[1] | |
image_path = None | |
mask_path = None | |
if len(global_points) == 0: | |
global_points.append([x, y]) | |
global_point_label.append(2) | |
image_with_point= show_point_or_box(image.copy(), global_points) | |
return image_with_point, original_image, None, global_points, global_point_label | |
elif len(global_points) == 1: | |
global_points.append([x, y]) | |
global_point_label.append(3) | |
x1, y1 = global_points[0] | |
x2, y2 = global_points[1] | |
if x1 < x2 and y1 >= y2: | |
global_points[0][0] = x1 | |
global_points[0][1] = y2 | |
global_points[1][0] = x2 | |
global_points[1][1] = y1 | |
elif x1 >= x2 and y1 < y2: | |
global_points[0][0] = x2 | |
global_points[0][1] = y1 | |
global_points[1][0] = x1 | |
global_points[1][1] = y2 | |
elif x1 >= x2 and y1 >= y2: | |
global_points[0][0] = x2 | |
global_points[0][1] = y2 | |
global_points[1][0] = x1 | |
global_points[1][1] = y1 | |
image_with_point = show_point_or_box(image.copy(), global_points) | |
# data process | |
input_point = np.array(global_points) | |
input_label = np.array(global_point_label) | |
pts_sampled = torch.reshape(torch.tensor(input_point), [1, 1, -1, 2]) | |
pts_labels = torch.reshape(torch.tensor(input_label), [1, 1, -1]) | |
img_tensor = transforms.ToTensor()(image) | |
# sam | |
predicted_logits, predicted_iou = sam( | |
img_tensor[None, ...], | |
pts_sampled, | |
pts_labels, | |
) | |
mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).float().cpu().detach().numpy() | |
mask_image = (mask*255.).astype(np.uint8) | |
return image_with_point, original_image, mask_image, global_points, global_point_label | |
else: | |
global_points=[[x, y]] | |
global_point_label=[2] | |
image_with_point= show_point_or_box(image.copy(), global_points) | |
return image_with_point, original_image, None, global_points, global_point_label | |
def segment_with_points_paste( | |
image, | |
original_image, | |
global_points, | |
global_point_label, | |
image_b, | |
evt: gr.SelectData, | |
dx, | |
dy, | |
resize_scale | |
): | |
if original_image is None: | |
original_image = image | |
else: | |
image = original_image | |
x, y = evt.index[0], evt.index[1] | |
if len(global_points) == 0: | |
global_points.append([x, y]) | |
global_point_label.append(2) | |
image_with_point= show_point_or_box(image.copy(), global_points) | |
return image_with_point, original_image, None, global_points, global_point_label, None | |
elif len(global_points) == 1: | |
global_points.append([x, y]) | |
global_point_label.append(3) | |
x1, y1 = global_points[0] | |
x2, y2 = global_points[1] | |
if x1 < x2 and y1 >= y2: | |
global_points[0][0] = x1 | |
global_points[0][1] = y2 | |
global_points[1][0] = x2 | |
global_points[1][1] = y1 | |
elif x1 >= x2 and y1 < y2: | |
global_points[0][0] = x2 | |
global_points[0][1] = y1 | |
global_points[1][0] = x1 | |
global_points[1][1] = y2 | |
elif x1 >= x2 and y1 >= y2: | |
global_points[0][0] = x2 | |
global_points[0][1] = y2 | |
global_points[1][0] = x1 | |
global_points[1][1] = y1 | |
image_with_point = show_point_or_box(image.copy(), global_points) | |
# data process | |
input_point = np.array(global_points) | |
input_label = np.array(global_point_label) | |
pts_sampled = torch.reshape(torch.tensor(input_point), [1, 1, -1, 2]) | |
pts_labels = torch.reshape(torch.tensor(input_label), [1, 1, -1]) | |
img_tensor = transforms.ToTensor()(image) | |
# sam | |
predicted_logits, predicted_iou = sam( | |
img_tensor[None, ...], | |
pts_sampled, | |
pts_labels, | |
) | |
mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).float().cpu().detach().numpy() | |
mask_uint8 = (mask*255.).astype(np.uint8) | |
return image_with_point, original_image, paste_with_mask_and_offset(image, image_b, mask_uint8, dx, dy, resize_scale), global_points, global_point_label, mask_uint8 | |
else: | |
global_points=[[x, y]] | |
global_point_label=[2] | |
image_with_point= show_point_or_box(image.copy(), global_points) | |
return image_with_point, original_image, None, global_points, global_point_label, None | |
def paste_with_mask_and_offset(image_a, image_b, mask, x_offset=0, y_offset=0, delta=1): | |
try: | |
numpy_mask = np.array(mask) | |
y_coords, x_coords = np.nonzero(numpy_mask) | |
x_min = x_coords.min() | |
x_max = x_coords.max() | |
y_min = y_coords.min() | |
y_max = y_coords.max() | |
target_center_x = int((x_min + x_max) / 2) | |
target_center_y = int((y_min + y_max) / 2) | |
image_a = Image.fromarray(image_a) | |
image_b = Image.fromarray(image_b) | |
mask = Image.fromarray(mask) | |
if image_a.size != mask.size: | |
mask = mask.resize(image_a.size) | |
cropped_image = Image.composite(image_a, Image.new('RGBA', image_a.size, (0, 0, 0, 0)), mask) | |
x_b = int(target_center_x * (image_b.width / cropped_image.width)) | |
y_b = int(target_center_y * (image_b.height / cropped_image.height)) | |
x_offset = x_offset - int((delta - 1) * x_b) | |
y_offset = y_offset - int((delta - 1) * y_b) | |
cropped_image = cropped_image.resize(image_b.size) | |
new_size = (int(cropped_image.width * delta), int(cropped_image.height * delta)) | |
cropped_image = cropped_image.resize(new_size) | |
image_b.putalpha(128) | |
result_image = Image.new('RGBA', image_b.size, (0, 0, 0, 0)) | |
result_image.paste(image_b, (0, 0)) | |
result_image.paste(cropped_image, (x_offset, y_offset), mask=cropped_image) | |
return result_image | |
except: | |
return None | |
def upload_image_move(img, original_image): | |
if original_image is not None: | |
return original_image | |
else: | |
return img | |
def fun_clear(*args): | |
result = [] | |
for arg in args: | |
if isinstance(arg, list): | |
result.append([]) | |
else: | |
result.append(None) | |
return tuple(result) | |
def clear_points(img): | |
image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. | |
if mask.sum() > 0: | |
mask = np.uint8(mask > 0) | |
masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) | |
else: | |
masked_img = image.copy() | |
return [], masked_img | |
def get_point(img, sel_pix, evt: gr.SelectData): | |
sel_pix.append(evt.index) | |
points = [] | |
for idx, point in enumerate(sel_pix): | |
if idx % 2 == 0: | |
cv2.circle(img, tuple(point), 10, (0, 0, 255), -1) | |
else: | |
cv2.circle(img, tuple(point), 10, (255, 0, 0), -1) | |
points.append(tuple(point)) | |
if len(points) == 2: | |
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) | |
points = [] | |
return img if isinstance(img, np.ndarray) else np.array(img) | |
def calculate_translation_percentage(ori_shape, selected_points): | |
dx = selected_points[1][0] - selected_points[0][0] | |
dy = selected_points[1][1] - selected_points[0][1] | |
dx_percentage = dx / ori_shape[1] | |
dy_percentage = dy / ori_shape[0] | |
return dx_percentage, dy_percentage | |
def get_point_move(original_image, img, sel_pix, evt: gr.SelectData): | |
if original_image is not None: | |
img = original_image.copy() | |
else: | |
original_image = img.copy() | |
if len(sel_pix)<2: | |
sel_pix.append(evt.index) | |
else: | |
sel_pix = [evt.index] | |
points = [] | |
dx, dy = 0, 0 | |
for idx, point in enumerate(sel_pix): | |
if idx % 2 == 0: | |
cv2.circle(img, tuple(point), 10, (0, 0, 255), -1) | |
else: | |
cv2.circle(img, tuple(point), 10, (255, 0, 0), -1) | |
points.append(tuple(point)) | |
if len(points) == 2: | |
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) | |
ori_shape = original_image.shape | |
dx, dy = calculate_translation_percentage(original_image.shape, sel_pix) | |
points = [] | |
img = np.array(img) | |
return img, original_image, sel_pix, dx, dy | |
def store_img(img): | |
image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. | |
if mask.sum() > 0: | |
mask = np.uint8(mask > 0) | |
masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) | |
else: | |
masked_img = image.copy() | |
return image, masked_img, mask | |
# im["background"], im["layers"][0] | |
def store_img_move(img, mask=None): | |
if mask is not None: | |
image = img["background"] | |
return image, None, mask | |
image, mask = img["background"], np.float32(["layers"][0][:, :, 0]) / 255. | |
if mask.sum() > 0: | |
mask = np.uint8(mask > 0) | |
masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) | |
else: | |
masked_img = image.copy() | |
return image, masked_img, (mask*255.).astype(np.uint8) | |
def store_img_move_old(img, mask=None): | |
if mask is not None: | |
image = img["image"] | |
return image, None, mask | |
image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. | |
if mask.sum() > 0: | |
mask = np.uint8(mask > 0) | |
masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) | |
else: | |
masked_img = image.copy() | |
return image, masked_img, (mask*255.).astype(np.uint8) | |
def mask_image(image, mask, color=[255,0,0], alpha=0.5, max_resolution=None): | |
""" Overlay mask on image for visualization purpose. | |
Args: | |
image (H, W, 3) or (H, W): input image | |
mask (H, W): mask to be overlaid | |
color: the color of overlaid mask | |
alpha: the transparency of the mask | |
""" | |
if max_resolution is not None: | |
image, _ = resize_numpy_image(image, max_resolution*max_resolution) | |
mask = cv2.resize(mask, (image.shape[1], image.shape[0]),interpolation=cv2.INTER_NEAREST) | |
out = deepcopy(image) | |
img = deepcopy(image) | |
img[mask == 1] = color | |
out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out) | |
contours = cv2.findContours(np.uint8(deepcopy(mask)), cv2.RETR_TREE, | |
cv2.CHAIN_APPROX_SIMPLE)[-2:] | |
return out |