Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,789 Bytes
d807efd 6eaf2d7 d807efd 3f3b681 d807efd ed7479e d807efd 850ea5b d807efd 872b038 9599a85 872b038 d807efd 9599a85 62c1dec d807efd 9599a85 d807efd 9599a85 62c1dec 9599a85 d807efd 872b038 d807efd 9599a85 62c1dec d807efd 9599a85 d807efd 9599a85 d807efd ed7479e 3f3b681 d807efd 850ea5b d807efd 850ea5b 8963af6 850ea5b 9599a85 850ea5b 872b038 850ea5b 3f3b681 9599a85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from PIL import Image
import spaces
import torch
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.patches as mpatches
import os
import numpy as np
import argparse
import matplotlib
import gradio as gr
def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
if type(image_path) is str:
image = np.array(Image.open(image_path))[:, :, :3]
else:
image = image_path
h, w, c = image.shape
left = min(left, w-1)
right = min(right, w - left - 1)
top = min(top, h - left - 1)
bottom = min(bottom, h - top - 1)
image = image[top:h-bottom, left:w-right]
h, w, c = image.shape
if h < w:
offset = (w - h) // 2
image = image[:, offset:offset + h]
elif w < h:
offset = (h - w) // 2
image = image[offset:offset + w]
image = np.array(Image.fromarray(image).resize((size, size)))
return image
def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, noseg = False, model =None):
if torch.max(segmentation)==torch.min(segmentation)==-1:
print("nothing is detected!")
noseg=True
viridis = matplotlib.colormaps['viridis'].resampled(1)
else:
viridis = matplotlib.colormaps['viridis'].resampled(torch.max(segmentation)-torch.min(segmentation)+1)
fig, ax = plt.subplots()
ax.imshow(segmentation)
instances_counter = defaultdict(int)
handles = []
label_list = []
mask_np_list = []
if not noseg:
if torch.min(segmentation) == 0:
mask = segmentation==0
mask = mask.cpu().detach().numpy() # [512,512] bool
print(mask.shape)
segment_label = "rest"
color = viridis(0)
label = f"{segment_label}-{0}"
mask_np_list.append(mask)
handles.append(mpatches.Patch(color=color, label=label))
label_list.append(label)
for segment in segments_info:
segment_id = segment['id']
mask = segmentation==segment_id
if torch.min(segmentation) != 0:
segment_id -= 1
mask = mask.cpu().detach().numpy() # [512,512] bool
print(mask.shape)
mask_np_list.append(mask)
segment_label = model.config.id2label[segment['label_id']]
instances_counter[segment['label_id']] += 1
color = viridis(segment_id)
label = f"{segment_label}-{segment_id}"
handles.append(mpatches.Patch(color=color, label=label))
label_list.append(label)
else:
mask = np.full(segmentation.shape, True)
print(mask.shape)
segment_label = "all"
mask_np_list.append(mask)
color = viridis(0)
label = f"{segment_label}-{0}"
handles.append(mpatches.Patch(color=color, label=label))
label_list.append(label)
plt.xticks([])
plt.yticks([])
# plt.savefig(os.path.join(save_folder, 'mask_clear.png'), dpi=500)
ax.legend(handles=handles)
plt.savefig(os.path.join(save_folder, 'seg_init.png'), dpi=500 )
print("; ".join(label_list))
return mask_np_list,label_list
@spaces.GPU(duration=10)
def run_segmentation(image, name="example_tmp", size = 512, noseg=False):
base_folder_path = "."
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
# input_folder = os.path.join(base_folder_path, name )
# try:
# image = load_image(os.path.join(input_folder, "img.png" ), size = size)
# except:
# image = load_image(os.path.join(input_folder, "img.jpg" ), size = size)
image =Image.fromarray(image)
image = image.resize((size, size))
os.makedirs(name, exist_ok=True)
#image.save(os.path.join(name,"img_{}.png".format(size)))
inputs = processor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
panoptic_segmentation = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
save_folder = os.path.join(base_folder_path, name)
os.makedirs(save_folder, exist_ok=True)
mask_list,label_list = draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model)
print("Finish segment")
#block_flag += 1
return image,mask_list,label_list#, gr.Button.update("1.2 Load edited masks",visible = True), gr.Checkbox.update(label = "Show Segmentation",visible = True)
|