root commited on
Commit
3f3b681
1 Parent(s): 9560d91

new demo button changed

Browse files
Files changed (2) hide show
  1. app.py +63 -28
  2. segment.py +4 -3
app.py CHANGED
@@ -11,6 +11,7 @@ from pathlib import Path
11
  from PIL import Image
12
  from functools import partial
13
  from main import run_main
 
14
  LENGTH=512 #length of the square area displaying/editing images
15
  TRANSPARENCY = 150 # transparency of the mask in display
16
 
@@ -53,7 +54,9 @@ def load_image_ui(load_edit, input_folder="example_tmp"):
53
  image = image.convert('RGB')
54
  segmentation = create_segmentation(mask_np_list)
55
  print("!!", len(mask_np_list))
56
- return image, segmentation, mask_np_list, mask_label_list, image
 
 
57
  except:
58
  print("Image folder invalid: The folder should contain image.png")
59
  return None, None, None, None, None
@@ -172,6 +175,7 @@ def slider_release(index, image, mask_np_list_updated, mask_label_list):
172
  return new_image, mask_label
173
 
174
  def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
 
175
  try:
176
  assert np.all(sum(mask_np_list_updated)==1)
177
  except:
@@ -186,6 +190,7 @@ def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="examp
186
  visualize_mask_list_clean(mask_np_list_updated, savepath)
187
 
188
  def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
 
189
  try:
190
  assert np.all(sum(mask_np_list_updated)==1)
191
  except:
@@ -197,7 +202,22 @@ def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="examp
197
  savepath = os.path.join(input_folder, "seg_edited.png")
198
  visualize_mask_list_clean(mask_np_list_updated, savepath)
199
 
200
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  import shutil
202
  if os.path.isdir("./example_tmp"):
203
  shutil.rmtree("./example_tmp")
