Spaces:
Runtime error
Runtime error
improve app design
Browse files- app.py +10 -5
- predict.py +13 -8
app.py
CHANGED
@@ -16,11 +16,16 @@ with demo:
|
|
16 |
|
17 |
with gr.Box():
|
18 |
|
|
|
19 |
with gr.Row():
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
|
|
|
|
|
|
|
|
|
24 |
output_mask = gr.Image(label="Predicted Masks", show_label=True)
|
25 |
|
26 |
gr.Markdown("**Predict**")
|
@@ -32,10 +37,10 @@ with demo:
|
|
32 |
gr.Markdown("**Examples:**")
|
33 |
|
34 |
with gr.Column():
|
35 |
-
gr.Examples(example_list, [input_image, segmentation_task], output_mask, predict_masks)
|
36 |
|
37 |
|
38 |
-
submit_button.click(predict_masks, inputs=[input_image, segmentation_task], outputs=output_mask)
|
39 |
|
40 |
gr.Markdown('\n Demo created by: <a href=\"https://www.linkedin.com/in/shivalika-singh/\">Shivalika Singh</a>')
|
41 |
|
|
|
16 |
|
17 |
with gr.Box():
|
18 |
|
19 |
+
|
20 |
with gr.Row():
|
21 |
+
with gr.Column():
|
22 |
+
gr.Markdown("**Inputs**")
|
23 |
+
segmentation_task = gr.Dropdown(["semantic", "instance", "panoptic"], value="panoptic", label="Segmentation Task", show_label=True)
|
24 |
input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
|
25 |
+
|
26 |
+
with gr.Column():
|
27 |
+
gr.Markdown("**Outputs**")
|
28 |
+
output_heading = gr.Textbox(label="Output Type", show_label=True)
|
29 |
output_mask = gr.Image(label="Predicted Masks", show_label=True)
|
30 |
|
31 |
gr.Markdown("**Predict**")
|
|
|
37 |
gr.Markdown("**Examples:**")
|
38 |
|
39 |
with gr.Column():
|
40 |
+
gr.Examples(example_list, [input_image, segmentation_task], [output_mask,output_heading], predict_masks)
|
41 |
|
42 |
|
43 |
+
submit_button.click(predict_masks, inputs=[input_image, segmentation_task], outputs=[output_mask,output_heading])
|
44 |
|
45 |
gr.Markdown('\n Demo created by: <a href=\"https://www.linkedin.com/in/shivalika-singh/\">Shivalika Singh</a>')
|
46 |
|
predict.py
CHANGED
@@ -4,7 +4,8 @@ import numpy as np
|
|
4 |
from PIL import Image
|
5 |
from collections import defaultdict
|
6 |
import os
|
7 |
-
#
|
|
|
8 |
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
|
9 |
|
10 |
from detectron2.data import MetadataCatalog
|
@@ -21,12 +22,12 @@ def load_model_and_processor(model_ckpt: str):
|
|
21 |
|
22 |
def load_default_ckpt(segmentation_task: str):
|
23 |
if segmentation_task == "semantic":
|
24 |
-
|
25 |
elif segmentation_task == "instance":
|
26 |
-
|
27 |
else:
|
28 |
-
|
29 |
-
return
|
30 |
|
31 |
def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
|
32 |
metadata = MetadataCatalog.get("coco_2017_val_panoptic")
|
@@ -73,8 +74,8 @@ def visualize_instance_seg_mask(mask, input_image):
|
|
73 |
def predict_masks(input_img_path: str, segmentation_task: str):
|
74 |
|
75 |
#load model and image processor
|
76 |
-
|
77 |
-
model, image_processor = load_model_and_processor(
|
78 |
|
79 |
## pass input image through image processor
|
80 |
image = Image.open(input_img_path)
|
@@ -90,16 +91,20 @@ def predict_masks(input_img_path: str, segmentation_task: str):
|
|
90 |
predicted_segmentation_map = result.cpu().numpy()
|
91 |
palette = ade_palette()
|
92 |
output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
|
|
|
93 |
|
94 |
elif segmentation_task == "instance":
|
95 |
result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
96 |
predicted_instance_map = result["segmentation"].cpu().detach().numpy()
|
97 |
output_result = visualize_instance_seg_mask(predicted_instance_map, image)
|
|
|
98 |
|
99 |
else:
|
100 |
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
101 |
predicted_segmentation_map = result["segmentation"]
|
102 |
seg_info = result['segments_info']
|
103 |
output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
|
|
|
104 |
|
105 |
-
|
|
|
|
4 |
from PIL import Image
|
5 |
from collections import defaultdict
|
6 |
import os
|
7 |
+
# Mentioning detectron2 as a dependency directly in requirements.txt tries to install detectron2 before torch and results in an error even if torch is listed as a dependency before detectron2.
|
8 |
+
# Hence, installing detectron2 this way when using Gradio HF spaces.
|
9 |
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
|
10 |
|
11 |
from detectron2.data import MetadataCatalog
|
|
|
22 |
|
23 |
def load_default_ckpt(segmentation_task: str):
|
24 |
if segmentation_task == "semantic":
|
25 |
+
default_ckpt = "facebook/mask2former-swin-tiny-ade-semantic"
|
26 |
elif segmentation_task == "instance":
|
27 |
+
default_ckpt = "facebook/mask2former-swin-small-coco-instance"
|
28 |
else:
|
29 |
+
default_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
|
30 |
+
return default_ckpt
|
31 |
|
32 |
def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
|
33 |
metadata = MetadataCatalog.get("coco_2017_val_panoptic")
|
|
|
74 |
def predict_masks(input_img_path: str, segmentation_task: str):
|
75 |
|
76 |
#load model and image processor
|
77 |
+
default_ckpt = load_default_ckpt(segmentation_task)
|
78 |
+
model, image_processor = load_model_and_processor(default_ckpt)
|
79 |
|
80 |
## pass input image through image processor
|
81 |
image = Image.open(input_img_path)
|
|
|
91 |
predicted_segmentation_map = result.cpu().numpy()
|
92 |
palette = ade_palette()
|
93 |
output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
|
94 |
+
output_heading = "Semantic Segmentation Output"
|
95 |
|
96 |
elif segmentation_task == "instance":
|
97 |
result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
98 |
predicted_instance_map = result["segmentation"].cpu().detach().numpy()
|
99 |
output_result = visualize_instance_seg_mask(predicted_instance_map, image)
|
100 |
+
output_heading = "Instance Segmentation Output"
|
101 |
|
102 |
else:
|
103 |
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
104 |
predicted_segmentation_map = result["segmentation"]
|
105 |
seg_info = result['segments_info']
|
106 |
output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
|
107 |
+
output_heading = "Panoptic Segmentation Output"
|
108 |
|
109 |
+
|
110 |
+
return output_result, output_heading
|