Spaces:
Running
on
Zero
Running
on
Zero
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation | |
from PIL import Image | |
import spaces | |
import torch | |
from collections import defaultdict | |
import matplotlib.pyplot as plt | |
from matplotlib import cm | |
import matplotlib.patches as mpatches | |
import os | |
import numpy as np | |
import argparse | |
import matplotlib | |
import gradio as gr | |
def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512): | |
if type(image_path) is str: | |
image = np.array(Image.open(image_path))[:, :, :3] | |
else: | |
image = image_path | |
h, w, c = image.shape | |
left = min(left, w-1) | |
right = min(right, w - left - 1) | |
top = min(top, h - left - 1) | |
bottom = min(bottom, h - top - 1) | |
image = image[top:h-bottom, left:w-right] | |
h, w, c = image.shape | |
if h < w: | |
offset = (w - h) // 2 | |
image = image[:, offset:offset + h] | |
elif w < h: | |
offset = (h - w) // 2 | |
image = image[offset:offset + w] | |
image = np.array(Image.fromarray(image).resize((size, size))) | |
return image | |
def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, noseg = False, model =None): | |
if torch.max(segmentation)==torch.min(segmentation)==-1: | |
print("nothing is detected!") | |
noseg=True | |
viridis = matplotlib.colormaps['viridis'].resampled(1) | |
else: | |
viridis = matplotlib.colormaps['viridis'].resampled(torch.max(segmentation)-torch.min(segmentation)+1) | |
fig, ax = plt.subplots() | |
ax.imshow(segmentation) | |
instances_counter = defaultdict(int) | |
handles = [] | |
label_list = [] | |
mask_np_list = [] | |
if not noseg: | |
if torch.min(segmentation) == 0: | |
mask = segmentation==0 | |
mask = mask.cpu().detach().numpy() # [512,512] bool | |
print(mask.shape) | |
segment_label = "rest" | |
color = viridis(0) | |
label = f"{segment_label}-{0}" | |
mask_np_list.append(mask) | |
handles.append(mpatches.Patch(color=color, label=label)) | |
label_list.append(label) | |
for segment in segments_info: | |
segment_id = segment['id'] | |
mask = segmentation==segment_id | |
if torch.min(segmentation) != 0: | |
segment_id -= 1 | |
mask = mask.cpu().detach().numpy() # [512,512] bool | |
print(mask.shape) | |
mask_np_list.append(mask) | |
segment_label = model.config.id2label[segment['label_id']] | |
instances_counter[segment['label_id']] += 1 | |
color = viridis(segment_id) | |
label = f"{segment_label}-{segment_id}" | |
handles.append(mpatches.Patch(color=color, label=label)) | |
label_list.append(label) | |
else: | |
mask = np.full(segmentation.shape, True) | |
print(mask.shape) | |
segment_label = "all" | |
mask_np_list.append(mask) | |
color = viridis(0) | |
label = f"{segment_label}-{0}" | |
handles.append(mpatches.Patch(color=color, label=label)) | |
label_list.append(label) | |
plt.xticks([]) | |
plt.yticks([]) | |
# plt.savefig(os.path.join(save_folder, 'mask_clear.png'), dpi=500) | |
ax.legend(handles=handles) | |
plt.savefig(os.path.join(save_folder, 'seg_init.png'), dpi=500 ) | |
print("; ".join(label_list)) | |
return mask_np_list,label_list | |
def run_segmentation(image, name="example_tmp", size = 512, noseg=False): | |
base_folder_path = "." | |
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") | |
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") | |
# input_folder = os.path.join(base_folder_path, name ) | |
# try: | |
# image = load_image(os.path.join(input_folder, "img.png" ), size = size) | |
# except: | |
# image = load_image(os.path.join(input_folder, "img.jpg" ), size = size) | |
image =Image.fromarray(image) | |
image = image.resize((size, size)) | |
os.makedirs(name, exist_ok=True) | |
#image.save(os.path.join(name,"img_{}.png".format(size))) | |
inputs = processor(image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
panoptic_segmentation = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] | |
save_folder = os.path.join(base_folder_path, name) | |
os.makedirs(save_folder, exist_ok=True) | |
mask_list,label_list = draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model) | |
print("Finish segment") | |
#block_flag += 1 | |
return image,mask_list,label_list#, gr.Button.update("1.2 Load edited masks",visible = True), gr.Checkbox.update(label = "Show Segmentation",visible = True) | |