3v324v23 commited on
Commit
515f781
1 Parent(s): f141c64

code pushed

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +1 -1
  2. app.py +494 -0
  3. configs/model/autokl.yaml +26 -0
  4. configs/model/clip.yaml +12 -0
  5. configs/model/controlnet.yaml +18 -0
  6. configs/model/openai_unet.yaml +35 -0
  7. configs/model/pfd.yaml +33 -0
  8. configs/model/seecoder.yaml +62 -0
  9. configs/model/swin.yaml +32 -0
  10. lib/__init__.py +0 -0
  11. lib/__pycache__/__init__.cpython-310.pyc +0 -0
  12. lib/__pycache__/cfg_helper.cpython-310.pyc +0 -0
  13. lib/__pycache__/cfg_holder.cpython-310.pyc +0 -0
  14. lib/__pycache__/log_service.cpython-310.pyc +0 -0
  15. lib/__pycache__/sync.cpython-310.pyc +0 -0
  16. lib/cfg_helper.py +666 -0
  17. lib/cfg_holder.py +28 -0
  18. lib/log_service.py +165 -0
  19. lib/model_zoo/__init__.py +4 -0
  20. lib/model_zoo/__pycache__/__init__.cpython-310.pyc +0 -0
  21. lib/model_zoo/__pycache__/attention.cpython-310.pyc +0 -0
  22. lib/model_zoo/__pycache__/autokl.cpython-310.pyc +0 -0
  23. lib/model_zoo/__pycache__/autokl_modules.cpython-310.pyc +0 -0
  24. lib/model_zoo/__pycache__/autokl_utils.cpython-310.pyc +0 -0
  25. lib/model_zoo/__pycache__/controlnet.cpython-310.pyc +0 -0
  26. lib/model_zoo/__pycache__/ddim.cpython-310.pyc +0 -0
  27. lib/model_zoo/__pycache__/diffusion_utils.cpython-310.pyc +0 -0
  28. lib/model_zoo/__pycache__/distributions.cpython-310.pyc +0 -0
  29. lib/model_zoo/__pycache__/ema.cpython-310.pyc +0 -0
  30. lib/model_zoo/__pycache__/openaimodel.cpython-310.pyc +0 -0
  31. lib/model_zoo/__pycache__/pfd.cpython-310.pyc +0 -0
  32. lib/model_zoo/__pycache__/seecoder.cpython-310.pyc +0 -0
  33. lib/model_zoo/__pycache__/seecoder_utils.cpython-310.pyc +0 -0
  34. lib/model_zoo/__pycache__/swin.cpython-310.pyc +0 -0
  35. lib/model_zoo/attention.py +540 -0
  36. lib/model_zoo/autokl.py +166 -0
  37. lib/model_zoo/autokl_modules.py +835 -0
  38. lib/model_zoo/autokl_utils.py +400 -0
  39. lib/model_zoo/clip.py +788 -0
  40. lib/model_zoo/common/__pycache__/get_model.cpython-310.pyc +0 -0
  41. lib/model_zoo/common/__pycache__/get_optimizer.cpython-310.pyc +0 -0
  42. lib/model_zoo/common/__pycache__/get_scheduler.cpython-310.pyc +0 -0
  43. lib/model_zoo/common/__pycache__/utils.cpython-310.pyc +0 -0
  44. lib/model_zoo/common/get_model.py +124 -0
  45. lib/model_zoo/common/get_optimizer.py +47 -0
  46. lib/model_zoo/common/get_scheduler.py +262 -0
  47. lib/model_zoo/common/utils.py +292 -0
  48. lib/model_zoo/controlnet.py +503 -0
  49. lib/model_zoo/controlnet_annotator/canny/__init__.py +5 -0
  50. lib/model_zoo/controlnet_annotator/hed/__init__.py +134 -0
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: Prompt-Free Diffusion
3
  emoji: 👀
4
- colorFrom: orange
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.32.0
 
1
  ---
2
  title: Prompt-Free Diffusion
3
  emoji: 👀
4
+ colorFrom: red
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.32.0
app.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ # Copyright (C) 2023 Xingqian Xu - All Rights Reserved #
3
+ # #
4
+ # Please visit Prompt-Free-Diffusion's arXiv paper for more details, link at #
5
+ # arxiv.org/abs/2305.16223 #
6
+ # #
7
+ ################################################################################
8
+
9
+ import gradio as gr
10
+ import os.path as osp
11
+ from PIL import Image
12
+ import numpy as np
13
+ import time
14
+
15
+ import torch
16
+ import torchvision.transforms as tvtrans
17
+ from lib.cfg_helper import model_cfg_bank
18
+ from lib.model_zoo import get_model
19
+
20
+ from collections import OrderedDict
21
+ from lib.model_zoo.ddim import DDIMSampler
22
+
23
+ n_sample_image = 1
24
+
25
+ controlnet_path = OrderedDict([
26
+ ['canny' , ('canny' , 'pretrained/controlnet/control_sd15_canny_slimmed.safetensors')],
27
+ ['canny_v11p' , ('canny' , 'pretrained/controlnet/control_v11p_sd15_canny_slimmed.safetensors')],
28
+ ['depth' , ('depth' , 'pretrained/controlnet/control_sd15_depth_slimmed.safetensors')],
29
+ ['hed' , ('hed' , 'pretrained/controlnet/control_sd15_hed_slimmed.safetensors')],
30
+ ['mlsd' , ('mlsd' , 'pretrained/controlnet/control_sd15_mlsd_slimmed.safetensors')],
31
+ ['mlsd_v11p' , ('mlsd' , 'pretrained/controlnet/control_v11p_sd15_mlsd_slimmed.safetensors')],
32
+ ['normal' , ('normal' , 'pretrained/controlnet/control_sd15_normal_slimmed.safetensors')],
33
+ ['openpose' , ('openpose', 'pretrained/controlnet/control_sd15_openpose_slimmed.safetensors')],
34
+ ['openpose_v11p' , ('openpose', 'pretrained/controlnet/control_v11p_sd15_openpose_slimmed.safetensors')],
35
+ ['scribble' , ('scribble', 'pretrained/controlnet/control_sd15_scribble_slimmed.safetensors')],
36
+ ['softedge_v11p' , ('scribble', 'pretrained/controlnet/control_v11p_sd15_softedge_slimmed.safetensors')],
37
+ ['seg' , ('none' , 'pretrained/controlnet/control_sd15_seg_slimmed.safetensors')],
38
+ ['lineart_v11p' , ('none' , 'pretrained/controlnet/control_v11p_sd15_lineart_slimmed.safetensors')],
39
+ ['lineart_anime_v11p', ('none' , 'pretrained/controlnet/control_v11p_sd15s2_lineart_anime_slimmed.safetensors')],
40
+ ])
41
+
42
+ preprocess_method = [
43
+ 'canny' ,
44
+ 'depth' ,
45
+ 'hed' ,
46
+ 'mlsd' ,
47
+ 'normal' ,
48
+ 'openpose' ,
49
+ 'openpose_withface' ,
50
+ 'openpose_withfacehand',
51
+ 'scribble' ,
52
+ 'none' ,
53
+ ]
54
+
55
+ diffuser_path = OrderedDict([
56
+ ['SD-v1.5' , 'pretrained/pfd/diffuser/SD-v1-5.safetensors'],
57
+ ['OpenJouney-v4' , 'pretrained/pfd/diffuser/OpenJouney-v4.safetensors'],
58
+ ['Deliberate-v2.0' , 'pretrained/pfd/diffuser/Deliberate-v2-0.safetensors'],
59
+ ['RealisticVision-v2.0', 'pretrained/pfd/diffuser/RealisticVision-v2-0.safetensors'],
60
+ ['Anything-v4' , 'pretrained/pfd/diffuser/Anything-v4.safetensors'],
61
+ ['Oam-v3' , 'pretrained/pfd/diffuser/AbyssOrangeMix-v3.safetensors'],
62
+ ['Oam-v2' , 'pretrained/pfd/diffuser/AbyssOrangeMix-v2.safetensors'],
63
+ ])
64
+
65
+ ctxencoder_path = OrderedDict([
66
+ ['SeeCoder' , 'pretrained/pfd/seecoder/seecoder-v1-0.safetensors'],
67
+ ['SeeCoder-PA' , 'pretrained/pfd/seecoder/seecoder-pa-v1-0.safetensors'],
68
+ ['SeeCoder-Anime', 'pretrained/pfd/seecoder/seecoder-anime-v1-0.safetensors'],
69
+ ])
70
+
71
+ ##########
72
+ # helper #
73
+ ##########
74
+
75
+ def highlight_print(info):
76
+ print('')
77
+ print(''.join(['#']*(len(info)+4)))
78
+ print('# '+info+' #')
79
+ print(''.join(['#']*(len(info)+4)))
80
+ print('')
81
+
82
+ def load_sd_from_file(target):
83
+ if osp.splitext(target)[-1] == '.ckpt':
84
+ sd = torch.load(target, map_location='cpu')['state_dict']
85
+ elif osp.splitext(target)[-1] == '.pth':
86
+ sd = torch.load(target, map_location='cpu')
87
+ elif osp.splitext(target)[-1] == '.safetensors':
88
+ from safetensors.torch import load_file as stload
89
+ sd = OrderedDict(stload(target, device='cpu'))
90
+ else:
91
+ assert False, "File type must be .ckpt or .pth or .safetensors"
92
+ return sd
93
+
94
+ ########
95
+ # main #
96
+ ########
97
+
98
+ class prompt_free_diffusion(object):
99
+ def __init__(self,
100
+ fp16=False,
101
+ tag_ctx=None,
102
+ tag_diffuser=None,
103
+ tag_ctl=None,):
104
+
105
+ self.tag_ctx = tag_ctx
106
+ self.tag_diffuser = tag_diffuser
107
+ self.tag_ctl = tag_ctl
108
+ self.strict_sd = True
109
+
110
+ cfgm = model_cfg_bank()('pfd_seecoder_with_controlnet')
111
+ self.net = get_model()(cfgm)
112
+
113
+ self.action_load_ctx(tag_ctx)
114
+ self.action_load_diffuser(tag_diffuser)
115
+ self.action_load_ctl(tag_ctl)
116
+
117
+ if fp16:
118
+ highlight_print('Running in FP16')
119
+ self.net.ctx['image'].fp16 = True
120
+ self.net = self.net.half()
121
+ self.dtype = torch.float16
122
+ else:
123
+ self.dtype = torch.float32
124
+
125
+ self.use_cuda = torch.cuda.is_available()
126
+ if self.use_cuda:
127
+ self.net.to('cuda')
128
+
129
+ self.net.eval()
130
+ self.sampler = DDIMSampler(self.net)
131
+
132
+ self.n_sample_image = n_sample_image
133
+ self.ddim_steps = 50
134
+ self.ddim_eta = 0.0
135
+ self.image_latent_dim = 4
136
+
137
+ def load_ctx(self, pretrained):
138
+ sd = load_sd_from_file(pretrained)
139
+ sd_extra = [(ki, vi) for ki, vi in self.net.state_dict().items() \
140
+ if ki.find('ctx.')!=0]
141
+ sd.update(OrderedDict(sd_extra))
142
+
143
+ self.net.load_state_dict(sd, strict=True)
144
+ print('Load context encoder from [{}] strict [{}].'.format(pretrained, True))
145
+
146
+ def load_diffuser(self, pretrained):
147
+ sd = load_sd_from_file(pretrained)
148
+ if len([ki for ki in sd.keys() if ki.find('diffuser.image.context_blocks.')==0]) == 0:
149
+ sd = [(
150
+ ki.replace('diffuser.text.context_blocks.', 'diffuser.image.context_blocks.'), vi)
151
+ for ki, vi in sd.items()]
152
+ sd = OrderedDict(sd)
153
+ sd_extra = [(ki, vi) for ki, vi in self.net.state_dict().items() \
154
+ if ki.find('diffuser.')!=0]
155
+ sd.update(OrderedDict(sd_extra))
156
+ self.net.load_state_dict(sd, strict=True)
157
+ print('Load diffuser from [{}] strict [{}].'.format(pretrained, True))
158
+
159
+ def load_ctl(self, pretrained):
160
+ sd = load_sd_from_file(pretrained)
161
+ self.net.ctl.load_state_dict(sd, strict=True)
162
+ print('Load controlnet from [{}] strict [{}].'.format(pretrained, True))
163
+
164
+ def action_load_ctx(self, tag):
165
+ pretrained = ctxencoder_path[tag]
166
+ if tag == 'SeeCoder-PA':
167
+ from lib.model_zoo.seecoder import PPE_MLP
168
+ pe_layer = \
169
+ PPE_MLP(freq_num=20, freq_max=None, out_channel=768, mlp_layer=3)
170
+ if self.dtype == torch.float16:
171
+ pe_layer = pe_layer.half()
172
+ if self.use_cuda:
173
+ pe_layer.to('cuda')
174
+ pe_layer.eval()
175
+ self.net.ctx['image'].qtransformer.pe_layer = pe_layer
176
+ else:
177
+ self.net.ctx['image'].qtransformer.pe_layer = None
178
+ if pretrained is not None:
179
+ self.load_ctx(pretrained)
180
+ self.tag_ctx = tag
181
+ return tag
182
+
183
+ def action_load_diffuser(self, tag):
184
+ pretrained = diffuser_path[tag]
185
+ if pretrained is not None:
186
+ self.load_diffuser(pretrained)
187
+ self.tag_diffuser = tag
188
+ return tag
189
+
190
+ def action_load_ctl(self, tag):
191
+ pretrained = controlnet_path[tag][1]
192
+ if pretrained is not None:
193
+ self.load_ctl(pretrained)
194
+ self.tag_ctl = tag
195
+ return tag
196
+
197
+ def action_autoset_hw(self, imctl):
198
+ if imctl is None:
199
+ return 512, 512
200
+ w, h = imctl.size
201
+ w = w//64 * 64
202
+ h = h//64 * 64
203
+ w = w if w >=512 else 512
204
+ w = w if w <=1536 else 1536
205
+ h = h if h >=512 else 512
206
+ h = h if h <=1536 else 1536
207
+ return h, w
208
+
209
+ def action_autoset_method(self, tag):
210
+ return controlnet_path[tag][0]
211
+
212
+ def action_inference(
213
+ self, im, imctl, ctl_method, do_preprocess,
214
+ h, w, ugscale, seed,
215
+ tag_ctx, tag_diffuser, tag_ctl,):
216
+
217
+ if tag_ctx != self.tag_ctx:
218
+ self.action_load_ctx(tag_ctx)
219
+ if tag_diffuser != self.tag_diffuser:
220
+ self.action_load_diffuser(tag_diffuser)
221
+ if tag_ctl != self.tag_ctl:
222
+ self.action_load_ctl(tag_ctl)
223
+
224
+ n_samples = self.n_sample_image
225
+
226
+ sampler = self.sampler
227
+ device = self.net.device
228
+
229
+ w = w//64 * 64
230
+ h = h//64 * 64
231
+ if imctl is not None:
232
+ imctl = imctl.resize([w, h], Image.Resampling.BICUBIC)
233
+
234
+ craw = tvtrans.ToTensor()(im)[None].to(device).to(self.dtype)
235
+ c = self.net.ctx_encode(craw, which='image').repeat(n_samples, 1, 1)
236
+ u = torch.zeros_like(c)
237
+
238
+ if tag_ctx in ["SeeCoder-Anime"]:
239
+ u = torch.load('assets/anime_ug.pth')[None].to(device).to(self.dtype)
240
+ pad = c.size(1) - u.size(1)
241
+ u = torch.cat([u, torch.zeros_like(u[:, 0:1].repeat(1, pad, 1))], axis=1)
242
+
243
+ if tag_ctl != 'none':
244
+ ccraw = tvtrans.ToTensor()(imctl)[None].to(device).to(self.dtype)
245
+ if do_preprocess:
246
+ cc = self.net.ctl.preprocess(ccraw, type=ctl_method, size=[h, w])
247
+ cc = cc.to(self.dtype)
248
+ else:
249
+ cc = ccraw
250
+ else:
251
+ cc = None
252
+
253
+ shape = [n_samples, self.image_latent_dim, h//8, w//8]
254
+
255
+ if seed < 0:
256
+ np.random.seed(int(time.time()))
257
+ torch.manual_seed(-seed + 100)
258
+ else:
259
+ np.random.seed(seed + 100)
260
+ torch.manual_seed(seed)
261
+
262
+ x, _ = sampler.sample(
263
+ steps=self.ddim_steps,
264
+ x_info={'type':'image',},
265
+ c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u,
266
+ 'unconditional_guidance_scale':ugscale,
267
+ 'control':cc,},
268
+ shape=shape,
269
+ verbose=False,
270
+ eta=self.ddim_eta)
271
+
272
+ ccout = [tvtrans.ToPILImage()(i) for i in cc] if cc is not None else []
273
+ imout = self.net.vae_decode(x, which='image')
274
+ imout = [tvtrans.ToPILImage()(i) for i in imout]
275
+ return imout + ccout
276
+
277
+ pfd_inference = prompt_free_diffusion(
278
+ fp16=True, tag_ctx = 'SeeCoder', tag_diffuser = 'Deliberate-v2.0', tag_ctl = 'canny',)
279
+
280
+ #################
281
+ # sub interface #
282
+ #################
283
+
284
+ cache_examples = True
285
+
286
+ def get_example():
287
+ case = [
288
+ [
289
+ 'assets/examples/ghibli-input.jpg',
290
+ 'assets/examples/ghibli-canny.png',
291
+ 'canny', False,
292
+ 768, 1024, 1.8, 23,
293
+ 'SeeCoder', 'Deliberate-v2.0', 'canny', ],
294
+ [
295
+ 'assets/examples/astronautridinghouse-input.jpg',
296
+ 'assets/examples/astronautridinghouse-canny.png',
297
+ 'canny', False,
298
+ 512, 768, 2.0, 21,
299
+ 'SeeCoder', 'Deliberate-v2.0', 'canny', ],
300
+ [
301
+ 'assets/examples/grassland-input.jpg',
302
+ 'assets/examples/grassland-scribble.png',
303
+ 'scribble', False,
304
+ 768, 512, 2.0, 41,
305
+ 'SeeCoder', 'Deliberate-v2.0', 'scribble', ],
306
+ [
307
+ 'assets/examples/jeep-input.jpg',
308
+ 'assets/examples/jeep-depth.png',
309
+ 'depth', False,
310
+ 512, 768, 2.0, 30,
311
+ 'SeeCoder', 'Deliberate-v2.0', 'depth', ],
312
+ [
313
+ 'assets/examples/bedroom-input.jpg',
314
+ 'assets/examples/bedroom-mlsd.png',
315
+ 'mlsd', False,
316
+ 512, 512, 2.0, 31,
317
+ 'SeeCoder', 'Deliberate-v2.0', 'mlsd', ],
318
+ [
319
+ 'assets/examples/nightstreet-input.jpg',
320
+ 'assets/examples/nightstreet-canny.png',
321
+ 'canny', False,
322
+ 768, 512, 2.3, 20,
323
+ 'SeeCoder', 'Deliberate-v2.0', 'canny', ],
324
+ [
325
+ 'assets/examples/woodcar-input.jpg',
326
+ 'assets/examples/woodcar-depth.png',
327
+ 'depth', False,
328
+ 768, 512, 2.0, 20,
329
+ 'SeeCoder', 'Deliberate-v2.0', 'depth', ],
330
+ [
331
+ 'assets/examples-anime/miku.jpg',
332
+ 'assets/examples-anime/miku-canny.png',
333
+ 'canny', False,
334
+ 768, 576, 1.5, 22,
335
+ 'SeeCoder-Anime', 'Anything-v4', 'canny', ],
336
+ [
337
+ 'assets/examples-anime/random0.jpg',
338
+ 'assets/examples-anime/pose.png',
339
+ 'openpose', False,
340
+ 768, 1536, 2.0, 41,
341
+ 'SeeCoder-Anime', 'Oam-v2', 'openpose_v11p', ],
342
+ [
343
+ 'assets/examples-anime/random1.jpg',
344
+ 'assets/examples-anime/pose.png',
345
+ 'openpose', False,
346
+ 768, 1536, 2.5, 28,
347
+ 'SeeCoder-Anime', 'Oam-v2', 'openpose_v11p', ],
348
+ [
349
+ 'assets/examples-anime/camping.jpg',
350
+ 'assets/examples-anime/pose.png',
351
+ 'openpose', False,
352
+ 768, 1536, 2.0, 35,
353
+ 'SeeCoder-Anime', 'Anything-v4', 'openpose_v11p', ],
354
+ [
355
+ 'assets/examples-anime/hanfu_girl.jpg',
356
+ 'assets/examples-anime/pose.png',
357
+ 'openpose', False,
358
+ 768, 1536, 2.0, 20,
359
+ 'SeeCoder-Anime', 'Anything-v4', 'openpose_v11p', ],
360
+ ]
361
+ return case
362
+
363
+ def interface():
364
+ with gr.Row():
365
+ with gr.Column():
366
+ img_input = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
367
+ with gr.Row():
368
+ out_width = gr.Slider(label="Width" , minimum=512, maximum=1536, value=512, step=64, visible=True)
369
+ out_height = gr.Slider(label="Height", minimum=512, maximum=1536, value=512, step=64, visible=True)
370
+ with gr.Row():
371
+ scl_lvl = gr.Slider(label="CFGScale", minimum=0, maximum=10, value=2, step=0.01, visible=True)
372
+ seed = gr.Number(20, label="Seed", precision=0)
373
+ with gr.Row():
374
+ tag_ctx = gr.Dropdown(label='Context Encoder', choices=[pi for pi in ctxencoder_path.keys()], value='SeeCoder')
375
+ tag_diffuser = gr.Dropdown(label='Diffuser', choices=[pi for pi in diffuser_path.keys()], value='Deliberate-v2.0')
376
+ button = gr.Button("Run")
377
+ with gr.Column():
378
+ ctl_input = gr.Image(label='Control Input', type='pil', elem_id='customized_imbox')
379
+ do_preprocess = gr.Checkbox(label='Preprocess', value=False)
380
+ with gr.Row():
381
+ ctl_method = gr.Dropdown(label='Preprocess Type', choices=preprocess_method, value='canny')
382
+ tag_ctl = gr.Dropdown(label='ControlNet', choices=[pi for pi in controlnet_path.keys()], value='canny')
383
+ with gr.Column():
384
+ img_output = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image+1)
385
+
386
+ tag_ctl.change(
387
+ pfd_inference.action_autoset_method,
388
+ inputs = [tag_ctl],
389
+ outputs = [ctl_method],)
390
+
391
+ ctl_input.change(
392
+ pfd_inference.action_autoset_hw,
393
+ inputs = [ctl_input],
394
+ outputs = [out_height, out_width],)
395
+
396
+ # tag_ctx.change(
397
+ # pfd_inference.action_load_ctx,
398
+ # inputs = [tag_ctx],
399
+ # outputs = [tag_ctx],)
400
+
401
+ # tag_diffuser.change(
402
+ # pfd_inference.action_load_diffuser,
403
+ # inputs = [tag_diffuser],
404
+ # outputs = [tag_diffuser],)
405
+
406
+ # tag_ctl.change(
407
+ # pfd_inference.action_load_ctl,
408
+ # inputs = [tag_ctl],
409
+ # outputs = [tag_ctl],)
410
+
411
+ button.click(
412
+ pfd_inference.action_inference,
413
+ inputs=[img_input, ctl_input, ctl_method, do_preprocess,
414
+ out_height, out_width, scl_lvl, seed,
415
+ tag_ctx, tag_diffuser, tag_ctl, ],
416
+ outputs=[img_output])
417
+
418
+ gr.Examples(
419
+ label='Examples',
420
+ examples=get_example(),
421
+ fn=pfd_inference.action_inference,
422
+ inputs=[img_input, ctl_input, ctl_method, do_preprocess,
423
+ out_height, out_width, scl_lvl, seed,
424
+ tag_ctx, tag_diffuser, tag_ctl, ],
425
+ outputs=[img_output],
426
+ cache_examples=cache_examples,)
427
+
428
+ #############
429
+ # Interface #
430
+ #############
431
+
432
+ css = """
433
+ #customized_imbox {
434
+ min-height: 450px;
435
+ }
436
+ #customized_imbox>div[data-testid="image"] {
437
+ min-height: 450px;
438
+ }
439
+ #customized_imbox>div[data-testid="image"]>div {
440
+ min-height: 450px;
441
+ }
442
+ #customized_imbox>div[data-testid="image"]>iframe {
443
+ min-height: 450px;
444
+ }
445
+ #customized_imbox>div.unpadded_box {
446
+ min-height: 450px;
447
+ }
448
+ #myinst {
449
+ font-size: 0.8rem;
450
+ margin: 0rem;
451
+ color: #6B7280;
452
+ }
453
+ #maskinst {
454
+ text-align: justify;
455
+ min-width: 1200px;
456
+ }
457
+ #maskinst>img {
458
+ min-width:399px;
459
+ max-width:450px;
460
+ vertical-align: top;
461
+ display: inline-block;
462
+ }
463
+ #maskinst:after {
464
+ content: "";
465
+ width: 100%;
466
+ display: inline-block;
467
+ }
468
+ """
469
+
470
+ if True:
471
+ with gr.Blocks(css=css) as demo:
472
+ gr.HTML(
473
+ """
474
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
475
+ <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
476
+ Prompt-Free Diffusion
477
+ </h1>
478
+ </div>
479
+ """)
480
+
481
+ interface()
482
+
483
+ # gr.HTML(
484
+ # """
485
+ # <div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
486
+ # <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
487
+ # <b>Version</b>: {}
488
+ # </h3>
489
+ # </div>
490
+ # """.format(' '+str(pfd_inference.pretrained)))
491
+
492
+ # demo.launch(server_name="0.0.0.0", server_port=7992)
493
+ # demo.launch()
494
+ demo.launch(debug=True)
configs/model/autokl.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ autokl:
2
+ symbol: autokl
3
+ find_unused_parameters: false
4
+
5
+ autokl_v1:
6
+ super_cfg: autokl
7
+ type: autoencoderkl
8
+ args:
9
+ embed_dim: 4
10
+ ddconfig:
11
+ double_z: true
12
+ z_channels: 4
13
+ resolution: 256
14
+ in_channels: 3
15
+ out_ch: 3
16
+ ch: 128
17
+ ch_mult: [1, 2, 4, 4]
18
+ num_res_blocks: 2
19
+ attn_resolutions: []
20
+ dropout: 0.0
21
+ lossconfig: null
22
+ pth: pretrained/kl-f8.pth
23
+
24
+ autokl_v2:
25
+ super_cfg: autokl_v1
26
+ pth: pretrained/pfd/vae/sd-v2-0-base-autokl.pth
configs/model/clip.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################
2
+ # clip for sd1 #
3
+ ################
4
+
5
+ clip:
6
+ symbol: clip
7
+ args: {}
8
+
9
+ clip_text_context_encoder_sdv1:
10
+ super_cfg: clip
11
+ type: clip_text_context_encoder_sdv1
12
+ args: {}
configs/model/controlnet.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ controlnet:
2
+ symbol: controlnet
3
+ type: controlnet
4
+ find_unused_parameters: false
5
+ args:
6
+ image_size: 32 # unused
7
+ in_channels: 4
8
+ hint_channels: 3
9
+ model_channels: 320
10
+ attention_resolutions: [ 4, 2, 1 ]
11
+ num_res_blocks: 2
12
+ channel_mult: [ 1, 2, 4, 4 ]
13
+ num_heads: 8
14
+ use_spatial_transformer: True
15
+ transformer_depth: 1
16
+ context_dim: 768
17
+ use_checkpoint: True
18
+ legacy: False
configs/model/openai_unet.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai_unet_sd:
2
+ type: openai_unet
3
+ args:
4
+ image_size: null # no use
5
+ in_channels: 4
6
+ out_channels: 4
7
+ model_channels: 320
8
+ attention_resolutions: [ 4, 2, 1 ]
9
+ num_res_blocks: [ 2, 2, 2, 2 ]
10
+ channel_mult: [ 1, 2, 4, 4 ]
11
+ # disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true
12
+ num_heads: 8
13
+ use_spatial_transformer: True
14
+ transformer_depth: 1
15
+ context_dim: 768
16
+ use_checkpoint: True
17
+ legacy: False
18
+
19
+ #########
20
+ # v1 2d #
21
+ #########
22
+
23
+ openai_unet_2d_v1:
24
+ type: openai_unet_2d_next
25
+ args:
26
+ in_channels: 4
27
+ out_channels: 4
28
+ model_channels: 320
29
+ attention_resolutions: [ 4, 2, 1 ]
30
+ num_res_blocks: [ 2, 2, 2, 2 ]
31
+ channel_mult: [ 1, 2, 4, 4 ]
32
+ num_heads: 8
33
+ context_dim: 768
34
+ use_checkpoint: False
35
+ parts: [global, data, context]
configs/model/pfd.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pfd_base:
2
+ symbol: pfd
3
+ find_unused_parameters: true
4
+ type: pfd
5
+ args:
6
+ beta_linear_start: 0.00085
7
+ beta_linear_end: 0.012
8
+ timesteps: 1000
9
+ use_ema: false
10
+
11
+ pfd_seecoder:
12
+ super_cfg: pfd_base
13
+ args:
14
+ vae_cfg_list:
15
+ - [image, MODEL(autokl_v2)]
16
+ ctx_cfg_list:
17
+ - [image, MODEL(seecoder)]
18
+ diffuser_cfg_list:
19
+ - [image, MODEL(openai_unet_2d_v1)]
20
+ latent_scale_factor:
21
+ image: 0.18215
22
+
23
+ pdf_seecoder_pa:
24
+ super_cfg: pfd_seecoder
25
+ args:
26
+ ctx_cfg_list:
27
+ - [image, MODEL(seecoder_pa)]
28
+
29
+ pfd_seecoder_with_controlnet:
30
+ super_cfg: pfd_seecoder
31
+ type: pfd_with_control
32
+ args:
33
+ ctl_cfg: MODEL(controlnet)
configs/model/seecoder.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seecoder_base:
2
+ symbol: seecoder
3
+ args: {}
4
+
5
+ seecoder:
6
+ super_cfg: seecoder_base
7
+ type: seecoder
8
+ args:
9
+ imencoder_cfg : MODEL(swin_large)
10
+ imdecoder_cfg : MODEL(seecoder_decoder)
11
+ qtransformer_cfg : MODEL(seecoder_query_transformer)
12
+
13
+ seecoder_pa:
14
+ super_cfg: seet
15
+ type: seecoder
16
+ args:
17
+ imencoder_cfg : MODEL(swin_large)
18
+ imdecoder_cfg : MODEL(seecoder_decoder)
19
+ qtransformer_cfg : MODEL(seecoder_query_transformer_position_aware)
20
+
21
+ ###########
22
+ # decoder #
23
+ ###########
24
+
25
+ seecoder_decoder:
26
+ super_cfg: seecoder_base
27
+ type: seecoder_decoder
28
+ args:
29
+ inchannels:
30
+ res3: 384
31
+ res4: 768
32
+ res5: 1536
33
+ trans_input_tags: ['res3', 'res4', 'res5']
34
+ trans_dim: 768
35
+ trans_dropout: 0.1
36
+ trans_nheads: 8
37
+ trans_feedforward_dim: 1024
38
+ trans_num_layers: 6
39
+
40
+ #####################
41
+ # query_transformer #
42
+ #####################
43
+
44
+ seecoder_query_transformer:
45
+ super_cfg: seecoder_base
46
+ type: seecoder_query_transformer
47
+ args:
48
+ in_channels : 768
49
+ hidden_dim: 768
50
+ num_queries: [4, 144]
51
+ nheads: 8
52
+ num_layers: 9
53
+ feedforward_dim: 2048
54
+ pre_norm: False
55
+ num_feature_levels: 3
56
+ enforce_input_project: False
57
+ with_fea2d_pos: false
58
+
59
+ seecoder_query_transformer_position_aware:
60
+ super_cfg: seecoder_query_transformer
61
+ args:
62
+ with_fea2d_pos: true
configs/model/swin.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ swin:
2
+ symbol: swin
3
+ args: {}
4
+
5
+ swin_base:
6
+ super_cfg: swin
7
+ type: swin
8
+ args:
9
+ embed_dim: 128
10
+ depths: [ 2, 2, 18, 2 ]
11
+ num_heads: [ 4, 8, 16, 32 ]
12
+ window_size: 7
13
+ ape: False
14
+ drop_path_rate: 0.3
15
+ patch_norm: True
16
+ pretrained: pretrained/swin/swin_base_patch4_window7_224_22k.pth
17
+ strict_sd: False
18
+
19
+ swin_large:
20
+ super_cfg: swin
21
+ type: swin
22
+ args:
23
+ embed_dim: 192
24
+ depths: [ 2, 2, 18, 2 ]
25
+ num_heads: [ 6, 12, 24, 48 ]
26
+ window_size: 12
27
+ ape: False
28
+ drop_path_rate: 0.3
29
+ patch_norm: True
30
+ pretrained: pretrained/swin/swin_large_patch4_window12_384_22k.pth
31
+ strict_sd: False
32
+
lib/__init__.py ADDED
File without changes
lib/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (150 Bytes). View file
 
