yibolu commited on
Commit
6eca12e
1 Parent(s): a38262d

update pipeline and demos

Browse files
Files changed (36) hide show
  1. README.md +19 -11
  2. controlnet_img2img_demo.py +6 -4
  3. controlnet_txt2img_demo.py +11 -5
  4. controlnet_txt2img_sdxl_demo.py +70 -0
  5. img2img_demo.py +5 -2
  6. lyrasd_model/__init__.py +5 -1
  7. lyrasd_model/lora_util.py +238 -6
  8. lyrasd_model/lyrasd_controlnet_img2img_pipeline.py +92 -110
  9. lyrasd_model/lyrasd_controlnet_txt2img_pipeline.py +40 -82
  10. lyrasd_model/lyrasd_img2img_pipeline.py +90 -95
  11. lyrasd_model/lyrasd_lib/libth_lyrasd_cu11_sm80.so +0 -3
  12. lyrasd_model/lyrasd_lib/libth_lyrasd_cu11_sm86.so +0 -3
  13. lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so +2 -2
  14. lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm86.so +2 -2
  15. lyrasd_model/lyrasd_pipeline_base.py +214 -0
  16. lyrasd_model/lyrasd_txt2img_inpaint_pipeline.py +826 -0
  17. lyrasd_model/lyrasd_txt2img_pipeline.py +172 -85
  18. lyrasd_model/lyrasd_vae_model.py +363 -0
  19. lyrasd_model/lyrasdxl_controlnet_txt2img_pipeline.py +346 -0
  20. lyrasd_model/lyrasdxl_pipeline_base.py +275 -0
  21. lyrasd_model/lyrasdxl_txt2img_inpaint_pipeline.py +535 -0
  22. lyrasd_model/lyrasdxl_txt2img_pipeline.py +267 -0
  23. lyrasd_model/{lyrasd_lib/placeholder.txt → module/__init__.py} +0 -0
  24. lyrasd_model/module/lyra_tool.py +5 -0
  25. lyrasd_model/module/lyrasd_ip_adapter.py +289 -0
  26. lyrasd_model/module/resampler.py +121 -0
  27. lyrasd_model/module/tools.py +148 -0
  28. models/README.md +14 -5
  29. outputs/res_controlnet_img2img_0.png +2 -2
  30. outputs/{res_controlnet_sdxl_txt2img.png → res_controlnet_sdxl_txt2img_0.png} +2 -2
  31. outputs/res_controlnet_txt2img_0.png +2 -2
  32. outputs/res_img2img_0.png +2 -2
  33. outputs/res_txt2img_lora_0.png +2 -2
  34. outputs/{res_sdxl_txt2img_lora_0.png → res_txt2img_xl_lora_0.png} +2 -2
  35. txt2img_demo.py +13 -10
  36. txt2img_sdxl_demo.py +55 -0
README.md CHANGED
@@ -79,12 +79,16 @@ from lyrasd_model import LyraSdTxt2ImgPipeline
79
  # 4. scheduler 配置
80
 
81
  # LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
82
- lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu11_sm80.so"
83
- model_path = "./models/lyrasd_rev_animated"
84
  lora_path = "./models/xiaorenshu.safetensors"
85
 
 
 
86
  # 构建 Txt2Img 的 Pipeline
87
- model = LyraSdTxt2ImgPipeline(model_path, lib_path)
 
 
88
 
89
  # load lora
90
  # lora model path, name,lora strength
@@ -94,7 +98,7 @@ model.load_lora_v2(lora_path, "xiaorenshu", 0.4)
94
  prompt = "a cat, cute, cartoon, concise, traditional, chinese painting, Tang and Song Dynasties, masterpiece, 4k, 8k, UHD, best quality"
95
  negative_prompt = "(((horrible))), (((scary))), (((naked))), (((large breasts))), high saturation, colorful, human:2, body:2, low quality, bad quality, lowres, out of frame, duplicate, watermark, signature, text, frames, cut, cropped, malformed limbs, extra limbs, (((missing arms))), (((missing legs)))"
96
  height, width = 512, 512
97
- steps = 30
98
  guidance_scale = 7
99
  generator = torch.Generator().manual_seed(123)
100
  num_images = 1
@@ -128,12 +132,16 @@ from lyrasd_model import LyraSdXLTxt2ImgPipeline
128
  # 4. scheduler 配置
129
 
130
  # LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
131
- lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu11_sm80.so"
132
- model_path = "./models/lyrasd_helloworldSDXL20Fp16"
133
  lora_path = "./models/dissolve_sdxl.safetensors"
134
 
 
 
135
  # 构建 Txt2Img 的 Pipeline
136
- model = LyraSdXLTxt2ImgPipeline(model_path, lib_path)
 
 
137
 
138
  # load lora
139
  # lora model path, name,lora strength
@@ -143,7 +151,7 @@ model.load_lora_v2(lora_path, "dissolve_sdxl", 0.4)
143
  prompt = "a cat, cute, cartoon, concise, traditional, chinese painting, Tang and Song Dynasties, masterpiece, 4k, 8k, UHD, best quality"
144
  negative_prompt = "(((horrible))), (((scary))), (((naked))), (((large breasts))), high saturation, colorful, human:2, body:2, low quality, bad quality, lowres, out of frame, duplicate, watermark, signature, text, frames, cut, cropped, malformed limbs, extra limbs, (((missing arms))), (((missing legs)))"
145
  height, width = 512, 512
146
- steps = 30
147
  guidance_scale = 7
148
  generator = torch.Generator().manual_seed(123)
149
  num_images = 1
@@ -181,7 +189,7 @@ model.unload_lora_v2("dissolve_sdxl", True)
181
  ![text2img_demo](./outputs/res_sdxl_txt2img_0.png)
182
 
183
  #### SDXL Text2Img with Lora
184
- ![text2img_demo](./outputs/res_sdxl_txt2img_lora_0.png)
185
 
186
 
187
  <!-- ### Img2Img
@@ -201,7 +209,7 @@ model.unload_lora_v2("dissolve_sdxl", True)
201
  ![text2img_demo](./outputs/res_controlnet_txt2img_0.png)
202
 
203
  #### SDXL ControlNet Text2Img Output
204
- ![text2img_demo](./outputs/res_controlnet_sdxl_txt2img.png)
205
 
206
 
207
  ## Docker Environment Recommendation
@@ -218,7 +226,7 @@ python txt2img_demo.py
218
 
219
  ## Citation
220
  ``` bibtex
221
- @Misc{lyraSD_2023,
222
  author = {Kangjian Wu, Zhengtao Wang, Yibo Lu, Haoxiong Su, Sa Xiao, Bin Wu},
223
  title = {lyraSD: Accelerating Stable Diffusion with best flexibility},
224
  howpublished = {\url{https://huggingface.co/TMElyralab/lyraSD}},
 
79
  # 4. scheduler 配置
80
 
81
  # LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
82
+ lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so"
83
+ model_path = "./models/rev-animated"
84
  lora_path = "./models/xiaorenshu.safetensors"
85
 
86
+ torch.classes.load_library(lib_path)
87
+
88
  # 构建 Txt2Img 的 Pipeline
89
+ model = LyraSdTxt2ImgPipeline()
90
+
91
+ model.reload_pipe(model_path)
92
 
93
  # load lora
94
  # lora model path, name,lora strength
 
98
  prompt = "a cat, cute, cartoon, concise, traditional, chinese painting, Tang and Song Dynasties, masterpiece, 4k, 8k, UHD, best quality"
99
  negative_prompt = "(((horrible))), (((scary))), (((naked))), (((large breasts))), high saturation, colorful, human:2, body:2, low quality, bad quality, lowres, out of frame, duplicate, watermark, signature, text, frames, cut, cropped, malformed limbs, extra limbs, (((missing arms))), (((missing legs)))"
100
  height, width = 512, 512
101
+ steps = 20
102
  guidance_scale = 7
103
  generator = torch.Generator().manual_seed(123)
104
  num_images = 1
 
132
  # 4. scheduler 配置
133
 
134
  # LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
135
+ lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so"
136
+ model_path = "./models/helloworldSDXL20Fp16"
137
  lora_path = "./models/dissolve_sdxl.safetensors"
138
 
139
+ torch.classes.load_library(lib_path)
140
+
141
  # 构建 Txt2Img 的 Pipeline
142
+ model = LyraSdXLTxt2ImgPipeline()
143
+
144
+ model.reload_pipe(model_path)
145
 
146
  # load lora
147
  # lora model path, name,lora strength
 
151
  prompt = "a cat, cute, cartoon, concise, traditional, chinese painting, Tang and Song Dynasties, masterpiece, 4k, 8k, UHD, best quality"
152
  negative_prompt = "(((horrible))), (((scary))), (((naked))), (((large breasts))), high saturation, colorful, human:2, body:2, low quality, bad quality, lowres, out of frame, duplicate, watermark, signature, text, frames, cut, cropped, malformed limbs, extra limbs, (((missing arms))), (((missing legs)))"
153
  height, width = 512, 512
154
+ steps = 20
155
  guidance_scale = 7
156
  generator = torch.Generator().manual_seed(123)
157
  num_images = 1
 
189
  ![text2img_demo](./outputs/res_sdxl_txt2img_0.png)
190
 
191
  #### SDXL Text2Img with Lora
192
+ ![text2img_demo](./outputs/res_txt2img_xl_lora_0.png)
193
 
194
 
195
  <!-- ### Img2Img
 
209
  ![text2img_demo](./outputs/res_controlnet_txt2img_0.png)
210
 
211
  #### SDXL ControlNet Text2Img Output
212
+ ![text2img_demo](./outputs/res_controlnet_sdxl_txt2img_0.png)
213
 
214
 
215
  ## Docker Environment Recommendation
 
226
 
227
  ## Citation
228
  ``` bibtex
