shivi commited on
Commit
ce088ab
1 Parent(s): c7ab35b

improve app design

Browse files
Files changed (2) hide show
  1. app.py +10 -5
  2. predict.py +13 -8
app.py CHANGED
@@ -16,11 +16,16 @@ with demo:
16
 
17
  with gr.Box():
18
 
 
19
  with gr.Row():
20
- segmentation_task = gr.Dropdown(["semantic", "instance", "panoptic"], value="panoptic", label="Segmentation Task", show_label=True)
21
- with gr.Box():
22
- with gr.Row():
23
  input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
 
 
 
 
24
  output_mask = gr.Image(label="Predicted Masks", show_label=True)
25
 
26
  gr.Markdown("**Predict**")
@@ -32,10 +37,10 @@ with demo:
32
  gr.Markdown("**Examples:**")
33
 
34
  with gr.Column():
35
- gr.Examples(example_list, [input_image, segmentation_task], output_mask, predict_masks)
36
 
37
 
38
- submit_button.click(predict_masks, inputs=[input_image, segmentation_task], outputs=output_mask)
39
 
40
  gr.Markdown('\n Demo created by: <a href=\"https://www.linkedin.com/in/shivalika-singh/\">Shivalika Singh</a>')
41
 
 
16
 
17
  with gr.Box():
18
 
19
+
20
  with gr.Row():
21
+ with gr.Column():
22
+ gr.Markdown("**Inputs**")
23
+ segmentation_task = gr.Dropdown(["semantic", "instance", "panoptic"], value="panoptic", label="Segmentation Task", show_label=True)
24
  input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
25
+
26
+ with gr.Column():
27
+ gr.Markdown("**Outputs**")
28
+ output_heading = gr.Textbox(label="Output Type", show_label=True)
29
  output_mask = gr.Image(label="Predicted Masks", show_label=True)
30
 
31
  gr.Markdown("**Predict**")
 
37
  gr.Markdown("**Examples:**")
38
 
39
  with gr.Column():
40
+ gr.Examples(example_list, [input_image, segmentation_task], [output_mask,output_heading], predict_masks)
41
 
42
 
43
+ submit_button.click(predict_masks, inputs=[input_image, segmentation_task], outputs=[output_mask,output_heading])
44
 
45
  gr.Markdown('\n Demo created by: <a href=\"https://www.linkedin.com/in/shivalika-singh/\">Shivalika Singh</a>')
46
 
predict.py CHANGED
@@ -4,7 +4,8 @@ import numpy as np
4
  from PIL import Image
5
  from collections import defaultdict
6
  import os
7
- # install detectron this way to avoid torch not available error if mentioned directly as dependency in requirements.txt
 
8
  os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
9
 
10
  from detectron2.data import MetadataCatalog
@@ -21,12 +22,12 @@ def load_model_and_processor(model_ckpt: str):
21
 
22
  def load_default_ckpt(segmentation_task: str):
23
  if segmentation_task == "semantic":
24
- default_pretrained_ckpt = "facebook/mask2former-swin-tiny-ade-semantic"
25
  elif segmentation_task == "instance":
26
- default_pretrained_ckpt = "facebook/mask2former-swin-small-coco-instance"
27
  else:
28
- default_pretrained_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
29
- return default_pretrained_ckpt
30
 
31
  def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
32
  metadata = MetadataCatalog.get("coco_2017_val_panoptic")
@@ -73,8 +74,8 @@ def visualize_instance_seg_mask(mask, input_image):
73
  def predict_masks(input_img_path: str, segmentation_task: str):
74
 
75
  #load model and image processor
76
- default_pretrained_ckpt = load_default_ckpt(segmentation_task)
77
- model, image_processor = load_model_and_processor(default_pretrained_ckpt)
78
 
79
  ## pass input image through image processor
80
  image = Image.open(input_img_path)
@@ -90,16 +91,20 @@ def predict_masks(input_img_path: str, segmentation_task: str):
90
  predicted_segmentation_map = result.cpu().numpy()
91
  palette = ade_palette()
92
  output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
 
93
 
94
  elif segmentation_task == "instance":
95
  result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
96
  predicted_instance_map = result["segmentation"].cpu().detach().numpy()
97
  output_result = visualize_instance_seg_mask(predicted_instance_map, image)
 
98
 
99
  else:
100
  result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
101
  predicted_segmentation_map = result["segmentation"]
102
  seg_info = result['segments_info']
103
  output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
 
104
 
105
- return output_result
 
 
4
  from PIL import Image
5
  from collections import defaultdict
6
  import os
7
+ # Mentioning detectron2 as a dependency directly in requirements.txt tries to install detectron2 before torch and results in an error even if torch is listed as a dependency before detectron2.
8
+ # Hence, installing detectron2 this way when using Gradio HF spaces.
9
  os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
10
 
11
  from detectron2.data import MetadataCatalog
 
22
 
23
  def load_default_ckpt(segmentation_task: str):
24
  if segmentation_task == "semantic":
25
+ default_ckpt = "facebook/mask2former-swin-tiny-ade-semantic"
26
  elif segmentation_task == "instance":
27
+ default_ckpt = "facebook/mask2former-swin-small-coco-instance"
28
  else:
29
+ default_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
30
+ return default_ckpt
31
 
32
  def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
33
  metadata = MetadataCatalog.get("coco_2017_val_panoptic")
 
74
  def predict_masks(input_img_path: str, segmentation_task: str):
75
 
76
  #load model and image processor
77
+ default_ckpt = load_default_ckpt(segmentation_task)
78
+ model, image_processor = load_model_and_processor(default_ckpt)
79
 
80
  ## pass input image through image processor
81
  image = Image.open(input_img_path)
 
91
  predicted_segmentation_map = result.cpu().numpy()
92
  palette = ade_palette()
93
  output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
94
+ output_heading = "Semantic Segmentation Output"
95
 
96
  elif segmentation_task == "instance":
97
  result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
98
  predicted_instance_map = result["segmentation"].cpu().detach().numpy()
99
  output_result = visualize_instance_seg_mask(predicted_instance_map, image)
100
+ output_heading = "Instance Segmentation Output"
101
 
102
  else:
103
  result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
104
  predicted_segmentation_map = result["segmentation"]
105
  seg_info = result['segments_info']
106
  output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
107
+ output_heading = "Panoptic Segmentation Output"
108
 
109
+
110
+ return output_result, output_heading