lib/__pycache__/cfg_helper.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
lib/__pycache__/cfg_holder.cpython-310.pyc ADDED
Binary file (1.22 kB). View file
 
lib/__pycache__/log_service.cpython-310.pyc ADDED
Binary file (5.01 kB). View file
 
lib/__pycache__/sync.cpython-310.pyc ADDED
Binary file (7.51 kB). View file
 
lib/cfg_helper.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import shutil
4
+ import copy
5
+ import time
6
+ import pprint
7
+ import numpy as np
8
+ import torch
9
+ import matplotlib
10
+ import argparse
11
+ import json
12
+ import yaml
13
+ from easydict import EasyDict as edict
14
+
15
+ from .model_zoo import get_model
16
+
17
+ ############
18
+ # cfg_bank #
19
+ ############
20
+
21
+ def cfg_solvef(cmd, root):
22
+ if not isinstance(cmd, str):
23
+ return cmd
24
+
25
+ if cmd.find('SAME')==0:
26
+ zoom = root
27
+ p = cmd[len('SAME'):].strip('()').split('.')
28
+ p = [pi.strip() for pi in p]
29
+ for pi in p:
30
+ try:
31
+ pi = int(pi)
32
+ except:
33
+ pass
34
+
35
+ try:
36
+ zoom = zoom[pi]
37
+ except:
38
+ return cmd
39
+ return cfg_solvef(zoom, root)
40
+
41
+ if cmd.find('SEARCH')==0:
42
+ zoom = root
43
+ p = cmd[len('SEARCH'):].strip('()').split('.')
44
+ p = [pi.strip() for pi in p]
45
+ find = True
46
+ # Depth first search
47
+ for pi in p:
48
+ try:
49
+ pi = int(pi)
50
+ except:
51
+ pass
52
+
53
+ try:
54
+ zoom = zoom[pi]
55
+ except:
56
+ find = False
57
+ break
58
+
59
+ if find:
60
+ return cfg_solvef(zoom, root)
61
+ else:
62
+ if isinstance(root, dict):
63
+ for ri in root:
64
+ rv = cfg_solvef(cmd, root[ri])
65
+ if rv != cmd:
66
+ return rv
67
+ if isinstance(root, list):
68
+ for ri in root:
69
+ rv = cfg_solvef(cmd, ri)
70
+ if rv != cmd:
71
+ return rv
72
+ return cmd
73
+
74
+ if cmd.find('MODEL')==0:
75
+ goto = cmd[len('MODEL'):].strip('()')
76
+ return model_cfg_bank()(goto)
77
+
78
+ if cmd.find('DATASET')==0:
79
+ goto = cmd[len('DATASET'):].strip('()')
80
+ return dataset_cfg_bank()(goto)
81
+
82
+ return cmd
83
+
84
+ def cfg_solve(cfg, cfg_root):
85
+ # The function solve cfg element such that
86
+ # all sorrogate input are settled.
87
+ # (i.e. SAME(***) )
88
+ if isinstance(cfg, list):
89
+ for i in range(len(cfg)):
90
+ if isinstance(cfg[i], (list, dict)):
91
+ cfg[i] = cfg_solve(cfg[i], cfg_root)
92
+ else:
93
+ cfg[i] = cfg_solvef(cfg[i], cfg_root)
94
+ if isinstance(cfg, dict):
95
+ for k in cfg:
96
+ if isinstance(cfg[k], (list, dict)):
97
+ cfg[k] = cfg_solve(cfg[k], cfg_root)
98
+ else:
99
+ cfg[k] = cfg_solvef(cfg[k], cfg_root)
100
+ return cfg
101
+
102
+ class model_cfg_bank(object):
103
+ def __init__(self):
104
+ self.cfg_dir = osp.join('configs', 'model')
105
+ self.cfg_bank = edict()
106
+
107
+ def __call__(self, name):
108
+ if name not in self.cfg_bank:
109
+ cfg_path = self.get_yaml_path(name)
110
+ with open(cfg_path, 'r') as f:
111
+ cfg_new = yaml.load(
112
+ f, Loader=yaml.FullLoader)
113
+ cfg_new = edict(cfg_new)
114
+ self.cfg_bank.update(cfg_new)
115
+
116
+ cfg = self.cfg_bank[name]
117
+ cfg.name = name
118
+ if 'super_cfg' not in cfg:
119
+ cfg = cfg_solve(cfg, cfg)
120
+ self.cfg_bank[name] = cfg
121
+ return copy.deepcopy(cfg)
122
+
123
+ super_cfg = self.__call__(cfg.super_cfg)
124
+ # unlike other field,
125
+ # args will not be replaced but update.
126
+ if 'args' in cfg:
127
+ if 'args' in super_cfg:
128
+ super_cfg.args.update(cfg.args)
129
+ else:
130
+ super_cfg.args = cfg.args
131
+ cfg.pop('args')
132
+
133
+ super_cfg.update(cfg)
134
+ super_cfg.pop('super_cfg')
135
+ cfg = super_cfg
136
+ try:
137
+ delete_args = cfg.pop('delete_args')
138
+ except:
139
+ delete_args = []
140
+
141
+ for dargs in delete_args:
142
+ cfg.args.pop(dargs)
143
+
144
+ cfg = cfg_solve(cfg, cfg)
145
+ self.cfg_bank[name] = cfg
146
+ return copy.deepcopy(cfg)
147
+
148
+ def get_yaml_path(self, name):
149
+ if name.find('openai_unet')==0:
150
+ return osp.join(
151
+ self.cfg_dir, 'openai_unet.yaml')
152
+ elif name.find('clip')==0:
153
+ return osp.join(
154
+ self.cfg_dir, 'clip.yaml')
155
+ elif name.find('autokl')==0:
156
+ return osp.join(
157
+ self.cfg_dir, 'autokl.yaml')
158
+ elif name.find('controlnet')==0:
159
+ return osp.join(
160
+ self.cfg_dir, 'controlnet.yaml')
161
+ elif name.find('swin')==0:
162
+ return osp.join(
163
+ self.cfg_dir, 'swin.yaml')
164
+ elif name.find('pfd')==0:
165
+ return osp.join(
166
+ self.cfg_dir, 'pfd.yaml')
167
+ elif name.find('seecoder')==0:
168
+ return osp.join(
169
+ self.cfg_dir, 'seecoder.yaml')
170
+ else:
171
+ raise ValueError
172
+
173
+ class dataset_cfg_bank(object):
174
+ def __init__(self):
175
+ self.cfg_dir = osp.join('configs', 'dataset')
176
+ self.cfg_bank = edict()
177
+
178
+ def __call__(self, name):
179
+ if name not in self.cfg_bank:
180
+ cfg_path = self.get_yaml_path(name)
181
+ with open(cfg_path, 'r') as f:
182
+ cfg_new = yaml.load(
183
+ f, Loader=yaml.FullLoader)
184
+ cfg_new = edict(cfg_new)
185
+ self.cfg_bank.update(cfg_new)
186
+
187
+ cfg = self.cfg_bank[name]
188
+ cfg.name = name
189
+ if cfg.get('super_cfg', None) is None:
190
+ cfg = cfg_solve(cfg, cfg)
191
+ self.cfg_bank[name] = cfg
192
+ return copy.deepcopy(cfg)
193
+
194
+ super_cfg = self.__call__(cfg.super_cfg)
195
+ super_cfg.update(cfg)
196
+ cfg = super_cfg
197
+ cfg.super_cfg = None
198
+ try:
199
+ delete = cfg.pop('delete')
200
+ except:
201
+ delete = []
202
+
203
+ for dargs in delete:
204
+ cfg.pop(dargs)
205
+
206
+ cfg = cfg_solve(cfg, cfg)
207
+ self.cfg_bank[name] = cfg
208
+ return copy.deepcopy(cfg)
209
+
210
+ def get_yaml_path(self, name):
211
+ if name.find('cityscapes')==0:
212
+ return osp.join(
213
+ self.cfg_dir, 'cityscapes.yaml')
214
+ elif name.find('div2k')==0:
215
+ return osp.join(
216
+ self.cfg_dir, 'div2k.yaml')
217
+ elif name.find('gandiv2k')==0:
218
+ return osp.join(
219
+ self.cfg_dir, 'gandiv2k.yaml')
220
+ elif name.find('srbenchmark')==0:
221
+ return osp.join(
222
+ self.cfg_dir, 'srbenchmark.yaml')
223
+ elif name.find('imagedir')==0:
224
+ return osp.join(
225
+ self.cfg_dir, 'imagedir.yaml')
226
+ elif name.find('places2')==0:
227
+ return osp.join(
228
+ self.cfg_dir, 'places2.yaml')
229
+ elif name.find('ffhq')==0:
230
+ return osp.join(
231
+ self.cfg_dir, 'ffhq.yaml')
232
+ elif name.find('imcpt')==0:
233
+ return osp.join(
234
+ self.cfg_dir, 'imcpt.yaml')
235
+ elif name.find('texture')==0:
236
+ return osp.join(
237
+ self.cfg_dir, 'texture.yaml')
238
+ elif name.find('openimages')==0:
239
+ return osp.join(
240
+ self.cfg_dir, 'openimages.yaml')
241
+ elif name.find('laion2b')==0:
242
+ return osp.join(
243
+ self.cfg_dir, 'laion2b.yaml')
244
+ elif name.find('laionart')==0:
245
+ return osp.join(
246
+ self.cfg_dir, 'laionart.yaml')
247
+ elif name.find('celeba')==0:
248
+ return osp.join(
249
+ self.cfg_dir, 'celeba.yaml')
250
+ elif name.find('coyo')==0:
251
+ return osp.join(
252
+ self.cfg_dir, 'coyo.yaml')
253
+ elif name.find('pafc')==0:
254
+ return osp.join(
255
+ self.cfg_dir, 'pafc.yaml')
256
+ elif name.find('coco')==0:
257
+ return osp.join(
258
+ self.cfg_dir, 'coco.yaml')
259
+ elif name.find('genai')==0:
260
+ return osp.join(
261
+ self.cfg_dir, 'genai.yaml')
262
+ else:
263
+ raise ValueError
264
+
265
+ class experiment_cfg_bank(object):
266
+ def __init__(self):
267
+ self.cfg_dir = osp.join('configs', 'experiment')
268
+ self.cfg_bank = edict()
269
+
270
+ def __call__(self, name):
271
+ if name not in self.cfg_bank:
272
+ cfg_path = self.get_yaml_path(name)
273
+ with open(cfg_path, 'r') as f:
274
+ cfg = yaml.load(
275
+ f, Loader=yaml.FullLoader)
276
+ cfg = edict(cfg)
277
+
278
+ cfg = cfg_solve(cfg, cfg)
279
+ cfg = cfg_solve(cfg, cfg)
280
+ # twice for SEARCH
281
+ self.cfg_bank[name] = cfg
282
+ return copy.deepcopy(cfg)
283
+
284
+ def get_yaml_path(self, name):
285
+ return osp.join(
286
+ self.cfg_dir, name+'.yaml')
287
+
288
+ def load_cfg_yaml(path):
289
+ if osp.isfile(path):
290
+ cfg_path = path
291
+ elif osp.isfile(osp.join('configs', 'experiment', path)):
292
+ cfg_path = osp.join('configs', 'experiment', path)
293
+ elif osp.isfile(osp.join('configs', 'experiment', path+'.yaml')):
294
+ cfg_path = osp.join('configs', 'experiment', path+'.yaml')
295
+ else:
296
+ assert False, 'No such config!'
297
+
298
+ with open(cfg_path, 'r') as f:
299
+ cfg = yaml.load(f, Loader=yaml.FullLoader)
300
+ cfg = edict(cfg)
301
+ cfg = cfg_solve(cfg, cfg)
302
+ cfg = cfg_solve(cfg, cfg)
303
+ return cfg
304
+
305
+ ##############
306
+ # cfg_helper #
307
+ ##############
308
+
309
+ def get_experiment_id(ref=None):
310
+ if ref is None:
311
+ time.sleep(0.5)
312
+ return int(time.time()*100)
313
+ else:
314
+ try:
315
+ return int(ref)
316
+ except:
317
+ pass
318
+
319
+ _, ref = osp.split(ref)
320
+ ref = ref.split('_')[0]
321
+ try:
322
+ return int(ref)
323
+ except:
324
+ assert False, 'Invalid experiment ID!'
325
+
326
+ def record_resume_cfg(path):
327
+ cnt = 0
328
+ while True:
329
+ if osp.exists(path+'.{:04d}'.format(cnt)):
330
+ cnt += 1
331
+ continue
332
+ shutil.copyfile(path, path+'.{:04d}'.format(cnt))
333
+ break
334
+
335
+ def get_command_line_args():
336
+ parser = argparse.ArgumentParser()
337
+ parser.add_argument('--debug', action='store_true', default=False)
338
+ parser.add_argument('--config', type=str)
339
+ parser.add_argument('--gpu', nargs='+', type=int)
340
+
341
+ parser.add_argument('--node_rank', type=int)
342
+ parser.add_argument('--node_list', nargs='+', type=str)
343
+ parser.add_argument('--nodes', type=int)
344
+ parser.add_argument('--addr', type=str, default='127.0.0.1')
345
+ parser.add_argument('--port', type=int, default=11233)
346
+
347
+ parser.add_argument('--signature', nargs='+', type=str)
348
+ parser.add_argument('--seed', type=int)
349
+
350
+ parser.add_argument('--eval', type=str)
351
+ parser.add_argument('--eval_subdir', type=str)
352
+ parser.add_argument('--pretrained', type=str)
353
+
354
+ parser.add_argument('--resume_dir', type=str)
355
+ parser.add_argument('--resume_step', type=int)
356
+ parser.add_argument('--resume_weight', type=str)
357
+
358
+ args = parser.parse_args()
359
+
360
+ # Special handling the resume
361
+ if args.resume_dir is not None:
362
+ cfg = edict()
363
+ cfg.env = edict()
364
+ cfg.env.debug = args.debug
365
+ cfg.env.resume = edict()
366
+ cfg.env.resume.dir = args.resume_dir
367
+ cfg.env.resume.step = args.resume_step
368
+ cfg.env.resume.weight = args.resume_weight
369
+ return cfg
370
+
371
+ cfg = load_cfg_yaml(args.config)
372
+ cfg.env.debug = args.debug
373
+ cfg.env.gpu_device = [0] if args.gpu is None else list(args.gpu)
374
+ cfg.env.master_addr = args.addr
375
+ cfg.env.master_port = args.port
376
+ cfg.env.dist_url = 'tcp://{}:{}'.format(args.addr, args.port)
377
+
378
+ if args.node_list is None:
379
+ cfg.env.node_rank = 0 if args.node_rank is None else args.node_rank
380
+ cfg.env.nodes = 1 if args.nodes is None else args.nodes
381
+ else:
382
+ import socket
383
+ hostname = socket.gethostname()
384
+ assert cfg.env.master_addr == args.node_list[0]
385
+ cfg.env.node_rank = args.node_list.index(hostname)
386
+ cfg.env.nodes = len(args.node_list)
387
+ cfg.env.node_list = args.node_list
388
+
389
+ istrain = False if args.eval is not None else True
390
+ isdebug = cfg.env.debug
391
+
392
+ if istrain:
393
+ if isdebug:
394
+ cfg.env.experiment_id = 999999999999
395
+ cfg.train.signature = ['debug']
396
+ else:
397
+ cfg.env.experiment_id = get_experiment_id()
398
+ if args.signature is not None:
399
+ cfg.train.signature = args.signature
400
+ else:
401
+ if 'train' in cfg:
402
+ cfg.pop('train')
403
+ cfg.env.experiment_id = get_experiment_id(args.eval)
404
+ if args.signature is not None:
405
+ cfg.eval.signature = args.signature
406
+
407
+ if isdebug and (args.eval is None):
408
+ cfg.env.experiment_id = 999999999999
409
+ cfg.eval.signature = ['debug']
410
+
411
+ if args.eval_subdir is not None:
412
+ if isdebug:
413
+ cfg.eval.eval_subdir = 'debug'
414
+ else:
415
+ cfg.eval.eval_subdir = args.eval_subdir
416
+ if args.pretrained is not None:
417
+ cfg.eval.pretrained = args.pretrained
418
+ # The override pretrained over the setting in cfg.model
419
+
420
+ if args.seed is not None:
421
+ cfg.env.rnd_seed = args.seed
422
+
423
+ return cfg
424
+
425
+ def cfg_initiates(cfg):
426
+ cfge = cfg.env
427
+ isdebug = cfge.debug
428
+ isresume = 'resume' in cfge
429
+ istrain = 'train' in cfg
430
+ haseval = 'eval' in cfg
431
+ cfgt = cfg.train if istrain else None
432
+ cfgv = cfg.eval if haseval else None
433
+
434
+ ###############################
435
+ # get some environment params #
436
+ ###############################
437
+
438
+ cfge.computer = os.uname()
439
+ cfge.torch_version = str(torch.__version__)
440
+
441
+ ##########
442
+ # resume #
443
+ ##########
444
+
445
+ if isresume:
446
+ resume_cfg_path = osp.join(cfge.resume.dir, 'config.yaml')
447
+ record_resume_cfg(resume_cfg_path)
448
+ with open(resume_cfg_path, 'r') as f:
449
+ cfg_resume = yaml.load(f, Loader=yaml.FullLoader)
450
+ cfg_resume = edict(cfg_resume)
451
+ cfg_resume.env.update(cfge)
452
+ cfg = cfg_resume
453
+ cfge = cfg.env
454
+ log_file = cfg.train.log_file
455
+
456
+ print('')
457
+ print('##########')
458
+ print('# resume #')
459
+ print('##########')
460
+ print('')
461
+ with open(log_file, 'a') as f:
462
+ print('', file=f)
463
+ print('##########', file=f)
464
+ print('# resume #', file=f)
465
+ print('##########', file=f)
466
+ print('', file=f)
467
+
468
+ pprint.pprint(cfg)
469
+ with open(log_file, 'a') as f:
470
+ pprint.pprint(cfg, f)
471
+
472
+ ####################
473
+ # node distributed #
474
+ ####################
475
+
476
+ if cfg.env.master_addr!='127.0.0.1':
477
+ os.environ['MASTER_ADDR'] = cfge.master_addr
478
+ os.environ['MASTER_PORT'] = '{}'.format(cfge.master_port)
479
+ if cfg.env.dist_backend=='nccl':
480
+ os.environ['NCCL_SOCKET_FAMILY'] = 'AF_INET'
481
+ if cfg.env.dist_backend=='gloo':
482
+ os.environ['GLOO_SOCKET_FAMILY'] = 'AF_INET'
483
+
484
+ #######################
485
+ # cuda visible device #
486
+ #######################
487
+
488
+ os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
489
+ [str(gid) for gid in cfge.gpu_device])
490
+
491
+ #####################
492
+ # return resume cfg #
493
+ #####################
494
+
495
+ if isresume:
496
+ return cfg
497
+
498
+ #############################################
499
+ # some misc setting that not need in resume #
500
+ #############################################
501
+
502
+ cfgm = cfg.model
503
+ cfge.gpu_count = len(cfge.gpu_device)
504
+
505
+ ##########################################
506
+ # align batch size and num worker config #
507
+ ##########################################
508
+
509
+ gpu_n = cfge.gpu_count * cfge.nodes
510
+ def align_batch_size(bs, bs_per_gpu):
511
+ assert (bs is not None) or (bs_per_gpu is not None)
512
+ bs = bs_per_gpu * gpu_n if bs is None else bs
513
+ bs_per_gpu = bs // gpu_n if bs_per_gpu is None else bs_per_gpu
514
+ assert (bs == bs_per_gpu * gpu_n)
515
+ return bs, bs_per_gpu
516
+
517
+ if istrain:
518
+ cfgt.batch_size, cfgt.batch_size_per_gpu = \
519
+ align_batch_size(cfgt.batch_size, cfgt.batch_size_per_gpu)
520
+ cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu = \
521
+ align_batch_size(cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu)
522
+ if haseval:
523
+ cfgv.batch_size, cfgv.batch_size_per_gpu = \
524
+ align_batch_size(cfgv.batch_size, cfgv.batch_size_per_gpu)
525
+ cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu = \
526
+ align_batch_size(cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu)
527
+
528
+ ##################
529
+ # create log dir #
530
+ ##################
531
+
532
+ if istrain:
533
+ if not isdebug:
534
+ sig = cfgt.get('signature', [])
535
+ sig = sig + ['s{}'.format(cfge.rnd_seed)]
536
+ else:
537
+ sig = ['debug']
538
+
539
+ log_dir = [
540
+ cfge.log_root_dir,
541
+ '{}_{}'.format(cfgm.symbol, cfgt.dataset.symbol),
542
+ '_'.join([str(cfge.experiment_id)] + sig)
543
+ ]
544
+ log_dir = osp.join(*log_dir)
545
+ log_file = osp.join(log_dir, 'train.log')
546
+ if not osp.exists(log_file):
547
+ os.makedirs(osp.dirname(log_file))
548
+ cfgt.log_dir = log_dir
549
+ cfgt.log_file = log_file
550
+
551
+ if haseval:
552
+ cfgv.log_dir = log_dir
553
+ cfgv.log_file = log_file
554
+ else:
555
+ model_symbol = cfgm.symbol
556
+ if cfgv.get('dataset', None) is None:
557
+ dataset_symbol = 'nodataset'
558
+ else:
559
+ dataset_symbol = cfgv.dataset.symbol
560
+
561
+ log_dir = osp.join(cfge.log_root_dir, '{}_{}'.format(model_symbol, dataset_symbol))
562
+ exp_dir = search_experiment_folder(log_dir, cfge.experiment_id)
563
+ if exp_dir is None:
564
+ if not isdebug:
565
+ sig = cfgv.get('signature', []) + ['evalonly']
566
+ else:
567
+ sig = ['debug']
568
+ exp_dir = '_'.join([str(cfge.experiment_id)] + sig)
569
+
570
+ eval_subdir = cfgv.get('eval_subdir', None)
571
+ # override subdir in debug mode (if eval_subdir is set)
572
+ eval_subdir = 'debug' if (eval_subdir is not None) and isdebug else eval_subdir
573
+
574
+ if eval_subdir is not None:
575
+ log_dir = osp.join(log_dir, exp_dir, eval_subdir)
576
+ else:
577
+ log_dir = osp.join(log_dir, exp_dir)
578
+
579
+ disable_log_override = cfgv.get('disable_log_override', False)
580
+ if osp.isdir(log_dir):
581
+ if disable_log_override:
582
+ assert False, 'Override an exsited log_dir is disabled at [{}]'.format(log_dir)
583
+ else:
584
+ os.makedirs(log_dir)
585
+
586
+ log_file = osp.join(log_dir, 'eval.log')
587
+ cfgv.log_dir = log_dir
588
+ cfgv.log_file = log_file
589
+
590
+ ######################
591
+ # print and save cfg #
592
+ ######################
593
+
594
+ pprint.pprint(cfg)
595
+ if cfge.node_rank==0:
596
+ with open(log_file, 'w') as f:
597
+ pprint.pprint(cfg, f)
598
+ with open(osp.join(log_dir, 'config.yaml'), 'w') as f:
599
+ yaml.dump(edict_2_dict(cfg), f)
600
+ else:
601
+ with open(osp.join(log_dir, 'config.yaml.{}'.format(cfge.node_rank)), 'w') as f:
602
+ yaml.dump(edict_2_dict(cfg), f)
603
+
604
+ #############
605
+ # save code #
606
+ #############
607
+
608
+ save_code = False
609
+ if istrain:
610
+ save_code = cfgt.get('save_code', False)
611
+ elif haseval:
612
+ save_code = cfgv.get('save_code', False)
613
+ save_code = save_code and (cfge.node_rank==0)
614
+
615
+ if save_code:
616
+ codedir = osp.join(log_dir, 'code')
617
+ if osp.exists(codedir):
618
+ shutil.rmtree(codedir)
619
+ for d in ['configs', 'lib']:
620
+ fromcodedir = d
621
+ tocodedir = osp.join(codedir, d)
622
+ shutil.copytree(
623
+ fromcodedir, tocodedir,
624
+ ignore=shutil.ignore_patterns(
625
+ '*__pycache__*', '*build*'))
626
+ for codei in os.listdir('.'):
627
+ if osp.splitext(codei)[1] == 'py':
628
+ shutil.copy(codei, codedir)
629
+
630
+ #######################
631
+ # set matplotlib mode #
632
+ #######################
633
+
634
+ if 'matplotlib_mode' in cfge:
635
+ try:
636
+ matplotlib.use(cfge.matplotlib_mode)
637
+ except:
638
+ print('Warning: matplotlib mode [{}] failed to be set!'.format(cfge.matplotlib_mode))
639
+
640
+ return cfg
641
+
642
+ def edict_2_dict(x):
643
+ if isinstance(x, dict):
644
+ xnew = {}
645
+ for k in x:
646
+ xnew[k] = edict_2_dict(x[k])
647
+ return xnew
648
+ elif isinstance(x, list):
649
+ xnew = []
650
+ for i in range(len(x)):
651
+ xnew.append( edict_2_dict(x[i]) )
652
+ return xnew
653
+ else:
654
+ return x
655
+
656
+ def search_experiment_folder(root, exid):
657
+ target = None
658
+ for fi in os.listdir(root):
659
+ if not osp.isdir(osp.join(root, fi)):
660
+ continue
661
+ if int(fi.split('_')[0]) == exid:
662
+ if target is not None:
663
+ return None # duplicated
664
+ elif target is None:
665
+ target = fi
666
+ return target
lib/cfg_holder.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ def singleton(class_):
4
+ instances = {}
5
+ def getinstance(*args, **kwargs):
6
+ if class_ not in instances:
7
+ instances[class_] = class_(*args, **kwargs)
8
+ return instances[class_]
9
+ return getinstance
10
+
11
+ ##############
12
+ # cfg_holder #
13
+ ##############
14
+
15
+ @singleton
16
+ class cfg_unique_holder(object):
17
+ def __init__(self):
18
+ self.cfg = None
19
+ # this is use to track the main codes.
20
+ self.code = set()
21
+ def save_cfg(self, cfg):
22
+ self.cfg = copy.deepcopy(cfg)
23
+ def add_code(self, code):
24
+ """
25
+ A new main code is reached and
26
+ its name is added.
27
+ """
28
+ self.code.add(code)
lib/log_service.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timeit
2
+ import numpy as np
3
+ import os
4
+ import os.path as osp
5
+ import shutil
6
+ import copy
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.distributed as dist
10
+ from .cfg_holder import cfg_unique_holder as cfguh
11
+ from . import sync
12
+
13
+ def print_log(*console_info):
14
+ grank, lrank, _ = sync.get_rank('all')
15
+ if lrank!=0:
16
+ return
17
+
18
+ console_info = [str(i) for i in console_info]
19
+ console_info = ' '.join(console_info)
20
+ print(console_info)
21
+
22
+ if grank!=0:
23
+ return
24
+
25
+ log_file = None
26
+ try:
27
+ log_file = cfguh().cfg.train.log_file
28
+ except:
29
+ try:
30
+ log_file = cfguh().cfg.eval.log_file
31
+ except:
32
+ return
33
+ if log_file is not None:
34
+ with open(log_file, 'a') as f:
35
+ f.write(console_info + '\n')
36
+
37
+ class distributed_log_manager(object):
38
+ def __init__(self):
39
+ self.sum = {}
40
+ self.cnt = {}
41
+ self.time_check = timeit.default_timer()
42
+
43
+ cfgt = cfguh().cfg.train
44
+ self.ddp = sync.is_ddp()
45
+ self.grank, self.lrank, _ = sync.get_rank('all')
46
+ self.gwsize = sync.get_world_size('global')
47
+
48
+ use_tensorboard = cfgt.get('log_tensorboard', False) and (self.grank==0)
49
+
50
+ self.tb = None
51
+ if use_tensorboard:
52
+ import tensorboardX
53
+ monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard')
54
+ self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir))
55
+
56
+ def accumulate(self, n, **data):
57
+ if n < 0:
58
+ raise ValueError
59
+
60
+ for itemn, di in data.items():
61
+ if itemn in self.sum:
62
+ self.sum[itemn] += di * n
63
+ self.cnt[itemn] += n
64
+ else:
65
+ self.sum[itemn] = di * n
66
+ self.cnt[itemn] = n
67
+
68
+ def get_mean_value_dict(self):
69
+ value_gather = [
70
+ self.sum[itemn]/self.cnt[itemn] \
71
+ for itemn in sorted(self.sum.keys()) ]
72
+
73
+ value_gather_tensor = torch.FloatTensor(value_gather).to(self.lrank)
74
+ if self.ddp:
75
+ dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM)
76
+ value_gather_tensor /= self.gwsize
77
+
78
+ mean = {}
79
+ for idx, itemn in enumerate(sorted(self.sum.keys())):
80
+ mean[itemn] = value_gather_tensor[idx].item()
81
+ return mean
82
+
83
+ def tensorboard_log(self, step, data, mode='train', **extra):
84
+ if self.tb is None:
85
+ return
86
+ if mode == 'train':
87
+ self.tb.add_scalar('other/epochn', extra['epochn'], step)
88
+ if ('lr' in extra) and (extra['lr'] is not None):
89
+ self.tb.add_scalar('other/lr', extra['lr'], step)
90
+ for itemn, di in data.items():
91
+ if itemn.find('loss') == 0:
92
+ self.tb.add_scalar('loss/'+itemn, di, step)
93
+ elif itemn == 'Loss':
94
+ self.tb.add_scalar('Loss', di, step)
95
+ else:
96
+ self.tb.add_scalar('other/'+itemn, di, step)
97
+ elif mode == 'eval':
98
+ if isinstance(data, dict):
99
+ for itemn, di in data.items():
100
+ self.tb.add_scalar('eval/'+itemn, di, step)
101
+ else:
102
+ self.tb.add_scalar('eval', data, step)
103
+ return
104
+
105
+ def train_summary(self, itern, epochn, samplen, lr, tbstep=None):
106
+ console_info = [
107
+ 'Iter:{}'.format(itern),
108
+ 'Epoch:{}'.format(epochn),
109
+ 'Sample:{}'.format(samplen),]
110
+
111
+ if lr is not None:
112
+ console_info += ['LR:{:.4E}'.format(lr)]
113
+
114
+ mean = self.get_mean_value_dict()
115
+
116
+ tbstep = itern if tbstep is None else tbstep
117
+ self.tensorboard_log(
118
+ tbstep, mean, mode='train',
119
+ itern=itern, epochn=epochn, lr=lr)
120
+
121
+ loss = mean.pop('Loss')
122
+ mean_info = ['Loss:{:.4f}'.format(loss)] + [
123
+ '{}:{:.4f}'.format(itemn, mean[itemn]) \
124
+ for itemn in sorted(mean.keys()) \
125
+ if itemn.find('loss') == 0
126
+ ]
127
+ console_info += mean_info
128
+ console_info.append('Time:{:.2f}s'.format(
129
+ timeit.default_timer() - self.time_check))
130
+ return ' , '.join(console_info)
131
+
132
+ def clear(self):
133
+ self.sum = {}
134
+ self.cnt = {}
135
+ self.time_check = timeit.default_timer()
136
+
137
+ def tensorboard_close(self):
138
+ if self.tb is not None:
139
+ self.tb.close()
140
+
141
+ # ----- also include some small utils -----
142
+
143
+ def torch_to_numpy(*argv):
144
+ if len(argv) > 1:
145
+ data = list(argv)
146
+ else:
147
+ data = argv[0]
148
+
149
+ if isinstance(data, torch.Tensor):
150
+ return data.to('cpu').detach().numpy()
151
+
152
+ elif isinstance(data, (list, tuple)):
153
+ out = []
154
+ for di in data:
155
+ out.append(torch_to_numpy(di))
156
+ return out
157
+
158
+ elif isinstance(data, dict):
159
+ out = {}
160
+ for ni, di in data.items():
161
+ out[ni] = torch_to_numpy(di)
162
+ return out
163
+
164
+ else:
165
+ return data
lib/model_zoo/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .common.get_model import get_model
2
+ from .common.get_optimizer import get_optimizer
3
+ from .common.get_scheduler import get_scheduler
4
+ from .common.utils import get_unit
lib/model_zoo/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (371 Bytes). View file
 