@@ -224,35 +244,26 @@ with gr.Blocks() as demo:
224
  canvas = gr.Image(value = "./img.png", type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
225
 
226
  segment_button = gr.Button("1.1 Run segmentation")
227
- segment_button.click(run_segmentation,
228
- [canvas, block_flag] ,
229
- [block_flag] )
230
 
231
- text_button = gr.Button("Waiting 1.1 to complete")
232
- text_button.click(load_image_ui,
233
- [ false] ,
234
- [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
235
-
236
- load_edit_button = gr.Button("Waiting 1.1 to complete")
237
- load_edit_button.click(load_image_ui,
238
- [ true] ,
239
- [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
240
 
241
- show_segment = gr.Checkbox(label = "Waiting 1.1 to complete")
242
  flag = gr.State(False)
243
  show_segment.select(show_segmentation,
244
  [image_loaded, segmentation, flag],
245
  [canvas, flag])
246
- def show_more_buttons():
247
- return gr.Button("1.2 Load original masks"), gr.Button("1.2 Load edited masks") , gr.Checkbox(label = "Show Segmentation")
248
- block_flag.change(show_more_buttons, [], [text_button,load_edit_button,show_segment ])
249
 
250
 
251
  # mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
252
  mask_np_list_updated = mask_np_list
253
  with gr.Column():
254
  gr.Markdown("""<p style="text-align: center; font-size: 20px">Edit Mask (Optional)</p>""")
255
- slider = gr.Slider(0, 20, step=1, interactive=True)
256
  label = gr.Textbox()
257
  slider.release(slider_release,
258
  inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list],
@@ -283,12 +294,25 @@ with gr.Blocks() as demo:
283
  add_mask_button.click(add_mask,
284
  [mask_np_list_updated, mask_label_list] ,
285
  [mask_np_list_updated, mask_label_list] )
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  with gr.Tab(label="2 Optimization"):
288
  with gr.Row():
289
  with gr.Column():
290
 
291
- txt_box = gr.Textbox("Click to start optimization...", interactive = False)
292
 
293
  opt_flag = gr.State(0)
294
  gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
@@ -312,7 +336,7 @@ with gr.Blocks() as demo:
312
  diffusion_model_learning_rate ,
313
  max_diffusion_train_steps,
314
  train_batch_size,
315
- gradient_accumulation_steps
316
  ):
317
  run_optimization = partial(
318
  run_main,
@@ -325,8 +349,16 @@ with gr.Blocks() as demo:
325
  gradient_accumulation_steps=int(gradient_accumulation_steps)
326
  )
327
  run_optimization()
328
- return opt_flag+1
 
329
 
 
 
 
 
 
 
 
330
  add_button.click(run_optimization_wrapper,
331
  inputs = [
332
  opt_flag,
@@ -338,15 +370,18 @@ with gr.Blocks() as demo:
338
  train_batch_size,
339
  gradient_accumulation_steps
340
  ],
341
- outputs = [opt_flag]
342
- )
 
 
 
343
 
344
  def change_text(txt_box):
345
  return gr.Textbox("Optimization Finished!", interactive = False)
346
- def change_text2(txt_box):
347
  return gr.Textbox("Start optimization, check logs for progress...", interactive = False)
348
- add_button.click(change_text2, txt_box, txt_box)
349
- opt_flag.change(change_text, txt_box, txt_box)
350
 
351
  with gr.Tab(label="3 Editing"):
352
  with gr.Tab(label="3.1 Text-based editing"):
@@ -401,7 +436,7 @@ with gr.Blocks() as demo:
401
  tgt_prompt ,
402
  tgt_index
403
  ],
404
- outputs = [canvas_text_edit]
405
  )
406
 
407
  def load_pil_img():
 
11
  from PIL import Image
12
  from functools import partial
13
  from main import run_main
14
+ import time
15
  LENGTH=512 #length of the square area displaying/editing images
16
  TRANSPARENCY = 150 # transparency of the mask in display
17
 
 
54
  image = image.convert('RGB')
55
  segmentation = create_segmentation(mask_np_list)
56
  print("!!", len(mask_np_list))
57
+ max_val = len(mask_np_list)-1
58
+ sliderup = gr.Slider.update(value = 0, minimum=0, maximum=max_val, step=1, interactive=True)
59
+ return image, segmentation, mask_np_list, mask_label_list, image, sliderup
60
  except:
61
  print("Image folder invalid: The folder should contain image.png")
62
  return None, None, None, None, None
 
175
  return new_image, mask_label
176
 
177
  def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
178
+ print(mask_np_list_updated)
179
  try:
180
  assert np.all(sum(mask_np_list_updated)==1)
181
  except:
 
190
  visualize_mask_list_clean(mask_np_list_updated, savepath)
191
 
192
  def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
193
+ print(mask_np_list_updated)
194
  try:
195
  assert np.all(sum(mask_np_list_updated)==1)
196
  except:
 
202
  savepath = os.path.join(input_folder, "seg_edited.png")
203
  visualize_mask_list_clean(mask_np_list_updated, savepath)
204
 
205
+
206
+ def image_change():
207
+ directory_path = "./example_tmp/"
208
+ for filename in os.listdir(directory_path):
209
+ file_path = os.path.join(directory_path, filename)
210
+ if os.path.isfile(file_path) or os.path.islink(file_path):
211
+ os.unlink(file_path)
212
+ elif os.path.isdir(file_path):
213
+ shutil.rmtree(file_path)
214
+ return gr.Button.update("1.2 Load original masks",visible = False), gr.Button.update("1.2 Load edited masks",visible = False), gr.Checkbox.update(label = "Show Segmentation",visible = False)
215
+
216
+
217
+ def button_clickable(is_clickable):
218
+ return gr.Button.update(interactive=is_clickable)
219
+
220
+
221
  import shutil
222
  if os.path.isdir("./example_tmp"):
223
  shutil.rmtree("./example_tmp")
 
244
  canvas = gr.Image(value = "./img.png", type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
245
 
246
  segment_button = gr.Button("1.1 Run segmentation")
 
 
 
247
 
248
+ text_button = gr.Button("Waiting 1.1 to complete",visible = False)
249
+
250
+ load_edit_button = gr.Button("Waiting 1.1 to complete",visible = False)
 
 
 
 
 
 
251
 
252
+ show_segment = gr.Checkbox(label = "Waiting 1.1 to complete",visible = False)
253
  flag = gr.State(False)
254
  show_segment.select(show_segmentation,
255
  [image_loaded, segmentation, flag],
256
  [canvas, flag])
257
+ #def show_more_buttons():
258
+ # return gr.Button("1.2 Load original masks",visible = True), gr.Button("1.2 Load edited masks") , gr.Checkbox(label = "Show Segmentation")
259
+ #block_flag.change(show_more_buttons, [], [text_button,load_edit_button,show_segment ])
260
 
261
 
262
  # mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
263
  mask_np_list_updated = mask_np_list
264
  with gr.Column():
265
  gr.Markdown("""<p style="text-align: center; font-size: 20px">Edit Mask (Optional)</p>""")
266
+ slider = gr.Slider(0, 20, step=1, interactive=False)
267
  label = gr.Textbox()
268
  slider.release(slider_release,
269
  inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list],
 
294
  add_mask_button.click(add_mask,
295
  [mask_np_list_updated, mask_label_list] ,
296
  [mask_np_list_updated, mask_label_list] )
297
+
298
+
299
+ segment_button.click(run_segmentation,
300
+ [canvas] ,
301
+ [text_button,load_edit_button,show_segment] )
302
+ text_button.click(load_image_ui, [false] ,
303
+ [image_loaded, segmentation, mask_np_list, mask_label_list, canvas, slider] )
304
+
305
+ load_edit_button.click(load_image_ui, [ true] ,
306
+ [image_loaded, segmentation, mask_np_list, mask_label_list, canvas, slider] )
307
+
308
+ canvas.upload(image_change, inputs=[], outputs=[text_button,load_edit_button,show_segment])
309
+
310
 
311
  with gr.Tab(label="2 Optimization"):
312
  with gr.Row():
313
  with gr.Column():
314
 
315
+ #txt_box = gr.Textbox("Click the button to start optimization...", interactive = False)
316
 
317
  opt_flag = gr.State(0)
318
  gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
 
336
  diffusion_model_learning_rate ,
337
  max_diffusion_train_steps,
338
  train_batch_size,
339
+ gradient_accumulation_steps,
340
  ):
341
  run_optimization = partial(
342
  run_main,
 
349
  gradient_accumulation_steps=int(gradient_accumulation_steps)
350
  )
351
  run_optimization()
352
+ print('finish')
353
+ #return gr.Button.update(value="Optimization finished!", interactive=False)
354
 
355
+ def immediate_update():
356
+ return gr.Button.update("Processing...", interactive=False)
357
+
358
+ def immediate_update2():
359
+ return gr.Button.update("Finished.", interactive=False)
360
+ add_button.click(fn=immediate_update, inputs=[], outputs=[add_button])
361
+
362
  add_button.click(run_optimization_wrapper,
363
  inputs = [
364
  opt_flag,
 
370
  train_batch_size,
371
  gradient_accumulation_steps
372
  ],
373
+ outputs = [])
374
+ add_button.click(fn=immediate_update2, inputs=[], outputs=[add_button])
375
+ add_button.update()
376
+ '''txt_box.change(fn=lambda x: gr.Button.update(value="Optimization Finished!", interactive=True),
377
+ inputs=[txt_box], outputs=[add_button])
378
 
379
  def change_text(txt_box):
380
  return gr.Textbox("Optimization Finished!", interactive = False)
381
+ def change_text2():
382
  return gr.Textbox("Start optimization, check logs for progress...", interactive = False)
383
+ add_button.click(change_text2, [], txt_box)'''
384
+ #opt_flag.change(change_text, txt_box, txt_box)
385
 
386
  with gr.Tab(label="3 Editing"):
387
  with gr.Tab(label="3.1 Text-based editing"):
 
436
  tgt_prompt ,
437
  tgt_index
438
  ],
439
+ outputs = [canvas_text_edit],queue=True,
440
  )