229
+ @Misc{lyraSD_2024,
230
  author = {Kangjian Wu, Zhengtao Wang, Yibo Lu, Haoxiong Su, Sa Xiao, Bin Wu},
231
  title = {lyraSD: Accelerating Stable Diffusion with best flexibility},
232
  howpublished = {\url{https://huggingface.co/TMElyralab/lyraSD}},
controlnet_img2img_demo.py CHANGED
@@ -14,14 +14,16 @@ from lyrasd_model import LyraSdControlnetImg2ImgPipeline
14
 
15
  # LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
16
  lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm86.so"
17
- model_path = "./models/lyrasd_rev_animated"
18
- canny_controlnet_path = "./models/lyrasd_canny"
 
19
 
20
  # 构建 Img2Img 的 Pipeline
21
- model = LyraSdControlnetImg2ImgPipeline(model_path, lib_path)
 
22
 
23
  # load Controlnet 模型,最多load 3个
24
- model.load_controlnet_model("canny", canny_controlnet_path, "fp32")
25
 
26
  control_img = Image.open("control_bird_canny.png")
27
 
 
14
 
15
  # LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
16
  lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm86.so"
17
+ model_path = "./models/rev-animated"
18
+ canny_controlnet_path = "./models/canny"
19
+ torch.classes.load_library(lib_path)
20
 
21
  # 构建 Img2Img 的 Pipeline
22
+ model = LyraSdControlnetImg2ImgPipeline()
23
+ model.reload_pipe(model_path)
24
 
25
  # load Controlnet 模型,最多load 3个
26
+ model.load_controlnet_model_v2("canny", canny_controlnet_path)
27
 
28
  control_img = Image.open("control_bird_canny.png")
29
 
controlnet_txt2img_demo.py CHANGED
@@ -12,16 +12,22 @@ from lyrasd_model import LyraSdControlnetTxt2ImgPipeline
12
  # 5. scheduler 配置
13
 
14
  # LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
15
- lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm86.so"
16
- model_path = "./models/lyrasd_rev_animated"
17
- canny_controlnet_path = "./models/lyrasd_canny"
 
 
 
18
  # 构建 Txt2Img 的 Pipeline
19
- pipe = LyraSdControlnetTxt2ImgPipeline(model_path, lib_path)
 
 
20
 
21
  # load Controlnet 模型,最多load 3个
22
  start = time.perf_counter()
23
- pipe.load_controlnet_model("canny", canny_controlnet_path, "fp32")
24
  print(f"controlnet load cost: {time.perf_counter() - start}")
 
25
  # 可以通过 get_loaded_controlnet 方法获取目前已经load 好的Controlnet list
26
  print(pipe.get_loaded_controlnet())
27
 
 
12
  # 5. scheduler 配置
13
 
14
  # LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
15
+ lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so"
16
+ model_path = "./models/rev-animated"
17
+ canny_controlnet_path = "./models/canny"
18
+
19
+ torch.classes.load_library(lib_path)
20
+
21
  # 构建 Txt2Img 的 Pipeline
22
+ pipe = LyraSdControlnetTxt2ImgPipeline()
23
+
24
+ pipe.reload_pipe(model_path)
25
 
26
  # load Controlnet 模型,最多load 3个
27
  start = time.perf_counter()
28
+ pipe.load_controlnet_model_v2("canny", canny_controlnet_path)
29
  print(f"controlnet load cost: {time.perf_counter() - start}")
30
+
31
  # 可以通过 get_loaded_controlnet 方法获取目前已经load 好的Controlnet list
32
  print(pipe.get_loaded_controlnet())
33
 
controlnet_txt2img_sdxl_demo.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ from PIL import Image
4
+ import numpy as np
5
+ from lyrasd_model import LyraSdXLControlnetTxt2ImgPipeline
6
+ import GPUtil
7
+
8
+ # 存放模型文件的路径,应该包含一下结构:
9
+ # 1. clip 模型
10
+ # 2. 转换好的优化后的 unet 模型
11
+ # 3. 转换好的优化后的 controlnet 模型
12
+ # 4. vae 模型
13
+ # 5. scheduler 配置
14
+ lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so"
15
+ model_path = "./models/helloworldSDXL20Fp16"
16
+ torch.classes.load_library(lib_path)
17
+
18
+ # 构建 Txt2Img 的 Pipeline
19
+ pipe = LyraSdXLControlnetTxt2ImgPipeline()
20
+
21
+ start = time.perf_counter()
22
+ pipe.reload_pipe(model_path)
23
+ print(f"pipeline load cost: {time.perf_counter() - start}")
24
+
25
+ # load Controlnet 模型,最多load 3个
26
+ start = time.perf_counter()
27
+ pipe.load_controlnet_model_v2("canny", "./models/controlnet-canny-sdxl-1.0")
28
+ print(f"controlnet load cost: {time.perf_counter() - start}")
29
+
30
+ # 可以通过 get_loaded_controlnet 方法获取目前已经load 好的Controlnet list
31
+ print(pipe.get_loaded_controlnet())
32
+
33
+ # 可以通过unload_controlnet_model 方法unload Controlnet
34
+ # pipe.unload_controlnet_model("canny")
35
+
36
+ control_img = Image.open("control_bird_canny.png")
37
+
38
+ # 准备应用的输入和超参数
39
+ prompt = "a bird"
40
+ negative_prompt = ""
41
+ height, width = 1024, 1024
42
+ steps = 20
43
+ guidance_scale = 7.5
44
+ generator = torch.Generator().manual_seed(123)
45
+ num_images = 1
46
+ guess_mode = False
47
+
48
+ # 可以一次性load 3 个 Controlnets,达到multi Controlnet的效果,这里的参数的长度需要对其
49
+ # Controlnet 所输入的img list 长度应该和 controlnet scale 与 Controlnet name 一致,而内部的list长度需要和batch size一致
50
+ # 对应的index 可以对其
51
+ controlnet_images = [[control_img]]
52
+ controlnet_scale = [0.5]
53
+ controlnet_names = ['canny']
54
+
55
+ # 推理生成,返回结果都是生成好的 PIL.Image
56
+ for batch in [1]:
57
+ print(f"cur batch: {batch}")
58
+ for _ in range(3):
59
+ start = time.perf_counter()
60
+ images = pipe(prompt=prompt, height=height, width=width, num_inference_steps=steps,
61
+ guidance_scale=guidance_scale, negative_prompt=negative_prompt, num_images_per_prompt=batch,
62
+ generator=generator, controlnet_images=controlnet_images,
63
+ controlnet_scale=controlnet_scale, controlnet_names=controlnet_names,
64
+ guess_mode=guess_mode
65
+ )
66
+ print("cur cost: ", time.perf_counter() - start)
67
+ GPUtil.showUtilization(all=True)
68
+ # 存储生成的图片
69
+ for i, image in enumerate(images):
70
+ image.save(f"./outputs/res_controlnet_sdxl_txt2img_{i}.png")
img2img_demo.py CHANGED
@@ -14,10 +14,13 @@ from lyrasd_model import LyraSDImg2ImgPipeline
14
 
15
  # LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
16
  lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm86.so"
17
- model_path = "./models/lyrasd_rev_animated"
 
 
18
 
19
  # 构建 Img2Img 的 Pipeline
20
- model = LyraSDImg2ImgPipeline(model_path, lib_path)
 
21
 
22
  # 准备应用的输入和超参数
23
  prompt = "a cat, cartoon style"
 
14
 
15
  # LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
16
  lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm86.so"
17
+ model_path = "./models/rev-animated"
18
+
19
+ torch.classes.load_library(lib_path)
20
 
21
  # 构建 Img2Img 的 Pipeline
22
+ model = LyraSDImg2ImgPipeline()
23
+ model.reload_pipe(model_path)
24
 
25
  # 准备应用的输入和超参数
26
  prompt = "a cat, cartoon style"
lyrasd_model/__init__.py CHANGED
@@ -1,5 +1,9 @@
1
  from . import lyrasd_img2img_pipeline, lyrasd_txt2img_pipeline, lyrasd_controlnet_txt2img_pipeline, lyrasd_controlnet_img2img_pipeline
2
  from .lyrasd_txt2img_pipeline import LyraSdTxt2ImgPipeline
3
  from .lyrasd_img2img_pipeline import LyraSDImg2ImgPipeline
 
4
  from .lyrasd_controlnet_txt2img_pipeline import LyraSdControlnetTxt2ImgPipeline
5
- from .lyrasd_controlnet_img2img_pipeline import LyraSdControlnetImg2ImgPipeline
 
 
 
 
1
  from . import lyrasd_img2img_pipeline, lyrasd_txt2img_pipeline, lyrasd_controlnet_txt2img_pipeline, lyrasd_controlnet_img2img_pipeline
2
  from .lyrasd_txt2img_pipeline import LyraSdTxt2ImgPipeline
3
  from .lyrasd_img2img_pipeline import LyraSDImg2ImgPipeline
4
+ from .lyrasd_txt2img_inpaint_pipeline import LyraSdTxt2ImgInpaintPipeline
5
  from .lyrasd_controlnet_txt2img_pipeline import LyraSdControlnetTxt2ImgPipeline
6
+ from .lyrasd_controlnet_img2img_pipeline import LyraSdControlnetImg2ImgPipeline
7
+ from .lyrasdxl_txt2img_pipeline import LyraSdXLTxt2ImgPipeline
8
+ from .lyrasdxl_controlnet_txt2img_pipeline import LyraSdXLControlnetTxt2ImgPipeline
9
+ from .lyrasdxl_txt2img_inpaint_pipeline import LyraSdXLTxt2ImgInpaintPipeline
lyrasd_model/lora_util.py CHANGED
@@ -1,7 +1,18 @@
1
  import os
 
 
2
  import torch
3
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
4
  import numpy as np
 
 
 
 
 
 
 
 
 
 
5
 
6
  def add_text_lora_layer(clip_model, lora_model_path="Misaka.safetensors", alpha=1.0, lora_file_format="fp32", device="cuda:0"):
7
  if lora_file_format == "fp32":
@@ -14,9 +25,10 @@ def add_text_lora_layer(clip_model, lora_model_path="Misaka.safetensors", alpha=
14
  unload_dict = []
15
  # directly update weight in diffusers model
16
  for file in all_files:
17
-
18
  if 'text' in file.name:
19
- layer_infos = file.name.split('.')[0].split('text_model_')[-1].split('_')
 
20
  curr_layer = clip_model.text_model
21
  else:
22
  continue
@@ -39,9 +51,71 @@ def add_text_lora_layer(clip_model, lora_model_path="Misaka.safetensors", alpha=
39
  temp_name += '_'+layer_infos.pop(0)
40
  else:
41
  temp_name = layer_infos.pop(0)
42
- data = torch.from_numpy(np.fromfile(file.path, dtype=model_dtype)).to(clip_model.dtype).to(clip_model.device).reshape(curr_layer.weight.data.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  if len(curr_layer.weight.data) == 4:
44
- adding_weight = alpha * data.permute(0,3,1,2)
45
  else:
46
  adding_weight = alpha * data
47
  curr_layer.weight.data += adding_weight
@@ -51,4 +125,162 @@ def add_text_lora_layer(clip_model, lora_model_path="Misaka.safetensors", alpha=
51
  "added_weight": adding_weight
52
  }
53
  unload_dict.append(curr_layer_unload_data)
54
- return unload_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import re
3
+ import time
4
  import torch
 
5
  import numpy as np
6
+ from safetensors.torch import load_file
7
+ from diffusers.loaders import LoraLoaderMixin
8
+ from diffusers.loaders.lora_conversion_utils import _maybe_map_sgm_blocks_to_diffusers, _convert_kohya_lora_to_diffusers
9
+ from types import SimpleNamespace
10
+ import logging.handlers
11
+ LORA_PREFIX_UNET = "lora_unet"
12
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
13
+ LORA_UNET_LAYERS = ['lora_unet_down_blocks_0_attentions_0', 'lora_unet_down_blocks_0_attentions_1', 'lora_unet_down_blocks_1_attentions_0', 'lora_unet_down_blocks_1_attentions_1', 'lora_unet_down_blocks_2_attentions_0', 'lora_unet_down_blocks_2_attentions_1', 'lora_unet_mid_block_attentions_0', 'lora_unet_up_blocks_1_attentions_0',
14
+ 'lora_unet_up_blocks_1_attentions_1', 'lora_unet_up_blocks_1_attentions_2', 'lora_unet_up_blocks_2_attentions_0', 'lora_unet_up_blocks_2_attentions_1', 'lora_unet_up_blocks_2_attentions_2', 'lora_unet_up_blocks_3_attentions_0', 'lora_unet_up_blocks_3_attentions_1', 'lora_unet_up_blocks_3_attentions_2']
15
+
16
 
17
  def add_text_lora_layer(clip_model, lora_model_path="Misaka.safetensors", alpha=1.0, lora_file_format="fp32", device="cuda:0"):
18
  if lora_file_format == "fp32":
 
25
  unload_dict = []
26
  # directly update weight in diffusers model
27
  for file in all_files:
28
+
29
  if 'text' in file.name:
30
+ layer_infos = file.name.split('.')[0].split(
31
+ 'text_model_')[-1].split('_')
32
  curr_layer = clip_model.text_model
33
  else:
34
  continue
 
51
  temp_name += '_'+layer_infos.pop(0)
52
  else:
53
  temp_name = layer_infos.pop(0)
54
+ data = torch.from_numpy(np.fromfile(file.path, dtype=model_dtype)).to(
55
+ clip_model.dtype).to(clip_model.device).reshape(curr_layer.weight.data.shape)
56
+ if len(curr_layer.weight.data) == 4:
57
+ adding_weight = alpha * data.permute(0, 3, 1, 2)
58
+ else:
59
+ adding_weight = alpha * data
60
+ curr_layer.weight.data += adding_weight
61
+
62
+ curr_layer_unload_data = {
63
+ "layer": curr_layer,
64
+ "added_weight": adding_weight
65
+ }
66
+ unload_dict.append(curr_layer_unload_data)
67
+ return unload_dict
68
+
69
+
70
+ def add_xltext_lora_layer(clip_model, clip_model_2, lora_model_path, alpha=1.0, lora_file_format="fp32", device="cuda:0"):
71
+ if lora_file_format == "fp32":
72
+ model_dtype = np.float32
73
+ elif lora_file_format == "fp16":
74
+ model_dtype = np.float16
75
+ else:
76
+ raise Exception(f"unsupported model dtype: {lora_file_format}")
77
+ all_files = os.scandir(lora_model_path)
78
+ unload_dict = []
79
+ # directly update weight in diffusers model
80
+ for file in all_files:
81
+
82
+ if 'text' in file.name:
83
+ layer_infos = file.name.split('.')[0].split(
84
+ 'text_model_')[-1].split('_')
85
+ if "text_encoder_2" in file.name:
86
+ curr_layer = clip_model_2.text_model
87
+ elif "text_encoder" in file.name:
88
+ curr_layer = clip_model.text_model
89
+ else:
90
+ raise ValueError(
91
+ "Cannot identify clip model, need text_encoder or text_encoder_2 in filename, found: ", file.name)
92
+ else:
93
+ continue
94
+
95
+ # find the target layer
96
+ # find the target layer
97
+ temp_name = layer_infos.pop(0)
98
+ while len(layer_infos) > -1:
99
+ try:
100
+ curr_layer = curr_layer.__getattr__(temp_name)
101
+ if len(layer_infos) > 0:
102
+ temp_name = layer_infos.pop(0)
103
+ # if temp_name == "self":
104
+ # temp_name += "_" + layer_infos.pop(0)
105
+ # elif temp_name != "mlp" and len(layer_infos) == 1:
106
+ # temp_name += "_" + layer_infos.pop(0)
107
+ elif len(layer_infos) == 0:
108
+ break
109
+ except Exception:
110
+ if len(temp_name) > 0:
111
+ temp_name += '_'+layer_infos.pop(0)
112
+ else:
113
+ temp_name = layer_infos.pop(0)
114
+
115
+ data = torch.from_numpy(np.fromfile(file.path, dtype=model_dtype)).to(
116
+ clip_model.dtype).to(clip_model.device).reshape(curr_layer.weight.data.shape)
117
  if len(curr_layer.weight.data) == 4:
118
+ adding_weight = alpha * data.permute(0, 3, 1, 2)
119
  else:
120
  adding_weight = alpha * data
121
  curr_layer.weight.data += adding_weight
 
125
  "added_weight": adding_weight
126
  }
127
  unload_dict.append(curr_layer_unload_data)
128
+ return unload_dict
129
+
130
+ def lora_trans(state_dict):
131
+ loraload = LoraLoaderMixin()
132
+ unet_config = SimpleNamespace(**{'layers_per_block': 2})
133
+ state_dicts = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
134
+ state_dicts_trans, state_dicts_alpha = _convert_kohya_lora_to_diffusers(
135
+ state_dicts)
136
+ keys = list(state_dicts_trans.keys())
137
+ for k in keys:
138
+ key = k.replace('processor.', '')
139
+ for x in ['.lora_linear_layer.', '_lora.', '.lora.']:
140
+ key = key.replace(x, '.lora_')
141
+ if key.find('text_encoder') >= 0:
142
+ for x in ['q', 'k', 'v', 'out']:
143
+ key = key.replace(f'.to_{x}.', f'.{x}_proj.')
144
+ key = key.replace('to_out.', 'to_out.0.')
145
+ if key != k:
146
+ state_dicts_trans[key] = state_dicts_trans.pop(k)
147
+ alpha = torch.Tensor(list(set(list(state_dicts_alpha.values()))))
148
+ state_dicts_trans.update({'lora.alpha': alpha})
149
+
150
+ return state_dicts_trans
151
+
152
+
153
+ def load_state_dict(filename, need_trans=True):
154
+ state_dict = load_file(os.path.abspath(filename), device="cpu")
155
+ if need_trans:
156
+ state_dict = lora_trans(state_dict)
157
+ return state_dict
158
+
159
+
160
+ def move_state_dict_to_cuda(state_dict):
161
+ ret_state_dict = {}
162
+ for item in state_dict:
163
+ ret_state_dict[item] = state_dict[item].cuda()
164
+ return ret_state_dict
165
+
166
+
167
+ def add_lora_to_opt_model(state_dict, unet, clip_model, clip_model_2, alpha=1.0, need_trans=False):
168
+ # directly update weight in diffusers model
169
+ state_dict = move_state_dict_to_cuda(state_dict)
170
+
171
+ alpha_ks = list(filter(lambda x: x.find('.alpha') >= 0, state_dict))
172
+ lora_alpha = state_dict[alpha_ks[0]].item() if len(alpha_ks) > 0 else -1
173
+
174
+ visited = set()
175
+ for key in state_dict:
176
+ # print(key)
177
+ # it is suggested to print out the key, it usually will be something like below
178
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
179
+
180
+ # as we have set the alpha beforehand, so just skip
181
+ if '.alpha' in key or key in visited:
182
+ continue
183
+
184
+ if "text" in key:
185
+ curr_layer = clip_model_2 if key.find(
186
+ 'text_encoder_2') >= 0 else clip_model
187
+
188
+ # if is_sdxl:
189
+ layer_infos = key.split('.')[1:]
190
+
191
+ for x in layer_infos:
192
+ try:
193
+ curr_layer = curr_layer.__getattr__(x)
194
+ except Exception:
195
+ break
196
+
197
+ # update weight
198
+ pair_keys = [key.replace("lora_down", "lora_up"),
199
+ key.replace("lora_up", "lora_down")]
200
+ weight_up, weight_down = state_dict[pair_keys[0]
201
+ ], state_dict[pair_keys[1]]
202
+
203
+ weight_scale = lora_alpha/weight_up.shape[1] if lora_alpha != -1 else 1.0
204
+
205
+ if len(weight_up.shape) == 4:
206
+ weight_up = weight_up.squeeze([2, 3])
207
+ weight_down = weight_down.squeeze([2, 3])
208
+ if len(weight_down.shape) == 4:
209
+ adding_weight = torch.einsum(
210
+ 'a b, b c h w -> a c h w', weight_up, weight_down)
211
+ else:
212
+ adding_weight = torch.mm(
213
+ weight_up, weight_down).unsqueeze(2).unsqueeze(3)
214
+ else:
215
+ adding_weight = torch.mm(weight_up, weight_down)
216
+ adding_weight = alpha * weight_scale * adding_weight
217
+
218
+ curr_layer.weight.data += adding_weight.to(torch.float16)
219
+ # update visited list
220
+ for item in pair_keys:
221
+ visited.add(item)
222
+
223
+ elif "unet" in key:
224
+ layer_infos = key
225
+ layer_infos = layer_infos.replace(".lora_up.weight", "")
226
+ layer_infos = layer_infos.replace(".lora_down.weight", "")
227
+
228
+ layer_infos = layer_infos[5:]
229
+ layer_names = layer_infos.split(".")
230
+
231
+ layers = []
232
+ i = 0
233
+ while i < len(layer_names):
234
+
235
+ if len(layers) >= 4:
236
+ layers[-1] += "_" + layer_names[i]
237
+ elif i + 1 < len(layer_names) and layer_names[i+1].isdigit():
238
+ layers.append(layer_names[i] + "_" + layer_names[i+1])
239
+ i += 1
240
+ elif len(layers) > 0 and "samplers" in layers[-1]:
241
+ layers[-1] += "_" + layer_names[i]
242
+ else:
243
+ layers.append(layer_names[i])
244
+ i += 1
245
+ layer_infos = ".".join(layers)
246
+
247
+ pair_keys = [key.replace("lora_down", "lora_up"),
248
+ key.replace("lora_up", "lora_down")]
249
+
250
+ # update weight
251
+ if len(state_dict[pair_keys[0]].shape) == 4:
252
+ weight_up = state_dict[pair_keys[0]].squeeze(
253
+ 3).squeeze(2).to(torch.float32)
254
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
255
+ weight_scale = lora_alpha/weight_up.shape[1] if lora_alpha != -1 else 1.0
256
+
257
+ weight_up, weight_down = state_dict[pair_keys[0]
258
+ ], state_dict[pair_keys[1]]
259
+ weight_up = weight_up.squeeze([2, 3]).to(torch.float32)
260
+ weight_down = weight_down.squeeze([2, 3]).to(torch.float32)
261
+ if len(weight_down.shape) == 4:
262
+ curr_layer_weight = weight_scale * \
263
+ torch.einsum('a b, b c h w -> a c h w',
264
+ weight_up, weight_down)
265
+ else:
266
+ curr_layer_weight = weight_scale * \
267
+ torch.mm(weight_up, weight_down).unsqueeze(
268
+ 2).unsqueeze(3)
269
+
270
+ curr_layer_weight = curr_layer_weight.permute(0, 2, 3, 1)
271
+
272
+ else:
273
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
274
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
275
+ weight_scale = lora_alpha/weight_up.shape[1] if lora_alpha != -1 else 1.0
276
+
277
+ curr_layer_weight = weight_scale * \
278
+ torch.mm(weight_up, weight_down)
279
+ #
280
+
281
+ curr_layer_weight = curr_layer_weight.to(torch.float16)
282
+
283
+ unet.load_lora_by_name(layers, curr_layer_weight, alpha)
284
+
285
+ for item in pair_keys:
286
+ visited.add(item)
lyrasd_model/lyrasd_controlnet_img2img_pipeline.py CHANGED
@@ -1,21 +1,18 @@
1
  import torch
2
  from typing import Any, Callable, Dict, List, Optional, Union
3
- from diffusers.schedulers import KarrasDiffusionSchedulers
4
  from diffusers.loaders import TextualInversionLoaderMixin
5
- from diffusers.models import AutoencoderKL
6
- from diffusers.utils import randn_tensor, logging
7
- from diffusers.schedulers import EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
8
  from diffusers.utils import PIL_INTERPOLATION
9
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
10
  import os
11
  import numpy as np
12
  import warnings
13
- from .lora_util import add_text_lora_layer
14
- import gc
15
 
16
  from PIL import Image
17
  import PIL
18
 
 
 
19
  import inspect
20
 
21
  import time
@@ -31,7 +28,8 @@ def numpy_to_pil(images):
31
  images = (images * 255).round().astype("uint8")
32
  if images.shape[-1] == 1:
33
  # special case for grayscale (single channel) images
34
- pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
 
35
  else:
36
  pil_images = [Image.fromarray(image) for image in images]
37
 
@@ -53,7 +51,8 @@ def preprocess(image):
53
  w, h = image[0].size
54
  w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
55
 
56
- image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
 
57
  image = np.concatenate(image, axis=0)
58
  image = np.array(image).astype(np.float32) / 255.0
59
  image = image.transpose(0, 3, 1, 2)
@@ -63,69 +62,11 @@ def preprocess(image):
63
  image = torch.cat(image, dim=0)
64
  return image
65
 
66
- class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
67
- def __init__(self, model_path, lib_so_path, model_dtype='fp32', device=torch.device("cuda"), dtype=torch.float16) -> None:
68
- self.device = device
69
- self.dtype = dtype
70
-
71
- torch.classes.load_library(lib_so_path)
72
-
73
- self.vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(dtype).to(device)
74
- self.tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
75
- self.text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder").to(dtype).to(device)
76
-
77
- self.unet_in_channels = 4
78
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
79
- self.vae.enable_tiling()
80
- self.unet = torch.classes.lyrasd.Unet2dConditionalModelOp(
81
- 3, # max num of controlnets
82
- "fp16" # inference dtype (can only use fp16 for now)
83
- )
84
-
85
- unet_path = os.path.join(model_path, "unet_bins/")
86
-
87
- self.reload_unet_model(unet_path, model_dtype)
88
-
89
- self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
90
-
91
- def load_controlnet_model(self, model_name, controlnet_path, model_dtype="fp32"):
92
- if len(controlnet_path) > 0 and controlnet_path[-1] != "/":
93
- controlnet_path = controlnet_path + "/"
94
- self.unet.load_controlnet_model(model_name, controlnet_path, model_dtype)
95
-
96
- def unload_controlnet_model(self, model_name):
97
- self.unet.unload_controlnet_model(model_name, True)
98
-
99
- def get_loaded_controlnet(self):
100
- return self.unet.get_loaded_controlnet()
101
-
102
- def reload_unet_model(self, unet_path, unet_file_format='fp32'):
103
- if len(unet_path) > 0 and unet_path[-1] != "/":
104
- unet_path = unet_path + "/"
105
- return self.unet.reload_unet_model(unet_path, unet_file_format)
106
-
107
- def load_lora(self, lora_model_path, lora_name, lora_strength, lora_file_format='fp32'):
108
- if len(lora_model_path) > 0 and lora_model_path[-1] != "/":
109
- lora_model_path = lora_model_path + "/"
110
- lora = add_text_lora_layer(self.text_encoder, lora_model_path, lora_strength, lora_file_format)
111
- self.loaded_lora[lora_name] = lora
112
- self.unet.load_lora(lora_model_path, lora_name, lora_strength, lora_file_format)
113
 
114
- def unload_lora(self, lora_name, clean_cache=False):
115
- for layer_data in self.loaded_lora[lora_name]:
116
- layer = layer_data['layer']
117
- added_weight = layer_data['added_weight']
118
- layer.weight.data -= added_weight
119
- self.unet.unload_lora(lora_name, clean_cache)
120
- del self.loaded_lora[lora_name]
121
- gc.collect()
122
- torch.cuda.empty_cache()
123
-
124
- def clean_lora_cache(self):
125
- self.unet.clean_lora_cache()
126
-
127
- def get_loaded_lora(self):
128
- return self.unet.get_loaded_lora()
129
 
130
  def _encode_prompt(
131
  self,
@@ -181,13 +122,14 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
181
  return_tensors="pt",
182
  )
183
  text_input_ids = text_inputs.input_ids
184
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
 
185
 
186
  if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
187
  text_input_ids, untruncated_ids
188
  ):
189
  removed_text = self.tokenizer.batch_decode(
190
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
191
  )
192
  logger.warning(
193
  "The following part of your input was truncated because CLIP can only handle sequences up to"
@@ -205,12 +147,14 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
205
  )
206
  prompt_embeds = prompt_embeds[0]
207
 
208
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
 
209
 
210
  bs_embed, seq_len, _ = prompt_embeds.shape
211
  # duplicate text embeddings for each generation per prompt, using mps friendly method
212
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
213
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
 
214
 
215
  # get unconditional embeddings for classifier free guidance
216
  if do_classifier_free_guidance and negative_prompt_embeds is None:
@@ -235,7 +179,8 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
235
 
236
  # textual inversion: procecss multi-vector tokens if necessary
237
  if isinstance(self, TextualInversionLoaderMixin):
238
- uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
 
239
 
240
  max_length = prompt_embeds.shape[1]
241
  uncond_input = self.tokenizer(
@@ -261,10 +206,13 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
261
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
262
  seq_len = negative_prompt_embeds.shape[1]
263
 
264
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
 
265
 
266
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
267
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
 
 
268
 
269
  # For classifier free guidance, we need to do two forward passes.
270
  # Here we concatenate the unconditional and text embeddings into a single batch
@@ -272,7 +220,6 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
272
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
273
 
274
  return prompt_embeds
275
-
276
 
277
  def decode_latents(self, latents):
278
  latents = 1 / self.vae.config.scaling_factor * latents
@@ -282,6 +229,17 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
282
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
283
  return image
284
 
 
 
 
 
 
 
 
 
 
 
 
285
  def check_inputs(
286
  self,
287
  prompt,
@@ -291,8 +249,9 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
291
  prompt_embeds=None,
292
  negative_prompt_embeds=None,
293
  ):
294
- if height % 64 != 0 or width % 64 != 0: # 初版暂时只支持 64 的倍数的 height 和 width
295
- raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.")
 
296
 
297
  if prompt is not None and prompt_embeds is not None:
298
  raise ValueError(
@@ -304,7 +263,8 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
304
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
305
  )
306
  elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
307
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
 
308
 
309
  if negative_prompt is not None and negative_prompt_embeds is not None:
310
  raise ValueError(
@@ -342,13 +302,14 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
342
 
343
  elif isinstance(generator, list):
344
  init_latents = [
345
- self.vae.encode(image[i: i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
346
  ]
347
  init_latents = torch.cat(init_latents, dim=0)
348
  else:
349
- init_latents = self.vae.encode(image).latent_dist.sample(generator)
 
350
 
351
- init_latents = self.vae.config.scaling_factor * init_latents
352
 
353
  if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
354
  # expand init_latents for batch_size
@@ -358,9 +319,9 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
358
  " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
359
  " your script to pass as many initial images as text prompts to suppress this warning."
360
  )
361
- deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
362
  additional_image_per_prompt = batch_size // init_latents.shape[0]
363
- init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
 
364
  elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
365
  raise ValueError(
366
  f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
@@ -369,7 +330,8 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
369
  init_latents = torch.cat([init_latents], dim=0)
370
 
371
  shape = init_latents.shape
372
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
 
373
 
374
  # get latents
375
  init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
@@ -398,7 +360,8 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
398
 
399
  for image_ in image:
400
  image_ = image_.convert("RGB")
401
- image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
 
402
  image_ = np.array(image_)
403
  image_ = image_[None, :]
404
  images.append(image_)
@@ -434,27 +397,29 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
434
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
435
  # and should be between [0, 1]
436
 
437
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
438
  extra_step_kwargs = {}
439
  if accepts_eta:
440
  extra_step_kwargs["eta"] = eta
441
 
442
  # check if the scheduler accepts generator
443
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
444
  if accepts_generator:
445
  extra_step_kwargs["generator"] = generator
446
  return extra_step_kwargs
447
 
448
  def get_timesteps(self, num_inference_steps, strength, device):
449
  # get the original timestep using init_timestep
450
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
 
451
 
452
  t_start = max(num_inference_steps - init_timestep, 0)
453
  timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
454
 
455
  return timesteps, num_inference_steps - t_start
456
 
457
-
458
  @torch.no_grad()
459
  def __call__(
460
  self,
@@ -477,9 +442,10 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
477
  controlnet_images: Optional[List[PIL.Image.Image]] = None,
478
  controlnet_scale: Optional[List[float]] = None,
479
  controlnet_names: Optional[List[str]] = None,
480
- guess_mode = False,
481
  eta: float = 0.0,
482
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
 
483
  latents: Optional[torch.FloatTensor] = None,
484
  prompt_embeds: Optional[torch.FloatTensor] = None,
485
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -549,7 +515,6 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
549
  # corresponds to doing no classifier free guidance.
550
  do_classifier_free_guidance = guidance_scale > 1.0
551
 
552
-
553
  # 3. Encode input prompt
554
  start = time.perf_counter()
555
  prompt_embeds = self._encode_prompt(
@@ -583,17 +548,21 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
583
  scales = [1.0, ] * 13
584
  if guess_mode:
585
  scales = torch.logspace(-1, 0, 13).tolist()
586
-
587
  for scale in controlnet_scale:
588
  scales_ = [d * scale for d in scales]
589
  control_scales.append(scales_)
590
 
591
- image = preprocess(image)
592
-
 
 
593
  # 5. set timesteps
594
  self.scheduler.set_timesteps(num_inference_steps, device=device)
595
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
596
- latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
 
 
597
 
598
  # 6. Prepare latent variables
599
  latents = self.prepare_latents(
@@ -604,33 +573,46 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
604
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
605
 
606
  # 8. Denoising loop
607
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
 
608
 
609
  start_unet = time.perf_counter()
610
  for i, t in enumerate(timesteps):
611
  # expand the latents if we are doing classifier free guidance
612
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
613
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
614
- latent_model_input = latent_model_input.permute(0, 2, 3, 1).contiguous()
 
 
 
615
 
616
  # 后边三个 None 是给到controlnet 的参数,暂时给到 None 当 placeholder
617
- noise_pred = self.unet.forward(latent_model_input, prompt_embeds, t, controlnet_names, control_images, control_scales, guess_mode)
 
618
 
619
  noise_pred = noise_pred.permute(0, 3, 1, 2)
620
  # perform guidance
621
 
622
  if do_classifier_free_guidance:
623
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
624
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
625
 
626
  # compute the previous noisy sample x_t -> x_t-1
627
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
 
628
 
629
  torch.cuda.synchronize()
630
 
 
 
 
631
  start = time.perf_counter()
632
- image = self.decode_latents(latents)
 
633
  torch.cuda.synchronize()
 
 
634
  image = numpy_to_pil(image)
635
 
636
  return image
 
1
  import torch
2
  from typing import Any, Callable, Dict, List, Optional, Union
 
3
  from diffusers.loaders import TextualInversionLoaderMixin
4
+ from diffusers.utils.torch_utils import logging, randn_tensor
 
 
5
  from diffusers.utils import PIL_INTERPOLATION
6
+
7
  import os
8
  import numpy as np
9
  import warnings
 
 
10
 
11
  from PIL import Image
12
  import PIL
13
 
14
+ from .lyrasd_pipeline_base import LyraSDXLPipelineBase
15
+
16
  import inspect
17
 
18
  import time
 
28
  images = (images * 255).round().astype("uint8")
29
  if images.shape[-1] == 1:
30
  # special case for grayscale (single channel) images
31
+ pil_images = [Image.fromarray(image.squeeze(), mode="L")
32
+ for image in images]
33
  else:
34
  pil_images = [Image.fromarray(image) for image in images]
35
 
 
51
  w, h = image[0].size
52
  w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
53
 
54
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[
55
+ None, :] for i in image]
56
  image = np.concatenate(image, axis=0)
57
  image = np.array(image).astype(np.float32) / 255.0
58
  image = image.transpose(0, 3, 1, 2)
 
62
  image = torch.cat(image, dim=0)
63
  return image
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ class LyraSdControlnetImg2ImgPipeline(LyraSDXLPipelineBase):
67
+ def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.18215) -> None:
68
+ super().__init__(device, dtype, vae_scale_factor=vae_scale_factor,
69
+ vae_scaling_factor=vae_scaling_factor)
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def _encode_prompt(
72
  self,
 
122
  return_tensors="pt",
123
  )
124
  text_input_ids = text_inputs.input_ids
125
+ untruncated_ids = self.tokenizer(
126
+ prompt, padding="longest", return_tensors="pt").input_ids
127
 
128
  if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
129
  text_input_ids, untruncated_ids
130
  ):
131
  removed_text = self.tokenizer.batch_decode(
132
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
133
  )
134
  logger.warning(
135
  "The following part of your input was truncated because CLIP can only handle sequences up to"
 
147
  )
148
  prompt_embeds = prompt_embeds[0]
149
 
150
+ prompt_embeds = prompt_embeds.to(
151
+ dtype=self.text_encoder.dtype, device=device)
152
 
153
  bs_embed, seq_len, _ = prompt_embeds.shape
154
  # duplicate text embeddings for each generation per prompt, using mps friendly method
155
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
156
+ prompt_embeds = prompt_embeds.view(
157
+ bs_embed * num_images_per_prompt, seq_len, -1)
158
 
159
  # get unconditional embeddings for classifier free guidance
160
  if do_classifier_free_guidance and negative_prompt_embeds is None:
 
179
 
180
  # textual inversion: procecss multi-vector tokens if necessary
181
  if isinstance(self, TextualInversionLoaderMixin):
182
+ uncond_tokens = self.maybe_convert_prompt(
183
+ uncond_tokens, self.tokenizer)
184
 
185
  max_length = prompt_embeds.shape[1]
186
  uncond_input = self.tokenizer(
 
206
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
207
  seq_len = negative_prompt_embeds.shape[1]
208
 
209
+ negative_prompt_embeds = negative_prompt_embeds.to(
210
+ dtype=self.text_encoder.dtype, device=device)
211
 
212
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
213
+ 1, num_images_per_prompt, 1)
214
+ negative_prompt_embeds = negative_prompt_embeds.view(
215
+ batch_size * num_images_per_prompt, seq_len, -1)
216
 
217
  # For classifier free guidance, we need to do two forward passes.
218
  # Here we concatenate the unconditional and text embeddings into a single batch
 
220
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
221
 
222
  return prompt_embeds
 
223
 
224
  def decode_latents(self, latents):
225
  latents = 1 / self.vae.config.scaling_factor * latents
 
229
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
230
  return image
231
 
232
+ def lyra_decode_latents(self, latents):
233
+ print("lyra_decode_latents")
234
+ latents = 1 / self.vae_scaling_factor * latents
235
+ image = self.vae.decode(latents)
236
+ image = image.permute(0, 2, 3, 1)
237
+
238
+ image = (image / 2 + 0.5).clamp(0, 1)
239
+ image = image.cpu().float().numpy()
240
+
241
+ return image
242
+
243
  def check_inputs(
244
  self,
245
  prompt,
 
249
  prompt_embeds=None,
250
  negative_prompt_embeds=None,
251
  ):
252
+ if height % 64 != 0 or width % 64 != 0: # 初版暂时只支持 64 的倍数的 height 和 width
253
+ raise ValueError(
254
+ f"`height` and `width` have to be divisible by 64 but are {height} and {width}.")
255
 
256
  if prompt is not None and prompt_embeds is not None:
257
  raise ValueError(
 
263
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
264
  )
265
  elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
266
+ raise ValueError(
267
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
268
 
269
  if negative_prompt is not None and negative_prompt_embeds is not None:
270
  raise ValueError(
 
302
 
303
  elif isinstance(generator, list):
304
  init_latents = [
305
+ self.vae.encode(image[i: i + 1]).sample(generator[i]) for i in range(batch_size)
306
  ]
307
  init_latents = torch.cat(init_latents, dim=0)
308
  else:
309
+ init_latents = self.vae.encode(
310
+ image).sample(generator)
311
 
312
+ init_latents = self.vae.scaling_factor * init_latents
313
 
314
  if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
315
  # expand init_latents for batch_size
 
319
  " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
320
  " your script to pass as many initial images as text prompts to suppress this warning."
321
  )
 
322
  additional_image_per_prompt = batch_size // init_latents.shape[0]
323
+ init_latents = torch.cat(
324
+ [init_latents] * additional_image_per_prompt, dim=0)
325
  elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
326
  raise ValueError(
327
  f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
 
330
  init_latents = torch.cat([init_latents], dim=0)
331
 
332
  shape = init_latents.shape
333
+ noise = randn_tensor(shape, generator=generator,
334
+ device=device, dtype=dtype)
335
 
336
  # get latents
337
  init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
 
360
 
361
  for image_ in image:
362
  image_ = image_.convert("RGB")
363
+ image_ = image_.resize(
364
+ (width, height), resample=PIL_INTERPOLATION["lanczos"])
365
  image_ = np.array(image_)
366
  image_ = image_[None, :]
367
  images.append(image_)
 
397
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
398
  # and should be between [0, 1]
399
 
400
+ accepts_eta = "eta" in set(inspect.signature(
401
+ self.scheduler.step).parameters.keys())
402
  extra_step_kwargs = {}
403
  if accepts_eta:
404
  extra_step_kwargs["eta"] = eta
405
 
406
  # check if the scheduler accepts generator
407
+ accepts_generator = "generator" in set(
408
+ inspect.signature(self.scheduler.step).parameters.keys())
409
  if accepts_generator:
410
  extra_step_kwargs["generator"] = generator
411
  return extra_step_kwargs
412
 
413
  def get_timesteps(self, num_inference_steps, strength, device):
414
  # get the original timestep using init_timestep
415
+ init_timestep = min(
416
+ int(num_inference_steps * strength), num_inference_steps)
417
 
418
  t_start = max(num_inference_steps - init_timestep, 0)
419
  timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
420
 
421
  return timesteps, num_inference_steps - t_start
422
 
 
423
  @torch.no_grad()
424
  def __call__(
425
  self,
 
442
  controlnet_images: Optional[List[PIL.Image.Image]] = None,
443
  controlnet_scale: Optional[List[float]] = None,
444
  controlnet_names: Optional[List[str]] = None,
445
+ guess_mode=False,
446
  eta: float = 0.0,
447
+ generator: Optional[Union[torch.Generator,
448
+ List[torch.Generator]]] = None,
449
  latents: Optional[torch.FloatTensor] = None,
450
  prompt_embeds: Optional[torch.FloatTensor] = None,
451
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
 
515
  # corresponds to doing no classifier free guidance.
516
  do_classifier_free_guidance = guidance_scale > 1.0
517
 
 
518
  # 3. Encode input prompt
519
  start = time.perf_counter()
520
  prompt_embeds = self._encode_prompt(
 
548
  scales = [1.0, ] * 13
549
  if guess_mode:
550
  scales = torch.logspace(-1, 0, 13).tolist()
551
+
552
  for scale in controlnet_scale:
553
  scales_ = [d * scale for d in scales]
554
  control_scales.append(scales_)
555
 
556
+ print(f"clip cost: {(time.perf_counter() - start)* 1000}")
557
+
558
+ image = self.image_processor.preprocess(image)
559
+
560
  # 5. set timesteps
561
  self.scheduler.set_timesteps(num_inference_steps, device=device)
562
+ timesteps, num_inference_steps = self.get_timesteps(
563
+ num_inference_steps, strength, device)
564
+ latent_timestep = timesteps[:1].repeat(
565
+ batch_size * num_images_per_prompt)
566
 
567
  # 6. Prepare latent variables
568
  latents = self.prepare_latents(
 
573
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
574
 
575
  # 8. Denoising loop
576
+ num_warmup_steps = len(timesteps) - \
577
+ num_inference_steps * self.scheduler.order
578
 
579
  start_unet = time.perf_counter()
580
  for i, t in enumerate(timesteps):
581
  # expand the latents if we are doing classifier free guidance
582
+ latent_model_input = torch.cat(
583
+ [latents] * 2) if do_classifier_free_guidance else latents
584
+ latent_model_input = self.scheduler.scale_model_input(
585
+ latent_model_input, t)
586
+ latent_model_input = latent_model_input.permute(
587
+ 0, 2, 3, 1).contiguous()
588
 
589
  # 后边三个 None 是给到controlnet 的参数,暂时给到 None 当 placeholder
590
+ noise_pred = self.unet.forward(
591
+ latent_model_input, prompt_embeds, t, controlnet_names, control_images, control_scales, guess_mode)
592
 
593
  noise_pred = noise_pred.permute(0, 3, 1, 2)
594
  # perform guidance
595
 
596
  if do_classifier_free_guidance:
597
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
598
+ noise_pred = noise_pred_uncond + guidance_scale * \
599
+ (noise_pred_text - noise_pred_uncond)
600
 
601
  # compute the previous noisy sample x_t -> x_t-1
602
+ latents = self.scheduler.step(
603
+ noise_pred, t, latents, **extra_step_kwargs).prev_sample
604
 
605
  torch.cuda.synchronize()
606
 
607
+ print(
608
+ f"unet x {num_inference_steps} cost: {(time.perf_counter() - start_unet) * 1000}")
609
+
610
  start = time.perf_counter()
611
+ # image = self.decode_latents(latents)
612
+ image = self.lyra_decode_latents(latents)
613
  torch.cuda.synchronize()
614
+ print(f"vae cost: {(time.perf_counter() - start)* 1000}")
615
+ print()
616
  image = numpy_to_pil(image)
617
 
618
  return image
lyrasd_model/lyrasd_controlnet_txt2img_pipeline.py CHANGED
@@ -1,12 +1,8 @@
1
  import torch
2
  from typing import Any, Callable, Dict, List, Optional, Union
3
- from diffusers.schedulers import KarrasDiffusionSchedulers
4
  from diffusers.loaders import TextualInversionLoaderMixin
5
- from diffusers.models import AutoencoderKL
6
- from diffusers.utils import randn_tensor, logging
7
- from diffusers.schedulers import EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
8
  from diffusers.utils import PIL_INTERPOLATION
9
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
10
  import os
11
  import numpy as np
12
  from .lora_util import add_text_lora_layer
@@ -17,6 +13,7 @@ import PIL
17
  import inspect
18
 
19
  import time
 
20
 
21
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
 
@@ -36,68 +33,11 @@ def numpy_to_pil(images):
36
  return pil_images
37
 
38
 
39
- class LyraSdControlnetTxt2ImgPipeline(TextualInversionLoaderMixin):
40
- def __init__(self, model_path, lib_so_path, model_dtype='fp32', device=torch.device("cuda"), dtype=torch.float16) -> None:
41
- self.device = device
42
- self.dtype = dtype
43
-
44
- torch.classes.load_library(lib_so_path)
45
-
46
- self.vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(dtype).to(device)
47
- self.tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
48
- self.text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder").to(dtype).to(device)
49
- self.unet_in_channels = 4
50
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
51
- self.vae.enable_tiling()
52
- self.unet = torch.classes.lyrasd.Unet2dConditionalModelOp(
53
- 3, # max num of controlnets
54
- "fp16" # inference dtype (can only use fp16 for now)
55
- )
56
-
57
- unet_path = os.path.join(model_path, "unet_bins/")
58
- self.reload_unet_model(unet_path, model_dtype)
59
-
60
- self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
61
-
62
- def load_controlnet_model(self, model_name, controlnet_path, model_dtype="fp32"):
63
- if len(controlnet_path) > 0 and controlnet_path[-1] != "/":
64
- controlnet_path = controlnet_path + "/"
65
- self.unet.load_controlnet_model(model_name, controlnet_path, model_dtype)
66
-
67
- def unload_controlnet_model(self, model_name):
68
- self.unet.unload_controlnet_model(model_name, True)
69
-
70
- def get_loaded_controlnet(self):
71
- return self.unet.get_loaded_controlnet()
72
-
73
- def reload_unet_model(self, unet_path, unet_file_format='fp32'):
74
- if len(unet_path) > 0 and unet_path[-1] != "/":
75
- unet_path = unet_path + "/"
76
- return self.unet.reload_unet_model(unet_path, unet_file_format)
77
-
78
- def load_lora(self, lora_model_path, lora_name, lora_strength, lora_file_format='fp32'):
79
- if len(lora_model_path) > 0 and lora_model_path[-1] != "/":
80
- lora_model_path = lora_model_path + "/"
81
- lora = add_text_lora_layer(self.text_encoder, lora_model_path, lora_strength, lora_file_format)
82
- self.loaded_lora[lora_name] = lora
83
- self.unet.load_lora(lora_model_path, lora_name, lora_strength, lora_file_format)
84
-
85
- def unload_lora(self, lora_name, clean_cache=False):
86
- for layer_data in self.loaded_lora[lora_name]:
87
- layer = layer_data['layer']
88
- added_weight = layer_data['added_weight']
89
- layer.weight.data -= added_weight
90
- self.unet.unload_lora(lora_name, clean_cache)
91
- del self.loaded_lora[lora_name]
92
- gc.collect()
93
- torch.cuda.empty_cache()
94
-
95
- def clean_lora_cache(self):
96
- self.unet.clean_lora_cache()
97
-
98
- def get_loaded_lora(self):
99
- return self.unet.get_loaded_lora()
100
-
101
  def _encode_prompt(
102
  self,
103
  prompt,
@@ -253,6 +193,23 @@ class LyraSdControlnetTxt2ImgPipeline(TextualInversionLoaderMixin):
253
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
254
  return image
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  def check_inputs(
257
  self,
258
  prompt,
@@ -342,21 +299,8 @@ class LyraSdControlnetTxt2ImgPipeline(TextualInversionLoaderMixin):
342
  elif isinstance(image[0], torch.Tensor):
343
  image = torch.cat(image, dim=0)
344
 
345
- image_batch_size = image.shape[0]
346
-
347
- if image_batch_size == 1:
348
- repeat_by = batch_size
349
- else:
350
- # image batch size is the same as prompt batch size
351
- repeat_by = num_images_per_prompt
352
-
353
- image = image.repeat_interleave(repeat_by, dim=0)
354
-
355
  image = image.to(device=device, dtype=dtype)
356
 
357
- if do_classifier_free_guidance and not guess_mode:
358
- image = torch.cat([image] * 2)
359
-
360
  return image
361
 
362
  def prepare_extra_step_kwargs(self, generator, eta):
@@ -376,6 +320,18 @@ class LyraSdControlnetTxt2ImgPipeline(TextualInversionLoaderMixin):
376
  extra_step_kwargs["generator"] = generator
377
  return extra_step_kwargs
378
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  @torch.no_grad()
380
  def __call__(
381
  self,
@@ -527,7 +483,7 @@ class LyraSdControlnetTxt2ImgPipeline(TextualInversionLoaderMixin):
527
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
528
  latent_model_input = latent_model_input.permute(0, 2, 3, 1).contiguous()
529
 
530
- # 后边三个 None 是给到controlnet 的参数,暂时给到 None 当 placeholder
531
  noise_pred = self.unet.forward(latent_model_input, prompt_embeds, t, controlnet_names, control_images, control_scales, guess_mode)
532
 
533
  noise_pred = noise_pred.permute(0, 3, 1, 2)
@@ -540,7 +496,9 @@ class LyraSdControlnetTxt2ImgPipeline(TextualInversionLoaderMixin):
540
  # compute the previous noisy sample x_t -> x_t-1
541
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
542
 
543
- image = self.decode_latents(latents)
 
 
544
  image = numpy_to_pil(image)
545
 
546
  return image
 
1
  import torch
2
  from typing import Any, Callable, Dict, List, Optional, Union
 
3
  from diffusers.loaders import TextualInversionLoaderMixin
4
+ from diffusers.utils.torch_utils import logging, randn_tensor
 
 
5
  from diffusers.utils import PIL_INTERPOLATION
 
6
  import os
7
  import numpy as np
8
  from .lora_util import add_text_lora_layer
 
13
  import inspect
14
 
15
  import time
16
+ from .lyrasd_pipeline_base import LyraSDXLPipelineBase
17
 
18
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
 
 
33
  return pil_images
34
 
35
 
36
+ class LyraSdControlnetTxt2ImgPipeline(LyraSDXLPipelineBase):
37
+ def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.18215) -> None:
38
+ super().__init__(device, dtype, vae_scale_factor=vae_scale_factor,
39
+ vae_scaling_factor=vae_scaling_factor)
40
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def _encode_prompt(
42
  self,
43
  prompt,
 
193
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
194
  return image
195
 
196
+ def lyra_decode_latents(self, latents):
197
+ print("lyra_decode_latents")
198
+ # np.save("", latents.)
199
+ # np.save(f"/workspace/vae_model/latent.npy", latents.detach().cpu().numpy())
200
+ latents = 1 / self.vae_scaling_factor * latents
201
+ latents = latents.permute(0, 2, 3, 1).contiguous()
202
+ image = self.vae.vae_decode(latents)
203
+
204
+ # print(image)
205
+ # GPUtil.showUtilization(all=True)
206
+
207
+ image = (image / 2 + 0.5).clamp(0, 1)
208
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
209
+ image = image.cpu().float().numpy()
210
+
211
+ return image
212
+
213
  def check_inputs(
214
  self,
215
  prompt,
 
299
  elif isinstance(image[0], torch.Tensor):
300
  image = torch.cat(image, dim=0)
301
 
 
 
 
 
 
 
 
 
 
 
302
  image = image.to(device=device, dtype=dtype)
303
 
 
 
 
304
  return image
305
 
306
  def prepare_extra_step_kwargs(self, generator, eta):
 
320
  extra_step_kwargs["generator"] = generator
321
  return extra_step_kwargs
322
 
323
+ def lyra_decode_latents(self, latents):
324
+ print("lyra_decode_latents")
325
+ latents = 1 / self.vae_scaling_factor * latents
326
+ image = self.vae.decode(latents)
327
+ image = image.permute(0, 2, 3, 1)
328
+ image = (image / 2 + 0.5).clamp(0, 1)
329
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
330
+ image = image.cpu().float().numpy()
331
+
332
+ return image
333
+
334
+
335
  @torch.no_grad()
336
  def __call__(
337
  self,
 
483
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
484
  latent_model_input = latent_model_input.permute(0, 2, 3, 1).contiguous()
485
 
486
+ control_images[0]
487
  noise_pred = self.unet.forward(latent_model_input, prompt_embeds, t, controlnet_names, control_images, control_scales, guess_mode)
488
 
489
  noise_pred = noise_pred.permute(0, 3, 1, 2)
 
496
  # compute the previous noisy sample x_t -> x_t-1
497
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
498
 
499
+ # image = self.decode_latents(latents)
500
+ image = self.lyra_decode_latents(latents)
501
+
502
  image = numpy_to_pil(image)
503
 
504
  return image
lyrasd_model/lyrasd_img2img_pipeline.py CHANGED
@@ -8,13 +8,12 @@ import numpy as np
8
  import PIL
9
  import torch
10
  from diffusers.loaders import TextualInversionLoaderMixin
11
- from diffusers.models import AutoencoderKL
12
- from diffusers.schedulers import EulerAncestralDiscreteScheduler
13
- from diffusers.utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
14
  from PIL import Image
15
- from transformers import CLIPTextModel, CLIPTokenizer
16
- from .lora_util import add_text_lora_layer
17
- import gc
18
 
19
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
 
@@ -28,7 +27,8 @@ def numpy_to_pil(images):
28
  images = (images * 255).round().astype("uint8")
29
  if images.shape[-1] == 1:
30
  # special case for grayscale (single channel) images
31
- pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
 
32
  else:
33
  pil_images = [Image.fromarray(image) for image in images]
34
 
@@ -50,7 +50,8 @@ def preprocess(image):
50
  w, h = image[0].size
51
  w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
52
 
53
- image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
 
54
  image = np.concatenate(image, axis=0)
55
  image = np.array(image).astype(np.float32) / 255.0
56
  image = image.transpose(0, 3, 1, 2)
@@ -61,60 +62,13 @@ def preprocess(image):
61
  return image
62
 
63
 
64
- class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
65
- def __init__(self, model_path, lib_so_path, model_dtype='fp32', device=torch.device("cuda"), dtype=torch.float16) -> None:
66
- self.device = device
67
- self.dtype = dtype
68
-
69
- torch.classes.load_library(lib_so_path)
70
-
71
- self.vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(dtype).to(device)
72
- self.tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
73
- self.text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder").to(dtype).to(device)
74
- unet_path = os.path.join(model_path, "unet_bins/")
75
-
76
- self.unet_in_channels = 4
77
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
78
- self.vae.enable_tiling()
79
- self.unet = torch.classes.lyrasd.Unet2dConditionalModelOp(
80
- 3, # max num of controlnets
81
- "fp16" # inference dtype (can only use fp16 for now)
82
- )
83
-
84
- self.reload_unet_model(unet_path, model_dtype)
85
-
86
- self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
87
-
88
- def reload_unet_model(self, unet_path, unet_file_format='fp32'):
89
- if len(unet_path) > 0 and unet_path[-1] != "/":
90
- unet_path = unet_path + "/"
91
- return self.unet.reload_unet_model(unet_path, unet_file_format)
92
-
93
- def load_lora(self, lora_model_path, lora_name, lora_strength, lora_file_format='fp32'):
94
- if len(lora_model_path) > 0 and lora_model_path[-1] != "/":
95
- lora_model_path = lora_model_path + "/"
96
- lora = add_text_lora_layer(self.text_encoder, lora_model_path, lora_strength, lora_file_format)
97
- self.loaded_lora[lora_name] = lora
98
- self.unet.load_lora(lora_model_path, lora_name, lora_strength, lora_file_format)
99
-
100
- def unload_lora(self, lora_name, clean_cache=False):
101
- for layer_data in self.loaded_lora[lora_name]:
102
- layer = layer_data['layer']
103
- added_weight = layer_data['added_weight']
104
- layer.weight.data -= added_weight
105
- self.unet.unload_lora(lora_name, clean_cache)
106
- del self.loaded_lora[lora_name]
107
- gc.collect()
108
- torch.cuda.empty_cache()
109
-
110
- def clean_lora_cache(self):
111
- self.unet.clean_lora_cache()
112
-
113
- def get_loaded_lora(self):
114
- return self.unet.get_loaded_lora()
115
-
116
 
117
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
 
118
  def _encode_prompt(
119
  self,
120
  prompt,
@@ -170,7 +124,8 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
170
  return_tensors="pt",
171
  )
172
  text_input_ids = text_inputs.input_ids
173
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
 
174
 
175
  if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
176
  text_input_ids, untruncated_ids
@@ -201,12 +156,14 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
201
  else:
202
  prompt_embeds_dtype = prompt_embeds.dtype
203
 
204
- prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
 
205
 
206
  bs_embed, seq_len, _ = prompt_embeds.shape
207
  # duplicate text embeddings for each generation per prompt, using mps friendly method
208
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
209
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
 
210
 
211
  # get unconditional embeddings for classifier free guidance
212
  if do_classifier_free_guidance and negative_prompt_embeds is None:
@@ -231,7 +188,8 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
231
 
232
  # textual inversion: procecss multi-vector tokens if necessary
233
  if isinstance(self, TextualInversionLoaderMixin):
234
- uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
 
235
 
236
  max_length = prompt_embeds.shape[1]
237
  uncond_input = self.tokenizer(
@@ -257,10 +215,13 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
257
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
258
  seq_len = negative_prompt_embeds.shape[1]
259
 
260
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
 
261
 
262
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
263
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
 
 
264
 
265
  # For classifier free guidance, we need to do two forward passes.
266
  # Here we concatenate the unconditional and text embeddings into a single batch
@@ -286,13 +247,15 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
286
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
287
  # and should be between [0, 1]
288
 
289
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
290
  extra_step_kwargs = {}
291
  if accepts_eta:
292
  extra_step_kwargs["eta"] = eta
293
 
294
  # check if the scheduler accepts generator
295
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
296
  if accepts_generator:
297
  extra_step_kwargs["generator"] = generator
298
  return extra_step_kwargs
@@ -301,10 +264,12 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
301
  self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
302
  ):
303
  if strength < 0 or strength > 1:
304
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
 
305
 
306
  if (callback_steps is None) or (
307
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
 
308
  ):
309
  raise ValueError(
310
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
@@ -321,7 +286,8 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
321
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
322
  )
323
  elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
324
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
 
325
 
326
  if negative_prompt is not None and negative_prompt_embeds is not None:
327
  raise ValueError(
@@ -339,7 +305,8 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
339
 
340
  def get_timesteps(self, num_inference_steps, strength, device):
341
  # get the original timestep using init_timestep
342
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
 
343
 
344
  t_start = max(num_inference_steps - init_timestep, 0)
345
  timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
@@ -354,6 +321,8 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
354
 
355
  image = image.to(device=device, dtype=dtype)
356
 
 
 
357
  batch_size = batch_size * num_images_per_prompt
358
 
359
  if image.shape[1] == 4:
@@ -368,13 +337,13 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
368
 
369
  elif isinstance(generator, list):
370
  init_latents = [
371
- self.vae.encode(image[i: i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
372
  ]
373
  init_latents = torch.cat(init_latents, dim=0)
374
  else:
375
- init_latents = self.vae.encode(image).latent_dist.sample(generator)
376
 
377
- init_latents = self.vae.config.scaling_factor * init_latents
378
 
379
  if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
380
  # expand init_latents for batch_size
@@ -384,9 +353,11 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
384
  " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
385
  " your script to pass as many initial images as text prompts to suppress this warning."
386
  )
387
- deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
 
388
  additional_image_per_prompt = batch_size // init_latents.shape[0]
389
- init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
 
390
  elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
391
  raise ValueError(
392
  f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
@@ -395,7 +366,8 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
395
  init_latents = torch.cat([init_latents], dim=0)
396
 
397
  shape = init_latents.shape
398
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
 
399
 
400
  # get latents
401
  init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
@@ -403,6 +375,17 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
403
 
404
  return latents
405
 
 
 
 
 
 
 
 
 
 
 
 
406
  @torch.no_grad()
407
  def __call__(
408
  self,
@@ -421,10 +404,12 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
421
  negative_prompt: Optional[Union[str, List[str]]] = None,
422
  num_images_per_prompt: Optional[int] = 1,
423
  eta: Optional[float] = 0.0,
424
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
 
425
  prompt_embeds: Optional[torch.FloatTensor] = None,
426
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
427
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
 
428
  callback_steps: int = 1,
429
  ):
430
  r"""
@@ -482,7 +467,8 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
482
  "not-safe-for-work" (nsfw) content.
483
  """
484
  # 1. Check inputs. Raise error if not correct
485
- self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
 
486
 
487
  # 2. Define call parameters
488
  if prompt is not None and isinstance(prompt, str):
@@ -510,12 +496,14 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
510
  )
511
 
512
  # 4. Preprocess image
513
- image = preprocess(image)
514
 
515
  # 5. set timesteps
516
  self.scheduler.set_timesteps(num_inference_steps, device=device)
517
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
518
- latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
 
 
519
 
520
  # 6. Prepare latent variables
521
  latents = self.prepare_latents(
@@ -526,29 +514,36 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
526
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
527
 
528
  # 8. Denoising loop
529
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
 
530
 
531
  for i, t in enumerate(timesteps):
532
  # expand the latents if we are doing classifier free guidance
533
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
534
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
535
- latent_model_input = latent_model_input.permute(0, 2, 3, 1).contiguous()
 
 
 
536
 
537
  # predict the noise residual
538
- # 后边4个 None 是给到controlnet 的参数,暂时给到 None 当 placeholder
539
- noise_pred = self.unet.forward(latent_model_input, prompt_embeds, t, None, None, None, None)
540
-
541
  noise_pred = noise_pred.permute(0, 3, 1, 2)
542
 
543
  # perform guidance
544
  if do_classifier_free_guidance:
545
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
546
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
547
 
548
  # compute the previous noisy sample x_t -> x_t-1
549
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
 
550
 
551
- image = self.decode_latents(latents)
 
552
  image = numpy_to_pil(image)
553
 
554
  return image
 
8
  import PIL
9
  import torch
10
  from diffusers.loaders import TextualInversionLoaderMixin
11
+ from diffusers.utils import PIL_INTERPOLATION, deprecate
12
+ from diffusers.utils.torch_utils import logging, randn_tensor
 
13
  from PIL import Image
14
+
15
+ from .lyrasd_pipeline_base import LyraSDXLPipelineBase
16
+
17
 
18
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
 
 
27
  images = (images * 255).round().astype("uint8")
28
  if images.shape[-1] == 1:
29
  # special case for grayscale (single channel) images
30
+ pil_images = [Image.fromarray(image.squeeze(), mode="L")
31
+ for image in images]
32
  else:
33
  pil_images = [Image.fromarray(image) for image in images]
34
 
 
50
  w, h = image[0].size
51
  w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
52
 
53
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[
54
+ None, :] for i in image]
55
  image = np.concatenate(image, axis=0)
56
  image = np.array(image).astype(np.float32) / 255.0
57
  image = image.transpose(0, 3, 1, 2)
 
62
  return image
63
 
64
 
65
+ class LyraSDImg2ImgPipeline(LyraSDXLPipelineBase):
66
+ def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.18215) -> None:
67
+ super().__init__(device, dtype, vae_scale_factor=vae_scale_factor,
68
+ vae_scaling_factor=vae_scaling_factor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
71
+
72
  def _encode_prompt(
73
  self,
74
  prompt,
 
124
  return_tensors="pt",
125
  )
126
  text_input_ids = text_inputs.input_ids
127
+ untruncated_ids = self.tokenizer(
128
+ prompt, padding="longest", return_tensors="pt").input_ids
129
 
130
  if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
131
  text_input_ids, untruncated_ids
 
156
  else:
157
  prompt_embeds_dtype = prompt_embeds.dtype
158
 
159
+ prompt_embeds = prompt_embeds.to(
160
+ dtype=prompt_embeds_dtype, device=device)
161
 
162
  bs_embed, seq_len, _ = prompt_embeds.shape
163
  # duplicate text embeddings for each generation per prompt, using mps friendly method
164
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
165
+ prompt_embeds = prompt_embeds.view(
166
+ bs_embed * num_images_per_prompt, seq_len, -1)
167
 
168
  # get unconditional embeddings for classifier free guidance
169
  if do_classifier_free_guidance and negative_prompt_embeds is None:
 
188
 
189
  # textual inversion: procecss multi-vector tokens if necessary
190
  if isinstance(self, TextualInversionLoaderMixin):
191
+ uncond_tokens = self.maybe_convert_prompt(
192
+ uncond_tokens, self.tokenizer)
193
 
194
  max_length = prompt_embeds.shape[1]
195
  uncond_input = self.tokenizer(
 
215
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
216
  seq_len = negative_prompt_embeds.shape[1]
217
 
218
+ negative_prompt_embeds = negative_prompt_embeds.to(
219
+ dtype=prompt_embeds_dtype, device=device)
220
 
221
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
222
+ 1, num_images_per_prompt, 1)
223
+ negative_prompt_embeds = negative_prompt_embeds.view(
224
+ batch_size * num_images_per_prompt, seq_len, -1)
225
 
226
  # For classifier free guidance, we need to do two forward passes.
227
  # Here we concatenate the unconditional and text embeddings into a single batch
 
247
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
248
  # and should be between [0, 1]
249
 
250
+ accepts_eta = "eta" in set(inspect.signature(
251
+ self.scheduler.step).parameters.keys())
252
  extra_step_kwargs = {}
253
  if accepts_eta:
254
  extra_step_kwargs["eta"] = eta
255
 
256
  # check if the scheduler accepts generator
257
+ accepts_generator = "generator" in set(
258
+ inspect.signature(self.scheduler.step).parameters.keys())
259
  if accepts_generator:
260
  extra_step_kwargs["generator"] = generator
261
  return extra_step_kwargs
 
264
  self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
265
  ):
266
  if strength < 0 or strength > 1:
267
+ raise ValueError(
268
+ f"The value of strength should in [0.0, 1.0] but is {strength}")
269
 
270
  if (callback_steps is None) or (
271
+ callback_steps is not None and (not isinstance(
272
+ callback_steps, int) or callback_steps <= 0)
273
  ):
274
  raise ValueError(
275
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
 
286
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
287
  )
288
  elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
289
+ raise ValueError(
290
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
291
 
292
  if negative_prompt is not None and negative_prompt_embeds is not None:
293
  raise ValueError(
 
305
 
306
  def get_timesteps(self, num_inference_steps, strength, device):
307
  # get the original timestep using init_timestep
308
+ init_timestep = min(
309
+ int(num_inference_steps * strength), num_inference_steps)
310
 
311
  t_start = max(num_inference_steps - init_timestep, 0)
312
  timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
 
321
 
322
  image = image.to(device=device, dtype=dtype)
323
 
324
+ print(image.shape)
325
+
326
  batch_size = batch_size * num_images_per_prompt
327
 
328
  if image.shape[1] == 4:
 
337
 
338
  elif isinstance(generator, list):
339
  init_latents = [
340
+ self.vae.encode(image[i: i + 1]).sample(generator[i]) for i in range(batch_size)
341
  ]
342
  init_latents = torch.cat(init_latents, dim=0)
343
  else:
344
+ init_latents = self.vae.encode(image).sample(generator)
345
 
346
+ init_latents = self.vae.scaling_factor * init_latents
347
 
348
  if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
349
  # expand init_latents for batch_size
 
353
  " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
354
  " your script to pass as many initial images as text prompts to suppress this warning."
355
  )
356
+ deprecate("len(prompt) != len(image)", "1.0.0",
357
+ deprecation_message, standard_warn=False)
358
  additional_image_per_prompt = batch_size // init_latents.shape[0]
359
+ init_latents = torch.cat(
360
+ [init_latents] * additional_image_per_prompt, dim=0)
361
  elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
362
  raise ValueError(
363
  f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
 
366
  init_latents = torch.cat([init_latents], dim=0)
367
 
368
  shape = init_latents.shape
369
+ noise = randn_tensor(shape, generator=generator,
370
+ device=device, dtype=dtype)
371
 
372
  # get latents
373
  init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
 
375
 
376
  return latents
377
 
378
+ def lyra_decode_latents(self, latents):
379
+ print("lyra_decode_latents")
380
+ latents = 1 / self.vae_scaling_factor * latents
381
+ image = self.vae.decode(latents)
382
+ image = image.permute(0, 2, 3, 1)
383
+ image = (image / 2 + 0.5).clamp(0, 1)
384
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
385
+ image = image.cpu().float().numpy()
386
+
387
+ return image
388
+
389
  @torch.no_grad()
390
  def __call__(
391
  self,
 
404
  negative_prompt: Optional[Union[str, List[str]]] = None,
405
  num_images_per_prompt: Optional[int] = 1,
406
  eta: Optional[float] = 0.0,
407
+ generator: Optional[Union[torch.Generator,
408
+ List[torch.Generator]]] = None,
409
  prompt_embeds: Optional[torch.FloatTensor] = None,
410
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
411
+ callback: Optional[Callable[[
412
+ int, int, torch.FloatTensor], None]] = None,
413
  callback_steps: int = 1,
414
  ):
415
  r"""
 
467
  "not-safe-for-work" (nsfw) content.
468
  """
469
  # 1. Check inputs. Raise error if not correct
470
+ self.check_inputs(prompt, strength, callback_steps,
471
+ negative_prompt, prompt_embeds, negative_prompt_embeds)
472
 
473
  # 2. Define call parameters
474
  if prompt is not None and isinstance(prompt, str):
 
496
  )
497
 
498
  # 4. Preprocess image
499
+ image = self.image_processor.preprocess(image)
500
 
501
  # 5. set timesteps
502
  self.scheduler.set_timesteps(num_inference_steps, device=device)
503
+ timesteps, num_inference_steps = self.get_timesteps(
504
+ num_inference_steps, strength, device)
505
+ latent_timestep = timesteps[:1].repeat(
506
+ batch_size * num_images_per_prompt)
507
 
508
  # 6. Prepare latent variables
509
  latents = self.prepare_latents(
 
514
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
515
 
516
  # 8. Denoising loop
517
+ num_warmup_steps = len(timesteps) - \
518
+ num_inference_steps * self.scheduler.order
519
 
520
  for i, t in enumerate(timesteps):
521
  # expand the latents if we are doing classifier free guidance
522
+ latent_model_input = torch.cat(
523
+ [latents] * 2) if do_classifier_free_guidance else latents
524
+ latent_model_input = self.scheduler.scale_model_input(
525
+ latent_model_input, t)
526
+ latent_model_input = latent_model_input.permute(
527
+ 0, 2, 3, 1).contiguous()
528
 
529
  # predict the noise residual
530
+ # 后边 None 是给到controlnet 的参数,暂时给到 None 当 placeholder
531
+ noise_pred = self.unet.forward(
532
+ latent_model_input, prompt_embeds, t)
533
  noise_pred = noise_pred.permute(0, 3, 1, 2)
534
 
535
  # perform guidance
536
  if do_classifier_free_guidance:
537
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
538
+ noise_pred = noise_pred_uncond + guidance_scale * \
539
+ (noise_pred_text - noise_pred_uncond)
540
 
541
  # compute the previous noisy sample x_t -> x_t-1
542
+ latents = self.scheduler.step(
543
+ noise_pred, t, latents, **extra_step_kwargs).prev_sample
544
 
545
+ # image = self.decode_latents(latents)
546
+ image = self.lyra_decode_latents(latents)
547
  image = numpy_to_pil(image)
548
 
549
  return image
lyrasd_model/lyrasd_lib/libth_lyrasd_cu11_sm80.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0689ed5d3b55f5033a8869d5f23ce900793aa0ab7fdc4a3e3c0a0f3a243c83da
3
- size 65441456
 
 
 
 
lyrasd_model/lyrasd_lib/libth_lyrasd_cu11_sm86.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b8e27e715fa3a17ce25bf23b772e0dd355d0780c1bd93cfeeb12ef45b0ba2444
3
- size 65389176
 
 
 
 
lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c2eaa9067ad8eb1d20872afa71ed9497f62d930819704d15e5e8bf559623eca7
3
- size 65498752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8600f5414d283ebf64cb3974ef520858747cbb1a6d59dd46a3dcd9427758613b
3
+ size 97823240
lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm86.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7d0c909ff2498934c6d1ed8f46af6cdc7812872177c0a4e7ca0ee99bf88fcb65
3
- size 65519232
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e5aefbb32667eeacb7fa60283656b4bb2ebb7dcd54276f9d101c856ed64e340
3
+ size 97823240
lyrasd_model/lyrasd_pipeline_base.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+ import time
4
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
5
+
6
+ import gc
7
+ import torch
8
+ import numpy as np
9
+ from glob import glob
10
+
11
+ from diffusers.loaders import TextualInversionLoaderMixin
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.models import AutoencoderKL
14
+ from diffusers.schedulers import (DPMSolverMultistepScheduler,
15
+ EulerAncestralDiscreteScheduler,
16
+ EulerDiscreteScheduler,
17
+ KarrasDiffusionSchedulers)
18
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
19
+ from .lyrasd_vae_model import LyraSdVaeModel
20
+ from .module.lyrasd_ip_adapter import LyraIPAdapter
21
+ from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict
22
+ from safetensors.torch import load_file
23
+
24
+
25
+ class LyraSDXLPipelineBase(TextualInversionLoaderMixin):
26
+ def __init__(self, device=torch.device("cuda"), dtype=torch.float16, num_channels_unet=4, num_channels_latents=4, vae_scale_factor=8, vae_scaling_factor=0.18215) -> None:
27
+ self.device = device
28
+ self.dtype = dtype
29
+
30
+ self.num_channels_unet = num_channels_unet
31
+ self.num_channels_latents = num_channels_latents
32
+ self.vae_scale_factor = vae_scale_factor
33
+ self.vae_scaling_factor = vae_scaling_factor
34
+
35
+ self.unet_cache = {}
36
+ self.unet_in_channels = 4
37
+
38
+ self.controlnet_cache = {}
39
+
40
+ self.loaded_lora = {}
41
+ self.loaded_lora_strength = {}
42
+
43
+ self.scheduler = None
44
+
45
+ self.init_pipe()
46
+
47
+ def init_pipe(self):
48
+ self.vae = LyraSdVaeModel(
49
+ scale_factor=self.vae_scale_factor, scaling_factor=self.vae_scaling_factor)
50
+
51
+ self.unet = torch.classes.lyrasd.Unet2dConditionalModelOp(
52
+ 3,
53
+ "fp16",
54
+ self.num_channels_unet,
55
+ self.num_channels_latents
56
+ )
57
+
58
+ self.image_processor = VaeImageProcessor(
59
+ vae_scale_factor=self.vae_scale_factor)
60
+
61
+ self.mask_processor = VaeImageProcessor(
62
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
63
+ )
64
+
65
+ self.feature_extractor = CLIPImageProcessor()
66
+
67
+ def reload_pipe(self, model_path):
68
+ self.tokenizer = CLIPTokenizer.from_pretrained(
69
+ model_path, subfolder="tokenizer")
70
+ self.text_encoder = CLIPTextModel.from_pretrained(
71
+ model_path, subfolder="text_encoder").to(self.dtype).to(self.device)
72
+
73
+ self.reload_unet_model_v2(model_path)
74
+ self.reload_vae_model_v2(model_path)
75
+
76
+ if not self.scheduler:
77
+ self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
78
+ model_path, subfolder="scheduler")
79
+
80
+ @property
81
+ def _execution_device(self):
82
+ if not hasattr(self.unet, "_hf_hook"):
83
+ return self.device
84
+ for module in self.unet.modules():
85
+ if (
86
+ hasattr(module, "_hf_hook")
87
+ and hasattr(module._hf_hook, "execution_device")
88
+ and module._hf_hook.execution_device is not None
89
+ ):
90
+ return torch.device(module._hf_hook.execution_device)
91
+ return self.device
92
+
93
+ def reload_unet_model(self, unet_path, unet_file_format='fp32'):
94
+ if len(unet_path) > 0 and unet_path[-1] != "/":
95
+ unet_path = unet_path + "/"
96
+ self.unet.reload_unet_model(unet_path, unet_file_format)
97
+ self.load_embedding_weight(
98
+ self.add_embedding, f"{unet_path}add_embedding*", unet_file_format=unet_file_format)
99
+
100
+ def reload_vae_model(self, vae_path, vae_file_format='fp32'):
101
+ if len(vae_path) > 0 and vae_path[-1] != "/":
102
+ vae_path = vae_path + "/"
103
+ return self.vae.reload_vae_model(vae_path, vae_file_format)
104
+
105
+ def load_lora(self, lora_model_path, lora_name, lora_strength, lora_file_format='fp32'):
106
+ if len(lora_model_path) > 0 and lora_model_path[-1] != "/":
107
+ lora_model_path = lora_model_path + "/"
108
+ lora = add_xltext_lora_layer(
109
+ self.text_encoder, self.text_encoder_2, lora_model_path, lora_strength, lora_file_format)
110
+
111
+ self.loaded_lora[lora_name] = lora
112
+ self.unet.load_lora(lora_model_path, lora_name,
113
+ lora_strength, lora_file_format)
114
+
115
+ def unload_lora(self, lora_name, clean_cache=False):
116
+ for layer_data in self.loaded_lora[lora_name]:
117
+ layer = layer_data['layer']
118
+ added_weight = layer_data['added_weight']
119
+ layer.weight.data -= added_weight
120
+ self.unet.unload_lora(lora_name, clean_cache)
121
+ del self.loaded_lora[lora_name]
122
+ gc.collect()
123
+ torch.cuda.empty_cache()
124
+
125
+ def load_lora_v2(self, lora_model_path, lora_name, lora_strength):
126
+ if lora_name in self.loaded_lora:
127
+ state_dict = self.loaded_lora[lora_name]
128
+ else:
129
+ state_dict = load_state_dict(lora_model_path)
130
+ self.loaded_lora[lora_name] = state_dict
131
+ self.loaded_lora_strength[lora_name] = lora_strength
132
+ add_lora_to_opt_model(state_dict, self.unet, self.text_encoder,
133
+ None, lora_strength)
134
+
135
+ def unload_lora_v2(self, lora_name, clean_cache=False):
136
+ state_dict = self.loaded_lora[lora_name]
137
+ lora_strength = self.loaded_lora_strength[lora_name]
138
+ add_lora_to_opt_model(state_dict, self.unet, self.text_encoder,
139
+ None, -1.0 * lora_strength)
140
+ del self.loaded_lora_strength[lora_name]
141
+
142
+ if clean_cache:
143
+ del self.loaded_lora[lora_name]
144
+ gc.collect()
145
+ torch.cuda.empty_cache()
146
+
147
+ def clean_lora_cache(self):
148
+ self.unet.clean_lora_cache()
149
+
150
+ def get_loaded_lora(self):
151
+ return self.unet.get_loaded_lora()
152
+
153
+ def load_ip_adapter(self, dir_ip_adapter, ip_plus, image_encoder_path, num_ip_tokens, ip_projection_dim, dir_face_in=None, num_fp_tokens=1, fp_projection_dim=None, sdxl=True):
154
+ self.ip_adapter_helper = LyraIPAdapter(self, sdxl, "cuda", dir_ip_adapter, ip_plus, image_encoder_path,
155
+ num_ip_tokens, ip_projection_dim, dir_face_in, num_fp_tokens, fp_projection_dim)
156
+
157
+ def reload_unet_model_v2(self, model_path):
158
+ checkpoint_file = os.path.join(
159
+ model_path, "unet/diffusion_pytorch_model.bin")
160
+ if not os.path.exists(checkpoint_file):
161
+ checkpoint_file = os.path.join(
162
+ model_path, "unet/diffusion_pytorch_model.safetensors")
163
+ if checkpoint_file in self.unet_cache:
164
+ state_dict = self.unet_cache[checkpoint_file]
165
+ else:
166
+ if "safetensors" in checkpoint_file:
167
+ state_dict = load_file(checkpoint_file)
168
+ else:
169
+ state_dict = torch.load(checkpoint_file, map_location="cpu")
170
+
171
+ for key in state_dict:
172
+ if len(state_dict[key].shape) == 4:
173
+ # converted_unet_checkpoint[key] = converted_unet_checkpoint[key].to(torch.float16).to("cuda").permute(0,2,3,1).contiguous().cpu()
174
+ state_dict[key] = state_dict[key].to(
175
+ torch.float16).permute(0, 2, 3, 1).contiguous()
176
+ state_dict[key] = state_dict[key].to(torch.float16)
177
+ self.unet_cache[checkpoint_file] = state_dict
178
+
179
+ self.unet.reload_unet_model_from_cache(state_dict, "cpu")
180
+
181
+ def reload_vae_model_v2(self, model_path):
182
+ self.vae.reload_vae_model_v2(model_path)
183
+
184
+ def load_controlnet_model(self, model_name, controlnet_path, model_dtype="fp32"):
185
+ if len(controlnet_path) > 0 and controlnet_path[-1] != "/":
186
+ controlnet_path = controlnet_path + "/"
187
+ self.unet.load_controlnet_model(model_name, controlnet_path, model_dtype)
188
+
189
+ def unload_controlnet_model(self, model_name):
190
+ self.unet.unload_controlnet_model(model_name, True)
191
+
192
+ def get_loaded_controlnet(self):
193
+ return self.unet.get_loaded_controlnet()
194
+
195
+ def load_controlnet_model_v2(self, model_name, controlnet_path):
196
+ checkpoint_file = os.path.join(controlnet_path, "diffusion_pytorch_model.bin")
197
+ if not os.path.exists(checkpoint_file):
198
+ checkpoint_file = os.path.join(controlnet_path, "diffusion_pytorch_model.safetensors")
199
+ if checkpoint_file in self.controlnet_cache:
200
+ state_dict = self.controlnet_cache[checkpoint_file]
201
+ else:
202
+ if "safetensors" in checkpoint_file:
203
+ state_dict = load_file(checkpoint_file)
204
+ else:
205
+ state_dict = torch.load(checkpoint_file, map_location="cpu")
206
+
207
+ for key in state_dict:
208
+ if len(state_dict[key].shape) == 4:
209
+ # converted_unet_checkpoint[key] = converted_unet_checkpoint[key].to(torch.float16).to("cuda").permute(0,2,3,1).contiguous().cpu()
210
+ state_dict[key] = state_dict[key].to(torch.float16).permute(0,2,3,1).contiguous()
211
+ state_dict[key] = state_dict[key].to(torch.float16)
212
+ self.controlnet_cache[checkpoint_file] = state_dict
213
+
214
+ self.unet.load_controlnet_model_from_state_dict(model_name, state_dict, "cpu")
lyrasd_model/lyrasd_txt2img_inpaint_pipeline.py ADDED
@@ -0,0 +1,826 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+ import sys
4
+ import time
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+ import GPUtil
7
+ import torch
8
+ from diffusers.loaders import TextualInversionLoaderMixin
9
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
10
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
11
+ from diffusers.utils.torch_utils import logging, randn_tensor
12
+ from PIL import Image
13
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
14
+ import gc
15
+ import numpy as np
16
+
17
+ from .lyrasd_vae_model import LyraSdVaeModel
18
+
19
+ from diffusers.models.embeddings import ImageProjection
20
+ from transformers import (
21
+ CLIPImageProcessor,
22
+ CLIPVisionModelWithProjection,
23
+ )
24
+
25
+ from .lyrasd_pipeline_base import LyraSDXLPipelineBase
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
31
+ """
32
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
33
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
34
+ """
35
+ std_text = noise_pred_text.std(
36
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
37
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
38
+ # rescale the results from guidance (fixes overexposure)
39
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
40
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
41
+ noise_cfg = guidance_rescale * noise_pred_rescaled + \
42
+ (1 - guidance_rescale) * noise_cfg
43
+ return noise_cfg
44
+
45
+
46
+ def numpy_to_pil(images):
47
+ """
48
+ Convert a numpy image or a batch of images to a PIL image.
49
+ """
50
+ if images.ndim == 3:
51
+ images = images[None, ...]
52
+ images = (images * 255).round().astype("uint8")
53
+ if images.shape[-1] == 1:
54
+ # special case for grayscale (single channel) images
55
+ pil_images = [Image.fromarray(image.squeeze(), mode="L")
56
+ for image in images]
57
+ else:
58
+ pil_images = [Image.fromarray(image) for image in images]
59
+
60
+ return pil_images
61
+
62
+
63
+ def retrieve_timesteps(
64
+ scheduler,
65
+ num_inference_steps: Optional[int] = None,
66
+ device: Optional[Union[str, torch.device]] = None,
67
+ timesteps: Optional[List[int]] = None,
68
+ **kwargs,
69
+ ):
70
+ """
71
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
72
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
73
+
74
+ Args:
75
+ scheduler (`SchedulerMixin`):
76
+ The scheduler to get timesteps from.
77
+ num_inference_steps (`int`):
78
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
79
+ `timesteps` must be `None`.
80
+ device (`str` or `torch.device`, *optional*):
81
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
82
+ timesteps (`List[int]`, *optional*):
83
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
84
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
85
+ must be `None`.
86
+
87
+ Returns:
88
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
89
+ second element is the number of inference steps.
90
+ """
91
+ if timesteps is not None:
92
+ print("set(inspect.signature(scheduler.set_timesteps).parameters.keys())", set(
93
+ inspect.signature(scheduler.set_timesteps).parameters.keys()))
94
+ accepts_timesteps = "timesteps" in set(
95
+ inspect.signature(scheduler.set_timesteps).parameters.keys())
96
+ if not accepts_timesteps:
97
+ raise ValueError(
98
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
99
+ f" timestep schedules. Please check whether you are using the correct scheduler."
100
+ )
101
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
102
+ timesteps = scheduler.timesteps
103
+ num_inference_steps = len(timesteps)
104
+ else:
105
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
106
+ timesteps = scheduler.timesteps
107
+ return timesteps, num_inference_steps
108
+
109
+
110
+ def retrieve_latents(
111
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
112
+ ):
113
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
114
+ return encoder_output.latent_dist.sample(generator)
115
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
116
+ return encoder_output.latent_dist.mode()
117
+ elif hasattr(encoder_output, "latents"):
118
+ return encoder_output.latents
119
+ else:
120
+ raise AttributeError(
121
+ "Could not access latents of provided encoder_output")
122
+
123
+
124
+ class LyraSdTxt2ImgInpaintPipeline(LyraSDXLPipelineBase):
125
+ def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.18215, num_channels_unet=9, num_channels_latents=4) -> None:
126
+ super().__init__(device, dtype, num_channels_unet=num_channels_unet, num_channels_latents=num_channels_latents,
127
+ vae_scale_factor=vae_scale_factor, vae_scaling_factor=vae_scaling_factor)
128
+
129
+ def _encode_prompt(
130
+ self,
131
+ prompt,
132
+ device,
133
+ num_images_per_prompt,
134
+ do_classifier_free_guidance,
135
+ negative_prompt=None,
136
+ prompt_embeds: Optional[torch.FloatTensor] = None,
137
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
138
+ ):
139
+ r"""
140
+ Encodes the prompt into text encoder hidden states.
141
+
142
+ Args:
143
+ prompt (`str` or `List[str]`, *optional*):
144
+ prompt to be encoded
145
+ device: (`torch.device`):
146
+ torch device
147
+ num_images_per_prompt (`int`):
148
+ number of images that should be generated per prompt
149
+ do_classifier_free_guidance (`bool`):
150
+ whether to use classifier free guidance or not
151
+ negative_prompt (`str` or `List[str]`, *optional*):
152
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
153
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
154
+ less than `1`).
155
+ prompt_embeds (`torch.FloatTensor`, *optional*):
156
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
157
+ provided, text embeddings will be generated from `prompt` input argument.
158
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
159
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
160
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
161
+ argument.
162
+ """
163
+ if prompt is not None and isinstance(prompt, str):
164
+ batch_size = 1
165
+ elif prompt is not None and isinstance(prompt, list):
166
+ batch_size = len(prompt)
167
+ else:
168
+ batch_size = prompt_embeds.shape[0]
169
+
170
+ if prompt_embeds is None:
171
+ # textual inversion: procecss multi-vector tokens if necessary
172
+ if isinstance(self, TextualInversionLoaderMixin):
173
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
174
+
175
+ text_inputs = self.tokenizer(
176
+ prompt,
177
+ padding="max_length",
178
+ max_length=self.tokenizer.model_max_length,
179
+ truncation=True,
180
+ return_tensors="pt",
181
+ )
182
+ text_input_ids = text_inputs.input_ids
183
+ untruncated_ids = self.tokenizer(
184
+ prompt, padding="longest", return_tensors="pt").input_ids
185
+
186
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
187
+ text_input_ids, untruncated_ids
188
+ ):
189
+ removed_text = self.tokenizer.batch_decode(
190
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
191
+ )
192
+ logger.warning(
193
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
194
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
195
+ )
196
+
197
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
198
+ attention_mask = text_inputs.attention_mask.to(device)
199
+ else:
200
+ attention_mask = None
201
+
202
+ prompt_embeds = self.text_encoder(
203
+ text_input_ids.to(device),
204
+ attention_mask=attention_mask,
205
+ )
206
+ prompt_embeds = prompt_embeds[0]
207
+
208
+ prompt_embeds = prompt_embeds.to(
209
+ dtype=self.text_encoder.dtype, device=device)
210
+
211
+ bs_embed, seq_len, _ = prompt_embeds.shape
212
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
213
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
214
+ prompt_embeds = prompt_embeds.view(
215
+ bs_embed * num_images_per_prompt, seq_len, -1)
216
+
217
+ # get unconditional embeddings for classifier free guidance
218
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
219
+ uncond_tokens: List[str]
220
+ if negative_prompt is None:
221
+ uncond_tokens = [""] * batch_size
222
+ elif type(prompt) is not type(negative_prompt):
223
+ raise TypeError(
224
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
225
+ f" {type(prompt)}."
226
+ )
227
+ elif isinstance(negative_prompt, str):
228
+ uncond_tokens = [negative_prompt]
229
+ elif batch_size != len(negative_prompt):
230
+ raise ValueError(
231
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
232
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
233
+ " the batch size of `prompt`."
234
+ )
235
+ else:
236
+ uncond_tokens = negative_prompt
237
+
238
+ # textual inversion: procecss multi-vector tokens if necessary
239
+ if isinstance(self, TextualInversionLoaderMixin):
240
+ uncond_tokens = self.maybe_convert_prompt(
241
+ uncond_tokens, self.tokenizer)
242
+
243
+ max_length = prompt_embeds.shape[1]
244
+ uncond_input = self.tokenizer(
245
+ uncond_tokens,
246
+ padding="max_length",
247
+ max_length=max_length,
248
+ truncation=True,
249
+ return_tensors="pt",
250
+ )
251
+
252
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
253
+ attention_mask = uncond_input.attention_mask.to(device)
254
+ else:
255
+ attention_mask = None
256
+
257
+ negative_prompt_embeds = self.text_encoder(
258
+ uncond_input.input_ids.to(device),
259
+ attention_mask=attention_mask,
260
+ )
261
+ negative_prompt_embeds = negative_prompt_embeds[0]
262
+
263
+ if do_classifier_free_guidance:
264
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
265
+ seq_len = negative_prompt_embeds.shape[1]
266
+
267
+ negative_prompt_embeds = negative_prompt_embeds.to(
268
+ dtype=self.text_encoder.dtype, device=device)
269
+
270
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
271
+ 1, num_images_per_prompt, 1)
272
+ negative_prompt_embeds = negative_prompt_embeds.view(
273
+ batch_size * num_images_per_prompt, seq_len, -1)
274
+
275
+ # For classifier free guidance, we need to do two forward passes.
276
+ # Here we concatenate the unconditional and text embeddings into a single batch
277
+ # to avoid doing two forward passes
278
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
279
+
280
+ return prompt_embeds
281
+
282
+ def load_ip_adapter(self,
283
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
284
+ subfolder: str,
285
+ weight_name: str,
286
+ **kwargs
287
+ ):
288
+ # if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
289
+ self.feature_extractor = CLIPImageProcessor()
290
+
291
+ # if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
292
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
293
+ pretrained_model_name_or_path_or_dict,
294
+ subfolder=os.path.join(subfolder, "image_encoder"),
295
+ ).to(self.device, dtype=self.dtype)
296
+ # else:
297
+ # print("kio: already has image_encoder", hasattr(self, "image_encoder"), getattr(self, "feature_extractor", None) is None)
298
+
299
+ # kiotodo: init ImageProjection
300
+ model_path = os.path.join(
301
+ pretrained_model_name_or_path_or_dict, subfolder, weight_name)
302
+ state_dict = torch.load(model_path, map_location="cpu")
303
+
304
+ clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
305
+ cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
306
+ self.encoder_hid_proj = ImageProjection(
307
+ cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
308
+ )
309
+
310
+ image_proj_state_dict = {}
311
+ image_proj_state_dict.update(
312
+ {
313
+ "image_embeds.weight": state_dict["image_proj"]["proj.weight"],
314
+ "image_embeds.bias": state_dict["image_proj"]["proj.bias"],
315
+ "norm.weight": state_dict["image_proj"]["norm.weight"],
316
+ "norm.bias": state_dict["image_proj"]["norm.bias"],
317
+ }
318
+ )
319
+
320
+ self.encoder_hid_proj.load_state_dict(image_proj_state_dict)
321
+ self.encoder_hid_proj.to(dtype=self.dtype, device=self.device)
322
+
323
+ dir_ipadapter = os.path.join(
324
+ pretrained_model_name_or_path_or_dict, subfolder, '.'.join(weight_name.split(".")[:-1]))
325
+ self.unet.load_ip_adapter(dir_ipadapter, "", 1, "fp16")
326
+
327
+ def encode_image(self, image, device, num_images_per_prompt):
328
+ dtype = next(self.image_encoder.parameters()).dtype
329
+ if not isinstance(image, torch.Tensor):
330
+ image = self.feature_extractor(
331
+ image, return_tensors="pt").pixel_values
332
+
333
+ image = image.to(device=device, dtype=dtype)
334
+ image_embeds = self.image_encoder(image).image_embeds
335
+ image_embeds = image_embeds.repeat_interleave(
336
+ num_images_per_prompt, dim=0)
337
+
338
+ uncond_image_embeds = torch.zeros_like(image_embeds)
339
+ return image_embeds, uncond_image_embeds
340
+
341
+ def decode_latents(self, latents):
342
+ latents = 1 / self.vae.scaling_factor * latents
343
+ image = self.vae.decode(latents).sample
344
+ image = (image / 2 + 0.5).clamp(0, 1)
345
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
346
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
347
+ return image
348
+
349
+ def lyra_decode_latents(self, latents):
350
+ # print("lyra_decode_latents")
351
+ # np.save("", latents.)
352
+ # np.save(f"/workspace/vae_model/latent.npy", latents.detach().cpu().numpy())
353
+ latents = 1 / self.vae.scaling_factor * latents
354
+ # latents = latents.permute(0, 2, 3, 1).contiguous()
355
+ image = self.vae.decode(latents)
356
+ image = image.permute(0, 2, 3, 1)
357
+ # print(image)
358
+ # GPUtil.showUtilization(all=True)
359
+
360
+ image = (image / 2 + 0.5).clamp(0, 1)
361
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
362
+ image = image.cpu().float().numpy()
363
+
364
+ return image
365
+
366
+ def get_timesteps(self, num_inference_steps, strength, device):
367
+ # get the original timestep using init_timestep
368
+ init_timestep = min(
369
+ int(num_inference_steps * strength), num_inference_steps)
370
+
371
+ t_start = max(num_inference_steps - init_timestep, 0)
372
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
373
+
374
+ return timesteps, num_inference_steps - t_start
375
+
376
+ def check_inputs(
377
+ self,
378
+ prompt,
379
+ height,
380
+ width,
381
+ negative_prompt=None,
382
+ prompt_embeds=None,
383
+ negative_prompt_embeds=None,
384
+ ):
385
+ if height % 64 != 0 or width % 64 != 0: # 初版暂时只支持 64 的倍数的 height 和 width
386
+ raise ValueError(
387
+ f"`height` and `width` have to be divisible by 64 but are {height} and {width}.")
388
+
389
+ if prompt is not None and prompt_embeds is not None:
390
+ raise ValueError(
391
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
392
+ " only forward one of the two."
393
+ )
394
+ elif prompt is None and prompt_embeds is None:
395
+ raise ValueError(
396
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
397
+ )
398
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
399
+ raise ValueError(
400
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
401
+
402
+ if negative_prompt is not None and negative_prompt_embeds is not None:
403
+ raise ValueError(
404
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
405
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
406
+ )
407
+
408
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
409
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
410
+ raise ValueError(
411
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
412
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
413
+ f" {negative_prompt_embeds.shape}."
414
+ )
415
+
416
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
417
+ if isinstance(generator, list):
418
+ image_latents = [
419
+ retrieve_latents(AutoencoderKLOutput(
420
+ latent_dist=self.vae.encode(image[i: i + 1])), generator=generator[i])
421
+ for i in range(image.shape[0])
422
+ ]
423
+ image_latents = torch.cat(image_latents, dim=0)
424
+ else:
425
+ image_latents = retrieve_latents(AutoencoderKLOutput(
426
+ latent_dist=self.vae.encode(image)), generator=generator)
427
+
428
+ image_latents = self.vae_scaling_factor * image_latents
429
+
430
+ return image_latents
431
+
432
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None,
433
+ image=None, timestep=None, is_strength_max=True, return_noise=False, return_image_latents=False):
434
+ shape = (batch_size, num_channels_latents, height //
435
+ self.vae_scale_factor, width // self.vae_scale_factor)
436
+ if isinstance(generator, list) and len(generator) != batch_size:
437
+ raise ValueError(
438
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
439
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
440
+ )
441
+
442
+ if (image is None or timestep is None) and not is_strength_max:
443
+ raise ValueError(
444
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
445
+ "However, either the image or the noise timestep has not been provided."
446
+ )
447
+
448
+ if return_image_latents or (latents is None and not is_strength_max):
449
+ image = image.to(device=device, dtype=dtype)
450
+
451
+ if image.shape[1] == 4:
452
+ image_latents = image
453
+ else:
454
+ image_latents = self._encode_vae_image(
455
+ image=image, generator=generator)
456
+ image_latents = image_latents.repeat(
457
+ batch_size // image_latents.shape[0], 1, 1, 1)
458
+
459
+ if latents is None:
460
+ noise = randn_tensor(shape, generator=generator,
461
+ device=device, dtype=dtype)
462
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
463
+ latents = noise if is_strength_max else self.scheduler.add_noise(
464
+ image_latents, noise, timestep)
465
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
466
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
467
+ else:
468
+ noise = latents.to(device)
469
+ latents = noise * self.scheduler.init_noise_sigma
470
+
471
+ outputs = (latents,)
472
+
473
+ if return_noise:
474
+ outputs += (noise,)
475
+
476
+ if return_image_latents:
477
+ outputs += (image_latents,)
478
+
479
+ return outputs
480
+
481
+ def prepare_mask_latents(
482
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
483
+ ):
484
+ # resize the mask to latents shape as we concatenate the mask to the latents
485
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
486
+ # and half precision
487
+ mask = torch.nn.functional.interpolate(
488
+ mask, size=(height // self.vae_scale_factor,
489
+ width // self.vae_scale_factor)
490
+ )
491
+ mask = mask.to(device=device, dtype=dtype)
492
+
493
+ masked_image = masked_image.to(device=device, dtype=dtype)
494
+
495
+ if masked_image.shape[1] == 4:
496
+ masked_image_latents = masked_image
497
+ else:
498
+ masked_image_latents = self._encode_vae_image(
499
+ masked_image, generator=generator)
500
+
501
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
502
+ if mask.shape[0] < batch_size:
503
+ if not batch_size % mask.shape[0] == 0:
504
+ raise ValueError(
505
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
506
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
507
+ " of masks that you pass is divisible by the total requested batch size."
508
+ )
509
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
510
+ if masked_image_latents.shape[0] < batch_size:
511
+ if not batch_size % masked_image_latents.shape[0] == 0:
512
+ raise ValueError(
513
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
514
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
515
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
516
+ )
517
+ masked_image_latents = masked_image_latents.repeat(
518
+ batch_size // masked_image_latents.shape[0], 1, 1, 1)
519
+
520
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
521
+ masked_image_latents = (
522
+ torch.cat([masked_image_latents] *
523
+ 2) if do_classifier_free_guidance else masked_image_latents
524
+ )
525
+
526
+ # aligning device to prevent device errors when concating it with the latent model input
527
+ masked_image_latents = masked_image_latents.to(
528
+ device=device, dtype=dtype)
529
+ return mask, masked_image_latents
530
+
531
+ def prepare_extra_step_kwargs(self, generator, eta):
532
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
533
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
534
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
535
+ # and should be between [0, 1]
536
+
537
+ accepts_eta = "eta" in set(inspect.signature(
538
+ self.scheduler.step).parameters.keys())
539
+ extra_step_kwargs = {}
540
+ if accepts_eta:
541
+ extra_step_kwargs["eta"] = eta
542
+
543
+ # check if the scheduler accepts generator
544
+ accepts_generator = "generator" in set(
545
+ inspect.signature(self.scheduler.step).parameters.keys())
546
+ if accepts_generator:
547
+ extra_step_kwargs["generator"] = generator
548
+ return extra_step_kwargs
549
+
550
+ @torch.no_grad()
551
+ def __call__(
552
+ self,
553
+ prompt: Union[str, List[str]] = None,
554
+ image: PipelineImageInput = None,
555
+ mask_image: PipelineImageInput = None,
556
+ masked_image_latents: torch.FloatTensor = None,
557
+ height: Optional[int] = None,
558
+ width: Optional[int] = None,
559
+ strength: float = 1.0,
560
+ num_inference_steps: int = 50,
561
+ guidance_scale: float = 7.5,
562
+ negative_prompt: Optional[Union[str, List[str]]] = None,
563
+ num_images_per_prompt: Optional[int] = 1,
564
+ eta: float = 0.0,
565
+ generator: Optional[Union[torch.Generator,
566
+ List[torch.Generator]]] = None,
567
+ latents: Optional[torch.FloatTensor] = None,
568
+ prompt_embeds: Optional[torch.FloatTensor] = None,
569
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
570
+ ip_adapter_image: Optional[PipelineImageInput] = None,
571
+ param_scale_dict: Optional[dict] = {}
572
+ ):
573
+ r"""
574
+ Function invoked when calling the pipeline for generation.
575
+
576
+ Args:
577
+ prompt (`str` or `List[str]`, *optional*):
578
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
579
+ instead.
580
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
581
+ The height in pixels of the generated image.
582
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
583
+ The width in pixels of the generated image.
584
+ num_inference_steps (`int`, *optional*, defaults to 50):
585
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
586
+ expense of slower inference.
587
+ guidance_scale (`float`, *optional*, defaults to 7.5):
588
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
589
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
590
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
591
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
592
+ usually at the expense of lower image quality.
593
+ negative_prompt (`str` or `List[str]`, *optional*):
594
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
595
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
596
+ less than `1`).
597
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
598
+ The number of images to generate per prompt.
599
+ eta (`float`, *optional*, defaults to 0.0):
600
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
601
+ [`schedulers.DDIMScheduler`], will be ignored for others.
602
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
603
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
604
+ to make generation deterministic.
605
+ latents (`torch.FloatTensor`, *optional*):
606
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
607
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
608
+ tensor will ge generated by sampling using the supplied random `generator`.
609
+ prompt_embeds (`torch.FloatTensor`, *optional*):
610
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
611
+ provided, text embeddings will be generated from `prompt` input argument.
612
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
613
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
614
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
615
+ argument.
616
+
617
+ """
618
+ # 0. Default height and width to unet
619
+ height = height or self.unet_config_sample_size * self.vae_scale_factor
620
+ width = width or self.unet_config_sample_size * self.vae_scale_factor
621
+ # self.unet_config.sample_size = 64
622
+ # height = 512
623
+ # width = 512
624
+
625
+ # 1. Check inputs. Raise error if not correct
626
+ # self.check_inputs(
627
+ # prompt, height, width, negative_prompt, prompt_embeds, negative_prompt_embeds
628
+ # )
629
+
630
+ # 2. Define call parameters
631
+ if prompt is not None and isinstance(prompt, str):
632
+ batch_size = 1
633
+ elif prompt is not None and isinstance(prompt, list):
634
+ batch_size = len(prompt)
635
+ else:
636
+ batch_size = prompt_embeds.shape[0]
637
+
638
+ device = self.device
639
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
640
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
641
+ # corresponds to doing no classifier free guidance.
642
+ do_classifier_free_guidance = guidance_scale > 1.0
643
+
644
+ # 3. Encode input prompt
645
+ prompt_embeds = self._encode_prompt(
646
+ prompt,
647
+ device,
648
+ num_images_per_prompt,
649
+ do_classifier_free_guidance,
650
+ negative_prompt,
651
+ prompt_embeds=prompt_embeds,
652
+ negative_prompt_embeds=negative_prompt_embeds,
653
+ )
654
+
655
+ # 3.5 Encode ipadapter_image
656
+ if ip_adapter_image is not None:
657
+ image_embeds, negative_image_embeds = self.encode_image(
658
+ ip_adapter_image, device, num_images_per_prompt)
659
+ if do_classifier_free_guidance:
660
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
661
+ image_embeds = self.encoder_hid_proj(image_embeds).to(self.dtype)
662
+
663
+ # 4. Prepare timesteps
664
+ # self.scheduler.set_timesteps(num_inference_steps, device=device)
665
+ # timesteps = self.scheduler.timesteps
666
+
667
+ # 4.5 Prepare mask and image
668
+ timesteps = None
669
+ timesteps, num_inference_steps = retrieve_timesteps(
670
+ self.scheduler, num_inference_steps, device, timesteps)
671
+ timesteps, num_inference_steps = self.get_timesteps(
672
+ num_inference_steps=num_inference_steps, strength=strength, device=device
673
+ )
674
+ # check that number of inference steps is not < 1 - as this doesn't make sense
675
+ if num_inference_steps < 1:
676
+ raise ValueError(
677
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
678
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
679
+ )
680
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
681
+ latent_timestep = timesteps[:1].repeat(
682
+ batch_size * num_images_per_prompt)
683
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
684
+ is_strength_max = strength == 1.0
685
+
686
+ # 5. Preprocess mask and image
687
+
688
+ init_image = self.image_processor.preprocess(
689
+ image, height=height, width=width)
690
+ init_image = init_image.to(dtype=torch.float32)
691
+
692
+ # 5. Prepare latent variables
693
+ return_image_latents = self.num_channels_unet == 4
694
+ latents_outputs = self.prepare_latents(
695
+ batch_size * num_images_per_prompt,
696
+ self.num_channels_latents,
697
+ height,
698
+ width,
699
+ prompt_embeds.dtype,
700
+ device,
701
+ generator,
702
+ latents,
703
+ image=init_image,
704
+ timestep=latent_timestep,
705
+ is_strength_max=is_strength_max,
706
+ return_noise=True,
707
+ return_image_latents=return_image_latents
708
+ )
709
+
710
+ if return_image_latents:
711
+ latents, noise, image_latents = latents_outputs
712
+ else:
713
+ latents, noise = latents_outputs
714
+
715
+ # 5.5 Prepare mask latent variables
716
+ mask_condition = self.mask_processor.preprocess(
717
+ mask_image, height=height, width=width)
718
+ if masked_image_latents is None:
719
+ masked_image = init_image * (mask_condition < 0.5)
720
+ else:
721
+ masked_image = masked_image_latents
722
+
723
+ mask, masked_image_latents = self.prepare_mask_latents(
724
+ mask_condition,
725
+ masked_image,
726
+ batch_size * num_images_per_prompt,
727
+ height,
728
+ width,
729
+ prompt_embeds.dtype,
730
+ device,
731
+ generator,
732
+ do_classifier_free_guidance,
733
+ )
734
+
735
+ # Check that sizes of mask, masked image and latents match
736
+ if self.num_channels_unet == 9:
737
+ # default case for runwayml/stable-diffusion-inpainting
738
+ num_channels_mask = mask.shape[1]
739
+ num_channels_masked_image = masked_image_latents.shape[1]
740
+ if self.num_channels_latents + num_channels_mask + num_channels_masked_image != self.num_channels_unet:
741
+ raise ValueError(
742
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
743
+ f" {self.num_channels_latents} but received `num_channels_latents`: {self.num_channels_latents} +"
744
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
745
+ f" = {self.num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
746
+ " `pipeline.unet` or your `mask_image` or `image` input."
747
+ )
748
+ elif self.num_channels_unet != 4:
749
+ raise ValueError(
750
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
751
+ )
752
+
753
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
754
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
755
+
756
+ # 7. Denoising loop
757
+ num_warmup_steps = len(timesteps) - \
758
+ num_inference_steps * self.scheduler.order
759
+
760
+ for i, t in enumerate(timesteps):
761
+ # expand the latents if we are doing classifier free guidance
762
+ latent_model_input = torch.cat(
763
+ [latents] * 2) if do_classifier_free_guidance else latents
764
+ latent_model_input = self.scheduler.scale_model_input(
765
+ latent_model_input, t)
766
+
767
+ if self.num_channels_unet == 9:
768
+ latent_model_input = torch.cat(
769
+ [latent_model_input, mask, masked_image_latents], dim=1)
770
+
771
+ latent_model_input = latent_model_input.permute(
772
+ 0, 2, 3, 1).contiguous()
773
+
774
+ # latent_model_input = latent_model_input[:,:4,:,:].
775
+
776
+ # 后边三个 None 是给到controlnet 的参数,暂时给到 None 当 placeholder
777
+ # todo: forward ip image_embeds
778
+ # break
779
+ if ip_adapter_image is not None:
780
+ noise_pred = self.unet.forward(
781
+ latent_model_input, prompt_embeds, t, None, None, None, None, {"ip_hidden_states": image_embeds}, param_scale_dict)
782
+ else:
783
+ noise_pred = self.unet.forward(
784
+ latent_model_input, prompt_embeds, t)
785
+
786
+ noise_pred = noise_pred.permute(0, 3, 1, 2).contiguous()
787
+ # saver.save_v(f"latent_model_input_{i}", latent_model_input)
788
+ # saver.save_v(f"noise_pred_{i}", noise_pred)
789
+ # saver.save_v(f"prompt_embeds_{i}", prompt_embeds)
790
+
791
+ # perform guidance
792
+ if do_classifier_free_guidance:
793
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
794
+ noise_pred = noise_pred_uncond + guidance_scale * \
795
+ (noise_pred_text - noise_pred_uncond)
796
+
797
+ # compute the previous noisy sample x_t -> x_t-1
798
+ latents = self.scheduler.step(
799
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
800
+ if self.num_channels_unet == 4:
801
+ init_latents_proper = image_latents
802
+ if self.do_classifier_free_guidance:
803
+ init_mask, _ = mask.chunk(2)
804
+ else:
805
+ init_mask = mask
806
+
807
+ if i < len(timesteps) - 1:
808
+ noise_timestep = timesteps[i + 1]
809
+ init_latents_proper = self.scheduler.add_noise(
810
+ init_latents_proper, noise, torch.tensor(
811
+ [noise_timestep])
812
+ )
813
+
814
+ latents = (1 - init_mask) * init_latents_proper + \
815
+ init_mask * latents
816
+
817
+ # if do_classifier_free_guidance and guidance_rescale > 0.0:
818
+ # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
819
+ # noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
820
+ # # compute the previous noisy sample x_t -> x_t-1
821
+ # latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
822
+ # image = self.decode_latents(latents)
823
+ image = self.lyra_decode_latents(latents)
824
+ image = numpy_to_pil(image)
825
+
826
+ return image
lyrasd_model/lyrasd_txt2img_pipeline.py CHANGED
@@ -2,7 +2,7 @@ import inspect
2
  import os
3
  import time
4
  from typing import Any, Callable, Dict, List, Optional, Union
5
-
6
  import torch
7
  from diffusers.loaders import TextualInversionLoaderMixin
8
  from diffusers.models import AutoencoderKL
@@ -10,17 +10,43 @@ from diffusers.schedulers import (DPMSolverMultistepScheduler,
10
  EulerAncestralDiscreteScheduler,
11
  EulerDiscreteScheduler,
12
  KarrasDiffusionSchedulers)
13
- from diffusers.utils import logging, randn_tensor
14
  from PIL import Image
15
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
16
  import gc
17
  import numpy as np
18
 
19
- from .lora_util import add_text_lora_layer
 
 
 
 
 
 
 
 
 
 
20
 
21
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def numpy_to_pil(images):
25
  """
26
  Convert a numpy image or a batch of images to a PIL image.
@@ -30,68 +56,18 @@ def numpy_to_pil(images):
30
  images = (images * 255).round().astype("uint8")
31
  if images.shape[-1] == 1:
32
  # special case for grayscale (single channel) images
33
- pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
 
34
  else:
35
  pil_images = [Image.fromarray(image) for image in images]
36
 
37
  return pil_images
38
 
39
 
40
- class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
41
- def __init__(self, model_path, lib_so_path, model_dtype="fp32", device=torch.device("cuda"), dtype=torch.float16) -> None:
42
- self.device = device
43
- self.dtype = dtype
44
-
45
- torch.classes.load_library(lib_so_path)
46
-
47
- self.vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(dtype).to(device)
48
- self.tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
49
- self.text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder").to(dtype).to(device)
50
- unet_path = os.path.join(model_path, "unet_bins/")
51
-
52
- self.unet_in_channels = 4
53
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
54
- self.vae.enable_tiling()
55
- self.unet = torch.classes.lyrasd.Unet2dConditionalModelOp(
56
- 3, # max num of controlnets
57
- "fp16" # inference dtype (can only use fp16 for now)
58
- )
59
-
60
- unet_path = os.path.join(model_path, "unet_bins/")
61
 
62
- self.reload_unet_model(unet_path, model_dtype)
63
-
64
- self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
65
-
66
- self.loaded_lora = {}
67
-
68
- def reload_unet_model(self, unet_path, unet_file_format='fp32'):
69
- if len(unet_path) > 0 and unet_path[-1] != "/":
70
- unet_path = unet_path + "/"
71
- return self.unet.reload_unet_model(unet_path, unet_file_format)
72
-
73
- def load_lora(self, lora_model_path, lora_name, lora_strength, lora_file_format='fp32'):
74
- if len(lora_model_path) > 0 and lora_model_path[-1] != "/":
75
- lora_model_path = lora_model_path + "/"
76
- lora = add_text_lora_layer(self.text_encoder, lora_model_path, lora_strength, lora_file_format)
77
- self.loaded_lora[lora_name] = lora
78
- self.unet.load_lora(lora_model_path, lora_name, lora_strength, lora_file_format)
79
-
80
- def unload_lora(self, lora_name, clean_cache=False):
81
- for layer_data in self.loaded_lora[lora_name]:
82
- layer = layer_data['layer']
83
- added_weight = layer_data['added_weight']
84
- layer.weight.data -= added_weight
85
- self.unet.unload_lora(lora_name, clean_cache)
86
- del self.loaded_lora[lora_name]
87
- gc.collect()
88
- torch.cuda.empty_cache()
89
-
90
- def clean_lora_cache(self):
91
- self.unet.clean_lora_cache()
92
-
93
- def get_loaded_lora(self):
94
- return self.unet.get_loaded_lora()
95
 
96
  def _encode_prompt(
97
  self,
@@ -147,7 +123,8 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
147
  return_tensors="pt",
148
  )
149
  text_input_ids = text_inputs.input_ids
150
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
 
151
 
152
  if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
153
  text_input_ids, untruncated_ids
@@ -171,12 +148,14 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
171
  )
172
  prompt_embeds = prompt_embeds[0]
173
 
174
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
 
175
 
176
  bs_embed, seq_len, _ = prompt_embeds.shape
177
  # duplicate text embeddings for each generation per prompt, using mps friendly method
178
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
179
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
 
180
 
181
  # get unconditional embeddings for classifier free guidance
182
  if do_classifier_free_guidance and negative_prompt_embeds is None:
@@ -201,7 +180,8 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
201
 
202
  # textual inversion: procecss multi-vector tokens if necessary
203
  if isinstance(self, TextualInversionLoaderMixin):
204
- uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
 
205
 
206
  max_length = prompt_embeds.shape[1]
207
  uncond_input = self.tokenizer(
@@ -227,10 +207,13 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
227
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
228
  seq_len = negative_prompt_embeds.shape[1]
229
 
230
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
 
231
 
232
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
233
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
 
 
234
 
235
  # For classifier free guidance, we need to do two forward passes.
236
  # Here we concatenate the unconditional and text embeddings into a single batch
@@ -239,14 +222,83 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
239
 
240
  return prompt_embeds
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  def decode_latents(self, latents):
243
- latents = 1 / self.vae.config.scaling_factor * latents
244
  image = self.vae.decode(latents).sample
245
  image = (image / 2 + 0.5).clamp(0, 1)
246
  # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
247
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
248
  return image
249
 
 
 
 
 
 
 
 
 
 
 
250
  def check_inputs(
251
  self,
252
  prompt,
@@ -257,7 +309,8 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
257
  negative_prompt_embeds=None,
258
  ):
259
  if height % 64 != 0 or width % 64 != 0: # 初版暂时只支持 64 的倍数的 height 和 width
260
- raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.")
 
261
 
262
  if prompt is not None and prompt_embeds is not None:
263
  raise ValueError(
@@ -269,7 +322,8 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
269
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
270
  )
271
  elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
272
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
 
273
 
274
  if negative_prompt is not None and negative_prompt_embeds is not None:
275
  raise ValueError(
@@ -286,7 +340,8 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
286
  )
287
 
288
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
289
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
 
290
  if isinstance(generator, list) and len(generator) != batch_size:
291
  raise ValueError(
292
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -294,7 +349,8 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
294
  )
295
 
296
  if latents is None:
297
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
 
298
  else:
299
  latents = latents.to(device)
300
 
@@ -308,13 +364,15 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
308
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
309
  # and should be between [0, 1]
310
 
311
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
312
  extra_step_kwargs = {}
313
  if accepts_eta:
314
  extra_step_kwargs["eta"] = eta
315
 
316
  # check if the scheduler accepts generator
317
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
318
  if accepts_generator:
319
  extra_step_kwargs["generator"] = generator
320
  return extra_step_kwargs
@@ -330,10 +388,13 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
330
  negative_prompt: Optional[Union[str, List[str]]] = None,
331
  num_images_per_prompt: Optional[int] = 1,
332
  eta: float = 0.0,
333
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
 
334
  latents: Optional[torch.FloatTensor] = None,
335
  prompt_embeds: Optional[torch.FloatTensor] = None,
336
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
 
 
337
  ):
338
  r"""
339
  Function invoked when calling the pipeline for generation.
@@ -410,6 +471,14 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
410
  negative_prompt_embeds=negative_prompt_embeds,
411
  )
412
 
 
 
 
 
 
 
 
 
413
  # 4. Prepare timesteps
414
  self.scheduler.set_timesteps(num_inference_steps, device=device)
415
  timesteps = self.scheduler.timesteps
@@ -431,28 +500,46 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
431
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
432
 
433
  # 7. Denoising loop
434
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
 
435
 
436
  for i, t in enumerate(timesteps):
437
  # expand the latents if we are doing classifier free guidance
438
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
439
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
440
- latent_model_input = latent_model_input.permute(0, 2, 3, 1).contiguous()
441
-
442
- # 后边4个 None 是给到controlnet 的参数,暂时给到 None 当 placeholder
443
- noise_pred = self.unet.forward(latent_model_input, prompt_embeds, t, None, None, None, None)
 
 
 
 
 
 
 
 
 
 
444
 
445
  noise_pred = noise_pred.permute(0, 3, 1, 2)
446
- # perform guidance
447
 
 
 
 
448
  if do_classifier_free_guidance:
449
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
450
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
451
 
 
 
 
452
  # compute the previous noisy sample x_t -> x_t-1
453
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
454
-
455
- image = self.decode_latents(latents)
 
456
  image = numpy_to_pil(image)
457
 
458
  return image
 
2
  import os
3
  import time
4
  from typing import Any, Callable, Dict, List, Optional, Union
5
+ import GPUtil
6
  import torch
7
  from diffusers.loaders import TextualInversionLoaderMixin
8
  from diffusers.models import AutoencoderKL
 
10
  EulerAncestralDiscreteScheduler,
11
  EulerDiscreteScheduler,
12
  KarrasDiffusionSchedulers)
13
+ from diffusers.utils.torch_utils import logging, randn_tensor
14
  from PIL import Image
15
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
16
  import gc
17
  import numpy as np
18
 
19
+ from .lyrasd_vae_model import LyraSdVaeModel
20
+
21
+ from diffusers.image_processor import PipelineImageInput
22
+ from diffusers.models.embeddings import ImageProjection
23
+ from transformers import (
24
+ CLIPImageProcessor,
25
+ CLIPVisionModelWithProjection,
26
+ )
27
+ from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict
28
+ from safetensors.torch import load_file
29
+ from .lyrasd_pipeline_base import LyraSDXLPipelineBase
30
 
31
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
 
33
 
34
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
35
+ """
36
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
37
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
38
+ """
39
+ std_text = noise_pred_text.std(
40
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
41
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
42
+ # rescale the results from guidance (fixes overexposure)
43
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
44
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
45
+ noise_cfg = guidance_rescale * noise_pred_rescaled + \
46
+ (1 - guidance_rescale) * noise_cfg
47
+ return noise_cfg
48
+
49
+
50
  def numpy_to_pil(images):
51
  """
52
  Convert a numpy image or a batch of images to a PIL image.
 
56
  images = (images * 255).round().astype("uint8")
57
  if images.shape[-1] == 1:
58
  # special case for grayscale (single channel) images
59
+ pil_images = [Image.fromarray(image.squeeze(), mode="L")
60
+ for image in images]
61
  else:
62
  pil_images = [Image.fromarray(image) for image in images]
63
 
64
  return pil_images
65
 
66
 
67
+ class LyraSdTxt2ImgPipeline(LyraSDXLPipelineBase):
68
+ def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.18215) -> None:
69
+ super().__init__(device, dtype, vae_scale_factor=vae_scale_factor, vae_scaling_factor=vae_scaling_factor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def _encode_prompt(
73
  self,
 
123
  return_tensors="pt",
124
  )
125
  text_input_ids = text_inputs.input_ids
126
+ untruncated_ids = self.tokenizer(
127
+ prompt, padding="longest", return_tensors="pt").input_ids
128
 
129
  if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
130
  text_input_ids, untruncated_ids
 
148
  )
149
  prompt_embeds = prompt_embeds[0]
150
 
151
+ prompt_embeds = prompt_embeds.to(
152
+ dtype=self.text_encoder.dtype, device=device)
153
 
154
  bs_embed, seq_len, _ = prompt_embeds.shape
155
  # duplicate text embeddings for each generation per prompt, using mps friendly method
156
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
157
+ prompt_embeds = prompt_embeds.view(
158
+ bs_embed * num_images_per_prompt, seq_len, -1)
159
 
160
  # get unconditional embeddings for classifier free guidance
161
  if do_classifier_free_guidance and negative_prompt_embeds is None:
 
180
 
181
  # textual inversion: procecss multi-vector tokens if necessary
182
  if isinstance(self, TextualInversionLoaderMixin):
183
+ uncond_tokens = self.maybe_convert_prompt(
184
+ uncond_tokens, self.tokenizer)
185
 
186
  max_length = prompt_embeds.shape[1]
187
  uncond_input = self.tokenizer(
 
207
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
208
  seq_len = negative_prompt_embeds.shape[1]
209
 
210
+ negative_prompt_embeds = negative_prompt_embeds.to(
211
+ dtype=self.text_encoder.dtype, device=device)
212
 
213
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
214
+ 1, num_images_per_prompt, 1)
215
+ negative_prompt_embeds = negative_prompt_embeds.view(
216
+ batch_size * num_images_per_prompt, seq_len, -1)
217
 
218
  # For classifier free guidance, we need to do two forward passes.
219
  # Here we concatenate the unconditional and text embeddings into a single batch
 
222
 
223
  return prompt_embeds
224
 
225
+ def load_ip_adapter(self,
226
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
227
+ subfolder: str,
228
+ weight_name: str,
229
+ **kwargs
230
+ ):
231
+ # if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
232
+ self.feature_extractor = CLIPImageProcessor()
233
+
234
+ # if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
235
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
236
+ pretrained_model_name_or_path_or_dict,
237
+ subfolder=os.path.join(subfolder, "image_encoder"),
238
+ ).to(self.device, dtype=self.dtype)
239
+ # else:
240
+ # print("kio: already has image_encoder", hasattr(self, "image_encoder"), getattr(self, "feature_extractor", None) is None)
241
+
242
+ # kiotodo: init ImageProjection
243
+ model_path = os.path.join(
244
+ pretrained_model_name_or_path_or_dict, subfolder, weight_name)
245
+ state_dict = torch.load(model_path, map_location="cpu")
246
+
247
+ clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
248
+ cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
249
+ self.encoder_hid_proj = ImageProjection(
250
+ cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
251
+ )
252
+
253
+ image_proj_state_dict = {}
254
+ image_proj_state_dict.update(
255
+ {
256
+ "image_embeds.weight": state_dict["image_proj"]["proj.weight"],
257
+ "image_embeds.bias": state_dict["image_proj"]["proj.bias"],
258
+ "norm.weight": state_dict["image_proj"]["norm.weight"],
259
+ "norm.bias": state_dict["image_proj"]["norm.bias"],
260
+ }
261
+ )
262
+
263
+ self.encoder_hid_proj.load_state_dict(image_proj_state_dict)
264
+ self.encoder_hid_proj.to(dtype=self.dtype, device=self.device)
265
+
266
+ dir_ipadapter = os.path.join(
267
+ pretrained_model_name_or_path_or_dict, subfolder, '.'.join(weight_name.split(".")[:-1]))
268
+ self.unet.load_ip_adapter(dir_ipadapter, "", 1, "fp16")
269
+
270
+ def encode_image(self, image, device, num_images_per_prompt):
271
+ dtype = next(self.image_encoder.parameters()).dtype
272
+ if not isinstance(image, torch.Tensor):
273
+ image = self.feature_extractor(
274
+ image, return_tensors="pt").pixel_values
275
+
276
+ image = image.to(device=device, dtype=dtype)
277
+ image_embeds = self.image_encoder(image).image_embeds
278
+ image_embeds = image_embeds.repeat_interleave(
279
+ num_images_per_prompt, dim=0)
280
+
281
+ uncond_image_embeds = torch.zeros_like(image_embeds)
282
+ return image_embeds, uncond_image_embeds
283
+
284
  def decode_latents(self, latents):
285
+ latents = 1 / self.vae.scaling_factor * latents
286
  image = self.vae.decode(latents).sample
287
  image = (image / 2 + 0.5).clamp(0, 1)
288
  # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
289
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
290
  return image
291
 
292
+ def lyra_decode_latents(self, latents):
293
+ latents = 1 / self.vae.scaling_factor * latents
294
+ image = self.vae.decode(latents)
295
+ image = image.permute(0, 2, 3, 1)
296
+ image = (image / 2 + 0.5).clamp(0, 1)
297
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
298
+ image = image.cpu().float().numpy()
299
+
300
+ return image
301
+
302
  def check_inputs(
303
  self,
304
  prompt,
 
309
  negative_prompt_embeds=None,
310
  ):
311
  if height % 64 != 0 or width % 64 != 0: # 初版暂时只支持 64 的倍数的 height 和 width
312
+ raise ValueError(
313
+ f"`height` and `width` have to be divisible by 64 but are {height} and {width}.")
314
 
315
  if prompt is not None and prompt_embeds is not None:
316
  raise ValueError(
 
322
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
323
  )
324
  elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
325
+ raise ValueError(
326
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
327
 
328
  if negative_prompt is not None and negative_prompt_embeds is not None:
329
  raise ValueError(
 
340
  )
341
 
342
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
343
+ shape = (batch_size, num_channels_latents, height //
344
+ self.vae.scale_factor, width // self.vae.scale_factor)
345
  if isinstance(generator, list) and len(generator) != batch_size:
346
  raise ValueError(
347
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
 
349
  )
350
 
351
  if latents is None:
352
+ latents = randn_tensor(
353
+ shape, generator=generator, device=device, dtype=dtype)
354
  else:
355
  latents = latents.to(device)
356
 
 
364
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
365
  # and should be between [0, 1]
366
 
367
+ accepts_eta = "eta" in set(inspect.signature(
368
+ self.scheduler.step).parameters.keys())
369
  extra_step_kwargs = {}
370
  if accepts_eta:
371
  extra_step_kwargs["eta"] = eta
372
 
373
  # check if the scheduler accepts generator
374
+ accepts_generator = "generator" in set(
375
+ inspect.signature(self.scheduler.step).parameters.keys())
376
  if accepts_generator:
377
  extra_step_kwargs["generator"] = generator
378
  return extra_step_kwargs
 
388
  negative_prompt: Optional[Union[str, List[str]]] = None,
389
  num_images_per_prompt: Optional[int] = 1,
390
  eta: float = 0.0,
391
+ generator: Optional[Union[torch.Generator,
392
+ List[torch.Generator]]] = None,
393
  latents: Optional[torch.FloatTensor] = None,
394
  prompt_embeds: Optional[torch.FloatTensor] = None,
395
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
396
+ ip_adapter_image: Optional[PipelineImageInput] = None,
397
+ param_scale_dict: Optional[dict] = {}
398
  ):
399
  r"""
400
  Function invoked when calling the pipeline for generation.
 
471
  negative_prompt_embeds=negative_prompt_embeds,
472
  )
473
 
474
+ # 3.5 Encode ipadapter_image
475
+ if ip_adapter_image is not None:
476
+ image_embeds, negative_image_embeds = self.encode_image(
477
+ ip_adapter_image, device, num_images_per_prompt)
478
+ if do_classifier_free_guidance:
479
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
480
+ image_embeds = self.encoder_hid_proj(image_embeds).to(self.dtype)
481
+
482
  # 4. Prepare timesteps
483
  self.scheduler.set_timesteps(num_inference_steps, device=device)
484
  timesteps = self.scheduler.timesteps
 
500
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
501
 
502
  # 7. Denoising loop
503
+ num_warmup_steps = len(timesteps) - \
504
+ num_inference_steps * self.scheduler.order
505
 
506
  for i, t in enumerate(timesteps):
507
  # expand the latents if we are doing classifier free guidance
508
+ latent_model_input = torch.cat(
509
+ [latents] * 2) if do_classifier_free_guidance else latents
510
+ latent_model_input = self.scheduler.scale_model_input(
511
+ latent_model_input, t)
512
+ latent_model_input = latent_model_input.permute(
513
+ 0, 2, 3, 1).contiguous()
514
+
515
+ # 后边三个 None 是给到controlnet 的参数,暂时给到 None 当 placeholder
516
+ # todo: forward ip image_embeds
517
+ # break
518
+ if ip_adapter_image is not None:
519
+ noise_pred = self.unet.forward(
520
+ latent_model_input, prompt_embeds, t, None, None, None, None, {"ip_hidden_states": image_embeds}, param_scale_dict)
521
+ else:
522
+ noise_pred = self.unet.forward(
523
+ latent_model_input, prompt_embeds, t)
524
 
525
  noise_pred = noise_pred.permute(0, 3, 1, 2)
 
526
 
527
+ np.save(f"/workspace/noise_pred_{i}.npy", noise_pred.detach().cpu().numpy())
528
+
529
+ # perform guidance
530
  if do_classifier_free_guidance:
531
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
532
+ noise_pred = noise_pred_uncond + guidance_scale * \
533
+ (noise_pred_text - noise_pred_uncond)
534
 
535
+ # if do_classifier_free_guidance and guidance_rescale > 0.0:
536
+ # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
537
+ # noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
538
  # compute the previous noisy sample x_t -> x_t-1
539
+ latents = self.scheduler.step(
540
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
541
+ # image = self.decode_latents(latents)
542
+ image = self.lyra_decode_latents(latents)
543
  image = numpy_to_pil(image)
544
 
545
  return image
lyrasd_model/lyrasd_vae_model.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
20
+ import numpy as np
21
+
22
+ from safetensors.torch import load_file
23
+
24
+ import os
25
+
26
+ class LyraSdVaeModel():
27
+ r"""
28
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
29
+
30
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
31
+ for all models (such as downloading or saving).
32
+
33
+ Parameters:
34
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
35
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
36
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
37
+ Tuple of downsample block types.
38
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
39
+ Tuple of upsample block types.
40
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
41
+ Tuple of block output channels.
42
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
43
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
44
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
45
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
46
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
47
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
48
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
49
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
50
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
51
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
52
+ force_upcast (`bool`, *optional*, default to `True`):
53
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
54
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
55
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
56
+ """
57
+
58
+ _supports_gradient_checkpointing = True
59
+
60
+ def __init__(
61
+ self,
62
+ dtype: str = "fp16",
63
+ scaling_factor: float = 0.18215,
64
+ scale_factor: int = 8,
65
+ is_upcast: bool = False
66
+ ):
67
+ super().__init__()
68
+ self.is_upcast = is_upcast
69
+ self.scaling_factor = scaling_factor
70
+ self.scale_factor = scale_factor
71
+ self.model = torch.classes.lyrasd.VaeModelOp(
72
+ dtype,
73
+ is_upcast
74
+ )
75
+
76
+ self.vae_cache = {}
77
+
78
+ self.use_slicing = False
79
+ self.use_tiling = False
80
+
81
+ self.tile_latent_min_size = 512
82
+ self.tile_sample_min_size = 64
83
+ self.tile_overlap_factor = 0.25
84
+
85
+ def reload_vae_model(self, vae_path, vae_file_format='fp32'):
86
+ if len(vae_path) > 0 and vae_path[-1] != "/":
87
+ vae_path = vae_path + "/"
88
+ return self.model.reload_vae_model(vae_path, vae_file_format)
89
+
90
+ def reload_vae_model_v2(self, model_path):
91
+ checkpoint_file = os.path.join(model_path, "vae/diffusion_pytorch_model.bin")
92
+ if not os.path.exists(checkpoint_file):
93
+ checkpoint_file = os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors")
94
+ if checkpoint_file in self.vae_cache:
95
+ state_dict = self.vae_cache[checkpoint_file]
96
+ else:
97
+ if "safetensors" in checkpoint_file:
98
+ state_dict = load_file(checkpoint_file)
99
+ else:
100
+ state_dict = torch.load(checkpoint_file, map_location="cpu")
101
+
102
+ # replace deprecated weights
103
+ for path in ["encoder.mid_block.attentions.0", "decoder.mid_block.attentions.0"]:
104
+ # group_norm path stays the same
105
+
106
+ # query -> to_q
107
+ if f"{path}.query.weight" in state_dict:
108
+ state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
109
+ if f"{path}.query.bias" in state_dict:
110
+ state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
111
+
112
+ # key -> to_k
113
+ if f"{path}.key.weight" in state_dict:
114
+ state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
115
+ if f"{path}.key.bias" in state_dict:
116
+ state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
117
+
118
+ # value -> to_v
119
+ if f"{path}.value.weight" in state_dict:
120
+ state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
121
+ if f"{path}.value.bias" in state_dict:
122
+ state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
123
+
124
+ # proj_attn -> to_out.0
125
+ if f"{path}.proj_attn.weight" in state_dict:
126
+ state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
127
+ if f"{path}.proj_attn.bias" in state_dict:
128
+ state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
129
+
130
+ for key in state_dict:
131
+ # print(key)
132
+ if len(state_dict[key].shape) == 4:
133
+ state_dict[key] = state_dict[key].permute(0,2,3,1).contiguous()
134
+ else:
135
+ state_dict[key] = state_dict[key]
136
+ if self.is_upcast and (key.startswith("decoder.up_blocks.2") or key.startswith("decoder.up_blocks.3") or key.startswith("decoder.conv_norm_out")):
137
+ # print(key)
138
+ state_dict[key] = state_dict[key].to(torch.float32)
139
+ else:
140
+ state_dict[key] = state_dict[key].to(torch.float16)
141
+
142
+ self.vae_cache[checkpoint_file] = state_dict
143
+
144
+ return self.model.reload_vae_model_from_cache(state_dict, "cpu")
145
+
146
+ def enable_tiling(self, use_tiling: bool = True):
147
+ r"""
148
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
149
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
150
+ processing larger images.
151
+ """
152
+ self.use_tiling = use_tiling
153
+
154
+ def disable_tiling(self):
155
+ r"""
156
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
157
+ decoding in one step.
158
+ """
159
+ self.enable_tiling(False)
160
+
161
+ def enable_slicing(self):
162
+ r"""
163
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
164
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
165
+ """
166
+ self.use_slicing = True
167
+
168
+ def disable_slicing(self):
169
+ r"""
170
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
171
+ decoding in one step.
172
+ """
173
+ self.use_slicing = False
174
+
175
+ def lyra_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
176
+ x = x.permute(0, 2, 3, 1).contiguous()
177
+ x = self.model.vae_decode(x)
178
+ return x.permute(0, 3, 1, 2)
179
+
180
+ def lyra_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
181
+ x = x.permute(0, 2, 3, 1).contiguous()
182
+ x = self.model.vae_encode(x)
183
+ return x.permute(0, 3, 1, 2)
184
+
185
+ def encode(
186
+ self, x: torch.FloatTensor, return_dict: bool = True
187
+ ) -> DiagonalGaussianDistribution:
188
+ """
189
+ Encode a batch of images into latents.
190
+
191
+ Args:
192
+ x (`torch.FloatTensor`): Input batch of images.
193
+ return_dict (`bool`, *optional*, defaults to `True`):
194
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
195
+
196
+ Returns:
197
+ The latent representations of the encoded images. If `return_dict` is True, a
198
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
199
+ """
200
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
201
+ return self.tiled_encode(x, return_dict=return_dict)
202
+
203
+ if self.use_slicing and x.shape[0] > 1:
204
+ encoded_slices = [self.lyra_encode(
205
+ x_slice) for x_slice in x.split(1)]
206
+ h = torch.cat(encoded_slices)
207
+ posterior = DiagonalGaussianDistribution(h)
208
+ else:
209
+ moments = self.lyra_encode(x)
210
+ posterior = DiagonalGaussianDistribution(moments)
211
+
212
+ return posterior
213
+
214
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
215
+ if self.use_tiling and (z.shape[2] > self.tile_latent_min_size or z.shape[3] > self.tile_latent_min_size):
216
+ return self.tiled_decode(z, return_dict=return_dict)
217
+
218
+ dec = self.lyra_decode(z)
219
+
220
+ return dec
221
+
222
+ def decode(
223
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
224
+ ) -> torch.FloatTensor:
225
+ """
226
+ Decode a batch of images.
227
+
228
+ Args:
229
+ z (`torch.FloatTensor`): Input batch of latent vectors.
230
+ return_dict (`bool`, *optional*, defaults to `True`):
231
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
232
+
233
+ Returns:
234
+ [`~models.vae.DecoderOutput`] or `tuple`:
235
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
236
+ returned.
237
+
238
+ """
239
+ if self.use_slicing and z.shape[0] > 1:
240
+ decoded_slices = [self._decode(
241
+ z_slice) for z_slice in z.split(1)]
242
+ decoded = torch.cat(decoded_slices)
243
+ else:
244
+ decoded = self._decode(z)
245
+
246
+ return decoded
247
+
248
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
249
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
250
+ for y in range(blend_extent):
251
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * \
252
+ (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
253
+ return b
254
+
255
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
256
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
257
+ for x in range(blend_extent):
258
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * \
259
+ (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
260
+ return b
261
+
262
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> DiagonalGaussianDistribution:
263
+ r"""Encode a batch of images using a tiled encoder.
264
+
265
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
266
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
267
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
268
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
269
+ output, but they should be much less noticeable.
270
+
271
+ Args:
272
+ x (`torch.FloatTensor`): Input batch of images.
273
+ return_dict (`bool`, *optional*, defaults to `True`):
274
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
275
+
276
+ Returns:
277
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
278
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
279
+ `tuple` is returned.
280
+ """
281
+ overlap_size = int(self.tile_sample_min_size *
282
+ (1 - self.tile_overlap_factor))
283
+ blend_extent = int(self.tile_latent_min_size *
284
+ self.tile_overlap_factor)
285
+ row_limit = self.tile_latent_min_size - blend_extent
286
+
287
+ # Split the image into 512x512 tiles and encode them separately.
288
+ rows = []
289
+ for i in range(0, x.shape[2], overlap_size):
290
+ row = []
291
+ for j in range(0, x.shape[3], overlap_size):
292
+ tile = x[:, :, i: i + self.tile_sample_min_size,
293
+ j: j + self.tile_sample_min_size]
294
+ tile = self.lyra_encode(tile)
295
+ row.append(tile)
296
+ rows.append(row)
297
+ result_rows = []
298
+ for i, row in enumerate(rows):
299
+ result_row = []
300
+ for j, tile in enumerate(row):
301
+ # blend the above tile and the left tile
302
+ # to the current tile and add the current tile to the result row
303
+ if i > 0:
304
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
305
+ if j > 0:
306
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
307
+ result_row.append(tile[:, :, :row_limit, :row_limit])
308
+ result_rows.append(torch.cat(result_row, dim=3))
309
+
310
+ moments = torch.cat(result_rows, dim=2)
311
+ posterior = DiagonalGaussianDistribution(moments)
312
+
313
+ return posterior
314
+
315
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
316
+ r"""
317
+ Decode a batch of images using a tiled decoder.
318
+
319
+ Args:
320
+ z (`torch.FloatTensor`): Input batch of latent vectors.
321
+ return_dict (`bool`, *optional*, defaults to `True`):
322
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
323
+
324
+ Returns:
325
+ [`~models.vae.DecoderOutput`] or `tuple`:
326
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
327
+ returned.
328
+ """
329
+ overlap_size = int(self.tile_latent_min_size *
330
+ (1 - self.tile_overlap_factor))
331
+ blend_extent = int(self.tile_sample_min_size *
332
+ self.tile_overlap_factor)
333
+ row_limit = self.tile_sample_min_size - blend_extent
334
+
335
+ # Split z into overlapping 64x64 tiles and decode them separately.
336
+ # The tiles have an overlap to avoid seams between tiles.
337
+ rows = []
338
+ for i in range(0, z.shape[2], overlap_size):
339
+ row = []
340
+ for j in range(0, z.shape[3], overlap_size):
341
+ tile = z[:, :, i: i + self.tile_latent_min_size,
342
+ j: j + self.tile_latent_min_size]
343
+ decoded = self.lyra_decode(tile)
344
+ row.append(decoded)
345
+ rows.append(row)
346
+ result_rows = []
347
+ for i, row in enumerate(rows):
348
+ result_row = []
349
+ for j, tile in enumerate(row):
350
+ # blend the above tile and the left tile
351
+ # to the current tile and add the current tile to the result row
352
+ if i > 0:
353
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
354
+ if j > 0:
355
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
356
+ result_row.append(tile[:, :, :row_limit, :row_limit])
357
+ result_rows.append(torch.cat(result_row, dim=3))
358
+
359
+ dec = torch.cat(result_rows, dim=2)
360
+ if not return_dict:
361
+ return (dec,)
362
+
363
+ return dec
lyrasd_model/lyrasdxl_controlnet_txt2img_pipeline.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+ import time
4
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
5
+
6
+ import gc
7
+ import torch
8
+ import numpy as np
9
+ from glob import glob
10
+
11
+ import PIL
12
+
13
+ from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
14
+ from diffusers.loaders import TextualInversionLoaderMixin
15
+ from diffusers.image_processor import VaeImageProcessor
16
+ from diffusers.models import AutoencoderKL
17
+ from diffusers.schedulers import (DPMSolverMultistepScheduler,
18
+ EulerAncestralDiscreteScheduler,
19
+ EulerDiscreteScheduler,
20
+ KarrasDiffusionSchedulers)
21
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
22
+ from diffusers.utils.torch_utils import randn_tensor
23
+ from diffusers.utils import logging
24
+ from PIL import Image
25
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
26
+ from diffusers.utils import PIL_INTERPOLATION
27
+ from .lyrasd_vae_model import LyraSdVaeModel
28
+
29
+ from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict
30
+ from safetensors.torch import load_file
31
+ from .lyrasdxl_pipeline_base import LyraSDXLPipelineBase
32
+
33
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
34
+ """
35
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
36
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
37
+ """
38
+ std_text = noise_pred_text.std(
39
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
40
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
41
+ # rescale the results from guidance (fixes overexposure)
42
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
43
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
44
+ noise_cfg = guidance_rescale * noise_pred_rescaled + \
45
+ (1 - guidance_rescale) * noise_cfg
46
+ return noise_cfg
47
+
48
+
49
+ class LyraSdXLControlnetTxt2ImgPipeline(LyraSDXLPipelineBase, StableDiffusionXLPipeline):
50
+ device = torch.device("cpu")
51
+ dtype = torch.float32
52
+
53
+ def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.13025) -> None:
54
+ self.register_to_config(force_zeros_for_empty_prompt=True)
55
+
56
+ super().__init__(device, dtype, vae_scale_factor=vae_scale_factor, vae_scaling_factor=vae_scaling_factor)
57
+
58
+
59
+ def prepare_image(
60
+ self,
61
+ image,
62
+ width,
63
+ height,
64
+ batch_size,
65
+ num_images_per_prompt,
66
+ device,
67
+ dtype,
68
+ do_classifier_free_guidance=False,
69
+ guess_mode=False,
70
+ ):
71
+ image = self.control_image_processor.preprocess(image, height, width)
72
+ image = image.permute(0, 2, 3, 1)
73
+
74
+ image = image.to(device=device, dtype=dtype)
75
+ # print(image.shape)
76
+ # print(image)
77
+
78
+ return image
79
+
80
+ @property
81
+ def _execution_device(self):
82
+ if not hasattr(self.unet, "_hf_hook"):
83
+ return self.device
84
+ for module in self.unet.modules():
85
+ if (
86
+ hasattr(module, "_hf_hook")
87
+ and hasattr(module._hf_hook, "execution_device")
88
+ and module._hf_hook.execution_device is not None
89
+ ):
90
+ return torch.device(module._hf_hook.execution_device)
91
+ return self.device
92
+
93
+ def _get_aug_emb(self, add_embedding, time_ids, text_embeds, dtype):
94
+ time_embeds = self.add_time_proj(time_ids.flatten())
95
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
96
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
97
+ add_embeds = add_embeds.to(dtype)
98
+ aug_emb = add_embedding(add_embeds)
99
+ return aug_emb
100
+
101
+ @torch.no_grad()
102
+ def __call__(
103
+ self,
104
+ prompt: Union[str, List[str]] = None,
105
+ prompt_2: Optional[Union[str, List[str]]] = None,
106
+ height: Optional[int] = None,
107
+ width: Optional[int] = None,
108
+ num_inference_steps: int = 50,
109
+ denoising_end: Optional[float] = None,
110
+ guidance_scale: float = 5.0,
111
+ negative_prompt: Optional[Union[str, List[str]]] = None,
112
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
113
+ num_images_per_prompt: Optional[int] = 1,
114
+ controlnet_names: Optional[List[str]] = None,
115
+ controlnet_images: Optional[List[PIL.Image.Image]] = None,
116
+ controlnet_scale: Optional[List[float]] = None,
117
+ guess_mode=False,
118
+ eta: float = 0.0,
119
+ generator: Optional[Union[torch.Generator,
120
+ List[torch.Generator]]] = None,
121
+ latents: Optional[torch.FloatTensor] = None,
122
+ prompt_embeds: Optional[torch.FloatTensor] = None,
123
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
124
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
125
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
126
+ output_type: Optional[str] = "pil",
127
+ return_dict: bool = True,
128
+ callback: Optional[Callable[[
129
+ int, int, torch.FloatTensor], None]] = None,
130
+ callback_steps: int = 1,
131
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
132
+ guidance_rescale: float = 0.0,
133
+ original_size: Optional[Tuple[int, int]] = None,
134
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
135
+ target_size: Optional[Tuple[int, int]] = None,
136
+ ):
137
+
138
+ # 0. Default height and width to unet
139
+ height = height or self.default_sample_size * self.vae_scale_factor
140
+ width = width or self.default_sample_size * self.vae_scale_factor
141
+
142
+ original_size = original_size or (height, width)
143
+ target_size = target_size or (height, width)
144
+
145
+ # 1. Check inputs. Raise error if not correct
146
+ self.check_inputs(
147
+ prompt,
148
+ prompt_2,
149
+ height,
150
+ width,
151
+ callback_steps,
152
+ negative_prompt,
153
+ negative_prompt_2,
154
+ prompt_embeds,
155
+ negative_prompt_embeds,
156
+ pooled_prompt_embeds,
157
+ negative_pooled_prompt_embeds,
158
+ )
159
+
160
+ # 2. Define call parameters
161
+ if prompt is not None and isinstance(prompt, str):
162
+ batch_size = 1
163
+ elif prompt is not None and isinstance(prompt, list):
164
+ batch_size = len(prompt)
165
+ else:
166
+ batch_size = prompt_embeds.shape[0]
167
+
168
+ device = self._execution_device
169
+
170
+ do_classifier_free_guidance = guidance_scale > 1.0
171
+
172
+ # 3. Encode input prompt
173
+ text_encoder_lora_scale = (
174
+ cross_attention_kwargs.get(
175
+ "scale", None) if cross_attention_kwargs is not None else None
176
+ )
177
+ (
178
+ prompt_embeds,
179
+ negative_prompt_embeds,
180
+ pooled_prompt_embeds,
181
+ negative_pooled_prompt_embeds,
182
+ ) = self.encode_prompt(
183
+ prompt=prompt,
184
+ prompt_2=prompt_2,
185
+ device=device,
186
+ num_images_per_prompt=num_images_per_prompt,
187
+ do_classifier_free_guidance=do_classifier_free_guidance,
188
+ negative_prompt=negative_prompt,
189
+ negative_prompt_2=negative_prompt_2,
190
+ prompt_embeds=prompt_embeds,
191
+ negative_prompt_embeds=negative_prompt_embeds,
192
+ pooled_prompt_embeds=pooled_prompt_embeds,
193
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
194
+ lora_scale=text_encoder_lora_scale,
195
+ )
196
+
197
+ control_images = []
198
+
199
+ for image_ in controlnet_images:
200
+ image_ = self.prepare_image(
201
+ image=image_,
202
+ width=width,
203
+ height=height,
204
+ batch_size=batch_size * num_images_per_prompt,
205
+ num_images_per_prompt=num_images_per_prompt,
206
+ device=device,
207
+ dtype=prompt_embeds.dtype,
208
+ do_classifier_free_guidance=do_classifier_free_guidance
209
+ )
210
+
211
+ control_images.append(image_)
212
+
213
+ control_scales = []
214
+
215
+ scales = [1.0, ] * 10
216
+ if guess_mode:
217
+ scales = torch.logspace(-1, 0, 10).tolist()
218
+
219
+ for scale in controlnet_scale:
220
+ scales_ = [d * scale for d in scales]
221
+ control_scales.append(scales_)
222
+
223
+ # 4. Prepare timesteps
224
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
225
+
226
+ timesteps = self.scheduler.timesteps
227
+
228
+ # 5. Prepare latent variables
229
+ num_channels_latents = self.unet_in_channels
230
+ latents = self.prepare_latents(
231
+ batch_size * num_images_per_prompt,
232
+ num_channels_latents,
233
+ height,
234
+ width,
235
+ prompt_embeds.dtype,
236
+ device,
237
+ generator,
238
+ latents,
239
+ )
240
+
241
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
242
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
243
+
244
+ # 7. Prepare added time ids & embeddings
245
+ add_text_embeds = pooled_prompt_embeds
246
+ add_time_ids = list(
247
+ original_size + crops_coords_top_left + target_size)
248
+ add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
249
+
250
+ if do_classifier_free_guidance:
251
+ prompt_embeds = torch.cat(
252
+ [negative_prompt_embeds, prompt_embeds], dim=0)
253
+ add_text_embeds = torch.cat(
254
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0)
255
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
256
+
257
+ prompt_embeds = prompt_embeds.to(device)
258
+ add_text_embeds = add_text_embeds.to(device)
259
+ add_time_ids = add_time_ids.to(device).repeat(
260
+ batch_size * num_images_per_prompt, 1)
261
+
262
+ # 8. Denoising loop
263
+ num_warmup_steps = max(
264
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0)
265
+
266
+ # 7.1 Apply denoising_end
267
+ if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
268
+ discrete_timestep_cutoff = int(
269
+ round(
270
+ self.scheduler.config.num_train_timesteps
271
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
272
+ )
273
+ )
274
+ num_inference_steps = len(
275
+ list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
276
+ timesteps = timesteps[:num_inference_steps]
277
+
278
+ aug_emb = self._get_aug_emb(
279
+ self.add_embedding, add_time_ids, add_text_embeds, prompt_embeds.dtype)
280
+
281
+ controlnet_aug_embs = []
282
+ for controlnet_name in controlnet_names:
283
+ controlnet_aug_embs.append(self._get_aug_emb(self.controlnet_add_embedding[controlnet_name],
284
+ add_time_ids, add_text_embeds, prompt_embeds.dtype))
285
+
286
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
287
+ for i, t in enumerate(timesteps):
288
+ # expand the latents if we are doing classifier free guidance
289
+ latent_model_input = torch.cat(
290
+ [latents] * 2) if do_classifier_free_guidance else latents
291
+
292
+ latent_model_input = self.scheduler.scale_model_input(
293
+ latent_model_input, t)
294
+ latent_model_input = latent_model_input.permute(
295
+ 0, 2, 3, 1).contiguous()
296
+
297
+ noise_pred = self.unet.forward(
298
+ latent_model_input, prompt_embeds, t, aug_emb,
299
+ controlnet_names, control_images, controlnet_aug_embs, control_scales, guess_mode).permute(0, 3, 1, 2)
300
+
301
+ # print(noise_pred)
302
+
303
+ # perform guidance
304
+ if do_classifier_free_guidance:
305
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
306
+ noise_pred = noise_pred_uncond + guidance_scale * \
307
+ (noise_pred_text - noise_pred_uncond)
308
+
309
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
310
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
311
+ noise_pred = rescale_noise_cfg(
312
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
313
+
314
+ # compute the previous noisy sample x_t -> x_t-1
315
+ latents = self.scheduler.step(
316
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
317
+
318
+ # call the callback, if provided
319
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
320
+ progress_bar.update()
321
+ if callback is not None and i % callback_steps == 0:
322
+ callback(i, t, latents)
323
+
324
+ # make sure the VAE is in float32 mode, as it overflows in float16
325
+ # if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
326
+ # self.upcast_vae()
327
+ # latents = latents.to(
328
+ # next(iter(self.vae.post_quant_conv.parameters())).dtype)
329
+ # # latents = latents.to(torch.float32)
330
+ # if output_type == "latent":
331
+ # return latents
332
+
333
+ # np.save(f"/workspace/latents.npy", latents.detach().cpu().numpy())
334
+
335
+ # image = self.vae.decode(
336
+ # latents / self.vae.config.scaling_factor, return_dict=False)[0]
337
+ image = self.vae.decode(1 / self.vae.scaling_factor * latents)
338
+
339
+ image = self.image_processor.postprocess(
340
+ image, output_type=output_type)
341
+
342
+ # Offload last model to CPU
343
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
344
+ self.final_offload_hook.offload()
345
+
346
+ return image
lyrasd_model/lyrasdxl_pipeline_base.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+ import time
4
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
5
+
6
+ import gc
7
+ import torch
8
+ import numpy as np
9
+ from glob import glob
10
+
11
+ from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
12
+ from diffusers.loaders import TextualInversionLoaderMixin
13
+ from diffusers.image_processor import VaeImageProcessor
14
+ from diffusers.models import AutoencoderKL
15
+ from diffusers.schedulers import (DPMSolverMultistepScheduler,
16
+ EulerAncestralDiscreteScheduler,
17
+ EulerDiscreteScheduler,
18
+ KarrasDiffusionSchedulers)
19
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
20
+ from diffusers.utils.torch_utils import randn_tensor
21
+ from diffusers.utils import logging
22
+ from PIL import Image
23
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
24
+ from .lyrasd_vae_model import LyraSdVaeModel
25
+ from .module.lyrasd_ip_adapter import LyraIPAdapter
26
+ from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict
27
+ from safetensors.torch import load_file
28
+
29
+
30
+ class LyraSDXLPipelineBase(TextualInversionLoaderMixin):
31
+ def __init__(self, device=torch.device("cuda"), dtype=torch.float16, num_channels_unet=4, num_channels_latents=4, vae_scale_factor=8, vae_scaling_factor=0.13025) -> None:
32
+ self.device = device
33
+ self.dtype = dtype
34
+
35
+ self.num_channels_unet = num_channels_unet
36
+ self.num_channels_latents = num_channels_latents
37
+ self.vae_scale_factor = vae_scale_factor
38
+ self.vae_scaling_factor = vae_scaling_factor
39
+
40
+ self.unet_cache = {}
41
+ self.unet_in_channels = 4
42
+
43
+ self.controlnet_cache = {}
44
+ self.controlnet_add_embedding = {}
45
+
46
+ self.loaded_lora = {}
47
+ self.loaded_lora_strength = {}
48
+
49
+ self.scheduler = None
50
+
51
+ self.init_pipe()
52
+
53
+ def init_pipe(self):
54
+ self.vae = LyraSdVaeModel(
55
+ scale_factor=self.vae_scale_factor, scaling_factor=self.vae_scaling_factor, is_upcast=True)
56
+
57
+ self.unet = torch.classes.lyrasd.XLUnet2dConditionalModelOp(
58
+ "fp16",
59
+ self.num_channels_unet,
60
+ self.num_channels_latents)
61
+
62
+ self.default_sample_size = 128
63
+ self.addition_time_embed_dim = 256
64
+ flip_sin_to_cos, freq_shift = True, 0
65
+ self.projection_class_embeddings_input_dim, self.time_embed_dim = 2816, 1280
66
+
67
+ self.add_time_proj = Timesteps(
68
+ self.addition_time_embed_dim, flip_sin_to_cos, freq_shift).to(self.dtype).to(self.device)
69
+
70
+ self.add_embedding = TimestepEmbedding(
71
+ self.projection_class_embeddings_input_dim, self.time_embed_dim).to(self.dtype).to(self.device)
72
+
73
+ self.image_processor = VaeImageProcessor(
74
+ vae_scale_factor=self.vae_scale_factor)
75
+
76
+ self.mask_processor = VaeImageProcessor(
77
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
78
+ )
79
+
80
+ self.control_image_processor = VaeImageProcessor(
81
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
82
+ )
83
+
84
+ self.feature_extractor = CLIPImageProcessor()
85
+
86
+ def reload_pipe(self, model_path):
87
+ self.tokenizer = CLIPTokenizer.from_pretrained(
88
+ model_path, subfolder="tokenizer")
89
+ self.text_encoder = CLIPTextModel.from_pretrained(
90
+ model_path, subfolder="text_encoder").to(self.dtype).to(self.device)
91
+
92
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(
93
+ model_path, subfolder="tokenizer_2")
94
+ self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
95
+ model_path, subfolder="text_encoder_2").to(self.dtype).to(self.device)
96
+
97
+ self.reload_unet_model_v2(model_path)
98
+ self.reload_vae_model_v2(model_path)
99
+
100
+ if not self.scheduler:
101
+ self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
102
+ model_path, subfolder="scheduler")
103
+
104
+ def load_embedding_weight(self, model, weight_path, unet_file_format="fp16"):
105
+ bin_list = glob(weight_path)
106
+ sate_dicts = model.state_dict()
107
+ dtype = np.float32 if unet_file_format == "fp32" else np.float16
108
+ for bin_file in bin_list:
109
+ weight = torch.from_numpy(np.fromfile(bin_file, dtype=dtype)).to(
110
+ self.dtype).to(self.device)
111
+ key = '.'.join(os.path.basename(bin_file).split('.')[1:-1])
112
+ weight = weight.reshape(sate_dicts[key].shape)
113
+ sate_dicts.update({key: weight})
114
+ model.load_state_dict(sate_dicts)
115
+
116
+ @property
117
+ def _execution_device(self):
118
+ if not hasattr(self.unet, "_hf_hook"):
119
+ return self.device
120
+ for module in self.unet.modules():
121
+ if (
122
+ hasattr(module, "_hf_hook")
123
+ and hasattr(module._hf_hook, "execution_device")
124
+ and module._hf_hook.execution_device is not None
125
+ ):
126
+ return torch.device(module._hf_hook.execution_device)
127
+ return self.device
128
+
129
+ def reload_unet_model(self, unet_path, unet_file_format='fp32'):
130
+ if len(unet_path) > 0 and unet_path[-1] != "/":
131
+ unet_path = unet_path + "/"
132
+ self.unet.reload_unet_model(unet_path, unet_file_format)
133
+ self.load_embedding_weight(
134
+ self.add_embedding, f"{unet_path}add_embedding*", unet_file_format=unet_file_format)
135
+
136
+ def reload_vae_model(self, vae_path, vae_file_format='fp32'):
137
+ if len(vae_path) > 0 and vae_path[-1] != "/":
138
+ vae_path = vae_path + "/"
139
+ return self.vae.reload_vae_model(vae_path, vae_file_format)
140
+
141
+ def load_lora(self, lora_model_path, lora_name, lora_strength, lora_file_format='fp32'):
142
+ if len(lora_model_path) > 0 and lora_model_path[-1] != "/":
143
+ lora_model_path = lora_model_path + "/"
144
+ lora = add_xltext_lora_layer(
145
+ self.text_encoder, self.text_encoder_2, lora_model_path, lora_strength, lora_file_format)
146
+
147
+ self.loaded_lora[lora_name] = lora
148
+ self.unet.load_lora(lora_model_path, lora_name,
149
+ lora_strength, lora_file_format)
150
+
151
+ def unload_lora(self, lora_name, clean_cache=False):
152
+ for layer_data in self.loaded_lora[lora_name]:
153
+ layer = layer_data['layer']
154
+ added_weight = layer_data['added_weight']
155
+ layer.weight.data -= added_weight
156
+ self.unet.unload_lora(lora_name, clean_cache)
157
+ del self.loaded_lora[lora_name]
158
+ gc.collect()
159
+ torch.cuda.empty_cache()
160
+
161
+ def load_lora_v2(self, lora_model_path, lora_name, lora_strength):
162
+ if lora_name in self.loaded_lora:
163
+ state_dict = self.loaded_lora[lora_name]
164
+ else:
165
+ state_dict = load_state_dict(lora_model_path)
166
+ self.loaded_lora[lora_name] = state_dict
167
+ self.loaded_lora_strength[lora_name] = lora_strength
168
+ add_lora_to_opt_model(state_dict, self.unet, self.text_encoder,
169
+ self.text_encoder_2, lora_strength)
170
+
171
+ def unload_lora_v2(self, lora_name, clean_cache=False):
172
+ state_dict = self.loaded_lora[lora_name]
173
+ lora_strength = self.loaded_lora_strength[lora_name]
174
+ add_lora_to_opt_model(state_dict, self.unet, self.text_encoder,
175
+ self.text_encoder_2, -1.0 * lora_strength)
176
+ del self.loaded_lora_strength[lora_name]
177
+
178
+ if clean_cache:
179
+ del self.loaded_lora[lora_name]
180
+ gc.collect()
181
+ torch.cuda.empty_cache()
182
+
183
+ def clean_lora_cache(self):
184
+ self.unet.clean_lora_cache()
185
+
186
+ def get_loaded_lora(self):
187
+ return self.unet.get_loaded_lora()
188
+
189
+ def _get_aug_emb(self, time_ids, text_embeds, dtype):
190
+ time_embeds = self.add_time_proj(time_ids.flatten())
191
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
192
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
193
+ add_embeds = add_embeds.to(dtype)
194
+ aug_emb = self.add_embedding(add_embeds)
195
+ return aug_emb
196
+
197
+ def load_ip_adapter(self, dir_ip_adapter, ip_plus, image_encoder_path, num_ip_tokens, ip_projection_dim, dir_face_in=None, num_fp_tokens=1, fp_projection_dim=None, sdxl=True):
198
+ self.ip_adapter_helper = LyraIPAdapter(self, sdxl, "cuda", dir_ip_adapter, ip_plus, image_encoder_path,
199
+ num_ip_tokens, ip_projection_dim, dir_face_in, num_fp_tokens, fp_projection_dim)
200
+
201
+ def reload_unet_model_v2(self, model_path):
202
+ checkpoint_file = os.path.join(
203
+ model_path, "unet/diffusion_pytorch_model.bin")
204
+ if not os.path.exists(checkpoint_file):
205
+ checkpoint_file = os.path.join(
206
+ model_path, "unet/diffusion_pytorch_model.safetensors")
207
+ if checkpoint_file in self.unet_cache:
208
+ state_dict = self.unet_cache[checkpoint_file]
209
+ else:
210
+ if "safetensors" in checkpoint_file:
211
+ state_dict = load_file(checkpoint_file)
212
+ else:
213
+ state_dict = torch.load(checkpoint_file, map_location="cpu")
214
+
215
+ for key in state_dict:
216
+ if len(state_dict[key].shape) == 4:
217
+ # converted_unet_checkpoint[key] = converted_unet_checkpoint[key].to(torch.float16).to("cuda").permute(0,2,3,1).contiguous().cpu()
218
+ state_dict[key] = state_dict[key].to(
219
+ torch.float16).permute(0, 2, 3, 1).contiguous()
220
+ state_dict[key] = state_dict[key].to(torch.float16)
221
+ self.unet_cache[checkpoint_file] = state_dict
222
+
223
+ self.unet.reload_unet_model_from_cache(state_dict, "cpu")
224
+ self.load_embedding_weight_v2(self.add_embedding, state_dict)
225
+
226
+ def load_embedding_weight_v2(self, model, state_dict):
227
+ sub_state_dict = {}
228
+ for k in state_dict:
229
+ if k.startswith("add_embedding"):
230
+ v = state_dict[k]
231
+ sub_k = ".".join(k.split(".")[1:])
232
+ sub_state_dict[sub_k] = v
233
+
234
+ model.load_state_dict(sub_state_dict)
235
+
236
+ def reload_vae_model_v2(self, model_path):
237
+ self.vae.reload_vae_model_v2(model_path)
238
+
239
+ def load_controlnet_model_v2(self, model_name, controlnet_path):
240
+ checkpoint_file = os.path.join(
241
+ controlnet_path, "diffusion_pytorch_model.bin")
242
+ if not os.path.exists(checkpoint_file):
243
+ checkpoint_file = os.path.join(
244
+ controlnet_path, "diffusion_pytorch_model.safetensors")
245
+ if checkpoint_file in self.controlnet_cache:
246
+ state_dict = self.controlnet_cache[checkpoint_file]
247
+ else:
248
+ if "safetensors" in checkpoint_file:
249
+ state_dict = load_file(checkpoint_file)
250
+ else:
251
+ state_dict = torch.load(checkpoint_file, map_location="cpu")
252
+
253
+ for key in state_dict:
254
+ if len(state_dict[key].shape) == 4:
255
+ # converted_unet_checkpoint[key] = converted_unet_checkpoint[key].to(torch.float16).to("cuda").permute(0,2,3,1).contiguous().cpu()
256
+ state_dict[key] = state_dict[key].to(
257
+ torch.float16).permute(0, 2, 3, 1).contiguous()
258
+ state_dict[key] = state_dict[key].to(torch.float16)
259
+ self.controlnet_cache[checkpoint_file] = state_dict
260
+
261
+ self.unet.load_controlnet_model_from_state_dict(
262
+ model_name, state_dict, "cpu")
263
+
264
+ add_embedding = TimestepEmbedding(
265
+ self.projection_class_embeddings_input_dim, self.time_embed_dim).to(self.dtype).to(self.device)
266
+
267
+ self.load_embedding_weight_v2(add_embedding, state_dict)
268
+ self.controlnet_add_embedding[model_name] = add_embedding
269
+
270
+ def unload_controlnet_model(self, model_name):
271
+ self.unet.unload_controlnet_model(model_name, True)
272
+ del self.controlnet_add_embedding[model_name]
273
+
274
+ def get_loaded_controlnet(self):
275
+ return self.unet.get_loaded_controlnet()
lyrasd_model/lyrasdxl_txt2img_inpaint_pipeline.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+ import time
4
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
5
+
6
+ import gc
7
+ import torch
8
+ import numpy as np
9
+ from glob import glob
10
+
11
+ from diffusers import StableDiffusionXLInpaintPipeline, UNet2DConditionModel
12
+ from diffusers.loaders import TextualInversionLoaderMixin
13
+ from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
14
+ from diffusers.models import AutoencoderKL
15
+ from diffusers.schedulers import (DPMSolverMultistepScheduler,
16
+ EulerAncestralDiscreteScheduler,
17
+ EulerDiscreteScheduler,
18
+ KarrasDiffusionSchedulers)
19
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
20
+ from diffusers.utils.torch_utils import randn_tensor
21
+ from diffusers.utils import logging
22
+ from PIL import Image
23
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
24
+ from .lyrasd_vae_model import LyraSdVaeModel
25
+ from .module.lyrasd_ip_adapter import LyraIPAdapter
26
+ from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict
27
+ from safetensors.torch import load_file
28
+
29
+ from .lyrasdxl_pipeline_base import LyraSDXLPipelineBase
30
+
31
+
32
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
33
+ """
34
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
35
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
36
+ """
37
+ std_text = noise_pred_text.std(
38
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
39
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
40
+ # rescale the results from guidance (fixes overexposure)
41
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
42
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
43
+ noise_cfg = guidance_rescale * noise_pred_rescaled + \
44
+ (1 - guidance_rescale) * noise_cfg
45
+ return noise_cfg
46
+
47
+ def retrieve_latents(
48
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
49
+ ):
50
+ if sample_mode == "sample":
51
+ return encoder_output.sample(generator)
52
+ elif sample_mode == "argmax":
53
+ return encoder_output.mode()
54
+ else:
55
+ return encoder_output
56
+
57
+
58
+ def retrieve_timesteps(
59
+ scheduler,
60
+ num_inference_steps: Optional[int] = None,
61
+ device: Optional[Union[str, torch.device]] = None,
62
+ timesteps: Optional[List[int]] = None,
63
+ **kwargs,
64
+ ):
65
+ """
66
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
67
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
68
+
69
+ Args:
70
+ scheduler (`SchedulerMixin`):
71
+ The scheduler to get timesteps from.
72
+ num_inference_steps (`int`):
73
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
74
+ `timesteps` must be `None`.
75
+ device (`str` or `torch.device`, *optional*):
76
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
77
+ timesteps (`List[int]`, *optional*):
78
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
79
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
80
+ must be `None`.
81
+
82
+ Returns:
83
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
84
+ second element is the number of inference steps.
85
+ """
86
+ if timesteps is not None:
87
+ accepts_timesteps = "timesteps" in set(
88
+ inspect.signature(scheduler.set_timesteps).parameters.keys())
89
+ if not accepts_timesteps:
90
+ raise ValueError(
91
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
92
+ f" timestep schedules. Please check whether you are using the correct scheduler."
93
+ )
94
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
95
+ timesteps = scheduler.timesteps
96
+ num_inference_steps = len(timesteps)
97
+ else:
98
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
99
+ timesteps = scheduler.timesteps
100
+ return timesteps, num_inference_steps
101
+
102
+
103
+ class LyraSdXLTxt2ImgInpaintPipeline(LyraSDXLPipelineBase, StableDiffusionXLInpaintPipeline):
104
+ device = torch.device("cpu")
105
+ dtype = torch.float32
106
+
107
+ def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.13025, num_channels_unet=9, num_channels_latents=4, requires_aesthetics_score: bool = False,
108
+ force_zeros_for_empty_prompt: bool = True) -> None:
109
+ self.register_to_config(
110
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
111
+ self.register_to_config(
112
+ requires_aesthetics_score=requires_aesthetics_score)
113
+
114
+ super().__init__(device, dtype, num_channels_unet=num_channels_unet, num_channels_latents=num_channels_latents, vae_scale_factor=vae_scale_factor, vae_scaling_factor=vae_scaling_factor)
115
+
116
+
117
+ def encode_image(self, image, device, num_images_per_prompt):
118
+ dtype = next(self.image_encoder.parameters()).dtype
119
+ if not isinstance(image, torch.Tensor):
120
+ image = self.feature_extractor(
121
+ image, return_tensors="pt").pixel_values
122
+
123
+ image = image.to(device=device, dtype=dtype)
124
+ image_embeds = self.image_encoder(image).image_embeds
125
+ image_embeds = image_embeds.repeat_interleave(
126
+ num_images_per_prompt, dim=0)
127
+
128
+ uncond_image_embeds = torch.zeros_like(image_embeds)
129
+ return image_embeds, uncond_image_embeds
130
+
131
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
132
+ dtype = image.dtype
133
+ # if self.vae.config.force_upcast:
134
+ # image = image.float()
135
+ # self.vae.to(dtype=torch.float32)
136
+
137
+ if isinstance(generator, list):
138
+ image_latents = [
139
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
140
+ for i in range(image.shape[0])
141
+ ]
142
+ image_latents = torch.cat(image_latents, dim=0)
143
+ else:
144
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
145
+
146
+ image_latents = image_latents.to(dtype)
147
+ image_latents = self.vae.scaling_factor * image_latents
148
+
149
+ return image_latents
150
+
151
+ def _get_add_time_ids(
152
+ self,
153
+ original_size,
154
+ crops_coords_top_left,
155
+ target_size,
156
+ aesthetic_score,
157
+ negative_aesthetic_score,
158
+ negative_original_size,
159
+ negative_crops_coords_top_left,
160
+ negative_target_size,
161
+ dtype,
162
+ text_encoder_projection_dim=None,
163
+ ):
164
+ if self.config.requires_aesthetics_score:
165
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
166
+ add_neg_time_ids = list(
167
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
168
+ )
169
+ else:
170
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
171
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
172
+
173
+ passed_add_embed_dim = (
174
+ self.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
175
+ )
176
+ expected_add_embed_dim = self.add_embedding.linear_1.in_features
177
+
178
+ if (
179
+ expected_add_embed_dim > passed_add_embed_dim
180
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.addition_time_embed_dim
181
+ ):
182
+ raise ValueError(
183
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
184
+ )
185
+ elif (
186
+ expected_add_embed_dim < passed_add_embed_dim
187
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.addition_time_embed_dim
188
+ ):
189
+ raise ValueError(
190
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
191
+ )
192
+ elif expected_add_embed_dim != passed_add_embed_dim:
193
+ raise ValueError(
194
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
195
+ )
196
+
197
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
198
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
199
+
200
+ return add_time_ids, add_neg_time_ids
201
+
202
+ def load_ip_adapter(self, dir_ip_adapter, ip_plus, image_encoder_path, num_ip_tokens, ip_projection_dim, dir_face_in=None, num_fp_tokens=1, fp_projection_dim=None, sdxl=True):
203
+ self.ip_adapter_helper = LyraIPAdapter(self, sdxl, "cuda", dir_ip_adapter, ip_plus, image_encoder_path,
204
+ num_ip_tokens, ip_projection_dim, dir_face_in, num_fp_tokens, fp_projection_dim)
205
+
206
+ @torch.no_grad()
207
+ def __call__(
208
+ self,
209
+ prompt: Union[str, List[str]] = None,
210
+ prompt_2: Optional[Union[str, List[str]]] = None,
211
+ image: PipelineImageInput = None,
212
+ mask_image: PipelineImageInput = None,
213
+ masked_image_latents: torch.FloatTensor = None,
214
+ height: Optional[int] = None,
215
+ width: Optional[int] = None,
216
+ strength: float = 0.9999,
217
+ num_inference_steps: int = 50,
218
+ timesteps: List[int] = None,
219
+ denoising_start: Optional[float] = None,
220
+ denoising_end: Optional[float] = None,
221
+ guidance_scale: float = 7.5,
222
+ negative_prompt: Optional[Union[str, List[str]]] = None,
223
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
224
+ num_images_per_prompt: Optional[int] = 1,
225
+ eta: float = 0.0,
226
+ generator: Optional[Union[torch.Generator,
227
+ List[torch.Generator]]] = None,
228
+ latents: Optional[torch.FloatTensor] = None,
229
+ prompt_embeds: Optional[torch.FloatTensor] = None,
230
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
231
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
232
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
233
+ output_type: Optional[str] = "pil",
234
+ return_dict: bool = True,
235
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
236
+ guidance_rescale: float = 0.0,
237
+ original_size: Tuple[int, int] = None,
238
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
239
+ target_size: Tuple[int, int] = None,
240
+ negative_original_size: Optional[Tuple[int, int]] = None,
241
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
242
+ negative_target_size: Optional[Tuple[int, int]] = None,
243
+ aesthetic_score: float = 6.0,
244
+ negative_aesthetic_score: float = 2.5,
245
+ clip_skip: Optional[int] = None,
246
+ extra_tensor_dict: Optional[Dict[str, torch.FloatTensor]] = {},
247
+ param_scale_dict: Optional[Dict[str, int]] = {},
248
+ **kwargs
249
+ ):
250
+
251
+ callback = kwargs.pop("callback", None)
252
+ callback_steps = kwargs.pop("callback_steps", None)
253
+
254
+ # 0. Default height and width to unet
255
+ height = height or self.default_sample_size * self.vae_scale_factor
256
+ width = width or self.default_sample_size * self.vae_scale_factor
257
+
258
+ original_size = original_size or (height, width)
259
+ target_size = target_size or (height, width)
260
+
261
+ self._guidance_scale = guidance_scale
262
+ self._guidance_rescale = guidance_rescale
263
+ self._clip_skip = clip_skip
264
+ self._cross_attention_kwargs = cross_attention_kwargs
265
+ self._denoising_end = denoising_end
266
+ self._denoising_start = denoising_start
267
+
268
+ # 1. Check inputs. Raise error if not correct
269
+ self.check_inputs(
270
+ prompt,
271
+ prompt_2,
272
+ height,
273
+ width,
274
+ strength,
275
+ callback_steps,
276
+ negative_prompt,
277
+ negative_prompt_2,
278
+ prompt_embeds,
279
+ negative_prompt_embeds,
280
+ )
281
+
282
+ # 2. Define call parameters
283
+ if prompt is not None and isinstance(prompt, str):
284
+ batch_size = 1
285
+ elif prompt is not None and isinstance(prompt, list):
286
+ batch_size = len(prompt)
287
+ else:
288
+ batch_size = prompt_embeds.shape[0]
289
+
290
+ device = self._execution_device
291
+
292
+ do_classifier_free_guidance = guidance_scale > 1.0
293
+
294
+ # 3. Encode input prompt
295
+ text_encoder_lora_scale = (
296
+ cross_attention_kwargs.get(
297
+ "scale", None) if cross_attention_kwargs is not None else None
298
+ )
299
+ (
300
+ prompt_embeds,
301
+ negative_prompt_embeds,
302
+ pooled_prompt_embeds,
303
+ negative_pooled_prompt_embeds,
304
+ ) = self.encode_prompt(
305
+ prompt=prompt,
306
+ prompt_2=prompt_2,
307
+ device=device,
308
+ num_images_per_prompt=num_images_per_prompt,
309
+ do_classifier_free_guidance=do_classifier_free_guidance,
310
+ negative_prompt=negative_prompt,
311
+ negative_prompt_2=negative_prompt_2,
312
+ prompt_embeds=prompt_embeds,
313
+ negative_prompt_embeds=negative_prompt_embeds,
314
+ pooled_prompt_embeds=pooled_prompt_embeds,
315
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
316
+ lora_scale=text_encoder_lora_scale,
317
+ clip_skip=clip_skip
318
+ )
319
+
320
+ def denoising_value_valid(dnv):
321
+ return isinstance(self.denoising_end, float) and 0 < dnv < 1
322
+
323
+ # 4. Prepare timesteps
324
+ timesteps, num_inference_steps = retrieve_timesteps(
325
+ self.scheduler, num_inference_steps, device, timesteps)
326
+ timesteps, num_inference_steps = self.get_timesteps(
327
+ num_inference_steps,
328
+ strength,
329
+ device,
330
+ denoising_start=self.denoising_start if denoising_value_valid else None,
331
+ )
332
+
333
+ latent_timestep = timesteps[:1].repeat(
334
+ batch_size * num_images_per_prompt)
335
+ is_strength_max = strength == 1.0
336
+
337
+ # 5. Prepare latent variables
338
+
339
+ init_image = self.image_processor.preprocess(
340
+ image, height=height, width=width)
341
+ init_image = init_image.to(dtype=torch.float32)
342
+
343
+ mask = self.mask_processor.preprocess(
344
+ mask_image, height=height, width=width)
345
+
346
+ if masked_image_latents is not None:
347
+ masked_image = masked_image_latents
348
+ elif init_image.shape[1] == 4:
349
+ # if images are in latent space, we can't mask it
350
+ masked_image = None
351
+ else:
352
+ masked_image = init_image * (mask < 0.5)
353
+
354
+ add_noise = True if self.denoising_start is None else False
355
+
356
+ return_image_latents = self.num_channels_unet == 4
357
+
358
+ latents_outputs = self.prepare_latents(
359
+ batch_size * num_images_per_prompt,
360
+ self.num_channels_latents,
361
+ height,
362
+ width,
363
+ prompt_embeds.dtype,
364
+ device,
365
+ generator,
366
+ latents,
367
+ image=init_image,
368
+ timestep=latent_timestep,
369
+ is_strength_max=is_strength_max,
370
+ add_noise=add_noise,
371
+ return_noise=True,
372
+ return_image_latents=return_image_latents,
373
+ )
374
+
375
+ if return_image_latents:
376
+ latents, noise, image_latents = latents_outputs
377
+ else:
378
+ latents, noise = latents_outputs
379
+
380
+ mask, masked_image_latents = self.prepare_mask_latents(
381
+ mask,
382
+ masked_image,
383
+ batch_size * num_images_per_prompt,
384
+ height,
385
+ width,
386
+ prompt_embeds.dtype,
387
+ device,
388
+ generator,
389
+ do_classifier_free_guidance,
390
+ )
391
+
392
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
393
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
394
+
395
+ # 7. Prepare added time ids & embeddings
396
+ add_text_embeds = pooled_prompt_embeds
397
+ if self.text_encoder_2 is None:
398
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
399
+ else:
400
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
401
+
402
+ if negative_original_size is None:
403
+ negative_original_size = original_size
404
+ if negative_target_size is None:
405
+ negative_target_size = target_size
406
+
407
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
408
+ original_size,
409
+ crops_coords_top_left,
410
+ target_size,
411
+ aesthetic_score,
412
+ negative_aesthetic_score,
413
+ negative_original_size,
414
+ negative_crops_coords_top_left,
415
+ negative_target_size,
416
+ dtype=prompt_embeds.dtype,
417
+ text_encoder_projection_dim=text_encoder_projection_dim,
418
+ )
419
+ add_time_ids = add_time_ids.repeat(
420
+ batch_size * num_images_per_prompt, 1)
421
+
422
+ if do_classifier_free_guidance:
423
+ prompt_embeds = torch.cat(
424
+ [negative_prompt_embeds, prompt_embeds], dim=0)
425
+ add_text_embeds = torch.cat(
426
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0)
427
+ add_neg_time_ids = add_neg_time_ids.repeat(
428
+ batch_size * num_images_per_prompt, 1)
429
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
430
+
431
+ prompt_embeds = prompt_embeds.to(device)
432
+ add_text_embeds = add_text_embeds.to(device)
433
+ add_time_ids = add_time_ids.to(device)
434
+
435
+ # 8. Denoising loop
436
+ num_warmup_steps = max(
437
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0)
438
+
439
+ # 7.1 Apply denoising_end
440
+ if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
441
+ discrete_timestep_cutoff = int(
442
+ round(
443
+ self.scheduler.config.num_train_timesteps
444
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
445
+ )
446
+ )
447
+ num_inference_steps = len(
448
+ list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
449
+ timesteps = timesteps[:num_inference_steps]
450
+
451
+ aug_emb = self._get_aug_emb(
452
+ add_time_ids, add_text_embeds, prompt_embeds.dtype)
453
+
454
+ extra_tensor_dict2 = {}
455
+ for name in extra_tensor_dict:
456
+ if name in ["fp_hidden_states", "ip_hidden_states"]:
457
+ v1, v2 = extra_tensor_dict[name][0], extra_tensor_dict[name][1]
458
+ extra_tensor_dict2[name] = torch.cat(
459
+ [v1.repeat(num_images_per_prompt, 1, 1), v2.repeat(num_images_per_prompt, 1, 1)])
460
+ else:
461
+ extra_tensor_dict2[name] = extra_tensor_dict[name]
462
+
463
+ # np.save("/workspace/prompt_embeds.npy", prompt_embeds.detach().cpu().numpy())
464
+ # prompt_embeds = torch.from_numpy(np.load("/workspace/gt_prompt_embeds.npy")).cuda()
465
+ self._num_timesteps = len(timesteps)
466
+
467
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
468
+ for i, t in enumerate(timesteps):
469
+ # expand the latents if we are doing classifier free guidance
470
+ latent_model_input = torch.cat(
471
+ [latents] * 2) if do_classifier_free_guidance else latents
472
+
473
+ latent_model_input = self.scheduler.scale_model_input(
474
+ latent_model_input, t)
475
+
476
+ if self.num_channels_unet == 9:
477
+ latent_model_input = torch.cat(
478
+ [latent_model_input, mask, masked_image_latents], dim=1)
479
+
480
+ latent_model_input = latent_model_input.permute(
481
+ 0, 2, 3, 1).contiguous()
482
+
483
+ noise_pred = self.unet.forward(latent_model_input, prompt_embeds, t, aug_emb, None, None,
484
+ None, None, None, extra_tensor_dict2, param_scale_dict).permute(0, 3, 1, 2).contiguous()
485
+
486
+ # perform guidance
487
+ if do_classifier_free_guidance:
488
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
489
+ noise_pred = noise_pred_uncond + self.guidance_scale * \
490
+ (noise_pred_text - noise_pred_uncond)
491
+
492
+ if do_classifier_free_guidance and self.guidance_rescale > 0.0:
493
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
494
+ noise_pred = rescale_noise_cfg(
495
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
496
+
497
+ # compute the previous noisy sample x_t -> x_t-1
498
+ latents = self.scheduler.step(
499
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
500
+
501
+ if self.num_channels_unet == 4:
502
+ init_latents_proper = image_latents
503
+ if do_classifier_free_guidance:
504
+ init_mask, _ = mask.chunk(2)
505
+ else:
506
+ init_mask = mask
507
+
508
+ if i < len(timesteps) - 1:
509
+ noise_timestep = timesteps[i + 1]
510
+ init_latents_proper = self.scheduler.add_noise(
511
+ init_latents_proper, noise, torch.tensor(
512
+ [noise_timestep])
513
+ )
514
+
515
+ latents = (1 - init_mask) * \
516
+ init_latents_proper + init_mask * latents
517
+
518
+ # call the callback, if provided
519
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
520
+ progress_bar.update()
521
+ if callback is not None and i % callback_steps == 0:
522
+ callback(i, t, latents)
523
+
524
+ if output_type == "latent":
525
+ return latents
526
+
527
+ image = self.vae.decode(1 / self.vae.scaling_factor * latents)
528
+ image = self.image_processor.postprocess(
529
+ image, output_type=output_type)
530
+
531
+ # Offload last model to CPU
532
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
533
+ self.final_offload_hook.offload()
534
+
535
+ return image
lyrasd_model/lyrasdxl_txt2img_pipeline.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+ import time
4
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
5
+
6
+ import gc
7
+ import torch
8
+ import numpy as np
9
+ from glob import glob
10
+
11
+ from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
12
+ from diffusers.loaders import TextualInversionLoaderMixin
13
+ from diffusers.image_processor import VaeImageProcessor
14
+ from diffusers.models import AutoencoderKL
15
+ from diffusers.schedulers import (DPMSolverMultistepScheduler,
16
+ EulerAncestralDiscreteScheduler,
17
+ EulerDiscreteScheduler,
18
+ KarrasDiffusionSchedulers)
19
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
20
+ from diffusers.utils.torch_utils import randn_tensor
21
+ from diffusers.utils import logging
22
+ from PIL import Image
23
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
24
+ from .lyrasd_vae_model import LyraSdVaeModel
25
+ from .module.lyrasd_ip_adapter import LyraIPAdapter
26
+ from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict
27
+ from safetensors.torch import load_file
28
+ from .lyrasdxl_pipeline_base import LyraSDXLPipelineBase
29
+
30
+
31
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
32
+ """
33
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
34
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
35
+ """
36
+ std_text = noise_pred_text.std(
37
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
38
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
39
+ # rescale the results from guidance (fixes overexposure)
40
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
41
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
42
+ noise_cfg = guidance_rescale * noise_pred_rescaled + \
43
+ (1 - guidance_rescale) * noise_cfg
44
+ return noise_cfg
45
+
46
+
47
+ class LyraSdXLTxt2ImgPipeline(LyraSDXLPipelineBase, StableDiffusionXLPipeline):
48
+ device = torch.device("cpu")
49
+ dtype = torch.float32
50
+
51
+ def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.13025) -> None:
52
+ self.register_to_config(force_zeros_for_empty_prompt=True)
53
+
54
+ super().__init__(device, dtype, vae_scale_factor=vae_scale_factor, vae_scaling_factor=vae_scaling_factor)
55
+
56
+ @torch.no_grad()
57
+ def __call__(
58
+ self,
59
+ prompt: Union[str, List[str]] = None,
60
+ prompt_2: Optional[Union[str, List[str]]] = None,
61
+ height: Optional[int] = None,
62
+ width: Optional[int] = None,
63
+ num_inference_steps: int = 50,
64
+ denoising_end: Optional[float] = None,
65
+ guidance_scale: float = 5.0,
66
+ negative_prompt: Optional[Union[str, List[str]]] = None,
67
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
68
+ num_images_per_prompt: Optional[int] = 1,
69
+ eta: float = 0.0,
70
+ generator: Optional[Union[torch.Generator,
71
+ List[torch.Generator]]] = None,
72
+ latents: Optional[torch.FloatTensor] = None,
73
+ prompt_embeds: Optional[torch.FloatTensor] = None,
74
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
75
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
76
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
77
+ output_type: Optional[str] = "pil",
78
+ return_dict: bool = True,
79
+ callback: Optional[Callable[[
80
+ int, int, torch.FloatTensor], None]] = None,
81
+ callback_steps: int = 1,
82
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
83
+ guidance_rescale: float = 0.0,
84
+ original_size: Optional[Tuple[int, int]] = None,
85
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
86
+ target_size: Optional[Tuple[int, int]] = None,
87
+ extra_tensor_dict: Optional[Dict[str, torch.FloatTensor]] = {},
88
+ param_scale_dict: Optional[Dict[str, int]] = {},
89
+ clip_skip: Optional[int] = None
90
+ ):
91
+
92
+ # 0. Default height and width to unet
93
+ height = height or self.default_sample_size * self.vae_scale_factor
94
+ width = width or self.default_sample_size * self.vae_scale_factor
95
+
96
+ original_size = original_size or (height, width)
97
+ target_size = target_size or (height, width)
98
+
99
+ # 1. Check inputs. Raise error if not correct
100
+ self.check_inputs(
101
+ prompt,
102
+ prompt_2,
103
+ height,
104
+ width,
105
+ callback_steps,
106
+ negative_prompt,
107
+ negative_prompt_2,
108
+ prompt_embeds,
109
+ negative_prompt_embeds,
110
+ pooled_prompt_embeds,
111
+ negative_pooled_prompt_embeds,
112
+ )
113
+
114
+ # 2. Define call parameters
115
+ if prompt is not None and isinstance(prompt, str):
116
+ batch_size = 1
117
+ elif prompt is not None and isinstance(prompt, list):
118
+ batch_size = len(prompt)
119
+ else:
120
+ batch_size = prompt_embeds.shape[0]
121
+
122
+ device = self._execution_device
123
+
124
+ do_classifier_free_guidance = guidance_scale > 1.0
125
+
126
+ # 3. Encode input prompt
127
+ text_encoder_lora_scale = (
128
+ cross_attention_kwargs.get(
129
+ "scale", None) if cross_attention_kwargs is not None else None
130
+ )
131
+ (
132
+ prompt_embeds,
133
+ negative_prompt_embeds,
134
+ pooled_prompt_embeds,
135
+ negative_pooled_prompt_embeds,
136
+ ) = self.encode_prompt(
137
+ prompt=prompt,
138
+ prompt_2=prompt_2,
139
+ device=device,
140
+ num_images_per_prompt=num_images_per_prompt,
141
+ do_classifier_free_guidance=do_classifier_free_guidance,
142
+ negative_prompt=negative_prompt,
143
+ negative_prompt_2=negative_prompt_2,
144
+ prompt_embeds=prompt_embeds,
145
+ negative_prompt_embeds=negative_prompt_embeds,
146
+ pooled_prompt_embeds=pooled_prompt_embeds,
147
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
148
+ lora_scale=text_encoder_lora_scale,
149
+ clip_skip=clip_skip
150
+ )
151
+
152
+ # 4. Prepare timesteps
153
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
154
+
155
+ timesteps = self.scheduler.timesteps
156
+
157
+ # 5. Prepare latent variables
158
+ num_channels_latents = self.unet_in_channels
159
+ latents = self.prepare_latents(
160
+ batch_size * num_images_per_prompt,
161
+ num_channels_latents,
162
+ height,
163
+ width,
164
+ prompt_embeds.dtype,
165
+ device,
166
+ generator,
167
+ latents,
168
+ )
169
+
170
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
171
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
172
+
173
+ # 7. Prepare added time ids & embeddings
174
+ add_text_embeds = pooled_prompt_embeds
175
+ add_time_ids = list(
176
+ original_size + crops_coords_top_left + target_size)
177
+ add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
178
+
179
+ if do_classifier_free_guidance:
180
+ prompt_embeds = torch.cat(
181
+ [negative_prompt_embeds, prompt_embeds], dim=0)
182
+ add_text_embeds = torch.cat(
183
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0)
184
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
185
+
186
+ prompt_embeds = prompt_embeds.to(device)
187
+ add_text_embeds = add_text_embeds.to(device)
188
+ add_time_ids = add_time_ids.to(device).repeat(
189
+ batch_size * num_images_per_prompt, 1)
190
+
191
+ # 8. Denoising loop
192
+ num_warmup_steps = max(
193
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0)
194
+
195
+ # 7.1 Apply denoising_end
196
+ if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
197
+ discrete_timestep_cutoff = int(
198
+ round(
199
+ self.scheduler.config.num_train_timesteps
200
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
201
+ )
202
+ )
203
+ num_inference_steps = len(
204
+ list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
205
+ timesteps = timesteps[:num_inference_steps]
206
+
207
+ aug_emb = self._get_aug_emb(
208
+ add_time_ids, add_text_embeds, prompt_embeds.dtype)
209
+
210
+ extra_tensor_dict2 = {}
211
+ for name in extra_tensor_dict:
212
+ if name in ["fp_hidden_states", "ip_hidden_states"]:
213
+ v1, v2 = extra_tensor_dict[name][0], extra_tensor_dict[name][1]
214
+ extra_tensor_dict2[name] = torch.cat(
215
+ [v1.repeat(num_images_per_prompt, 1, 1), v2.repeat(num_images_per_prompt, 1, 1)])
216
+ else:
217
+ extra_tensor_dict2[name] = extra_tensor_dict[name]
218
+
219
+ # np.save("/workspace/prompt_embeds.npy", prompt_embeds.detach().cpu().numpy())
220
+ # prompt_embeds = torch.from_numpy(np.load("/workspace/gt_prompt_embeds.npy")).cuda()
221
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
222
+ for i, t in enumerate(timesteps):
223
+ # expand the latents if we are doing classifier free guidance
224
+ latent_model_input = torch.cat(
225
+ [latents] * 2) if do_classifier_free_guidance else latents
226
+
227
+ latent_model_input = self.scheduler.scale_model_input(
228
+ latent_model_input, t)
229
+ latent_model_input = latent_model_input.permute(
230
+ 0, 2, 3, 1).contiguous()
231
+
232
+ noise_pred = self.unet.forward(latent_model_input, prompt_embeds, t, aug_emb, None, None,
233
+ None, None, None, extra_tensor_dict2, param_scale_dict).permute(0, 3, 1, 2).contiguous()
234
+
235
+ # perform guidance
236
+ if do_classifier_free_guidance:
237
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
238
+ noise_pred = noise_pred_uncond + guidance_scale * \
239
+ (noise_pred_text - noise_pred_uncond)
240
+
241
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
242
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
243
+ noise_pred = rescale_noise_cfg(
244
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
245
+
246
+ # compute the previous noisy sample x_t -> x_t-1
247
+ latents = self.scheduler.step(
248
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
249
+
250
+ # call the callback, if provided
251
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
252
+ progress_bar.update()
253
+ if callback is not None and i % callback_steps == 0:
254
+ callback(i, t, latents)
255
+
256
+ if output_type == "latent":
257
+ return latents
258
+
259
+ image = self.vae.decode(1 / self.vae.scaling_factor * latents)
260
+ image = self.image_processor.postprocess(
261
+ image, output_type=output_type)
262
+
263
+ # Offload last model to CPU
264
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
265
+ self.final_offload_hook.offload()
266
+
267
+ return image
lyrasd_model/{lyrasd_lib/placeholder.txt → module/__init__.py} RENAMED
File without changes
lyrasd_model/module/lyra_tool.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import yaml
2
+
3
+ def load_yaml(cfg_path):
4
+ with open(cfg_path, 'r', encoding='utf-8') as f:
5
+ return yaml.safe_load(f)
lyrasd_model/module/lyrasd_ip_adapter.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ from typing import List
3
+
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+ from diffusers.pipelines.controlnet import MultiControlNetModel
7
+ from diffusers.models.embeddings import ImageProjection
8
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
9
+ from PIL import Image
10
+ from typing import Any, Callable, Dict, List, Optional, Union
11
+ from copy import deepcopy
12
+ import time
13
+ sys.path.append(os.path.dirname(__file__))
14
+ from resampler import Resampler
15
+ from diffusers import DiffusionPipeline
16
+ import numpy as np
17
+ # sys.path.append(os.environ['LYRASD_WORKDIR'] + "/tests/utils")
18
+ from .tools import get_mem_use
19
+
20
+ class ImageProjModel(torch.nn.Module):
21
+ """Projection Model"""
22
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
23
+ super().__init__()
24
+
25
+ self.cross_attention_dim = cross_attention_dim
26
+ self.clip_extra_context_tokens = clip_extra_context_tokens
27
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
28
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
29
+
30
+ def forward(self, image_embeds):
31
+ embeds = image_embeds
32
+ clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
33
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
34
+ return clip_extra_context_tokens
35
+
36
+
37
+ class LyraIPAdapter:
38
+ def __init__(
39
+ self,
40
+ sd_pipe,
41
+ sdxl,
42
+ device,
43
+ ip_ckpt=None,
44
+ ip_plus=False,
45
+ image_encoder_path=None,
46
+ num_ip_tokens=4,
47
+ ip_projection_dim=None,
48
+ fp_ckpt=None,
49
+ num_fp_tokens=1,
50
+ fp_projection_dim=None,
51
+ ):
52
+ self.pipe = sd_pipe
53
+ self.device = device
54
+ self.fp_ckpt = fp_ckpt
55
+ self.ip_ckpt = ip_ckpt
56
+ self.num_fp_tokens = num_fp_tokens
57
+ self.num_ip_tokens = num_ip_tokens
58
+ self.fp_projection_dim = fp_projection_dim
59
+ self.ip_projection_dim = ip_projection_dim
60
+ self.sdxl = sdxl
61
+ self.ip_plus = ip_plus
62
+ self.cross_attention_dim = 2048
63
+ # self.pipe = sd_pipe.to(self.device)
64
+ # self.set_ip_adapter()
65
+
66
+ if image_encoder_path:
67
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(self.device, dtype=torch.float16)
68
+ self.clip_image_processor = CLIPImageProcessor()
69
+ self.projection_dim = self.image_encoder.config.projection_dim
70
+
71
+ # image proj model
72
+ if self.ip_ckpt:
73
+ if self.ip_plus:
74
+ proj_heads = 20 if self.sdxl else 12
75
+ self.image_proj_model = self.init_proj_plus(proj_heads, self.num_ip_tokens)
76
+ else:
77
+ self.image_proj_model = self.init_proj(self.ip_projection_dim, self.num_ip_tokens)
78
+
79
+ # face proj model
80
+ if self.fp_ckpt:
81
+ self.face_proj_model = self.init_proj(self.fp_projection_dim, self.num_fp_tokens)
82
+
83
+ self.load_ip_adapter()
84
+
85
+ def init_proj_diffuser(self, state_dict):
86
+ # diffusers加载版本
87
+ clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
88
+ cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
89
+
90
+ image_proj_model = ImageProjection(
91
+ cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
92
+ ).to(dtype=self.dtype, device=self.device)
93
+ return image_proj_model
94
+
95
+ # init_proj / init_proj_plus 是 facein里实现的
96
+ def init_proj(self, projection_dim, num_tokens):
97
+ image_proj_model = ImageProjModel(
98
+ cross_attention_dim=self.cross_attention_dim,
99
+ clip_embeddings_dim=projection_dim,
100
+ clip_extra_context_tokens=num_tokens,
101
+ ).to(self.device, dtype=torch.float16)
102
+ return image_proj_model
103
+
104
+
105
+ def init_proj_plus(self, heads, num_tokens):
106
+ image_proj_model = Resampler(
107
+ dim=1280,
108
+ depth=4,
109
+ dim_head=64,
110
+ heads=heads,
111
+ num_queries=num_tokens,
112
+ embedding_dim=self.image_encoder.config.hidden_size,
113
+ output_dim=self.cross_attention_dim,
114
+ ff_mult=4,
115
+ ).to(self.device, dtype=torch.float16)
116
+ return image_proj_model
117
+
118
+ def load_ip_adapter(self):
119
+ unet = self.pipe.unet
120
+
121
+ def parse_ckpt_path(ckpt):
122
+ ll = ckpt.split("/")
123
+ weight_name = ll[-1]
124
+ subfolder = ll[-2]
125
+ pretrained_path = "/".join(ll[:-2])
126
+ return pretrained_path, subfolder, weight_name
127
+
128
+ if self.ip_ckpt:
129
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
130
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
131
+ pretrained_path, subfolder, weight_name = parse_ckpt_path(self.ip_ckpt)
132
+ dir_ipadapter = os.path.join(pretrained_path, "lyra_tran", subfolder, '.'.join(weight_name.split(".")[:-1]))
133
+ unet.load_ip_adapter(dir_ipadapter, "", 1, "fp16")
134
+
135
+ if self.fp_ckpt:
136
+ state_dict = torch.load(self.fp_ckpt, map_location="cpu")
137
+ self.face_proj_model.load_state_dict(state_dict["face_proj"])
138
+ pretrained_path, subfolder, weight_name = parse_ckpt_path(self.fp_ckpt)
139
+ dir_ipadapter = os.path.join(pretrained_path, "lyra_tran", subfolder, '.'.join(weight_name.split(".")[:-1]))
140
+ unet.load_facein(dir_ipadapter, "fp16")
141
+
142
+ @torch.inference_mode()
143
+ def get_image_embeds(self, image=None, face_emb=None):
144
+ image_prompt_embeds, uncond_image_prompt_embeds = None, None
145
+
146
+ if image is not None:
147
+ if not isinstance(image, list):
148
+ image = [image]
149
+ clip_image = self.clip_image_processor(images=image, return_tensors="pt").pixel_values
150
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
151
+ if self.ip_plus:
152
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
153
+ uncond_clip_image_embeds = self.image_encoder(
154
+ torch.zeros_like(clip_image), output_hidden_states=True
155
+ ).hidden_states[-2]
156
+ else:
157
+ clip_image_embeds = self.image_encoder(clip_image).image_embeds
158
+ uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds)
159
+ clip_image_prompt_embeds = self.image_proj_model(clip_image_embeds)
160
+ uncond_clip_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
161
+ image_prompt_embeds = clip_image_prompt_embeds
162
+ uncond_image_prompt_embeds = uncond_clip_image_prompt_embeds
163
+
164
+ if face_emb is not None:
165
+ face_embeds = face_emb.to(self.device, dtype=torch.float16)
166
+ face_prompt_embeds = self.face_proj_model(face_embeds)
167
+ uncond_face_prompt_embeds = self.face_proj_model(torch.zeros_like(face_embeds))
168
+ if image_prompt_embeds is None:
169
+ image_prompt_embeds = face_prompt_embeds
170
+ uncond_image_prompt_embeds = uncond_face_prompt_embeds
171
+ else:
172
+ image_prompt_embeds = torch.cat([face_prompt_embeds, image_prompt_embeds], axis=1)
173
+ uncond_image_prompt_embeds = torch.cat([uncond_face_prompt_embeds, uncond_image_prompt_embeds], dim=1)
174
+
175
+ return image_prompt_embeds, uncond_image_prompt_embeds
176
+
177
+ @torch.inference_mode()
178
+ def get_image_embeds_lyrasd(self, image=None, ip_image_embeds=None, face_emb=None, batch_size = 1, ip_scale=1.0, fp_scale=1.0, do_classifier_free_guidance=True):
179
+ dict_tensor = {}
180
+
181
+ if self.ip_ckpt and ip_scale>0:
182
+ if ip_image_embeds is not None:
183
+ dict_tensor["ip_hidden_states"] = ip_image_embeds
184
+ elif image is not None:
185
+ if not isinstance(image, list):
186
+ image = [image]
187
+ clip_image = self.clip_image_processor(images=image, return_tensors="pt").pixel_values
188
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
189
+ if self.ip_plus:
190
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
191
+ uncond_clip_image_embeds = self.image_encoder(
192
+ torch.zeros_like(clip_image), output_hidden_states=True
193
+ ).hidden_states[-2]
194
+ else:
195
+ clip_image_embeds = self.image_encoder(clip_image).image_embeds
196
+ uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds)
197
+
198
+ if do_classifier_free_guidance:
199
+ clip_image_embeds = torch.cat([uncond_clip_image_embeds, clip_image_embeds])
200
+ ip_image_embeds = self.image_proj_model(clip_image_embeds)
201
+ dict_tensor["ip_hidden_states"] = ip_image_embeds
202
+
203
+ if face_emb is not None and self.fp_ckpt and ip_scale>0:
204
+ face_embeds = face_emb.to(self.device, dtype=torch.float16)
205
+ face_prompt_embeds = self.face_proj_model(face_embeds)
206
+ uncond_face_prompt_embeds = self.face_proj_model(torch.zeros_like(face_embeds))
207
+ if do_classifier_free_guidance:
208
+ fp_image_embeds = torch.cat([uncond_face_prompt_embeds, face_prompt_embeds])
209
+ else:
210
+ fp_image_embeds = face_prompt_embeds
211
+ dict_tensor["fp_hidden_states"] = fp_image_embeds
212
+ return dict_tensor
213
+
214
+
215
+ if __name__ == "__main__":
216
+ sys.path.append("/data/home/kiokaxiao/repos/LyraSD/python/lyrasd")
217
+ from lyrasd_model import LyraSdXLTxt2ImgPipeline
218
+
219
+ model_path = "/data/SharedModels/SD/checkpoints/stable-diffusion-xl-base-1.0/"
220
+ # model_path = "/cfs-datasets/projects/VirtualIdol/models/base_model/sdxl/xxmix9realisticsdxlV1"
221
+ lib_path = os.environ.get("LIBLYRASD_SO")
222
+
223
+ dir_ip_adapter = "/cfs-datasets/projects/VirtualIdol/models/ip_adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin"
224
+ dir_facein = "/cfs-datasets/projects/VirtualIdol/models/FaceIn/v1/FaceIn_sdxl.bin"
225
+ image_encoder_path = "/cfs-datasets/projects/VirtualIdol/models/ip_adapter/models/image_encoder"
226
+
227
+ pipeline = LyraSdXLTxt2ImgPipeline(model_path, lib_path)
228
+ pipeline.load_ip_adapter(dir_ip_adapter, True, image_encoder_path, 16,1024, dir_facein, 1, 512)
229
+ # pipeline.load_ip_adapter(dir_ip_adapter, True, image_encoder_path, 16,1024, "", 1, 512)
230
+
231
+ face_emb = np.load("/data/home/kiokaxiao/repos/VidolImageDraw/girl.npy")
232
+ face_emb = torch.Tensor(face_emb.reshape([1,-1]))
233
+ ip_image = Image.open("/data/home/kiokaxiao/repos/VidolImageDraw/images/input_image.png").convert('RGB')
234
+
235
+ generator = torch.Generator("cuda").manual_seed(123)
236
+ batches = [2]
237
+ sizes = [[512, 512], [768, 768], [1024, 1024]]
238
+ # sizes = [[832, 640]]
239
+ # sizes = [[1024, 1024]]
240
+ running_cnt = 1
241
+ do_bench = False
242
+
243
+ ip_ratio = 1
244
+ facein_ratio = 0.6
245
+ extra_tensor_dict = {}
246
+ extra_tensor_dict = pipeline.ip_adapter_helper.get_image_embeds_lyrasd(ip_image, None, face_emb, batches[0], ip_ratio, facein_ratio)
247
+ param_scale_dict = {"facein_ratio": facein_ratio, "ip_ratio": ip_ratio}
248
+ draw_cfg = {'width': 640,
249
+ 'num_inference_steps': 30,
250
+ 'height': 832,
251
+ 'negative_prompt': '(worst quality, low quality, 3d, 2d, cartoons, sketch), tooth, open mouth',
252
+ 'guidance_scale': 7,
253
+ 'prompt': 'xxmixgirl, masterpiece, best quality, 1girl, solo, looking at viewer, simple background, hair ornament, black eyes, portrait',
254
+ 'output_type': 'pil',
255
+ 'extra_tensor_dict': extra_tensor_dict,
256
+ "param_scale_dict": param_scale_dict}
257
+
258
+
259
+ def warmup(draw_cfg):
260
+ draw_cfg_wm = deepcopy(draw_cfg)
261
+ draw_cfg_wm['num_inference_steps'] = 1
262
+ pipeline(**draw_cfg_wm, generator= generator)
263
+
264
+ if not do_bench:
265
+ images = pipeline(**draw_cfg, generator= generator)
266
+ else:
267
+ for batch in batches:
268
+ for height, width in sizes:
269
+ draw_cfg['width'] = width
270
+ draw_cfg['height'] = height
271
+ draw_cfg['num_images_per_prompt'] = batch
272
+ draw_cfg["num_inference_steps"] = 20
273
+ warmup(draw_cfg)
274
+ time_uses = []
275
+ for x in range(running_cnt):
276
+ start = time.perf_counter()
277
+ draw_cfg['num_images_per_prompt'] = batch
278
+ generator = torch.Generator("cuda").manual_seed(123)
279
+ print("draw_cfg: ", draw_cfg.keys())
280
+ print("draw_cfg: ", draw_cfg)
281
+
282
+ images = pipeline(**draw_cfg, generator= generator)
283
+ time_use = time.perf_counter() - start
284
+ time_uses.append(time_use)
285
+ print("bench", batch, width, sum(time_uses)/running_cnt, get_mem_use())
286
+
287
+ print(type(images))
288
+ images[0].save("t.png")
289
+
lyrasd_model/module/resampler.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ # FFN
9
+ def FeedForward(dim, mult=4):
10
+ inner_dim = int(dim * mult)
11
+ return nn.Sequential(
12
+ nn.LayerNorm(dim),
13
+ nn.Linear(dim, inner_dim, bias=False),
14
+ nn.GELU(),
15
+ nn.Linear(inner_dim, dim, bias=False),
16
+ )
17
+
18
+
19
+ def reshape_tensor(x, heads):
20
+ bs, length, width = x.shape
21
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
22
+ x = x.view(bs, length, heads, -1)
23
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
24
+ x = x.transpose(1, 2)
25
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
26
+ x = x.reshape(bs, heads, length, -1)
27
+ return x
28
+
29
+
30
+ class PerceiverAttention(nn.Module):
31
+ def __init__(self, *, dim, dim_head=64, heads=8):
32
+ super().__init__()
33
+ self.scale = dim_head**-0.5
34
+ self.dim_head = dim_head
35
+ self.heads = heads
36
+ inner_dim = dim_head * heads
37
+
38
+ self.norm1 = nn.LayerNorm(dim)
39
+ self.norm2 = nn.LayerNorm(dim)
40
+
41
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
42
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
43
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
44
+
45
+
46
+ def forward(self, x, latents):
47
+ """
48
+ Args:
49
+ x (torch.Tensor): image features
50
+ shape (b, n1, D)
51
+ latent (torch.Tensor): latent features
52
+ shape (b, n2, D)
53
+ """
54
+ x = self.norm1(x)
55
+ latents = self.norm2(latents)
56
+
57
+ b, l, _ = latents.shape
58
+
59
+ q = self.to_q(latents)
60
+ kv_input = torch.cat((x, latents), dim=-2)
61
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
62
+
63
+ q = reshape_tensor(q, self.heads)
64
+ k = reshape_tensor(k, self.heads)
65
+ v = reshape_tensor(v, self.heads)
66
+
67
+ # attention
68
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
69
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
70
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
71
+ out = weight @ v
72
+
73
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
74
+
75
+ return self.to_out(out)
76
+
77
+
78
+ class Resampler(nn.Module):
79
+ def __init__(
80
+ self,
81
+ dim=1024,
82
+ depth=8,
83
+ dim_head=64,
84
+ heads=16,
85
+ num_queries=8,
86
+ embedding_dim=768,
87
+ output_dim=1024,
88
+ ff_mult=4,
89
+ ):
90
+ super().__init__()
91
+
92
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
93
+
94
+ self.proj_in = nn.Linear(embedding_dim, dim)
95
+
96
+ self.proj_out = nn.Linear(dim, output_dim)
97
+ self.norm_out = nn.LayerNorm(output_dim)
98
+
99
+ self.layers = nn.ModuleList([])
100
+ for _ in range(depth):
101
+ self.layers.append(
102
+ nn.ModuleList(
103
+ [
104
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
105
+ FeedForward(dim=dim, mult=ff_mult),
106
+ ]
107
+ )
108
+ )
109
+
110
+ def forward(self, x):
111
+
112
+ latents = self.latents.repeat(x.size(0), 1, 1)
113
+
114
+ x = self.proj_in(x)
115
+ print("layers: ", len(self.layers))
116
+ for attn, ff in self.layers:
117
+ latents = attn(x, latents) + latents
118
+ latents = ff(latents) + latents
119
+
120
+ latents = self.proj_out(latents)
121
+ return self.norm_out(latents)
lyrasd_model/module/tools.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import os, sys
4
+ import time
5
+
6
+ class LyraChecker:
7
+ def __init__(self, dir_data, tol):
8
+ self.dir_data = dir_data
9
+ self.tol = tol
10
+
11
+ def cmp(self, fpath1, fpath2="", tol=0):
12
+ tolbk = self.tol
13
+ if tol != 0:
14
+ self.tol = tol
15
+ if fpath2 == "":
16
+ fpath2 = fpath1
17
+ fpath1 += "_1"
18
+ fpath2 += "_2"
19
+ v1 = self.get_npy(fpath1) #np.load(os.path.join(self.dir_data, fpath1))
20
+ v2 = self.get_npy(fpath2) #np.load(os.path.join(self.dir_data, fpath2))
21
+ name = fpath1
22
+ if ".npy" in fpath1:
23
+ name = ".".join(os.path.basename(fpath1).split(".")[:-1])
24
+ self._cmp_inner(v1, v2, name)
25
+ self.tol = tolbk
26
+
27
+ def _cmp_inner(self, v1, v2, name):
28
+ print(v1.shape, v2.shape)
29
+ if v1.shape != v2.shape:
30
+ if v1.shape[1] == v2.shape[1]:
31
+ v2 = v2.reshape([v2.shape[0], v2.shape[1], -1])
32
+ else:
33
+ v2 = torch.tensor(v2).permute(0, 3, 1, 2).numpy()
34
+ print(v1.shape, v2.shape)
35
+ self._check_data(name, v1, v2)
36
+ print(np.size(v1))
37
+
38
+ def _check_data(self, stage, x_out, x_gt):
39
+ print(f"========== {stage} =============")
40
+ print(x_out.shape, x_gt.shape)
41
+ if np.allclose(x_gt, x_out, atol=self.tol):
42
+ print(f"[OK] At {stage}, tol: {self.tol}")
43
+ else:
44
+ diff_cnt = np.count_nonzero(np.abs(x_gt - x_out)>self.tol)
45
+ print(f"[FAIL]At {stage}, not aligned. tol: {self.tol}")
46
+ print(" [INFO]Max diff: ", np.max(np.abs(x_gt - x_out)))
47
+ print(" [INFO]Diff count: ", diff_cnt, ", ratio: ", round(diff_cnt/np.size(x_out), 2))
48
+ print(f">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
49
+
50
+
51
+ def cmp_query(self, fpath1, fpath2):
52
+ v1 = np.load(os.path.join(self.dir_data, fpath1))
53
+ vk = np.load(os.path.join(self.dir_data, fpath1).replace("query", "key"))
54
+ vv = np.load(os.path.join(self.dir_data, fpath1).replace("query", "value"))
55
+
56
+ v2 = np.load(os.path.join(self.dir_data, fpath2))
57
+ # print(v1.shape, v2.shape)
58
+ q2 = v2[:,:,0,:,:].transpose([0,2,1,3])
59
+ # print(v1.shape, q2.shape)
60
+ self.check_data("query", v1, q2)
61
+ # print(vk.shape, v2.shape)
62
+ k2 = v2[:,:,1,:,:].transpose([0,2,1,3])
63
+ self.check_data("key", vk, k2)
64
+ vv2 = v2[:,:,2,:,:].transpose([0,2,1,3])
65
+ # print(vv.shape, vv2.shape)
66
+ self.check_data("value", vv, vv2)
67
+
68
+ def _get_data_fpath(self, fname):
69
+ fpath = os.path.join(self.dir_data, fname)
70
+ if not fpath.endswith(".npy"):
71
+ fpath += ".npy"
72
+ return fpath
73
+
74
+ def get_npy(self, fname):
75
+ fpath = self._get_data_fpath(fname)
76
+ return np.load(fpath)
77
+
78
+
79
+
80
+
81
+ class MkDataHelper:
82
+ def __init__(self, data_dir="/data/home/kiokaxiao/data"):
83
+ self.data_dir = data_dir
84
+
85
+ def mkdata(self, subdir, name, shape, dtype=torch.float16):
86
+ outdir = os.path.join(self.data_dir, subdir)
87
+ os.makedirs(outdir, exist_ok=True)
88
+ fpath = os.path.join(outdir, name+".npy")
89
+ data = torch.randn(shape, dtype=torch.float16)
90
+ np.save(fpath, data.to(dtype).numpy())
91
+ return data
92
+
93
+ def gen_out_with_func(self, func, inputs):
94
+ output = func(inputs)
95
+ return output
96
+
97
+ def savedata(self, subdir, name, data):
98
+ outdir = os.path.join(self.data_dir, subdir)
99
+ os.makedirs(outdir, exist_ok=True)
100
+ fpath = os.path.join(outdir, name+".npy")
101
+ np.save(fpath, data.cpu().numpy())
102
+
103
+
104
+ class TorchSaver:
105
+ def __init__(self, data_dir):
106
+ self.data_dir = data_dir
107
+ os.makedirs(self.data_dir, exist_ok=True)
108
+ self.is_save = True
109
+
110
+ def save_v(self, name, v):
111
+ if not self.is_save:
112
+ return
113
+ fpath = os.path.join(self.data_dir, name+"_1.npy")
114
+ np.save(fpath, v.detach().cpu().numpy())
115
+
116
+ def save_v2(self, name, v):
117
+ if not self.is_save:
118
+ return
119
+ fpath = os.path.join(self.data_dir, name+"_1.npy")
120
+ np.save(fpath, v.detach().cpu().numpy())
121
+
122
+ def timer_annoc(funct):
123
+ def inner(*args,**kwargs):
124
+ start = time.perf_counter()
125
+ res = funct(*args,**kwargs)
126
+ torch.cuda.synchronize()
127
+ end = time.perf_counter()
128
+ print("torch cost: ", end-start)
129
+ return res
130
+ return inner
131
+
132
+ def get_mem_use():
133
+ f = os.popen("nvidia-smi | grep MiB" )
134
+ line = f.read().strip()
135
+ while " " in line:
136
+ line = line.replace(" ", " ")
137
+ memuse = line.split(" ")[8]
138
+ return memuse
139
+
140
+ if __name__ == "__main__":
141
+ dir_data = sys.argv[1]
142
+ fname_v1 = sys.argv[2]
143
+ fname_v2 = sys.argv[3]
144
+ tol = 0.01
145
+ if len(sys.argv) > 4:
146
+ tol = float(sys.argv[4])
147
+ checker = LyraChecker(dir_data, tol)
148
+ checker.cmp(fname_v1, fname_v2)
models/README.md CHANGED
@@ -2,11 +2,20 @@
2
  ### This is the place where you should download the checkpoints, and unzip them
3
 
4
  ```bash
5
- wget -O lyrasd_rev_animated.tar.gz "https://chuangxin-research-1258344705.cos.ap-guangzhou.myqcloud.com/share/files/lyrasd/lyrasd_rev_animated.tar.gz?q-sign-algorithm=sha1&q-ak=AKIDBF6i7GCtKWS8ZkgOtACzX3MQDl37xYty&q-sign-time=1694078210;1866878210&q-key-time=1694078210;1866878210&q-header-list=&q-url-param-list=&q-signature=6046546135631dee9e8be7d8e061a77e8790e675"
6
- wget -O lyrasd_canny.tar.gz "https://chuangxin-research-1258344705.cos.ap-guangzhou.myqcloud.com/share/files/lyrasd/lyrasd_canny.tar.gz?q-sign-algorithm=sha1&q-ak=AKIDBF6i7GCtKWS8ZkgOtACzX3MQDl37xYty&q-sign-time=1694078194;1866878194&q-key-time=1694078194;1866878194&q-header-list=&q-url-param-list=&q-signature=efb713ee650a0ee3c954fb3a0e148c37ef13cd3b"
7
- wget -O lyrasd_xiaorenshu_lora.tar.gz "https://chuangxin-research-1258344705.cos.ap-guangzhou.myqcloud.com/share/files/lyrasd/lyrasd_xiaorenshu_lora.tar.gz?q-sign-algorithm=sha1&q-ak=AKIDBF6i7GCtKWS8ZkgOtACzX3MQDl37xYty&q-sign-time=1694078234;1866878234&q-key-time=1694078234;1866878234&q-header-list=&q-url-param-list=&q-signature=fb9a577a54ea6dedd9be696e40b96b71a1b23b5d"
 
 
 
 
 
 
 
 
8
 
9
  tar -xvf lyrasd_rev_animated.tar.gz
10
- tar -xvf lyrasd_canny.tar.gz
11
- tar -xvf lyrasd_xiaorenshu_lora.tar.gz
 
12
  ```
 
2
  ### This is the place where you should download the checkpoints, and unzip them
3
 
4
  ```bash
5
+ wget -O lyrasd_rev_animated.tar.gz "https://chuangxin-research-1258344705.cos.ap-guangzhou.myqcloud.com/share/files/lyrasd/lyrasd_rev_animated.tar.gz"
6
+
7
+ wget -O sd-controlnet-canny.tar.gz "https://chuangxin-research-1258344705.cos.ap-guangzhou.myqcloud.com/share/files/lyrasd/sd-controlnet-canny.tar.gz"
8
+
9
+ wget -O xiaorenshu.safetensors "https://civitai.com/api/download/models/25661"
10
+
11
+ wget -O helloworldSDXL20Fp16.tar.gz "https://chuangxin-research-1258344705.cos.ap-guangzhou.myqcloud.com/share/files/lyrasd/helloworldSDXL20Fp16.tar.gz"
12
+
13
+ wget -O controlnet-canny-sdxl-1.0.tar.gz "https://chuangxin-research-1258344705.cos.ap-guangzhou.myqcloud.com/share/files/lyrasd/controlnet-canny-sdxl-1.0.tar.gz"
14
+
15
+ wget -O dissolve_sdxl.safetensors "https://civitai.com/api/download/models/277389?type=Model&format=SafeTensor"
16
 
17
  tar -xvf lyrasd_rev_animated.tar.gz
18
+ tar -xvf sd-controlnet-canny.tar.gz
19
+ tar -xvf helloworldSDXL20Fp16.tar.gz
20
+ tar -xvf controlnet-canny-sdxl-1.0.tar.gz
21
  ```
outputs/res_controlnet_img2img_0.png CHANGED

Git LFS Details

  • SHA256: 1b314eb678f2f3d76737b2b90507fe66e3a62393a89f00681e29bf821d273a60
  • Pointer size: 131 Bytes
  • Size of remote file: 447 kB

Git LFS Details

  • SHA256: 96aea3fc1f0992974935c798380f1ce008e61ff3b75d89c5d12700ed10fddbc9
  • Pointer size: 131 Bytes
  • Size of remote file: 436 kB
outputs/{res_controlnet_sdxl_txt2img.png → res_controlnet_sdxl_txt2img_0.png} RENAMED
File without changes
outputs/res_controlnet_txt2img_0.png CHANGED

Git LFS Details

  • SHA256: 225654758e835c97f49749170bb2440988d34607c023d47a03935068c9778993
  • Pointer size: 131 Bytes
  • Size of remote file: 398 kB

Git LFS Details

  • SHA256: b6d15a9715dd171a9e58ed2b8d628a5655b2b2de2539a9b9147f64d3e1529838
  • Pointer size: 131 Bytes
  • Size of remote file: 389 kB
outputs/res_img2img_0.png CHANGED

Git LFS Details

  • SHA256: cfe8f20e1e4382eacfa6851c5f7d386b5aeb875bca6ff7d927ede1ba43e7677a
  • Pointer size: 131 Bytes
  • Size of remote file: 406 kB

Git LFS Details

  • SHA256: 500882308c72de757094b7d8cc097eadc02dae2745c75ec223b37254190ad9f3
  • Pointer size: 131 Bytes
  • Size of remote file: 409 kB
outputs/res_txt2img_lora_0.png CHANGED

Git LFS Details

  • SHA256: b3879bf13166e9a16cd5314ab69072b6f7f69b80840b2e2204342c7fcfafbe04
  • Pointer size: 131 Bytes
  • Size of remote file: 433 kB

Git LFS Details

  • SHA256: cc46ad18b2444ddee13772c4862eca9519099c8bbb93722061004bfaec486bb7
  • Pointer size: 131 Bytes
  • Size of remote file: 436 kB
outputs/{res_sdxl_txt2img_lora_0.png → res_txt2img_xl_lora_0.png} RENAMED
File without changes
txt2img_demo.py CHANGED
@@ -10,22 +10,25 @@ from lyrasd_model import LyraSdTxt2ImgPipeline
10
  # 4. scheduler 配置
11
 
12
  # LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
13
- lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm86.so"
14
- model_path = "./models/lyrasd_rev_animated"
15
- lora_path = "./models/lyrasd_xiaorenshu_lora"
 
 
16
 
17
  # 构建 Txt2Img 的 Pipeline
18
- model = LyraSdTxt2ImgPipeline(model_path, lib_path)
 
19
 
20
  # load lora
21
  # 参数分别为 lora 存放位置,名字,lora 强度,lora模型精度
22
- model.load_lora(lora_path, "xiaorenshu", 0.4, "fp32")
23
 
24
  # 准备应用的输入和超参数
25
  prompt = "a cat, cute, cartoon, concise, traditional, chinese painting, Tang and Song Dynasties, masterpiece, 4k, 8k, UHD, best quality"
26
  negative_prompt = "(((horrible))), (((scary))), (((naked))), (((large breasts))), high saturation, colorful, human:2, body:2, low quality, bad quality, lowres, out of frame, duplicate, watermark, signature, text, frames, cut, cropped, malformed limbs, extra limbs, (((missing arms))), (((missing legs)))"
27
  height, width = 512, 512
28
- steps = 30
29
  guidance_scale = 7
30
  generator = torch.Generator().manual_seed(123)
31
  num_images = 1
@@ -33,12 +36,12 @@ num_images = 1
33
  start = time.perf_counter()
34
  # 推理生成
35
  images = model(prompt, height, width, steps,
36
- guidance_scale, negative_prompt, num_images,
37
- generator=generator)
38
- print("image gen cost: ",time.perf_counter() - start)
39
  # 存储生成的图片
40
  for i, image in enumerate(images):
41
  image.save(f"outputs/res_txt2img_lora_{i}.png")
42
 
43
  # unload lora,参数为 lora 的名字,是否清除 lora 缓存
44
- # model.unload_lora("xiaorenshu", True)
 
10
  # 4. scheduler 配置
11
 
12
  # LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
13
+ lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so"
14
+ model_path = "./models/rev-animated"
15
+ lora_path = "./models/xiaorenshu.safetensors"
16
+
17
+ torch.classes.load_library(lib_path)
18
 
19
  # 构建 Txt2Img 的 Pipeline
20
+ model = LyraSdTxt2ImgPipeline()
21
+ model.reload_pipe(model_path)
22
 
23
  # load lora
24
  # 参数分别为 lora 存放位置,名字,lora 强度,lora模型精度
25
+ model.load_lora_v2(lora_path, "xiaorenshu", 0.4)
26
 
27
  # 准备应用的输入和超参数
28
  prompt = "a cat, cute, cartoon, concise, traditional, chinese painting, Tang and Song Dynasties, masterpiece, 4k, 8k, UHD, best quality"
29
  negative_prompt = "(((horrible))), (((scary))), (((naked))), (((large breasts))), high saturation, colorful, human:2, body:2, low quality, bad quality, lowres, out of frame, duplicate, watermark, signature, text, frames, cut, cropped, malformed limbs, extra limbs, (((missing arms))), (((missing legs)))"
30
  height, width = 512, 512
31
+ steps = 20
32
  guidance_scale = 7
33
  generator = torch.Generator().manual_seed(123)
34
  num_images = 1
 
36
  start = time.perf_counter()
37
  # 推理生成
38
  images = model(prompt, height, width, steps,
39
+ guidance_scale, negative_prompt, num_images,
40
+ generator=generator)
41
+ print("image gen cost: ", time.perf_counter() - start)
42
  # 存储生成的图片
43
  for i, image in enumerate(images):
44
  image.save(f"outputs/res_txt2img_lora_{i}.png")
45
 
46
  # unload lora,参数为 lora 的名字,是否清除 lora 缓存
47
+ model.unload_lora_v2("xiaorenshu", True)
txt2img_sdxl_demo.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from lyrasd_model import LyraSdXLTxt2ImgPipeline
3
+ import time
4
+ import GPUtil
5
+ import os
6
+ from glob import glob
7
+ import random
8
+
9
+ # 存放模型文件的路径,应该包含一下结构:
10
+ # 1. clip 模型
11
+ # 2. 转换好的优化后的 unet 模型,放入其中的 unet_bins 文件夹
12
+ # 3. vae 模型
13
+ # 4. scheduler 配置
14
+
15
+ # LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
16
+ lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so"
17
+ model_path = "./models/helloworldSDXL20Fp16"
18
+ lora_path = "./models/dissolve_sdxl.safetensors"
19
+ torch.classes.load_library(lib_path)
20
+
21
+ # 构建 Txt2Img 的 Pipeline
22
+ model = LyraSdXLTxt2ImgPipeline()
23
+
24
+ model.reload_pipe(model_path)
25
+
26
+ # load lora
27
+ # lora model path, name,lora strength
28
+ model.load_lora_v2(lora_path, "dissolve_sdxl", 0.4)
29
+
30
+ # 准备应用的输入和超参数
31
+ prompt = "a cat, ral-dissolve"
32
+ negative_prompt = "nswf, watermark"
33
+ height, width = 1024, 1024
34
+ steps = 20
35
+ guidance_scale = 7.5
36
+ generator = torch.Generator().manual_seed(8788800)
37
+
38
+ start = time.perf_counter()
39
+ # 推理生成
40
+ images = model(prompt,
41
+ height=height,
42
+ width=width,
43
+ num_inference_steps=steps,
44
+ num_images_per_prompt=1,
45
+ guidance_scale=guidance_scale,
46
+ negative_prompt=negative_prompt,
47
+ generator=generator
48
+ )
49
+ print("image gen cost: ", time.perf_counter() - start)
50
+ # 存储生成的图片
51
+ for i, image in enumerate(images):
52
+ image.save(f"outputs/res_txt2img_xl_lora_{i}.png")
53
+
54
+ # unload lora,参数为 lora 的名字,是否清除 lora 缓存
55
+ model.unload_lora_v2("dissolve_sdxl", True)