lib/model_zoo/__pycache__/attention.cpython-310.pyc ADDED
Binary file (15.8 kB). View file
 
lib/model_zoo/__pycache__/autokl.cpython-310.pyc ADDED
Binary file (6.08 kB). View file
 
lib/model_zoo/__pycache__/autokl_modules.cpython-310.pyc ADDED
Binary file (20.3 kB). View file
 
lib/model_zoo/__pycache__/autokl_utils.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
lib/model_zoo/__pycache__/controlnet.cpython-310.pyc ADDED
Binary file (13.1 kB). View file
 
lib/model_zoo/__pycache__/ddim.cpython-310.pyc ADDED
Binary file (7.89 kB). View file
 
lib/model_zoo/__pycache__/diffusion_utils.cpython-310.pyc ADDED
Binary file (9.53 kB). View file
 
lib/model_zoo/__pycache__/distributions.cpython-310.pyc ADDED
Binary file (3.76 kB). View file
 
lib/model_zoo/__pycache__/ema.cpython-310.pyc ADDED
Binary file (3.01 kB). View file
 
lib/model_zoo/__pycache__/openaimodel.cpython-310.pyc ADDED
Binary file (51.1 kB). View file
 
lib/model_zoo/__pycache__/pfd.cpython-310.pyc ADDED
Binary file (15.9 kB). View file
 
lib/model_zoo/__pycache__/seecoder.cpython-310.pyc ADDED
Binary file (16.6 kB). View file
 
lib/model_zoo/__pycache__/seecoder_utils.cpython-310.pyc ADDED
Binary file (4.7 kB). View file
 
lib/model_zoo/__pycache__/swin.cpython-310.pyc ADDED
Binary file (21.2 kB). View file
 