441
 
442
  def load_pil_img():
segment.py CHANGED
@@ -10,6 +10,7 @@ import os
10
  import numpy as np
11
  import argparse
12
  import matplotlib
 
13
 
14
  def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
15
  if type(image_path) is str:
@@ -89,7 +90,7 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
89
 
90
 
91
 
92
- def run_segmentation(image, block_flag, name="example_tmp", size = 512, noseg=False):
93
 
94
  base_folder_path = "."
95
 
@@ -115,5 +116,5 @@ def run_segmentation(image, block_flag, name="example_tmp", size = 512, noseg=Fa
115
  os.makedirs(save_folder, exist_ok=True)
116
  draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model)
117
  print("Finish segment")
118
- block_flag += 1
119
- return block_flag
 
10
  import numpy as np
11
  import argparse
12
  import matplotlib
13
+ import gradio as gr
14
 
15
  def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
16
  if type(image_path) is str:
 
90
 
91
 
92
 
93
+ def run_segmentation(image, name="example_tmp", size = 512, noseg=False):
94
 
95
  base_folder_path = "."
96
 
 
116
  os.makedirs(save_folder, exist_ok=True)
117
  draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model)
118
  print("Finish segment")
119
+ #block_flag += 1
120
+ return gr.Button.update("1.2 Load original masks",visible = True), gr.Button.update("1.2 Load edited masks",visible = True), gr.Checkbox.update(label = "Show Segmentation",visible = True)