root commited on
Commit
9599a85
1 Parent(s): 872b038
Files changed (2) hide show
  1. app.py +2 -5
  2. segment.py +10 -10
app.py CHANGED
@@ -51,11 +51,8 @@ def load_mask_ui(input_folder="example_tmp",load_edit = False):
51
  def load_image_ui(load_edit, input_folder="example_tmp"):
52
  #try:
53
  if 1:
54
- for img_path in Path(input_folder).iterdir():
55
- if img_path.name in ["img_512.png"]:
56
- image = Image.open(img_path)
57
- mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit)
58
- image = image.convert('RGB')
59
  segmentation = create_segmentation(mask_np_list)
60
  print("!!", len(mask_np_list))
61
  max_val = len(mask_np_list)-1
 
51
  def load_image_ui(load_edit, input_folder="example_tmp"):
52
  #try:
53
  if 1:
54
+ image, mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit)
55
+ #image = image.convert('RGB')
 
 
 
56
  segmentation = create_segmentation(mask_np_list)
57
  print("!!", len(mask_np_list))
58
  max_val = len(mask_np_list)-1
segment.py CHANGED
@@ -46,16 +46,16 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
46
  handles = []
47
  label_list = []
48
 
49
- mask_list = []
50
 
51
  if not noseg:
52
  if torch.min(segmentation) == 0:
53
  mask = segmentation==0
54
- mask = mask.cpu().detach() # [512,512] bool
55
  segment_label = "rest"
56
  color = viridis(0)
57
  label = f"{segment_label}-{0}"
58
- mask_list.append(mask)
59
  handles.append(mpatches.Patch(color=color, label=label))
60
  label_list.append(label)
61
 
@@ -64,8 +64,8 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
64
  mask = segmentation==segment_id
65
  if torch.min(segmentation) != 0:
66
  segment_id -= 1
67
- mask = mask.cpu().detach() # [512,512] bool
68
- mask_list.append(mask)
69
  segment_label = model.config.id2label[segment['label_id']]
70
  instances_counter[segment['label_id']] += 1
71
 
@@ -75,9 +75,9 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
75
  handles.append(mpatches.Patch(color=color, label=label))
76
  label_list.append(label)
77
  else:
78
- mask = torch.from_numpy(np.full(segmentation.shape, True))
79
  segment_label = "all"
80
- mask_list.append(mask)
81
  color = viridis(0)
82
  label = f"{segment_label}-{0}"
83
  handles.append(mpatches.Patch(color=color, label=label))
@@ -89,7 +89,7 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
89
  ax.legend(handles=handles)
90
  plt.savefig(os.path.join(save_folder, 'seg_init.png'), dpi=500 )
91
  print("; ".join(label_list))
92
- return mask_list,label_list
93
 
94
 
95
 
@@ -110,7 +110,7 @@ def run_segmentation(image, name="example_tmp", size = 512, noseg=False):
110
  image =Image.fromarray(image)
111
  image = image.resize((size, size))
112
  os.makedirs(name, exist_ok=True)
113
- image.save(os.path.join(name,"img_{}.png".format(size)))
114
  inputs = processor(image, return_tensors="pt")
115
  with torch.no_grad():
116
  outputs = model(**inputs)
@@ -121,4 +121,4 @@ def run_segmentation(image, name="example_tmp", size = 512, noseg=False):
121
  mask_list,label_list = draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model)
122
  print("Finish segment")
123
  #block_flag += 1
124
- return mask_list,label_list#, gr.Button.update("1.2 Load edited masks",visible = True), gr.Checkbox.update(label = "Show Segmentation",visible = True)
 
46
  handles = []
47
  label_list = []
48
 
49
+ mask_np_list = []
50
 
51
  if not noseg:
52
  if torch.min(segmentation) == 0:
53
  mask = segmentation==0
54
+ mask = mask.cpu().detach().numpy() # [512,512] bool
55
  segment_label = "rest"
56
  color = viridis(0)
57
  label = f"{segment_label}-{0}"
58
+ mask_np_list.append(mask)
59
  handles.append(mpatches.Patch(color=color, label=label))
60
  label_list.append(label)
61
 
 
64
  mask = segmentation==segment_id
65
  if torch.min(segmentation) != 0:
66
  segment_id -= 1
67
+ mask = mask.cpu().detach().numpy() # [512,512] bool
68
+ mask_np_list.append(mask)
69
  segment_label = model.config.id2label[segment['label_id']]
70
  instances_counter[segment['label_id']] += 1
71
 
 
75
  handles.append(mpatches.Patch(color=color, label=label))
76
  label_list.append(label)
77
  else:
78
+ mask = np.full(segmentation.shape, True)
79
  segment_label = "all"
80
+ mask_np_list.append(mask)
81
  color = viridis(0)
82
  label = f"{segment_label}-{0}"
83
  handles.append(mpatches.Patch(color=color, label=label))
 
89
  ax.legend(handles=handles)
90
  plt.savefig(os.path.join(save_folder, 'seg_init.png'), dpi=500 )
91
  print("; ".join(label_list))
92
+ return mask_np_list,label_list
93
 
94
 
95
 
 
110
  image =Image.fromarray(image)
111
  image = image.resize((size, size))
112
  os.makedirs(name, exist_ok=True)
113
+ #image.save(os.path.join(name,"img_{}.png".format(size)))
114
  inputs = processor(image, return_tensors="pt")
115
  with torch.no_grad():
116
  outputs = model(**inputs)
 
121
  mask_list,label_list = draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model)
122
  print("Finish segment")
123
  #block_flag += 1
124
+ return image,mask_list,label_list#, gr.Button.update("1.2 Load edited masks",visible = True), gr.Checkbox.update(label = "Show Segmentation",visible = True)