lib/model_zoo/attention.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ from .diffusion_utils import checkpoint
9
+
10
+ try:
11
+ import xformers
12
+ import xformers.ops
13
+ XFORMERS_IS_AVAILBLE = True
14
+ except:
15
+ XFORMERS_IS_AVAILBLE = False
16
+
17
+
18
+ def exists(val):
19
+ return val is not None
20
+
21
+
22
+ def uniq(arr):
23
+ return{el: True for el in arr}.keys()
24
+
25
+
26
+ def default(val, d):
27
+ if exists(val):
28
+ return val
29
+ return d() if isfunction(d) else d
30
+
31
+
32
+ def max_neg_value(t):
33
+ return -torch.finfo(t.dtype).max
34
+
35
+
36
+ def init_(tensor):
37
+ dim = tensor.shape[-1]
38
+ std = 1 / math.sqrt(dim)
39
+ tensor.uniform_(-std, std)
40
+ return tensor
41
+
42
+
43
+ # feedforward
44
+ class GEGLU(nn.Module):
45
+ def __init__(self, dim_in, dim_out):
46
+ super().__init__()
47
+ self.proj = nn.Linear(dim_in, dim_out * 2)
48
+
49
+ def forward(self, x):
50
+ x, gate = self.proj(x).chunk(2, dim=-1)
51
+ return x * F.gelu(gate)
52
+
53
+
54
+ class FeedForward(nn.Module):
55
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
56
+ super().__init__()
57
+ inner_dim = int(dim * mult)
58
+ dim_out = default(dim_out, dim)
59
+ project_in = nn.Sequential(
60
+ nn.Linear(dim, inner_dim),
61
+ nn.GELU()
62
+ ) if not glu else GEGLU(dim, inner_dim)
63
+
64
+ self.net = nn.Sequential(
65
+ project_in,
66
+ nn.Dropout(dropout),
67
+ nn.Linear(inner_dim, dim_out)
68
+ )
69
+
70
+ def forward(self, x):
71
+ return self.net(x)
72
+
73
+
74
+ def zero_module(module):
75
+ """
76
+ Zero out the parameters of a module and return it.
77
+ """
78
+ for p in module.parameters():
79
+ p.detach().zero_()
80
+ return module
81
+
82
+
83
+ def Normalize(in_channels):
84
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
85
+
86
+
87
+ class LinearAttention(nn.Module):
88
+ def __init__(self, dim, heads=4, dim_head=32):
89
+ super().__init__()
90
+ self.heads = heads
91
+ hidden_dim = dim_head * heads
92
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
93
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
94
+
95
+ def forward(self, x):
96
+ b, c, h, w = x.shape
97
+ qkv = self.to_qkv(x)
98
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
99
+ k = k.softmax(dim=-1)
100
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
101
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
102
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
103
+ return self.to_out(out)
104
+
105
+
106
+ class SpatialSelfAttention(nn.Module):
107
+ def __init__(self, in_channels):
108
+ super().__init__()
109
+ self.in_channels = in_channels
110
+
111
+ self.norm = Normalize(in_channels)
112
+ self.q = torch.nn.Conv2d(in_channels,
113
+ in_channels,
114
+ kernel_size=1,
115
+ stride=1,
116
+ padding=0)
117
+ self.k = torch.nn.Conv2d(in_channels,
118
+ in_channels,
119
+ kernel_size=1,
120
+ stride=1,
121
+ padding=0)
122
+ self.v = torch.nn.Conv2d(in_channels,
123
+ in_channels,
124
+ kernel_size=1,
125
+ stride=1,
126
+ padding=0)
127
+ self.proj_out = torch.nn.Conv2d(in_channels,
128
+ in_channels,
129
+ kernel_size=1,
130
+ stride=1,
131
+ padding=0)
132
+
133
+ def forward(self, x):
134
+ h_ = x
135
+ h_ = self.norm(h_)
136
+ q = self.q(h_)
137
+ k = self.k(h_)
138
+ v = self.v(h_)
139
+
140
+ # compute attention
141
+ b,c,h,w = q.shape
142
+ q = rearrange(q, 'b c h w -> b (h w) c')
143
+ k = rearrange(k, 'b c h w -> b c (h w)')
144
+ w_ = torch.einsum('bij,bjk->bik', q, k)
145
+
146
+ w_ = w_ * (int(c)**(-0.5))
147
+ w_ = torch.nn.functional.softmax(w_, dim=2)
148
+
149
+ # attend to values
150
+ v = rearrange(v, 'b c h w -> b c (h w)')
151
+ w_ = rearrange(w_, 'b i j -> b j i')
152
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
153
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
154
+ h_ = self.proj_out(h_)
155
+
156
+ return x+h_
157
+
158
+
159
+ class CrossAttention(nn.Module):
160
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
161
+ super().__init__()
162
+ inner_dim = dim_head * heads
163
+ context_dim = default(context_dim, query_dim)
164
+
165
+ self.scale = dim_head ** -0.5
166
+ self.heads = heads
167
+ self.inner_dim = inner_dim
168
+
169
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
170
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
171
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
172
+
173
+ self.to_out = nn.Sequential(
174
+ nn.Linear(inner_dim, query_dim),
175
+ nn.Dropout(dropout)
176
+ )
177
+
178
+ def forward(self, x, context=None, mask=None):
179
+ h = self.heads
180
+
181
+ q = self.to_q(x)
182
+ context = default(context, x)
183
+ k = self.to_k(context)
184
+ v = self.to_v(context)
185
+
186
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
187
+
188
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
189
+
190
+ if exists(mask):
191
+ mask = rearrange(mask, 'b ... -> b (...)')
192
+ max_neg_value = -torch.finfo(sim.dtype).max
193
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
194
+ sim.masked_fill_(~mask, max_neg_value)
195
+
196
+ # attention, what we cannot get enough of
197
+ attn = sim.softmax(dim=-1)
198
+
199
+ out = einsum('b i j, b j d -> b i d', attn, v)
200
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
201
+ return self.to_out(out)
202
+
203
+ def forward_next(self, x, context=None, mask=None):
204
+ assert mask is None, 'not supported yet'
205
+ x0 = rearrange(x, 'b n c -> n b c')
206
+ if context is not None:
207
+ c0 = rearrange(context, 'b n c -> n b c')
208
+ else:
209
+ c0 = x0
210
+ r, _ = F.multi_head_attention_forward(
211
+ x0, c0, c0,
212
+ embed_dim_to_check = self.inner_dim,
213
+ num_heads = self.heads,
214
+ in_proj_weight = None, in_proj_bias = None,
215
+ bias_k = None, bias_v = None,
216
+ add_zero_attn = False, dropout_p = 0,
217
+ out_proj_weight = self.to_out[0].weight,
218
+ out_proj_bias = self.to_out[0].bias,
219
+ use_separate_proj_weight = True,
220
+ q_proj_weight = self.to_q.weight,
221
+ k_proj_weight = self.to_k.weight,
222
+ v_proj_weight = self.to_v.weight,)
223
+ r = rearrange(r, 'n b c -> b n c')
224
+ r = self.to_out[1](r) # dropout
225
+ return r
226
+
227
+
228
+ class MemoryEfficientCrossAttention(nn.Module):
229
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
230
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
231
+ super().__init__()
232
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
233
+ f"{heads} heads.")
234
+ inner_dim = dim_head * heads
235
+ context_dim = default(context_dim, query_dim)
236
+
237
+ self.heads = heads
238
+ self.dim_head = dim_head
239
+
240
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
241
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
242
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
243
+
244
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
245
+ self.attention_op: Optional[Any] = None
246
+
247
+ def forward(self, x, context=None, mask=None):
248
+ q = self.to_q(x)
249
+ context = default(context, x)
250
+ k = self.to_k(context)
251
+ v = self.to_v(context)
252
+
253
+ b, _, _ = q.shape
254
+ q, k, v = map(
255
+ lambda t: t.unsqueeze(3)
256
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
257
+ .permute(0, 2, 1, 3)
258
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
259
+ .contiguous(),
260
+ (q, k, v),
261
+ )
262
+
263
+ # actually compute the attention, what we cannot get enough of
264
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
265
+
266
+ if exists(mask):
267
+ raise NotImplementedError
268
+ out = (
269
+ out.unsqueeze(0)
270
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
271
+ .permute(0, 2, 1, 3)
272
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
273
+ )
274
+ return self.to_out(out)
275
+
276
+
277
+ class BasicTransformerBlock(nn.Module):
278
+ ATTENTION_MODES = {
279
+ "softmax": CrossAttention, # vanilla attention
280
+ "softmax-xformers": MemoryEfficientCrossAttention
281
+ }
282
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
283
+ disable_self_attn=False):
284
+ super().__init__()
285
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
286
+ assert attn_mode in self.ATTENTION_MODES
287
+ attn_cls = self.ATTENTION_MODES[attn_mode]
288
+ self.disable_self_attn = disable_self_attn
289
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
290
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
291
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
292
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
293
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
294
+ self.norm1 = nn.LayerNorm(dim)
295
+ self.norm2 = nn.LayerNorm(dim)
296
+ self.norm3 = nn.LayerNorm(dim)
297
+ self.checkpoint = checkpoint
298
+
299
+ def forward(self, x, context=None):
300
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
301
+
302
+ def _forward(self, x, context=None):
303
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
304
+ x = self.attn2(self.norm2(x), context=context) + x
305
+ x = self.ff(self.norm3(x)) + x
306
+ return x
307
+
308
+
309
+ class SpatialTransformer(nn.Module):
310
+ """
311
+ Transformer block for image-like data.
312
+ First, project the input (aka embedding)
313
+ and reshape to b, t, d.
314
+ Then apply standard transformer action.
315
+ Finally, reshape to image
316
+ NEW: use_linear for more efficiency instead of the 1x1 convs
317
+ """
318
+ def __init__(self, in_channels, n_heads, d_head,
319
+ depth=1, dropout=0., context_dim=None,
320
+ disable_self_attn=False, use_linear=False,
321
+ use_checkpoint=True):
322
+ super().__init__()
323
+ if exists(context_dim) and not isinstance(context_dim, list):
324
+ context_dim = [context_dim]
325
+ self.in_channels = in_channels
326
+ inner_dim = n_heads * d_head
327
+ self.norm = Normalize(in_channels)
328
+ if not use_linear:
329
+ self.proj_in = nn.Conv2d(in_channels,
330
+ inner_dim,
331
+ kernel_size=1,
332
+ stride=1,
333
+ padding=0)
334
+ else:
335
+ self.proj_in = nn.Linear(in_channels, inner_dim)
336
+
337
+ self.transformer_blocks = nn.ModuleList(
338
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
339
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
340
+ for d in range(depth)]
341
+ )
342
+ if not use_linear:
343
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
344
+ in_channels,
345
+ kernel_size=1,
346
+ stride=1,
347
+ padding=0))
348
+ else:
349
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
350
+ self.use_linear = use_linear
351
+
352
+ def forward(self, x, context=None):
353
+ # note: if no context is given, cross-attention defaults to self-attention
354
+ if not isinstance(context, list):
355
+ context = [context]
356
+ b, c, h, w = x.shape
357
+ x_in = x
358
+ x = self.norm(x)
359
+ if not self.use_linear:
360
+ x = self.proj_in(x)
361
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
362
+ if self.use_linear:
363
+ x = self.proj_in(x)
364
+ for i, block in enumerate(self.transformer_blocks):
365
+ x = block(x, context=context[i])
366
+ if self.use_linear:
367
+ x = self.proj_out(x)
368
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
369
+ if not self.use_linear:
370
+ x = self.proj_out(x)
371
+ return x + x_in
372
+
373
+
374
+ ##########################
375
+ # transformer no context #
376
+ ##########################
377
+
378
+ class BasicTransformerBlockNoContext(nn.Module):
379
+ def __init__(self, dim, n_heads, d_head, dropout=0., gated_ff=True, checkpoint=True):
380
+ super().__init__()
381
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
382
+ dropout=dropout, context_dim=None)
383
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
384
+ self.attn2 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
385
+ dropout=dropout, context_dim=None)
386
+ self.norm1 = nn.LayerNorm(dim)
387
+ self.norm2 = nn.LayerNorm(dim)
388
+ self.norm3 = nn.LayerNorm(dim)
389
+ self.checkpoint = checkpoint
390
+
391
+ def forward(self, x):
392
+ return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
393
+
394
+ def _forward(self, x):
395
+ x = self.attn1(self.norm1(x)) + x
396
+ x = self.attn2(self.norm2(x)) + x
397
+ x = self.ff(self.norm3(x)) + x
398
+ return x
399
+
400
+ class SpatialTransformerNoContext(nn.Module):
401
+ """
402
+ Transformer block for image-like data.
403
+ First, project the input (aka embedding)
404
+ and reshape to b, t, d.
405
+ Then apply standard transformer action.
406
+ Finally, reshape to image
407
+ """
408
+ def __init__(self, in_channels, n_heads, d_head,
409
+ depth=1, dropout=0.,):
410
+ super().__init__()
411
+ self.in_channels = in_channels
412
+ inner_dim = n_heads * d_head
413
+ self.norm = Normalize(in_channels)
414
+
415
+ self.proj_in = nn.Conv2d(in_channels,
416
+ inner_dim,
417
+ kernel_size=1,
418
+ stride=1,
419
+ padding=0)
420
+
421
+ self.transformer_blocks = nn.ModuleList(
422
+ [BasicTransformerBlockNoContext(inner_dim, n_heads, d_head, dropout=dropout)
423
+ for d in range(depth)]
424
+ )
425
+
426
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
427
+ in_channels,
428
+ kernel_size=1,
429
+ stride=1,
430
+ padding=0))
431
+
432
+ def forward(self, x):
433
+ # note: if no context is given, cross-attention defaults to self-attention
434
+ b, c, h, w = x.shape
435
+ x_in = x
436
+ x = self.norm(x)
437
+ x = self.proj_in(x)
438
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
439
+ for block in self.transformer_blocks:
440
+ x = block(x)
441
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
442
+ x = self.proj_out(x)
443
+ return x + x_in
444
+
445
+
446
+ #######################################
447
+ # Spatial Transformer with Two Branch #
448
+ #######################################
449
+
450
+ class DualSpatialTransformer(nn.Module):
451
+ def __init__(self, in_channels, n_heads, d_head,
452
+ depth=1, dropout=0., context_dim=None,
453
+ disable_self_attn=False):
454
+ super().__init__()
455
+ self.in_channels = in_channels
456
+ inner_dim = n_heads * d_head
457
+
458
+ # First crossattn
459
+ self.norm_0 = Normalize(in_channels)
460
+ self.proj_in_0 = nn.Conv2d(
461
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
462
+ self.transformer_blocks_0 = nn.ModuleList(
463
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
464
+ disable_self_attn=disable_self_attn)
465
+ for d in range(depth)]
466
+ )
467
+ self.proj_out_0 = zero_module(nn.Conv2d(
468
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
469
+
470
+ # Second crossattn
471
+ self.norm_1 = Normalize(in_channels)
472
+ self.proj_in_1 = nn.Conv2d(
473
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
474
+ self.transformer_blocks_1 = nn.ModuleList(
475
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
476
+ disable_self_attn=disable_self_attn)
477
+ for d in range(depth)]
478
+ )
479
+ self.proj_out_1 = zero_module(nn.Conv2d(
480
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
481
+
482
+ def forward(self, x, context=None, which=None):
483
+ # note: if no context is given, cross-attention defaults to self-attention
484
+ b, c, h, w = x.shape
485
+ x_in = x
486
+ if which==0:
487
+ norm, proj_in, blocks, proj_out = \
488
+ self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
489
+ elif which==1:
490
+ norm, proj_in, blocks, proj_out = \
491
+ self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1
492
+ else:
493
+ # assert False, 'DualSpatialTransformer forward with a invalid which branch!'
494
+ # import numpy.random as npr
495
+ # rwhich = 0 if npr.rand() < which else 1
496
+ # context = context[rwhich]
497
+ # if rwhich==0:
498
+ # norm, proj_in, blocks, proj_out = \
499
+ # self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
500
+ # elif rwhich==1:
501
+ # norm, proj_in, blocks, proj_out = \
502
+ # self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1
503
+
504
+ # import numpy.random as npr
505
+ # rwhich = 0 if npr.rand() < 0.33 else 1
506
+ # if rwhich==0:
507
+ # context = context[rwhich]
508
+ # norm, proj_in, blocks, proj_out = \
509
+ # self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
510
+ # else:
511
+
512
+ norm, proj_in, blocks, proj_out = \
513
+ self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
514
+ x0 = norm(x)
515
+ x0 = proj_in(x0)
516
+ x0 = rearrange(x0, 'b c h w -> b (h w) c').contiguous()
517
+ for block in blocks:
518
+ x0 = block(x0, context=context[0])
519
+ x0 = rearrange(x0, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
520
+ x0 = proj_out(x0)
521
+
522
+ norm, proj_in, blocks, proj_out = \
523
+ self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1
524
+ x1 = norm(x)
525
+ x1 = proj_in(x1)
526
+ x1 = rearrange(x1, 'b c h w -> b (h w) c').contiguous()
527
+ for block in blocks:
528
+ x1 = block(x1, context=context[1])
529
+ x1 = rearrange(x1, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
530
+ x1 = proj_out(x1)
531
+ return x0*which + x1*(1-which) + x_in
532
+
533
+ x = norm(x)
534
+ x = proj_in(x)
535
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
536
+ for block in blocks:
537
+ x = block(x, context=context)
538
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
539
+ x = proj_out(x)
540
+ return x + x_in
lib/model_zoo/autokl.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+ from lib.model_zoo.common.get_model import get_model, register
6
+
7
+ # from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
8
+
9
+ from .autokl_modules import Encoder, Decoder
10
+ from .distributions import DiagonalGaussianDistribution
11
+
12
+ from .autokl_utils import LPIPSWithDiscriminator
13
+
14
+ @register('autoencoderkl')
15
+ class AutoencoderKL(nn.Module):
16
+ def __init__(self,
17
+ ddconfig,
18
+ lossconfig,
19
+ embed_dim,):
20
+ super().__init__()
21
+ self.encoder = Encoder(**ddconfig)
22
+ self.decoder = Decoder(**ddconfig)
23
+ if lossconfig is not None:
24
+ self.loss = LPIPSWithDiscriminator(**lossconfig)
25
+ assert ddconfig["double_z"]
26
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
27
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
28
+ self.embed_dim = embed_dim
29
+
30
+ @torch.no_grad()
31
+ def encode(self, x, out_posterior=False):
32
+ return self.encode_trainable(x, out_posterior)
33
+
34
+ def encode_trainable(self, x, out_posterior=False):
35
+ x = x*2-1
36
+ h = self.encoder(x)
37
+ moments = self.quant_conv(h)
38
+ posterior = DiagonalGaussianDistribution(moments)
39
+ if out_posterior:
40
+ return posterior
41
+ else:
42
+ return posterior.sample()
43
+
44
+ @torch.no_grad()
45
+ def decode(self, z):
46
+ dec = self.decode_trainable(z)
47
+ dec = torch.clamp(dec, 0, 1)
48
+ return dec
49
+
50
+ def decode_trainable(self, z):
51
+ z = self.post_quant_conv(z)
52
+ dec = self.decoder(z)
53
+ dec = (dec+1)/2
54
+ return dec
55
+
56
+ def apply_model(self, input, sample_posterior=True):
57
+ posterior = self.encode_trainable(input, out_posterior=True)
58
+ if sample_posterior:
59
+ z = posterior.sample()
60
+ else:
61
+ z = posterior.mode()
62
+ dec = self.decode_trainable(z)
63
+ return dec, posterior
64
+
65
+ def get_input(self, batch, k):
66
+ x = batch[k]
67
+ if len(x.shape) == 3:
68
+ x = x[..., None]
69
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
70
+ return x
71
+
72
+ def forward(self, x, optimizer_idx, global_step):
73
+ reconstructions, posterior = self.apply_model(x)
74
+
75
+ if optimizer_idx == 0:
76
+ # train encoder+decoder+logvar
77
+ aeloss, log_dict_ae = self.loss(x, reconstructions, posterior, optimizer_idx, global_step=global_step,
78
+ last_layer=self.get_last_layer(), split="train")
79
+ return aeloss, log_dict_ae
80
+
81
+ if optimizer_idx == 1:
82
+ # train the discriminator
83
+ discloss, log_dict_disc = self.loss(x, reconstructions, posterior, optimizer_idx, global_step=global_step,
84
+ last_layer=self.get_last_layer(), split="train")
85
+
86
+ return discloss, log_dict_disc
87
+
88
+ def validation_step(self, batch, batch_idx):
89
+ inputs = self.get_input(batch, self.image_key)
90
+ reconstructions, posterior = self(inputs)
91
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
92
+ last_layer=self.get_last_layer(), split="val")
93
+
94
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
95
+ last_layer=self.get_last_layer(), split="val")
96
+
97
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
98
+ self.log_dict(log_dict_ae)
99
+ self.log_dict(log_dict_disc)
100
+ return self.log_dict
101
+
102
+ def configure_optimizers(self):
103
+ lr = self.learning_rate
104
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
105
+ list(self.decoder.parameters())+
106
+ list(self.quant_conv.parameters())+
107
+ list(self.post_quant_conv.parameters()),
108
+ lr=lr, betas=(0.5, 0.9))
109
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
110
+ lr=lr, betas=(0.5, 0.9))
111
+ return [opt_ae, opt_disc], []
112
+
113
+ def get_last_layer(self):
114
+ return self.decoder.conv_out.weight
115
+
116
+ @torch.no_grad()
117
+ def log_images(self, batch, only_inputs=False, **kwargs):
118
+ log = dict()
119
+ x = self.get_input(batch, self.image_key)
120
+ x = x.to(self.device)
121
+ if not only_inputs:
122
+ xrec, posterior = self(x)
123
+ if x.shape[1] > 3:
124
+ # colorize with random projection
125
+ assert xrec.shape[1] > 3
126
+ x = self.to_rgb(x)
127
+ xrec = self.to_rgb(xrec)
128
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
129
+ log["reconstructions"] = xrec
130
+ log["inputs"] = x
131
+ return log
132
+
133
+ def to_rgb(self, x):
134
+ assert self.image_key == "segmentation"
135
+ if not hasattr(self, "colorize"):
136
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
137
+ x = F.conv2d(x, weight=self.colorize)
138
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
139
+ return x
140
+
141
+ @register('autoencoderkl_customnorm')
142
+ class AutoencoderKL_CustomNorm(AutoencoderKL):
143
+ def __init__(self, *args, **kwargs):
144
+ super().__init__(*args, **kwargs)
145
+ self.mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073])
146
+ self.std = torch.Tensor([0.26862954, 0.26130258, 0.27577711])
147
+
148
+ def encode_trainable(self, x, out_posterior=False):
149
+ m = self.mean[None, :, None, None].to(z.device).to(z.dtype)
150
+ s = self.std[None, :, None, None].to(z.device).to(z.dtype)
151
+ x = (x-m)/s
152
+ h = self.encoder(x)
153
+ moments = self.quant_conv(h)
154
+ posterior = DiagonalGaussianDistribution(moments)
155
+ if out_posterior:
156
+ return posterior
157
+ else:
158
+ return posterior.sample()
159
+
160
+ def decode_trainable(self, z):
161
+ m = self.mean[None, :, None, None].to(z.device).to(z.dtype)
162
+ s = self.std[None, :, None, None].to(z.device).to(z.dtype)
163
+ z = self.post_quant_conv(z)
164
+ dec = self.decoder(z)
165
+ dec = (dec+1)/2
166
+ return dec
lib/model_zoo/autokl_modules.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+
8
+ # from .diffusion_utils import instantiate_from_config
9
+ from .attention import LinearAttention
10
+
11
+
12
+ def get_timestep_embedding(timesteps, embedding_dim):
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
15
+ From Fairseq.
16
+ Build sinusoidal embeddings.
17
+ This matches the implementation in tensor2tensor, but differs slightly
18
+ from the description in Section 3.5 of "Attention Is All You Need".
19
+ """
20
+ assert len(timesteps.shape) == 1
21
+
22
+ half_dim = embedding_dim // 2
23
+ emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
25
+ emb = emb.to(device=timesteps.device)
26
+ emb = timesteps.float()[:, None] * emb[None, :]
27
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
+ if embedding_dim % 2 == 1: # zero pad
29
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
30
+ return emb
31
+
32
+
33
+ def nonlinearity(x):
34
+ # swish
35
+ return x*torch.sigmoid(x)
36
+
37
+
38
+ def Normalize(in_channels, num_groups=32):
39
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
40
+
41
+
42
+ class Upsample(nn.Module):
43
+ def __init__(self, in_channels, with_conv):
44
+ super().__init__()
45
+ self.with_conv = with_conv
46
+ if self.with_conv:
47
+ self.conv = torch.nn.Conv2d(in_channels,
48
+ in_channels,
49
+ kernel_size=3,
50
+ stride=1,
51
+ padding=1)
52
+
53
+ def forward(self, x):
54
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
55
+ if self.with_conv:
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class Downsample(nn.Module):
61
+ def __init__(self, in_channels, with_conv):
62
+ super().__init__()
63
+ self.with_conv = with_conv
64
+ if self.with_conv:
65
+ # no asymmetric padding in torch conv, must do it ourselves
66
+ self.conv = torch.nn.Conv2d(in_channels,
67
+ in_channels,
68
+ kernel_size=3,
69
+ stride=2,
70
+ padding=0)
71
+
72
+ def forward(self, x):
73
+ if self.with_conv:
74
+ pad = (0,1,0,1)
75
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
76
+ x = self.conv(x)
77
+ else:
78
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
79
+ return x
80
+
81
+
82
+ class ResnetBlock(nn.Module):
83
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
84
+ dropout, temb_channels=512):
85
+ super().__init__()
86
+ self.in_channels = in_channels
87
+ out_channels = in_channels if out_channels is None else out_channels
88
+ self.out_channels = out_channels
89
+ self.use_conv_shortcut = conv_shortcut
90
+
91
+ self.norm1 = Normalize(in_channels)
92
+ self.conv1 = torch.nn.Conv2d(in_channels,
93
+ out_channels,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1)
97
+ if temb_channels > 0:
98
+ self.temb_proj = torch.nn.Linear(temb_channels,
99
+ out_channels)
100
+ self.norm2 = Normalize(out_channels)
101
+ self.dropout = torch.nn.Dropout(dropout)
102
+ self.conv2 = torch.nn.Conv2d(out_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ if self.in_channels != self.out_channels:
108
+ if self.use_conv_shortcut:
109
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
110
+ out_channels,
111
+ kernel_size=3,
112
+ stride=1,
113
+ padding=1)
114
+ else:
115
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
116
+ out_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+
121
+ def forward(self, x, temb):
122
+ h = x
123
+ h = self.norm1(h)
124
+ h = nonlinearity(h)
125
+ h = self.conv1(h)
126
+
127
+ if temb is not None:
128
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
129
+
130
+ h = self.norm2(h)
131
+ h = nonlinearity(h)
132
+ h = self.dropout(h)
133
+ h = self.conv2(h)
134
+
135
+ if self.in_channels != self.out_channels:
136
+ if self.use_conv_shortcut:
137
+ x = self.conv_shortcut(x)
138
+ else:
139
+ x = self.nin_shortcut(x)
140
+
141
+ return x+h
142
+
143
+
144
+ class LinAttnBlock(LinearAttention):
145
+ """to match AttnBlock usage"""
146
+ def __init__(self, in_channels):
147
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
148
+
149
+
150
+ class AttnBlock(nn.Module):
151
+ def __init__(self, in_channels):
152
+ super().__init__()
153
+ self.in_channels = in_channels
154
+
155
+ self.norm = Normalize(in_channels)
156
+ self.q = torch.nn.Conv2d(in_channels,
157
+ in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0)
161
+ self.k = torch.nn.Conv2d(in_channels,
162
+ in_channels,
163
+ kernel_size=1,
164
+ stride=1,
165
+ padding=0)
166
+ self.v = torch.nn.Conv2d(in_channels,
167
+ in_channels,
168
+ kernel_size=1,
169
+ stride=1,
170
+ padding=0)
171
+ self.proj_out = torch.nn.Conv2d(in_channels,
172
+ in_channels,
173
+ kernel_size=1,
174
+ stride=1,
175
+ padding=0)
176
+
177
+
178
+ def forward(self, x):
179
+ h_ = x
180
+ h_ = self.norm(h_)
181
+ q = self.q(h_)
182
+ k = self.k(h_)
183
+ v = self.v(h_)
184
+
185
+ # compute attention
186
+ b,c,h,w = q.shape
187
+ q = q.reshape(b,c,h*w)
188
+ q = q.permute(0,2,1) # b,hw,c
189
+ k = k.reshape(b,c,h*w) # b,c,hw
190
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
191
+ w_ = w_ * (int(c)**(-0.5))
192
+ w_ = torch.nn.functional.softmax(w_, dim=2)
193
+
194
+ # attend to values
195
+ v = v.reshape(b,c,h*w)
196
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
197
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
198
+ h_ = h_.reshape(b,c,h,w)
199
+
200
+ h_ = self.proj_out(h_)
201
+
202
+ return x+h_
203
+
204
+
205
+ def make_attn(in_channels, attn_type="vanilla"):
206
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
207
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
208
+ if attn_type == "vanilla":
209
+ return AttnBlock(in_channels)
210
+ elif attn_type == "none":
211
+ return nn.Identity(in_channels)
212
+ else:
213
+ return LinAttnBlock(in_channels)
214
+
215
+
216
+ class Model(nn.Module):
217
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
218
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
219
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
220
+ super().__init__()
221
+ if use_linear_attn: attn_type = "linear"
222
+ self.ch = ch
223
+ self.temb_ch = self.ch*4
224
+ self.num_resolutions = len(ch_mult)
225
+ self.num_res_blocks = num_res_blocks
226
+ self.resolution = resolution
227
+ self.in_channels = in_channels
228
+
229
+ self.use_timestep = use_timestep
230
+ if self.use_timestep:
231
+ # timestep embedding
232
+ self.temb = nn.Module()
233
+ self.temb.dense = nn.ModuleList([
234
+ torch.nn.Linear(self.ch,
235
+ self.temb_ch),
236
+ torch.nn.Linear(self.temb_ch,
237
+ self.temb_ch),
238
+ ])
239
+
240
+ # downsampling
241
+ self.conv_in = torch.nn.Conv2d(in_channels,
242
+ self.ch,
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1)
246
+
247
+ curr_res = resolution
248
+ in_ch_mult = (1,)+tuple(ch_mult)
249
+ self.down = nn.ModuleList()
250
+ for i_level in range(self.num_resolutions):
251
+ block = nn.ModuleList()
252
+ attn = nn.ModuleList()
253
+ block_in = ch*in_ch_mult[i_level]
254
+ block_out = ch*ch_mult[i_level]
255
+ for i_block in range(self.num_res_blocks):
256
+ block.append(ResnetBlock(in_channels=block_in,
257
+ out_channels=block_out,
258
+ temb_channels=self.temb_ch,
259
+ dropout=dropout))
260
+ block_in = block_out
261
+ if curr_res in attn_resolutions:
262
+ attn.append(make_attn(block_in, attn_type=attn_type))
263
+ down = nn.Module()
264
+ down.block = block
265
+ down.attn = attn
266
+ if i_level != self.num_resolutions-1:
267
+ down.downsample = Downsample(block_in, resamp_with_conv)
268
+ curr_res = curr_res // 2
269
+ self.down.append(down)
270
+
271
+ # middle
272
+ self.mid = nn.Module()
273
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
274
+ out_channels=block_in,
275
+ temb_channels=self.temb_ch,
276
+ dropout=dropout)
277
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
278
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
279
+ out_channels=block_in,
280
+ temb_channels=self.temb_ch,
281
+ dropout=dropout)
282
+
283
+ # upsampling
284
+ self.up = nn.ModuleList()
285
+ for i_level in reversed(range(self.num_resolutions)):
286
+ block = nn.ModuleList()
287
+ attn = nn.ModuleList()
288
+ block_out = ch*ch_mult[i_level]
289
+ skip_in = ch*ch_mult[i_level]
290
+ for i_block in range(self.num_res_blocks+1):
291
+ if i_block == self.num_res_blocks:
292
+ skip_in = ch*in_ch_mult[i_level]
293
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
294
+ out_channels=block_out,
295
+ temb_channels=self.temb_ch,
296
+ dropout=dropout))
297
+ block_in = block_out
298
+ if curr_res in attn_resolutions:
299
+ attn.append(make_attn(block_in, attn_type=attn_type))
300
+ up = nn.Module()
301
+ up.block = block
302
+ up.attn = attn
303
+ if i_level != 0:
304
+ up.upsample = Upsample(block_in, resamp_with_conv)
305
+ curr_res = curr_res * 2
306
+ self.up.insert(0, up) # prepend to get consistent order
307
+
308
+ # end
309
+ self.norm_out = Normalize(block_in)
310
+ self.conv_out = torch.nn.Conv2d(block_in,
311
+ out_ch,
312
+ kernel_size=3,
313
+ stride=1,
314
+ padding=1)
315
+
316
+ def forward(self, x, t=None, context=None):
317
+ #assert x.shape[2] == x.shape[3] == self.resolution
318
+ if context is not None:
319
+ # assume aligned context, cat along channel axis
320
+ x = torch.cat((x, context), dim=1)
321
+ if self.use_timestep:
322
+ # timestep embedding
323
+ assert t is not None
324
+ temb = get_timestep_embedding(t, self.ch)
325
+ temb = self.temb.dense[0](temb)
326
+ temb = nonlinearity(temb)
327
+ temb = self.temb.dense[1](temb)
328
+ else:
329
+ temb = None
330
+
331
+ # downsampling
332
+ hs = [self.conv_in(x)]
333
+ for i_level in range(self.num_resolutions):
334
+ for i_block in range(self.num_res_blocks):
335
+ h = self.down[i_level].block[i_block](hs[-1], temb)
336
+ if len(self.down[i_level].attn) > 0:
337
+ h = self.down[i_level].attn[i_block](h)
338
+ hs.append(h)
339
+ if i_level != self.num_resolutions-1:
340
+ hs.append(self.down[i_level].downsample(hs[-1]))
341
+
342
+ # middle
343
+ h = hs[-1]
344
+ h = self.mid.block_1(h, temb)
345
+ h = self.mid.attn_1(h)
346
+ h = self.mid.block_2(h, temb)
347
+
348
+ # upsampling
349
+ for i_level in reversed(range(self.num_resolutions)):
350
+ for i_block in range(self.num_res_blocks+1):
351
+ h = self.up[i_level].block[i_block](
352
+ torch.cat([h, hs.pop()], dim=1), temb)
353
+ if len(self.up[i_level].attn) > 0:
354
+ h = self.up[i_level].attn[i_block](h)
355
+ if i_level != 0:
356
+ h = self.up[i_level].upsample(h)
357
+
358
+ # end
359
+ h = self.norm_out(h)
360
+ h = nonlinearity(h)
361
+ h = self.conv_out(h)
362
+ return h
363
+
364
+ def get_last_layer(self):
365
+ return self.conv_out.weight
366
+
367
+
368
+ class Encoder(nn.Module):
369
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
370
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
371
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
372
+ **ignore_kwargs):
373
+ super().__init__()
374
+ if use_linear_attn: attn_type = "linear"
375
+ self.ch = ch
376
+ self.temb_ch = 0
377
+ self.num_resolutions = len(ch_mult)
378
+ self.num_res_blocks = num_res_blocks
379
+ self.resolution = resolution
380
+ self.in_channels = in_channels
381
+
382
+ # downsampling
383
+ self.conv_in = torch.nn.Conv2d(in_channels,
384
+ self.ch,
385
+ kernel_size=3,
386
+ stride=1,
387
+ padding=1)
388
+
389
+ curr_res = resolution
390
+ in_ch_mult = (1,)+tuple(ch_mult)
391
+ self.in_ch_mult = in_ch_mult
392
+ self.down = nn.ModuleList()
393
+ for i_level in range(self.num_resolutions):
394
+ block = nn.ModuleList()
395
+ attn = nn.ModuleList()
396
+ block_in = ch*in_ch_mult[i_level]
397
+ block_out = ch*ch_mult[i_level]
398
+ for i_block in range(self.num_res_blocks):
399
+ block.append(ResnetBlock(in_channels=block_in,
400
+ out_channels=block_out,
401
+ temb_channels=self.temb_ch,
402
+ dropout=dropout))
403
+ block_in = block_out
404
+ if curr_res in attn_resolutions:
405
+ attn.append(make_attn(block_in, attn_type=attn_type))
406
+ down = nn.Module()
407
+ down.block = block
408
+ down.attn = attn
409
+ if i_level != self.num_resolutions-1:
410
+ down.downsample = Downsample(block_in, resamp_with_conv)
411
+ curr_res = curr_res // 2
412
+ self.down.append(down)
413
+
414
+ # middle
415
+ self.mid = nn.Module()
416
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
417
+ out_channels=block_in,
418
+ temb_channels=self.temb_ch,
419
+ dropout=dropout)
420
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
421
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
422
+ out_channels=block_in,
423
+ temb_channels=self.temb_ch,
424
+ dropout=dropout)
425
+
426
+ # end
427
+ self.norm_out = Normalize(block_in)
428
+ self.conv_out = torch.nn.Conv2d(block_in,
429
+ 2*z_channels if double_z else z_channels,
430
+ kernel_size=3,
431
+ stride=1,
432
+ padding=1)
433
+
434
+ def forward(self, x):
435
+ # timestep embedding
436
+ temb = None
437
+
438
+ # downsampling
439
+ hs = [self.conv_in(x)]
440
+ for i_level in range(self.num_resolutions):
441
+ for i_block in range(self.num_res_blocks):
442
+ h = self.down[i_level].block[i_block](hs[-1], temb)
443
+ if len(self.down[i_level].attn) > 0:
444
+ h = self.down[i_level].attn[i_block](h)
445
+ hs.append(h)
446
+ if i_level != self.num_resolutions-1:
447
+ hs.append(self.down[i_level].downsample(hs[-1]))
448
+
449
+ # middle
450
+ h = hs[-1]
451
+ h = self.mid.block_1(h, temb)
452
+ h = self.mid.attn_1(h)
453
+ h = self.mid.block_2(h, temb)
454
+
455
+ # end
456
+ h = self.norm_out(h)
457
+ h = nonlinearity(h)
458
+ h = self.conv_out(h)
459
+ return h
460
+
461
+
462
+ class Decoder(nn.Module):
463
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
464
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
465
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
466
+ attn_type="vanilla", **ignorekwargs):
467
+ super().__init__()
468
+ if use_linear_attn: attn_type = "linear"
469
+ self.ch = ch
470
+ self.temb_ch = 0
471
+ self.num_resolutions = len(ch_mult)
472
+ self.num_res_blocks = num_res_blocks
473
+ self.resolution = resolution
474
+ self.in_channels = in_channels
475
+ self.give_pre_end = give_pre_end
476
+ self.tanh_out = tanh_out
477
+
478
+ # compute in_ch_mult, block_in and curr_res at lowest res
479
+ in_ch_mult = (1,)+tuple(ch_mult)
480
+ block_in = ch*ch_mult[self.num_resolutions-1]
481
+ curr_res = resolution // 2**(self.num_resolutions-1)
482
+ self.z_shape = (1,z_channels,curr_res,curr_res)
483
+ print("Working with z of shape {} = {} dimensions.".format(
484
+ self.z_shape, np.prod(self.z_shape)))
485
+
486
+ # z to block_in
487
+ self.conv_in = torch.nn.Conv2d(z_channels,
488
+ block_in,
489
+ kernel_size=3,
490
+ stride=1,
491
+ padding=1)
492
+
493
+ # middle
494
+ self.mid = nn.Module()
495
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
496
+ out_channels=block_in,
497
+ temb_channels=self.temb_ch,
498
+ dropout=dropout)
499
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
500
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
501
+ out_channels=block_in,
502
+ temb_channels=self.temb_ch,
503
+ dropout=dropout)
504
+
505
+ # upsampling
506
+ self.up = nn.ModuleList()
507
+ for i_level in reversed(range(self.num_resolutions)):
508
+ block = nn.ModuleList()
509
+ attn = nn.ModuleList()
510
+ block_out = ch*ch_mult[i_level]
511
+ for i_block in range(self.num_res_blocks+1):
512
+ block.append(ResnetBlock(in_channels=block_in,
513
+ out_channels=block_out,
514
+ temb_channels=self.temb_ch,
515
+ dropout=dropout))
516
+ block_in = block_out
517
+ if curr_res in attn_resolutions:
518
+ attn.append(make_attn(block_in, attn_type=attn_type))
519
+ up = nn.Module()
520
+ up.block = block
521
+ up.attn = attn
522
+ if i_level != 0:
523
+ up.upsample = Upsample(block_in, resamp_with_conv)
524
+ curr_res = curr_res * 2
525
+ self.up.insert(0, up) # prepend to get consistent order
526
+
527
+ # end
528
+ self.norm_out = Normalize(block_in)
529
+ self.conv_out = torch.nn.Conv2d(block_in,
530
+ out_ch,
531
+ kernel_size=3,
532
+ stride=1,
533
+ padding=1)
534
+
535
+ def forward(self, z):
536
+ #assert z.shape[1:] == self.z_shape[1:]
537
+ self.last_z_shape = z.shape
538
+
539
+ # timestep embedding
540
+ temb = None
541
+
542
+ # z to block_in
543
+ h = self.conv_in(z)
544
+
545
+ # middle
546
+ h = self.mid.block_1(h, temb)
547
+ h = self.mid.attn_1(h)
548
+ h = self.mid.block_2(h, temb)
549
+
550
+ # upsampling
551
+ for i_level in reversed(range(self.num_resolutions)):
552
+ for i_block in range(self.num_res_blocks+1):
553
+ h = self.up[i_level].block[i_block](h, temb)
554
+ if len(self.up[i_level].attn) > 0:
555
+ h = self.up[i_level].attn[i_block](h)
556
+ if i_level != 0:
557
+ h = self.up[i_level].upsample(h)
558
+
559
+ # end
560
+ if self.give_pre_end:
561
+ return h
562
+
563
+ h = self.norm_out(h)
564
+ h = nonlinearity(h)
565
+ h = self.conv_out(h)
566
+ if self.tanh_out:
567
+ h = torch.tanh(h)
568
+ return h
569
+
570
+
571
+ class SimpleDecoder(nn.Module):
572
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
573
+ super().__init__()
574
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
575
+ ResnetBlock(in_channels=in_channels,
576
+ out_channels=2 * in_channels,
577
+ temb_channels=0, dropout=0.0),
578
+ ResnetBlock(in_channels=2 * in_channels,
579
+ out_channels=4 * in_channels,
580
+ temb_channels=0, dropout=0.0),
581
+ ResnetBlock(in_channels=4 * in_channels,
582
+ out_channels=2 * in_channels,
583
+ temb_channels=0, dropout=0.0),
584
+ nn.Conv2d(2*in_channels, in_channels, 1),
585
+ Upsample(in_channels, with_conv=True)])
586
+ # end
587
+ self.norm_out = Normalize(in_channels)
588
+ self.conv_out = torch.nn.Conv2d(in_channels,
589
+ out_channels,
590
+ kernel_size=3,
591
+ stride=1,
592
+ padding=1)
593
+
594
+ def forward(self, x):
595
+ for i, layer in enumerate(self.model):
596
+ if i in [1,2,3]:
597
+ x = layer(x, None)
598
+ else:
599
+ x = layer(x)
600
+
601
+ h = self.norm_out(x)
602
+ h = nonlinearity(h)
603
+ x = self.conv_out(h)
604
+ return x
605
+
606
+
607
+ class UpsampleDecoder(nn.Module):
608
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
609
+ ch_mult=(2,2), dropout=0.0):
610
+ super().__init__()
611
+ # upsampling
612
+ self.temb_ch = 0
613
+ self.num_resolutions = len(ch_mult)
614
+ self.num_res_blocks = num_res_blocks
615
+ block_in = in_channels
616
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
617
+ self.res_blocks = nn.ModuleList()
618
+ self.upsample_blocks = nn.ModuleList()
619
+ for i_level in range(self.num_resolutions):
620
+ res_block = []
621
+ block_out = ch * ch_mult[i_level]
622
+ for i_block in range(self.num_res_blocks + 1):
623
+ res_block.append(ResnetBlock(in_channels=block_in,
624
+ out_channels=block_out,
625
+ temb_channels=self.temb_ch,
626
+ dropout=dropout))
627
+ block_in = block_out
628
+ self.res_blocks.append(nn.ModuleList(res_block))
629
+ if i_level != self.num_resolutions - 1:
630
+ self.upsample_blocks.append(Upsample(block_in, True))
631
+ curr_res = curr_res * 2
632
+
633
+ # end
634
+ self.norm_out = Normalize(block_in)
635
+ self.conv_out = torch.nn.Conv2d(block_in,
636
+ out_channels,
637
+ kernel_size=3,
638
+ stride=1,
639
+ padding=1)
640
+
641
+ def forward(self, x):
642
+ # upsampling
643
+ h = x
644
+ for k, i_level in enumerate(range(self.num_resolutions)):
645
+ for i_block in range(self.num_res_blocks + 1):
646
+ h = self.res_blocks[i_level][i_block](h, None)
647
+ if i_level != self.num_resolutions - 1:
648
+ h = self.upsample_blocks[k](h)
649
+ h = self.norm_out(h)
650
+ h = nonlinearity(h)
651
+ h = self.conv_out(h)
652
+ return h
653
+
654
+
655
+ class LatentRescaler(nn.Module):
656
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
657
+ super().__init__()
658
+ # residual block, interpolate, residual block
659
+ self.factor = factor
660
+ self.conv_in = nn.Conv2d(in_channels,
661
+ mid_channels,
662
+ kernel_size=3,
663
+ stride=1,
664
+ padding=1)
665
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
666
+ out_channels=mid_channels,
667
+ temb_channels=0,
668
+ dropout=0.0) for _ in range(depth)])
669
+ self.attn = AttnBlock(mid_channels)
670
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
671
+ out_channels=mid_channels,
672
+ temb_channels=0,
673
+ dropout=0.0) for _ in range(depth)])
674
+
675
+ self.conv_out = nn.Conv2d(mid_channels,
676
+ out_channels,
677
+ kernel_size=1,
678
+ )
679
+
680
+ def forward(self, x):
681
+ x = self.conv_in(x)
682
+ for block in self.res_block1:
683
+ x = block(x, None)
684
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
685
+ x = self.attn(x)
686
+ for block in self.res_block2:
687
+ x = block(x, None)
688
+ x = self.conv_out(x)
689
+ return x
690
+
691
+
692
+ class MergedRescaleEncoder(nn.Module):
693
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
694
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
695
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
696
+ super().__init__()
697
+ intermediate_chn = ch * ch_mult[-1]
698
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
699
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
700
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
701
+ out_ch=None)
702
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
703
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
704
+
705
+ def forward(self, x):
706
+ x = self.encoder(x)
707
+ x = self.rescaler(x)
708
+ return x
709
+
710
+
711
+ class MergedRescaleDecoder(nn.Module):
712
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
713
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
714
+ super().__init__()
715
+ tmp_chn = z_channels*ch_mult[-1]
716
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
717
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
718
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
719
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
720
+ out_channels=tmp_chn, depth=rescale_module_depth)
721
+
722
+ def forward(self, x):
723
+ x = self.rescaler(x)
724
+ x = self.decoder(x)
725
+ return x
726
+
727
+
728
+ class Upsampler(nn.Module):
729
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
730
+ super().__init__()
731
+ assert out_size >= in_size
732
+ num_blocks = int(np.log2(out_size//in_size))+1
733
+ factor_up = 1.+ (out_size % in_size)
734
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
735
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
736
+ out_channels=in_channels)
737
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
738
+ attn_resolutions=[], in_channels=None, ch=in_channels,
739
+ ch_mult=[ch_mult for _ in range(num_blocks)])
740
+
741
+ def forward(self, x):
742
+ x = self.rescaler(x)
743
+ x = self.decoder(x)
744
+ return x
745
+
746
+
747
+ class Resize(nn.Module):
748
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
749
+ super().__init__()
750
+ self.with_conv = learned
751
+ self.mode = mode
752
+ if self.with_conv:
753
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
754
+ raise NotImplementedError()
755
+ assert in_channels is not None
756
+ # no asymmetric padding in torch conv, must do it ourselves
757
+ self.conv = torch.nn.Conv2d(in_channels,
758
+ in_channels,
759
+ kernel_size=4,
760
+ stride=2,
761
+ padding=1)
762
+
763
+ def forward(self, x, scale_factor=1.0):
764
+ if scale_factor==1.0:
765
+ return x
766
+ else:
767
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
768
+ return x
769
+
770
+ class FirstStagePostProcessor(nn.Module):
771
+
772
+ def __init__(self, ch_mult:list, in_channels,
773
+ pretrained_model:nn.Module=None,
774
+ reshape=False,
775
+ n_channels=None,
776
+ dropout=0.,
777
+ pretrained_config=None):
778
+ super().__init__()
779
+ if pretrained_config is None:
780
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
781
+ self.pretrained_model = pretrained_model
782
+ else:
783
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
784
+ self.instantiate_pretrained(pretrained_config)
785
+
786
+ self.do_reshape = reshape
787
+
788
+ if n_channels is None:
789
+ n_channels = self.pretrained_model.encoder.ch
790
+
791
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
792
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
793
+ stride=1,padding=1)
794
+
795
+ blocks = []
796
+ downs = []
797
+ ch_in = n_channels
798
+ for m in ch_mult:
799
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
800
+ ch_in = m * n_channels
801
+ downs.append(Downsample(ch_in, with_conv=False))
802
+
803
+ self.model = nn.ModuleList(blocks)
804
+ self.downsampler = nn.ModuleList(downs)
805
+
806
+
807
+ def instantiate_pretrained(self, config):
808
+ model = instantiate_from_config(config)
809
+ self.pretrained_model = model.eval()
810
+ # self.pretrained_model.train = False
811
+ for param in self.pretrained_model.parameters():
812
+ param.requires_grad = False
813
+
814
+
815
+ @torch.no_grad()
816
+ def encode_with_pretrained(self,x):
817
+ c = self.pretrained_model.encode(x)
818
+ if isinstance(c, DiagonalGaussianDistribution):
819
+ c = c.mode()
820
+ return c
821
+
822
+ def forward(self,x):
823
+ z_fs = self.encode_with_pretrained(x)
824
+ z = self.proj_norm(z_fs)
825
+ z = self.proj(z)
826
+ z = nonlinearity(z)
827
+
828
+ for submodel, downmodel in zip(self.model,self.downsampler):
829
+ z = submodel(z,temb=None)
830
+ z = downmodel(z)
831
+
832
+ if self.do_reshape:
833
+ z = rearrange(z,'b c h w -> b (h w) c')
834
+ return z
835
+
lib/model_zoo/autokl_utils.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import functools
4
+
5
+ class ActNorm(nn.Module):
6
+ def __init__(self, num_features, logdet=False, affine=True,
7
+ allow_reverse_init=False):
8
+ assert affine
9
+ super().__init__()
10
+ self.logdet = logdet
11
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
12
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
13
+ self.allow_reverse_init = allow_reverse_init
14
+
15
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
16
+
17
+ def initialize(self, input):
18
+ with torch.no_grad():
19
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
20
+ mean = (
21
+ flatten.mean(1)
22
+ .unsqueeze(1)
23
+ .unsqueeze(2)
24
+ .unsqueeze(3)
25
+ .permute(1, 0, 2, 3)
26
+ )
27
+ std = (
28
+ flatten.std(1)
29
+ .unsqueeze(1)
30
+ .unsqueeze(2)
31
+ .unsqueeze(3)
32
+ .permute(1, 0, 2, 3)
33
+ )
34
+
35
+ self.loc.data.copy_(-mean)
36
+ self.scale.data.copy_(1 / (std + 1e-6))
37
+
38
+ def forward(self, input, reverse=False):
39
+ if reverse:
40
+ return self.reverse(input)
41
+ if len(input.shape) == 2:
42
+ input = input[:,:,None,None]
43
+ squeeze = True
44
+ else:
45
+ squeeze = False
46
+
47
+ _, _, height, width = input.shape
48
+
49
+ if self.training and self.initialized.item() == 0:
50
+ self.initialize(input)
51
+ self.initialized.fill_(1)
52
+
53
+ h = self.scale * (input + self.loc)
54
+
55
+ if squeeze:
56
+ h = h.squeeze(-1).squeeze(-1)
57
+
58
+ if self.logdet:
59
+ log_abs = torch.log(torch.abs(self.scale))
60
+ logdet = height*width*torch.sum(log_abs)
61
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
62
+ return h, logdet
63
+
64
+ return h
65
+
66
+ def reverse(self, output):
67
+ if self.training and self.initialized.item() == 0:
68
+ if not self.allow_reverse_init:
69
+ raise RuntimeError(
70
+ "Initializing ActNorm in reverse direction is "
71
+ "disabled by default. Use allow_reverse_init=True to enable."
72
+ )
73
+ else:
74
+ self.initialize(output)
75
+ self.initialized.fill_(1)
76
+
77
+ if len(output.shape) == 2:
78
+ output = output[:,:,None,None]
79
+ squeeze = True
80
+ else:
81
+ squeeze = False
82
+
83
+ h = output / self.scale - self.loc
84
+
85
+ if squeeze:
86
+ h = h.squeeze(-1).squeeze(-1)
87
+ return h
88
+
89
+ #################
90
+ # Discriminator #
91
+ #################
92
+
93
+ def weights_init(m):
94
+ classname = m.__class__.__name__
95
+ if classname.find('Conv') != -1:
96
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
97
+ elif classname.find('BatchNorm') != -1:
98
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
99
+ nn.init.constant_(m.bias.data, 0)
100
+
101
+ class NLayerDiscriminator(nn.Module):
102
+ """Defines a PatchGAN discriminator as in Pix2Pix
103
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
104
+ """
105
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
106
+ """Construct a PatchGAN discriminator
107
+ Parameters:
108
+ input_nc (int) -- the number of channels in input images
109
+ ndf (int) -- the number of filters in the last conv layer
110
+ n_layers (int) -- the number of conv layers in the discriminator
111
+ norm_layer -- normalization layer
112
+ """
113
+ super(NLayerDiscriminator, self).__init__()
114
+ if not use_actnorm:
115
+ norm_layer = nn.BatchNorm2d
116
+ else:
117
+ norm_layer = ActNorm
118
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
119
+ use_bias = norm_layer.func != nn.BatchNorm2d
120
+ else:
121
+ use_bias = norm_layer != nn.BatchNorm2d
122
+
123
+ kw = 4
124
+ padw = 1
125
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
126
+ nf_mult = 1
127
+ nf_mult_prev = 1
128
+ for n in range(1, n_layers): # gradually increase the number of filters
129
+ nf_mult_prev = nf_mult
130
+ nf_mult = min(2 ** n, 8)
131
+ sequence += [
132
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
133
+ norm_layer(ndf * nf_mult),
134
+ nn.LeakyReLU(0.2, True)
135
+ ]
136
+
137
+ nf_mult_prev = nf_mult
138
+ nf_mult = min(2 ** n_layers, 8)
139
+ sequence += [
140
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
141
+ norm_layer(ndf * nf_mult),
142
+ nn.LeakyReLU(0.2, True)
143
+ ]
144
+
145
+ sequence += [
146
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
147
+ self.main = nn.Sequential(*sequence)
148
+
149
+ def forward(self, input):
150
+ """Standard forward."""
151
+ return self.main(input)
152
+
153
+ #########
154
+ # LPIPS #
155
+ #########
156
+
157
+ class ScalingLayer(nn.Module):
158
+ def __init__(self):
159
+ super(ScalingLayer, self).__init__()
160
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
161
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
162
+
163
+ def forward(self, inp):
164
+ return (inp - self.shift) / self.scale
165
+
166
+ class NetLinLayer(nn.Module):
167
+ """ A single linear layer which does a 1x1 conv """
168
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
169
+ super(NetLinLayer, self).__init__()
170
+ layers = [nn.Dropout(), ] if (use_dropout) else []
171
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
172
+ self.model = nn.Sequential(*layers)
173
+
174
+ from collections import namedtuple
175
+ from torchvision import models
176
+ from torchvision.models import VGG16_Weights
177
+
178
+ class vgg16(torch.nn.Module):
179
+ def __init__(self, requires_grad=False, pretrained=True):
180
+ super(vgg16, self).__init__()
181
+ if pretrained:
182
+ vgg_pretrained_features = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features
183
+ self.slice1 = torch.nn.Sequential()
184
+ self.slice2 = torch.nn.Sequential()
185
+ self.slice3 = torch.nn.Sequential()
186
+ self.slice4 = torch.nn.Sequential()
187
+ self.slice5 = torch.nn.Sequential()
188
+ self.N_slices = 5
189
+ for x in range(4):
190
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
191
+ for x in range(4, 9):
192
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
193
+ for x in range(9, 16):
194
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
195
+ for x in range(16, 23):
196
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
197
+ for x in range(23, 30):
198
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
199
+ if not requires_grad:
200
+ for param in self.parameters():
201
+ param.requires_grad = False
202
+
203
+ def forward(self, X):
204
+ h = self.slice1(X)
205
+ h_relu1_2 = h
206
+ h = self.slice2(h)
207
+ h_relu2_2 = h
208
+ h = self.slice3(h)
209
+ h_relu3_3 = h
210
+ h = self.slice4(h)
211
+ h_relu4_3 = h
212
+ h = self.slice5(h)
213
+ h_relu5_3 = h
214
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
215
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
216
+ return out
217
+
218
+ def normalize_tensor(x,eps=1e-10):
219
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
220
+ return x/(norm_factor+eps)
221
+
222
+ def spatial_average(x, keepdim=True):
223
+ return x.mean([2,3],keepdim=keepdim)
224
+
225
+ def get_ckpt_path(*args, **kwargs):
226
+ return 'pretrained/lpips.pth'
227
+
228
+ class LPIPS(nn.Module):
229
+ # Learned perceptual metric
230
+ def __init__(self, use_dropout=True):
231
+ super().__init__()
232
+ self.scaling_layer = ScalingLayer()
233
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
234
+ self.net = vgg16(pretrained=True, requires_grad=False)
235
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
236
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
237
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
238
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
239
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
240
+ self.load_from_pretrained()
241
+ for param in self.parameters():
242
+ param.requires_grad = False
243
+
244
+ def load_from_pretrained(self, name="vgg_lpips"):
245
+ ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
246
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
247
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
248
+
249
+ @classmethod
250
+ def from_pretrained(cls, name="vgg_lpips"):
251
+ if name != "vgg_lpips":
252
+ raise NotImplementedError
253
+ model = cls()
254
+ ckpt = get_ckpt_path(name)
255
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
256
+ return model
257
+
258
+ def forward(self, input, target):
259
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
260
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
261
+ feats0, feats1, diffs = {}, {}, {}
262
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
263
+ for kk in range(len(self.chns)):
264
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
265
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
266
+
267
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
268
+ val = res[0]
269
+ for l in range(1, len(self.chns)):
270
+ val += res[l]
271
+ return val
272
+
273
+ ############
274
+ # The loss #
275
+ ############
276
+
277
+ def adopt_weight(weight, global_step, threshold=0, value=0.):
278
+ if global_step < threshold:
279
+ weight = value
280
+ return weight
281
+
282
+ def hinge_d_loss(logits_real, logits_fake):
283
+ loss_real = torch.mean(F.relu(1. - logits_real))
284
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
285
+ d_loss = 0.5 * (loss_real + loss_fake)
286
+ return d_loss
287
+
288
+ def vanilla_d_loss(logits_real, logits_fake):
289
+ d_loss = 0.5 * (
290
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
291
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
292
+ return d_loss
293
+
294
+ class LPIPSWithDiscriminator(nn.Module):
295
+ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
296
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
297
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
298
+ disc_loss="hinge"):
299
+
300
+ super().__init__()
301
+ assert disc_loss in ["hinge", "vanilla"]
302
+ self.kl_weight = kl_weight
303
+ self.pixel_weight = pixelloss_weight
304
+ self.perceptual_loss = LPIPS().eval()
305
+ self.perceptual_weight = perceptual_weight
306
+ # output log variance
307
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
308
+
309
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
310
+ n_layers=disc_num_layers,
311
+ use_actnorm=use_actnorm
312
+ ).apply(weights_init)
313
+ self.discriminator_iter_start = disc_start
314
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
315
+ self.disc_factor = disc_factor
316
+ self.discriminator_weight = disc_weight
317
+ self.disc_conditional = disc_conditional
318
+
319
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
320
+ if last_layer is not None:
321
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
322
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
323
+ else:
324
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
325
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
326
+
327
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
328
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
329
+ d_weight = d_weight * self.discriminator_weight
330
+ return d_weight
331
+
332
+ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
333
+ global_step, last_layer=None, cond=None, split="train",
334
+ weights=None):
335
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
336
+ if self.perceptual_weight > 0:
337
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
338
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
339
+
340
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
341
+ weighted_nll_loss = nll_loss
342
+ if weights is not None:
343
+ weighted_nll_loss = weights*nll_loss
344
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
345
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
346
+ kl_loss = posteriors.kl()
347
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
348
+
349
+ # now the GAN part
350
+ if optimizer_idx == 0:
351
+ # generator update
352
+ if cond is None:
353
+ assert not self.disc_conditional
354
+ logits_fake = self.discriminator(reconstructions.contiguous())
355
+ else:
356
+ assert self.disc_conditional
357
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
358
+ g_loss = -torch.mean(logits_fake)
359
+
360
+ if self.disc_factor > 0.0:
361
+ try:
362
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
363
+ except RuntimeError:
364
+ assert not self.training
365
+ d_weight = torch.tensor(0.0)
366
+ else:
367
+ d_weight = torch.tensor(0.0)
368
+
369
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
370
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
371
+
372
+ log = {"Loss": loss.clone().detach().mean(),
373
+ "logvar": self.logvar.detach(),
374
+ "loss_kl": kl_loss.detach().mean(),
375
+ "loss_nll": nll_loss.detach().mean(),
376
+ "loss_rec": rec_loss.detach().mean(),
377
+ "d_weight": d_weight.detach(),
378
+ "disc_factor": torch.tensor(disc_factor),
379
+ "loss_g": g_loss.detach().mean(),
380
+ }
381
+ return loss, log
382
+
383
+ if optimizer_idx == 1:
384
+ # second pass for discriminator update
385
+ if cond is None:
386
+ logits_real = self.discriminator(inputs.contiguous().detach())
387
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
388
+ else:
389
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
390
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
391
+
392
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
393
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
394
+
395
+ log = {"Loss": d_loss.clone().detach().mean(),
396
+ "loss_disc": d_loss.clone().detach().mean(),
397
+ "logits_real": logits_real.detach().mean(),
398
+ "logits_fake": logits_fake.detach().mean()
399
+ }
400
+ return d_loss, log
lib/model_zoo/clip.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from functools import partial
5
+ from lib.model_zoo.common.get_model import register
6
+
7
+ symbol = 'clip'
8
+
9
+ class AbstractEncoder(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def encode(self, *args, **kwargs):
14
+ raise NotImplementedError
15
+
16
+ from transformers import CLIPTokenizer, CLIPTextModel
17
+
18
+ def disabled_train(self, mode=True):
19
+ """Overwrite model.train with this function to make sure train/eval mode
20
+ does not change anymore."""
21
+ return self
22
+
23
+ @register('clip_text_context_encoder_sdv1')
24
+ class CLIPTextContextEncoderSDv1(AbstractEncoder):
25
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
26
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True): # clip-vit-base-patch32
27
+ super().__init__()
28
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
29
+ self.transformer = CLIPTextModel.from_pretrained(version)
30
+ self.device = device
31
+ self.max_length = max_length
32
+ if freeze:
33
+ self.freeze()
34
+
35
+ def freeze(self):
36
+ self.transformer = self.transformer.eval()
37
+ for param in self.parameters():
38
+ param.requires_grad = False
39
+
40
+ def forward(self, text):
41
+ with torch.no_grad():
42
+ batch_encoding = self.tokenizer(
43
+ text, truncation=True, max_length=self.max_length, return_length=True,
44
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
45
+ tokens = batch_encoding["input_ids"].to(self.device)
46
+ max_token_n = self.transformer.text_model.embeddings.position_ids.shape[1]
47
+ positional_ids = torch.arange(max_token_n)[None].to(self.device)
48
+ outputs = self.transformer(
49
+ input_ids=tokens,
50
+ position_ids=positional_ids, )
51
+ z = outputs.last_hidden_state
52
+ return z
53
+
54
+ def encode(self, text):
55
+ return self(text)
56
+
57
+ #############################
58
+ # copyed from justin's code #
59
+ #############################
60
+
61
+ @register('clip_image_context_encoder_justin')
62
+ class CLIPImageContextEncoderJustin(AbstractEncoder):
63
+ """
64
+ Uses the CLIP image encoder.
65
+ """
66
+ def __init__(
67
+ self,
68
+ model='ViT-L/14',
69
+ jit=False,
70
+ device='cuda' if torch.cuda.is_available() else 'cpu',
71
+ antialias=False,
72
+ ):
73
+ super().__init__()
74
+ from . import clip_justin
75
+ self.model, _ = clip_justin.load(name=model, device=device, jit=jit)
76
+ self.device = device
77
+ self.antialias = antialias
78
+
79
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
80
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
81
+
82
+ # I didn't call this originally, but seems like it was frozen anyway
83
+ self.freeze()
84
+
85
+ def freeze(self):
86
+ self.transformer = self.model.eval()
87
+ for param in self.parameters():
88
+ param.requires_grad = False
89
+
90
+ def preprocess(self, x):
91
+ import kornia
92
+ # Expects inputs in the range -1, 1
93
+ x = kornia.geometry.resize(x, (224, 224),
94
+ interpolation='bicubic',align_corners=True,
95
+ antialias=self.antialias)
96
+ x = (x + 1.) / 2.
97
+ # renormalize according to clip
98
+ x = kornia.enhance.normalize(x, self.mean, self.std)
99
+ return x
100
+
101
+ def forward(self, x):
102
+ # x is assumed to be in range [-1,1]
103
+ return self.model.encode_image(self.preprocess(x)).float()
104
+
105
+ def encode(self, im):
106
+ return self(im).unsqueeze(1)
107
+
108
+ ###############
109
+ # for vd next #
110
+ ###############
111
+
112
+ from transformers import CLIPModel
113
+
114
+ @register('clip_text_context_encoder')
115
+ class CLIPTextContextEncoder(AbstractEncoder):
116
+ def __init__(self,
117
+ version="openai/clip-vit-large-patch14",
118
+ max_length=77,
119
+ fp16=False, ):
120
+ super().__init__()
121
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
122
+ self.model = CLIPModel.from_pretrained(version)
123
+ self.max_length = max_length
124
+ self.fp16 = fp16
125
+ self.freeze()
126
+
127
+ def get_device(self):
128
+ # A trick to get device
129
+ return self.model.text_projection.weight.device
130
+
131
+ def freeze(self):
132
+ self.model = self.model.eval()
133
+ self.train = disabled_train
134
+ for param in self.parameters():
135
+ param.requires_grad = False
136
+
137
+ def encode(self, text):
138
+ batch_encoding = self.tokenizer(
139
+ text, truncation=True, max_length=self.max_length, return_length=True,
140
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
141
+ tokens = batch_encoding["input_ids"].to(self.get_device())
142
+ outputs = self.model.text_model(input_ids=tokens)
143
+ z = self.model.text_projection(outputs.last_hidden_state)
144
+ z_pooled = self.model.text_projection(outputs.pooler_output)
145
+ z = z / torch.norm(z_pooled.unsqueeze(1), dim=-1, keepdim=True)
146
+ return z
147
+
148
+ from transformers import CLIPProcessor
149
+
150
+ @register('clip_image_context_encoder')
151
+ class CLIPImageContextEncoder(AbstractEncoder):
152
+ def __init__(self,
153
+ version="openai/clip-vit-large-patch14",
154
+ fp16=False, ):
155
+ super().__init__()
156
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
157
+ self.processor = CLIPProcessor.from_pretrained(version)
158
+ self.model = CLIPModel.from_pretrained(version)
159
+ self.fp16 = fp16
160
+ self.freeze()
161
+
162
+ def get_device(self):
163
+ # A trick to get device
164
+ return self.model.text_projection.weight.device
165
+
166
+ def freeze(self):
167
+ self.model = self.model.eval()
168
+ self.train = disabled_train
169
+ for param in self.parameters():
170
+ param.requires_grad = False
171
+
172
+ def _encode(self, images):
173
+ if isinstance(images, torch.Tensor):
174
+ import torchvision.transforms as tvtrans
175
+ images = [tvtrans.ToPILImage()(i) for i in images]
176
+ inputs = self.processor(images=images, return_tensors="pt")
177
+ pixels = inputs['pixel_values'].half() if self.fp16 else inputs['pixel_values']
178
+ pixels = pixels.to(self.get_device())
179
+ outputs = self.model.vision_model(pixel_values=pixels)
180
+ z = outputs.last_hidden_state
181
+ z = self.model.vision_model.post_layernorm(z)
182
+ z = self.model.visual_projection(z)
183
+ z_pooled = z[:, 0:1]
184
+ z = z / torch.norm(z_pooled, dim=-1, keepdim=True)
185
+ return z
186
+
187
+ @torch.no_grad()
188
+ def _encode_wmask(self, images, masks):
189
+ assert isinstance(masks, torch.Tensor)
190
+ assert (len(masks.shape)==4) and (masks.shape[1]==1)
191
+ masks = torch.clamp(masks, 0, 1)
192
+ masked_images = images*masks
193
+ masks = masks.float()
194
+ masks = F.interpolate(masks, [224, 224], mode='bilinear')
195
+ if masks.sum() == masks.numel():
196
+ return self._encode(images)
197
+
198
+ device = images.device
199
+ dtype = images.dtype
200
+ gscale = masks.mean(axis=[1, 2, 3], keepdim=True).flatten(2)
201
+
202
+ vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size
203
+ vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride
204
+ mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float()
205
+ vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2)
206
+ vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size)
207
+ vtoken_mask = torch.concat([gscale, vtoken_mask], axis=1)
208
+
209
+ import types
210
+ def customized_embedding_forward(self, pixel_values):
211
+ batch_size = pixel_values.shape[0]
212
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
213
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
214
+
215
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
216
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
217
+ embeddings = embeddings + self.position_embedding(self.position_ids)
218
+ embeddings = embeddings*vtoken_mask.to(embeddings.dtype)
219
+ return embeddings
220
+
221
+ old_forward = self.model.vision_model.embeddings.forward
222
+ self.model.vision_model.embeddings.forward = types.MethodType(
223
+ customized_embedding_forward, self.model.vision_model.embeddings)
224
+
225
+ z = self._encode(images)
226
+ self.model.vision_model.embeddings.forward = old_forward
227
+ z = z * vtoken_mask.to(dtype)
228
+ return z
229
+
230
+ # def _encode_wmask(self, images, masks):
231
+ # assert isinstance(masks, torch.Tensor)
232
+ # assert (len(masks.shape)==4) and (masks.shape[1]==1)
233
+ # masks = torch.clamp(masks, 0, 1)
234
+ # masks = masks.float()
235
+ # masks = F.interpolate(masks, [224, 224], mode='bilinear')
236
+ # if masks.sum() == masks.numel():
237
+ # return self._encode(images)
238
+
239
+ # device = images.device
240
+ # dtype = images.dtype
241
+
242
+ # vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size
243
+ # vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride
244
+ # mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float()
245
+ # vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2)
246
+ # vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size)
247
+
248
+ # z = self._encode(images)
249
+ # z[:, 1:, :] = z[:, 1:, :] * vtoken_mask.to(dtype)
250
+ # z[:, 0, :] = 0
251
+ # return z
252
+
253
+ def encode(self, images, masks=None):
254
+ if masks is None:
255
+ return self._encode(images)
256
+ else:
257
+ return self._encode_wmask(images, masks)
258
+
259
+ @register('clip_image_context_encoder_position_agnostic')
260
+ class CLIPImageContextEncoderPA(CLIPImageContextEncoder):
261
+ def __init__(self, *args, **kwargs):
262
+ super().__init__(*args, **kwargs)
263
+ import types
264
+ def customized_embedding_forward(self, pixel_values):
265
+ batch_size = pixel_values.shape[0]
266
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
267
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
268
+
269
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
270
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
271
+ pembeddings = self.position_embedding(self.position_ids)
272
+ pembeddings = torch.cat([
273
+ pembeddings[:, 0:1],
274
+ pembeddings[:, 1: ].mean(dim=1, keepdim=True).repeat(1, 256, 1)], dim=1)
275
+ embeddings = embeddings + pembeddings
276
+ return embeddings
277
+
278
+ self.model.vision_model.embeddings.forward = types.MethodType(
279
+ customized_embedding_forward, self.model.vision_model.embeddings)
280
+
281
+ ##############
282
+ # from sd2.0 #
283
+ ##############
284
+
285
+ import open_clip
286
+ import torch.nn.functional as F
287
+
288
+ @register('openclip_text_context_encoder_sdv2')
289
+ class FrozenOpenCLIPTextEmbedderSDv2(AbstractEncoder):
290
+ """
291
+ Uses the OpenCLIP transformer encoder for text
292
+ """
293
+ LAYERS = [
294
+ #"pooled",
295
+ "last",
296
+ "penultimate"
297
+ ]
298
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
299
+ freeze=True, layer="last"):
300
+ super().__init__()
301
+ assert layer in self.LAYERS
302
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
303
+ del model.visual
304
+ self.model = model
305
+
306
+ self.device = device
307
+ self.max_length = max_length
308
+ if freeze:
309
+ self.freeze()
310
+ self.layer = layer
311
+ if self.layer == "last":
312
+ self.layer_idx = 0
313
+ elif self.layer == "penultimate":
314
+ self.layer_idx = 1
315
+ else:
316
+ raise NotImplementedError()
317
+
318
+ def freeze(self):
319
+ self.model = self.model.eval()
320
+ for param in self.parameters():
321
+ param.requires_grad = False
322
+
323
+ def forward(self, text):
324
+ tokens = open_clip.tokenize(text)
325
+ z = self.encode_with_transformer(tokens.to(self.device))
326
+ return z
327
+
328
+ def encode_with_transformer(self, text):
329
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
330
+ x = x + self.model.positional_embedding
331
+ x = x.permute(1, 0, 2) # NLD -> LND
332
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
333
+ x = x.permute(1, 0, 2) # LND -> NLD
334
+ x = self.model.ln_final(x)
335
+ return x
336
+
337
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
338
+ for i, r in enumerate(self.model.transformer.resblocks):
339
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
340
+ break
341
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
342
+ x = checkpoint(r, x, attn_mask)
343
+ else:
344
+ x = r(x, attn_mask=attn_mask)
345
+ return x
346
+
347
+ def encode(self, text):
348
+ return self(text)
349
+
350
+ @register('openclip_text_context_encoder')
351
+ class FrozenOpenCLIPTextEmbedder(AbstractEncoder):
352
+ """
353
+ Uses the OpenCLIP transformer encoder for text
354
+ """
355
+ def __init__(self,
356
+ arch="ViT-H-14",
357
+ version="laion2b_s32b_b79k",
358
+ max_length=77,
359
+ freeze=True,):
360
+ super().__init__()
361
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
362
+ del model.visual
363
+ self.model = model
364
+ self.max_length = max_length
365
+ self.device = 'cpu'
366
+ if freeze:
367
+ self.freeze()
368
+
369
+ def to(self, device):
370
+ self.device = device
371
+ super().to(device)
372
+
373
+ def freeze(self):
374
+ self.model = self.model.eval()
375
+ for param in self.parameters():
376
+ param.requires_grad = False
377
+
378
+ def forward(self, text):
379
+ self.device = self.model.ln_final.weight.device # urgly trick
380
+ tokens = open_clip.tokenize(text)
381
+ z = self.encode_with_transformer(tokens.to(self.device))
382
+ return z
383
+
384
+ def encode_with_transformer(self, text):
385
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
386
+ x = x + self.model.positional_embedding
387
+ x = x.permute(1, 0, 2) # NLD -> LND
388
+ x = self.model.transformer(x, attn_mask=self.model.attn_mask)
389
+ x = x.permute(1, 0, 2) # LND -> NLD
390
+ x = self.model.ln_final(x)
391
+ x_pool = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection
392
+ # x_pool_debug = F.normalize(x_pool, dim=-1)
393
+ x = x @ self.model.text_projection
394
+ x = x / x_pool.norm(dim=1, keepdim=True).unsqueeze(1)
395
+ return x
396
+
397
+ def encode(self, text):
398
+ return self(text)
399
+
400
+ @register('openclip_image_context_encoder')
401
+ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
402
+ """
403
+ Uses the OpenCLIP transformer encoder for text
404
+ """
405
+ def __init__(self,
406
+ arch="ViT-H-14",
407
+ version="laion2b_s32b_b79k",
408
+ freeze=True,):
409
+ super().__init__()
410
+ model, _, preprocess = open_clip.create_model_and_transforms(
411
+ arch, device=torch.device('cpu'), pretrained=version)
412
+ self.model = model.visual
413
+ self.device = 'cpu'
414
+ import torchvision.transforms as tvtrans
415
+ # we only need resize & normalization
416
+ preprocess.transforms[0].size = [224, 224] # make it more precise
417
+ self.preprocess = tvtrans.Compose([
418
+ preprocess.transforms[0],
419
+ preprocess.transforms[4],])
420
+ if freeze:
421
+ self.freeze()
422
+
423
+ def to(self, device):
424
+ self.device = device
425
+ super().to(device)
426
+
427
+ def freeze(self):
428
+ self.model = self.model.eval()
429
+ for param in self.parameters():
430
+ param.requires_grad = False
431
+
432
+ def forward(self, image):
433
+ z = self.preprocess(image)
434
+ z = self.encode_with_transformer(z)
435
+ return z
436
+
437
+ def encode_with_transformer(self, image):
438
+ x = self.model.conv1(image)
439
+ x = x.reshape(x.shape[0], x.shape[1], -1)
440
+ x = x.permute(0, 2, 1)
441
+ x = torch.cat([
442
+ self.model.class_embedding.to(x.dtype)
443
+ + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
444
+ x], dim=1)
445
+ x = x + self.model.positional_embedding.to(x.dtype)
446
+ x = self.model.ln_pre(x)
447
+ x = x.permute(1, 0, 2)
448
+ x = self.model.transformer(x)
449
+ x = x.permute(1, 0, 2)
450
+
451
+ x = self.model.ln_post(x)
452
+ if self.model.proj is not None:
453
+ x = x @ self.model.proj
454
+
455
+ x_pool = x[:, 0, :]
456
+ # x_pool_debug = self.model(image)
457
+ # x_pooln_debug = F.normalize(x_pool_debug, dim=-1)
458
+ x = x / x_pool.norm(dim=1, keepdim=True).unsqueeze(1)
459
+ return x
460
+
461
+ def _encode(self, image):
462
+ return self(image)
463
+
464
+ def _encode_wmask(self, images, masks):
465
+ z = self._encode(images)
466
+ device = z.device
467
+ vtoken_kernel_size = self.model.conv1.kernel_size
468
+ vtoken_stride = self.model.conv1.stride
469
+ mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, dtype=z.dtype, requires_grad=False)
470
+ mask_kernal /= np.prod(vtoken_kernel_size)
471
+
472
+ assert isinstance(masks, torch.Tensor)
473
+ assert (len(masks.shape)==4) and (masks.shape[1]==1)
474
+ masks = torch.clamp(masks, 0, 1)
475
+ masks = F.interpolate(masks, [224, 224], mode='bilinear')
476
+
477
+ vtoken_mask = torch.nn.functional.conv2d(1-masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2)
478
+ z[:, 1:, :] = z[:, 1:, :] * vtoken_mask
479
+ z[:, 0, :] = 0
480
+ return z
481
+
482
+ def encode(self, images, masks=None):
483
+ if masks is None:
484
+ return self._encode(images)
485
+ else:
486
+ return self._encode_wmask(images, masks)
487
+
488
+ ############################
489
+ # def customized tokenizer #
490
+ ############################
491
+
492
+ from open_clip import SimpleTokenizer
493
+
494
+ @register('openclip_text_context_encoder_sdv2_customized_tokenizer_v1')
495
+ class FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV1(FrozenOpenCLIPTextEmbedderSDv2):
496
+ """
497
+ Uses the OpenCLIP transformer encoder for text
498
+ """
499
+ def __init__(self, customized_tokens, *args, **kwargs):
500
+ super().__init__(*args, **kwargs)
501
+ if isinstance(customized_tokens, str):
502
+ customized_tokens = [customized_tokens]
503
+ self.tokenizer = open_clip.SimpleTokenizer(special_tokens=customized_tokens)
504
+ self.num_regular_tokens = self.model.token_embedding.weight.shape[0]
505
+ self.embedding_dim = self.model.ln_final.weight.shape[0]
506
+ self.customized_token_embedding = nn.Embedding(
507
+ len(customized_tokens), embedding_dim=self.embedding_dim)
508
+ nn.init.normal_(self.customized_token_embedding.weight, std=0.02)
509
+
510
+ def tokenize(self, texts):
511
+ if isinstance(texts, str):
512
+ texts = [texts]
513
+ sot_token = self.tokenizer.encoder["<start_of_text>"]
514
+ eot_token = self.tokenizer.encoder["<end_of_text>"]
515
+ all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
516
+ maxn = self.num_regular_tokens
517
+ regular_tokens = [[ti if ti < maxn else 0 for ti in tokens] for tokens in all_tokens]
518
+ token_mask = [[0 if ti < maxn else 1 for ti in tokens] for tokens in all_tokens]
519
+ customized_tokens = [[ti-maxn if ti >= maxn else 0 for ti in tokens] for tokens in all_tokens]
520
+ return regular_tokens, customized_tokens, token_mask
521
+
522
+ def pad_to_length(self, tokens, context_length=77, eot_token=None):
523
+ result = torch.zeros(len(tokens), context_length, dtype=torch.long)
524
+ eot_token = self.tokenizer.encoder["<end_of_text>"] if eot_token is None else eot_token
525
+ for i, tokens in enumerate(tokens):
526
+ if len(tokens) > context_length:
527
+ tokens = tokens[:context_length] # Truncate
528
+ tokens[-1] = eot_token
529
+ result[i, :len(tokens)] = torch.tensor(tokens)
530
+ return result
531
+
532
+ def forward(self, text):
533
+ self.device = self.model.ln_final.weight.device # urgly trick
534
+ regular_tokens, customized_tokens, token_mask = self.tokenize(text)
535
+ regular_tokens = self.pad_to_length(regular_tokens).to(self.device)
536
+ customized_tokens = self.pad_to_length(customized_tokens, eot_token=0).to(self.device)
537
+ token_mask = self.pad_to_length(token_mask, eot_token=0).to(self.device)
538
+ z0 = self.encode_with_transformer(regular_tokens)
539
+ z1 = self.customized_token_embedding(customized_tokens)
540
+ token_mask = token_mask[:, :, None].type(z0.dtype)
541
+ z = z0 * (1-token_mask) + z1 * token_mask
542
+ return z
543
+
544
+ @register('openclip_text_context_encoder_sdv2_customized_tokenizer_v2')
545
+ class FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV2(FrozenOpenCLIPTextEmbedderSDv2):
546
+ """
547
+ Uses the OpenCLIP transformer encoder for text
548
+ """
549
+ def __init__(self, customized_tokens, *args, **kwargs):
550
+ super().__init__(*args, **kwargs)
551
+ if isinstance(customized_tokens, str):
552
+ customized_tokens = [customized_tokens]
553
+ self.tokenizer = open_clip.SimpleTokenizer(special_tokens=customized_tokens)
554
+ self.num_regular_tokens = self.model.token_embedding.weight.shape[0]
555
+ self.embedding_dim = self.model.token_embedding.weight.shape[1]
556
+ self.customized_token_embedding = nn.Embedding(
557
+ len(customized_tokens), embedding_dim=self.embedding_dim)
558
+ nn.init.normal_(self.customized_token_embedding.weight, std=0.02)
559
+
560
+ def tokenize(self, texts):
561
+ if isinstance(texts, str):
562
+ texts = [texts]
563
+ sot_token = self.tokenizer.encoder["<start_of_text>"]
564
+ eot_token = self.tokenizer.encoder["<end_of_text>"]
565
+ all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
566
+ maxn = self.num_regular_tokens
567
+ regular_tokens = [[ti if ti < maxn else 0 for ti in tokens] for tokens in all_tokens]
568
+ token_mask = [[0 if ti < maxn else 1 for ti in tokens] for tokens in all_tokens]
569
+ customized_tokens = [[ti-maxn if ti >= maxn else 0 for ti in tokens] for tokens in all_tokens]
570
+ return regular_tokens, customized_tokens, token_mask
571
+
572
+ def pad_to_length(self, tokens, context_length=77, eot_token=None):
573
+ result = torch.zeros(len(tokens), context_length, dtype=torch.long)
574
+ eot_token = self.tokenizer.encoder["<end_of_text>"] if eot_token is None else eot_token
575
+ for i, tokens in enumerate(tokens):
576
+ if len(tokens) > context_length:
577
+ tokens = tokens[:context_length] # Truncate
578
+ tokens[-1] = eot_token
579
+ result[i, :len(tokens)] = torch.tensor(tokens)
580
+ return result
581
+
582
+ def forward(self, text):
583
+ self.device = self.model.token_embedding.weight.device # urgly trick
584
+ regular_tokens, customized_tokens, token_mask = self.tokenize(text)
585
+ regular_tokens = self.pad_to_length(regular_tokens).to(self.device)
586
+ customized_tokens = self.pad_to_length(customized_tokens, eot_token=0).to(self.device)
587
+ token_mask = self.pad_to_length(token_mask, eot_token=0).to(self.device)
588
+ z = self.encode_with_transformer(regular_tokens, customized_tokens, token_mask)
589
+ return z
590
+
591
+ def encode_with_transformer(self, token, customized_token, token_mask):
592
+ x0 = self.model.token_embedding(token)
593
+ x1 = self.customized_token_embedding(customized_token)
594
+ token_mask = token_mask[:, :, None].type(x0.dtype)
595
+ x = x0 * (1-token_mask) + x1 * token_mask
596
+ x = x + self.model.positional_embedding
597
+ x = x.permute(1, 0, 2) # NLD -> LND
598
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
599
+ x = x.permute(1, 0, 2) # LND -> NLD
600
+ x = self.model.ln_final(x)
601
+ return x
602
+
603
+ class ln_freezed_temp(nn.LayerNorm):
604
+ def forward(self, x):
605
+ self.weight.requires_grad = False
606
+ self.bias.requires_grad = False
607
+ return super().forward(x)
608
+
609
+ @register('openclip_text_context_encoder_sdv2_customized_tokenizer_v3')
610
+ class FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV3(FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV2):
611
+ """
612
+ Uses the OpenCLIP transformer encoder for text
613
+ """
614
+ def __init__(self, customized_tokens, texpand=4, lora_rank=None, lora_bias_trainable=True, *args, **kwargs):
615
+ super().__init__(customized_tokens, *args, **kwargs)
616
+ if isinstance(customized_tokens, str):
617
+ customized_tokens = [customized_tokens]
618
+ self.texpand = texpand
619
+ self.customized_token_embedding = nn.Embedding(
620
+ len(customized_tokens)*texpand, embedding_dim=self.embedding_dim)
621
+ nn.init.normal_(self.customized_token_embedding.weight, std=0.02)
622
+
623
+ if lora_rank is not None:
624
+ from .lora import freeze_param, freeze_module, to_lora
625
+ def convert_resattnblock(module):
626
+ module.ln_1.__class__ = ln_freezed_temp
627
+ # freeze_module(module.ln_1)
628
+ module.attn = to_lora(module.attn, lora_rank, lora_bias_trainable)
629
+ module.ln_2.__class__ = ln_freezed_temp
630
+ # freeze_module(module.ln_2)
631
+ module.mlp.c_fc = to_lora(module.mlp.c_fc, lora_rank, lora_bias_trainable)
632
+ module.mlp.c_proj = to_lora(module.mlp.c_proj, lora_rank, lora_bias_trainable)
633
+ freeze_param(self.model, 'positional_embedding')
634
+ freeze_param(self.model, 'text_projection')
635
+ freeze_param(self.model, 'logit_scale')
636
+ for idx, resattnblock in enumerate(self.model.transformer.resblocks):
637
+ convert_resattnblock(resattnblock)
638
+ freeze_module(self.model.token_embedding)
639
+ self.model.ln_final.__class__ = ln_freezed_temp
640
+ # freeze_module(self.model.ln_final)
641
+
642
+ def tokenize(self, texts):
643
+ if isinstance(texts, str):
644
+ texts = [texts]
645
+ sot_token = self.tokenizer.encoder["<start_of_text>"]
646
+ eot_token = self.tokenizer.encoder["<end_of_text>"]
647
+ all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
648
+ maxn = self.num_regular_tokens
649
+ regular_tokens = [[[ti] if ti < maxn else [0]*self.texpand for ti in tokens] for tokens in all_tokens]
650
+ token_mask = [[[ 0] if ti < maxn else [1]*self.texpand for ti in tokens] for tokens in all_tokens]
651
+ custom_tokens = [[[ 0] if ti < maxn else [
652
+ (ti-maxn)*self.texpand+ii for ii in range(self.texpand)]
653
+ for ti in tokens] for tokens in all_tokens]
654
+
655
+ from itertools import chain
656
+ regular_tokens = [[i for i in chain(*tokens)] for tokens in regular_tokens]
657
+ token_mask = [[i for i in chain(*tokens)] for tokens in token_mask]
658
+ custom_tokens = [[i for i in chain(*tokens)] for tokens in custom_tokens]
659
+ return regular_tokens, custom_tokens, token_mask
660
+
661
+ ###################
662
+ # clip expandable #
663
+ ###################
664
+
665
+ @register('clip_text_sdv1_customized_embedding')
666
+ class CLIPTextSD1CE(nn.Module):
667
+ def __init__(
668
+ self,
669
+ replace_info="text|elon musk",
670
+ version="openai/clip-vit-large-patch14",
671
+ max_length=77):
672
+ super().__init__()
673
+
674
+ self.name = 'clip_text_sdv1_customized_embedding'
675
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
676
+ self.transformer = CLIPTextModel.from_pretrained(version)
677
+ self.reset_replace_info(replace_info)
678
+ self.max_length = max_length
679
+ self.special_token = "<new_token>"
680
+
681
+ def reset_replace_info(self, replace_info):
682
+ rtype, rpara = replace_info.split("|")
683
+ self.replace_type = rtype
684
+ if rtype == "token_embedding":
685
+ ce_num = int(rpara)
686
+ ce_dim = self.transformer.text_model.embeddings.token_embedding.weight.size(1)
687
+ self.cembedding = nn.Embedding(ce_num, ce_dim)
688
+ self.cembedding = self.cembedding.to(self.get_device())
689
+ elif rtype == "context_embedding":
690
+ ce_num = int(rpara)
691
+ ce_dim = self.transformer.text_model.encoder.layers[-1].layer_norm2.weight.size(0)
692
+ self.cembedding = nn.Embedding(ce_num, ce_dim)
693
+ self.cembedding = self.cembedding.to(self.get_device())
694
+ else:
695
+ assert rtype=="text"
696
+ self.replace_type = "text"
697
+ self.replace_string = rpara
698
+ self.cembedding = None
699
+
700
+ def get_device(self):
701
+ return self.transformer.text_model.embeddings.token_embedding.weight.device
702
+
703
+ def position_to_mask(self, tokens, positions):
704
+ mask = torch.zeros_like(tokens)
705
+ for idxb, idxs, idxe in zip(*positions):
706
+ mask[idxb, idxs:idxe] = 1
707
+ return mask
708
+
709
+ def forward(self, text):
710
+ tokens, positions = self.tokenize(text)
711
+ mask = self.position_to_mask(tokens, positions)
712
+ max_token_n = tokens.size(1)
713
+ positional_ids = torch.arange(max_token_n)[None].to(self.get_device())
714
+
715
+ if self.replace_what == 'token_embedding':
716
+ cembeds = self.cembedding(tokens * mask)
717
+
718
+ def embedding_customized_forward(
719
+ self, input_ids=None, position_ids=None, inputs_embeds=None,):
720
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
721
+ if position_ids is None:
722
+ position_ids = self.position_ids[:, :seq_length]
723
+ if inputs_embeds is None:
724
+ inputs_embeds = self.token_embedding(input_ids)
725
+ inputs_embeds = inputs_embeds * (1-mask.float())[:, :, None]
726
+ inputs_embeds = inputs_embeds + cembeds
727
+ position_embeddings = self.position_embedding(position_ids)
728
+ embeddings = inputs_embeds + position_embeddings
729
+ return embeddings
730
+
731
+ import types
732
+ self.transformer.text_model.embeddings.forward = types.MethodType(
733
+ embedding_customized_forward, self.transformer.text_model.embeddings)
734
+
735
+ else:
736
+ # TODO: Implement
737
+ assert False
738
+
739
+ outputs = self.transformer(
740
+ input_ids=tokens,
741
+ position_ids=positional_ids, )
742
+ z = outputs.last_hidden_state
743
+ return z
744
+
745
+ def encode(self, text):
746
+ return self(text)
747
+
748
+ @torch.no_grad()
749
+ def tokenize(self, text):
750
+ if isinstance(text, str):
751
+ text = [text]
752
+
753
+ bos_special_text = "<|startoftext|>"
754
+ text = [ti.replace(self.special_token, bos_special_text) for ti in text]
755
+
756
+ batch_encoding = self.tokenizer(
757
+ text, truncation=True, max_length=self.max_length, return_length=True,
758
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
759
+ tokens = batch_encoding["input_ids"]
760
+
761
+ bosid = tokens[0, 0]
762
+ eosid = tokens[0, -1]
763
+ bs, maxn = tokens.shape
764
+
765
+ if self.replace_what in ['token_embedding', 'context_embedding']:
766
+ newtokens = []
767
+ ce_num = self.cembedding.weight.size(0)
768
+ idxi = []; idxstart = []; idxend = [];
769
+ for idxii, tokeni in enumerate(tokens):
770
+ newtokeni = []
771
+ idxjj = 0
772
+ for ii, tokenii in enumerate(tokeni):
773
+ if (tokenii == bosid) and (ii != 0):
774
+ newtokeni.extend([i for i in range(ce_num)])
775
+ idxi.append(idxii); idxstart.append(idxjj);
776
+ idxjj += ce_num
777
+ idxjj_record = idxjj if idxjj<=maxn-1 else maxn-1
778
+ idxend.append(idxjj_record);
779
+ else:
780
+ newtokeni.extend([tokenii])
781
+ idxjj += 1
782
+ newtokeni = newtokeni[:maxn]
783
+ newtokeni[-1] = eosid
784
+ newtokens.append(newtokeni)
785
+ return torch.LongTensor(newtokens).to(self.get_device()), (idxi, idxstart, idxend)
786
+ else:
787
+ # TODO: Implement
788
+ assert False
lib/model_zoo/common/__pycache__/get_model.cpython-310.pyc ADDED
Binary file (3.32 kB). View file
 
lib/model_zoo/common/__pycache__/get_optimizer.cpython-310.pyc ADDED
Binary file (1.95 kB). View file
 
lib/model_zoo/common/__pycache__/get_scheduler.cpython-310.pyc ADDED
Binary file (9.44 kB). View file
 
lib/model_zoo/common/__pycache__/utils.cpython-310.pyc ADDED
Binary file (9.72 kB). View file
 
lib/model_zoo/common/get_model.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from email.policy import strict
2
+ import torch
3
+ import torchvision.models
4
+ import os.path as osp
5
+ import copy
6
+ from ...log_service import print_log
7
+ from .utils import \
8
+ get_total_param, get_total_param_sum, \
9
+ get_unit
10
+
11
+ # def load_state_dict(net, model_path):
12
+ # if isinstance(net, dict):
13
+ # for ni, neti in net.items():
14
+ # paras = torch.load(model_path[ni], map_location=torch.device('cpu'))
15
+ # new_paras = neti.state_dict()
16
+ # new_paras.update(paras)
17
+ # neti.load_state_dict(new_paras)
18
+ # else:
19
+ # paras = torch.load(model_path, map_location=torch.device('cpu'))
20
+ # new_paras = net.state_dict()
21
+ # new_paras.update(paras)
22
+ # net.load_state_dict(new_paras)
23
+ # return
24
+
25
+ # def save_state_dict(net, path):
26
+ # if isinstance(net, (torch.nn.DataParallel,
27
+ # torch.nn.parallel.DistributedDataParallel)):
28
+ # torch.save(net.module.state_dict(), path)
29
+ # else:
30
+ # torch.save(net.state_dict(), path)
31
+
32
+ def singleton(class_):
33
+ instances = {}
34
+ def getinstance(*args, **kwargs):
35
+ if class_ not in instances:
36
+ instances[class_] = class_(*args, **kwargs)
37
+ return instances[class_]
38
+ return getinstance
39
+
40
+ def preprocess_model_args(args):
41
+ # If args has layer_units, get the corresponding
42
+ # units.
43
+ # If args get backbone, get the backbone model.
44
+ args = copy.deepcopy(args)
45
+ if 'layer_units' in args:
46
+ layer_units = [
47
+ get_unit()(i) for i in args.layer_units
48
+ ]
49
+ args.layer_units = layer_units
50
+ if 'backbone' in args:
51
+ args.backbone = get_model()(args.backbone)
52
+ return args
53
+
54
+ @singleton
55
+ class get_model(object):
56
+ def __init__(self):
57
+ self.model = {}
58
+
59
+ def register(self, model, name):
60
+ self.model[name] = model
61
+
62
+ def __call__(self, cfg, verbose=True):
63
+ """
64
+ Construct model based on the config.
65
+ """
66
+ if cfg is None:
67
+ return None
68
+
69
+ t = cfg.type
70
+
71
+ # the register is in each file
72
+ if t.find('pfd')==0:
73
+ from .. import pfd
74
+ elif t=='autoencoderkl':
75
+ from .. import autokl
76
+ elif (t.find('clip')==0) or (t.find('openclip')==0):
77
+ from .. import clip
78
+ elif t.find('openai_unet')==0:
79
+ from .. import openaimodel
80
+ elif t.find('controlnet')==0:
81
+ from .. import controlnet
82
+ elif t.find('seecoder')==0:
83
+ from .. import seecoder
84
+ elif t.find('swin')==0:
85
+ from .. import swin
86
+
87
+ args = preprocess_model_args(cfg.args)
88
+ net = self.model[t](**args)
89
+
90
+ pretrained = cfg.get('pretrained', None)
91
+ if pretrained is None: # backward compatible
92
+ pretrained = cfg.get('pth', None)
93
+ map_location = cfg.get('map_location', 'cpu')
94
+ strict_sd = cfg.get('strict_sd', True)
95
+
96
+ if pretrained is not None:
97
+ if osp.splitext(pretrained)[1] == '.pth':
98
+ sd = torch.load(pretrained, map_location=map_location)
99
+ elif osp.splitext(pretrained)[1] == '.ckpt':
100
+ sd = torch.load(pretrained, map_location=map_location)['state_dict']
101
+ elif osp.splitext(pretrained)[1] == '.safetensors':
102
+ from safetensors.torch import load_file
103
+ from collections import OrderedDict
104
+ sd = load_file(pretrained, map_location)
105
+ sd = OrderedDict(sd)
106
+ net.load_state_dict(sd, strict=strict_sd)
107
+ if verbose:
108
+ print_log('Load model from [{}] strict [{}].'.format(pretrained, strict_sd))
109
+
110
+ # display param_num & param_sum
111
+ if verbose:
112
+ print_log(
113
+ 'Load {} with total {} parameters,'
114
+ '{:.3f} parameter sum.'.format(
115
+ t,
116
+ get_total_param(net),
117
+ get_total_param_sum(net) ))
118
+ return net
119
+
120
+ def register(name):
121
+ def wrapper(class_):
122
+ get_model().register(class_, name)
123
+ return class_
124
+ return wrapper
lib/model_zoo/common/get_optimizer.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import numpy as np
4
+ import itertools
5
+
6
+ def singleton(class_):
7
+ instances = {}
8
+ def getinstance(*args, **kwargs):
9
+ if class_ not in instances:
10
+ instances[class_] = class_(*args, **kwargs)
11
+ return instances[class_]
12
+ return getinstance
13
+
14
+ class get_optimizer(object):
15
+ def __init__(self):
16
+ self.optimizer = {}
17
+ self.register(optim.SGD, 'sgd')
18
+ self.register(optim.Adam, 'adam')
19
+ self.register(optim.AdamW, 'adamw')
20
+
21
+ def register(self, optim, name):
22
+ self.optimizer[name] = optim
23
+
24
+ def __call__(self, net, cfg):
25
+ if cfg is None:
26
+ return None
27
+ t = cfg.type
28
+ if isinstance(net, (torch.nn.DataParallel,
29
+ torch.nn.parallel.DistributedDataParallel)):
30
+ netm = net.module
31
+ else:
32
+ netm = net
33
+ pg = getattr(netm, 'parameter_group', None)
34
+
35
+ if pg is not None:
36
+ params = []
37
+ for group_name, module_or_para in pg.items():
38
+ if not isinstance(module_or_para, list):
39
+ module_or_para = [module_or_para]
40
+
41
+ grouped_params = [mi.parameters() if isinstance(mi, torch.nn.Module) else [mi] for mi in module_or_para]
42
+ grouped_params = itertools.chain(*grouped_params)
43
+ pg_dict = {'params':grouped_params, 'name':group_name}
44
+ params.append(pg_dict)
45
+ else:
46
+ params = net.parameters()
47
+ return self.optimizer[t](params, lr=0, **cfg.args)
lib/model_zoo/common/get_scheduler.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import numpy as np
4
+ import copy
5
+ from ... import sync
6
+ from ...cfg_holder import cfg_unique_holder as cfguh
7
+
8
+ def singleton(class_):
9
+ instances = {}
10
+ def getinstance(*args, **kwargs):
11
+ if class_ not in instances:
12
+ instances[class_] = class_(*args, **kwargs)
13
+ return instances[class_]
14
+ return getinstance
15
+
16
+ @singleton
17
+ class get_scheduler(object):
18
+ def __init__(self):
19
+ self.lr_scheduler = {}
20
+
21
+ def register(self, lrsf, name):
22
+ self.lr_scheduler[name] = lrsf
23
+
24
+ def __call__(self, cfg):
25
+ if cfg is None:
26
+ return None
27
+ if isinstance(cfg, list):
28
+ schedulers = []
29
+ for ci in cfg:
30
+ t = ci.type
31
+ schedulers.append(
32
+ self.lr_scheduler[t](**ci.args))
33
+ if len(schedulers) == 0:
34
+ raise ValueError
35
+ else:
36
+ return compose_scheduler(schedulers)
37
+ t = cfg.type
38
+ return self.lr_scheduler[t](**cfg.args)
39
+
40
+
41
+ def register(name):
42
+ def wrapper(class_):
43
+ get_scheduler().register(class_, name)
44
+ return class_
45
+ return wrapper
46
+
47
+ class template_scheduler(object):
48
+ def __init__(self, step):
49
+ self.step = step
50
+
51
+ def __getitem__(self, idx):
52
+ raise ValueError
53
+
54
+ def set_lr(self, optim, new_lr, pg_lrscale=None):
55
+ """
56
+ Set Each parameter_groups in optim with new_lr
57
+ New_lr can be find according to the idx.
58
+ pg_lrscale tells how to scale each pg.
59
+ """
60
+ # new_lr = self.__getitem__(idx)
61
+ pg_lrscale = copy.deepcopy(pg_lrscale)
62
+ for pg in optim.param_groups:
63
+ if pg_lrscale is None:
64
+ pg['lr'] = new_lr
65
+ else:
66
+ pg['lr'] = new_lr * pg_lrscale.pop(pg['name'])
67
+ assert (pg_lrscale is None) or (len(pg_lrscale)==0), \
68
+ "pg_lrscale doesn't match pg"
69
+
70
+ @register('constant')
71
+ class constant_scheduler(template_scheduler):
72
+ def __init__(self, lr, step):
73
+ super().__init__(step)
74
+ self.lr = lr
75
+
76
+ def __getitem__(self, idx):
77
+ if idx >= self.step:
78
+ raise ValueError
79
+ return self.lr
80
+
81
+ @register('poly')
82
+ class poly_scheduler(template_scheduler):
83
+ def __init__(self, start_lr, end_lr, power, step):
84
+ super().__init__(step)
85
+ self.start_lr = start_lr
86
+ self.end_lr = end_lr
87
+ self.power = power
88
+
89
+ def __getitem__(self, idx):
90
+ if idx >= self.step:
91
+ raise ValueError
92
+ a, b = self.start_lr, self.end_lr
93
+ p, n = self.power, self.step
94
+ return b + (a-b)*((1-idx/n)**p)
95
+
96
+ @register('linear')
97
+ class linear_scheduler(template_scheduler):
98
+ def __init__(self, start_lr, end_lr, step):
99
+ super().__init__(step)
100
+ self.start_lr = start_lr
101
+ self.end_lr = end_lr
102
+
103
+ def __getitem__(self, idx):
104
+ if idx >= self.step:
105
+ raise ValueError
106
+ a, b, n = self.start_lr, self.end_lr, self.step
107
+ return b + (a-b)*(1-idx/n)
108
+
109
+ @register('multistage')
110
+ class constant_scheduler(template_scheduler):
111
+ def __init__(self, start_lr, milestones, gamma, step):
112
+ super().__init__(step)
113
+ self.start_lr = start_lr
114
+ m = [0] + milestones + [step]
115
+ lr_iter = start_lr
116
+ self.lr = []
117
+ for ms, me in zip(m[0:-1], m[1:]):
118
+ for _ in range(ms, me):
119
+ self.lr.append(lr_iter)
120
+ lr_iter *= gamma
121
+
122
+ def __getitem__(self, idx):
123
+ if idx >= self.step:
124
+ raise ValueError
125
+ return self.lr[idx]
126
+
127
+ class compose_scheduler(template_scheduler):
128
+ def __init__(self, schedulers):
129
+ self.schedulers = schedulers
130
+ self.step = [si.step for si in schedulers]
131
+ self.step_milestone = []
132
+ acc = 0
133
+ for i in self.step:
134
+ acc += i
135
+ self.step_milestone.append(acc)
136
+ self.step = sum(self.step)
137
+
138
+ def __getitem__(self, idx):
139
+ if idx >= self.step:
140
+ raise ValueError
141
+ ms = self.step_milestone
142
+ for idx, (mi, mj) in enumerate(zip(ms[:-1], ms[1:])):
143
+ if mi <= idx < mj:
144
+ return self.schedulers[idx-mi]
145
+ raise ValueError
146
+
147
+ ####################
148
+ # lambda schedular #
149
+ ####################
150
+
151
+ class LambdaWarmUpCosineScheduler(template_scheduler):
152
+ """
153
+ note: use with a base_lr of 1.0
154
+ """
155
+ def __init__(self,
156
+ base_lr,
157
+ warm_up_steps,
158
+ lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
159
+ cfgt = cfguh().cfg.train
160
+ bs = cfgt.batch_size
161
+ if 'gradacc_every' not in cfgt:
162
+ print('Warning, gradacc_every is not found in xml, use 1 as default.')
163
+ acc = cfgt.get('gradacc_every', 1)
164
+ self.lr_multi = base_lr * bs * acc
165
+ self.lr_warm_up_steps = warm_up_steps
166
+ self.lr_start = lr_start
167
+ self.lr_min = lr_min
168
+ self.lr_max = lr_max
169
+ self.lr_max_decay_steps = max_decay_steps
170
+ self.last_lr = 0.
171
+ self.verbosity_interval = verbosity_interval
172
+
173
+ def schedule(self, n):
174
+ if self.verbosity_interval > 0:
175
+ if n % self.verbosity_interval == 0:
176
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
177
+ if n < self.lr_warm_up_steps:
178
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
179
+ self.last_lr = lr
180
+ return lr
181
+ else:
182
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
183
+ t = min(t, 1.0)
184
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
185
+ 1 + np.cos(t * np.pi))
186
+ self.last_lr = lr
187
+ return lr
188
+
189
+ def __getitem__(self, idx):
190
+ return self.schedule(idx) * self.lr_multi
191
+
192
+ class LambdaWarmUpCosineScheduler2(template_scheduler):
193
+ """
194
+ supports repeated iterations, configurable via lists
195
+ note: use with a base_lr of 1.0.
196
+ """
197
+ def __init__(self,
198
+ base_lr,
199
+ warm_up_steps,
200
+ f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
201
+ cfgt = cfguh().cfg.train
202
+ # bs = cfgt.batch_size
203
+ # if 'gradacc_every' not in cfgt:
204
+ # print('Warning, gradacc_every is not found in xml, use 1 as default.')
205
+ # acc = cfgt.get('gradacc_every', 1)
206
+ # self.lr_multi = base_lr * bs * acc
207
+ self.lr_multi = base_lr
208
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
209
+ self.lr_warm_up_steps = warm_up_steps
210
+ self.f_start = f_start
211
+ self.f_min = f_min
212
+ self.f_max = f_max
213
+ self.cycle_lengths = cycle_lengths
214
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
215
+ self.last_f = 0.
216
+ self.verbosity_interval = verbosity_interval
217
+
218
+ def find_in_interval(self, n):
219
+ interval = 0
220
+ for cl in self.cum_cycles[1:]:
221
+ if n <= cl:
222
+ return interval
223
+ interval += 1
224
+
225
+ def schedule(self, n):
226
+ cycle = self.find_in_interval(n)
227
+ n = n - self.cum_cycles[cycle]
228
+ if self.verbosity_interval > 0:
229
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
230
+ f"current cycle {cycle}")
231
+ if n < self.lr_warm_up_steps[cycle]:
232
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
233
+ self.last_f = f
234
+ return f
235
+ else:
236
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
237
+ t = min(t, 1.0)
238
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
239
+ 1 + np.cos(t * np.pi))
240
+ self.last_f = f
241
+ return f
242
+
243
+ def __getitem__(self, idx):
244
+ return self.schedule(idx) * self.lr_multi
245
+
246
+ @register('stable_diffusion_linear')
247
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
248
+ def schedule(self, n):
249
+ cycle = self.find_in_interval(n)
250
+ n = n - self.cum_cycles[cycle]
251
+ if self.verbosity_interval > 0:
252
+ if n % self.verbosity_interval == 0:
253
+ print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
254
+ f"current cycle {cycle}")
255
+ if n < self.lr_warm_up_steps[cycle]:
256
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
257
+ self.last_f = f
258
+ return f
259
+ else:
260
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
261
+ self.last_f = f
262
+ return f
lib/model_zoo/common/utils.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import copy
6
+ import functools
7
+ import itertools
8
+
9
+ import matplotlib.pyplot as plt
10
+
11
+ ########
12
+ # unit #
13
+ ########
14
+
15
+ def singleton(class_):
16
+ instances = {}
17
+ def getinstance(*args, **kwargs):
18
+ if class_ not in instances:
19
+ instances[class_] = class_(*args, **kwargs)
20
+ return instances[class_]
21
+ return getinstance
22
+
23
+ def str2value(v):
24
+ v = v.strip()
25
+ try:
26
+ return int(v)
27
+ except:
28
+ pass
29
+ try:
30
+ return float(v)
31
+ except:
32
+ pass
33
+ if v in ('True', 'true'):
34
+ return True
35
+ elif v in ('False', 'false'):
36
+ return False
37
+ else:
38
+ return v
39
+
40
+ @singleton
41
+ class get_unit(object):
42
+ def __init__(self):
43
+ self.unit = {}
44
+ self.register('none', None)
45
+
46
+ # general convolution
47
+ self.register('conv' , nn.Conv2d)
48
+ self.register('bn' , nn.BatchNorm2d)
49
+ self.register('relu' , nn.ReLU)
50
+ self.register('relu6' , nn.ReLU6)
51
+ self.register('lrelu' , nn.LeakyReLU)
52
+ self.register('dropout' , nn.Dropout)
53
+ self.register('dropout2d', nn.Dropout2d)
54
+ self.register('sine', Sine)
55
+ self.register('relusine', ReLUSine)
56
+
57
+ def register(self,
58
+ name,
59
+ unitf,):
60
+
61
+ self.unit[name] = unitf
62
+
63
+ def __call__(self, name):
64
+ if name is None:
65
+ return None
66
+ i = name.find('(')
67
+ i = len(name) if i==-1 else i
68
+ t = name[:i]
69
+ f = self.unit[t]
70
+ args = name[i:].strip('()')
71
+ if len(args) == 0:
72
+ args = {}
73
+ return f
74
+ else:
75
+ args = args.split('=')
76
+ args = [[','.join(i.split(',')[:-1]), i.split(',')[-1]] for i in args]
77
+ args = list(itertools.chain.from_iterable(args))
78
+ args = [i.strip() for i in args if len(i)>0]
79
+ kwargs = {}
80
+ for k, v in zip(args[::2], args[1::2]):
81
+ if v[0]=='(' and v[-1]==')':
82
+ kwargs[k] = tuple([str2value(i) for i in v.strip('()').split(',')])
83
+ elif v[0]=='[' and v[-1]==']':
84
+ kwargs[k] = [str2value(i) for i in v.strip('[]').split(',')]
85
+ else:
86
+ kwargs[k] = str2value(v)
87
+ return functools.partial(f, **kwargs)
88
+
89
+ def register(name):
90
+ def wrapper(class_):
91
+ get_unit().register(name, class_)
92
+ return class_
93
+ return wrapper
94
+
95
+ class Sine(object):
96
+ def __init__(self, freq, gain=1):
97
+ self.freq = freq
98
+ self.gain = gain
99
+ self.repr = 'sine(freq={}, gain={})'.format(freq, gain)
100
+
101
+ def __call__(self, x, gain=1):
102
+ act_gain = self.gain * gain
103
+ return torch.sin(self.freq * x) * act_gain
104
+
105
+ def __repr__(self,):
106
+ return self.repr
107
+
108
+ class ReLUSine(nn.Module):
109
+ def __init(self):
110
+ super().__init__()
111
+
112
+ def forward(self, input):
113
+ a = torch.sin(30 * input)
114
+ b = nn.ReLU(inplace=False)(input)
115
+ return a+b
116
+
117
+ @register('lrelu_agc')
118
+ # class lrelu_agc(nn.Module):
119
+ class lrelu_agc(object):
120
+ """
121
+ The lrelu layer with alpha, gain and clamp
122
+ """
123
+ def __init__(self, alpha=0.1, gain=1, clamp=None):
124
+ # super().__init__()
125
+ self.alpha = alpha
126
+ if gain == 'sqrt_2':
127
+ self.gain = np.sqrt(2)
128
+ else:
129
+ self.gain = gain
130
+ self.clamp = clamp
131
+ self.repr = 'lrelu_agc(alpha={}, gain={}, clamp={})'.format(
132
+ alpha, gain, clamp)
133
+
134
+ # def forward(self, x, gain=1):
135
+ def __call__(self, x, gain=1):
136
+ x = F.leaky_relu(x, negative_slope=self.alpha, inplace=True)
137
+ act_gain = self.gain * gain
138
+ act_clamp = self.clamp * gain if self.clamp is not None else None
139
+ if act_gain != 1:
140
+ x = x * act_gain
141
+ if act_clamp is not None:
142
+ x = x.clamp(-act_clamp, act_clamp)
143
+ return x
144
+
145
+ def __repr__(self,):
146
+ return self.repr
147
+
148
+ ####################
149
+ # spatial encoding #
150
+ ####################
151
+
152
+ @register('se')
153
+ class SpatialEncoding(nn.Module):
154
+ def __init__(self,
155
+ in_dim,
156
+ out_dim,
157
+ sigma = 6,
158
+ cat_input=True,
159
+ require_grad=False,):
160
+
161
+ super().__init__()
162
+ assert out_dim % (2*in_dim) == 0, "dimension must be dividable"
163
+
164
+ n = out_dim // 2 // in_dim
165
+ m = 2**np.linspace(0, sigma, n)
166
+ m = np.stack([m] + [np.zeros_like(m)]*(in_dim-1), axis=-1)
167
+ m = np.concatenate([np.roll(m, i, axis=-1) for i in range(in_dim)], axis=0)
168
+ self.emb = torch.FloatTensor(m)
169
+ if require_grad:
170
+ self.emb = nn.Parameter(self.emb, requires_grad=True)
171
+ self.in_dim = in_dim
172
+ self.out_dim = out_dim
173
+ self.sigma = sigma
174
+ self.cat_input = cat_input
175
+ self.require_grad = require_grad
176
+
177
+ def forward(self, x, format='[n x c]'):
178
+ """
179
+ Args:
180
+ x: [n x m1],
181
+ m1 usually is 2
182
+ Outputs:
183
+ y: [n x m2]
184
+ m2 dimention number
185
+ """
186
+ if format == '[bs x c x 2D]':
187
+ xshape = x.shape
188
+ x = x.permute(0, 2, 3, 1).contiguous()
189
+ x = x.view(-1, x.size(-1))
190
+ elif format == '[n x c]':
191
+ pass
192
+ else:
193
+ raise ValueError
194
+
195
+ if not self.require_grad:
196
+ self.emb = self.emb.to(x.device)
197
+ y = torch.mm(x, self.emb.T)
198
+ if self.cat_input:
199
+ z = torch.cat([x, torch.sin(y), torch.cos(y)], dim=-1)
200
+ else:
201
+ z = torch.cat([torch.sin(y), torch.cos(y)], dim=-1)
202
+
203
+ if format == '[bs x c x 2D]':
204
+ z = z.view(xshape[0], xshape[2], xshape[3], -1)
205
+ z = z.permute(0, 3, 1, 2).contiguous()
206
+ return z
207
+
208
+ def extra_repr(self):
209
+ outstr = 'SpatialEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
210
+ self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
211
+ return outstr
212
+
213
+ @register('rffe')
214
+ class RFFEncoding(SpatialEncoding):
215
+ """
216
+ Random Fourier Features
217
+ """
218
+ def __init__(self,
219
+ in_dim,
220
+ out_dim,
221
+ sigma = 6,
222
+ cat_input=True,
223
+ require_grad=False,):
224
+
225
+ super().__init__(in_dim, out_dim, sigma, cat_input, require_grad)
226
+ n = out_dim // 2
227
+ m = np.random.normal(0, sigma, size=(n, in_dim))
228
+ self.emb = torch.FloatTensor(m)
229
+ if require_grad:
230
+ self.emb = nn.Parameter(self.emb, requires_grad=True)
231
+
232
+ def extra_repr(self):
233
+ outstr = 'RFFEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
234
+ self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
235
+ return outstr
236
+
237
+ ##########
238
+ # helper #
239
+ ##########
240
+
241
+ def freeze(net):
242
+ for m in net.modules():
243
+ if isinstance(m, (
244
+ nn.BatchNorm2d,
245
+ nn.SyncBatchNorm,)):
246
+ # inplace_abn not supported
247
+ m.eval()
248
+ for pi in net.parameters():
249
+ pi.requires_grad = False
250
+ return net
251
+
252
+ def common_init(m):
253
+ if isinstance(m, (
254
+ nn.Conv2d,
255
+ nn.ConvTranspose2d,)):
256
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
257
+ if m.bias is not None:
258
+ nn.init.constant_(m.bias, 0)
259
+ elif isinstance(m, (
260
+ nn.BatchNorm2d,
261
+ nn.SyncBatchNorm,)):
262
+ nn.init.constant_(m.weight, 1)
263
+ nn.init.constant_(m.bias, 0)
264
+ else:
265
+ pass
266
+
267
+ def init_module(module):
268
+ """
269
+ Args:
270
+ module: [nn.module] list or nn.module
271
+ a list of module to be initialized.
272
+ """
273
+ if isinstance(module, (list, tuple)):
274
+ module = list(module)
275
+ else:
276
+ module = [module]
277
+
278
+ for mi in module:
279
+ for mii in mi.modules():
280
+ common_init(mii)
281
+
282
+ def get_total_param(net):
283
+ if getattr(net, 'parameters', None) is None:
284
+ return 0
285
+ return sum(p.numel() for p in net.parameters())
286
+
287
+ def get_total_param_sum(net):
288
+ if getattr(net, 'parameters', None) is None:
289
+ return 0
290
+ with torch.no_grad():
291
+ s = sum(p.cpu().detach().numpy().sum().item() for p in net.parameters())
292
+ return s
lib/model_zoo/controlnet.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import numpy.random as npr
6
+ import copy
7
+ from functools import partial
8
+ from contextlib import contextmanager
9
+ from lib.model_zoo.common.get_model import get_model, register
10
+ from lib.log_service import print_log
11
+
12
+ from .openaimodel import \
13
+ TimestepEmbedSequential, conv_nd, zero_module, \
14
+ ResBlock, AttentionBlock, SpatialTransformer, \
15
+ Downsample, timestep_embedding
16
+
17
+ ####################
18
+ # preprocess depth #
19
+ ####################
20
+
21
+ # depth_model = None
22
+
23
+ # def unload_midas_model():
24
+ # global depth_model
25
+ # if depth_model is not None:
26
+ # depth_model = depth_model.cpu()
27
+
28
+ # def apply_midas(input_image, a=np.pi*2.0, bg_th=0.1, device='cpu'):
29
+ # import cv2
30
+ # from einops import rearrange
31
+ # from .controlnet_annotators.midas import MiDaSInference
32
+ # global depth_model
33
+ # if depth_model is None:
34
+ # depth_model = MiDaSInference(model_type="dpt_hybrid")
35
+ # depth_model = depth_model.to(device)
36
+
37
+ # assert input_image.ndim == 3
38
+ # image_depth = input_image
39
+ # with torch.no_grad():
40
+ # image_depth = torch.from_numpy(image_depth).float()
41
+ # image_depth = image_depth.to(device)
42
+ # image_depth = image_depth / 127.5 - 1.0
43
+ # image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
44
+ # depth = depth_model(image_depth)[0]
45
+
46
+ # depth_pt = depth.clone()
47
+ # depth_pt -= torch.min(depth_pt)
48
+ # depth_pt /= torch.max(depth_pt)
49
+ # depth_pt = depth_pt.cpu().numpy()
50
+ # depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
51
+
52
+ # depth_np = depth.cpu().numpy()
53
+ # x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
54
+ # y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
55
+ # z = np.ones_like(x) * a
56
+ # x[depth_pt < bg_th] = 0
57
+ # y[depth_pt < bg_th] = 0
58
+ # normal = np.stack([x, y, z], axis=2)
59
+ # normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
60
+ # normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
61
+
62
+ # return depth_image, normal_image
63
+
64
+
65
+ @register('controlnet')
66
+ class ControlNet(nn.Module):
67
+ def __init__(
68
+ self,
69
+ image_size,
70
+ in_channels,
71
+ model_channels,
72
+ hint_channels,
73
+ num_res_blocks,
74
+ attention_resolutions,
75
+ dropout=0,
76
+ channel_mult=(1, 2, 4, 8),
77
+ conv_resample=True,
78
+ dims=2,
79
+ use_checkpoint=False,
80
+ use_fp16=False,
81
+ num_heads=-1,
82
+ num_head_channels=-1,
83
+ num_heads_upsample=-1,
84
+ use_scale_shift_norm=False,
85
+ resblock_updown=False,
86
+ use_new_attention_order=False,
87
+ use_spatial_transformer=False, # custom transformer support
88
+ transformer_depth=1, # custom transformer support
89
+ context_dim=None, # custom transformer support
90
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
91
+ legacy=True,
92
+ disable_self_attentions=None,
93
+ num_attention_blocks=None,
94
+ disable_middle_self_attn=False,
95
+ use_linear_in_transformer=False,
96
+ ):
97
+ super().__init__()
98
+ if use_spatial_transformer:
99
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
100
+
101
+ if context_dim is not None:
102
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
103
+ from omegaconf.listconfig import ListConfig
104
+ if type(context_dim) == ListConfig:
105
+ context_dim = list(context_dim)
106
+
107
+ if num_heads_upsample == -1:
108
+ num_heads_upsample = num_heads
109
+
110
+ if num_heads == -1:
111
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
112
+
113
+ if num_head_channels == -1:
114
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
115
+
116
+ self.dims = dims
117
+ self.image_size = image_size
118
+ self.in_channels = in_channels
119
+ self.model_channels = model_channels
120
+ if isinstance(num_res_blocks, int):
121
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
122
+ else:
123
+ if len(num_res_blocks) != len(channel_mult):
124
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
125
+ "as a list/tuple (per-level) with the same length as channel_mult")
126
+ self.num_res_blocks = num_res_blocks
127
+ if disable_self_attentions is not None:
128
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
129
+ assert len(disable_self_attentions) == len(channel_mult)
130
+ if num_attention_blocks is not None:
131
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
132
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
133
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
134
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
135
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
136
+ f"attention will still not be set.")
137
+
138
+ self.attention_resolutions = attention_resolutions
139
+ self.dropout = dropout
140
+ self.channel_mult = channel_mult
141
+ self.conv_resample = conv_resample
142
+ self.use_checkpoint = use_checkpoint
143
+ self.dtype = torch.float16 if use_fp16 else torch.float32
144
+ self.num_heads = num_heads
145
+ self.num_head_channels = num_head_channels
146
+ self.num_heads_upsample = num_heads_upsample
147
+ self.predict_codebook_ids = n_embed is not None
148
+
149
+ time_embed_dim = model_channels * 4
150
+ self.time_embed = nn.Sequential(
151
+ nn.Linear(model_channels, time_embed_dim),
152
+ nn.SiLU(),
153
+ nn.Linear(time_embed_dim, time_embed_dim),
154
+ )
155
+
156
+ self.input_blocks = nn.ModuleList(
157
+ [
158
+ TimestepEmbedSequential(
159
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
160
+ )
161
+ ]
162
+ )
163
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
164
+
165
+ self.input_hint_block = TimestepEmbedSequential(
166
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
167
+ nn.SiLU(),
168
+ conv_nd(dims, 16, 16, 3, padding=1),
169
+ nn.SiLU(),
170
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
171
+ nn.SiLU(),
172
+ conv_nd(dims, 32, 32, 3, padding=1),
173
+ nn.SiLU(),
174
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
175
+ nn.SiLU(),
176
+ conv_nd(dims, 96, 96, 3, padding=1),
177
+ nn.SiLU(),
178
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
179
+ nn.SiLU(),
180
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
181
+ )
182
+
183
+ self._feature_size = model_channels
184
+ input_block_chans = [model_channels]
185
+ ch = model_channels
186
+ ds = 1
187
+ for level, mult in enumerate(channel_mult):
188
+ for nr in range(self.num_res_blocks[level]):
189
+ layers = [
190
+ ResBlock(
191
+ ch,
192
+ time_embed_dim,
193
+ dropout,
194
+ out_channels=mult * model_channels,
195
+ dims=dims,
196
+ use_checkpoint=use_checkpoint,
197
+ use_scale_shift_norm=use_scale_shift_norm,
198
+ )
199
+ ]
200
+ ch = mult * model_channels
201
+ if ds in attention_resolutions:
202
+ if num_head_channels == -1:
203
+ dim_head = ch // num_heads
204
+ else:
205
+ num_heads = ch // num_head_channels
206
+ dim_head = num_head_channels
207
+ if legacy:
208
+ # num_heads = 1
209
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
210
+ if disable_self_attentions is not None:
211
+ disabled_sa = disable_self_attentions[level]
212
+ else:
213
+ disabled_sa = False
214
+
215
+ if (num_attention_blocks is None) or nr < num_attention_blocks[level]:
216
+ layers.append(
217
+ AttentionBlock(
218
+ ch,
219
+ use_checkpoint=use_checkpoint,
220
+ num_heads=num_heads,
221
+ num_head_channels=dim_head,
222
+ use_new_attention_order=use_new_attention_order,
223
+ ) if not use_spatial_transformer else SpatialTransformer(
224
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
225
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
226
+ use_checkpoint=use_checkpoint
227
+ )
228
+ )
229
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
230
+ self.zero_convs.append(self.make_zero_conv(ch))
231
+ self._feature_size += ch
232
+ input_block_chans.append(ch)
233
+ if level != len(channel_mult) - 1:
234
+ out_ch = ch
235
+ self.input_blocks.append(
236
+ TimestepEmbedSequential(
237
+ ResBlock(
238
+ ch,
239
+ time_embed_dim,
240
+ dropout,
241
+ out_channels=out_ch,
242
+ dims=dims,
243
+ use_checkpoint=use_checkpoint,
244
+ use_scale_shift_norm=use_scale_shift_norm,
245
+ down=True,
246
+ )
247
+ if resblock_updown
248
+ else Downsample(
249
+ ch, conv_resample, dims=dims, out_channels=out_ch
250
+ )
251
+ )
252
+ )
253
+ ch = out_ch
254
+ input_block_chans.append(ch)
255
+ self.zero_convs.append(self.make_zero_conv(ch))
256
+ ds *= 2
257
+ self._feature_size += ch
258
+
259
+ if num_head_channels == -1:
260
+ dim_head = ch // num_heads
261
+ else:
262
+ num_heads = ch // num_head_channels
263
+ dim_head = num_head_channels
264
+ if legacy:
265
+ # num_heads = 1
266
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
267
+ self.middle_block = TimestepEmbedSequential(
268
+ ResBlock(
269
+ ch,
270
+ time_embed_dim,
271
+ dropout,
272
+ dims=dims,
273
+ use_checkpoint=use_checkpoint,
274
+ use_scale_shift_norm=use_scale_shift_norm,
275
+ ),
276
+ AttentionBlock(
277
+ ch,
278
+ use_checkpoint=use_checkpoint,
279
+ num_heads=num_heads,
280
+ num_head_channels=dim_head,
281
+ use_new_attention_order=use_new_attention_order,
282
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
283
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
284
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
285
+ use_checkpoint=use_checkpoint
286
+ ),
287
+ ResBlock(
288
+ ch,
289
+ time_embed_dim,
290
+ dropout,
291
+ dims=dims,
292
+ use_checkpoint=use_checkpoint,
293
+ use_scale_shift_norm=use_scale_shift_norm,
294
+ ),
295
+ )
296
+ self.middle_block_out = self.make_zero_conv(ch)
297
+ self._feature_size += ch
298
+
299
+ def make_zero_conv(self, channels):
300
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
301
+
302
+ def forward(self, x, hint, timesteps, context, **kwargs):
303
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
304
+ t_emb = t_emb.to(x.dtype)
305
+ emb = self.time_embed(t_emb)
306
+
307
+ guided_hint = self.input_hint_block(hint, emb, context)
308
+
309
+ outs = []
310
+
311
+ h = x
312
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
313
+ if guided_hint is not None:
314
+ h = module(h, emb, context)
315
+ h += guided_hint
316
+ guided_hint = None
317
+ else:
318
+ h = module(h, emb, context)
319
+ outs.append(zero_conv(h, emb, context))
320
+
321
+ h = self.middle_block(h, emb, context)
322
+ outs.append(self.middle_block_out(h, emb, context))
323
+
324
+ return outs
325
+
326
+ def get_device(self):
327
+ return self.time_embed[0].weight.device
328
+
329
+ def get_dtype(self):
330
+ return self.time_embed[0].weight.dtype
331
+
332
+ def preprocess(self, x, type='canny', **kwargs):
333
+ import torchvision.transforms as tvtrans
334
+ if isinstance(x, str):
335
+ import PIL.Image
336
+ device, dtype = self.get_device(), self.get_dtype()
337
+ x_list = [PIL.Image.open(x)]
338
+ elif isinstance(x, torch.Tensor):
339
+ x_list = [tvtrans.ToPILImage()(xi) for xi in x]
340
+ device, dtype = x.device, x.dtype
341
+ else:
342
+ assert False
343
+
344
+ if type == 'none' or type is None:
345
+ return None
346
+
347
+ elif type in ['input', 'shuffle_v11e']:
348
+ y_torch = torch.stack([tvtrans.ToTensor()(xi) for xi in x_list])
349
+ y_torch = y_torch.to(device).to(torch.float32)
350
+ return y_torch
351
+
352
+ elif type in ['canny', 'canny_v11p']:
353
+ low_threshold = kwargs.pop('low_threshold', 100)
354
+ high_threshold = kwargs.pop('high_threshold', 200)
355
+ from .controlnet_annotator.canny import apply_canny
356
+ y_list = [apply_canny(np.array(xi), low_threshold, high_threshold) for xi in x_list]
357
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
358
+ y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
359
+ y_torch = y_torch.to(device).to(torch.float32)
360
+ return y_torch
361
+
362
+ elif type == 'depth':
363
+ from .controlnet_annotator.midas import apply_midas
364
+ y_list, _ = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, device=device) for xi in x_list])
365
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
366
+ y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
367
+ y_torch = y_torch.to(device).to(torch.float32)
368
+ return y_torch
369
+
370
+ elif type in ['hed', 'softedge_v11p']:
371
+ from .controlnet_annotator.hed import apply_hed
372
+ y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
373
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
374
+ y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
375
+ y_torch = y_torch.to(device).to(torch.float32)
376
+ return y_torch
377
+
378
+ elif type in ['mlsd', 'mlsd_v11p']:
379
+ thr_v = kwargs.pop('thr_v', 0.1)
380
+ thr_d = kwargs.pop('thr_d', 0.1)
381
+ from .controlnet_annotator.mlsd import apply_mlsd
382
+ y_list = [apply_mlsd(np.array(xi), thr_v=thr_v, thr_d=thr_d, device=device) for xi in x_list]
383
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
384
+ y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
385
+ y_torch = y_torch.to(device).to(torch.float32)
386
+ return y_torch
387
+
388
+ elif type == 'normal':
389
+ bg_th = kwargs.pop('bg_th', 0.4)
390
+ from .controlnet_annotator.midas import apply_midas
391
+ _, y_list = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, bg_th=bg_th, device=device) for xi in x_list])
392
+ y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
393
+ y_torch = y_torch.to(device).to(torch.float32)
394
+ return y_torch
395
+
396
+ elif type in ['openpose', 'openpose_v11p']:
397
+ from .controlnet_annotator.openpose import OpenposeModel
398
+ from functools import partial
399
+ wrapper = OpenposeModel()
400
+ apply_openpose = partial(
401
+ wrapper.run_model, include_body=True, include_hand=False, include_face=False,
402
+ json_pose_callback=None, device=device)
403
+ y_list = [apply_openpose(np.array(xi)) for xi in x_list]
404
+ y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
405
+ y_torch = y_torch.to(device).to(torch.float32)
406
+ return y_torch
407
+
408
+ elif type in ['openpose_withface', 'openpose_withface_v11p']:
409
+ from .controlnet_annotator.openpose import OpenposeModel
410
+ from functools import partial
411
+ wrapper = OpenposeModel()
412
+ apply_openpose = partial(
413
+ wrapper.run_model, include_body=True, include_hand=False, include_face=True,
414
+ json_pose_callback=None, device=device)
415
+ y_list = [apply_openpose(np.array(xi)) for xi in x_list]
416
+ y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
417
+ y_torch = y_torch.to(device).to(torch.float32)
418
+ return y_torch
419
+
420
+ elif type in ['openpose_withfacehand', 'openpose_withfacehand_v11p']:
421
+ from .controlnet_annotator.openpose import OpenposeModel
422
+ from functools import partial
423
+ wrapper = OpenposeModel()
424
+ apply_openpose = partial(
425
+ wrapper.run_model, include_body=True, include_hand=True, include_face=True,
426
+ json_pose_callback=None, device=device)
427
+ y_list = [apply_openpose(np.array(xi)) for xi in x_list]
428
+ y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
429
+ y_torch = y_torch.to(device).to(torch.float32)
430
+ return y_torch
431
+
432
+ elif type == 'scribble':
433
+ method = kwargs.pop('method', 'pidinet')
434
+
435
+ import cv2
436
+ def nms(x, t, s):
437
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
438
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
439
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
440
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
441
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
442
+ y = np.zeros_like(x)
443
+ for f in [f1, f2, f3, f4]:
444
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
445
+ z = np.zeros_like(y, dtype=np.uint8)
446
+ z[y > t] = 255
447
+ return z
448
+
449
+ def make_scribble(result):
450
+ result = nms(result, 127, 3.0)
451
+ result = cv2.GaussianBlur(result, (0, 0), 3.0)
452
+ result[result > 4] = 255
453
+ result[result < 255] = 0
454
+ return result
455
+
456
+ if method == 'hed':
457
+ from .controlnet_annotator.hed import apply_hed
458
+ y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
459
+ y_list = [make_scribble(yi) for yi in y_list]
460
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
461
+ y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
462
+ y_torch = y_torch.to(device).to(torch.float32)
463
+ return y_torch
464
+
465
+ elif method == 'pidinet':
466
+ from .controlnet_annotator.pidinet import apply_pidinet
467
+ y_list = [apply_pidinet(np.array(xi), device=device) for xi in x_list]
468
+ y_list = [make_scribble(yi) for yi in y_list]
469
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
470
+ y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
471
+ y_torch = y_torch.to(device).to(torch.float32)
472
+ return y_torch
473
+
474
+ elif method == 'xdog':
475
+ threshold = kwargs.pop('threshold', 32)
476
+ def apply_scribble_xdog(img):
477
+ g1 = cv2.GaussianBlur(img.astype(np.float32), (0, 0), 0.5)
478
+ g2 = cv2.GaussianBlur(img.astype(np.float32), (0, 0), 5.0)
479
+ dog = (255 - np.min(g2 - g1, axis=2)).clip(0, 255).astype(np.uint8)
480
+ result = np.zeros_like(img, dtype=np.uint8)
481
+ result[2 * (255 - dog) > threshold] = 255
482
+ return result
483
+
484
+ y_list = [apply_scribble_xdog(np.array(xi), device=device) for xi in x_list]
485
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
486
+ y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
487
+ y_torch = y_torch.to(device).to(torch.float32)
488
+ return y_torch
489
+
490
+ else:
491
+ raise ValueError
492
+
493
+ elif type == 'seg':
494
+ method = kwargs.pop('method', 'ufade20k')
495
+ if method == 'ufade20k':
496
+ from .controlnet_annotator.uniformer import apply_uniformer
497
+ y_list = [apply_uniformer(np.array(xi), palette='ade20k', device=device) for xi in x_list]
498
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
499
+ y_torch = y_torch.to(device).to(torch.float32)
500
+ return y_torch
501
+
502
+ else:
503
+ raise ValueError
lib/model_zoo/controlnet_annotator/canny/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import cv2
2
+
3
+
4
+ def apply_canny(img, low_threshold, high_threshold):
5
+ return cv2.Canny(img, low_threshold, high_threshold)
lib/model_zoo/controlnet_annotator/hed/__init__.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an improved version and model of HED edge detection with Apache License, Version 2.0.
2
+ # Please use this implementation in your products
3
+ # This implementation may produce slightly different results from Saining Xie's official implementations,
4
+ # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
5
+ # Different from official models and other implementations, this is an RGB-input model (rather than BGR)
6
+ # and in this way it works better for gradio's RGB protocol
7
+
8
+ import os
9
+ import cv2
10
+ import torch
11
+ import numpy as np
12
+
13
+ from einops import rearrange
14
+ import os
15
+
16
+ models_path = 'pretrained/controlnet/preprocess'
17
+
18
+ def safe_step(x, step=2):
19
+ y = x.astype(np.float32) * float(step + 1)
20
+ y = y.astype(np.int32).astype(np.float32) / float(step)
21
+ return y
22
+
23
+ class DoubleConvBlock(torch.nn.Module):
24
+ def __init__(self, input_channel, output_channel, layer_number):
25
+ super().__init__()
26
+ self.convs = torch.nn.Sequential()
27
+ self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
28
+ for i in range(1, layer_number):
29
+ self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
30
+ self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
31
+
32
+ def __call__(self, x, down_sampling=False):
33
+ h = x
34
+ if down_sampling:
35
+ h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
36
+ for conv in self.convs:
37
+ h = conv(h)
38
+ h = torch.nn.functional.relu(h)
39
+ return h, self.projection(h)
40
+
41
+
42
+ class ControlNetHED_Apache2(torch.nn.Module):
43
+ def __init__(self):
44
+ super().__init__()
45
+ self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
46
+ self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
47
+ self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
48
+ self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
49
+ self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
50
+ self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
51
+
52
+ def __call__(self, x):
53
+ h = x - self.norm
54
+ h, projection1 = self.block1(h)
55
+ h, projection2 = self.block2(h, down_sampling=True)
56
+ h, projection3 = self.block3(h, down_sampling=True)
57
+ h, projection4 = self.block4(h, down_sampling=True)
58
+ h, projection5 = self.block5(h, down_sampling=True)
59
+ return projection1, projection2, projection3, projection4, projection5
60
+
61
+
62
+ netNetwork = None
63
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
64
+ modeldir = os.path.join(models_path, "hed")
65
+ old_modeldir = os.path.dirname(os.path.realpath(__file__))
66
+
67
+
68
+ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
69
+ """Load file form http url, will download models if necessary.
70
+
71
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
72
+
73
+ Args:
74
+ url (str): URL to be downloaded.
75
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
76
+ Default: None.
77
+ progress (bool): Whether to show the download progress. Default: True.
78
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
79
+
80
+ Returns:
81
+ str: The path to the downloaded file.
82
+ """
83
+ from torch.hub import download_url_to_file, get_dir
84
+ from urllib.parse import urlparse
85
+ if model_dir is None: # use the pytorch hub_dir
86
+ hub_dir = get_dir()
87
+ model_dir = os.path.join(hub_dir, 'checkpoints')
88
+
89
+ os.makedirs(model_dir, exist_ok=True)
90
+
91
+ parts = urlparse(url)
92
+ filename = os.path.basename(parts.path)
93
+ if file_name is not None:
94
+ filename = file_name
95
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
96
+ if not os.path.exists(cached_file):
97
+ print(f'Downloading: "{url}" to {cached_file}\n')
98
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
99
+ return cached_file
100
+
101
+
102
+ def apply_hed(input_image, is_safe=False, device='cpu'):
103
+ global netNetwork
104
+ if netNetwork is None:
105
+ modelpath = os.path.join(modeldir, "ControlNetHED.pth")
106
+ old_modelpath = os.path.join(old_modeldir, "ControlNetHED.pth")
107
+ if os.path.exists(old_modelpath):
108
+ modelpath = old_modelpath
109
+ elif not os.path.exists(modelpath):
110
+ load_file_from_url(remote_model_path, model_dir=modeldir)
111
+ netNetwork = ControlNetHED_Apache2().to(device)
112
+ netNetwork.load_state_dict(torch.load(modelpath, map_location='cpu'))
113
+ netNetwork.to(device).float().eval()
114
+
115
+ assert input_image.ndim == 3
116
+ H, W, C = input_image.shape
117
+ with torch.no_grad():
118
+ image_hed = torch.from_numpy(input_image.copy()).float().to(device)
119
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
120
+ edges = netNetwork(image_hed)
121
+ edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
122
+ edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
123
+ edges = np.stack(edges, axis=2)
124
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
125
+ if is_safe:
126
+ edge = safe_step(edge)
127
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
128
+ return edge
129
+
130
+
131
+ def unload_hed_model():
132
+ global netNetwork
133
+ if netNetwork is not None:
134
+ netNetwork.cpu()