afeng commited on
Commit
8963af6
1 Parent(s): 850ea5b
Files changed (7) hide show
  1. app copy.py +3 -2
  2. app.py +104 -80
  3. img.png +0 -0
  4. main copy.py +480 -0
  5. main.py +381 -391
  6. pipeline_dedit_sd.py +4 -3
  7. segment.py +2 -1
app copy.py CHANGED
@@ -317,7 +317,7 @@ with gr.Blocks() as demo:
317
  canvas_text_edit = gr.State() # store mask
318
  with gr.Row():
319
  with gr.Column():
320
- canvas_text_edit = gr.Image(value = None, label="Editing results", show_label=True, height=LENGTH, width=LENGTH)
321
  # canvas_text_edit = gr.Gallery(label = "Edited results")
322
 
323
  with gr.Column():
@@ -342,8 +342,9 @@ with gr.Blocks() as demo:
342
  tgt_idx,
343
  guidance_scale
344
  ],
345
- outputs = [canvas_text_edit]
346
  )
347
 
348
 
 
349
  demo.queue().launch(share=True, debug=True)
 
317
  canvas_text_edit = gr.State() # store mask
318
  with gr.Row():
319
  with gr.Column():
320
+ canvas_text_edit = gr.Image(value = None, type="pil", label="Editing results", show_label=True, height=LENGTH, width=LENGTH)
321
  # canvas_text_edit = gr.Gallery(label = "Edited results")
322
 
323
  with gr.Column():
 
342
  tgt_idx,
343
  guidance_scale
344
  ],
345
+ outputs = []
346
  )
347
 
348
 
349
+
350
  demo.queue().launch(share=True, debug=True)
app.py CHANGED
@@ -10,7 +10,8 @@ from utils_mask import process_mask_to_follow_priority, mask_union, visualize_ma
10
  from pathlib import Path
11
  import subprocess
12
  from PIL import Image
13
-
 
14
  LENGTH=512 #length of the square area displaying/editing images
15
  TRANSPARENCY = 150 # transparency of the mask in display
16
 
@@ -32,7 +33,7 @@ def create_segmentation(mask_np_list):
32
  segmentation = Image.fromarray(np.uint8(segmentation*255))
33
  return segmentation
34
 
35
- def load_mask_ui(input_folder,load_edit = False):
36
  if not load_edit:
37
  mask_list, mask_label_list = load_mask(input_folder)
38
  else:
@@ -44,28 +45,29 @@ def load_mask_ui(input_folder,load_edit = False):
44
 
45
  return mask_np_list, mask_label_list
46
 
47
- def load_image_ui(input_folder, load_edit):
48
  try:
49
  for img_path in Path(input_folder).iterdir():
50
- if img_path.name in ["img.png", "img_1024.png", "img_512.png"]:
51
  image = Image.open(img_path)
52
  mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit)
53
  image = image.convert('RGB')
54
- segmentation = create_segmentation(mask_np_list)
 
55
  return image, segmentation, mask_np_list, mask_label_list, image
56
  except:
57
  print("Image folder invalid: The folder should contain image.png")
58
  return None, None, None, None, None
59
 
60
  def run_edit_text(
61
- input_folder,
62
  num_tokens,
63
  num_sampling_steps,
64
  strength,
65
  edge_thickness,
66
  tgt_prompt,
67
  tgt_idx,
68
- guidance_scale
 
69
  ):
70
  subprocess.run(["python",
71
  "main.py" ,
@@ -89,14 +91,14 @@ def run_edit_text(
89
 
90
 
91
  def run_optimization(
92
- input_folder,
93
  num_tokens,
94
  embedding_learning_rate,
95
  max_emb_train_steps,
96
  diffusion_model_learning_rate,
97
  max_diffusion_train_steps,
98
  train_batch_size,
99
- gradient_accumulation_steps
 
100
  ):
101
  subprocess.run(["python",
102
  "main.py" ,
@@ -124,6 +126,7 @@ def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
124
 
125
  bimg_np = np.array(bimg)
126
  mask_np = mask_np[:,:,np.newaxis]
 
127
  try:
128
  new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np
129
  return Image.fromarray(new_img_np)
@@ -159,6 +162,7 @@ def edit_mask_add(canvas, image, idx, mask_np_list):
159
  return mask_np_list_updated, image_edit
160
 
161
  def slider_release(index, image, mask_np_list_updated, mask_label_list):
 
162
  if index > len(mask_np_list_updated):
163
  return image, "out of range"
164
  else:
@@ -168,7 +172,7 @@ def slider_release(index, image, mask_np_list_updated, mask_label_list):
168
  new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY)
169
  return new_image, mask_label
170
 
171
- def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder):
172
  try:
173
  assert np.all(sum(mask_np_list_updated)==1)
174
  except:
@@ -182,7 +186,7 @@ def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder):
182
  savepath = os.path.join(input_folder, "seg_current.png")
183
  visualize_mask_list_clean(mask_np_list_updated, savepath)
184
 
185
- def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder):
186
  try:
187
  assert np.all(sum(mask_np_list_updated)==1)
188
  except:
@@ -195,6 +199,10 @@ def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder):
195
  visualize_mask_list_clean(mask_np_list_updated, savepath)
196
 
197
 
 
 
 
 
198
  from segment import run_segmentation
199
  with gr.Blocks() as demo:
200
  image = gr.State() # store mask
@@ -213,8 +221,7 @@ with gr.Blocks() as demo:
213
  with gr.Tab(label="1 Edit mask"):
214
  with gr.Row():
215
  with gr.Column():
