afeng commited on
Commit
01d1b1f
1 Parent(s): a900192
Files changed (7) hide show
  1. .gitignore +1 -0
  2. app copy 2.py +0 -385
  3. app copy.py +0 -350
  4. app.py +64 -55
  5. main copy.py +0 -480
  6. main.py +3 -2
  7. pipeline_dedit_sd.py +2 -2
.gitignore CHANGED
@@ -5,6 +5,7 @@ example1_example2_1024/
5
  example1/
6
  old/
7
  example_tmp/
 
8
 
9
  out_active.png
10
  out_mask.png
 
5
  example1/
6
  old/
7
  example_tmp/
8
+ z_*
9
 
10
  out_active.png
11
  out_mask.png
app copy 2.py DELETED
@@ -1,385 +0,0 @@
1
-
2
- import os
3
- import copy
4
- from PIL import Image
5
- import matplotlib
6
- import numpy as np
7
- import gradio as gr
8
- from utils import load_mask, load_mask_edit
9
- from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean
10
- from pathlib import Path
11
- import subprocess
12
- from PIL import Image
13
- from functools import partial
14
- from main import run_main
15
- LENGTH=512 #length of the square area displaying/editing images
16
- TRANSPARENCY = 150 # transparency of the mask in display
17
-
18
- def add_mask(mask_np_list_updated, mask_label_list):
19
- mask_new = np.zeros_like(mask_np_list_updated[0])
20
- mask_np_list_updated.append(mask_new)
21
- mask_label_list.append("new")
22
- return mask_np_list_updated, mask_label_list
23
-
24
- def create_segmentation(mask_np_list):
25
- viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list))
26
- segmentation = 0
27
- for i, m in enumerate(mask_np_list):
28
- color = matplotlib.colors.to_rgb(viridis(i))
29
- color_mat = np.ones_like(m)
30
- color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2)
31
- color_mat = color_mat * m[:,:,np.newaxis]
32
- segmentation += color_mat
33
- segmentation = Image.fromarray(np.uint8(segmentation*255))
34
- return segmentation
35
-
36
- def load_mask_ui(input_folder="example_tmp",load_edit = False):
37
- if not load_edit:
38
- mask_list, mask_label_list = load_mask(input_folder)
39
- else:
40
- mask_list, mask_label_list = load_mask_edit(input_folder)
41
-
42
- mask_np_list = []
43
- for m in mask_list:
44
- mask_np_list. append( m.cpu().numpy())
45
-
46
- return mask_np_list, mask_label_list
47
-
48
- def load_image_ui(load_edit, input_folder="example_tmp"):
49
- try:
50
- for img_path in Path(input_folder).iterdir():
51
- if img_path.name in ["img_512.png"]:
52
- image = Image.open(img_path)
53
- mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit)
54
- image = image.convert('RGB')
55
- segmentation = create_segmentation(mask_np_list)
56
- print("!!", len(mask_np_list))
57
- return image, segmentation, mask_np_list, mask_label_list, image
58
- except:
59
- print("Image folder invalid: The folder should contain image.png")
60
- return None, None, None, None, None
61
-
62
- def run_edit_text(
63
- num_tokens,
64
- num_sampling_steps,
65
- strength,
66
- edge_thickness,
67
- tgt_prompt,
68
- tgt_idx,
69
- guidance_scale,
70
- input_folder="example_tmp"
71
- ):
72
- subprocess.run(["python",
73
- "main.py" ,
74
- "--text",
75
- "--name={}".format(input_folder),
76
- "--dpm={}".format("sd"),
77
- "--resolution={}".format(512),
78
- "--load_trained",
79
- "--num_tokens={}".format(num_tokens),
80
- "--seed={}".format(2024),
81
- "--guidance_scale={}".format(guidance_scale),
82
- "--num_sampling_step={}".format(num_sampling_steps),
83
- "--strength={}".format(strength),
84
- "--edge_thickness={}".format(edge_thickness),
85
- "--num_imgs={}".format(2),
86
- "--tgt_prompt={}".format(tgt_prompt) ,
87
- "--tgt_index={}".format(tgt_idx)
88
- ])
89
-
90
- return Image.open(os.path.join(input_folder, "text", "out_text_0.png"))
91
-
92
-
93
- def run_optimization(
94
- num_tokens,
95
- embedding_learning_rate,
96
- max_emb_train_steps,
97
- diffusion_model_learning_rate,
98
- max_diffusion_train_steps,
99
- train_batch_size,
100
- gradient_accumulation_steps,
101
- input_folder = "example_tmp"
102
- ):
103
- subprocess.run(["python",
104
- "main.py" ,
105
- "--name={}".format(input_folder),
106
- "--dpm={}".format("sd"),
107
- "--resolution={}".format(512),
108
- "--num_tokens={}".format(num_tokens),
109
- "--embedding_learning_rate={}".format(embedding_learning_rate),
110
- "--diffusion_model_learning_rate={}".format(diffusion_model_learning_rate),
111
- "--max_emb_train_steps={}".format(max_emb_train_steps),
112
- "--max_diffusion_train_steps={}".format(max_diffusion_train_steps),
113
- "--train_batch_size={}".format(train_batch_size),
114
- "--gradient_accumulation_steps={}".format(gradient_accumulation_steps)
115
-
116
- ])
117
- return
118
-
119
-
120
- def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
121
- backimg_solid_np = np.array(backimg)
122
- bimg = backimg.copy()
123
- fimg = foreimg.copy()
124
- fimg.putalpha(transparency)
125
- bimg.paste(fimg, (0,0), fimg)
126
-
127
- bimg_np = np.array(bimg)
128
- mask_np = mask_np[:,:,np.newaxis]
129
-
130
- try:
131
- new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np
132
- return Image.fromarray(new_img_np)
133
- except:
134
- import pdb; pdb.set_trace()
135
-
136
- def show_segmentation(image, segmentation, flag):
137
- if flag is False:
138
- flag = True
139
- mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8)
140
- image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY)
141
- return image_edit, flag
142
- else:
143
- flag = False
144
- return image,flag
145
-
146
- def edit_mask_add(canvas, image, idx, mask_np_list):
147
- mask_sel = mask_np_list[idx]
148
- mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.)
149
- mask_np_list_updated = []
150
- for midx, m in enumerate(mask_np_list):
151
- if midx == idx:
152
- mask_np_list_updated.append(mask_union(mask_sel, mask_new))
153
- else:
154
- mask_np_list_updated.append(m)
155
-
156
- priority_list = [0 for _ in range(len(mask_np_list_updated))]
157
- priority_list[idx] = 1
158
- mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list)
159
- mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8)
160
- segmentation = create_segmentation(mask_np_list_updated)
161
- image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY)
162
- return mask_np_list_updated, image_edit
163
-
164
- def slider_release(index, image, mask_np_list_updated, mask_label_list):
165
-
166
- if index > len(mask_np_list_updated):
167
- return image, "out of range"
168
- else:
169
- mask_np = mask_np_list_updated[index]
170
- mask_label = mask_label_list[index]
171
- segmentation = create_segmentation(mask_np_list_updated)
172
- new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY)
173
- return new_image, mask_label
174
-
175
- def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
176
- try:
177
- assert np.all(sum(mask_np_list_updated)==1)
178
- except:
179
- print("please check mask")
180
- # plt.imsave( "out_mask.png", mask_list_edit[0])
181
- import pdb; pdb.set_trace()
182
-
183
- for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
184
- # np.save(os.path.join(input_folder, "maskEDIT{}_{}.npy".format(midx, mask_label)),mask )
185
- np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask )
186
- savepath = os.path.join(input_folder, "seg_current.png")
187
- visualize_mask_list_clean(mask_np_list_updated, savepath)
188
-
189
- def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
190
- try:
191
- assert np.all(sum(mask_np_list_updated)==1)
192
- except:
193
- print("please check mask")
194
- # plt.imsave( "out_mask.png", mask_list_edit[0])
195
- import pdb; pdb.set_trace()
196
- for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
197
- np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask)
198
- savepath = os.path.join(input_folder, "seg_edited.png")
199
- visualize_mask_list_clean(mask_np_list_updated, savepath)
200
-
201
-
202
- import shutil
203
- if os.path.isdir("./example_tmp"):
204
- shutil.rmtree("./example_tmp")
205
-
206
- from segment import run_segmentation
207
- with gr.Blocks() as demo:
208
- image = gr.State() # store mask
209
- image_loaded = gr.State()
210
- segmentation = gr.State()
211
-
212
- mask_np_list = gr.State([])
213
- mask_label_list = gr.State([])
214
- mask_np_list_updated = gr.State([])
215
- true = gr.State(True)
216
- false = gr.State(False)
217
-
218
- with gr.Row():
219
- gr.Markdown("""# D-Edit""")
220
-
221
- with gr.Tab(label="1 Edit mask"):
222
- with gr.Row():
223
- with gr.Column():
224
- canvas = gr.Image(value = "./img.png", type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
225
-
226
- segment_button = gr.Button("1.1 Run segmentation")
227
- segment_button.click(run_segmentation,
228
- [canvas] ,
229
- [] )
230
-
231
- text_button = gr.Button("1.2 Load original masks")
232
- text_button.click(load_image_ui,
233
- [ false] ,
234
- [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
235
-
236
- load_edit_button = gr.Button("1.2 Load edited masks")
237
- load_edit_button.click(load_image_ui,
238
- [ true] ,
239
- [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
240
-
241
- show_segment = gr.Checkbox(label = "Show Segmentation")
242
- flag = gr.State(False)
243
- show_segment.select(show_segmentation,
244
- [image_loaded, segmentation, flag],
245
- [canvas, flag])
246
-
247
- # mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
248
- mask_np_list_updated = mask_np_list
249
- with gr.Column():
250
- gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
251
- slider = gr.Slider(0, 20, step=1, interactive=True)
252
- label = gr.Textbox()
253
- slider.release(slider_release,
254
- inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list],
255
- outputs= [canvas, label]
256
- )
257
- add_button = gr.Button("Add")
258
- add_button.click( edit_mask_add,
259
- [canvas, image_loaded, slider, mask_np_list_updated] ,
260
- [mask_np_list_updated, canvas]
261
- )
262
-
263
- save_button2 = gr.Button("Set and Save as edited masks")
264
- save_button2.click( save_as_edit_mask,
265
- [mask_np_list_updated, mask_label_list] ,
266
- [] )
267
-
268
- save_button = gr.Button("Set and Save as original masks")
269
- save_button.click( save_as_orig_mask,
270
- [mask_np_list_updated, mask_label_list] ,
271
- [] )
272
-
273
- back_button = gr.Button("Back to current seg")
274
- back_button.click( load_mask_ui,
275
- [] ,
276
- [ mask_np_list_updated,mask_label_list] )
277
-
278
- add_mask_button = gr.Button("Add new empty mask")
279
- add_mask_button.click(add_mask,
280
- [mask_np_list_updated, mask_label_list] ,
281
- [mask_np_list_updated, mask_label_list] )
282
-
283
- with gr.Tab(label="2 Optimization"):
284
- with gr.Row():
285
-
286
- with gr.Column():
287
- gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
288
- num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
289
- embedding_learning_rate = gr.Textbox(value="0.0001", label="Embedding optimization: Learning rate", interactive= True )
290
- max_emb_train_steps = gr.Number(value="200", label="embedding optimization: Training steps", interactive= True )
291
-
292
- diffusion_model_learning_rate = gr.Textbox(value="0.00005", label="UNet Optimization: Learning rate", interactive= True )
293
- max_diffusion_train_steps = gr.Number(value="200", label="UNet Optimization: Learning rate: Training steps", interactive= True )
294
-
295
- train_batch_size = gr.Number(value="5", label="Batch size", interactive= True )
296
- gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True )
297
-
298
- add_button = gr.Button("Run optimization")
299
- def run_optimization_wrapper (
300
- num_tokens,
301
- embedding_learning_rate ,
302
- max_emb_train_steps ,
303
- diffusion_model_learning_rate ,
304
- max_diffusion_train_steps,
305
- train_batch_size,
306
- gradient_accumulation_steps
307
- ):
308
- run_optimization = partial(
309
- run_main,
310
- num_tokens=int(num_tokens),
311
- embedding_learning_rate = float(embedding_learning_rate),
312
- max_emb_train_steps = int(max_emb_train_steps),
313
- diffusion_model_learning_rate= float(diffusion_model_learning_rate),
314
- max_diffusion_train_steps = int(max_diffusion_train_steps),
315
- train_batch_size=int(train_batch_size),
316
- gradient_accumulation_steps=int(gradient_accumulation_steps)
317
- )
318
- run_optimization()
319
-
320
- add_button.click(run_optimization_wrapper,
321
- inputs = [
322
- num_tokens,
323
- embedding_learning_rate ,
324
- max_emb_train_steps ,
325
- diffusion_model_learning_rate ,
326
- max_diffusion_train_steps,
327
- train_batch_size,
328
- gradient_accumulation_steps
329
- ],
330
- outputs = []
331
- )
332
-
333
-
334
- with gr.Tab(label="3 Editing"):
335
- with gr.Tab(label="3.1 Text-based editing"):
336
-
337
- with gr.Row():
338
- with gr.Column():
339
- canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True)
340
- # canvas_text_edit = gr.Gallery(label = "Edited results")
341
-
342
- with gr.Column():
343
- gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""")
344
-
345
- tgt_prompt = gr.Textbox(value="White bag", label="Editing: Text prompt", interactive= True )
346
- tgt_index = gr.Number(value="0", label="Editing: Object index", interactive= True )
347
- guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True )
348
- num_sampling_steps = gr.Number(value="50", label="Editing: Sampling steps", interactive= True )
349
- edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True )
350
- strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
351
-
352
- add_button = gr.Button("Run Editing")
353
- run_edit_text = partial(
354
- run_main,
355
- load_trained=True,
356
- text=True,
357
- num_tokens = int(num_tokens.value),
358
- guidance_scale = float(guidance_scale.value),
359
- num_sampling_steps = int(num_sampling_steps.value),
360
- strength = float(strength.value),
361
- edge_thickness = int(edge_thickness.value),
362
- num_imgs = 1,
363
- tgt_prompt = tgt_prompt.value,
364
- tgt_index = int(tgt_index.value)
365
- )
366
-
367
- add_button.click(run_edit_text,
368
- inputs = [],
369
- outputs = [canvas_text_edit]
370
- )
371
-
372
- def load_pil_img():
373
- from PIL import Image
374
- return Image.open("example_tmp/text/out_text_0.png")
375
-
376
- load_button = gr.Button("Load results")
377
- load_button.click(load_pil_img,
378
- inputs = [],
379
- outputs = [canvas_text_edit]
380
- )
381
-
382
-
383
-
384
-
385
- demo.queue().launch(share=True, debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app copy.py DELETED
@@ -1,350 +0,0 @@
1
-
2
- import os
3
- import copy
4
- from PIL import Image
5
- import matplotlib
6
- import numpy as np
7
- import gradio as gr
8
- from utils import load_mask, load_mask_edit
9
- from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean
10
- from pathlib import Path
11
- import subprocess
12
- from PIL import Image
13
-
14
- LENGTH=512 #length of the square area displaying/editing images
15
- TRANSPARENCY = 150 # transparency of the mask in display
16
-
17
- def add_mask(mask_np_list_updated, mask_label_list):
18
- mask_new = np.zeros_like(mask_np_list_updated[0])
19
- mask_np_list_updated.append(mask_new)
20
- mask_label_list.append("new")
21
- return mask_np_list_updated, mask_label_list
22
-
23
- def create_segmentation(mask_np_list):
24
- viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list))
25
- segmentation = 0
26
- for i, m in enumerate(mask_np_list):
27
- color = matplotlib.colors.to_rgb(viridis(i))
28
- color_mat = np.ones_like(m)
29
- color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2)
30
- color_mat = color_mat * m[:,:,np.newaxis]
31
- segmentation += color_mat
32
- segmentation = Image.fromarray(np.uint8(segmentation*255))
33
- return segmentation
34
-
35
- def load_mask_ui(input_folder,load_edit = False):
36
- if not load_edit:
37
- mask_list, mask_label_list = load_mask(input_folder)
38
- else:
39
- mask_list, mask_label_list = load_mask_edit(input_folder)
40
-
41
- mask_np_list = []
42
- for m in mask_list:
43
- mask_np_list. append( m.cpu().numpy())
44
-
45
- return mask_np_list, mask_label_list
46
-
47
- def load_image_ui(input_folder, load_edit):
48
- try:
49
- for img_path in Path(input_folder).iterdir():
50
- if img_path.name in ["img.png", "img_1024.png", "img_512.png"]:
51
- image = Image.open(img_path)
52
- mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit)
53
- image = image.convert('RGB')
54
- segmentation = create_segmentation(mask_np_list)
55
- return image, segmentation, mask_np_list, mask_label_list, image
56
- except:
57
- print("Image folder invalid: The folder should contain image.png")
58
- return None, None, None, None, None
59
-
60
- def run_segmentation(input_folder):
61
- subprocess.run(["python", "segment.py" , "--name={}".format(input_folder)])
62
- return
63
-
64
-
65
-
66
- def run_edit_text(
67
- input_folder,
68
- num_tokens,
69
- num_sampling_steps,
70
- strength,
71
- edge_thickness,
72
- tgt_prompt,
73
- tgt_idx,
74
- guidance_scale
75
- ):
76
- subprocess.run(["python",
77
- "main.py" ,
78
- "--text",
79
- "--name={}".format(input_folder),
80
- "--dpm={}".format("sd"),
81
- "--resolution={}".format(512),
82
- "--load_trained",
83
- "--num_tokens={}".format(num_tokens),
84
- "--seed={}".format(2024),
85
- "--guidance_scale={}".format(guidance_scale),
86
- "--num_sampling_step={}".format(num_sampling_steps),
87
- "--strength={}".format(strength),
88
- "--edge_thickness={}".format(edge_thickness),
89
- "--num_imgs={}".format(2),
90
- "--tgt_prompt={}".format(tgt_prompt) ,
91
- "--tgt_index={}".format(tgt_idx)
92
- ])
93
-
94
- return Image.open(os.path.join(input_folder, "text", "out_text_0.png"))
95
-
96
-
97
- def run_optimization(
98
- input_folder,
99
- num_tokens,
100
- embedding_learning_rate,
101
- max_emb_train_steps,
102
- diffusion_model_learning_rate,
103
- max_diffusion_train_steps,
104
- train_batch_size,
105
- gradient_accumulation_steps
106
- ):
107
- subprocess.run(["python",
108
- "main.py" ,
109
- "--name={}".format(input_folder),
110
- "--dpm={}".format("sd"),
111
- "--resolution={}".format(512),
112
- "--num_tokens={}".format(num_tokens),
113
- "--embedding_learning_rate={}".format(embedding_learning_rate),
114
- "--diffusion_model_learning_rate={}".format(diffusion_model_learning_rate),
115
- "--max_emb_train_steps={}".format(max_emb_train_steps),
116
- "--max_diffusion_train_steps={}".format(max_diffusion_train_steps),
117
- "--train_batch_size={}".format(train_batch_size),
118
- "--gradient_accumulation_steps={}".format(gradient_accumulation_steps)
119
-
120
- ])
121
- return
122
-
123
-
124
- def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
125
- backimg_solid_np = np.array(backimg)
126
- bimg = backimg.copy()
127
- fimg = foreimg.copy()
128
- fimg.putalpha(transparency)
129
- bimg.paste(fimg, (0,0), fimg)
130
-
131
- bimg_np = np.array(bimg)
132
- mask_np = mask_np[:,:,np.newaxis]
133
- try:
134
- new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np
135
- return Image.fromarray(new_img_np)
136
- except:
137
- import pdb; pdb.set_trace()
138
-
139
- def show_segmentation(image, segmentation, flag):
140
- if flag is False:
141
- flag = True
142
- mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8)
143
- image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY)
144
- return image_edit, flag
145
- else:
146
- flag = False
147
- return image,flag
148
-
149
- def edit_mask_add(canvas, image, idx, mask_np_list):
150
- mask_sel = mask_np_list[idx]
151
- mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.)
152
- mask_np_list_updated = []
153
- for midx, m in enumerate(mask_np_list):
154
- if midx == idx:
155
- mask_np_list_updated.append(mask_union(mask_sel, mask_new))
156
- else:
157
- mask_np_list_updated.append(m)
158
-
159
- priority_list = [0 for _ in range(len(mask_np_list_updated))]
160
- priority_list[idx] = 1
161
- mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list)
162
- mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8)
163
- segmentation = create_segmentation(mask_np_list_updated)
164
- image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY)
165
- return mask_np_list_updated, image_edit
166
-
167
- def slider_release(index, image, mask_np_list_updated, mask_label_list):
168
- if index > len(mask_np_list_updated):
169
- return image, "out of range"
170
- else:
171
- mask_np = mask_np_list_updated[index]
172
- mask_label = mask_label_list[index]
173
- segmentation = create_segmentation(mask_np_list_updated)
174
- new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY)
175
- return new_image, mask_label
176
-
177
- def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder):
178
- try:
179
- assert np.all(sum(mask_np_list_updated)==1)
180
- except:
181
- print("please check mask")
182
- # plt.imsave( "out_mask.png", mask_list_edit[0])
183
- import pdb; pdb.set_trace()
184
-
185
- for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
186
- # np.save(os.path.join(input_folder, "maskEDIT{}_{}.npy".format(midx, mask_label)),mask )
187
- np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask )
188
- savepath = os.path.join(input_folder, "seg_current.png")
189
- visualize_mask_list_clean(mask_np_list_updated, savepath)
190
-
191
- def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder):
192
- try:
193
- assert np.all(sum(mask_np_list_updated)==1)
194
- except:
195
- print("please check mask")
196
- # plt.imsave( "out_mask.png", mask_list_edit[0])
197
- import pdb; pdb.set_trace()
198
- for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
199
- np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask)
200
- savepath = os.path.join(input_folder, "seg_edited.png")
201
- visualize_mask_list_clean(mask_np_list_updated, savepath)
202
-
203
- with gr.Blocks() as demo:
204
- image = gr.State() # store mask
205
- image_loaded = gr.State()
206
- segmentation = gr.State()
207
-
208
- mask_np_list = gr.State([])
209
- mask_label_list = gr.State([])
210
- mask_np_list_updated = gr.State([])
211
- true = gr.State(True)
212
- false = gr.State(False)
213
-
214
-
215
- with gr.Row():
216
- gr.Markdown("""# D-Edit""")
217
-
218
- with gr.Tab(label="1 Edit mask"):
219
- with gr.Row():
220
- with gr.Column():
221
- canvas = gr.Image(value = None, type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
222
- input_folder = gr.Textbox(value="example1", label="input folder", interactive= True, )
223
-
224
- segment_button = gr.Button("1.1 Run segmentation")
225
- segment_button.click(run_segmentation,
226
- [input_folder] ,
227
- [] )
228
-
229
-
230
- text_button = gr.Button("1.2 Load original masks")
231
- text_button.click(load_image_ui,
232
- [input_folder, false] ,
233
- [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
234
-
235
- load_edit_button = gr.Button("1.2 Load edited masks")
236
- load_edit_button.click(load_image_ui,
237
- [input_folder, true] ,
238
- [image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
239
-
240
- show_segment = gr.Checkbox(label = "Show Segmentation")
241
-
242
- flag = gr.State(False)
243
- show_segment.select(show_segmentation,
244
- [image_loaded, segmentation, flag],
245
- [canvas, flag])
246
-
247
- mask_np_list_updated = copy.deepcopy(mask_np_list)
248
-
249
- with gr.Column():
250
- gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
251
- slider = gr.Slider(0, 20, step=1, interactive=True)
252
- label = gr.Textbox()
253
- slider.release(slider_release,
254
- inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list],
255
- outputs= [canvas, label]
256
- )
257
- add_button = gr.Button("Add")
258
- add_button.click( edit_mask_add,
259
- [canvas, image_loaded, slider, mask_np_list_updated] ,
260
- [mask_np_list_updated, canvas]
261
- )
262
-
263
- save_button2 = gr.Button("Set and Save as edited masks")
264
- save_button2.click( save_as_edit_mask,
265
- [mask_np_list_updated, mask_label_list, input_folder] ,
266
- [] )
267
-
268
- save_button = gr.Button("Set and Save as original masks")
269
- save_button.click( save_as_orig_mask,
270
- [mask_np_list_updated, mask_label_list, input_folder] ,
271
- [] )
272
-
273
- back_button = gr.Button("Back to current seg")
274
- back_button.click( load_mask_ui,
275
- [input_folder] ,
276
- [ mask_np_list_updated,mask_label_list] )
277
-
278
- add_mask_button = gr.Button("Add new empty mask")
279
- add_mask_button.click(add_mask,
280
- [mask_np_list_updated, mask_label_list] ,
281
- [mask_np_list_updated, mask_label_list] )
282
-
283
- with gr.Tab(label="2 Optimization"):
284
- with gr.Row():
285
- with gr.Column():
286
- canvas_opt = gr.Image(value = canvas.value, type="pil", label="Loaded Image", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
287
-
288
- with gr.Column():
289
- gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
290
- num_tokens = gr.Textbox(value="5", label="num tokens to represent each object", interactive= True)
291
- embedding_learning_rate = gr.Textbox(value="1e-4", label="Embedding optimization: Learning rate", interactive= True )
292
- max_emb_train_steps = gr.Textbox(value="500", label="embedding optimization: Training steps", interactive= True )
293
-
294
- diffusion_model_learning_rate = gr.Textbox(value="5e-5", label="UNet Optimization: Learning rate", interactive= True )
295
- max_diffusion_train_steps = gr.Textbox(value="500", label="UNet Optimization: Learning rate: Training steps", interactive= True )
296
-
297
- train_batch_size = gr.Textbox(value="5", label="Batch size", interactive= True )
298
- gradient_accumulation_steps=gr.Textbox(value="5", label="Gradient accumulation", interactive= True )
299
-
300
- add_button = gr.Button("Run optimization")
301
- add_button.click(run_optimization,
302
- inputs = [
303
- input_folder,
304
- num_tokens,
305
- embedding_learning_rate,
306
- max_emb_train_steps,
307
- diffusion_model_learning_rate,
308
- max_diffusion_train_steps,
309
- train_batch_size,gradient_accumulation_steps
310
- ],
311
- outputs = []
312
- )
313
-
314
-
315
- with gr.Tab(label="3 Editing"):
316
- with gr.Tab(label="3.1 Text-based editing"):
317
- canvas_text_edit = gr.State() # store mask
318
- with gr.Row():
319
- with gr.Column():
320
- canvas_text_edit = gr.Image(value = None, type="pil", label="Editing results", show_label=True, height=LENGTH, width=LENGTH)
321
- # canvas_text_edit = gr.Gallery(label = "Edited results")
322
-
323
- with gr.Column():
324
- gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""")
325
-
326
- tgt_prompt = gr.Textbox(value="Dog", label="Editing: Text prompt", interactive= True )
327
- tgt_idx = gr.Textbox(value="0", label="Editing: Object index", interactive= True )
328
- guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True )
329
- num_sampling_steps = gr.Textbox(value="50", label="Editing: Sampling steps", interactive= True )
330
- edge_thickness = gr.Textbox(value="10", label="Editing: Edge thickness", interactive= True )
331
- strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
332
-
333
- add_button = gr.Button("Run Editing")
334
- add_button.click(run_edit_text,
335
- inputs = [
336
- input_folder,
337
- num_tokens,
338
- num_sampling_steps,
339
- strength,
340
- edge_thickness,
341
- tgt_prompt,
342
- tgt_idx,
343
- guidance_scale
344
- ],
345
- outputs = []
346
- )
347
-
348
-
349
-
350
- demo.queue().launch(share=True, debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -59,62 +59,62 @@ def load_image_ui(load_edit, input_folder="example_tmp"):
59
  print("Image folder invalid: The folder should contain image.png")
60
  return None, None, None, None, None
61
 
62
- def run_edit_text(
63
- num_tokens,
64
- num_sampling_steps,
65
- strength,
66
- edge_thickness,
67
- tgt_prompt,
68
- tgt_idx,
69
- guidance_scale,
70
- input_folder="example_tmp"
71
- ):
72
- subprocess.run(["python",
73
- "main.py" ,
74
- "--text",
75
- "--name={}".format(input_folder),
76
- "--dpm={}".format("sd"),
77
- "--resolution={}".format(512),
78
- "--load_trained",
79
- "--num_tokens={}".format(num_tokens),
80
- "--seed={}".format(2024),
81
- "--guidance_scale={}".format(guidance_scale),
82
- "--num_sampling_step={}".format(num_sampling_steps),
83
- "--strength={}".format(strength),
84
- "--edge_thickness={}".format(edge_thickness),
85
- "--num_imgs={}".format(2),
86
- "--tgt_prompt={}".format(tgt_prompt) ,
87
- "--tgt_index={}".format(tgt_idx)
88
- ])
89
 
90
- return Image.open(os.path.join(input_folder, "text", "out_text_0.png"))
91
 
92
 
93
- def run_optimization(
94
- num_tokens,
95
- embedding_learning_rate,
96
- max_emb_train_steps,
97
- diffusion_model_learning_rate,
98
- max_diffusion_train_steps,
99
- train_batch_size,
100
- gradient_accumulation_steps,
101
- input_folder = "example_tmp"
102
- ):
103
- subprocess.run(["python",
104
- "main.py" ,
105
- "--name={}".format(input_folder),
106
- "--dpm={}".format("sd"),
107
- "--resolution={}".format(512),
108
- "--num_tokens={}".format(num_tokens),
109
- "--embedding_learning_rate={}".format(embedding_learning_rate),
110
- "--diffusion_model_learning_rate={}".format(diffusion_model_learning_rate),
111
- "--max_emb_train_steps={}".format(max_emb_train_steps),
112
- "--max_diffusion_train_steps={}".format(max_diffusion_train_steps),
113
- "--train_batch_size={}".format(train_batch_size),
114
- "--gradient_accumulation_steps={}".format(gradient_accumulation_steps)
115
 
116
- ])
117
- return
118
 
119
 
120
  def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
@@ -215,6 +215,7 @@ with gr.Blocks() as demo:
215
  true = gr.State(True)
216
  false = gr.State(False)
217
  block_flag = gr.State(0)
 
218
  with gr.Row():
219
  gr.Markdown("""# D-Edit""")
220
 
@@ -293,6 +294,7 @@ with gr.Blocks() as demo:
293
  opt_flag = gr.State(0)
294
  gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
295
  num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
 
296
  embedding_learning_rate = gr.Textbox(value="0.0001", label="Embedding optimization: Learning rate", interactive= True )
297
  max_emb_train_steps = gr.Number(value="200", label="embedding optimization: Training steps", interactive= True )
298
 
@@ -380,7 +382,7 @@ with gr.Blocks() as demo:
380
  run_main,
381
  load_trained=True,
382
  text=True,
383
- num_tokens = int(num_tokens),
384
  guidance_scale = float(guidance_scale),
385
  num_sampling_steps = int(num_sampling_steps),
386
  strength = float(strength),
@@ -391,8 +393,15 @@ with gr.Blocks() as demo:
391
  )
392
  return run_edit_text()
393
 
394
- add_button.click(run_edit_text,
395
- inputs = [],
 
 
 
 
 
 
 
396
  outputs = [canvas_text_edit]
397
  )
398
 
 
59
  print("Image folder invalid: The folder should contain image.png")
60
  return None, None, None, None, None
61
 
62
+ # def run_edit_text(
63
+ # num_tokens,
64
+ # num_sampling_steps,
65
+ # strength,
66
+ # edge_thickness,
67
+ # tgt_prompt,
68
+ # tgt_idx,
69
+ # guidance_scale,
70
+ # input_folder="example_tmp"
71
+ # ):
72
+ # subprocess.run(["python",
73
+ # "main.py" ,
74
+ # "--text=True",
75
+ # "--name={}".format(input_folder),
76
+ # "--dpm={}".format("sd"),
77
+ # "--resolution={}".format(512),
78
+ # "--load_trained",
79
+ # "--num_tokens={}".format(num_tokens),
80
+ # "--seed={}".format(2024),
81
+ # "--guidance_scale={}".format(guidance_scale),
82
+ # "--num_sampling_step={}".format(num_sampling_steps),
83
+ # "--strength={}".format(strength),
84
+ # "--edge_thickness={}".format(edge_thickness),
85
+ # "--num_imgs={}".format(2),
86
+ # "--tgt_prompt={}".format(tgt_prompt) ,
87
+ # "--tgt_index={}".format(tgt_idx)
88
+ # ])
89
 
90
+ # return Image.open(os.path.join(input_folder, "text", "out_text_0.png"))
91
 
92
 
93
+ # def run_optimization(
94
+ # num_tokens,
95
+ # embedding_learning_rate,
96
+ # max_emb_train_steps,
97
+ # diffusion_model_learning_rate,
98
+ # max_diffusion_train_steps,
99
+ # train_batch_size,
100
+ # gradient_accumulation_steps,
101
+ # input_folder = "example_tmp"
102
+ # ):
103
+ # subprocess.run(["python",
104
+ # "main.py" ,
105
+ # "--name={}".format(input_folder),
106
+ # "--dpm={}".format("sd"),
107
+ # "--resolution={}".format(512),
108
+ # "--num_tokens={}".format(num_tokens),
109
+ # "--embedding_learning_rate={}".format(embedding_learning_rate),
110
+ # "--diffusion_model_learning_rate={}".format(diffusion_model_learning_rate),
111
+ # "--max_emb_train_steps={}".format(max_emb_train_steps),
112
+ # "--max_diffusion_train_steps={}".format(max_diffusion_train_steps),
113
+ # "--train_batch_size={}".format(train_batch_size),
114
+ # "--gradient_accumulation_steps={}".format(gradient_accumulation_steps)
115
 
116
+ # ])
117
+ # return
118
 
119
 
120
  def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
 
215
  true = gr.State(True)
216
  false = gr.State(False)
217
  block_flag = gr.State(0)
218
+ num_tokens_global = gr.State(5)
219
  with gr.Row():
220
  gr.Markdown("""# D-Edit""")
221
 
 
294
  opt_flag = gr.State(0)
295
  gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
296
  num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
297
+ num_tokens_global = num_tokens
298
  embedding_learning_rate = gr.Textbox(value="0.0001", label="Embedding optimization: Learning rate", interactive= True )
299
  max_emb_train_steps = gr.Number(value="200", label="embedding optimization: Training steps", interactive= True )
300
 
 
382
  run_main,
383
  load_trained=True,
384
  text=True,
385
+ num_tokens = int(num_tokens_global.value),
386
  guidance_scale = float(guidance_scale),
387
  num_sampling_steps = int(num_sampling_steps),
388
  strength = float(strength),
 
393
  )
394
  return run_edit_text()
395
 
396
+ add_button.click(run_edit_text_wrapper,
397
+ inputs = [num_tokens_global,
398
+ guidance_scale,
399
+ num_sampling_steps,
400
+ strength ,
401
+ edge_thickness,
402
+ tgt_prompt ,
403
+ tgt_index
404
+ ],
405
  outputs = [canvas_text_edit]
406
  )
407
 
main copy.py DELETED
@@ -1,480 +0,0 @@
1
- import os
2
- import torch
3
- import numpy as np
4
- import argparse
5
- from peft import LoraConfig
6
- from old.pipeline_dedit_sdxl import DEditSDXLPipeline
7
- from pipeline_dedit_sd import DEditSDPipeline
8
- from utils import load_image, load_mask, load_mask_edit
9
- from utils_mask import process_mask_move_torch, process_mask_remove_torch, mask_union_torch, mask_substract_torch, create_outer_edge_mask_torch
10
- from utils_mask import check_mask_overlap_torch, check_cover_all_torch, visualize_mask_list, get_mask_difference_torch, save_mask_list_to_npys
11
-
12
- parser = argparse.ArgumentParser()
13
- parser.add_argument("--name", type=str,required=True, default=None)
14
- parser.add_argument("--name_2", type=str,required=False, default=None)
15
- parser.add_argument("--dpm", type=str,required=True, default="sd")
16
- parser.add_argument("--resolution", type=int, default=1024)
17
- parser.add_argument("--seed", type=int, default=42)
18
- parser.add_argument("--embedding_learning_rate", type=float, default=1e-4)
19
- parser.add_argument("--max_emb_train_steps", type=int, default=200)
20
- parser.add_argument("--diffusion_model_learning_rate", type=float, default=5e-5)
21
- parser.add_argument("--max_diffusion_train_steps", type=int, default=200)
22
- parser.add_argument("--train_batch_size", type=int, default=1)
23
- parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
24
- parser.add_argument("--num_tokens", type=int, default=1)
25
-
26
-
27
- parser.add_argument("--load_trained", default=False, action="store_true" )
28
- parser.add_argument("--num_sampling_steps", type=int, default=50)
29
- parser.add_argument("--guidance_scale", type=float, default = 3 )
30
- parser.add_argument("--strength", type=float, default=0.8)
31
-
32
- parser.add_argument("--train_full_lora", default=False, action="store_true" )
33
- parser.add_argument("--lora_rank", type=int, default=4)
34
- parser.add_argument("--lora_alpha", type=int, default=4)
35
-
36
- parser.add_argument("--prompt_auxin_list", nargs="+", type=str, default = None)
37
- parser.add_argument("--prompt_auxin_idx_list", nargs="+", type=int, default = None)
38
-
39
- # general editing configs
40
- parser.add_argument("--load_edited_mask", default=False, action="store_true")
41
- parser.add_argument("--load_edited_processed_mask", default=False, action="store_true")
42
- parser.add_argument("--edge_thickness", type=int, default=20)
43
- parser.add_argument("--num_imgs", type=int, default = 1 )
44
- parser.add_argument('--active_mask_list', nargs="+", type=int)
45
- parser.add_argument("--tgt_index", type=int, default=None)
46
-
47
- # recon
48
- parser.add_argument("--recon", default=False, action="store_true" )
49
- parser.add_argument("--recon_an_item", default=False, action="store_true" )
50
- parser.add_argument("--recon_prompt", type=str, default=None)
51
-
52
- # text-based editing
53
- parser.add_argument("--text", default=False, action="store_true")
54
- parser.add_argument("--tgt_prompt", type=str, default=None)
55
-
56
- # image-based editing
57
- parser.add_argument("--image", default=False, action="store_true" )
58
- parser.add_argument("--src_index", type=int, default=None)
59
- parser.add_argument("--tgt_name", type=str, default=None)
60
-
61
- # mask-based move
62
- parser.add_argument("--move_resize", default=False, action="store_true" )
63
- parser.add_argument('--tgt_indices_list', nargs="+", type=int)
64
- parser.add_argument("--delta_x_list", nargs="+", type=int)
65
- parser.add_argument("--delta_y_list", nargs="+", type=int)
66
- parser.add_argument("--priority_list", nargs="+", type=int)
67
- parser.add_argument("--force_mask_remain", type=int, default=None)
68
- parser.add_argument("--resize_list", nargs="+", type=float)
69
-
70
- # remove
71
- parser.add_argument("--remove", default=False, action="store_true" )
72
- parser.add_argument("--load_edited_removemask", default=False, action="store_true")
73
-
74
- args = parser.parse_args()
75
-
76
-
77
- def run_main(
78
- name=None,
79
- name_2=None,
80
- dpm="sd",
81
- resolution=1024,
82
- seed=42,
83
- embedding_learning_rate=1e-4,
84
- max_emb_train_steps=200,
85
- diffusion_model_learning_rate=5e-5,
86
- max_diffusion_train_steps=200,
87
- train_batch_size=1,
88
- gradient_accumulation_steps=1,
89
- num_tokens=1,
90
-
91
- load_trained="store_true" ,
92
- num_sampling_steps=50,
93
- guidance_scale= 3 ,
94
- strength=0.8,
95
-
96
- train_full_lora="store_true" ,
97
- lora_rank=4,
98
- lora_alpha=4,
99
-
100
- prompt_auxin_list = None,
101
- prompt_auxin_idx_list= None,
102
-
103
- load_edited_mask="store_true",
104
- load_edited_processed_mask="store_true",
105
- edge_thickness=20,
106
- num_imgs= 1 ,
107
- active_mask_list = None,
108
- tgt_index=None,
109
-
110
- recon=False ,
111
- recon_an_item=False,
112
- recon_prompt=None,
113
-
114
- text="store_true",
115
- tgt_prompt=None,
116
-
117
- image="store_true" ,
118
- src_index=None,
119
- tgt_name=None,
120
-
121
- move_resize="store_true" ,
122
- tgt_indices_list=None,
123
- delta_x_list=None,
124
- delta_y_list=None,
125
- priority_list=None,
126
- force_mask_remain=None,
127
- resize_list=None,
128
-
129
- remove=False,
130
- load_edited_removemask=False
131
- ):
132
- torch.cuda.manual_seed_all(args.seed)
133
- torch.manual_seed(args.seed)
134
- base_input_folder = "."
135
- base_output_folder = "."
136
-
137
- input_folder = os.path.join(base_input_folder, args.name)
138
-
139
-
140
- mask_list, mask_label_list = load_mask(input_folder)
141
- assert mask_list[0].shape[0] == args.resolution, "Segmentation should be done on size {}".format(args.resolution)
142
- try:
143
- image_gt = load_image(os.path.join(input_folder, "img_{}.png".format(args.resolution) ), size = args.resolution)
144
- except:
145
- image_gt = load_image(os.path.join(input_folder, "img_{}.jpg".format(args.resolution) ), size = args.resolution)
146
-
147
- if args.image:
148
- input_folder_2 = os.path.join(base_input_folder, args.name_2)
149
- mask_list_2, mask_label_list_2 = load_mask(input_folder_2)
150
- assert mask_list_2[0].shape[0] == args.resolution, "Segmentation should be done on size {}".format(args.resolution)
151
- try:
152
- image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.png".format(args.resolution) ), size = args.resolution)
153
- except:
154
- image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.jpg".format(args.resolution) ), size = args.resolution)
155
- output_dir = os.path.join(base_output_folder, args.name + "_" + args.name_2)
156
- os.makedirs(output_dir, exist_ok = True)
157
- else:
158
- output_dir = os.path.join(base_output_folder, args.name)
159
- os.makedirs(output_dir, exist_ok = True)
160
-
161
- if args.dpm == "sd":
162
- if args.image:
163
- pipe = DEditSDPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = args.resolution, num_tokens = args.num_tokens)
164
- else:
165
- pipe = DEditSDPipeline(mask_list, mask_label_list, resolution = args.resolution, num_tokens = args.num_tokens)
166
-
167
- elif args.dpm == "sdxl":
168
- if args.image:
169
- pipe = DEditSDXLPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = args.resolution, num_tokens = args.num_tokens)
170
- else:
171
- pipe = DEditSDXLPipeline(mask_list, mask_label_list, resolution = args.resolution, num_tokens = args.num_tokens)
172
-
173
- else:
174
- raise NotImplementedError
175
-
176
- set_string_list = pipe.set_string_list
177
- if args.prompt_auxin_list is not None:
178
- for auxin_idx, auxin_prompt in zip(args.prompt_auxin_idx_list, args.prompt_auxin_list):
179
- set_string_list[auxin_idx] = auxin_prompt.replace("*", set_string_list[auxin_idx] )
180
- print(set_string_list)
181
-
182
- if args.image:
183
- set_string_list_2 = pipe.set_string_list_2
184
- print(set_string_list_2)
185
-
186
- if args.load_trained:
187
- unet_save_path = os.path.join(output_dir, "unet.pt")
188
- unet_state_dict = torch.load(unet_save_path)
189
- text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
190
- text_encoder1_state_dict = torch.load(text_encoder1_save_path)
191
- if args.dpm == "sdxl":
192
- text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
193
- text_encoder2_state_dict = torch.load(text_encoder2_save_path)
194
-
195
- if 'lora' in ''.join(unet_state_dict.keys()):
196
- unet_lora_config = LoraConfig(
197
- r=args.lora_rank,
198
- lora_alpha=args.lora_alpha,
199
- init_lora_weights="gaussian",
200
- target_modules=["to_k", "to_q", "to_v", "to_out.0"],
201
- )
202
- pipe.unet.add_adapter(unet_lora_config)
203
-
204
- pipe.unet.load_state_dict(unet_state_dict)
205
- pipe.text_encoder.load_state_dict(text_encoder1_state_dict)
206
- if args.dpm == "sdxl":
207
- pipe.text_encoder_2.load_state_dict(text_encoder2_state_dict)
208
- else:
209
- if args.image:
210
- pipe.mask_list = [m.cuda() for m in pipe.mask_list]
211
- pipe.mask_list_2 = [m.cuda() for m in pipe.mask_list_2]
212
- pipe.train_emb_2imgs(
213
- image_gt,
214
- image_gt_2,
215
- set_string_list,
216
- set_string_list_2,
217
- gradient_accumulation_steps = args.gradient_accumulation_steps,
218
- embedding_learning_rate = args.embedding_learning_rate,
219
- max_emb_train_steps = args.max_emb_train_steps,
220
- train_batch_size = args.train_batch_size,
221
- )
222
-
223
- pipe.train_model_2imgs(
224
- image_gt,
225
- image_gt_2,
226
- set_string_list,
227
- set_string_list_2,
228
- gradient_accumulation_steps = args.gradient_accumulation_steps,
229
- max_diffusion_train_steps = args.max_diffusion_train_steps,
230
- diffusion_model_learning_rate = args.diffusion_model_learning_rate ,
231
- train_batch_size =args.train_batch_size,
232
- train_full_lora = args.train_full_lora,
233
- lora_rank = args.lora_rank,
234
- lora_alpha = args.lora_alpha
235
- )
236
-
237
- else:
238
- pipe.mask_list = [m.cuda() for m in pipe.mask_list]
239
- pipe.train_emb(
240
- image_gt,
241
- set_string_list,
242
- gradient_accumulation_steps = args.gradient_accumulation_steps,
243
- embedding_learning_rate = args.embedding_learning_rate,
244
- max_emb_train_steps = args.max_emb_train_steps,
245
- train_batch_size = args.train_batch_size,
246
- )
247
-
248
- pipe.train_model(
249
- image_gt,
250
- set_string_list,
251
- gradient_accumulation_steps = args.gradient_accumulation_steps,
252
- max_diffusion_train_steps = args.max_diffusion_train_steps,
253
- diffusion_model_learning_rate = args.diffusion_model_learning_rate ,
254
- train_batch_size = args.train_batch_size,
255
- train_full_lora = args.train_full_lora,
256
- lora_rank = args.lora_rank,
257
- lora_alpha = args.lora_alpha
258
- )
259
-
260
-
261
- unet_save_path = os.path.join(output_dir, "unet.pt")
262
- torch.save(pipe.unet.state_dict(),unet_save_path )
263
- text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
264
- torch.save(pipe.text_encoder.state_dict(), text_encoder1_save_path)
265
- if args.dpm == "sdxl":
266
- text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
267
- torch.save(pipe.text_encoder_2.state_dict(), text_encoder2_save_path )
268
-
269
-
270
- if args.recon:
271
- output_dir = os.path.join(output_dir, "recon")
272
- os.makedirs(output_dir, exist_ok = True)
273
- if args.recon_an_item:
274
- mask_list = [torch.from_numpy(np.ones_like(mask_list[0].numpy()))]
275
- tgt_string = set_string_list[args.tgt_index]
276
- tgt_string = args.recon_prompt.replace("*", tgt_string)
277
- set_string_list = [tgt_string]
278
- print(set_string_list)
279
- save_path = os.path.join(output_dir, "out_recon.png")
280
- x_np = pipe.inference_with_mask(
281
- save_path,
282
- guidance_scale = args.guidance_scale,
283
- num_sampling_steps = args.num_sampling_steps,
284
- seed = args.seed,
285
- num_imgs = args.num_imgs,
286
- set_string_list = set_string_list,
287
- mask_list = mask_list
288
- )
289
-
290
- if args.text:
291
- print("Text-guided editing ")
292
- output_dir = os.path.join(output_dir, "text")
293
- os.makedirs(output_dir, exist_ok = True)
294
- save_path = os.path.join(output_dir, "out_text.png")
295
- set_string_list[args.tgt_index] = args.tgt_prompt
296
- mask_active = torch.zeros_like(mask_list[0])
297
- mask_active = mask_union_torch(mask_active, mask_list[args.tgt_index])
298
-
299
- if args.active_mask_list is not None:
300
- for midx in args.active_mask_list:
301
- mask_active = mask_union_torch(mask_active, mask_list[midx])
302
-
303
- if args.load_edited_mask:
304
- mask_list_edited, mask_label_list_edited = load_mask_edit(input_folder)
305
- mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
306
- mask_active = mask_union_torch(mask_active, mask_diff)
307
- mask_list = mask_list_edited
308
- save_path = os.path.join(output_dir, "out_textEdited.png")
309
-
310
- mask_hard = mask_substract_torch(torch.ones_like(mask_list[0]), mask_active)
311
- mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = args.edge_thickness)
312
- mask_hard = mask_substract_torch(mask_hard, mask_soft)
313
-
314
- pipe.inference_with_mask(
315
- save_path,
316
- orig_image = image_gt,
317
- set_string_list = set_string_list,
318
- guidance_scale = args.guidance_scale,
319
- strength = args.strength,
320
- num_imgs = args.num_imgs,
321
- mask_hard= mask_hard,
322
- mask_soft = mask_soft,
323
- mask_list = mask_list,
324
- seed = args.seed,
325
- num_sampling_steps = args.num_sampling_steps
326
- )
327
-
328
- if args.remove:
329
- output_dir = os.path.join(output_dir, "remove")
330
- save_path = os.path.join(output_dir, "out_remove.png")
331
- os.makedirs(output_dir, exist_ok = True)
332
- mask_active = torch.zeros_like(mask_list[0])
333
-
334
- if args.load_edited_mask:
335
- mask_list_edited, _ = load_mask_edit(input_folder)
336
- mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
337
- mask_active = mask_union_torch(mask_active, mask_diff)
338
- mask_list = mask_list_edited
339
-
340
- if args.load_edited_processed_mask:
341
- # manually edit or draw masks after removing one index, then load
342
- mask_list_processed, _ = load_mask_edit(output_dir)
343
- mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
344
- else:
345
- # generate masks after removing one index, using nearest neighbor algorithm
346
- mask_list_processed, mask_remain = process_mask_remove_torch(mask_list, args.tgt_index)
347
- save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
348
- visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_removed.png"))
349
- check_cover_all_torch(*mask_list_processed)
350
- mask_active = mask_union_torch(mask_active, mask_remain)
351
-
352
- if args.active_mask_list is not None:
353
- for midx in args.active_mask_list:
354
- mask_active = mask_union_torch(mask_active, mask_list[midx])
355
-
356
- mask_hard = 1 - mask_active
357
- mask_soft = create_outer_edge_mask_torch(mask_remain, edge_thickness = args.edge_thickness)
358
- mask_hard = mask_substract_torch(mask_hard, mask_soft)
359
-
360
- pipe.inference_with_mask(
361
- save_path,
362
- orig_image = image_gt,
363
- guidance_scale = args.guidance_scale,
364
- strength = args.strength,
365
- num_imgs = args.num_imgs,
366
- mask_hard= mask_hard,
367
- mask_soft = mask_soft,
368
- mask_list = mask_list_processed,
369
- seed = args.seed,
370
- num_sampling_steps = args.num_sampling_steps
371
- )
372
-
373
- if args.image:
374
- output_dir = os.path.join(output_dir, "image")
375
- save_path = os.path.join(output_dir, "out_image.png")
376
- os.makedirs(output_dir, exist_ok = True)
377
- mask_active = torch.zeros_like(mask_list[0])
378
-
379
- if None not in (args.tgt_name, args.src_index, args.tgt_index):
380
- if args.tgt_name == args.name:
381
- set_string_list_tgt = set_string_list
382
- set_string_list_src = set_string_list_2
383
- image_tgt = image_gt
384
- if args.load_edited_mask:
385
- mask_list_edited, _ = load_mask_edit(input_folder)
386
- mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
387
- mask_active = mask_union_torch(mask_active, mask_diff)
388
- mask_list = mask_list_edited
389
- save_path = os.path.join(output_dir, "out_imageEdited.png")
390
- mask_list_tgt = mask_list
391
-
392
- elif args.tgt_name == args.name_2:
393
- set_string_list_tgt = set_string_list_2
394
- set_string_list_src = set_string_list
395
- image_tgt = image_gt_2
396
- if args.load_edited_mask:
397
- mask_list_2_edited, _ = load_mask_edit(input_folder_2)
398
- mask_diff = get_mask_difference_torch(mask_list_2_edited, mask_list_2)
399
- mask_active = mask_union_torch(mask_active, mask_diff)
400
- mask_list_2 = mask_list_2_edited
401
- save_path = os.path.join(output_dir, "out_imageEdited.png")
402
- mask_list_tgt = mask_list_2
403
- else:
404
- exit("tgt_name should be either name or name_2")
405
-
406
- set_string_list_tgt[args.tgt_index] = set_string_list_src[args.src_index]
407
-
408
- mask_active = mask_list_tgt[args.tgt_index]
409
- mask_frozen = (1-mask_active.float()).to(mask_active.device)
410
- mask_soft = create_outer_edge_mask_torch(mask_active.cpu(), edge_thickness = args.edge_thickness)
411
- mask_hard = mask_substract_torch(mask_frozen.cpu(), mask_soft.cpu())
412
-
413
- mask_list_tgt = [m.cuda() for m in mask_list_tgt]
414
-
415
- pipe.inference_with_mask(
416
- save_path,
417
- set_string_list = set_string_list_tgt,
418
- mask_list = mask_list_tgt,
419
- guidance_scale = args.guidance_scale,
420
- num_sampling_steps = args.num_sampling_steps,
421
- mask_hard = mask_hard.cuda(),
422
- mask_soft = mask_soft.cuda(),
423
- num_imgs = args.num_imgs,
424
- orig_image = image_tgt,
425
- strength = args.strength,
426
- )
427
-
428
- if args.move_resize:
429
- output_dir = os.path.join(output_dir, "move_resize")
430
- os.makedirs(output_dir, exist_ok = True)
431
- save_path = os.path.join(output_dir, "out_moveresize.png")
432
- mask_active = torch.zeros_like(mask_list[0])
433
-
434
- if args.load_edited_mask:
435
- mask_list_edited, _ = load_mask_edit(input_folder)
436
- mask_diff = get_mask_difference_torch(mask_list_edited, mask_list)
437
- mask_active = mask_union_torch(mask_active, mask_diff)
438
- mask_list = mask_list_edited
439
- # save_path = os.path.join(output_dir, "out_moveresizeEdited.png")
440
-
441
- if args.load_edited_processed_mask:
442
- mask_list_processed, _ = load_mask_edit(output_dir)
443
- mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
444
- else:
445
- mask_list_processed, mask_remain = process_mask_move_torch(
446
- mask_list,
447
- args.tgt_indices_list,
448
- args.delta_x_list,
449
- args.delta_y_list, args.priority_list,
450
- force_mask_remain = args.force_mask_remain,
451
- resize_list = args.resize_list
452
- )
453
- save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
454
- visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_move_resize.png"))
455
- active_idxs = args.tgt_indices_list
456
-
457
- mask_active = mask_union_torch(mask_active, *[m for midx, m in enumerate(mask_list_processed) if midx in active_idxs])
458
- mask_active = mask_union_torch(mask_remain, mask_active)
459
- if args.active_mask_list is not None:
460
- for midx in args.active_mask_list:
461
- mask_active = mask_union_torch(mask_active, mask_list_processed[midx])
462
-
463
- mask_frozen =(1 - mask_active.float())
464
- mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = args.edge_thickness)
465
- mask_hard = mask_substract_torch(mask_frozen, mask_soft)
466
-
467
- check_mask_overlap_torch(mask_hard, mask_soft)
468
-
469
- pipe.inference_with_mask(
470
- save_path,
471
- strength = args.strength,
472
- orig_image = image_gt,
473
- guidance_scale = args.guidance_scale,
474
- num_sampling_steps = args.num_sampling_steps,
475
- num_imgs = args.num_imgs,
476
- mask_hard= mask_hard,
477
- mask_soft = mask_soft,
478
- mask_list = mask_list_processed,
479
- seed = args.seed
480
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -64,6 +64,7 @@ def run_main(
64
  remove=False,
65
  load_edited_removemask=False
66
  ):
 
67
  torch.cuda.manual_seed_all(seed)
68
  torch.manual_seed(seed)
69
  base_input_folder = "."
@@ -220,9 +221,9 @@ def run_main(
220
  set_string_list = set_string_list,
221
  mask_list = mask_list
222
  )
223
-
224
  if text:
225
- print("Text-guided editing ")
226
  output_dir = os.path.join(output_dir, "text")
227
  os.makedirs(output_dir, exist_ok = True)
228
  save_path = os.path.join(output_dir, "out_text.png")
 
64
  remove=False,
65
  load_edited_removemask=False
66
  ):
67
+
68
  torch.cuda.manual_seed_all(seed)
69
  torch.manual_seed(seed)
70
  base_input_folder = "."
 
221
  set_string_list = set_string_list,
222
  mask_list = mask_list
223
  )
224
+
225
  if text:
226
+ print("*** Text-guided editing ")
227
  output_dir = os.path.join(output_dir, "text")
228
  os.makedirs(output_dir, exist_ok = True)
229
  save_path = os.path.join(output_dir, "out_text.png")
pipeline_dedit_sd.py CHANGED
@@ -810,5 +810,5 @@ class DEditSDPipeline:
810
  seed = seed
811
  )
812
  save_images(x0, save_path)
813
- # from PIL import Image
814
- # return Image.open("example_tmp/text/out_text_0.png")
 
810
  seed = seed
811
  )
812
  save_images(x0, save_path)
813
+ from PIL import Image
814
+ return Image.open("example_tmp/text/out_text_0.png")