216
- canvas = gr.Image(value = None, type="pil", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
217
- input_folder = gr.Textbox(value="example1", label="input folder", interactive= True, )
218
 
219
  segment_button = gr.Button("1.1 Run segmentation")
220
  segment_button.click(run_segmentation,
@@ -223,23 +230,22 @@ with gr.Blocks() as demo:
223
 
224
  text_button = gr.Button("1.2 Load original masks")
225
  text_button.click(load_image_ui,
226
- [input_folder, false] ,
227
  [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
228
-
229
  load_edit_button = gr.Button("1.2 Load edited masks")
230
  load_edit_button.click(load_image_ui,
231
- [input_folder, true] ,
232
  [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
233
 
234
  show_segment = gr.Checkbox(label = "Show Segmentation")
235
-
236
  flag = gr.State(False)
237
  show_segment.select(show_segmentation,
238
  [image_loaded, segmentation, flag],
239
  [canvas, flag])
240
-
241
- mask_np_list_updated = copy.deepcopy(mask_np_list)
242
-
243
  with gr.Column():
244
  gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
245
  slider = gr.Slider(0, 20, step=1, interactive=True)
@@ -256,17 +262,17 @@ with gr.Blocks() as demo:
256
 
257
  save_button2 = gr.Button("Set and Save as edited masks")
258
  save_button2.click( save_as_edit_mask,
259
- [mask_np_list_updated, mask_label_list, input_folder] ,
260
  [] )
261
 
262
  save_button = gr.Button("Set and Save as original masks")
263
  save_button.click( save_as_orig_mask,
264
- [mask_np_list_updated, mask_label_list, input_folder] ,
265
  [] )
266
 
267
  back_button = gr.Button("Back to current seg")
268
  back_button.click( load_mask_ui,
269
- [input_folder] ,
270
  [ mask_np_list_updated,mask_label_list] )
271
 
272
  add_mask_button = gr.Button("Add new empty mask")
@@ -274,70 +280,88 @@ with gr.Blocks() as demo:
274
  [mask_np_list_updated, mask_label_list] ,
275
  [mask_np_list_updated, mask_label_list] )
276
 
277
- # with gr.Tab(label="2 Optimization"):
278
- # with gr.Row():
279
- # with gr.Column():
280
- # canvas_opt = gr.Image(value = canvas.value, type="pil", label="Loaded Image", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
281
 
282
- # with gr.Column():
283
- # gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
284
- # num_tokens = gr.Textbox(value="5", label="num tokens to represent each object", interactive= True)
285
- # embedding_learning_rate = gr.Textbox(value="1e-4", label="Embedding optimization: Learning rate", interactive= True )
286
- # max_emb_train_steps = gr.Textbox(value="500", label="embedding optimization: Training steps", interactive= True )
287
 
288
- # diffusion_model_learning_rate = gr.Textbox(value="5e-5", label="UNet Optimization: Learning rate", interactive= True )
289
- # max_diffusion_train_steps = gr.Textbox(value="500", label="UNet Optimization: Learning rate: Training steps", interactive= True )
290
 
291
- # train_batch_size = gr.Textbox(value="5", label="Batch size", interactive= True )
292
- # gradient_accumulation_steps=gr.Textbox(value="5", label="Gradient accumulation", interactive= True )
293
 
294
- # add_button = gr.Button("Run optimization")
295
- # add_button.click(run_optimization,
296
- # inputs = [
297
- # input_folder,
298
- # num_tokens,
299
- # embedding_learning_rate,
300
- # max_emb_train_steps,
301
- # diffusion_model_learning_rate,
302
- # max_diffusion_train_steps,
303
- # train_batch_size,gradient_accumulation_steps
304
- # ],
305
- # outputs = []
306
- # )
307
-
308
-
309
- # with gr.Tab(label="3 Editing"):
310
- # with gr.Tab(label="3.1 Text-based editing"):
311
- # canvas_text_edit = gr.State() # store mask
312
- # with gr.Row():
313
- # with gr.Column():
314
- # canvas_text_edit = gr.Image(value = None, label="Editing results", show_label=True, height=LENGTH, width=LENGTH)
315
- # # canvas_text_edit = gr.Gallery(label = "Edited results")
 
 
 
316
 
317
- # with gr.Column():
318
- # gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
- # tgt_prompt = gr.Textbox(value="Dog", label="Editing: Text prompt", interactive= True )
321
- # tgt_idx = gr.Textbox(value="0", label="Editing: Object index", interactive= True )
322
- # guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True )
323
- # num_sampling_steps = gr.Textbox(value="50", label="Editing: Sampling steps", interactive= True )
324
- # edge_thickness = gr.Textbox(value="10", label="Editing: Edge thickness", interactive= True )
325
- # strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
326
 
327
- # add_button = gr.Button("Run Editing")
328
- # add_button.click(run_edit_text,
329
- # inputs = [
330
- # input_folder,
331
- # num_tokens,
332
- # num_sampling_steps,
333
- # strength,
334
- # edge_thickness,
335
- # tgt_prompt,
336
- # tgt_idx,
337
- # guidance_scale
338
- # ],
339
- # outputs = [canvas_text_edit]
340
- # )
341
 
342
 
343
  demo.queue().launch(share=True, debug=True)
 
10
  from pathlib import Path
11
  import subprocess
12
  from PIL import Image
13
+ from functools import partial
14
+ from main import run_main
15
  LENGTH=512 #length of the square area displaying/editing images
16
  TRANSPARENCY = 150 # transparency of the mask in display
17
 
 
33
  segmentation = Image.fromarray(np.uint8(segmentation*255))
34
  return segmentation
35
 
36
+ def load_mask_ui(input_folder="example_tmp",load_edit = False):
37
  if not load_edit:
38
  mask_list, mask_label_list = load_mask(input_folder)
39
  else:
 
45
 
46
  return mask_np_list, mask_label_list
47
 
48
+ def load_image_ui(load_edit, input_folder="example_tmp"):
49
  try:
50
  for img_path in Path(input_folder).iterdir():
51
+ if img_path.name in ["img_512.png"]:
52
  image = Image.open(img_path)
53
  mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit)
54
  image = image.convert('RGB')
55
+ segmentation = create_segmentation(mask_np_list)
56
+ print("!!", len(mask_np_list))
57
  return image, segmentation, mask_np_list, mask_label_list, image
58
  except:
59
  print("Image folder invalid: The folder should contain image.png")
60
  return None, None, None, None, None
61
 
62
  def run_edit_text(
 
63
  num_tokens,
64
  num_sampling_steps,
65
  strength,
66
  edge_thickness,
67
  tgt_prompt,
68
  tgt_idx,
69
+ guidance_scale,
70
+ input_folder="example_tmp"
71
  ):
72
  subprocess.run(["python",
73
  "main.py" ,
 
91
 
92
 
93
  def run_optimization(
 
94
  num_tokens,
95
  embedding_learning_rate,
96
  max_emb_train_steps,
97
  diffusion_model_learning_rate,
98
  max_diffusion_train_steps,
99
  train_batch_size,
100
+ gradient_accumulation_steps,
101
+ input_folder = "example_tmp"
102
  ):
103
  subprocess.run(["python",
104
  "main.py" ,
 
126
 
127
  bimg_np = np.array(bimg)
128
  mask_np = mask_np[:,:,np.newaxis]
129
+
130
  try:
131
  new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np
132
  return Image.fromarray(new_img_np)
 
162
  return mask_np_list_updated, image_edit
163
 
164
  def slider_release(index, image, mask_np_list_updated, mask_label_list):
165
+
166
  if index > len(mask_np_list_updated):
167
  return image, "out of range"
168
  else:
 
172
  new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY)
173
  return new_image, mask_label
174
 
175
+ def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
176
  try:
177
  assert np.all(sum(mask_np_list_updated)==1)
178
  except:
 
186
  savepath = os.path.join(input_folder, "seg_current.png")
187
  visualize_mask_list_clean(mask_np_list_updated, savepath)
188
 
189
+ def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
190
  try:
191
  assert np.all(sum(mask_np_list_updated)==1)
192
  except:
 
199
  visualize_mask_list_clean(mask_np_list_updated, savepath)
200
 
201
 
202
+ import shutil
203
+ if os.path.isdir("./example_tmp"):
204
+ shutil.rmtree("./example_tmp")
205
+
206
  from segment import run_segmentation
207
  with gr.Blocks() as demo:
208
  image = gr.State() # store mask
 
221
  with gr.Tab(label="1 Edit mask"):
222
  with gr.Row():
223
  with gr.Column():
224
+ canvas = gr.Image(value = "./img.png", type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
 
225
 
226
  segment_button = gr.Button("1.1 Run segmentation")
227
  segment_button.click(run_segmentation,
 
230
 
231
  text_button = gr.Button("1.2 Load original masks")
232
  text_button.click(load_image_ui,
233
+ [ false] ,
234
  [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
235
+
236
  load_edit_button = gr.Button("1.2 Load edited masks")
237
  load_edit_button.click(load_image_ui,
238
+ [ true] ,
239
  [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
240
 
241
  show_segment = gr.Checkbox(label = "Show Segmentation")
 
242
  flag = gr.State(False)
243
  show_segment.select(show_segmentation,
244
  [image_loaded, segmentation, flag],
245
  [canvas, flag])
246
+
247
+ # mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
248
+ mask_np_list_updated = mask_np_list
249
  with gr.Column():
250
  gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
251
  slider = gr.Slider(0, 20, step=1, interactive=True)
 
262
 
263
  save_button2 = gr.Button("Set and Save as edited masks")
264
  save_button2.click( save_as_edit_mask,
265
+ [mask_np_list_updated, mask_label_list] ,
266
  [] )
267
 
268
  save_button = gr.Button("Set and Save as original masks")
269
  save_button.click( save_as_orig_mask,
270
+ [mask_np_list_updated, mask_label_list] ,
271
  [] )
272
 
273
  back_button = gr.Button("Back to current seg")
274
  back_button.click( load_mask_ui,
275
+ [] ,
276
  [ mask_np_list_updated,mask_label_list] )
277
 
278
  add_mask_button = gr.Button("Add new empty mask")
 
280
  [mask_np_list_updated, mask_label_list] ,
281
  [mask_np_list_updated, mask_label_list] )
282
 
283
+ with gr.Tab(label="2 Optimization"):
284
+ with gr.Row():
 
 
285
 
286
+ with gr.Column():
287
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
288
+ num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
289
+ embedding_learning_rate = gr.Textbox(value="0.0001", label="Embedding optimization: Learning rate", interactive= True )
290
+ max_emb_train_steps = gr.Number(value="200", label="embedding optimization: Training steps", interactive= True )
291
 
292
+ diffusion_model_learning_rate = gr.Textbox(value="0.00005", label="UNet Optimization: Learning rate", interactive= True )
293
+ max_diffusion_train_steps = gr.Number(value="200", label="UNet Optimization: Learning rate: Training steps", interactive= True )
294
 
295
+ train_batch_size = gr.Number(value="5", label="Batch size", interactive= True )
296
+ gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True )
297
 
298
+ add_button = gr.Button("Run optimization")
299
+
300
+ run_optimization = partial(
301
+ run_main,
302
+ num_tokens=int(num_tokens.value),
303
+ embedding_learning_rate = float(embedding_learning_rate.value),
304
+ max_emb_train_steps = int(max_emb_train_steps.value),
305
+ diffusion_model_learning_rate= float(diffusion_model_learning_rate.value),
306
+ max_diffusion_train_steps = int(max_diffusion_train_steps.value),
307
+ train_batch_size=int(train_batch_size.value),
308
+ gradient_accumulation_steps=int(gradient_accumulation_steps.value)
309
+ )
310
+ add_button.click(run_optimization,
311
+ inputs = [],
312
+ outputs = []
313
+ )
314
+
315
+
316
+ with gr.Tab(label="3 Editing"):
317
+ with gr.Tab(label="3.1 Text-based editing"):
318
+
319
+ with gr.Row():
320
+ with gr.Column():
321
+ canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True)
322
+ # canvas_text_edit = gr.Gallery(label = "Edited results")
323
 
324
+ with gr.Column():
325
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""")
326
+
327
+ tgt_prompt = gr.Textbox(value="White bag", label="Editing: Text prompt", interactive= True )
328
+ tgt_index = gr.Number(value="0", label="Editing: Object index", interactive= True )
329
+ guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True )
330
+ num_sampling_steps = gr.Number(value="50", label="Editing: Sampling steps", interactive= True )
331
+ edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True )
332
+ strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
333
+
334
+ add_button = gr.Button("Run Editing")
335
+ run_edit_text = partial(
336
+ run_main,
337
+ load_trained=True,
338
+ text=True,
339
+ num_tokens = int(num_tokens.value),
340
+ guidance_scale = float(guidance_scale.value),
341
+ num_sampling_steps = int(num_sampling_steps.value),
342
+ strength = float(strength.value),
343
+ edge_thickness = int(edge_thickness.value),
344
+ num_imgs = 1,
345
+ tgt_prompt = tgt_prompt.value,
346
+ tgt_index = int(tgt_index.value)
347
+ )
348
+
349
+ add_button.click(run_edit_text,
350
+ inputs = [],
351
+ outputs = [canvas_text_edit]
352
+ )
353
 
354
+ def load_pil_img():
355
+ from PIL import Image
356
+ return Image.open("example_tmp/text/out_text_0.png")
 
 
 
357
 
358
+ load_button = gr.Button("Load results")
359
+ load_button.click(load_pil_img,
360
+ inputs = [],
361
+ outputs = [canvas_text_edit]
362
+ )
363
+
364
+
 
 
 
 
 
 
 
365
 
366
 
367
  demo.queue().launch(share=True, debug=True)
img.png ADDED
main copy.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import argparse
5
+ from peft import LoraConfig
6
+ from old.pipeline_dedit_sdxl import DEditSDXLPipeline
7
+ from pipeline_dedit_sd import DEditSDPipeline
8
+ from utils import load_image, load_mask, load_mask_edit
9
+ from utils_mask import process_mask_move_torch, process_mask_remove_torch, mask_union_torch, mask_substract_torch, create_outer_edge_mask_torch
10
+ from utils_mask import check_mask_overlap_torch, check_cover_all_torch, visualize_mask_list, get_mask_difference_torch, save_mask_list_to_npys
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--name", type=str,required=True, default=None)
14
+ parser.add_argument("--name_2", type=str,required=False, default=None)
15
+ parser.add_argument("--dpm", type=str,required=True, default="sd")
16
+ parser.add_argument("--resolution", type=int, default=1024)
17
+ parser.add_argument("--seed", type=int, default=42)
18
+ parser.add_argument("--embedding_learning_rate", type=float, default=1e-4)
19
+ parser.add_argument("--max_emb_train_steps", type=int, default=200)
20
+ parser.add_argument("--diffusion_model_learning_rate", type=float, default=5e-5)
21
+ parser.add_argument("--max_diffusion_train_steps", type=int, default=200)
22
+ parser.add_argument("--train_batch_size", type=int, default=1)
23
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
24
+ parser.add_argument("--num_tokens", type=int, default=1)
25
+
26
+
27
+ parser.add_argument("--load_trained", default=False, action="store_true" )
28
+ parser.add_argument("--num_sampling_steps", type=int, default=50)
29
+ parser.add_argument("--guidance_scale", type=float, default = 3 )
30
+ parser.add_argument("--strength", type=float, default=0.8)
31
+
32
+ parser.add_argument("--train_full_lora", default=False, action="store_true" )
33
+ parser.add_argument("--lora_rank", type=int, default=4)
34
+ parser.add_argument("--lora_alpha", type=int, default=4)
35
+
36
+ parser.add_argument("--prompt_auxin_list", nargs="+", type=str, default = None)
37
+ parser.add_argument("--prompt_auxin_idx_list", nargs="+", type=int, default = None)
38
+
39
+ # general editing configs
40
+ parser.add_argument("--load_edited_mask", default=False, action="store_true")
41
+ parser.add_argument("--load_edited_processed_mask", default=False, action="store_true")
42
+ parser.add_argument("--edge_thickness", type=int, default=20)
43
+ parser.add_argument("--num_imgs", type=int, default = 1 )
44
+ parser.add_argument('--active_mask_list', nargs="+", type=int)
45
+ parser.add_argument("--tgt_index", type=int, default=None)
46
+
47
+ # recon
48
+ parser.add_argument("--recon", default=False, action="store_true" )
49
+ parser.add_argument("--recon_an_item", default=False, action="store_true" )
50
+ parser.add_argument("--recon_prompt", type=str, default=None)
51
+
52
+ # text-based editing
53
+ parser.add_argument("--text", default=False, action="store_true")
54
+ parser.add_argument("--tgt_prompt", type=str, default=None)
55
+
56
+ # image-based editing
57
+ parser.add_argument("--image", default=False, action="store_true" )
58
+ parser.add_argument("--src_index", type=int, default=None)
59
+ parser.add_argument("--tgt_name", type=str, default=None)
60
+
61
+ # mask-based move
62
+ parser.add_argument("--move_resize", default=False, action="store_true" )
63
+ parser.add_argument('--tgt_indices_list', nargs="+", type=int)
64
+ parser.add_argument("--delta_x_list", nargs="+", type=int)
65
+ parser.add_argument("--delta_y_list", nargs="+", type=int)
66
+ parser.add_argument("--priority_list", nargs="+", type=int)
67
+ parser.add_argument("--force_mask_remain", type=int, default=None)
68
+ parser.add_argument("--resize_list", nargs="+", type=float)
69
+
70
+ # remove
71
+ parser.add_argument("--remove", default=False, action="store_true" )
72
+ parser.add_argument("--load_edited_removemask", default=False, action="store_true")
73
+
74
+ args = parser.parse_args()
75
+
76
+
77
+ def run_main(
78
+ name=None,
79
+ name_2=None,
80
+ dpm="sd",
81
+ resolution=1024,
82
+ seed=42,
83
+ embedding_learning_rate=1e-4,
84
+ max_emb_train_steps=200,
85
+ diffusion_model_learning_rate=5e-5,
86
+ max_diffusion_train_steps=200,
87
+ train_batch_size=1,
88
+ gradient_accumulation_steps=1,
89
+ num_tokens=1,
90
+
91
+ load_trained="store_true" ,
92
+ num_sampling_steps=50,
93
+ guidance_scale= 3 ,
94
+ strength=0.8,
95
+
96
+ train_full_lora="store_true" ,
97
+ lora_rank=4,
98
+ lora_alpha=4,
99
+
100
+ prompt_auxin_list = None,
101
+ prompt_auxin_idx_list= None,
102
+
103
+ load_edited_mask="store_true",
104
+ load_edited_processed_mask="store_true",
105
+ edge_thickness=20,
106
+ num_imgs= 1 ,
107
+ active_mask_list = None,
108
+ tgt_index=None,
109
+
110
+ recon=False ,
111
+ recon_an_item=False,
112
+ recon_prompt=None,
113
+
114
+ text="store_true",
115
+ tgt_prompt=None,
116
+
117
+ image="store_true" ,
118
+ src_index=None,
119
+ tgt_name=None,
120
+
121
+ move_resize="store_true" ,
122
+ tgt_indices_list=None,
123
+ delta_x_list=None,
124
+ delta_y_list=None,
125
+ priority_list=None,
126
+ force_mask_remain=None,
127
+ resize_list=None,
128
+
129
+ remove=False,
130
+ load_edited_removemask=False
131
+ ):
132
+ torch.cuda.manual_seed_all(args.seed)
133
+ torch.manual_seed(args.seed)
134
+ base_input_folder = "."
135
+ base_output_folder = "."
136
+
137
+ input_folder = os.path.join(base_input_folder, args.name)
138
+
139
+
140
+ mask_list, mask_label_list = load_mask(input_folder)
141
+ assert mask_list[0].shape[0] == args.resolution, "Segmentation should be done on size {}".format(args.resolution)
142
+ try:
143
+ image_gt = load_image(os.path.join(input_folder, "img_{}.png".format(args.resolution) ), size = args.resolution)
144
+ except:
145
+ image_gt = load_image(os.path.join(input_folder, "img_{}.jpg".format(args.resolution) ), size = args.resolution)
146
+
147
+ if args.image:
148
+ input_folder_2 = os.path.join(base_input_folder, args.name_2)
149
+ mask_list_2, mask_label_list_2 = load_mask(input_folder_2)
150
+ assert mask_list_2[0].shape[0] == args.resolution, "Segmentation should be done on size {}".format(args.resolution)
151
+ try:
152
+ image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.png".format(args.resolution) ), size = args.resolution)
153
+ except:
154
+ image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.jpg".format(args.resolution) ), size = args.resolution)
155
+ output_dir = os.path.join(base_output_folder, args.name + "_" + args.name_2)
156
+ os.makedirs(output_dir, exist_ok = True)
157
+ else:
158
+ output_dir = os.path.join(base_output_folder, args.name)
159
+ os.makedirs(output_dir, exist_ok = True)
160
+
161
+ if args.dpm == "sd":
162
+ if args.image:
163
+ pipe = DEditSDPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = args.resolution, num_tokens = args.num_tokens)
164
+ else:
165
+ pipe = DEditSDPipeline(mask_list, mask_label_list, resolution = args.resolution, num_tokens = args.num_tokens)
166
+
167
+ elif args.dpm == "sdxl":
168
+ if args.image:
169
+ pipe = DEditSDXLPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = args.resolution, num_tokens = args.num_tokens)
170
+ else:
171
+ pipe = DEditSDXLPipeline(mask_list, mask_label_list, resolution = args.resolution, num_tokens = args.num_tokens)
172
+
173
+ else:
174
+ raise NotImplementedError
175
+
176
+ set_string_list = pipe.set_string_list
177
+ if args.prompt_auxin_list is not None:
178
+ for auxin_idx, auxin_prompt in zip(args.prompt_auxin_idx_list, args.prompt_auxin_list):
179
+ set_string_list[auxin_idx] = auxin_prompt.replace("*", set_string_list[auxin_idx] )
180
+ print(set_string_list)
181
+
182
+ if args.image:
183
+ set_string_list_2 = pipe.set_string_list_2
184
+ print(set_string_list_2)
185
+
186
+ if args.load_trained:
187
+ unet_save_path = os.path.join(output_dir, "unet.pt")
188
+ unet_state_dict = torch.load(unet_save_path)
189
+ text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
190
+ text_encoder1_state_dict = torch.load(text_encoder1_save_path)
191
+ if args.dpm == "sdxl":
192
+ text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
193
+ text_encoder2_state_dict = torch.load(text_encoder2_save_path)
194
+
195
+ if 'lora' in ''.join(unet_state_dict.keys()):
196
+ unet_lora_config = LoraConfig(
197
+ r=args.lora_rank,
198
+ lora_alpha=args.lora_alpha,
199
+ init_lora_weights="gaussian",
200
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
201
+ )
202
+ pipe.unet.add_adapter(unet_lora_config)
203
+
204
+ pipe.unet.load_state_dict(unet_state_dict)
205
+ pipe.text_encoder.load_state_dict(text_encoder1_state_dict)
206
+ if args.dpm == "sdxl":
207
+ pipe.text_encoder_2.load_state_dict(text_encoder2_state_dict)
208
+ else:
209
+ if args.image:
210
+ pipe.mask_list = [m.cuda() for m in pipe.mask_list]
211
+ pipe.mask_list_2 = [m.cuda() for m in pipe.mask_list_2]
212
+ pipe.train_emb_2imgs(
213
+ image_gt,
214
+ image_gt_2,
215
+ set_string_list,
216
+ set_string_list_2,
217
+ gradient_accumulation_steps = args.gradient_accumulation_steps,
218
+ embedding_learning_rate = args.embedding_learning_rate,
219
+ max_emb_train_steps = args.max_emb_train_steps,
220
+ train_batch_size = args.train_batch_size,
221
+ )
222
+
223
+ pipe.train_model_2imgs(
224
+ image_gt,
225
+ image_gt_2,
226
+ set_string_list,
227
+ set_string_list_2,
228
+ gradient_accumulation_steps = args.gradient_accumulation_steps,
229
+ max_diffusion_train_steps = args.max_diffusion_train_steps,
230
+ diffusion_model_learning_rate = args.diffusion_model_learning_rate ,
231
+ train_batch_size =args.train_batch_size,
232
+ train_full_lora = args.train_full_lora,
233
+ lora_rank = args.lora_rank,
234
+ lora_alpha = args.lora_alpha
235
+ )
236
+
237
+ else:
238
+ pipe.mask_list = [m.cuda() for m in pipe.mask_list]
239
+ pipe.train_emb(
240
+ image_gt,
241
+ set_string_list,
242
+ gradient_accumulation_steps = args.gradient_accumulation_steps,
243
+ embedding_learning_rate = args.embedding_learning_rate,
244
+ max_emb_train_steps = args.max_emb_train_steps,
245
+ train_batch_size = args.train_batch_size,
246
+ )
247
+
248
+ pipe.train_model(
249
+ image_gt,
250
+ set_string_list,
251
+ gradient_accumulation_steps = args.gradient_accumulation_steps,
252
+ max_diffusion_train_steps = args.max_diffusion_train_steps,
253
+ diffusion_model_learning_rate = args.diffusion_model_learning_rate ,
254
+ train_batch_size = args.train_batch_size,
255
+ train_full_lora = args.train_full_lora,
256
+ lora_rank = args.lora_rank,
257
+ lora_alpha = args.lora_alpha
258
+ )
259
+
260
+
261
+ unet_save_path = os.path.join(output_dir, "unet.pt")
262
+ torch.save(pipe.unet.state_dict(),unet_save_path )
263
+ text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
264
+ torch.save(pipe.text_encoder.state_dict(), text_encoder1_save_path)
265
+ if args.dpm == "sdxl":
266
+ text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
267
+ torch.save(pipe.text_encoder_2.state_dict(), text_encoder2_save_path )
268
+
269
+
270
+ if args.recon:
271
+ output_dir = os.path.join(output_dir, "recon")
272
+ os.makedirs(output_dir, exist_ok = True)
273
+ if args.recon_an_item:
274
+ mask_list = [torch.from_numpy(np.ones_like(mask_list[0].numpy()))]
275
+ tgt_string = set_string_list[args.tgt_index]
276
+ tgt_string = args.recon_prompt.replace("*", tgt_string)
277
+ set_string_list = [tgt_string]
278
+ print(set_string_list)
279
+ save_path = os.path.join(output_dir, "out_recon.png")
280
+ x_np = pipe.inference_with_mask(
281
+ save_path,
282
+ guidance_scale = args.guidance_scale,
283
+ num_sampling_steps = args.num_sampling_steps,
284
+ seed = args.seed,
285
+ num_imgs = args.num_imgs,
286
+ set_string_list = set_string_list,
287
+ mask_list = mask_list
288
+ )
289
+
290
+ if args.text:
291
+ print("Text-guided editing ")
292
+ output_dir = os.path.join(output_dir, "text")
293
+ os.makedirs(output_dir, exist_ok = True)
294
+ save_path = os.path.join(output_dir, "out_text.png")
295
+ set_string_list[args.tgt_index] = args.tgt_prompt
296
+ mask_active = torch.zeros_like(mask_list[0])
297
+ mask_active = mask_union_torch(mask_active, mask_list[args.tgt_index])
298
+
299
+ if args.active_mask_list is not None:
300
+ for midx in args.active_mask_list:
301
+ mask_active = mask_union_torch(mask_active, mask_list[midx])
302
+
303
+ if args.load_edited_mask:
304
+ mask_list_edited, mask_label_list_edited = load_mask_edit(input_folder)
305
+ mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
306
+ mask_active = mask_union_torch(mask_active, mask_diff)
307
+ mask_list = mask_list_edited
308
+ save_path = os.path.join(output_dir, "out_textEdited.png")
309
+
310
+ mask_hard = mask_substract_torch(torch.ones_like(mask_list[0]), mask_active)
311
+ mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = args.edge_thickness)
312
+ mask_hard = mask_substract_torch(mask_hard, mask_soft)
313
+
314
+ pipe.inference_with_mask(
315
+ save_path,
316
+ orig_image = image_gt,
317
+ set_string_list = set_string_list,
318
+ guidance_scale = args.guidance_scale,
319
+ strength = args.strength,
320
+ num_imgs = args.num_imgs,
321
+ mask_hard= mask_hard,
322
+ mask_soft = mask_soft,
323
+ mask_list = mask_list,
324
+ seed = args.seed,
325
+ num_sampling_steps = args.num_sampling_steps
326
+ )
327
+
328
+ if args.remove:
329
+ output_dir = os.path.join(output_dir, "remove")
330
+ save_path = os.path.join(output_dir, "out_remove.png")
331
+ os.makedirs(output_dir, exist_ok = True)
332
+ mask_active = torch.zeros_like(mask_list[0])
333
+
334
+ if args.load_edited_mask:
335
+ mask_list_edited, _ = load_mask_edit(input_folder)
336
+ mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
337
+ mask_active = mask_union_torch(mask_active, mask_diff)
338
+ mask_list = mask_list_edited
339
+
340
+ if args.load_edited_processed_mask:
341
+ # manually edit or draw masks after removing one index, then load
342
+ mask_list_processed, _ = load_mask_edit(output_dir)
343
+ mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
344
+ else:
345
+ # generate masks after removing one index, using nearest neighbor algorithm
346
+ mask_list_processed, mask_remain = process_mask_remove_torch(mask_list, args.tgt_index)
347
+ save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
348
+ visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_removed.png"))
349
+ check_cover_all_torch(*mask_list_processed)
350
+ mask_active = mask_union_torch(mask_active, mask_remain)
351
+
352
+ if args.active_mask_list is not None:
353
+ for midx in args.active_mask_list:
354
+ mask_active = mask_union_torch(mask_active, mask_list[midx])
355
+
356
+ mask_hard = 1 - mask_active
357
+ mask_soft = create_outer_edge_mask_torch(mask_remain, edge_thickness = args.edge_thickness)
358
+ mask_hard = mask_substract_torch(mask_hard, mask_soft)
359
+
360
+ pipe.inference_with_mask(
361
+ save_path,
362
+ orig_image = image_gt,
363
+ guidance_scale = args.guidance_scale,
364
+ strength = args.strength,
365
+ num_imgs = args.num_imgs,
366
+ mask_hard= mask_hard,
367
+ mask_soft = mask_soft,
368
+ mask_list = mask_list_processed,
369
+ seed = args.seed,
370
+ num_sampling_steps = args.num_sampling_steps
371
+ )
372
+
373
+ if args.image:
374
+ output_dir = os.path.join(output_dir, "image")
375
+ save_path = os.path.join(output_dir, "out_image.png")
376
+ os.makedirs(output_dir, exist_ok = True)
377
+ mask_active = torch.zeros_like(mask_list[0])
378
+
379
+ if None not in (args.tgt_name, args.src_index, args.tgt_index):
380
+ if args.tgt_name == args.name:
381
+ set_string_list_tgt = set_string_list
382
+ set_string_list_src = set_string_list_2
383
+ image_tgt = image_gt
384
+ if args.load_edited_mask:
385
+ mask_list_edited, _ = load_mask_edit(input_folder)
386
+ mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
387
+ mask_active = mask_union_torch(mask_active, mask_diff)
388
+ mask_list = mask_list_edited
389
+ save_path = os.path.join(output_dir, "out_imageEdited.png")
390
+ mask_list_tgt = mask_list
391
+
392
+ elif args.tgt_name == args.name_2:
393
+ set_string_list_tgt = set_string_list_2
394
+ set_string_list_src = set_string_list
395
+ image_tgt = image_gt_2
396
+ if args.load_edited_mask:
397
+ mask_list_2_edited, _ = load_mask_edit(input_folder_2)
398
+ mask_diff = get_mask_difference_torch(mask_list_2_edited, mask_list_2)
399
+ mask_active = mask_union_torch(mask_active, mask_diff)
400
+ mask_list_2 = mask_list_2_edited
401
+ save_path = os.path.join(output_dir, "out_imageEdited.png")
402
+ mask_list_tgt = mask_list_2
403
+ else:
404
+ exit("tgt_name should be either name or name_2")
405
+
406
+ set_string_list_tgt[args.tgt_index] = set_string_list_src[args.src_index]
407
+
408
+ mask_active = mask_list_tgt[args.tgt_index]
409
+ mask_frozen = (1-mask_active.float()).to(mask_active.device)
410
+ mask_soft = create_outer_edge_mask_torch(mask_active.cpu(), edge_thickness = args.edge_thickness)
411
+ mask_hard = mask_substract_torch(mask_frozen.cpu(), mask_soft.cpu())
412
+
413
+ mask_list_tgt = [m.cuda() for m in mask_list_tgt]
414
+
415
+ pipe.inference_with_mask(
416
+ save_path,
417
+ set_string_list = set_string_list_tgt,
418
+ mask_list = mask_list_tgt,
419
+ guidance_scale = args.guidance_scale,
420
+ num_sampling_steps = args.num_sampling_steps,
421
+ mask_hard = mask_hard.cuda(),
422
+ mask_soft = mask_soft.cuda(),
423
+ num_imgs = args.num_imgs,
424
+ orig_image = image_tgt,
425
+ strength = args.strength,
426
+ )
427
+
428
+ if args.move_resize:
429
+ output_dir = os.path.join(output_dir, "move_resize")
430
+ os.makedirs(output_dir, exist_ok = True)
431
+ save_path = os.path.join(output_dir, "out_moveresize.png")
432
+ mask_active = torch.zeros_like(mask_list[0])
433
+
434
+ if args.load_edited_mask:
435
+ mask_list_edited, _ = load_mask_edit(input_folder)
436
+ mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
437
+ mask_active = mask_union_torch(mask_active, mask_diff)
438
+ mask_list = mask_list_edited
439
+ # save_path = os.path.join(output_dir, "out_moveresizeEdited.png")
440
+
441
+ if args.load_edited_processed_mask:
442
+ mask_list_processed, _ = load_mask_edit(output_dir)
443
+ mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
444
+ else:
445
+ mask_list_processed, mask_remain = process_mask_move_torch(
446
+ mask_list,
447
+ args.tgt_indices_list,
448
+ args.delta_x_list,
449
+ args.delta_y_list, args.priority_list,
450
+ force_mask_remain = args.force_mask_remain,
451
+ resize_list = args.resize_list
452
+ )
453
+ save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
454
+ visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_move_resize.png"))
455
+ active_idxs = args.tgt_indices_list
456
+
457
+ mask_active = mask_union_torch(mask_active, *[m for midx, m in enumerate(mask_list_processed) if midx in active_idxs])
458
+ mask_active = mask_union_torch(mask_remain, mask_active)
459
+ if args.active_mask_list is not None:
460
+ for midx in args.active_mask_list:
461
+ mask_active = mask_union_torch(mask_active, mask_list_processed[midx])
462
+
463
+ mask_frozen =(1 - mask_active.float())
464
+ mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = args.edge_thickness)
465
+ mask_hard = mask_substract_torch(mask_frozen, mask_soft)
466
+
467
+ check_mask_overlap_torch(mask_hard, mask_soft)
468
+
469
+ pipe.inference_with_mask(
470
+ save_path,
471
+ strength = args.strength,
472
+ orig_image = image_gt,
473
+ guidance_scale = args.guidance_scale,
474
+ num_sampling_steps = args.num_sampling_steps,
475
+ num_imgs = args.num_imgs,
476
+ mask_hard= mask_hard,
477
+ mask_soft = mask_soft,
478
+ mask_list = mask_list_processed,
479
+ seed = args.seed
480
+ )
main.py CHANGED
@@ -9,416 +9,406 @@ from utils import load_image, load_mask, load_mask_edit
9
  from utils_mask import process_mask_move_torch, process_mask_remove_torch, mask_union_torch, mask_substract_torch, create_outer_edge_mask_torch
10
  from utils_mask import check_mask_overlap_torch, check_cover_all_torch, visualize_mask_list, get_mask_difference_torch, save_mask_list_to_npys
11
 
12
- parser = argparse.ArgumentParser()
13
- parser.add_argument("--name", type=str,required=True, default=None)
14
- parser.add_argument("--name_2", type=str,required=False, default=None)
15
- parser.add_argument("--dpm", type=str,required=True, default="sd")
16
- parser.add_argument("--resolution", type=int, default=1024)
17
- parser.add_argument("--seed", type=int, default=42)
18
- parser.add_argument("--embedding_learning_rate", type=float, default=1e-4)
19
- parser.add_argument("--max_emb_train_steps", type=int, default=200)
20
- parser.add_argument("--diffusion_model_learning_rate", type=float, default=5e-5)
21
- parser.add_argument("--max_diffusion_train_steps", type=int, default=200)
22
- parser.add_argument("--train_batch_size", type=int, default=1)
23
- parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
24
- parser.add_argument("--num_tokens", type=int, default=1)
25
-
26
-
27
- parser.add_argument("--load_trained", default=False, action="store_true" )
28
- parser.add_argument("--num_sampling_steps", type=int, default=50)
29
- parser.add_argument("--guidance_scale", type=float, default = 3 )
30
- parser.add_argument("--strength", type=float, default=0.8)
31
-
32
- parser.add_argument("--train_full_lora", default=False, action="store_true" )
33
- parser.add_argument("--lora_rank", type=int, default=4)
34
- parser.add_argument("--lora_alpha", type=int, default=4)
35
-
36
- parser.add_argument("--prompt_auxin_list", nargs="+", type=str, default = None)
37
- parser.add_argument("--prompt_auxin_idx_list", nargs="+", type=int, default = None)
38
-
39
- # general editing configs
40
- parser.add_argument("--load_edited_mask", default=False, action="store_true")
41
- parser.add_argument("--load_edited_processed_mask", default=False, action="store_true")
42
- parser.add_argument("--edge_thickness", type=int, default=20)
43
- parser.add_argument("--num_imgs", type=int, default = 1 )
44
- parser.add_argument('--active_mask_list', nargs="+", type=int)
45
- parser.add_argument("--tgt_index", type=int, default=None)
46
-
47
- # recon
48
- parser.add_argument("--recon", default=False, action="store_true" )
49
- parser.add_argument("--recon_an_item", default=False, action="store_true" )
50
- parser.add_argument("--recon_prompt", type=str, default=None)
51
-
52
- # text-based editing
53
- parser.add_argument("--text", default=False, action="store_true")
54
- parser.add_argument("--tgt_prompt", type=str, default=None)
55
-
56
- # image-based editing
57
- parser.add_argument("--image", default=False, action="store_true" )
58
- parser.add_argument("--src_index", type=int, default=None)
59
- parser.add_argument("--tgt_name", type=str, default=None)
60
-
61
- # mask-based move
62
- parser.add_argument("--move_resize", default=False, action="store_true" )
63
- parser.add_argument('--tgt_indices_list', nargs="+", type=int)
64
- parser.add_argument("--delta_x_list", nargs="+", type=int)
65
- parser.add_argument("--delta_y_list", nargs="+", type=int)
66
- parser.add_argument("--priority_list", nargs="+", type=int)
67
- parser.add_argument("--force_mask_remain", type=int, default=None)
68
- parser.add_argument("--resize_list", nargs="+", type=float)
69
-
70
- # remove
71
- parser.add_argument("--remove", default=False, action="store_true" )
72
- parser.add_argument("--load_edited_removemask", default=False, action="store_true")
73
-
74
- args = parser.parse_args()
75
-
76
- torch.cuda.manual_seed_all(args.seed)
77
- torch.manual_seed(args.seed)
78
- base_input_folder = "."
79
- base_output_folder = "."
80
-
81
- input_folder = os.path.join(base_input_folder, args.name)
82
-
83
-
84
- mask_list, mask_label_list = load_mask(input_folder)
85
- assert mask_list[0].shape[0] == args.resolution, "Segmentation should be done on size {}".format(args.resolution)
86
- try:
87
- image_gt = load_image(os.path.join(input_folder, "img_{}.png".format(args.resolution) ), size = args.resolution)
88
- except:
89
- image_gt = load_image(os.path.join(input_folder, "img_{}.jpg".format(args.resolution) ), size = args.resolution)
90
-
91
- if args.image:
92
- input_folder_2 = os.path.join(base_input_folder, args.name_2)
93
- mask_list_2, mask_label_list_2 = load_mask(input_folder_2)
94
- assert mask_list_2[0].shape[0] == args.resolution, "Segmentation should be done on size {}".format(args.resolution)
95
  try:
96
- image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.png".format(args.resolution) ), size = args.resolution)
97
  except:
98
- image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.jpg".format(args.resolution) ), size = args.resolution)
99
- output_dir = os.path.join(base_output_folder, args.name + "_" + args.name_2)
100
- os.makedirs(output_dir, exist_ok = True)
101
- else:
102
- output_dir = os.path.join(base_output_folder, args.name)
103
- os.makedirs(output_dir, exist_ok = True)
104
-
105
- if args.dpm == "sd":
106
- if args.image:
107
- pipe = DEditSDPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = args.resolution, num_tokens = args.num_tokens)
 
 
108
  else:
109
- pipe = DEditSDPipeline(mask_list, mask_label_list, resolution = args.resolution, num_tokens = args.num_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- elif args.dpm == "sdxl":
112
- if args.image:
113
- pipe = DEditSDXLPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = args.resolution, num_tokens = args.num_tokens)
 
114
  else:
115
- pipe = DEditSDXLPipeline(mask_list, mask_label_list, resolution = args.resolution, num_tokens = args.num_tokens)
116
-
117
- else:
118
- raise NotImplementedError
119
-
120
- set_string_list = pipe.set_string_list
121
- if args.prompt_auxin_list is not None:
122
- for auxin_idx, auxin_prompt in zip(args.prompt_auxin_idx_list, args.prompt_auxin_list):
123
- set_string_list[auxin_idx] = auxin_prompt.replace("*", set_string_list[auxin_idx] )
124
- print(set_string_list)
125
-
126
- if args.image:
127
- set_string_list_2 = pipe.set_string_list_2
128
- print(set_string_list_2)
129
-
130
- if args.load_trained:
131
- unet_save_path = os.path.join(output_dir, "unet.pt")
132
- unet_state_dict = torch.load(unet_save_path)
133
- text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
134
- text_encoder1_state_dict = torch.load(text_encoder1_save_path)
135
- if args.dpm == "sdxl":
136
- text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
137
- text_encoder2_state_dict = torch.load(text_encoder2_save_path)
138
-
139
- if 'lora' in ''.join(unet_state_dict.keys()):
140
- unet_lora_config = LoraConfig(
141
- r=args.lora_rank,
142
- lora_alpha=args.lora_alpha,
143
- init_lora_weights="gaussian",
144
- target_modules=["to_k", "to_q", "to_v", "to_out.0"],
145
  )
146
- pipe.unet.add_adapter(unet_lora_config)
147
-
148
- pipe.unet.load_state_dict(unet_state_dict)
149
- pipe.text_encoder.load_state_dict(text_encoder1_state_dict)
150
- if args.dpm == "sdxl":
151
- pipe.text_encoder_2.load_state_dict(text_encoder2_state_dict)
152
- else:
153
- if args.image:
154
- pipe.mask_list = [m.cuda() for m in pipe.mask_list]
155
- pipe.mask_list_2 = [m.cuda() for m in pipe.mask_list_2]
156
- pipe.train_emb_2imgs(
157
- image_gt,
158
- image_gt_2,
159
- set_string_list,
160
- set_string_list_2,
161
- gradient_accumulation_steps = args.gradient_accumulation_steps,
162
- embedding_learning_rate = args.embedding_learning_rate,
163
- max_emb_train_steps = args.max_emb_train_steps,
164
- train_batch_size = args.train_batch_size,
165
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- pipe.train_model_2imgs(
168
- image_gt,
169
- image_gt_2,
170
- set_string_list,
171
- set_string_list_2,
172
- gradient_accumulation_steps = args.gradient_accumulation_steps,
173
- max_diffusion_train_steps = args.max_diffusion_train_steps,
174
- diffusion_model_learning_rate = args.diffusion_model_learning_rate ,
175
- train_batch_size =args.train_batch_size,
176
- train_full_lora = args.train_full_lora,
177
- lora_rank = args.lora_rank,
178
- lora_alpha = args.lora_alpha
179
- )
180
 
181
- else:
182
- pipe.mask_list = [m.cuda() for m in pipe.mask_list]
183
- pipe.train_emb(
184
- image_gt,
185
- set_string_list,
186
- gradient_accumulation_steps = args.gradient_accumulation_steps,
187
- embedding_learning_rate = args.embedding_learning_rate,
188
- max_emb_train_steps = args.max_emb_train_steps,
189
- train_batch_size = args.train_batch_size,
 
 
 
 
 
 
 
 
 
 
190
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- pipe.train_model(
193
- image_gt,
194
- set_string_list,
195
- gradient_accumulation_steps = args.gradient_accumulation_steps,
196
- max_diffusion_train_steps = args.max_diffusion_train_steps,
197
- diffusion_model_learning_rate = args.diffusion_model_learning_rate ,
198
- train_batch_size = args.train_batch_size,
199
- train_full_lora = args.train_full_lora,
200
- lora_rank = args.lora_rank,
201
- lora_alpha = args.lora_alpha
 
 
202
  )
203
 
204
-
205
- unet_save_path = os.path.join(output_dir, "unet.pt")
206
- torch.save(pipe.unet.state_dict(),unet_save_path )
207
- text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
208
- torch.save(pipe.text_encoder.state_dict(), text_encoder1_save_path)
209
- if args.dpm == "sdxl":
210
- text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
211
- torch.save(pipe.text_encoder_2.state_dict(), text_encoder2_save_path )
212
-
213
-
214
- if args.recon:
215
- output_dir = os.path.join(output_dir, "recon")
216
- os.makedirs(output_dir, exist_ok = True)
217
- if args.recon_an_item:
218
- mask_list = [torch.from_numpy(np.ones_like(mask_list[0].numpy()))]
219
- tgt_string = set_string_list[args.tgt_index]
220
- tgt_string = args.recon_prompt.replace("*", tgt_string)
221
- set_string_list = [tgt_string]
222
- print(set_string_list)
223
- save_path = os.path.join(output_dir, "out_recon.png")
224
- x_np = pipe.inference_with_mask(
225
- save_path,
226
- guidance_scale = args.guidance_scale,
227
- num_sampling_steps = args.num_sampling_steps,
228
- seed = args.seed,
229
- num_imgs = args.num_imgs,
230
- set_string_list = set_string_list,
231
- mask_list = mask_list
232
- )
233
-
234
- if args.text:
235
- print("Text-guided editing ")
236
- output_dir = os.path.join(output_dir, "text")
237
- os.makedirs(output_dir, exist_ok = True)
238
- save_path = os.path.join(output_dir, "out_text.png")
239
- set_string_list[args.tgt_index] = args.tgt_prompt
240
- mask_active = torch.zeros_like(mask_list[0])
241
- mask_active = mask_union_torch(mask_active, mask_list[args.tgt_index])
242
-
243
- if args.active_mask_list is not None:
244
- for midx in args.active_mask_list:
245
- mask_active = mask_union_torch(mask_active, mask_list[midx])
246
-
247
- if args.load_edited_mask:
248
- mask_list_edited, mask_label_list_edited = load_mask_edit(input_folder)
249
- mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
250
- mask_active = mask_union_torch(mask_active, mask_diff)
251
- mask_list = mask_list_edited
252
- save_path = os.path.join(output_dir, "out_textEdited.png")
253
-
254
- mask_hard = mask_substract_torch(torch.ones_like(mask_list[0]), mask_active)
255
- mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = args.edge_thickness)
256
- mask_hard = mask_substract_torch(mask_hard, mask_soft)
257
-
258
- pipe.inference_with_mask(
259
- save_path,
260
- orig_image = image_gt,
261
- set_string_list = set_string_list,
262
- guidance_scale = args.guidance_scale,
263
- strength = args.strength,
264
- num_imgs = args.num_imgs,
265
- mask_hard= mask_hard,
266
- mask_soft = mask_soft,
267
- mask_list = mask_list,
268
- seed = args.seed,
269
- num_sampling_steps = args.num_sampling_steps
270
- )
271
-
272
- if args.remove:
273
- output_dir = os.path.join(output_dir, "remove")
274
- save_path = os.path.join(output_dir, "out_remove.png")
275
- os.makedirs(output_dir, exist_ok = True)
276
- mask_active = torch.zeros_like(mask_list[0])
277
-
278
- if args.load_edited_mask:
279
- mask_list_edited, _ = load_mask_edit(input_folder)
280
- mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
281
- mask_active = mask_union_torch(mask_active, mask_diff)
282
- mask_list = mask_list_edited
283
 
284
- if args.load_edited_processed_mask:
285
- # manually edit or draw masks after removing one index, then load
286
- mask_list_processed, _ = load_mask_edit(output_dir)
287
- mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
288
- else:
289
- # generate masks after removing one index, using nearest neighbor algorithm
290
- mask_list_processed, mask_remain = process_mask_remove_torch(mask_list, args.tgt_index)
291
- save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
292
- visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_removed.png"))
293
- check_cover_all_torch(*mask_list_processed)
294
- mask_active = mask_union_torch(mask_active, mask_remain)
295
-
296
- if args.active_mask_list is not None:
297
- for midx in args.active_mask_list:
298
- mask_active = mask_union_torch(mask_active, mask_list[midx])
299
-
300
- mask_hard = 1 - mask_active
301
- mask_soft = create_outer_edge_mask_torch(mask_remain, edge_thickness = args.edge_thickness)
302
- mask_hard = mask_substract_torch(mask_hard, mask_soft)
303
-
304
- pipe.inference_with_mask(
305
- save_path,
306
- orig_image = image_gt,
307
- guidance_scale = args.guidance_scale,
308
- strength = args.strength,
309
- num_imgs = args.num_imgs,
310
- mask_hard= mask_hard,
311
- mask_soft = mask_soft,
312
- mask_list = mask_list_processed,
313
- seed = args.seed,
314
- num_sampling_steps = args.num_sampling_steps
315
- )
316
-
317
- if args.image:
318
- output_dir = os.path.join(output_dir, "image")
319
- save_path = os.path.join(output_dir, "out_image.png")
320
- os.makedirs(output_dir, exist_ok = True)
321
- mask_active = torch.zeros_like(mask_list[0])
322
-
323
- if None not in (args.tgt_name, args.src_index, args.tgt_index):
324
- if args.tgt_name == args.name:
325
- set_string_list_tgt = set_string_list
326
- set_string_list_src = set_string_list_2
327
- image_tgt = image_gt
328
- if args.load_edited_mask:
329
- mask_list_edited, _ = load_mask_edit(input_folder)
330
- mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
331
- mask_active = mask_union_torch(mask_active, mask_diff)
332
- mask_list = mask_list_edited
333
- save_path = os.path.join(output_dir, "out_imageEdited.png")
334
- mask_list_tgt = mask_list
335
 
336
- elif args.tgt_name == args.name_2:
337
- set_string_list_tgt = set_string_list_2
338
- set_string_list_src = set_string_list
339
- image_tgt = image_gt_2
340
- if args.load_edited_mask:
341
- mask_list_2_edited, _ = load_mask_edit(input_folder_2)
342
- mask_diff = get_mask_difference_torch(mask_list_2_edited, mask_list_2)
343
- mask_active = mask_union_torch(mask_active, mask_diff)
344
- mask_list_2 = mask_list_2_edited
345
- save_path = os.path.join(output_dir, "out_imageEdited.png")
346
- mask_list_tgt = mask_list_2
347
  else:
348
- exit("tgt_name should be either name or name_2")
349
-
350
- set_string_list_tgt[args.tgt_index] = set_string_list_src[args.src_index]
 
 
 
351
 
352
- mask_active = mask_list_tgt[args.tgt_index]
353
- mask_frozen = (1-mask_active.float()).to(mask_active.device)
354
- mask_soft = create_outer_edge_mask_torch(mask_active.cpu(), edge_thickness = args.edge_thickness)
355
- mask_hard = mask_substract_torch(mask_frozen.cpu(), mask_soft.cpu())
356
-
357
- mask_list_tgt = [m.cuda() for m in mask_list_tgt]
 
358
 
359
  pipe.inference_with_mask(
360
- save_path,
361
- set_string_list = set_string_list_tgt,
362
- mask_list = mask_list_tgt,
363
- guidance_scale = args.guidance_scale,
364
- num_sampling_steps = args.num_sampling_steps,
365
- mask_hard = mask_hard.cuda(),
366
- mask_soft = mask_soft.cuda(),
367
- num_imgs = args.num_imgs,
368
- orig_image = image_tgt,
369
- strength = args.strength,
370
  )
371
 
372
- if args.move_resize:
373
- output_dir = os.path.join(output_dir, "move_resize")
374
- os.makedirs(output_dir, exist_ok = True)
375
- save_path = os.path.join(output_dir, "out_moveresize.png")
376
- mask_active = torch.zeros_like(mask_list[0])
377
-
378
- if args.load_edited_mask:
379
- mask_list_edited, _ = load_mask_edit(input_folder)
380
- mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
381
- mask_active = mask_union_torch(mask_active, mask_diff)
382
- mask_list = mask_list_edited
383
- # save_path = os.path.join(output_dir, "out_moveresizeEdited.png")
384
 
385
- if args.load_edited_processed_mask:
386
- mask_list_processed, _ = load_mask_edit(output_dir)
387
- mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
388
- else:
389
- mask_list_processed, mask_remain = process_mask_move_torch(
390
- mask_list,
391
- args.tgt_indices_list,
392
- args.delta_x_list,
393
- args.delta_y_list, args.priority_list,
394
- force_mask_remain = args.force_mask_remain,
395
- resize_list = args.resize_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  )
397
- save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
398
- visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_move_resize.png"))
399
- active_idxs = args.tgt_indices_list
400
-
401
- mask_active = mask_union_torch(mask_active, *[m for midx, m in enumerate(mask_list_processed) if midx in active_idxs])
402
- mask_active = mask_union_torch(mask_remain, mask_active)
403
- if args.active_mask_list is not None:
404
- for midx in args.active_mask_list:
405
- mask_active = mask_union_torch(mask_active, mask_list_processed[midx])
406
-
407
- mask_frozen =(1 - mask_active.float())
408
- mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = args.edge_thickness)
409
- mask_hard = mask_substract_torch(mask_frozen, mask_soft)
410
-
411
- check_mask_overlap_torch(mask_hard, mask_soft)
412
-
413
- pipe.inference_with_mask(
414
- save_path,
415
- strength = args.strength,
416
- orig_image = image_gt,
417
- guidance_scale = args.guidance_scale,
418
- num_sampling_steps = args.num_sampling_steps,
419
- num_imgs = args.num_imgs,
420
- mask_hard= mask_hard,
421
- mask_soft = mask_soft,
422
- mask_list = mask_list_processed,
423
- seed = args.seed
424
- )
 
9
  from utils_mask import process_mask_move_torch, process_mask_remove_torch, mask_union_torch, mask_substract_torch, create_outer_edge_mask_torch
10
  from utils_mask import check_mask_overlap_torch, check_cover_all_torch, visualize_mask_list, get_mask_difference_torch, save_mask_list_to_npys
11
 
12
+ def run_main(
13
+ name="example_tmp",
14
+ name_2=None,
15
+ dpm="sd",
16
+ resolution=512,
17
+ seed=42,
18
+ embedding_learning_rate=1e-4,
19
+ max_emb_train_steps=200,
20
+ diffusion_model_learning_rate=5e-5,
21
+ max_diffusion_train_steps=200,
22
+ train_batch_size=1,
23
+ gradient_accumulation_steps=1,
24
+ num_tokens=1,
25
+
26
+ load_trained=False ,
27
+ num_sampling_steps=50,
28
+ guidance_scale= 3 ,
29
+ strength=0.8,
30
+
31
+ train_full_lora=False ,
32
+ lora_rank=4,
33
+ lora_alpha=4,
34
+
35
+ prompt_auxin_list = None,
36
+ prompt_auxin_idx_list= None,
37
+
38
+ load_edited_mask=False,
39
+ load_edited_processed_mask=False,
40
+ edge_thickness=20,
41
+ num_imgs= 1 ,
42
+ active_mask_list = None,
43
+ tgt_index=None,
44
+
45
+ recon=False ,
46
+ recon_an_item=False,
47
+ recon_prompt=None,
48
+
49
+ text=False,
50
+ tgt_prompt=None,
51
+
52
+ image=False ,
53
+ src_index=None,
54
+ tgt_name=None,
55
+
56
+ move_resize=False ,
57
+ tgt_indices_list=None,
58
+ delta_x_list=None,
59
+ delta_y_list=None,
60
+ priority_list=None,
61
+ force_mask_remain=None,
62
+ resize_list=None,
63
+
64
+ remove=False,
65
+ load_edited_removemask=False
66
+ ):
67
+ torch.cuda.manual_seed_all(seed)
68
+ torch.manual_seed(seed)
69
+ base_input_folder = "."
70
+ base_output_folder = "."
71
+
72
+ input_folder = os.path.join(base_input_folder, name)
73
+
74
+ mask_list, mask_label_list = load_mask(input_folder)
75
+ assert mask_list[0].shape[0] == resolution, "Segmentation should be done on size {}".format(resolution)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  try:
77
+ image_gt = load_image(os.path.join(input_folder, "img_{}.png".format(resolution) ), size = resolution)
78
  except:
79
+ image_gt = load_image(os.path.join(input_folder, "img_{}.jpg".format(resolution) ), size = resolution)
80
+
81
+ if image:
82
+ input_folder_2 = os.path.join(base_input_folder, name_2)
83
+ mask_list_2, mask_label_list_2 = load_mask(input_folder_2)
84
+ assert mask_list_2[0].shape[0] == resolution, "Segmentation should be done on size {}".format(resolution)
85
+ try:
86
+ image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.png".format(resolution) ), size = resolution)
87
+ except:
88
+ image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.jpg".format(resolution) ), size = resolution)
89
+ output_dir = os.path.join(base_output_folder, name + "_" + name_2)
90
+ os.makedirs(output_dir, exist_ok = True)
91
  else:
92
+ output_dir = os.path.join(base_output_folder, name)
93
+ os.makedirs(output_dir, exist_ok = True)
94
+
95
+ if dpm == "sd":
96
+ if image:
97
+ pipe = DEditSDPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = resolution, num_tokens = num_tokens)
98
+ else:
99
+ pipe = DEditSDPipeline(mask_list, mask_label_list, resolution = resolution, num_tokens = num_tokens)
100
+
101
+ elif dpm == "sdxl":
102
+ if image:
103
+ pipe = DEditSDXLPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = resolution, num_tokens = num_tokens)
104
+ else:
105
+ pipe = DEditSDXLPipeline(mask_list, mask_label_list, resolution = resolution, num_tokens = num_tokens)
106
+
107
+ else:
108
+ raise NotImplementedError
109
+
110
+ set_string_list = pipe.set_string_list
111
+ if prompt_auxin_list is not None:
112
+ for auxin_idx, auxin_prompt in zip(prompt_auxin_idx_list, prompt_auxin_list):
113
+ set_string_list[auxin_idx] = auxin_prompt.replace("*", set_string_list[auxin_idx] )
114
+ print(set_string_list)
115
+
116
+ if image:
117
+ set_string_list_2 = pipe.set_string_list_2
118
+ print(set_string_list_2)
119
+
120
+ if load_trained:
121
+ unet_save_path = os.path.join(output_dir, "unet.pt")
122
+ unet_state_dict = torch.load(unet_save_path)
123
+ text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
124
+ text_encoder1_state_dict = torch.load(text_encoder1_save_path)
125
+ if dpm == "sdxl":
126
+ text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
127
+ text_encoder2_state_dict = torch.load(text_encoder2_save_path)
128
+
129
+ if 'lora' in ''.join(unet_state_dict.keys()):
130
+ unet_lora_config = LoraConfig(
131
+ r=lora_rank,
132
+ lora_alpha=lora_alpha,
133
+ init_lora_weights="gaussian",
134
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
135
+ )
136
+ pipe.unet.add_adapter(unet_lora_config)
137
 
138
+ pipe.unet.load_state_dict(unet_state_dict)
139
+ pipe.text_encoder.load_state_dict(text_encoder1_state_dict)
140
+ if dpm == "sdxl":
141
+ pipe.text_encoder_2.load_state_dict(text_encoder2_state_dict)
142
  else:
143
+ if image:
144
+ pipe.mask_list = [m.cuda() for m in pipe.mask_list]
145
+ pipe.mask_list_2 = [m.cuda() for m in pipe.mask_list_2]
146
+ pipe.train_emb_2imgs(
147
+ image_gt,
148
+ image_gt_2,
149
+ set_string_list,
150
+ set_string_list_2,
151
+ gradient_accumulation_steps = gradient_accumulation_steps,
152
+ embedding_learning_rate = embedding_learning_rate,
153
+ max_emb_train_steps = max_emb_train_steps,
154
+ train_batch_size = train_batch_size,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  )
156
+
157
+ pipe.train_model_2imgs(
158
+ image_gt,
159
+ image_gt_2,
160
+ set_string_list,
161
+ set_string_list_2,
162
+ gradient_accumulation_steps = gradient_accumulation_steps,
163
+ max_diffusion_train_steps = max_diffusion_train_steps,
164
+ diffusion_model_learning_rate = diffusion_model_learning_rate ,
165
+ train_batch_size =train_batch_size,
166
+ train_full_lora = train_full_lora,
167
+ lora_rank = lora_rank,
168
+ lora_alpha = lora_alpha
169
+ )
170
+
171
+ else:
172
+ pipe.mask_list = [m.cuda() for m in pipe.mask_list]
173
+ pipe.train_emb(
174
+ image_gt,
175
+ set_string_list,
176
+ gradient_accumulation_steps = gradient_accumulation_steps,
177
+ embedding_learning_rate = embedding_learning_rate,
178
+ max_emb_train_steps = max_emb_train_steps,
179
+ train_batch_size = train_batch_size,
180
+ )
181
+
182
+ pipe.train_model(
183
+ image_gt,
184
+ set_string_list,
185
+ gradient_accumulation_steps = gradient_accumulation_steps,
186
+ max_diffusion_train_steps = max_diffusion_train_steps,
187
+ diffusion_model_learning_rate = diffusion_model_learning_rate ,
188
+ train_batch_size = train_batch_size,
189
+ train_full_lora = train_full_lora,
190
+ lora_rank = lora_rank,
191
+ lora_alpha = lora_alpha
192
+ )
193
+
194
 
195
+ unet_save_path = os.path.join(output_dir, "unet.pt")
196
+ torch.save(pipe.unet.state_dict(),unet_save_path )
197
+ text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
198
+ torch.save(pipe.text_encoder.state_dict(), text_encoder1_save_path)
199
+ if dpm == "sdxl":
200
+ text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
201
+ torch.save(pipe.text_encoder_2.state_dict(), text_encoder2_save_path )
 
 
 
 
 
 
202
 
203
+
204
+ if recon:
205
+ output_dir = os.path.join(output_dir, "recon")
206
+ os.makedirs(output_dir, exist_ok = True)
207
+ if recon_an_item:
208
+ mask_list = [torch.from_numpy(np.ones_like(mask_list[0].numpy()))]
209
+ tgt_string = set_string_list[tgt_index]
210
+ tgt_string = recon_prompt.replace("*", tgt_string)
211
+ set_string_list = [tgt_string]
212
+ print(set_string_list)
213
+ save_path = os.path.join(output_dir, "out_recon.png")
214
+ x_np = pipe.inference_with_mask(
215
+ save_path,
216
+ guidance_scale = guidance_scale,
217
+ num_sampling_steps = num_sampling_steps,
218
+ seed = seed,
219
+ num_imgs = num_imgs,
220
+ set_string_list = set_string_list,
221
+ mask_list = mask_list
222
  )
223
+
224
+ if text:
225
+ print("Text-guided editing ")
226
+ output_dir = os.path.join(output_dir, "text")
227
+ os.makedirs(output_dir, exist_ok = True)
228
+ save_path = os.path.join(output_dir, "out_text.png")
229
+ set_string_list[tgt_index] = tgt_prompt
230
+ mask_active = torch.zeros_like(mask_list[0])
231
+ mask_active = mask_union_torch(mask_active, mask_list[tgt_index])
232
+
233
+ if active_mask_list is not None:
234
+ for midx in active_mask_list:
235
+ mask_active = mask_union_torch(mask_active, mask_list[midx])
236
+
237
+ if load_edited_mask:
238
+ mask_list_edited, mask_label_list_edited = load_mask_edit(input_folder)
239
+ mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
240
+ mask_active = mask_union_torch(mask_active, mask_diff)
241
+ mask_list = mask_list_edited
242
+ save_path = os.path.join(output_dir, "out_textEdited.png")
243
+
244
+ mask_hard = mask_substract_torch(torch.ones_like(mask_list[0]), mask_active)
245
+ mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = edge_thickness)
246
+ mask_hard = mask_substract_torch(mask_hard, mask_soft)
247
 
248
+ pipe.inference_with_mask(
249
+ save_path,
250
+ orig_image = image_gt,
251
+ set_string_list = set_string_list,
252
+ guidance_scale = guidance_scale,
253
+ strength = strength,
254
+ num_imgs = num_imgs,
255
+ mask_hard= mask_hard,
256
+ mask_soft = mask_soft,
257
+ mask_list = mask_list,
258
+ seed = seed,
259
+ num_sampling_steps = num_sampling_steps
260
  )
261
 
262
+ if remove:
263
+ output_dir = os.path.join(output_dir, "remove")
264
+ save_path = os.path.join(output_dir, "out_remove.png")
265
+ os.makedirs(output_dir, exist_ok = True)
266
+ mask_active = torch.zeros_like(mask_list[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
+ if load_edited_mask:
269
+ mask_list_edited, _ = load_mask_edit(input_folder)
270
+ mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
271
+ mask_active = mask_union_torch(mask_active, mask_diff)
272
+ mask_list = mask_list_edited
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
+ if load_edited_processed_mask:
275
+ # manually edit or draw masks after removing one index, then load
276
+ mask_list_processed, _ = load_mask_edit(output_dir)
277
+ mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
 
 
 
 
 
 
 
278
  else:
279
+ # generate masks after removing one index, using nearest neighbor algorithm
280
+ mask_list_processed, mask_remain = process_mask_remove_torch(mask_list, tgt_index)
281
+ save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
282
+ visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_removed.png"))
283
+ check_cover_all_torch(*mask_list_processed)
284
+ mask_active = mask_union_torch(mask_active, mask_remain)
285
 
286
+ if active_mask_list is not None:
287
+ for midx in active_mask_list:
288
+ mask_active = mask_union_torch(mask_active, mask_list[midx])
289
+
290
+ mask_hard = 1 - mask_active
291
+ mask_soft = create_outer_edge_mask_torch(mask_remain, edge_thickness = edge_thickness)
292
+ mask_hard = mask_substract_torch(mask_hard, mask_soft)
293
 
294
  pipe.inference_with_mask(
295
+ save_path,
296
+ orig_image = image_gt,
297
+ guidance_scale = guidance_scale,
298
+ strength = strength,
299
+ num_imgs = num_imgs,
300
+ mask_hard= mask_hard,
301
+ mask_soft = mask_soft,
302
+ mask_list = mask_list_processed,
303
+ seed = seed,
304
+ num_sampling_steps = num_sampling_steps
305
  )
306
 
307
+ if image:
308
+ output_dir = os.path.join(output_dir, "image")
309
+ save_path = os.path.join(output_dir, "out_image.png")
310
+ os.makedirs(output_dir, exist_ok = True)
311
+ mask_active = torch.zeros_like(mask_list[0])
 
 
 
 
 
 
 
312
 
313
+ if None not in (tgt_name, src_index, tgt_index):
314
+ if tgt_name == name:
315
+ set_string_list_tgt = set_string_list
316
+ set_string_list_src = set_string_list_2
317
+ image_tgt = image_gt
318
+ if load_edited_mask:
319
+ mask_list_edited, _ = load_mask_edit(input_folder)
320
+ mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
321
+ mask_active = mask_union_torch(mask_active, mask_diff)
322
+ mask_list = mask_list_edited
323
+ save_path = os.path.join(output_dir, "out_imageEdited.png")
324
+ mask_list_tgt = mask_list
325
+
326
+ elif tgt_name == name_2:
327
+ set_string_list_tgt = set_string_list_2
328
+ set_string_list_src = set_string_list
329
+ image_tgt = image_gt_2
330
+ if load_edited_mask:
331
+ mask_list_2_edited, _ = load_mask_edit(input_folder_2)
332
+ mask_diff = get_mask_difference_torch(mask_list_2_edited, mask_list_2)
333
+ mask_active = mask_union_torch(mask_active, mask_diff)
334
+ mask_list_2 = mask_list_2_edited
335
+ save_path = os.path.join(output_dir, "out_imageEdited.png")
336
+ mask_list_tgt = mask_list_2
337
+ else:
338
+ exit("tgt_name should be either name or name_2")
339
+
340
+ set_string_list_tgt[tgt_index] = set_string_list_src[src_index]
341
+
342
+ mask_active = mask_list_tgt[tgt_index]
343
+ mask_frozen = (1-mask_active.float()).to(mask_active.device)
344
+ mask_soft = create_outer_edge_mask_torch(mask_active.cpu(), edge_thickness = edge_thickness)
345
+ mask_hard = mask_substract_torch(mask_frozen.cpu(), mask_soft.cpu())
346
+
347
+ mask_list_tgt = [m.cuda() for m in mask_list_tgt]
348
+
349
+ pipe.inference_with_mask(
350
+ save_path,
351
+ set_string_list = set_string_list_tgt,
352
+ mask_list = mask_list_tgt,
353
+ guidance_scale = guidance_scale,
354
+ num_sampling_steps = num_sampling_steps,
355
+ mask_hard = mask_hard.cuda(),
356
+ mask_soft = mask_soft.cuda(),
357
+ num_imgs = num_imgs,
358
+ orig_image = image_tgt,
359
+ strength = strength,
360
+ )
361
+
362
+ if move_resize:
363
+ output_dir = os.path.join(output_dir, "move_resize")
364
+ os.makedirs(output_dir, exist_ok = True)
365
+ save_path = os.path.join(output_dir, "out_moveresize.png")
366
+ mask_active = torch.zeros_like(mask_list[0])
367
+
368
+ if load_edited_mask:
369
+ mask_list_edited, _ = load_mask_edit(input_folder)
370
+ mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
371
+ mask_active = mask_union_torch(mask_active, mask_diff)
372
+ mask_list = mask_list_edited
373
+ # save_path = os.path.join(output_dir, "out_moveresizeEdited.png")
374
+
375
+ if load_edited_processed_mask:
376
+ mask_list_processed, _ = load_mask_edit(output_dir)
377
+ mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
378
+ else:
379
+ mask_list_processed, mask_remain = process_mask_move_torch(
380
+ mask_list,
381
+ tgt_indices_list,
382
+ delta_x_list,
383
+ delta_y_list, priority_list,
384
+ force_mask_remain = force_mask_remain,
385
+ resize_list = resize_list
386
+ )
387
+ save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
388
+ visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_move_resize.png"))
389
+ active_idxs = tgt_indices_list
390
+
391
+ mask_active = mask_union_torch(mask_active, *[m for midx, m in enumerate(mask_list_processed) if midx in active_idxs])
392
+ mask_active = mask_union_torch(mask_remain, mask_active)
393
+ if active_mask_list is not None:
394
+ for midx in active_mask_list:
395
+ mask_active = mask_union_torch(mask_active, mask_list_processed[midx])
396
+
397
+ mask_frozen =(1 - mask_active.float())
398
+ mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = edge_thickness)
399
+ mask_hard = mask_substract_torch(mask_frozen, mask_soft)
400
+
401
+ check_mask_overlap_torch(mask_hard, mask_soft)
402
+
403
+ pipe.inference_with_mask(
404
+ save_path,
405
+ strength = strength,
406
+ orig_image = image_gt,
407
+ guidance_scale = guidance_scale,
408
+ num_sampling_steps = num_sampling_steps,
409
+ num_imgs = num_imgs,
410
+ mask_hard= mask_hard,
411
+ mask_soft = mask_soft,
412
+ mask_list = mask_list_processed,
413
+ seed = seed
414
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipeline_dedit_sd.py CHANGED
@@ -27,11 +27,11 @@ class DEditSDPipeline:
27
  mask_label_list,
28
  mask_list_2 = None,
29
  mask_label_list_2 = None,
30
- resolution = 1024,
31
  num_tokens = 1
32
  ):
33
  super().__init__()
34
- model_id = "./stable-diffusion-v1-5"
35
  self.model_id = model_id
36
  self.tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", use_fast=False)
37
  text_encoder_cls_one = import_model_class_from_model_name_or_path(model_id, subfolder = "text_encoder")
@@ -810,4 +810,5 @@ class DEditSDPipeline:
810
  seed = seed
811
  )
812
  save_images(x0, save_path)
813
- return x0
 
 
27
  mask_label_list,
28
  mask_list_2 = None,
29
  mask_label_list_2 = None,
30
+ resolution = 512,
31
  num_tokens = 1
32
  ):
33
  super().__init__()
34
+ model_id = "CompVis/stable-diffusion-v1-4"
35
  self.model_id = model_id
36
  self.tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", use_fast=False)
37
  text_encoder_cls_one = import_model_class_from_model_name_or_path(model_id, subfolder = "text_encoder")
 
810
  seed = seed
811
  )
812
  save_images(x0, save_path)
813
+ # from PIL import Image
814
+ # return Image.open("example_tmp/text/out_text_0.png")
segment.py CHANGED
@@ -102,7 +102,8 @@ def run_segmentation(image, name="example_tmp", size = 512, noseg=False):
102
  # image = load_image(os.path.join(input_folder, "img.png" ), size = size)
103
  # except:
104
  # image = load_image(os.path.join(input_folder, "img.jpg" ), size = size)
105
- # image =Image.fromarray(image)
 
106
  os.makedirs(name, exist_ok=True)
107
  image.save(os.path.join(name,"img_{}.png".format(size)))
108
  inputs = processor(image, return_tensors="pt")
 
102
  # image = load_image(os.path.join(input_folder, "img.png" ), size = size)
103
  # except:
104
  # image = load_image(os.path.join(input_folder, "img.jpg" ), size = size)
105
+ image =Image.fromarray(image)
106
+ image = image.resize((size, size))
107
  os.makedirs(name, exist_ok=True)
108
  image.save(os.path.join(name,"img_{}.png".format(size)))
109
  inputs = processor(image, return_tensors="pt")