Spaces:
Runtime error
Runtime error
add code
Browse files- app.py +258 -0
- config/chinese-ink-paint.json +72 -0
- config/cloud.json +69 -0
- config/default.json +75 -0
- config/digital-art.json +72 -0
- config/fire.json +72 -0
- config/klimt.json +72 -0
- config/line-art.json +72 -0
- config/low-poly.json +72 -0
- config/munch.json +72 -0
- config/totoro.json +72 -0
- config/van-gogh.json +72 -0
- pipelines/__init__.py +0 -0
- pipelines/controlnet.py +844 -0
- pipelines/inverted_ve_pipeline.py +1615 -0
- pipelines/pipeline_controlnet_sd_xl.py +0 -0
- pipelines/pipeline_stable_diffusion_xl.py +1792 -0
- requirements.txt +10 -0
- utils.py +143 -0
- visualize_attention_src/__init__.py +0 -0
- visualize_attention_src/pipeline_stable_diffusion_xl_attn.py +1573 -0
- visualize_attention_src/save_attn_map_script.py +283 -0
- visualize_attention_src/utils.py +111 -0
- visualize_attention_src/visualize_attn_map_script.py +168 -0
app.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pipelines.inverted_ve_pipeline import STYLE_DESCRIPTION_DICT, create_image_grid
|
3 |
+
import gradio as gr
|
4 |
+
import os, json
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from pipelines.pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
9 |
+
from diffusers import ControlNetModel, AutoencoderKL
|
10 |
+
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
11 |
+
from random import randint
|
12 |
+
from utils import init_latent
|
13 |
+
|
14 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
15 |
+
if device == 'cpu':
|
16 |
+
torch_dtype = torch.float32
|
17 |
+
else:
|
18 |
+
torch_dtype = torch.float16
|
19 |
+
|
20 |
+
|
21 |
+
def memory_efficient(model):
|
22 |
+
try:
|
23 |
+
model.to(device)
|
24 |
+
except Exception as e:
|
25 |
+
print("Error moving model to device:", e)
|
26 |
+
|
27 |
+
try:
|
28 |
+
model.enable_model_cpu_offload()
|
29 |
+
except AttributeError:
|
30 |
+
print("enable_model_cpu_offload is not supported.")
|
31 |
+
try:
|
32 |
+
model.enable_vae_slicing()
|
33 |
+
except AttributeError:
|
34 |
+
print("enable_vae_slicing is not supported.")
|
35 |
+
|
36 |
+
if device == 'cuda':
|
37 |
+
try:
|
38 |
+
model.enable_xformers_memory_efficient_attention()
|
39 |
+
except AttributeError:
|
40 |
+
print("enable_xformers_memory_efficient_attention is not supported.")
|
41 |
+
|
42 |
+
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch_dtype)
|
43 |
+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype)
|
44 |
+
|
45 |
+
model_controlnet = StableDiffusionXLControlNetPipeline.from_pretrained(
|
46 |
+
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch_dtype
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
print("vae")
|
51 |
+
memory_efficient(vae)
|
52 |
+
print("control")
|
53 |
+
memory_efficient(controlnet)
|
54 |
+
print("ControlNet-SDXL")
|
55 |
+
memory_efficient(model_controlnet)
|
56 |
+
|
57 |
+
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
|
58 |
+
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
|
59 |
+
|
60 |
+
# controlnet_scale, canny thres 1, 2 (2 > 1, 2:1, 3:1)
|
61 |
+
|
62 |
+
def parse_config(config):
|
63 |
+
with open(config, 'r') as f:
|
64 |
+
config = json.load(f)
|
65 |
+
return config
|
66 |
+
|
67 |
+
def get_depth_map(image):
|
68 |
+
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
|
69 |
+
with torch.no_grad(), torch.autocast(device):
|
70 |
+
depth_map = depth_estimator(image).predicted_depth
|
71 |
+
|
72 |
+
depth_map = torch.nn.functional.interpolate(
|
73 |
+
depth_map.unsqueeze(1),
|
74 |
+
size=(1024, 1024),
|
75 |
+
mode="bicubic",
|
76 |
+
align_corners=False,
|
77 |
+
)
|
78 |
+
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
|
79 |
+
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
|
80 |
+
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
|
81 |
+
image = torch.cat([depth_map] * 3, dim=1)
|
82 |
+
|
83 |
+
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
|
84 |
+
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
|
85 |
+
return image
|
86 |
+
|
87 |
+
|
88 |
+
def get_depth_edge_array(depth_img_path):
|
89 |
+
depth_image_tmp = Image.fromarray(depth_img_path)
|
90 |
+
|
91 |
+
# get depth map
|
92 |
+
depth_map = get_depth_map(depth_image_tmp)
|
93 |
+
|
94 |
+
return depth_map
|
95 |
+
|
96 |
+
def load_example_controlnet():
|
97 |
+
folder_path = 'assets/ref'
|
98 |
+
examples = []
|
99 |
+
for filename in os.listdir(folder_path):
|
100 |
+
if filename.endswith((".png")):
|
101 |
+
image_path = os.path.join(folder_path, filename)
|
102 |
+
image_name = os.path.basename(image_path)
|
103 |
+
style_name = image_name.split('_')[1]
|
104 |
+
|
105 |
+
config_path = './config/{}.json'.format(style_name)
|
106 |
+
config = parse_config(config_path)
|
107 |
+
inf_object_name = config["inference_info"]["inf_object_list"][0]
|
108 |
+
|
109 |
+
canny_path = './assets/depth_dir/gundam.png'
|
110 |
+
image_info = [image_path, canny_path, style_name, inf_object_name, 1, 0.5, 50]
|
111 |
+
|
112 |
+
examples.append(image_info)
|
113 |
+
|
114 |
+
return examples
|
115 |
+
|
116 |
+
def controlnet_fn(image_path, depth_image_path, style_name, content_text, output_number, controlnet_scale=0.5, diffusion_step=50):
|
117 |
+
"""
|
118 |
+
|
119 |
+
:param style_name: 어떤 json 파일 부를거냐 ?
|
120 |
+
:param content_text: 어떤 콘텐츠로 변화를 원하니 ?
|
121 |
+
:param output_number: 몇개 생성할거니 ?
|
122 |
+
:return:
|
123 |
+
"""
|
124 |
+
config_path = './config/{}.json'.format(style_name)
|
125 |
+
config = parse_config(config_path)
|
126 |
+
|
127 |
+
inf_object = content_text
|
128 |
+
inf_seeds = [randint(0, 10**10) for _ in range(int(output_number))]
|
129 |
+
# inf_seeds = [i for i in range(int(output_number))]
|
130 |
+
|
131 |
+
activate_layer_indices_list = config['inference_info']['activate_layer_indices_list']
|
132 |
+
activate_step_indices_list = config['inference_info']['activate_step_indices_list']
|
133 |
+
ref_seed = config['reference_info']['ref_seeds'][0]
|
134 |
+
|
135 |
+
attn_map_save_steps = config['inference_info']['attn_map_save_steps']
|
136 |
+
guidance_scale = config['guidance_scale']
|
137 |
+
use_inf_negative_prompt = config['inference_info']['use_negative_prompt']
|
138 |
+
|
139 |
+
style_name = config["style_name_list"][0]
|
140 |
+
|
141 |
+
ref_object = config["reference_info"]["ref_object_list"][0]
|
142 |
+
ref_with_style_description = config['reference_info']['with_style_description']
|
143 |
+
inf_with_style_description = config['inference_info']['with_style_description']
|
144 |
+
|
145 |
+
use_shared_attention = config['inference_info']['use_shared_attention']
|
146 |
+
adain_queries = config['inference_info']['adain_queries']
|
147 |
+
adain_keys = config['inference_info']['adain_keys']
|
148 |
+
adain_values = config['inference_info']['adain_values']
|
149 |
+
|
150 |
+
use_advanced_sampling = config['inference_info']['use_advanced_sampling']
|
151 |
+
|
152 |
+
#get canny edge array
|
153 |
+
depth_image = get_depth_edge_array(depth_image_path)
|
154 |
+
|
155 |
+
style_description_pos, style_description_neg = STYLE_DESCRIPTION_DICT[style_name][0], \
|
156 |
+
STYLE_DESCRIPTION_DICT[style_name][1]
|
157 |
+
|
158 |
+
# Inference
|
159 |
+
with torch.inference_mode():
|
160 |
+
grid = None
|
161 |
+
if ref_with_style_description:
|
162 |
+
ref_prompt = style_description_pos.replace("{object}", ref_object)
|
163 |
+
else:
|
164 |
+
ref_prompt = ref_object
|
165 |
+
|
166 |
+
if inf_with_style_description:
|
167 |
+
inf_prompt = style_description_pos.replace("{object}", inf_object)
|
168 |
+
else:
|
169 |
+
inf_prompt = inf_object
|
170 |
+
|
171 |
+
for activate_layer_indices in activate_layer_indices_list:
|
172 |
+
|
173 |
+
for activate_step_indices in activate_step_indices_list:
|
174 |
+
|
175 |
+
str_activate_layer, str_activate_step = model_controlnet.activate_layer(
|
176 |
+
activate_layer_indices=activate_layer_indices,
|
177 |
+
attn_map_save_steps=attn_map_save_steps,
|
178 |
+
activate_step_indices=activate_step_indices,
|
179 |
+
use_shared_attention=use_shared_attention,
|
180 |
+
adain_queries=adain_queries,
|
181 |
+
adain_keys=adain_keys,
|
182 |
+
adain_values=adain_values,
|
183 |
+
)
|
184 |
+
|
185 |
+
# ref_latent = model_controlnet.get_init_latent(ref_seed, precomputed_path=None)
|
186 |
+
ref_latent = init_latent(model_controlnet, device_name=device, dtype=torch_dtype, seed=ref_seed)
|
187 |
+
latents = [ref_latent]
|
188 |
+
|
189 |
+
for inf_seed in inf_seeds:
|
190 |
+
# latents.append(model_controlnet.get_init_latent(inf_seed, precomputed_path=None))
|
191 |
+
inf_latent = init_latent(model_controlnet, device_name=device, dtype=torch_dtype, seed=inf_seed)
|
192 |
+
latents.append(inf_latent)
|
193 |
+
|
194 |
+
|
195 |
+
latents = torch.cat(latents, dim=0)
|
196 |
+
latents.to(device)
|
197 |
+
|
198 |
+
images = model_controlnet.generated_ve_inference(
|
199 |
+
prompt=ref_prompt,
|
200 |
+
negative_prompt=style_description_neg,
|
201 |
+
guidance_scale=guidance_scale,
|
202 |
+
num_inference_steps=diffusion_step,
|
203 |
+
controlnet_conditioning_scale=controlnet_scale,
|
204 |
+
latents=latents,
|
205 |
+
num_images_per_prompt=len(inf_seeds) + 1,
|
206 |
+
target_prompt=inf_prompt,
|
207 |
+
image=depth_image,
|
208 |
+
use_inf_negative_prompt=use_inf_negative_prompt,
|
209 |
+
use_advanced_sampling=use_advanced_sampling
|
210 |
+
)[0][1:]
|
211 |
+
|
212 |
+
n_row = 1
|
213 |
+
n_col = len(inf_seeds) # 원본추가하려면 + 1
|
214 |
+
|
215 |
+
# make grid
|
216 |
+
grid = create_image_grid(images, n_row, n_col)
|
217 |
+
|
218 |
+
torch.cuda.empty_cache()
|
219 |
+
return grid
|
220 |
+
|
221 |
+
|
222 |
+
description_md = """
|
223 |
+
|
224 |
+
### We introduce `Visual Style Prompting`, which reflects the style of a reference image to the images generated by a pretrained text-to-image diffusion model without finetuning or optimization (e.g., Figure N).
|
225 |
+
### 📖 [[Paper](https://arxiv.org/abs/2402.12974)] | ✨ [[Project page](https://curryjung.github.io/VisualStylePrompt)] | ✨ [[Code](https://github.com/naver-ai/Visual-Style-Prompting)]
|
226 |
+
### 🔥 [[Default ver](https://huggingface.co/spaces/naver-ai/VisualStylePrompting)]
|
227 |
+
---
|
228 |
+
### Visual Style Prompting also works on `ControlNet` which specifies the shape of the results by depthmap or keypoints.
|
229 |
+
|
230 |
+
### To try out our demo with ControlNet,
|
231 |
+
1. Upload an `image for depth control`. An off-the-shelf model will produce the depthmap from it.
|
232 |
+
2. Choose `ControlNet scale` which determines the alignment to the depthmap.
|
233 |
+
3. Choose a `style reference` from the collection of images below.
|
234 |
+
4. Enter the `text prompt`. (`Empty text` is okay, but a depthmap description helps.)
|
235 |
+
5. Choose the `number of outputs`.
|
236 |
+
|
237 |
+
### To achieve faster results, we recommend lowering the diffusion steps to 30.
|
238 |
+
### Enjoy ! 😄
|
239 |
+
"""
|
240 |
+
|
241 |
+
iface_controlnet = gr.Interface(
|
242 |
+
fn=controlnet_fn,
|
243 |
+
inputs=[
|
244 |
+
gr.components.Image(label="Style image"),
|
245 |
+
gr.components.Image(label="Depth image"),
|
246 |
+
gr.components.Textbox(label='Style name', visible=False),
|
247 |
+
gr.components.Textbox(label="Text prompt", placeholder="Enter Text prompt"),
|
248 |
+
gr.components.Textbox(label="Number of outputs", placeholder="Enter Number of outputs"),
|
249 |
+
gr.components.Slider(minimum=0.5, maximum=10, step=0.5, value=0.5, label="Controlnet scale"),
|
250 |
+
gr.components.Slider(minimum=50, maximum=50, step=10, value=50, label="Diffusion steps")
|
251 |
+
],
|
252 |
+
outputs=gr.components.Image(type="pil"),
|
253 |
+
title="🎨 Visual Style Prompting (w/ ControlNet)",
|
254 |
+
description=description_md,
|
255 |
+
examples=load_example_controlnet(),
|
256 |
+
)
|
257 |
+
|
258 |
+
iface_controlnet.launch(debug=True)
|
config/chinese-ink-paint.json
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"device": "cuda",
|
3 |
+
"precomputed_path": "./precomputed",
|
4 |
+
"guidance_scale": 7.0,
|
5 |
+
"style_name_list": [
|
6 |
+
"chinese-ink-paint"
|
7 |
+
],
|
8 |
+
"save_info": {
|
9 |
+
"base_exp_dir": "experiments",
|
10 |
+
"base_exp_name": "results"
|
11 |
+
},
|
12 |
+
"reference_info": {
|
13 |
+
"ref_seeds": [
|
14 |
+
1
|
15 |
+
],
|
16 |
+
"ref_object_list": [
|
17 |
+
"A horse"
|
18 |
+
],
|
19 |
+
"with_style_description": true,
|
20 |
+
"external_init_noise_path": false,
|
21 |
+
"guidance_scale": 7.0,
|
22 |
+
"use_negative_prompt": true
|
23 |
+
},
|
24 |
+
"inference_info": {
|
25 |
+
"activate_layer_indices_list": [
|
26 |
+
[
|
27 |
+
[
|
28 |
+
0,
|
29 |
+
0
|
30 |
+
],
|
31 |
+
[
|
32 |
+
128,
|
33 |
+
140
|
34 |
+
]
|
35 |
+
]
|
36 |
+
],
|
37 |
+
"inf_seeds": [
|
38 |
+
0,
|
39 |
+
1,
|
40 |
+
2,
|
41 |
+
3,
|
42 |
+
4,
|
43 |
+
5,
|
44 |
+
6,
|
45 |
+
7,
|
46 |
+
8,
|
47 |
+
9
|
48 |
+
],
|
49 |
+
"inf_object_list": [
|
50 |
+
"A tiger"
|
51 |
+
],
|
52 |
+
"with_style_description": true,
|
53 |
+
"negative_prompts": false,
|
54 |
+
"external_init_noise_path": false,
|
55 |
+
"attn_map_save_steps": [],
|
56 |
+
"guidance_scale": 7.0,
|
57 |
+
"use_negative_prompt": true,
|
58 |
+
"activate_step_indices_list": [
|
59 |
+
[
|
60 |
+
[
|
61 |
+
0,
|
62 |
+
49
|
63 |
+
]
|
64 |
+
]
|
65 |
+
],
|
66 |
+
"use_advanced_sampling": true,
|
67 |
+
"use_shared_attention": false,
|
68 |
+
"adain_queries": true,
|
69 |
+
"adain_keys": true,
|
70 |
+
"adain_values": false
|
71 |
+
}
|
72 |
+
}
|
config/cloud.json
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"device": "cuda",
|
3 |
+
"precomputed_path": "./precomputed",
|
4 |
+
"guidance_scale": 7.0,
|
5 |
+
"style_name_list": [
|
6 |
+
"cloud"
|
7 |
+
],
|
8 |
+
"save_info": {
|
9 |
+
"base_exp_dir": "experiments",
|
10 |
+
"base_exp_name": "results"
|
11 |
+
},
|
12 |
+
"reference_info": {
|
13 |
+
"ref_seeds": [
|
14 |
+
3
|
15 |
+
],
|
16 |
+
"ref_object_list": [
|
17 |
+
"a Cloud in the sky"
|
18 |
+
],
|
19 |
+
"with_style_description": true,
|
20 |
+
"external_init_noise_path": false,
|
21 |
+
"guidance_scale": 7.0,
|
22 |
+
"use_negative_prompt": true
|
23 |
+
},
|
24 |
+
"inference_info": {
|
25 |
+
"activate_layer_indices_list": [
|
26 |
+
[
|
27 |
+
[
|
28 |
+
0,
|
29 |
+
0
|
30 |
+
],
|
31 |
+
[
|
32 |
+
128,
|
33 |
+
140
|
34 |
+
]
|
35 |
+
]
|
36 |
+
],
|
37 |
+
"inf_seeds": [
|
38 |
+
0,
|
39 |
+
1,
|
40 |
+
2,
|
41 |
+
3,
|
42 |
+
4,
|
43 |
+
5
|
44 |
+
|
45 |
+
],
|
46 |
+
"inf_object_list": [
|
47 |
+
"A photo of a dog"
|
48 |
+
],
|
49 |
+
"with_style_description": true,
|
50 |
+
"negative_prompts": false,
|
51 |
+
"external_init_noise_path": false,
|
52 |
+
"attn_map_save_steps": [],
|
53 |
+
"guidance_scale": 7.0,
|
54 |
+
"use_negative_prompt": true,
|
55 |
+
"activate_step_indices_list": [
|
56 |
+
[
|
57 |
+
[
|
58 |
+
0,
|
59 |
+
49
|
60 |
+
]
|
61 |
+
]
|
62 |
+
],
|
63 |
+
"use_advanced_sampling": true,
|
64 |
+
"use_shared_attention": false,
|
65 |
+
"adain_queries": true,
|
66 |
+
"adain_keys": true,
|
67 |
+
"adain_values": false
|
68 |
+
}
|
69 |
+
}
|
config/default.json
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"device": "cuda",
|
3 |
+
"precomputed_path": "./precomputed",
|
4 |
+
"guidance_scale": 7.0,
|
5 |
+
"style_name_list": [
|
6 |
+
"anime",
|
7 |
+
"Artstyle_Pop_Art",
|
8 |
+
"low_poly",
|
9 |
+
"line_art"
|
10 |
+
],
|
11 |
+
"save_info": {
|
12 |
+
"base_exp_dir": "experiments",
|
13 |
+
"base_exp_name": "results"
|
14 |
+
},
|
15 |
+
"reference_info": {
|
16 |
+
"ref_seeds": [
|
17 |
+
42
|
18 |
+
],
|
19 |
+
"ref_object_list": [
|
20 |
+
"cat"
|
21 |
+
],
|
22 |
+
"with_style_description": true,
|
23 |
+
"external_init_noise_path": false,
|
24 |
+
"guidance_scale": 7.0,
|
25 |
+
"use_negative_prompt": true
|
26 |
+
},
|
27 |
+
"inference_info": {
|
28 |
+
"activate_layer_indices_list": [
|
29 |
+
[
|
30 |
+
[
|
31 |
+
0,
|
32 |
+
0
|
33 |
+
],
|
34 |
+
[
|
35 |
+
128,
|
36 |
+
140
|
37 |
+
]
|
38 |
+
]
|
39 |
+
],
|
40 |
+
"inf_seeds": [
|
41 |
+
0,
|
42 |
+
1,
|
43 |
+
2,
|
44 |
+
3,
|
45 |
+
4,
|
46 |
+
5,
|
47 |
+
6,
|
48 |
+
7,
|
49 |
+
8,
|
50 |
+
9
|
51 |
+
],
|
52 |
+
"inf_object_list": [
|
53 |
+
"A photo of a dog"
|
54 |
+
],
|
55 |
+
"with_style_description": true,
|
56 |
+
"negative_prompts": false,
|
57 |
+
"external_init_noise_path": false,
|
58 |
+
"attn_map_save_steps": [],
|
59 |
+
"guidance_scale": 7.0,
|
60 |
+
"use_negative_prompt": true,
|
61 |
+
"activate_step_indices_list": [
|
62 |
+
[
|
63 |
+
[
|
64 |
+
0,
|
65 |
+
49
|
66 |
+
]
|
67 |
+
]
|
68 |
+
],
|
69 |
+
"use_advanced_sampling": true,
|
70 |
+
"use_shared_attention": false,
|
71 |
+
"adain_queries": true,
|
72 |
+
"adain_keys": true,
|
73 |
+
"adain_values": false
|
74 |
+
}
|
75 |
+
}
|
config/digital-art.json
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"device": "cuda",
|
3 |
+
"precomputed_path": "./precomputed",
|
4 |
+
"guidance_scale": 7.0,
|
5 |
+
"style_name_list": [
|
6 |
+
"digital-art"
|
7 |
+
],
|
8 |
+
"save_info": {
|
9 |
+
"base_exp_dir": "experiments",
|
10 |
+
"base_exp_name": "results"
|
11 |
+
},
|
12 |
+
"reference_info": {
|
13 |
+
"ref_seeds": [
|
14 |
+
42
|
15 |
+
],
|
16 |
+
"ref_object_list": [
|
17 |
+
"A robot"
|
18 |
+
],
|
19 |
+
"with_style_description": true,
|
20 |
+
"external_init_noise_path": false,
|
21 |
+
"guidance_scale": 7.0,
|
22 |
+
"use_negative_prompt": true
|
23 |
+
},
|
24 |
+
"inference_info": {
|
25 |
+
"activate_layer_indices_list": [
|
26 |
+
[
|
27 |
+
[
|
28 |
+
0,
|
29 |
+
0
|
30 |
+
],
|
31 |
+
[
|
32 |
+
128,
|
33 |
+
140
|
34 |
+
]
|
35 |
+
]
|
36 |
+
],
|
37 |
+
"inf_seeds": [
|
38 |
+
0,
|
39 |
+
1,
|
40 |
+
2,
|
41 |
+
3,
|
42 |
+
4,
|
43 |
+
5,
|
44 |
+
6,
|
45 |
+
7,
|
46 |
+
8,
|
47 |
+
9
|
48 |
+
],
|
49 |
+
"inf_object_list": [
|
50 |
+
"A woman playing basketball"
|
51 |
+
],
|
52 |
+
"with_style_description": true,
|
53 |
+
"negative_prompts": false,
|
54 |
+
"external_init_noise_path": false,
|
55 |
+
"attn_map_save_steps": [],
|
56 |
+
"guidance_scale": 7.0,
|
57 |
+
"use_negative_prompt": true,
|
58 |
+
"activate_step_indices_list": [
|
59 |
+
[
|
60 |
+
[
|
61 |
+
0,
|
62 |
+
49
|
63 |
+
]
|
64 |
+
]
|
65 |
+
],
|
66 |
+
"use_advanced_sampling": true,
|
67 |
+
"use_shared_attention": false,
|
68 |
+
"adain_queries": true,
|
69 |
+
"adain_keys": true,
|
70 |
+
"adain_values": false
|
71 |
+
}
|
72 |
+
}
|
config/fire.json
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"device": "cuda",
|
3 |
+
"precomputed_path": "./precomputed",
|
4 |
+
"guidance_scale": 7.0,
|
5 |
+
"style_name_list": [
|
6 |
+
"fire"
|
7 |
+
],
|
8 |
+
"save_info": {
|
9 |
+
"base_exp_dir": "experiments",
|
10 |
+
"base_exp_name": "results"
|
11 |
+
},
|
12 |
+
"reference_info": {
|
13 |
+
"ref_seeds": [
|
14 |
+
3
|
15 |
+
],
|
16 |
+
"ref_object_list": [
|
17 |
+
"fire"
|
18 |
+
],
|
19 |
+
"with_style_description": true,
|
20 |
+
"external_init_noise_path": false,
|
21 |
+
"guidance_scale": 7.0,
|
22 |
+
"use_negative_prompt": true
|
23 |
+
},
|
24 |
+
"inference_info": {
|
25 |
+
"activate_layer_indices_list": [
|
26 |
+
[
|
27 |
+
[
|
28 |
+
0,
|
29 |
+
0
|
30 |
+
],
|
31 |
+
[
|
32 |
+
128,
|
33 |
+
140
|
34 |
+
]
|
35 |
+
]
|
36 |
+
],
|
37 |
+
"inf_seeds": [
|
38 |
+
0,
|
39 |
+
1,
|
40 |
+
2,
|
41 |
+
3,
|
42 |
+
4,
|
43 |
+
5,
|
44 |
+
6,
|
45 |
+
7,
|
46 |
+
8,
|
47 |
+
9
|
48 |
+
],
|
49 |
+
"inf_object_list": [
|
50 |
+
"A dragon"
|
51 |
+
],
|
52 |
+
"with_style_description": true,
|
53 |
+
"negative_prompts": false,
|
54 |
+
"external_init_noise_path": false,
|
55 |
+
"attn_map_save_steps": [],
|
56 |
+
"guidance_scale": 7.0,
|
57 |
+
"use_negative_prompt": true,
|
58 |
+
"activate_step_indices_list": [
|
59 |
+
[
|
60 |
+
[
|
61 |
+
0,
|
62 |
+
49
|
63 |
+
]
|
64 |
+
]
|
65 |
+
],
|
66 |
+
"use_advanced_sampling": true,
|
67 |
+
"use_shared_attention": false,
|
68 |
+
"adain_queries": true,
|
69 |
+
"adain_keys": true,
|
70 |
+
"adain_values": false
|
71 |
+
}
|
72 |
+
}
|
config/klimt.json
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"device": "cuda",
|
3 |
+
"precomputed_path": "./precomputed",
|
4 |
+
"guidance_scale": 7.0,
|
5 |
+
"style_name_list": [
|
6 |
+
"klimt"
|
7 |
+
],
|
8 |
+
"save_info": {
|
9 |
+
"base_exp_dir": "experiments",
|
10 |
+
"base_exp_name": "results"
|
11 |
+
},
|
12 |
+
"reference_info": {
|
13 |
+
"ref_seeds": [
|
14 |
+
1
|
15 |
+
],
|
16 |
+
"ref_object_list": [
|
17 |
+
"the kiss"
|
18 |
+
],
|
19 |
+
"with_style_description": true,
|
20 |
+
"external_init_noise_path": false,
|
21 |
+
"guidance_scale": 7.0,
|
22 |
+
"use_negative_prompt": true
|
23 |
+
},
|
24 |
+
"inference_info": {
|
25 |
+
"activate_layer_indices_list": [
|
26 |
+
[
|
27 |
+
[
|
28 |
+
0,
|
29 |
+
0
|
30 |
+
],
|
31 |
+
[
|
32 |
+
128,
|
33 |
+
140
|
34 |
+
]
|
35 |
+
]
|
36 |
+
],
|
37 |
+
"inf_seeds": [
|
38 |
+
0,
|
39 |
+
1,
|
40 |
+
2,
|
41 |
+
3,
|
42 |
+
4,
|
43 |
+
5,
|
44 |
+
6,
|
45 |
+
7,
|
46 |
+
8,
|
47 |
+
9
|
48 |
+
],
|
49 |
+
"inf_object_list": [
|
50 |
+
"Frog"
|
51 |
+
],
|
52 |
+
"with_style_description": true,
|
53 |
+
"negative_prompts": false,
|
54 |
+
"external_init_noise_path": false,
|
55 |
+
"attn_map_save_steps": [],
|
56 |
+
"guidance_scale": 7.0,
|
57 |
+
"use_negative_prompt": true,
|
58 |
+
"activate_step_indices_list": [
|
59 |
+
[
|
60 |
+
[
|
61 |
+
0,
|
62 |
+
49
|
63 |
+
]
|
64 |
+
]
|
65 |
+
],
|
66 |
+
"use_advanced_sampling": true,
|
67 |
+
"use_shared_attention": false,
|
68 |
+
"adain_queries": true,
|
69 |
+
"adain_keys": true,
|
70 |
+
"adain_values": false
|
71 |
+
}
|
72 |
+
}
|
config/line-art.json
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"device": "cuda",
|
3 |
+
"precomputed_path": "./precomputed",
|
4 |
+
"guidance_scale": 7.0,
|
5 |
+
"style_name_list": [
|
6 |
+
"line-art"
|
7 |
+
],
|
8 |
+
"save_info": {
|
9 |
+
"base_exp_dir": "experiments",
|
10 |
+
"base_exp_name": "results"
|
11 |
+
},
|
12 |
+
"reference_info": {
|
13 |
+
"ref_seeds": [
|
14 |
+
42
|
15 |
+
],
|
16 |
+
"ref_object_list": [
|
17 |
+
"an owl"
|
18 |
+
],
|
19 |
+
"with_style_description": true,
|
20 |
+
"external_init_noise_path": false,
|
21 |
+
"guidance_scale": 7.0,
|
22 |
+
"use_negative_prompt": true
|
23 |
+
},
|
24 |
+
"inference_info": {
|
25 |
+
"activate_layer_indices_list": [
|
26 |
+
[
|
27 |
+
[
|
28 |
+
0,
|
29 |
+
0
|
30 |
+
],
|
31 |
+
[
|
32 |
+
128,
|
33 |
+
140
|
34 |
+
]
|
35 |
+
]
|
36 |
+
],
|
37 |
+
"inf_seeds": [
|
38 |
+
0,
|
39 |
+
1,
|
40 |
+
2,
|
41 |
+
3,
|
42 |
+
4,
|
43 |
+
5,
|
44 |
+
6,
|
45 |
+
7,
|
46 |
+
8,
|
47 |
+
9
|
48 |
+
],
|
49 |
+
"inf_object_list": [
|
50 |
+
"A dragon"
|
51 |
+
],
|
52 |
+
"with_style_description": true,
|
53 |
+
"negative_prompts": false,
|
54 |
+
"external_init_noise_path": false,
|
55 |
+
"attn_map_save_steps": [],
|
56 |
+
"guidance_scale": 7.0,
|
57 |
+
"use_negative_prompt": true,
|
58 |
+
"activate_step_indices_list": [
|
59 |
+
[
|
60 |
+
[
|
61 |
+
0,
|
62 |
+
49
|
63 |
+
]
|
64 |
+
]
|
65 |
+
],
|
66 |
+
"use_advanced_sampling": true,
|
67 |
+
"use_shared_attention": false,
|
68 |
+
"adain_queries": true,
|
69 |
+
"adain_keys": true,
|
70 |
+
"adain_values": false
|
71 |
+
}
|
72 |
+
}
|
config/low-poly.json
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"device": "cuda",
|
3 |
+
"precomputed_path": "./precomputed",
|
4 |
+
"guidance_scale": 7.0,
|
5 |
+
"style_name_list": [
|
6 |
+
"low-poly"
|
7 |
+
],
|
8 |
+
"save_info": {
|
9 |
+
"base_exp_dir": "experiments",
|
10 |
+
"base_exp_name": "results"
|
11 |
+
},
|
12 |
+
"reference_info": {
|
13 |
+
"ref_seeds": [
|
14 |
+
42
|
15 |
+
],
|
16 |
+
"ref_object_list": [
|
17 |
+
"A cat"
|
18 |
+
],
|
19 |
+
"with_style_description": true,
|
20 |
+
"external_init_noise_path": false,
|
21 |
+
"guidance_scale": 7.0,
|
22 |
+
"use_negative_prompt": true
|
23 |
+
},
|
24 |
+
"inference_info": {
|
25 |
+
"activate_layer_indices_list": [
|
26 |
+
[
|
27 |
+
[
|
28 |
+
0,
|
29 |
+
0
|
30 |
+
],
|
31 |
+
[
|
32 |
+
128,
|
33 |
+
140
|
34 |
+
]
|
35 |
+
]
|
36 |
+
],
|
37 |
+
"inf_seeds": [
|
38 |
+
0,
|
39 |
+
1,
|
40 |
+
2,
|
41 |
+
3,
|
42 |
+
4,
|
43 |
+
5,
|
44 |
+
6,
|
45 |
+
7,
|
46 |
+
8,
|
47 |
+
9
|
48 |
+
],
|
49 |
+
"inf_object_list": [
|
50 |
+
"A rhino"
|
51 |
+
],
|
52 |
+
"with_style_description": true,
|
53 |
+
"negative_prompts": false,
|
54 |
+
"external_init_noise_path": false,
|
55 |
+
"attn_map_save_steps": [],
|
56 |
+
"guidance_scale": 7.0,
|
57 |
+
"use_negative_prompt": true,
|
58 |
+
"activate_step_indices_list": [
|
59 |
+
[
|
60 |
+
[
|
61 |
+
0,
|
62 |
+
49
|
63 |
+
]
|
64 |
+
]
|
65 |
+
],
|
66 |
+
"use_advanced_sampling": true,
|
67 |
+
"use_shared_attention": false,
|
68 |
+
"adain_queries": true,
|
69 |
+
"adain_keys": true,
|
70 |
+
"adain_values": false
|
71 |
+
}
|
72 |
+
}
|
config/munch.json
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"device": "cuda",
|
3 |
+
"precomputed_path": "./precomputed",
|
4 |
+
"guidance_scale": 7.0,
|
5 |
+
"style_name_list": [
|
6 |
+
"munch"
|
7 |
+
],
|
8 |
+
"save_info": {
|
9 |
+
"base_exp_dir": "experiments",
|
10 |
+
"base_exp_name": "results"
|
11 |
+
},
|
12 |
+
"reference_info": {
|
13 |
+
"ref_seeds": [
|
14 |
+
1
|
15 |
+
],
|
16 |
+
"ref_object_list": [
|
17 |
+
"The scream"
|
18 |
+
],
|
19 |
+
"with_style_description": true,
|
20 |
+
"external_init_noise_path": false,
|
21 |
+
"guidance_scale": 7.0,
|
22 |
+
"use_negative_prompt": true
|
23 |
+
},
|
24 |
+
"inference_info": {
|
25 |
+
"activate_layer_indices_list": [
|
26 |
+
[
|
27 |
+
[
|
28 |
+
0,
|
29 |
+
0
|
30 |
+
],
|
31 |
+
[
|
32 |
+
128,
|
33 |
+
140
|
34 |
+
]
|
35 |
+
]
|
36 |
+
],
|
37 |
+
"inf_seeds": [
|
38 |
+
0,
|
39 |
+
1,
|
40 |
+
2,
|
41 |
+
3,
|
42 |
+
4,
|
43 |
+
5,
|
44 |
+
6,
|
45 |
+
7,
|
46 |
+
8,
|
47 |
+
9
|
48 |
+
],
|
49 |
+
"inf_object_list": [
|
50 |
+
"A dragon"
|
51 |
+
],
|
52 |
+
"with_style_description": true,
|
53 |
+
"negative_prompts": false,
|
54 |
+
"external_init_noise_path": false,
|
55 |
+
"attn_map_save_steps": [],
|
56 |
+
"guidance_scale": 7.0,
|
57 |
+
"use_negative_prompt": true,
|
58 |
+
"activate_step_indices_list": [
|
59 |
+
[
|
60 |
+
[
|
61 |
+
0,
|
62 |
+
49
|
63 |
+
]
|
64 |
+
]
|
65 |
+
],
|
66 |
+
"use_advanced_sampling": true,
|
67 |
+
"use_shared_attention": false,
|
68 |
+
"adain_queries": true,
|
69 |
+
"adain_keys": true,
|
70 |
+
"adain_values": false
|
71 |
+
}
|
72 |
+
}
|
config/totoro.json
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"device": "cuda",
|
3 |
+
"precomputed_path": "./precomputed",
|
4 |
+
"guidance_scale": 7.0,
|
5 |
+
"style_name_list": [
|
6 |
+
"totoro"
|
7 |
+
],
|
8 |
+
"save_info": {
|
9 |
+
"base_exp_dir": "experiments",
|
10 |
+
"base_exp_name": "results"
|
11 |
+
},
|
12 |
+
"reference_info": {
|
13 |
+
"ref_seeds": [
|
14 |
+
42
|
15 |
+
],
|
16 |
+
"ref_object_list": [
|
17 |
+
"totoro holding a tiny umbrella in the rain"
|
18 |
+
],
|
19 |
+
"with_style_description": true,
|
20 |
+
"external_init_noise_path": false,
|
21 |
+
"guidance_scale": 7.0,
|
22 |
+
"use_negative_prompt": true
|
23 |
+
},
|
24 |
+
"inference_info": {
|
25 |
+
"activate_layer_indices_list": [
|
26 |
+
[
|
27 |
+
[
|
28 |
+
0,
|
29 |
+
0
|
30 |
+
],
|
31 |
+
[
|
32 |
+
108,
|
33 |
+
140
|
34 |
+
]
|
35 |
+
]
|
36 |
+
],
|
37 |
+
"inf_seeds": [
|
38 |
+
0,
|
39 |
+
1,
|
40 |
+
2,
|
41 |
+
3,
|
42 |
+
4,
|
43 |
+
5,
|
44 |
+
6,
|
45 |
+
7,
|
46 |
+
8,
|
47 |
+
9
|
48 |
+
],
|
49 |
+
"inf_object_list": [
|
50 |
+
"1 cute bird holding a tiny umbrella, forward facing"
|
51 |
+
],
|
52 |
+
"with_style_description": true,
|
53 |
+
"negative_prompts": true,
|
54 |
+
"external_init_noise_path": false,
|
55 |
+
"attn_map_save_steps": [],
|
56 |
+
"guidance_scale": 7.0,
|
57 |
+
"use_negative_prompt": true,
|
58 |
+
"activate_step_indices_list": [
|
59 |
+
[
|
60 |
+
[
|
61 |
+
0,
|
62 |
+
49
|
63 |
+
]
|
64 |
+
]
|
65 |
+
],
|
66 |
+
"use_advanced_sampling": true,
|
67 |
+
"use_shared_attention": false,
|
68 |
+
"adain_queries": true,
|
69 |
+
"adain_keys": true,
|
70 |
+
"adain_values": false
|
71 |
+
}
|
72 |
+
}
|
config/van-gogh.json
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"device": "cuda",
|
3 |
+
"precomputed_path": "./precomputed",
|
4 |
+
"guidance_scale": 7.0,
|
5 |
+
"style_name_list": [
|
6 |
+
"van-gogh"
|
7 |
+
],
|
8 |
+
"save_info": {
|
9 |
+
"base_exp_dir": "experiments",
|
10 |
+
"base_exp_name": "results"
|
11 |
+
},
|
12 |
+
"reference_info": {
|
13 |
+
"ref_seeds": [
|
14 |
+
1
|
15 |
+
],
|
16 |
+
"ref_object_list": [
|
17 |
+
"The Starry Night"
|
18 |
+
],
|
19 |
+
"with_style_description": true,
|
20 |
+
"external_init_noise_path": false,
|
21 |
+
"guidance_scale": 7.0,
|
22 |
+
"use_negative_prompt": true
|
23 |
+
},
|
24 |
+
"inference_info": {
|
25 |
+
"activate_layer_indices_list": [
|
26 |
+
[
|
27 |
+
[
|
28 |
+
0,
|
29 |
+
0
|
30 |
+
],
|
31 |
+
[
|
32 |
+
128,
|
33 |
+
140
|
34 |
+
]
|
35 |
+
]
|
36 |
+
],
|
37 |
+
"inf_seeds": [
|
38 |
+
0,
|
39 |
+
1,
|
40 |
+
2,
|
41 |
+
3,
|
42 |
+
4,
|
43 |
+
5,
|
44 |
+
6,
|
45 |
+
7,
|
46 |
+
8,
|
47 |
+
9
|
48 |
+
],
|
49 |
+
"inf_object_list": [
|
50 |
+
"A dragon"
|
51 |
+
],
|
52 |
+
"with_style_description": true,
|
53 |
+
"negative_prompts": false,
|
54 |
+
"external_init_noise_path": false,
|
55 |
+
"attn_map_save_steps": [],
|
56 |
+
"guidance_scale": 7.0,
|
57 |
+
"use_negative_prompt": true,
|
58 |
+
"activate_step_indices_list": [
|
59 |
+
[
|
60 |
+
[
|
61 |
+
0,
|
62 |
+
49
|
63 |
+
]
|
64 |
+
]
|
65 |
+
],
|
66 |
+
"use_advanced_sampling": true,
|
67 |
+
"use_shared_attention": false,
|
68 |
+
"adain_queries": true,
|
69 |
+
"adain_keys": true,
|
70 |
+
"adain_values": false
|
71 |
+
}
|
72 |
+
}
|
pipelines/__init__.py
ADDED
File without changes
|
pipelines/controlnet.py
ADDED
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
from torch.nn import functional as F
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.loaders import FromOriginalControlnetMixin
|
23 |
+
from diffusers.utils import BaseOutput, logging
|
24 |
+
from diffusers.models.attention_processor import (
|
25 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
26 |
+
CROSS_ATTENTION_PROCESSORS,
|
27 |
+
AttentionProcessor,
|
28 |
+
AttnAddedKVProcessor,
|
29 |
+
AttnProcessor,
|
30 |
+
)
|
31 |
+
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
32 |
+
from diffusers.models.modeling_utils import ModelMixin
|
33 |
+
from diffusers.models.unet_2d_blocks import (
|
34 |
+
CrossAttnDownBlock2D,
|
35 |
+
DownBlock2D,
|
36 |
+
UNetMidBlock2DCrossAttn,
|
37 |
+
get_down_block,
|
38 |
+
)
|
39 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
40 |
+
|
41 |
+
|
42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class ControlNetOutput(BaseOutput):
|
47 |
+
"""
|
48 |
+
The output of [`ControlNetModel`].
|
49 |
+
|
50 |
+
Args:
|
51 |
+
down_block_res_samples (`tuple[torch.Tensor]`):
|
52 |
+
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
53 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
54 |
+
used to condition the original UNet's downsampling activations.
|
55 |
+
mid_down_block_re_sample (`torch.Tensor`):
|
56 |
+
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
57 |
+
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
58 |
+
Output can be used to condition the original UNet's middle block activation.
|
59 |
+
"""
|
60 |
+
|
61 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
62 |
+
mid_block_res_sample: torch.Tensor
|
63 |
+
|
64 |
+
|
65 |
+
class ControlNetConditioningEmbedding(nn.Module):
|
66 |
+
"""
|
67 |
+
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
68 |
+
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
69 |
+
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
70 |
+
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
71 |
+
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
72 |
+
model) to encode image-space conditions ... into feature maps ..."
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
conditioning_embedding_channels: int,
|
78 |
+
conditioning_channels: int = 3,
|
79 |
+
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
84 |
+
|
85 |
+
self.blocks = nn.ModuleList([])
|
86 |
+
|
87 |
+
for i in range(len(block_out_channels) - 1):
|
88 |
+
channel_in = block_out_channels[i]
|
89 |
+
channel_out = block_out_channels[i + 1]
|
90 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
91 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
92 |
+
|
93 |
+
self.conv_out = zero_module(
|
94 |
+
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
95 |
+
)
|
96 |
+
|
97 |
+
def forward(self, conditioning):
|
98 |
+
embedding = self.conv_in(conditioning)
|
99 |
+
embedding = F.silu(embedding)
|
100 |
+
|
101 |
+
for block in self.blocks:
|
102 |
+
embedding = block(embedding)
|
103 |
+
embedding = F.silu(embedding)
|
104 |
+
|
105 |
+
embedding = self.conv_out(embedding)
|
106 |
+
|
107 |
+
return embedding
|
108 |
+
|
109 |
+
|
110 |
+
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
111 |
+
"""
|
112 |
+
A ControlNet model.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
in_channels (`int`, defaults to 4):
|
116 |
+
The number of channels in the input sample.
|
117 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
118 |
+
Whether to flip the sin to cos in the time embedding.
|
119 |
+
freq_shift (`int`, defaults to 0):
|
120 |
+
The frequency shift to apply to the time embedding.
|
121 |
+
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
122 |
+
The tuple of downsample blocks to use.
|
123 |
+
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
124 |
+
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
125 |
+
The tuple of output channels for each block.
|
126 |
+
layers_per_block (`int`, defaults to 2):
|
127 |
+
The number of layers per block.
|
128 |
+
downsample_padding (`int`, defaults to 1):
|
129 |
+
The padding to use for the downsampling convolution.
|
130 |
+
mid_block_scale_factor (`float`, defaults to 1):
|
131 |
+
The scale factor to use for the mid block.
|
132 |
+
act_fn (`str`, defaults to "silu"):
|
133 |
+
The activation function to use.
|
134 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
135 |
+
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
136 |
+
in post-processing.
|
137 |
+
norm_eps (`float`, defaults to 1e-5):
|
138 |
+
The epsilon to use for the normalization.
|
139 |
+
cross_attention_dim (`int`, defaults to 1280):
|
140 |
+
The dimension of the cross attention features.
|
141 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
142 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
143 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
144 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
145 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
146 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
147 |
+
dimension to `cross_attention_dim`.
|
148 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
149 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
150 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
151 |
+
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
152 |
+
The dimension of the attention heads.
|
153 |
+
use_linear_projection (`bool`, defaults to `False`):
|
154 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
155 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
156 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
157 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
158 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
159 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
160 |
+
num_class_embeds (`int`, *optional*, defaults to 0):
|
161 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
162 |
+
class conditioning with `class_embed_type` equal to `None`.
|
163 |
+
upcast_attention (`bool`, defaults to `False`):
|
164 |
+
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
165 |
+
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
166 |
+
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
167 |
+
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
168 |
+
`class_embed_type="projection"`.
|
169 |
+
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
170 |
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
171 |
+
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
172 |
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
173 |
+
global_pool_conditions (`bool`, defaults to `False`):
|
174 |
+
"""
|
175 |
+
|
176 |
+
_supports_gradient_checkpointing = True
|
177 |
+
|
178 |
+
@register_to_config
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
in_channels: int = 4,
|
182 |
+
conditioning_channels: int = 3,
|
183 |
+
flip_sin_to_cos: bool = True,
|
184 |
+
freq_shift: int = 0,
|
185 |
+
down_block_types: Tuple[str] = (
|
186 |
+
"CrossAttnDownBlock2D",
|
187 |
+
"CrossAttnDownBlock2D",
|
188 |
+
"CrossAttnDownBlock2D",
|
189 |
+
"DownBlock2D",
|
190 |
+
),
|
191 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
192 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
193 |
+
layers_per_block: int = 2,
|
194 |
+
downsample_padding: int = 1,
|
195 |
+
mid_block_scale_factor: float = 1,
|
196 |
+
act_fn: str = "silu",
|
197 |
+
norm_num_groups: Optional[int] = 32,
|
198 |
+
norm_eps: float = 1e-5,
|
199 |
+
cross_attention_dim: int = 1280,
|
200 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
201 |
+
encoder_hid_dim: Optional[int] = None,
|
202 |
+
encoder_hid_dim_type: Optional[str] = None,
|
203 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
204 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
205 |
+
use_linear_projection: bool = False,
|
206 |
+
class_embed_type: Optional[str] = None,
|
207 |
+
addition_embed_type: Optional[str] = None,
|
208 |
+
addition_time_embed_dim: Optional[int] = None,
|
209 |
+
num_class_embeds: Optional[int] = None,
|
210 |
+
upcast_attention: bool = False,
|
211 |
+
resnet_time_scale_shift: str = "default",
|
212 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
213 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
214 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
215 |
+
global_pool_conditions: bool = False,
|
216 |
+
addition_embed_type_num_heads=64,
|
217 |
+
):
|
218 |
+
super().__init__()
|
219 |
+
|
220 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
221 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
222 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
223 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
224 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
225 |
+
# which is why we correct for the naming here.
|
226 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
227 |
+
|
228 |
+
# Check inputs
|
229 |
+
if len(block_out_channels) != len(down_block_types):
|
230 |
+
raise ValueError(
|
231 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
232 |
+
)
|
233 |
+
|
234 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
235 |
+
raise ValueError(
|
236 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
237 |
+
)
|
238 |
+
|
239 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
240 |
+
raise ValueError(
|
241 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
242 |
+
)
|
243 |
+
|
244 |
+
if isinstance(transformer_layers_per_block, int):
|
245 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
246 |
+
|
247 |
+
# input
|
248 |
+
conv_in_kernel = 3
|
249 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
250 |
+
self.conv_in = nn.Conv2d(
|
251 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
252 |
+
)
|
253 |
+
|
254 |
+
# time
|
255 |
+
time_embed_dim = block_out_channels[0] * 4
|
256 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
257 |
+
timestep_input_dim = block_out_channels[0]
|
258 |
+
self.time_embedding = TimestepEmbedding(
|
259 |
+
timestep_input_dim,
|
260 |
+
time_embed_dim,
|
261 |
+
act_fn=act_fn,
|
262 |
+
)
|
263 |
+
|
264 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
265 |
+
encoder_hid_dim_type = "text_proj"
|
266 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
267 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
268 |
+
|
269 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
270 |
+
raise ValueError(
|
271 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
272 |
+
)
|
273 |
+
|
274 |
+
if encoder_hid_dim_type == "text_proj":
|
275 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
276 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
277 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
278 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
279 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
280 |
+
self.encoder_hid_proj = TextImageProjection(
|
281 |
+
text_embed_dim=encoder_hid_dim,
|
282 |
+
image_embed_dim=cross_attention_dim,
|
283 |
+
cross_attention_dim=cross_attention_dim,
|
284 |
+
)
|
285 |
+
|
286 |
+
elif encoder_hid_dim_type is not None:
|
287 |
+
raise ValueError(
|
288 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
289 |
+
)
|
290 |
+
else:
|
291 |
+
self.encoder_hid_proj = None
|
292 |
+
|
293 |
+
# class embedding
|
294 |
+
if class_embed_type is None and num_class_embeds is not None:
|
295 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
296 |
+
elif class_embed_type == "timestep":
|
297 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
298 |
+
elif class_embed_type == "identity":
|
299 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
300 |
+
elif class_embed_type == "projection":
|
301 |
+
if projection_class_embeddings_input_dim is None:
|
302 |
+
raise ValueError(
|
303 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
304 |
+
)
|
305 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
306 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
307 |
+
# 2. it projects from an arbitrary input dimension.
|
308 |
+
#
|
309 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
310 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
311 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
312 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
313 |
+
else:
|
314 |
+
self.class_embedding = None
|
315 |
+
|
316 |
+
if addition_embed_type == "text":
|
317 |
+
if encoder_hid_dim is not None:
|
318 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
319 |
+
else:
|
320 |
+
text_time_embedding_from_dim = cross_attention_dim
|
321 |
+
|
322 |
+
self.add_embedding = TextTimeEmbedding(
|
323 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
324 |
+
)
|
325 |
+
elif addition_embed_type == "text_image":
|
326 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
327 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
328 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
329 |
+
self.add_embedding = TextImageTimeEmbedding(
|
330 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
331 |
+
)
|
332 |
+
elif addition_embed_type == "text_time":
|
333 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
334 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
335 |
+
|
336 |
+
elif addition_embed_type is not None:
|
337 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
338 |
+
|
339 |
+
# control net conditioning embedding
|
340 |
+
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
341 |
+
conditioning_embedding_channels=block_out_channels[0],
|
342 |
+
block_out_channels=conditioning_embedding_out_channels,
|
343 |
+
conditioning_channels=conditioning_channels,
|
344 |
+
)
|
345 |
+
|
346 |
+
self.down_blocks = nn.ModuleList([])
|
347 |
+
self.controlnet_down_blocks = nn.ModuleList([])
|
348 |
+
|
349 |
+
if isinstance(only_cross_attention, bool):
|
350 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
351 |
+
|
352 |
+
if isinstance(attention_head_dim, int):
|
353 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
354 |
+
|
355 |
+
if isinstance(num_attention_heads, int):
|
356 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
357 |
+
|
358 |
+
# down
|
359 |
+
output_channel = block_out_channels[0]
|
360 |
+
|
361 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
362 |
+
controlnet_block = zero_module(controlnet_block)
|
363 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
364 |
+
|
365 |
+
for i, down_block_type in enumerate(down_block_types):
|
366 |
+
input_channel = output_channel
|
367 |
+
output_channel = block_out_channels[i]
|
368 |
+
is_final_block = i == len(block_out_channels) - 1
|
369 |
+
|
370 |
+
down_block = get_down_block(
|
371 |
+
down_block_type,
|
372 |
+
num_layers=layers_per_block,
|
373 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
374 |
+
in_channels=input_channel,
|
375 |
+
out_channels=output_channel,
|
376 |
+
temb_channels=time_embed_dim,
|
377 |
+
add_downsample=not is_final_block,
|
378 |
+
resnet_eps=norm_eps,
|
379 |
+
resnet_act_fn=act_fn,
|
380 |
+
resnet_groups=norm_num_groups,
|
381 |
+
cross_attention_dim=cross_attention_dim,
|
382 |
+
num_attention_heads=num_attention_heads[i],
|
383 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
384 |
+
downsample_padding=downsample_padding,
|
385 |
+
use_linear_projection=use_linear_projection,
|
386 |
+
only_cross_attention=only_cross_attention[i],
|
387 |
+
upcast_attention=upcast_attention,
|
388 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
389 |
+
)
|
390 |
+
self.down_blocks.append(down_block)
|
391 |
+
|
392 |
+
for _ in range(layers_per_block):
|
393 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
394 |
+
controlnet_block = zero_module(controlnet_block)
|
395 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
396 |
+
|
397 |
+
if not is_final_block:
|
398 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
399 |
+
controlnet_block = zero_module(controlnet_block)
|
400 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
401 |
+
|
402 |
+
# mid
|
403 |
+
mid_block_channel = block_out_channels[-1]
|
404 |
+
|
405 |
+
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
406 |
+
controlnet_block = zero_module(controlnet_block)
|
407 |
+
self.controlnet_mid_block = controlnet_block
|
408 |
+
|
409 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
410 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
411 |
+
in_channels=mid_block_channel,
|
412 |
+
temb_channels=time_embed_dim,
|
413 |
+
resnet_eps=norm_eps,
|
414 |
+
resnet_act_fn=act_fn,
|
415 |
+
output_scale_factor=mid_block_scale_factor,
|
416 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
417 |
+
cross_attention_dim=cross_attention_dim,
|
418 |
+
num_attention_heads=num_attention_heads[-1],
|
419 |
+
resnet_groups=norm_num_groups,
|
420 |
+
use_linear_projection=use_linear_projection,
|
421 |
+
upcast_attention=upcast_attention,
|
422 |
+
)
|
423 |
+
|
424 |
+
@classmethod
|
425 |
+
def from_unet(
|
426 |
+
cls,
|
427 |
+
unet: UNet2DConditionModel,
|
428 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
429 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
430 |
+
load_weights_from_unet: bool = True,
|
431 |
+
):
|
432 |
+
r"""
|
433 |
+
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
434 |
+
|
435 |
+
Parameters:
|
436 |
+
unet (`UNet2DConditionModel`):
|
437 |
+
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
438 |
+
where applicable.
|
439 |
+
"""
|
440 |
+
transformer_layers_per_block = (
|
441 |
+
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
442 |
+
)
|
443 |
+
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
444 |
+
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
445 |
+
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
446 |
+
addition_time_embed_dim = (
|
447 |
+
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
448 |
+
)
|
449 |
+
|
450 |
+
controlnet = cls(
|
451 |
+
encoder_hid_dim=encoder_hid_dim,
|
452 |
+
encoder_hid_dim_type=encoder_hid_dim_type,
|
453 |
+
addition_embed_type=addition_embed_type,
|
454 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
455 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
456 |
+
in_channels=unet.config.in_channels,
|
457 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
458 |
+
freq_shift=unet.config.freq_shift,
|
459 |
+
down_block_types=unet.config.down_block_types,
|
460 |
+
only_cross_attention=unet.config.only_cross_attention,
|
461 |
+
block_out_channels=unet.config.block_out_channels,
|
462 |
+
layers_per_block=unet.config.layers_per_block,
|
463 |
+
downsample_padding=unet.config.downsample_padding,
|
464 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
465 |
+
act_fn=unet.config.act_fn,
|
466 |
+
norm_num_groups=unet.config.norm_num_groups,
|
467 |
+
norm_eps=unet.config.norm_eps,
|
468 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
469 |
+
attention_head_dim=unet.config.attention_head_dim,
|
470 |
+
num_attention_heads=unet.config.num_attention_heads,
|
471 |
+
use_linear_projection=unet.config.use_linear_projection,
|
472 |
+
class_embed_type=unet.config.class_embed_type,
|
473 |
+
num_class_embeds=unet.config.num_class_embeds,
|
474 |
+
upcast_attention=unet.config.upcast_attention,
|
475 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
476 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
477 |
+
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
478 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
479 |
+
)
|
480 |
+
|
481 |
+
if load_weights_from_unet:
|
482 |
+
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
483 |
+
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
484 |
+
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
485 |
+
|
486 |
+
if controlnet.class_embedding:
|
487 |
+
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
488 |
+
|
489 |
+
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
|
490 |
+
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
491 |
+
|
492 |
+
return controlnet
|
493 |
+
|
494 |
+
@property
|
495 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
496 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
497 |
+
r"""
|
498 |
+
Returns:
|
499 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
500 |
+
indexed by its weight name.
|
501 |
+
"""
|
502 |
+
# set recursively
|
503 |
+
processors = {}
|
504 |
+
|
505 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
506 |
+
if hasattr(module, "get_processor"):
|
507 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
508 |
+
|
509 |
+
for sub_name, child in module.named_children():
|
510 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
511 |
+
|
512 |
+
return processors
|
513 |
+
|
514 |
+
for name, module in self.named_children():
|
515 |
+
fn_recursive_add_processors(name, module, processors)
|
516 |
+
|
517 |
+
return processors
|
518 |
+
|
519 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
520 |
+
def set_attn_processor(
|
521 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
522 |
+
):
|
523 |
+
r"""
|
524 |
+
Sets the attention processor to use to compute attention.
|
525 |
+
|
526 |
+
Parameters:
|
527 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
528 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
529 |
+
for **all** `Attention` layers.
|
530 |
+
|
531 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
532 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
533 |
+
|
534 |
+
"""
|
535 |
+
count = len(self.attn_processors.keys())
|
536 |
+
|
537 |
+
if isinstance(processor, dict) and len(processor) != count:
|
538 |
+
raise ValueError(
|
539 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
540 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
541 |
+
)
|
542 |
+
|
543 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
544 |
+
if hasattr(module, "set_processor"):
|
545 |
+
if not isinstance(processor, dict):
|
546 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
547 |
+
else:
|
548 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
549 |
+
|
550 |
+
for sub_name, child in module.named_children():
|
551 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
552 |
+
|
553 |
+
for name, module in self.named_children():
|
554 |
+
fn_recursive_attn_processor(name, module, processor)
|
555 |
+
|
556 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
557 |
+
def set_default_attn_processor(self):
|
558 |
+
"""
|
559 |
+
Disables custom attention processors and sets the default attention implementation.
|
560 |
+
"""
|
561 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
562 |
+
processor = AttnAddedKVProcessor()
|
563 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
564 |
+
processor = AttnProcessor()
|
565 |
+
else:
|
566 |
+
raise ValueError(
|
567 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
568 |
+
)
|
569 |
+
|
570 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
571 |
+
|
572 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
573 |
+
def set_attention_slice(self, slice_size):
|
574 |
+
r"""
|
575 |
+
Enable sliced attention computation.
|
576 |
+
|
577 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
578 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
579 |
+
|
580 |
+
Args:
|
581 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
582 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
583 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
584 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
585 |
+
must be a multiple of `slice_size`.
|
586 |
+
"""
|
587 |
+
sliceable_head_dims = []
|
588 |
+
|
589 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
590 |
+
if hasattr(module, "set_attention_slice"):
|
591 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
592 |
+
|
593 |
+
for child in module.children():
|
594 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
595 |
+
|
596 |
+
# retrieve number of attention layers
|
597 |
+
for module in self.children():
|
598 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
599 |
+
|
600 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
601 |
+
|
602 |
+
if slice_size == "auto":
|
603 |
+
# half the attention head size is usually a good trade-off between
|
604 |
+
# speed and memory
|
605 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
606 |
+
elif slice_size == "max":
|
607 |
+
# make smallest slice possible
|
608 |
+
slice_size = num_sliceable_layers * [1]
|
609 |
+
|
610 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
611 |
+
|
612 |
+
if len(slice_size) != len(sliceable_head_dims):
|
613 |
+
raise ValueError(
|
614 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
615 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
616 |
+
)
|
617 |
+
|
618 |
+
for i in range(len(slice_size)):
|
619 |
+
size = slice_size[i]
|
620 |
+
dim = sliceable_head_dims[i]
|
621 |
+
if size is not None and size > dim:
|
622 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
623 |
+
|
624 |
+
# Recursively walk through all the children.
|
625 |
+
# Any children which exposes the set_attention_slice method
|
626 |
+
# gets the message
|
627 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
628 |
+
if hasattr(module, "set_attention_slice"):
|
629 |
+
module.set_attention_slice(slice_size.pop())
|
630 |
+
|
631 |
+
for child in module.children():
|
632 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
633 |
+
|
634 |
+
reversed_slice_size = list(reversed(slice_size))
|
635 |
+
for module in self.children():
|
636 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
637 |
+
|
638 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
639 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
640 |
+
module.gradient_checkpointing = value
|
641 |
+
|
642 |
+
def forward(
|
643 |
+
self,
|
644 |
+
sample: torch.FloatTensor,
|
645 |
+
timestep: Union[torch.Tensor, float, int],
|
646 |
+
encoder_hidden_states: torch.Tensor,
|
647 |
+
controlnet_cond: torch.FloatTensor,
|
648 |
+
conditioning_scale: float = 1.0,
|
649 |
+
class_labels: Optional[torch.Tensor] = None,
|
650 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
651 |
+
attention_mask: Optional[torch.Tensor] = None,
|
652 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
653 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
654 |
+
guess_mode: bool = False,
|
655 |
+
return_dict: bool = True,
|
656 |
+
) -> Union[ControlNetOutput, Tuple]:
|
657 |
+
"""
|
658 |
+
The [`ControlNetModel`] forward method.
|
659 |
+
|
660 |
+
Args:
|
661 |
+
sample (`torch.FloatTensor`):
|
662 |
+
The noisy input tensor.
|
663 |
+
timestep (`Union[torch.Tensor, float, int]`):
|
664 |
+
The number of timesteps to denoise an input.
|
665 |
+
encoder_hidden_states (`torch.Tensor`):
|
666 |
+
The encoder hidden states.
|
667 |
+
controlnet_cond (`torch.FloatTensor`):
|
668 |
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
669 |
+
conditioning_scale (`float`, defaults to `1.0`):
|
670 |
+
The scale factor for ControlNet outputs.
|
671 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
672 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
673 |
+
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
674 |
+
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
675 |
+
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
676 |
+
embeddings.
|
677 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
678 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
679 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
680 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
681 |
+
added_cond_kwargs (`dict`):
|
682 |
+
Additional conditions for the Stable Diffusion XL UNet.
|
683 |
+
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
684 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
685 |
+
guess_mode (`bool`, defaults to `False`):
|
686 |
+
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
687 |
+
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
688 |
+
return_dict (`bool`, defaults to `True`):
|
689 |
+
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
690 |
+
|
691 |
+
Returns:
|
692 |
+
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
693 |
+
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
694 |
+
returned where the first element is the sample tensor.
|
695 |
+
"""
|
696 |
+
# check channel order
|
697 |
+
channel_order = self.config.controlnet_conditioning_channel_order
|
698 |
+
|
699 |
+
if channel_order == "rgb":
|
700 |
+
# in rgb order by default
|
701 |
+
...
|
702 |
+
elif channel_order == "bgr":
|
703 |
+
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
704 |
+
else:
|
705 |
+
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
706 |
+
|
707 |
+
# prepare attention_mask
|
708 |
+
if attention_mask is not None:
|
709 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
710 |
+
attention_mask = attention_mask.unsqueeze(1)
|
711 |
+
|
712 |
+
# 1. time
|
713 |
+
timesteps = timestep
|
714 |
+
if not torch.is_tensor(timesteps):
|
715 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
716 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
717 |
+
is_mps = sample.device.type == "mps"
|
718 |
+
if isinstance(timestep, float):
|
719 |
+
dtype = torch.float32 if is_mps else torch.float64
|
720 |
+
else:
|
721 |
+
dtype = torch.int32 if is_mps else torch.int64
|
722 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
723 |
+
elif len(timesteps.shape) == 0:
|
724 |
+
timesteps = timesteps[None].to(sample.device)
|
725 |
+
|
726 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
727 |
+
timesteps = timesteps.expand(sample.shape[0])
|
728 |
+
|
729 |
+
t_emb = self.time_proj(timesteps)
|
730 |
+
|
731 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
732 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
733 |
+
# there might be better ways to encapsulate this.
|
734 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
735 |
+
|
736 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
737 |
+
aug_emb = None
|
738 |
+
|
739 |
+
if self.class_embedding is not None:
|
740 |
+
if class_labels is None:
|
741 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
742 |
+
|
743 |
+
if self.config.class_embed_type == "timestep":
|
744 |
+
class_labels = self.time_proj(class_labels)
|
745 |
+
|
746 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
747 |
+
emb = emb + class_emb
|
748 |
+
|
749 |
+
if self.config.addition_embed_type is not None:
|
750 |
+
if self.config.addition_embed_type == "text":
|
751 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
752 |
+
|
753 |
+
elif self.config.addition_embed_type == "text_time":
|
754 |
+
if "text_embeds" not in added_cond_kwargs:
|
755 |
+
raise ValueError(
|
756 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
757 |
+
)
|
758 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
759 |
+
if "time_ids" not in added_cond_kwargs:
|
760 |
+
raise ValueError(
|
761 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
762 |
+
)
|
763 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
764 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
765 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
766 |
+
|
767 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
768 |
+
add_embeds = add_embeds.to(emb.dtype)
|
769 |
+
aug_emb = self.add_embedding(add_embeds)
|
770 |
+
|
771 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
772 |
+
|
773 |
+
# 2. pre-process
|
774 |
+
sample = self.conv_in(sample)
|
775 |
+
|
776 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
777 |
+
sample = sample + controlnet_cond
|
778 |
+
|
779 |
+
# 3. down
|
780 |
+
down_block_res_samples = (sample,)
|
781 |
+
for downsample_block in self.down_blocks:
|
782 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
783 |
+
sample, res_samples = downsample_block(
|
784 |
+
hidden_states=sample,
|
785 |
+
temb=emb,
|
786 |
+
encoder_hidden_states=encoder_hidden_states,
|
787 |
+
attention_mask=attention_mask,
|
788 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
789 |
+
)
|
790 |
+
else:
|
791 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
792 |
+
|
793 |
+
down_block_res_samples += res_samples
|
794 |
+
|
795 |
+
# 4. mid
|
796 |
+
if self.mid_block is not None:
|
797 |
+
sample = self.mid_block(
|
798 |
+
sample,
|
799 |
+
emb,
|
800 |
+
encoder_hidden_states=encoder_hidden_states,
|
801 |
+
attention_mask=attention_mask,
|
802 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
803 |
+
)
|
804 |
+
|
805 |
+
# 5. Control net blocks
|
806 |
+
|
807 |
+
controlnet_down_block_res_samples = ()
|
808 |
+
|
809 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
810 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
811 |
+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
812 |
+
|
813 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
814 |
+
|
815 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
816 |
+
|
817 |
+
# 6. scaling
|
818 |
+
if guess_mode and not self.config.global_pool_conditions:
|
819 |
+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
820 |
+
scales = scales * conditioning_scale
|
821 |
+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
822 |
+
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
823 |
+
else:
|
824 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
825 |
+
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
826 |
+
|
827 |
+
if self.config.global_pool_conditions:
|
828 |
+
down_block_res_samples = [
|
829 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
830 |
+
]
|
831 |
+
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
832 |
+
|
833 |
+
if not return_dict:
|
834 |
+
return (down_block_res_samples, mid_block_res_sample)
|
835 |
+
|
836 |
+
return ControlNetOutput(
|
837 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
838 |
+
)
|
839 |
+
|
840 |
+
|
841 |
+
def zero_module(module):
|
842 |
+
for p in module.parameters():
|
843 |
+
nn.init.zeros_(p)
|
844 |
+
return module
|
pipelines/inverted_ve_pipeline.py
ADDED
@@ -0,0 +1,1615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from diffusers import StableDiffusionPipeline
|
3 |
+
import torch
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Callable, List, Optional, Union, Any, Dict
|
6 |
+
import numpy as np
|
7 |
+
from diffusers.utils import deprecate, logging, BaseOutput
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
from torch.nn.functional import grid_sample
|
10 |
+
from torch.nn import functional as nnf
|
11 |
+
import torchvision.transforms as T
|
12 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
13 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel, attention_processor
|
14 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
15 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
16 |
+
import PIL
|
17 |
+
from PIL import Image
|
18 |
+
from kornia.morphology import dilation
|
19 |
+
from collections import OrderedDict
|
20 |
+
from packaging import version
|
21 |
+
import inspect
|
22 |
+
from diffusers.utils import (
|
23 |
+
deprecate,
|
24 |
+
is_accelerate_available,
|
25 |
+
is_accelerate_version,
|
26 |
+
logging,
|
27 |
+
replace_example_docstring,
|
28 |
+
)
|
29 |
+
from diffusers.utils.torch_utils import randn_tensor
|
30 |
+
import torch.nn as nn
|
31 |
+
|
32 |
+
T = torch.Tensor
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass(frozen=True)
|
36 |
+
class StyleAlignedArgs:
|
37 |
+
share_group_norm: bool = True
|
38 |
+
share_layer_norm: bool = True,
|
39 |
+
share_attention: bool = True
|
40 |
+
adain_queries: bool = True
|
41 |
+
adain_keys: bool = True
|
42 |
+
adain_values: bool = False
|
43 |
+
full_attention_share: bool = False
|
44 |
+
keys_scale: float = 1.
|
45 |
+
only_self_level: float = 0.
|
46 |
+
|
47 |
+
def expand_first(feat: T, scale=1., ) -> T:
|
48 |
+
b = feat.shape[0]
|
49 |
+
feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
|
50 |
+
if scale == 1:
|
51 |
+
feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
|
52 |
+
else:
|
53 |
+
feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
|
54 |
+
feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
|
55 |
+
return feat_style.reshape(*feat.shape)
|
56 |
+
|
57 |
+
|
58 |
+
def concat_first(feat: T, dim=2, scale=1.) -> T:
|
59 |
+
feat_style = expand_first(feat, scale=scale)
|
60 |
+
return torch.cat((feat, feat_style), dim=dim)
|
61 |
+
|
62 |
+
|
63 |
+
def calc_mean_std(feat, eps: float = 1e-5) -> tuple[T, T]:
|
64 |
+
feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
|
65 |
+
feat_mean = feat.mean(dim=-2, keepdims=True)
|
66 |
+
return feat_mean, feat_std
|
67 |
+
|
68 |
+
|
69 |
+
def adain(feat: T) -> T:
|
70 |
+
feat_mean, feat_std = calc_mean_std(feat)
|
71 |
+
feat_style_mean = expand_first(feat_mean)
|
72 |
+
feat_style_std = expand_first(feat_std)
|
73 |
+
feat = (feat - feat_mean) / feat_std
|
74 |
+
feat = feat * feat_style_std + feat_style_mean
|
75 |
+
return feat
|
76 |
+
|
77 |
+
|
78 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
79 |
+
|
80 |
+
|
81 |
+
EXAMPLE_DOC_STRING = """
|
82 |
+
Examples:
|
83 |
+
```py
|
84 |
+
>>> import torch
|
85 |
+
>>> from diffusers import StableDiffusionPipeline
|
86 |
+
|
87 |
+
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
88 |
+
>>> pipe = pipe.to("cuda")
|
89 |
+
|
90 |
+
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
91 |
+
>>> image = pipe(prompt).images[0]
|
92 |
+
```
|
93 |
+
"""
|
94 |
+
|
95 |
+
# ACTIVATE_STEP_CANDIDATE = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 1]
|
96 |
+
|
97 |
+
|
98 |
+
def create_image_grid(image_list, rows, cols, padding=10):
|
99 |
+
# Ensure the number of rows and columns doesn't exceed the number of images
|
100 |
+
rows = min(rows, len(image_list))
|
101 |
+
cols = min(cols, len(image_list))
|
102 |
+
|
103 |
+
# Get the dimensions of a single image
|
104 |
+
image_width, image_height = image_list[0].size
|
105 |
+
|
106 |
+
# Calculate the size of the output image
|
107 |
+
grid_width = cols * (image_width + padding) - padding
|
108 |
+
grid_height = rows * (image_height + padding) - padding
|
109 |
+
|
110 |
+
# Create an empty grid image
|
111 |
+
grid_image = Image.new('RGB', (grid_width, grid_height), (255, 255, 255))
|
112 |
+
|
113 |
+
# Paste images into the grid
|
114 |
+
for i, img in enumerate(image_list[:rows * cols]):
|
115 |
+
row = i // cols
|
116 |
+
col = i % cols
|
117 |
+
x = col * (image_width + padding)
|
118 |
+
y = row * (image_height + padding)
|
119 |
+
grid_image.paste(img, (x, y))
|
120 |
+
|
121 |
+
return grid_image
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
class CrossFrameAttnProcessor_backup:
|
127 |
+
def __init__(self, unet_chunk_size=2):
|
128 |
+
self.unet_chunk_size = unet_chunk_size
|
129 |
+
|
130 |
+
def __call__(
|
131 |
+
self,
|
132 |
+
attn,
|
133 |
+
hidden_states,
|
134 |
+
encoder_hidden_states=None,
|
135 |
+
attention_mask=None):
|
136 |
+
|
137 |
+
|
138 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
139 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
140 |
+
query = attn.to_q(hidden_states)
|
141 |
+
|
142 |
+
is_cross_attention = encoder_hidden_states is not None
|
143 |
+
if encoder_hidden_states is None:
|
144 |
+
encoder_hidden_states = hidden_states
|
145 |
+
# elif attn.cross_attention_norm:
|
146 |
+
# encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
|
147 |
+
key = attn.to_k(encoder_hidden_states)
|
148 |
+
value = attn.to_v(encoder_hidden_states)
|
149 |
+
# Sparse Attention
|
150 |
+
if not is_cross_attention:
|
151 |
+
video_length = key.size()[0] // self.unet_chunk_size
|
152 |
+
# former_frame_index = torch.arange(video_length) - 1
|
153 |
+
# former_frame_index[0] = 0
|
154 |
+
# import pdb; pdb.set_trace()
|
155 |
+
|
156 |
+
# if video_length > 3:
|
157 |
+
# import pdb; pdb.set_trace()
|
158 |
+
former_frame_index = [0] * video_length
|
159 |
+
key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
|
160 |
+
key = key[:, former_frame_index]
|
161 |
+
key = rearrange(key, "b f d c -> (b f) d c")
|
162 |
+
value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
|
163 |
+
value = value[:, former_frame_index]
|
164 |
+
value = rearrange(value, "b f d c -> (b f) d c")
|
165 |
+
|
166 |
+
|
167 |
+
query = attn.head_to_batch_dim(query)
|
168 |
+
key = attn.head_to_batch_dim(key)
|
169 |
+
value = attn.head_to_batch_dim(value)
|
170 |
+
|
171 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
172 |
+
hidden_states = torch.bmm(attention_probs, value)
|
173 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
174 |
+
|
175 |
+
# linear proj
|
176 |
+
hidden_states = attn.to_out[0](hidden_states)
|
177 |
+
# dropout
|
178 |
+
hidden_states = attn.to_out[1](hidden_states)
|
179 |
+
|
180 |
+
return hidden_states
|
181 |
+
|
182 |
+
|
183 |
+
class SharedAttentionProcessor:
|
184 |
+
def __init__(self,
|
185 |
+
adain_keys=True,
|
186 |
+
adain_queries=True,
|
187 |
+
adain_values=False,
|
188 |
+
keys_scale=1.,
|
189 |
+
attn_map_save_steps=[]):
|
190 |
+
super().__init__()
|
191 |
+
self.adain_queries = adain_queries
|
192 |
+
self.adain_keys = adain_keys
|
193 |
+
self.adain_values = adain_values
|
194 |
+
# self.full_attention_share = style_aligned_args.full_attention_share
|
195 |
+
self.keys_scale = keys_scale
|
196 |
+
self.attn_map_save_steps = attn_map_save_steps
|
197 |
+
|
198 |
+
|
199 |
+
def __call__(
|
200 |
+
self,
|
201 |
+
attn: attention_processor.Attention,
|
202 |
+
hidden_states,
|
203 |
+
encoder_hidden_states=None,
|
204 |
+
attention_mask=None,
|
205 |
+
**kwargs
|
206 |
+
):
|
207 |
+
|
208 |
+
if not hasattr(attn, "attn_map"):
|
209 |
+
setattr(attn, "attn_map", {})
|
210 |
+
setattr(attn, "inference_step", 0)
|
211 |
+
else:
|
212 |
+
attn.inference_step += 1
|
213 |
+
|
214 |
+
residual = hidden_states
|
215 |
+
input_ndim = hidden_states.ndim
|
216 |
+
if input_ndim == 4:
|
217 |
+
batch_size, channel, height, width = hidden_states.shape
|
218 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
219 |
+
batch_size, sequence_length, _ = (
|
220 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
221 |
+
)
|
222 |
+
|
223 |
+
is_cross_attention = encoder_hidden_states is not None
|
224 |
+
|
225 |
+
if attention_mask is not None:
|
226 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
227 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
228 |
+
# (batch, heads, source_length, target_length)
|
229 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
230 |
+
|
231 |
+
if attn.group_norm is not None:
|
232 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
233 |
+
|
234 |
+
query = attn.to_q(hidden_states)
|
235 |
+
|
236 |
+
if encoder_hidden_states is None:
|
237 |
+
encoder_hidden_states = hidden_states
|
238 |
+
# elif attn.cross_attention_norm:
|
239 |
+
# encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
|
240 |
+
key = attn.to_k(encoder_hidden_states)
|
241 |
+
value = attn.to_v(encoder_hidden_states)
|
242 |
+
|
243 |
+
inner_dim = key.shape[-1]
|
244 |
+
head_dim = inner_dim // attn.heads
|
245 |
+
|
246 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
247 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
248 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
249 |
+
# if self.step >= self.start_inject:
|
250 |
+
|
251 |
+
|
252 |
+
if not is_cross_attention:# and self.share_attention:
|
253 |
+
if self.adain_queries:
|
254 |
+
query = adain(query)
|
255 |
+
if self.adain_keys:
|
256 |
+
key = adain(key)
|
257 |
+
if self.adain_values:
|
258 |
+
value = adain(value)
|
259 |
+
key = concat_first(key, -2, scale=self.keys_scale)
|
260 |
+
value = concat_first(value, -2)
|
261 |
+
hidden_states = nnf.scaled_dot_product_attention(
|
262 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
263 |
+
)
|
264 |
+
else:
|
265 |
+
hidden_states = nnf.scaled_dot_product_attention(
|
266 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
267 |
+
)
|
268 |
+
|
269 |
+
|
270 |
+
|
271 |
+
|
272 |
+
# hidden_states = adain(hidden_states)
|
273 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
274 |
+
hidden_states = hidden_states.to(query.dtype)
|
275 |
+
|
276 |
+
# linear proj
|
277 |
+
hidden_states = attn.to_out[0](hidden_states)
|
278 |
+
# dropout
|
279 |
+
hidden_states = attn.to_out[1](hidden_states)
|
280 |
+
|
281 |
+
if input_ndim == 4:
|
282 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
283 |
+
|
284 |
+
if attn.residual_connection:
|
285 |
+
hidden_states = hidden_states + residual
|
286 |
+
|
287 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
288 |
+
return hidden_states
|
289 |
+
|
290 |
+
|
291 |
+
class SharedAttentionProcessor_v2:
|
292 |
+
def __init__(self,
|
293 |
+
adain_keys=True,
|
294 |
+
adain_queries=True,
|
295 |
+
adain_values=False,
|
296 |
+
keys_scale=1.,
|
297 |
+
attn_map_save_steps=[]):
|
298 |
+
super().__init__()
|
299 |
+
self.adain_queries = adain_queries
|
300 |
+
self.adain_keys = adain_keys
|
301 |
+
self.adain_values = adain_values
|
302 |
+
# self.full_attention_share = style_aligned_args.full_attention_share
|
303 |
+
self.keys_scale = keys_scale
|
304 |
+
self.attn_map_save_steps = attn_map_save_steps
|
305 |
+
|
306 |
+
|
307 |
+
def __call__(
|
308 |
+
self,
|
309 |
+
attn: attention_processor.Attention,
|
310 |
+
hidden_states,
|
311 |
+
encoder_hidden_states=None,
|
312 |
+
attention_mask=None,
|
313 |
+
**kwargs
|
314 |
+
):
|
315 |
+
|
316 |
+
if not hasattr(attn, "attn_map"):
|
317 |
+
setattr(attn, "attn_map", {})
|
318 |
+
setattr(attn, "inference_step", 0)
|
319 |
+
else:
|
320 |
+
attn.inference_step += 1
|
321 |
+
|
322 |
+
residual = hidden_states
|
323 |
+
input_ndim = hidden_states.ndim
|
324 |
+
if input_ndim == 4:
|
325 |
+
batch_size, channel, height, width = hidden_states.shape
|
326 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
327 |
+
batch_size, sequence_length, _ = (
|
328 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
329 |
+
)
|
330 |
+
|
331 |
+
is_cross_attention = encoder_hidden_states is not None
|
332 |
+
|
333 |
+
if attention_mask is not None:
|
334 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
335 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
336 |
+
# (batch, heads, source_length, target_length)
|
337 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
338 |
+
|
339 |
+
if attn.group_norm is not None:
|
340 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
341 |
+
|
342 |
+
query = attn.to_q(hidden_states)
|
343 |
+
|
344 |
+
|
345 |
+
if encoder_hidden_states is None:
|
346 |
+
encoder_hidden_states = hidden_states
|
347 |
+
# elif attn.cross_attention_norm:
|
348 |
+
# encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
|
349 |
+
key = attn.to_k(encoder_hidden_states)
|
350 |
+
value = attn.to_v(encoder_hidden_states)
|
351 |
+
|
352 |
+
tmp_query_shape = query.shape
|
353 |
+
tmp_key_shape = key.shape
|
354 |
+
tmp_value_shape = value.shape
|
355 |
+
|
356 |
+
|
357 |
+
inner_dim = key.shape[-1]
|
358 |
+
head_dim = inner_dim // attn.heads
|
359 |
+
|
360 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
361 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
362 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
363 |
+
# if self.step >= self.start_inject:
|
364 |
+
|
365 |
+
|
366 |
+
if not is_cross_attention:# and self.share_attention:
|
367 |
+
if self.adain_queries:
|
368 |
+
query = adain(query)
|
369 |
+
if self.adain_keys:
|
370 |
+
key = adain(key)
|
371 |
+
if self.adain_values:
|
372 |
+
value = adain(value)
|
373 |
+
key = concat_first(key, -2, scale=self.keys_scale)
|
374 |
+
value = concat_first(value, -2)
|
375 |
+
# hidden_states = nnf.scaled_dot_product_attention(
|
376 |
+
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
377 |
+
# )
|
378 |
+
|
379 |
+
if attn.inference_step in self.attn_map_save_steps:
|
380 |
+
|
381 |
+
query = query.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
382 |
+
key = key.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
383 |
+
value = value.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
384 |
+
|
385 |
+
query = attn.head_to_batch_dim(query)
|
386 |
+
key = attn.head_to_batch_dim(key)
|
387 |
+
value = attn.head_to_batch_dim(value)
|
388 |
+
|
389 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
390 |
+
|
391 |
+
if attn.inference_step in self.attn_map_save_steps:
|
392 |
+
attn.attn_map[attn.inference_step] = attention_probs.clone().cpu().detach()
|
393 |
+
|
394 |
+
hidden_states = torch.bmm(attention_probs, value)
|
395 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
396 |
+
else:
|
397 |
+
hidden_states = nnf.scaled_dot_product_attention(
|
398 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
399 |
+
)
|
400 |
+
# hidden_states = adain(hidden_states)
|
401 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
402 |
+
hidden_states = hidden_states.to(query.dtype)
|
403 |
+
|
404 |
+
else:
|
405 |
+
|
406 |
+
hidden_states = nnf.scaled_dot_product_attention(
|
407 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
408 |
+
)
|
409 |
+
# hidden_states = adain(hidden_states)
|
410 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
411 |
+
hidden_states = hidden_states.to(query.dtype)
|
412 |
+
|
413 |
+
# linear proj
|
414 |
+
hidden_states = attn.to_out[0](hidden_states)
|
415 |
+
# dropout
|
416 |
+
hidden_states = attn.to_out[1](hidden_states)
|
417 |
+
|
418 |
+
if input_ndim == 4:
|
419 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
420 |
+
|
421 |
+
if attn.residual_connection:
|
422 |
+
hidden_states = hidden_states + residual
|
423 |
+
|
424 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
425 |
+
|
426 |
+
if attn.inference_step == 49:
|
427 |
+
#initialize inference step
|
428 |
+
attn.inference_step = -1
|
429 |
+
|
430 |
+
return hidden_states
|
431 |
+
|
432 |
+
|
433 |
+
def swapping_attention(key, value, chunk_size=2):
|
434 |
+
chunk_length = key.size()[0] // chunk_size # [text-condition, null-condition]
|
435 |
+
reference_image_index = [0] * chunk_length # [0 0 0 0 0]
|
436 |
+
key = rearrange(key, "(b f) d c -> b f d c", f=chunk_length)
|
437 |
+
key = key[:, reference_image_index] # ref to all
|
438 |
+
key = rearrange(key, "b f d c -> (b f) d c")
|
439 |
+
value = rearrange(value, "(b f) d c -> b f d c", f=chunk_length)
|
440 |
+
value = value[:, reference_image_index] # ref to all
|
441 |
+
value = rearrange(value, "b f d c -> (b f) d c")
|
442 |
+
|
443 |
+
return key, value
|
444 |
+
|
445 |
+
class CrossFrameAttnProcessor:
|
446 |
+
def __init__(self, unet_chunk_size=2, attn_map_save_steps=[],activate_step_indices=None):
|
447 |
+
self.unet_chunk_size = unet_chunk_size
|
448 |
+
self.attn_map_save_steps = attn_map_save_steps
|
449 |
+
self.activate_step_indices = activate_step_indices
|
450 |
+
|
451 |
+
def __call__(
|
452 |
+
self,
|
453 |
+
attn,
|
454 |
+
hidden_states,
|
455 |
+
encoder_hidden_states=None,
|
456 |
+
attention_mask=None):
|
457 |
+
|
458 |
+
if not hasattr(attn, "attn_map"):
|
459 |
+
setattr(attn, "attn_map", {})
|
460 |
+
setattr(attn, "inference_step", 0)
|
461 |
+
else:
|
462 |
+
attn.inference_step += 1
|
463 |
+
|
464 |
+
|
465 |
+
|
466 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
467 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
468 |
+
query = attn.to_q(hidden_states)
|
469 |
+
|
470 |
+
is_cross_attention = encoder_hidden_states is not None
|
471 |
+
if encoder_hidden_states is None:
|
472 |
+
encoder_hidden_states = hidden_states
|
473 |
+
|
474 |
+
key = attn.to_k(encoder_hidden_states)
|
475 |
+
value = attn.to_v(encoder_hidden_states)
|
476 |
+
|
477 |
+
is_in_inference_step = False
|
478 |
+
|
479 |
+
if self.activate_step_indices is not None:
|
480 |
+
for activate_step_index in self.activate_step_indices:
|
481 |
+
if attn.inference_step >= activate_step_index[0] and attn.inference_step <= activate_step_index[1]:
|
482 |
+
is_in_inference_step = True
|
483 |
+
break
|
484 |
+
|
485 |
+
# Swapping Attention
|
486 |
+
if not is_cross_attention and is_in_inference_step:
|
487 |
+
key, value = swapping_attention(key, value, self.unet_chunk_size)
|
488 |
+
|
489 |
+
|
490 |
+
|
491 |
+
|
492 |
+
query = attn.head_to_batch_dim(query)
|
493 |
+
key = attn.head_to_batch_dim(key)
|
494 |
+
value = attn.head_to_batch_dim(value)
|
495 |
+
|
496 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
497 |
+
|
498 |
+
if attn.inference_step in self.attn_map_save_steps:
|
499 |
+
attn.attn_map[attn.inference_step] = attention_probs.clone().cpu().detach()
|
500 |
+
|
501 |
+
hidden_states = torch.bmm(attention_probs, value)
|
502 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
503 |
+
|
504 |
+
# linear proj
|
505 |
+
hidden_states = attn.to_out[0](hidden_states)
|
506 |
+
# dropout
|
507 |
+
hidden_states = attn.to_out[1](hidden_states)
|
508 |
+
|
509 |
+
if attn.inference_step == 49:
|
510 |
+
attn.inference_step = -1
|
511 |
+
|
512 |
+
return hidden_states
|
513 |
+
|
514 |
+
|
515 |
+
|
516 |
+
|
517 |
+
class CrossFrameAttnProcessor4Inversion:
|
518 |
+
def __init__(self, unet_chunk_size=2, attn_map_save_steps=[],activate_step_indices=None):
|
519 |
+
self.unet_chunk_size = unet_chunk_size
|
520 |
+
self.attn_map_save_steps = attn_map_save_steps
|
521 |
+
self.activate_step_indices = activate_step_indices
|
522 |
+
|
523 |
+
def __call__(
|
524 |
+
self,
|
525 |
+
attn,
|
526 |
+
hidden_states,
|
527 |
+
encoder_hidden_states=None,
|
528 |
+
attention_mask=None):
|
529 |
+
|
530 |
+
if not hasattr(attn, "attn_map"):
|
531 |
+
setattr(attn, "attn_map", {})
|
532 |
+
setattr(attn, "inference_step", 0)
|
533 |
+
else:
|
534 |
+
attn.inference_step += 1
|
535 |
+
|
536 |
+
|
537 |
+
|
538 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
539 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
540 |
+
query = attn.to_q(hidden_states)
|
541 |
+
|
542 |
+
is_cross_attention = encoder_hidden_states is not None
|
543 |
+
if encoder_hidden_states is None:
|
544 |
+
encoder_hidden_states = hidden_states
|
545 |
+
# elif attn.cross_attention_norm:
|
546 |
+
# encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
|
547 |
+
key = attn.to_k(encoder_hidden_states)
|
548 |
+
value = attn.to_v(encoder_hidden_states)
|
549 |
+
|
550 |
+
is_in_inference_step = False
|
551 |
+
|
552 |
+
if self.activate_step_indices is not None:
|
553 |
+
for activate_step_index in self.activate_step_indices:
|
554 |
+
if attn.inference_step >= activate_step_index[0] and attn.inference_step <= activate_step_index[1]:
|
555 |
+
is_in_inference_step = True
|
556 |
+
break
|
557 |
+
|
558 |
+
# Swapping Attention
|
559 |
+
if not is_cross_attention and is_in_inference_step:
|
560 |
+
key, value = swapping_attention(key, value, self.unet_chunk_size)
|
561 |
+
|
562 |
+
|
563 |
+
|
564 |
+
query = attn.head_to_batch_dim(query)
|
565 |
+
key = attn.head_to_batch_dim(key)
|
566 |
+
value = attn.head_to_batch_dim(value)
|
567 |
+
|
568 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
569 |
+
|
570 |
+
# if attn.inference_step > 45 and attn.inference_step < 50:
|
571 |
+
# if attn.inference_step == 42 or attn.inference_step==49:
|
572 |
+
if attn.inference_step in self.attn_map_save_steps:
|
573 |
+
attn.attn_map[attn.inference_step] = attention_probs.clone().cpu().detach()
|
574 |
+
|
575 |
+
hidden_states = torch.bmm(attention_probs, value)
|
576 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
577 |
+
|
578 |
+
# linear proj
|
579 |
+
hidden_states = attn.to_out[0](hidden_states)
|
580 |
+
# dropout
|
581 |
+
hidden_states = attn.to_out[1](hidden_states)
|
582 |
+
|
583 |
+
if attn.inference_step == 49:
|
584 |
+
#initialize inference step
|
585 |
+
attn.inference_step = -1
|
586 |
+
|
587 |
+
return hidden_states
|
588 |
+
|
589 |
+
|
590 |
+
|
591 |
+
class CrossFrameAttnProcessor_store:
|
592 |
+
def __init__(self, unet_chunk_size=2, attn_map_save_steps=[]):
|
593 |
+
self.unet_chunk_size = unet_chunk_size
|
594 |
+
self.attn_map_save_steps = attn_map_save_steps
|
595 |
+
|
596 |
+
def __call__(
|
597 |
+
self,
|
598 |
+
attn,
|
599 |
+
hidden_states,
|
600 |
+
encoder_hidden_states=None,
|
601 |
+
attention_mask=None):
|
602 |
+
|
603 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
604 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
605 |
+
query = attn.to_q(hidden_states)
|
606 |
+
|
607 |
+
is_cross_attention = encoder_hidden_states is not None
|
608 |
+
if encoder_hidden_states is None:
|
609 |
+
encoder_hidden_states = hidden_states
|
610 |
+
# elif attn.cross_attention_norm:
|
611 |
+
# encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
|
612 |
+
key = attn.to_k(encoder_hidden_states)
|
613 |
+
value = attn.to_v(encoder_hidden_states)
|
614 |
+
|
615 |
+
# Swapping Attention
|
616 |
+
if not is_cross_attention:
|
617 |
+
key, value = swapping_attention(key, value, self.unet_chunk_size)
|
618 |
+
|
619 |
+
|
620 |
+
query = attn.head_to_batch_dim(query)
|
621 |
+
key = attn.head_to_batch_dim(key)
|
622 |
+
value = attn.head_to_batch_dim(value)
|
623 |
+
|
624 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
625 |
+
|
626 |
+
if not hasattr(attn, "attn_map"):
|
627 |
+
setattr(attn, "attn_map", {})
|
628 |
+
setattr(attn, "inference_step", 0)
|
629 |
+
else:
|
630 |
+
attn.inference_step += 1
|
631 |
+
|
632 |
+
|
633 |
+
# if attn.inference_step > 45 and attn.inference_step < 50:
|
634 |
+
# if attn.inference_step == 42 or attn.inference_step==49:
|
635 |
+
if attn.inference_step in self.attn_map_save_steps:
|
636 |
+
attn.attn_map[attn.inference_step] = attention_probs.clone().cpu().detach()
|
637 |
+
|
638 |
+
hidden_states = torch.bmm(attention_probs, value)
|
639 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
640 |
+
|
641 |
+
# linear proj
|
642 |
+
hidden_states = attn.to_out[0](hidden_states)
|
643 |
+
# dropout
|
644 |
+
hidden_states = attn.to_out[1](hidden_states)
|
645 |
+
|
646 |
+
return hidden_states
|
647 |
+
|
648 |
+
|
649 |
+
class InvertedVEAttnProcessor:
|
650 |
+
def __init__(self, unet_chunk_size=2, scale=1.0):
|
651 |
+
self.unet_chunk_size = unet_chunk_size
|
652 |
+
self.scale = scale
|
653 |
+
|
654 |
+
def __call__(
|
655 |
+
self,
|
656 |
+
attn,
|
657 |
+
hidden_states,
|
658 |
+
encoder_hidden_states=None,
|
659 |
+
attention_mask=None):
|
660 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
661 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
662 |
+
query = attn.to_q(hidden_states)
|
663 |
+
|
664 |
+
is_cross_attention = encoder_hidden_states is not None
|
665 |
+
if encoder_hidden_states is None:
|
666 |
+
encoder_hidden_states = hidden_states
|
667 |
+
elif attn.cross_attention_norm:
|
668 |
+
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
|
669 |
+
key = attn.to_k(encoder_hidden_states)
|
670 |
+
value = attn.to_v(encoder_hidden_states)
|
671 |
+
|
672 |
+
#Dual Attention
|
673 |
+
if not is_cross_attention:
|
674 |
+
ve_key = key.clone()
|
675 |
+
ve_value = value.clone()
|
676 |
+
video_length = ve_key.size()[0] // self.unet_chunk_size
|
677 |
+
|
678 |
+
former_frame_index = [0] * video_length
|
679 |
+
ve_key = rearrange(ve_key, "(b f) d c -> b f d c", f=video_length)
|
680 |
+
ve_key = ve_key[:, former_frame_index]
|
681 |
+
ve_key = rearrange(ve_key, "b f d c -> (b f) d c")
|
682 |
+
ve_value = rearrange(ve_value, "(b f) d c -> b f d c", f=video_length)
|
683 |
+
ve_value = ve_value[:, former_frame_index]
|
684 |
+
ve_value = rearrange(ve_value, "b f d c -> (b f) d c")
|
685 |
+
|
686 |
+
ve_key = attn.head_to_batch_dim(ve_key)
|
687 |
+
ve_value = attn.head_to_batch_dim(ve_value)
|
688 |
+
ve_query = attn.head_to_batch_dim(query)
|
689 |
+
|
690 |
+
ve_attention_probs = attn.get_attention_scores(ve_query, ve_key, attention_mask)
|
691 |
+
ve_hidden_states = torch.bmm(ve_attention_probs, ve_value)
|
692 |
+
ve_hidden_states = attn.batch_to_head_dim(ve_hidden_states)
|
693 |
+
ve_hidden_states[0,...] = 0
|
694 |
+
ve_hidden_states[video_length,...] = 0
|
695 |
+
|
696 |
+
query = attn.head_to_batch_dim(query)
|
697 |
+
key = attn.head_to_batch_dim(key)
|
698 |
+
value = attn.head_to_batch_dim(value)
|
699 |
+
|
700 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
701 |
+
hidden_states = torch.bmm(attention_probs, value)
|
702 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
703 |
+
|
704 |
+
hidden_states = hidden_states + ve_hidden_states * self.scale
|
705 |
+
|
706 |
+
else:
|
707 |
+
query = attn.head_to_batch_dim(query)
|
708 |
+
key = attn.head_to_batch_dim(key)
|
709 |
+
value = attn.head_to_batch_dim(value)
|
710 |
+
|
711 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
712 |
+
hidden_states = torch.bmm(attention_probs, value)
|
713 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
714 |
+
|
715 |
+
|
716 |
+
|
717 |
+
# linear proj
|
718 |
+
hidden_states = attn.to_out[0](hidden_states)
|
719 |
+
# dropout
|
720 |
+
hidden_states = attn.to_out[1](hidden_states)
|
721 |
+
|
722 |
+
return hidden_states
|
723 |
+
|
724 |
+
class AttnProcessor(nn.Module):
|
725 |
+
r"""
|
726 |
+
Default processor for performing attention-related computations.
|
727 |
+
"""
|
728 |
+
def __init__(
|
729 |
+
self,
|
730 |
+
hidden_size=None,
|
731 |
+
cross_attention_dim=None,
|
732 |
+
):
|
733 |
+
super().__init__()
|
734 |
+
|
735 |
+
def __call__(
|
736 |
+
self,
|
737 |
+
attn,
|
738 |
+
hidden_states,
|
739 |
+
encoder_hidden_states=None,
|
740 |
+
attention_mask=None,
|
741 |
+
temb=None,
|
742 |
+
):
|
743 |
+
|
744 |
+
residual = hidden_states
|
745 |
+
# import pdb; pdb.set_trace()
|
746 |
+
# if attn.spatial_norm is not None:
|
747 |
+
# hidden_states = attn.spatial_norm(hidden_states, temb)
|
748 |
+
|
749 |
+
input_ndim = hidden_states.ndim
|
750 |
+
|
751 |
+
if input_ndim == 4:
|
752 |
+
batch_size, channel, height, width = hidden_states.shape
|
753 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
754 |
+
|
755 |
+
batch_size, sequence_length, _ = (
|
756 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
757 |
+
)
|
758 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
759 |
+
|
760 |
+
# if attn.group_norm is not None:
|
761 |
+
# hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
762 |
+
|
763 |
+
query = attn.to_q(hidden_states)
|
764 |
+
|
765 |
+
if encoder_hidden_states is None:
|
766 |
+
encoder_hidden_states = hidden_states
|
767 |
+
elif attn.norm_cross:
|
768 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
769 |
+
|
770 |
+
key = attn.to_k(encoder_hidden_states)
|
771 |
+
value = attn.to_v(encoder_hidden_states)
|
772 |
+
|
773 |
+
query = attn.head_to_batch_dim(query)
|
774 |
+
key = attn.head_to_batch_dim(key)
|
775 |
+
value = attn.head_to_batch_dim(value)
|
776 |
+
|
777 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
778 |
+
hidden_states = torch.bmm(attention_probs, value)
|
779 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
780 |
+
|
781 |
+
# linear proj
|
782 |
+
hidden_states = attn.to_out[0](hidden_states)
|
783 |
+
# dropout
|
784 |
+
hidden_states = attn.to_out[1](hidden_states)
|
785 |
+
|
786 |
+
if input_ndim == 4:
|
787 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
788 |
+
|
789 |
+
if attn.residual_connection:
|
790 |
+
hidden_states = hidden_states + residual
|
791 |
+
|
792 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
793 |
+
|
794 |
+
return hidden_states
|
795 |
+
|
796 |
+
|
797 |
+
@dataclass
|
798 |
+
class StableDiffusionPipelineOutput(BaseOutput):
|
799 |
+
"""
|
800 |
+
Output class for Stable Diffusion pipelines.
|
801 |
+
|
802 |
+
Args:
|
803 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
804 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
805 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
806 |
+
nsfw_content_detected (`List[bool]`)
|
807 |
+
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
808 |
+
(nsfw) content, or `None` if safety checking could not be performed.
|
809 |
+
"""
|
810 |
+
|
811 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
812 |
+
nsfw_content_detected: Optional[List[bool]]
|
813 |
+
|
814 |
+
class FrozenDict(OrderedDict):
|
815 |
+
def __init__(self, *args, **kwargs):
|
816 |
+
super().__init__(*args, **kwargs)
|
817 |
+
|
818 |
+
for key, value in self.items():
|
819 |
+
setattr(self, key, value)
|
820 |
+
|
821 |
+
self.__frozen = True
|
822 |
+
|
823 |
+
def __delitem__(self, *args, **kwargs):
|
824 |
+
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
825 |
+
|
826 |
+
def setdefault(self, *args, **kwargs):
|
827 |
+
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
828 |
+
|
829 |
+
def pop(self, *args, **kwargs):
|
830 |
+
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
831 |
+
|
832 |
+
def update(self, *args, **kwargs):
|
833 |
+
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
834 |
+
|
835 |
+
def __setattr__(self, name, value):
|
836 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
837 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
838 |
+
super().__setattr__(name, value)
|
839 |
+
|
840 |
+
def __setitem__(self, name, value):
|
841 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
842 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
843 |
+
super().__setitem__(name, value)
|
844 |
+
|
845 |
+
|
846 |
+
class InvertedVEPipeline(StableDiffusionPipeline):
|
847 |
+
r"""
|
848 |
+
Pipeline for text-to-image generation using Stable Diffusion.
|
849 |
+
|
850 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
851 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
852 |
+
|
853 |
+
Args:
|
854 |
+
vae ([`AutoencoderKL`]):
|
855 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
856 |
+
text_encoder ([`CLIPTextModel`]):
|
857 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
858 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
859 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
860 |
+
tokenizer (`CLIPTokenizer`):
|
861 |
+
Tokenizer of class
|
862 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
863 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
864 |
+
scheduler ([`SchedulerMixin`]):
|
865 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
866 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
867 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
868 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
869 |
+
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
870 |
+
feature_extractor ([`CLIPFeatureExtractor`]):
|
871 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
872 |
+
"""
|
873 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
874 |
+
|
875 |
+
def __init__(
|
876 |
+
self,
|
877 |
+
vae: AutoencoderKL,
|
878 |
+
text_encoder: CLIPTextModel,
|
879 |
+
tokenizer: CLIPTokenizer,
|
880 |
+
unet: UNet2DConditionModel,
|
881 |
+
scheduler: KarrasDiffusionSchedulers,
|
882 |
+
safety_checker: StableDiffusionSafetyChecker,
|
883 |
+
feature_extractor: CLIPFeatureExtractor,
|
884 |
+
requires_safety_checker: bool = True,
|
885 |
+
):
|
886 |
+
# super().__init__()
|
887 |
+
super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
|
888 |
+
safety_checker, feature_extractor, requires_safety_checker)
|
889 |
+
|
890 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
891 |
+
deprecation_message = (
|
892 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
893 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
894 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
895 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
896 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
897 |
+
" file"
|
898 |
+
)
|
899 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
900 |
+
new_config = dict(scheduler.config)
|
901 |
+
new_config["steps_offset"] = 1
|
902 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
903 |
+
|
904 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
905 |
+
deprecation_message = (
|
906 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
907 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
908 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
909 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
910 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
911 |
+
)
|
912 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
913 |
+
new_config = dict(scheduler.config)
|
914 |
+
new_config["clip_sample"] = False
|
915 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
916 |
+
|
917 |
+
if safety_checker is None and requires_safety_checker:
|
918 |
+
logger.warning(
|
919 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
920 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
921 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
922 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
923 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
924 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
925 |
+
)
|
926 |
+
|
927 |
+
if safety_checker is not None and feature_extractor is None:
|
928 |
+
raise ValueError(
|
929 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
930 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
931 |
+
)
|
932 |
+
|
933 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
934 |
+
version.parse(unet.config._diffusers_version).base_version
|
935 |
+
) < version.parse("0.9.0.dev0")
|
936 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
937 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
938 |
+
deprecation_message = (
|
939 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
940 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
941 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
942 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
943 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
944 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
945 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
946 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
947 |
+
" the `unet/config.json` file"
|
948 |
+
)
|
949 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
950 |
+
new_config = dict(unet.config)
|
951 |
+
new_config["sample_size"] = 64
|
952 |
+
unet._internal_dict = FrozenDict(new_config)
|
953 |
+
|
954 |
+
self.register_modules(
|
955 |
+
vae=vae,
|
956 |
+
text_encoder=text_encoder,
|
957 |
+
tokenizer=tokenizer,
|
958 |
+
unet=unet,
|
959 |
+
scheduler=scheduler,
|
960 |
+
safety_checker=safety_checker,
|
961 |
+
feature_extractor=feature_extractor,
|
962 |
+
)
|
963 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
964 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
965 |
+
|
966 |
+
def enable_vae_slicing(self):
|
967 |
+
r"""
|
968 |
+
Enable sliced VAE decoding.
|
969 |
+
|
970 |
+
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
971 |
+
steps. This is useful to save some memory and allow larger batch sizes.
|
972 |
+
"""
|
973 |
+
self.vae.enable_slicing()
|
974 |
+
|
975 |
+
def disable_vae_slicing(self):
|
976 |
+
r"""
|
977 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
|
978 |
+
computing decoding in one step.
|
979 |
+
"""
|
980 |
+
self.vae.disable_slicing()
|
981 |
+
|
982 |
+
def enable_vae_tiling(self):
|
983 |
+
r"""
|
984 |
+
Enable tiled VAE decoding.
|
985 |
+
|
986 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
|
987 |
+
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
|
988 |
+
"""
|
989 |
+
self.vae.enable_tiling()
|
990 |
+
|
991 |
+
def disable_vae_tiling(self):
|
992 |
+
r"""
|
993 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
|
994 |
+
computing decoding in one step.
|
995 |
+
"""
|
996 |
+
self.vae.disable_tiling()
|
997 |
+
|
998 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
999 |
+
r"""
|
1000 |
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
1001 |
+
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
1002 |
+
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
1003 |
+
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
1004 |
+
`enable_model_cpu_offload`, but performance is lower.
|
1005 |
+
"""
|
1006 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
|
1007 |
+
from accelerate import cpu_offload
|
1008 |
+
else:
|
1009 |
+
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
1010 |
+
|
1011 |
+
device = torch.device(f"cuda:{gpu_id}")
|
1012 |
+
|
1013 |
+
if self.device.type != "cpu":
|
1014 |
+
self.to("cpu", silence_dtype_warnings=True)
|
1015 |
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
1016 |
+
|
1017 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
1018 |
+
cpu_offload(cpu_offloaded_model, device)
|
1019 |
+
|
1020 |
+
if self.safety_checker is not None:
|
1021 |
+
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
1022 |
+
|
1023 |
+
def enable_model_cpu_offload(self, gpu_id=0):
|
1024 |
+
r"""
|
1025 |
+
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
1026 |
+
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
1027 |
+
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
1028 |
+
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
1029 |
+
"""
|
1030 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
1031 |
+
from accelerate import cpu_offload_with_hook
|
1032 |
+
else:
|
1033 |
+
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
1034 |
+
|
1035 |
+
device = torch.device(f"cuda:{gpu_id}")
|
1036 |
+
|
1037 |
+
if self.device.type != "cpu":
|
1038 |
+
self.to("cpu", silence_dtype_warnings=True)
|
1039 |
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
1040 |
+
|
1041 |
+
hook = None
|
1042 |
+
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
1043 |
+
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
1044 |
+
|
1045 |
+
if self.safety_checker is not None:
|
1046 |
+
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
1047 |
+
|
1048 |
+
# We'll offload the last model manually.
|
1049 |
+
self.final_offload_hook = hook
|
1050 |
+
|
1051 |
+
@property
|
1052 |
+
def _execution_device(self):
|
1053 |
+
r"""
|
1054 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
1055 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
1056 |
+
hooks.
|
1057 |
+
"""
|
1058 |
+
if not hasattr(self.unet, "_hf_hook"):
|
1059 |
+
return self.device
|
1060 |
+
for module in self.unet.modules():
|
1061 |
+
if (
|
1062 |
+
hasattr(module, "_hf_hook")
|
1063 |
+
and hasattr(module._hf_hook, "execution_device")
|
1064 |
+
and module._hf_hook.execution_device is not None
|
1065 |
+
):
|
1066 |
+
return torch.device(module._hf_hook.execution_device)
|
1067 |
+
return self.device
|
1068 |
+
|
1069 |
+
|
1070 |
+
def run_safety_checker(self, image, device, dtype):
|
1071 |
+
if self.safety_checker is not None:
|
1072 |
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
1073 |
+
image, has_nsfw_concept = self.safety_checker(
|
1074 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
1075 |
+
)
|
1076 |
+
else:
|
1077 |
+
has_nsfw_concept = None
|
1078 |
+
return image, has_nsfw_concept
|
1079 |
+
|
1080 |
+
def decode_latents(self, latents):
|
1081 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
1082 |
+
image = self.vae.decode(latents).sample
|
1083 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
1084 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
1085 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
1086 |
+
return image
|
1087 |
+
|
1088 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
1089 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
1090 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
1091 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
1092 |
+
# and should be between [0, 1]
|
1093 |
+
|
1094 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
1095 |
+
extra_step_kwargs = {}
|
1096 |
+
if accepts_eta:
|
1097 |
+
extra_step_kwargs["eta"] = eta
|
1098 |
+
|
1099 |
+
# check if the scheduler accepts generator
|
1100 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
1101 |
+
if accepts_generator:
|
1102 |
+
extra_step_kwargs["generator"] = generator
|
1103 |
+
return extra_step_kwargs
|
1104 |
+
|
1105 |
+
def check_inputs(
|
1106 |
+
self,
|
1107 |
+
prompt,
|
1108 |
+
height,
|
1109 |
+
width,
|
1110 |
+
callback_steps,
|
1111 |
+
negative_prompt=None,
|
1112 |
+
prompt_embeds=None,
|
1113 |
+
negative_prompt_embeds=None,
|
1114 |
+
):
|
1115 |
+
if height % 8 != 0 or width % 8 != 0:
|
1116 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
1117 |
+
|
1118 |
+
if (callback_steps is None) or (
|
1119 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
1120 |
+
):
|
1121 |
+
raise ValueError(
|
1122 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
1123 |
+
f" {type(callback_steps)}."
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
if prompt is not None and prompt_embeds is not None:
|
1127 |
+
raise ValueError(
|
1128 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
1129 |
+
" only forward one of the two."
|
1130 |
+
)
|
1131 |
+
elif prompt is None and prompt_embeds is None:
|
1132 |
+
raise ValueError(
|
1133 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
1134 |
+
)
|
1135 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
1136 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
1137 |
+
|
1138 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
1139 |
+
raise ValueError(
|
1140 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
1141 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
1142 |
+
)
|
1143 |
+
|
1144 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
1145 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
1146 |
+
raise ValueError(
|
1147 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
1148 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
1149 |
+
f" {negative_prompt_embeds.shape}."
|
1150 |
+
)
|
1151 |
+
|
1152 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
1153 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
1154 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
1155 |
+
raise ValueError(
|
1156 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
1157 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
1158 |
+
)
|
1159 |
+
|
1160 |
+
if latents is None:
|
1161 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
1162 |
+
else:
|
1163 |
+
latents = latents.to(device)
|
1164 |
+
|
1165 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
1166 |
+
latents = latents * self.scheduler.init_noise_sigma
|
1167 |
+
return latents
|
1168 |
+
|
1169 |
+
@torch.no_grad()
|
1170 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
1171 |
+
def __call__(
|
1172 |
+
self,
|
1173 |
+
prompt: Union[str, List[str]] = None,
|
1174 |
+
height: Optional[int] = None,
|
1175 |
+
width: Optional[int] = None,
|
1176 |
+
num_inference_steps: int = 50,
|
1177 |
+
guidance_scale: float = 7.5,
|
1178 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1179 |
+
num_images_per_prompt: Optional[int] = 1,
|
1180 |
+
eta: float = 0.0,
|
1181 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
1182 |
+
latents: Optional[torch.FloatTensor] = None,
|
1183 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
1184 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1185 |
+
output_type: Optional[str] = "pil",
|
1186 |
+
return_dict: bool = True,
|
1187 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1188 |
+
callback_steps: int = 1,
|
1189 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1190 |
+
target_prompt: Optional[str] = None,
|
1191 |
+
# device: Optional[Union[str, torch.device]] = "cpu",
|
1192 |
+
):
|
1193 |
+
r"""
|
1194 |
+
Function invoked when calling the pipeline for generation.
|
1195 |
+
|
1196 |
+
Args:
|
1197 |
+
prompt (`str` or `List[str]`, *optional*):
|
1198 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
1199 |
+
instead.
|
1200 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
1201 |
+
The height in pixels of the generated image.
|
1202 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
1203 |
+
The width in pixels of the generated image.
|
1204 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1205 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1206 |
+
expense of slower inference.
|
1207 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1208 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1209 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1210 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1211 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1212 |
+
usually at the expense of lower image quality.
|
1213 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1214 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
1215 |
+
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
1216 |
+
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
1217 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1218 |
+
The number of images to generate per prompt.
|
1219 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1220 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1221 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1222 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
1223 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
1224 |
+
to make generation deterministic.
|
1225 |
+
latents (`torch.FloatTensor`, *optional*):
|
1226 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
1227 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
1228 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
1229 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
1230 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
1231 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
1232 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1233 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
1234 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
1235 |
+
argument.
|
1236 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1237 |
+
The output format of the generate image. Choose between
|
1238 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1239 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1240 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1241 |
+
plain tuple.
|
1242 |
+
callback (`Callable`, *optional*):
|
1243 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1244 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1245 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1246 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1247 |
+
called at every step.
|
1248 |
+
cross_attention_kwargs (`dict`, *optional*):
|
1249 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
|
1250 |
+
`self.processor` in
|
1251 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
1252 |
+
|
1253 |
+
Examples:
|
1254 |
+
|
1255 |
+
Returns:
|
1256 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1257 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1258 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1259 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1260 |
+
(nsfw) content, according to the `safety_checker`.
|
1261 |
+
"""
|
1262 |
+
# 0. Default height and width to unet
|
1263 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
1264 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
1265 |
+
|
1266 |
+
# 1. Check inputs. Raise error if not correct
|
1267 |
+
self.check_inputs(
|
1268 |
+
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
1269 |
+
)
|
1270 |
+
|
1271 |
+
# 2. Define call parameters
|
1272 |
+
if prompt is not None and isinstance(prompt, str):
|
1273 |
+
batch_size = 1
|
1274 |
+
elif prompt is not None and isinstance(prompt, list):
|
1275 |
+
batch_size = len(prompt)
|
1276 |
+
else:
|
1277 |
+
batch_size = prompt_embeds.shape[0]
|
1278 |
+
|
1279 |
+
device = self._execution_device
|
1280 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
1281 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
1282 |
+
# corresponds to doing no classifier free guidance.
|
1283 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
1284 |
+
|
1285 |
+
# 3. Encode input prompt
|
1286 |
+
# import pdb; pdb.set_trace()
|
1287 |
+
|
1288 |
+
|
1289 |
+
prompt_embeds = self._encode_prompt(
|
1290 |
+
prompt,
|
1291 |
+
device,
|
1292 |
+
num_images_per_prompt,
|
1293 |
+
do_classifier_free_guidance,
|
1294 |
+
negative_prompt,
|
1295 |
+
prompt_embeds=prompt_embeds,
|
1296 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1297 |
+
)
|
1298 |
+
|
1299 |
+
# import pdb; pdb.set_trace()
|
1300 |
+
|
1301 |
+
if target_prompt is not None:
|
1302 |
+
target_prompt_embeds = self._encode_prompt(
|
1303 |
+
target_prompt,
|
1304 |
+
device,
|
1305 |
+
num_images_per_prompt,
|
1306 |
+
do_classifier_free_guidance,
|
1307 |
+
negative_prompt,
|
1308 |
+
prompt_embeds=None,
|
1309 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1310 |
+
)
|
1311 |
+
prompt_embeds[num_images_per_prompt+1: ] = target_prompt_embeds[num_images_per_prompt+1:]
|
1312 |
+
import pdb; pdb.set_trace()
|
1313 |
+
|
1314 |
+
# 4. Prepare timesteps
|
1315 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
1316 |
+
timesteps = self.scheduler.timesteps
|
1317 |
+
|
1318 |
+
# 5. Prepare latent variables
|
1319 |
+
num_channels_latents = self.unet.in_channels
|
1320 |
+
latents = self.prepare_latents(
|
1321 |
+
batch_size * num_images_per_prompt,
|
1322 |
+
num_channels_latents,
|
1323 |
+
height,
|
1324 |
+
width,
|
1325 |
+
prompt_embeds.dtype,
|
1326 |
+
device,
|
1327 |
+
generator,
|
1328 |
+
latents,
|
1329 |
+
)
|
1330 |
+
|
1331 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1332 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
1333 |
+
|
1334 |
+
# 7. Denoising loop
|
1335 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1336 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1337 |
+
for i, t in enumerate(timesteps):
|
1338 |
+
# expand the latents if we are doing classifier free guidance
|
1339 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
1340 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1341 |
+
|
1342 |
+
# predict the noise residual
|
1343 |
+
noise_pred = self.unet(
|
1344 |
+
latent_model_input,
|
1345 |
+
t,
|
1346 |
+
encoder_hidden_states=prompt_embeds,
|
1347 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1348 |
+
).sample
|
1349 |
+
|
1350 |
+
# perform guidance
|
1351 |
+
if do_classifier_free_guidance:
|
1352 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1353 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1354 |
+
|
1355 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1356 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
1357 |
+
|
1358 |
+
# call the callback, if provided
|
1359 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1360 |
+
progress_bar.update()
|
1361 |
+
if callback is not None and i % callback_steps == 0:
|
1362 |
+
callback(i, t, latents)
|
1363 |
+
|
1364 |
+
if output_type == "latent":
|
1365 |
+
image = latents
|
1366 |
+
has_nsfw_concept = None
|
1367 |
+
elif output_type == "pil":
|
1368 |
+
# 8. Post-processing
|
1369 |
+
image = self.decode_latents(latents)
|
1370 |
+
|
1371 |
+
# 9. Run safety checker
|
1372 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
1373 |
+
|
1374 |
+
# 10. Convert to PIL
|
1375 |
+
image = self.numpy_to_pil(image)
|
1376 |
+
else:
|
1377 |
+
# 8. Post-processing
|
1378 |
+
image = self.decode_latents(latents)
|
1379 |
+
|
1380 |
+
# 9. Run safety checker
|
1381 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
1382 |
+
|
1383 |
+
# Offload last model to CPU
|
1384 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
1385 |
+
self.final_offload_hook.offload()
|
1386 |
+
|
1387 |
+
if not return_dict:
|
1388 |
+
return (image, has_nsfw_concept)
|
1389 |
+
|
1390 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
1391 |
+
|
1392 |
+
|
1393 |
+
ACTIVATE_LAYER_CANDIDATE= [
|
1394 |
+
'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor',
|
1395 |
+
'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor',
|
1396 |
+
'down_blocks.1.attentions.0.transformer_blocks.1.attn1.processor',
|
1397 |
+
'down_blocks.1.attentions.0.transformer_blocks.1.attn2.processor',
|
1398 |
+
'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor',
|
1399 |
+
'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor',
|
1400 |
+
'down_blocks.1.attentions.1.transformer_blocks.1.attn1.processor',
|
1401 |
+
'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor', #8
|
1402 |
+
|
1403 |
+
'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor',
|
1404 |
+
'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor',
|
1405 |
+
'down_blocks.2.attentions.0.transformer_blocks.1.attn1.processor',
|
1406 |
+
'down_blocks.2.attentions.0.transformer_blocks.1.attn2.processor',
|
1407 |
+
'down_blocks.2.attentions.0.transformer_blocks.2.attn1.processor',
|
1408 |
+
'down_blocks.2.attentions.0.transformer_blocks.2.attn2.processor',
|
1409 |
+
'down_blocks.2.attentions.0.transformer_blocks.3.attn1.processor',
|
1410 |
+
'down_blocks.2.attentions.0.transformer_blocks.3.attn2.processor',
|
1411 |
+
'down_blocks.2.attentions.0.transformer_blocks.4.attn1.processor',
|
1412 |
+
'down_blocks.2.attentions.0.transformer_blocks.4.attn2.processor',
|
1413 |
+
'down_blocks.2.attentions.0.transformer_blocks.5.attn1.processor',
|
1414 |
+
'down_blocks.2.attentions.0.transformer_blocks.5.attn2.processor',
|
1415 |
+
'down_blocks.2.attentions.0.transformer_blocks.6.attn1.processor',
|
1416 |
+
'down_blocks.2.attentions.0.transformer_blocks.6.attn2.processor',
|
1417 |
+
'down_blocks.2.attentions.0.transformer_blocks.7.attn1.processor',
|
1418 |
+
'down_blocks.2.attentions.0.transformer_blocks.7.attn2.processor',
|
1419 |
+
'down_blocks.2.attentions.0.transformer_blocks.8.attn1.processor',
|
1420 |
+
'down_blocks.2.attentions.0.transformer_blocks.8.attn2.processor',
|
1421 |
+
'down_blocks.2.attentions.0.transformer_blocks.9.attn1.processor',
|
1422 |
+
'down_blocks.2.attentions.0.transformer_blocks.9.attn2.processor', #20
|
1423 |
+
|
1424 |
+
'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor',
|
1425 |
+
'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor',
|
1426 |
+
'down_blocks.2.attentions.1.transformer_blocks.1.attn1.processor',
|
1427 |
+
'down_blocks.2.attentions.1.transformer_blocks.1.attn2.processor',
|
1428 |
+
'down_blocks.2.attentions.1.transformer_blocks.2.attn1.processor',
|
1429 |
+
'down_blocks.2.attentions.1.transformer_blocks.2.attn2.processor',
|
1430 |
+
'down_blocks.2.attentions.1.transformer_blocks.3.attn1.processor',
|
1431 |
+
'down_blocks.2.attentions.1.transformer_blocks.3.attn2.processor',
|
1432 |
+
'down_blocks.2.attentions.1.transformer_blocks.4.attn1.processor',
|
1433 |
+
'down_blocks.2.attentions.1.transformer_blocks.4.attn2.processor',
|
1434 |
+
'down_blocks.2.attentions.1.transformer_blocks.5.attn1.processor',
|
1435 |
+
'down_blocks.2.attentions.1.transformer_blocks.5.attn2.processor',
|
1436 |
+
'down_blocks.2.attentions.1.transformer_blocks.6.attn1.processor',
|
1437 |
+
'down_blocks.2.attentions.1.transformer_blocks.6.attn2.processor',
|
1438 |
+
'down_blocks.2.attentions.1.transformer_blocks.7.attn1.processor',
|
1439 |
+
'down_blocks.2.attentions.1.transformer_blocks.7.attn2.processor',
|
1440 |
+
'down_blocks.2.attentions.1.transformer_blocks.8.attn1.processor',
|
1441 |
+
'down_blocks.2.attentions.1.transformer_blocks.8.attn2.processor',
|
1442 |
+
'down_blocks.2.attentions.1.transformer_blocks.9.attn1.processor',
|
1443 |
+
'down_blocks.2.attentions.1.transformer_blocks.9.attn2.processor',#20
|
1444 |
+
|
1445 |
+
'mid_block.attentions.0.transformer_blocks.0.attn1.processor',
|
1446 |
+
'mid_block.attentions.0.transformer_blocks.0.attn2.processor',
|
1447 |
+
'mid_block.attentions.0.transformer_blocks.1.attn1.processor',
|
1448 |
+
'mid_block.attentions.0.transformer_blocks.1.attn2.processor',
|
1449 |
+
'mid_block.attentions.0.transformer_blocks.2.attn1.processor',
|
1450 |
+
'mid_block.attentions.0.transformer_blocks.2.attn2.processor',
|
1451 |
+
'mid_block.attentions.0.transformer_blocks.3.attn1.processor',
|
1452 |
+
'mid_block.attentions.0.transformer_blocks.3.attn2.processor',
|
1453 |
+
'mid_block.attentions.0.transformer_blocks.4.attn1.processor',
|
1454 |
+
'mid_block.attentions.0.transformer_blocks.4.attn2.processor',
|
1455 |
+
'mid_block.attentions.0.transformer_blocks.5.attn1.processor',
|
1456 |
+
'mid_block.attentions.0.transformer_blocks.5.attn2.processor',
|
1457 |
+
'mid_block.attentions.0.transformer_blocks.6.attn1.processor',
|
1458 |
+
'mid_block.attentions.0.transformer_blocks.6.attn2.processor',
|
1459 |
+
'mid_block.attentions.0.transformer_blocks.7.attn1.processor',
|
1460 |
+
'mid_block.attentions.0.transformer_blocks.7.attn2.processor',
|
1461 |
+
'mid_block.attentions.0.transformer_blocks.8.attn1.processor',
|
1462 |
+
'mid_block.attentions.0.transformer_blocks.8.attn2.processor',
|
1463 |
+
'mid_block.attentions.0.transformer_blocks.9.attn1.processor',
|
1464 |
+
'mid_block.attentions.0.transformer_blocks.9.attn2.processor', #20
|
1465 |
+
|
1466 |
+
'up_blocks.0.attentions.0.transformer_blocks.0.attn1.processor',
|
1467 |
+
'up_blocks.0.attentions.0.transformer_blocks.0.attn2.processor',
|
1468 |
+
'up_blocks.0.attentions.0.transformer_blocks.1.attn1.processor',
|
1469 |
+
'up_blocks.0.attentions.0.transformer_blocks.1.attn2.processor',
|
1470 |
+
'up_blocks.0.attentions.0.transformer_blocks.2.attn1.processor',
|
1471 |
+
'up_blocks.0.attentions.0.transformer_blocks.2.attn2.processor',
|
1472 |
+
'up_blocks.0.attentions.0.transformer_blocks.3.attn1.processor',
|
1473 |
+
'up_blocks.0.attentions.0.transformer_blocks.3.attn2.processor',
|
1474 |
+
'up_blocks.0.attentions.0.transformer_blocks.4.attn1.processor',
|
1475 |
+
'up_blocks.0.attentions.0.transformer_blocks.4.attn2.processor',
|
1476 |
+
'up_blocks.0.attentions.0.transformer_blocks.5.attn1.processor',
|
1477 |
+
'up_blocks.0.attentions.0.transformer_blocks.5.attn2.processor',
|
1478 |
+
'up_blocks.0.attentions.0.transformer_blocks.6.attn1.processor',
|
1479 |
+
'up_blocks.0.attentions.0.transformer_blocks.6.attn2.processor',
|
1480 |
+
'up_blocks.0.attentions.0.transformer_blocks.7.attn1.processor',
|
1481 |
+
'up_blocks.0.attentions.0.transformer_blocks.7.attn2.processor',
|
1482 |
+
'up_blocks.0.attentions.0.transformer_blocks.8.attn1.processor',
|
1483 |
+
'up_blocks.0.attentions.0.transformer_blocks.8.attn2.processor',
|
1484 |
+
'up_blocks.0.attentions.0.transformer_blocks.9.attn1.processor',
|
1485 |
+
'up_blocks.0.attentions.0.transformer_blocks.9.attn2.processor',#20
|
1486 |
+
|
1487 |
+
'up_blocks.0.attentions.1.transformer_blocks.0.attn1.processor',
|
1488 |
+
'up_blocks.0.attentions.1.transformer_blocks.0.attn2.processor',
|
1489 |
+
'up_blocks.0.attentions.1.transformer_blocks.1.attn1.processor',
|
1490 |
+
'up_blocks.0.attentions.1.transformer_blocks.1.attn2.processor',
|
1491 |
+
'up_blocks.0.attentions.1.transformer_blocks.2.attn1.processor',
|
1492 |
+
'up_blocks.0.attentions.1.transformer_blocks.2.attn2.processor',
|
1493 |
+
'up_blocks.0.attentions.1.transformer_blocks.3.attn1.processor',
|
1494 |
+
'up_blocks.0.attentions.1.transformer_blocks.3.attn2.processor',
|
1495 |
+
'up_blocks.0.attentions.1.transformer_blocks.4.attn1.processor',
|
1496 |
+
'up_blocks.0.attentions.1.transformer_blocks.4.attn2.processor',
|
1497 |
+
'up_blocks.0.attentions.1.transformer_blocks.5.attn1.processor',
|
1498 |
+
'up_blocks.0.attentions.1.transformer_blocks.5.attn2.processor',
|
1499 |
+
'up_blocks.0.attentions.1.transformer_blocks.6.attn1.processor',
|
1500 |
+
'up_blocks.0.attentions.1.transformer_blocks.6.attn2.processor',
|
1501 |
+
'up_blocks.0.attentions.1.transformer_blocks.7.attn1.processor',
|
1502 |
+
'up_blocks.0.attentions.1.transformer_blocks.7.attn2.processor',
|
1503 |
+
'up_blocks.0.attentions.1.transformer_blocks.8.attn1.processor',
|
1504 |
+
'up_blocks.0.attentions.1.transformer_blocks.8.attn2.processor',
|
1505 |
+
'up_blocks.0.attentions.1.transformer_blocks.9.attn1.processor',
|
1506 |
+
'up_blocks.0.attentions.1.transformer_blocks.9.attn2.processor',#20
|
1507 |
+
|
1508 |
+
'up_blocks.0.attentions.2.transformer_blocks.0.attn1.processor',
|
1509 |
+
'up_blocks.0.attentions.2.transformer_blocks.0.attn2.processor',
|
1510 |
+
'up_blocks.0.attentions.2.transformer_blocks.1.attn1.processor',
|
1511 |
+
'up_blocks.0.attentions.2.transformer_blocks.1.attn2.processor',
|
1512 |
+
'up_blocks.0.attentions.2.transformer_blocks.2.attn1.processor',
|
1513 |
+
'up_blocks.0.attentions.2.transformer_blocks.2.attn2.processor',
|
1514 |
+
'up_blocks.0.attentions.2.transformer_blocks.3.attn1.processor',
|
1515 |
+
'up_blocks.0.attentions.2.transformer_blocks.3.attn2.processor',
|
1516 |
+
'up_blocks.0.attentions.2.transformer_blocks.4.attn1.processor',
|
1517 |
+
'up_blocks.0.attentions.2.transformer_blocks.4.attn2.processor',
|
1518 |
+
'up_blocks.0.attentions.2.transformer_blocks.5.attn1.processor',
|
1519 |
+
'up_blocks.0.attentions.2.transformer_blocks.5.attn2.processor',
|
1520 |
+
'up_blocks.0.attentions.2.transformer_blocks.6.attn1.processor',
|
1521 |
+
'up_blocks.0.attentions.2.transformer_blocks.6.attn2.processor',
|
1522 |
+
'up_blocks.0.attentions.2.transformer_blocks.7.attn1.processor',
|
1523 |
+
'up_blocks.0.attentions.2.transformer_blocks.7.attn2.processor',
|
1524 |
+
'up_blocks.0.attentions.2.transformer_blocks.8.attn1.processor',
|
1525 |
+
'up_blocks.0.attentions.2.transformer_blocks.8.attn2.processor',
|
1526 |
+
'up_blocks.0.attentions.2.transformer_blocks.9.attn1.processor',
|
1527 |
+
'up_blocks.0.attentions.2.transformer_blocks.9.attn2.processor', #20
|
1528 |
+
|
1529 |
+
'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor',
|
1530 |
+
'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor',
|
1531 |
+
'up_blocks.1.attentions.0.transformer_blocks.1.attn1.processor',
|
1532 |
+
'up_blocks.1.attentions.0.transformer_blocks.1.attn2.processor',
|
1533 |
+
'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor',
|
1534 |
+
'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor',
|
1535 |
+
'up_blocks.1.attentions.1.transformer_blocks.1.attn1.processor',
|
1536 |
+
'up_blocks.1.attentions.1.transformer_blocks.1.attn2.processor',
|
1537 |
+
'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor',
|
1538 |
+
'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor',
|
1539 |
+
'up_blocks.1.attentions.2.transformer_blocks.1.attn1.processor',
|
1540 |
+
'up_blocks.1.attentions.2.transformer_blocks.1.attn2.processor',#12
|
1541 |
+
|
1542 |
+
]
|
1543 |
+
|
1544 |
+
STYLE_DESCRIPTION_DICT = {
|
1545 |
+
"chinese-ink-paint":("{object} in colorful chinese ink paintings style",""),
|
1546 |
+
"cloud":("Photography of {object}, realistic",""),
|
1547 |
+
"digital-art":("{object} in digital glitch arts style",""),
|
1548 |
+
"fire":("{object} photography, realistic, black background'",""),
|
1549 |
+
"klimt":("{object} in style of Gustav Klimt",""),
|
1550 |
+
"line-art":("line art drawing of {object} . professional, sleek, modern, minimalist, graphic, line art, vector graphics",""),
|
1551 |
+
"low-poly":("low-poly style of {object} . low-poly game art, polygon mesh, jagged, blocky, wireframe edges, centered composition",
|
1552 |
+
"noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo"),
|
1553 |
+
"munch":("{object} in Edvard Munch style",""),
|
1554 |
+
"van-gogh":("{object}, Van Gogh",""),
|
1555 |
+
"totoro":("{object}, art by studio ghibli, cinematic, masterpiece,key visual, studio anime, highly detailed",
|
1556 |
+
"photo, deformed, black and white, realism, disfigured, low contrast"),
|
1557 |
+
|
1558 |
+
"realistic": ("A portrait of {object}, photorealistic, 35mm film, realistic",
|
1559 |
+
"gray, ugly, deformed, noisy, blurry"),
|
1560 |
+
|
1561 |
+
"line_art": ("line art drawing of {object} . professional, sleek, modern, minimalist, graphic, line art, vector graphics",
|
1562 |
+
"anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic"
|
1563 |
+
) ,
|
1564 |
+
|
1565 |
+
"anime": ("anime artwork of {object} . anime style, key visual, vibrant, studio anime, highly detailed",
|
1566 |
+
"photo, deformed, black and white, realism, disfigured, low contrast"
|
1567 |
+
),
|
1568 |
+
|
1569 |
+
"Artstyle_Pop_Art" : ("pop Art style of {object} . bright colors, bold outlines, popular culture themes, ironic or kitsch",
|
1570 |
+
"ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, minimalist"
|
1571 |
+
),
|
1572 |
+
|
1573 |
+
"Artstyle_Pointillism": ("pointillism style of {object} . composed entirely of small, distinct dots of color, vibrant, highly detailed",
|
1574 |
+
"line drawing, smooth shading, large color fields, simplistic"
|
1575 |
+
),
|
1576 |
+
|
1577 |
+
"origami": ("origami style of {object} . paper art, pleated paper, folded, origami art, pleats, cut and fold, centered composition",
|
1578 |
+
"noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo"
|
1579 |
+
),
|
1580 |
+
|
1581 |
+
"craft_clay": ("play-doh style of {object} . sculpture, clay art, centered composition, Claymation",
|
1582 |
+
"sloppy, messy, grainy, highly detailed, ultra textured, photo"
|
1583 |
+
),
|
1584 |
+
|
1585 |
+
"low_poly" : ("low-poly style of {object} . low-poly game art, polygon mesh, jagged, blocky, wireframe edges, centered composition",
|
1586 |
+
"noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo"
|
1587 |
+
),
|
1588 |
+
|
1589 |
+
"Artstyle_watercolor": ("watercolor painting of {object} . vibrant, beautiful, painterly, detailed, textural, artistic",
|
1590 |
+
"anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy"
|
1591 |
+
),
|
1592 |
+
|
1593 |
+
"Papercraft_Collage" : ("collage style of {object} . mixed media, layered, textural, detailed, artistic",
|
1594 |
+
"ugly, deformed, noisy, blurry, low contrast, realism, photorealistic"
|
1595 |
+
),
|
1596 |
+
|
1597 |
+
"Artstyle_Impressionist" : ("impressionist painting of {object} . loose brushwork, vibrant color, light and shadow play, captures feeling over form",
|
1598 |
+
"anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy"
|
1599 |
+
),
|
1600 |
+
"realistic_bg_black":("{object} photography, realistic, black background",
|
1601 |
+
""),
|
1602 |
+
"photography_realistic":("Photography of {object}, realistic",
|
1603 |
+
""),
|
1604 |
+
"digital_art":("{object} in digital glitch arts style.",
|
1605 |
+
""
|
1606 |
+
),
|
1607 |
+
"chinese_painting":("{object} in traditional a chinese ink painting style.",
|
1608 |
+
""
|
1609 |
+
),
|
1610 |
+
"no_style":("{object}",
|
1611 |
+
""),
|
1612 |
+
"kid_drawing":("{object} in kid crayon drawings style.",""),
|
1613 |
+
"onepiece":("{object}, wanostyle, angry looking, straw hat, looking at viewer, solo, upper body, masterpiece, best quality, (extremely detailed), watercolor, illustration, depth of field, sketch, dark intense shadows, sharp focus, soft lighting, hdr, colorful, good composition, fire all around, spectacular, closed shirt",
|
1614 |
+
" watermark, text, error, blurry, jpeg artifacts, many objects, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature")
|
1615 |
+
}
|
pipelines/pipeline_controlnet_sd_xl.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pipelines/pipeline_stable_diffusion_xl.py
ADDED
@@ -0,0 +1,1792 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
17 |
+
import PIL
|
18 |
+
import torch
|
19 |
+
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
20 |
+
|
21 |
+
from diffusers.image_processor import VaeImageProcessor
|
22 |
+
from diffusers.loaders import (
|
23 |
+
FromSingleFileMixin,
|
24 |
+
StableDiffusionXLLoraLoaderMixin,
|
25 |
+
TextualInversionLoaderMixin,
|
26 |
+
)
|
27 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
28 |
+
from diffusers.models.attention_processor import (
|
29 |
+
AttnProcessor2_0,
|
30 |
+
LoRAAttnProcessor2_0,
|
31 |
+
LoRAXFormersAttnProcessor,
|
32 |
+
XFormersAttnProcessor,
|
33 |
+
)
|
34 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
35 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
36 |
+
from diffusers.utils import (
|
37 |
+
USE_PEFT_BACKEND,
|
38 |
+
deprecate,
|
39 |
+
is_invisible_watermark_available,
|
40 |
+
is_torch_xla_available,
|
41 |
+
logging,
|
42 |
+
replace_example_docstring,
|
43 |
+
scale_lora_layers,
|
44 |
+
unscale_lora_layers,
|
45 |
+
)
|
46 |
+
from diffusers.utils.torch_utils import randn_tensor
|
47 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
48 |
+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
49 |
+
|
50 |
+
from pipelines.inverted_ve_pipeline import CrossFrameAttnProcessor, ACTIVATE_LAYER_CANDIDATE, SharedAttentionProcessor, SharedAttentionProcessor_v2
|
51 |
+
from diffusers.models.attention_processor import AttnProcessor
|
52 |
+
|
53 |
+
import os
|
54 |
+
|
55 |
+
if is_invisible_watermark_available():
|
56 |
+
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
57 |
+
|
58 |
+
if is_torch_xla_available():
|
59 |
+
import torch_xla.core.xla_model as xm
|
60 |
+
|
61 |
+
XLA_AVAILABLE = True
|
62 |
+
else:
|
63 |
+
XLA_AVAILABLE = False
|
64 |
+
|
65 |
+
|
66 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
67 |
+
|
68 |
+
EXAMPLE_DOC_STRING = """
|
69 |
+
Examples:
|
70 |
+
```py
|
71 |
+
>>> import torch
|
72 |
+
>>> from diffusers import StableDiffusionXLPipeline
|
73 |
+
|
74 |
+
>>> pipe = StableDiffusionXLPipeline.from_pretrained(
|
75 |
+
... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
76 |
+
... )
|
77 |
+
>>> pipe = pipe.to("cuda")
|
78 |
+
|
79 |
+
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
80 |
+
>>> image = pipe(prompt).images[0]
|
81 |
+
```
|
82 |
+
"""
|
83 |
+
|
84 |
+
|
85 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
86 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
87 |
+
"""
|
88 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
89 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
90 |
+
"""
|
91 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
92 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
93 |
+
# rescale the results from guidance (fixes overexposure)
|
94 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
95 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
96 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
97 |
+
return noise_cfg
|
98 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
99 |
+
def retrieve_latents(encoder_output, generator):
|
100 |
+
if hasattr(encoder_output, "latent_dist"):
|
101 |
+
return encoder_output.latent_dist.sample(generator)
|
102 |
+
elif hasattr(encoder_output, "latents"):
|
103 |
+
return encoder_output.latents
|
104 |
+
else:
|
105 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
class StableDiffusionXLPipeline(
|
110 |
+
DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
111 |
+
):
|
112 |
+
r"""
|
113 |
+
Pipeline for text-to-image generation using Stable Diffusion XL.
|
114 |
+
|
115 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
116 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
117 |
+
|
118 |
+
In addition the pipeline inherits the following loading methods:
|
119 |
+
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
|
120 |
+
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
121 |
+
|
122 |
+
as well as the following saving methods:
|
123 |
+
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
|
124 |
+
|
125 |
+
Args:
|
126 |
+
vae ([`AutoencoderKL`]):
|
127 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
128 |
+
text_encoder ([`CLIPTextModel`]):
|
129 |
+
Frozen text-encoder. Stable Diffusion XL uses the text portion of
|
130 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
131 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
132 |
+
text_encoder_2 ([` CLIPTextModelWithProjection`]):
|
133 |
+
Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
|
134 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
135 |
+
specifically the
|
136 |
+
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
137 |
+
variant.
|
138 |
+
tokenizer (`CLIPTokenizer`):
|
139 |
+
Tokenizer of class
|
140 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
141 |
+
tokenizer_2 (`CLIPTokenizer`):
|
142 |
+
Second Tokenizer of class
|
143 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
144 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
145 |
+
scheduler ([`SchedulerMixin`]):
|
146 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
147 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
148 |
+
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
|
149 |
+
Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
|
150 |
+
`stabilityai/stable-diffusion-xl-base-1-0`.
|
151 |
+
add_watermarker (`bool`, *optional*):
|
152 |
+
Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
|
153 |
+
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
|
154 |
+
watermarker will be used.
|
155 |
+
"""
|
156 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
157 |
+
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
158 |
+
_callback_tensor_inputs = [
|
159 |
+
"latents",
|
160 |
+
"prompt_embeds",
|
161 |
+
"negative_prompt_embeds",
|
162 |
+
"add_text_embeds",
|
163 |
+
"add_time_ids",
|
164 |
+
"negative_pooled_prompt_embeds",
|
165 |
+
"negative_add_time_ids",
|
166 |
+
]
|
167 |
+
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
vae: AutoencoderKL,
|
171 |
+
text_encoder: CLIPTextModel,
|
172 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
173 |
+
tokenizer: CLIPTokenizer,
|
174 |
+
tokenizer_2: CLIPTokenizer,
|
175 |
+
unet: UNet2DConditionModel,
|
176 |
+
scheduler: KarrasDiffusionSchedulers,
|
177 |
+
force_zeros_for_empty_prompt: bool = True,
|
178 |
+
add_watermarker: Optional[bool] = None,
|
179 |
+
):
|
180 |
+
super().__init__()
|
181 |
+
|
182 |
+
self.register_modules(
|
183 |
+
vae=vae,
|
184 |
+
text_encoder=text_encoder,
|
185 |
+
text_encoder_2=text_encoder_2,
|
186 |
+
tokenizer=tokenizer,
|
187 |
+
tokenizer_2=tokenizer_2,
|
188 |
+
unet=unet,
|
189 |
+
scheduler=scheduler,
|
190 |
+
)
|
191 |
+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
192 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
193 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
194 |
+
|
195 |
+
self.default_sample_size = self.unet.config.sample_size
|
196 |
+
|
197 |
+
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
198 |
+
|
199 |
+
if add_watermarker:
|
200 |
+
self.watermark = StableDiffusionXLWatermarker()
|
201 |
+
else:
|
202 |
+
self.watermark = None
|
203 |
+
|
204 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
205 |
+
def enable_vae_slicing(self):
|
206 |
+
r"""
|
207 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
208 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
209 |
+
"""
|
210 |
+
self.vae.enable_slicing()
|
211 |
+
|
212 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
213 |
+
def disable_vae_slicing(self):
|
214 |
+
r"""
|
215 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
216 |
+
computing decoding in one step.
|
217 |
+
"""
|
218 |
+
self.vae.disable_slicing()
|
219 |
+
|
220 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
221 |
+
def enable_vae_tiling(self):
|
222 |
+
r"""
|
223 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
224 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
225 |
+
processing larger images.
|
226 |
+
"""
|
227 |
+
self.vae.enable_tiling()
|
228 |
+
|
229 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
230 |
+
def disable_vae_tiling(self):
|
231 |
+
r"""
|
232 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
233 |
+
computing decoding in one step.
|
234 |
+
"""
|
235 |
+
self.vae.disable_tiling()
|
236 |
+
|
237 |
+
def encode_prompt(
|
238 |
+
self,
|
239 |
+
prompt: str,
|
240 |
+
prompt_2: Optional[str] = None,
|
241 |
+
device: Optional[torch.device] = None,
|
242 |
+
num_images_per_prompt: int = 1,
|
243 |
+
do_classifier_free_guidance: bool = True,
|
244 |
+
negative_prompt: Optional[str] = None,
|
245 |
+
negative_prompt_2: Optional[str] = None,
|
246 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
247 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
248 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
249 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
250 |
+
lora_scale: Optional[float] = None,
|
251 |
+
clip_skip: Optional[int] = None,
|
252 |
+
):
|
253 |
+
r"""
|
254 |
+
Encodes the prompt into text encoder hidden states.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
prompt (`str` or `List[str]`, *optional*):
|
258 |
+
prompt to be encoded
|
259 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
260 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
261 |
+
used in both text-encoders
|
262 |
+
device: (`torch.device`):
|
263 |
+
torch device
|
264 |
+
num_images_per_prompt (`int`):
|
265 |
+
number of images that should be generated per prompt
|
266 |
+
do_classifier_free_guidance (`bool`):
|
267 |
+
whether to use classifier free guidance or not
|
268 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
269 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
270 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
271 |
+
less than `1`).
|
272 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
273 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
274 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
275 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
276 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
277 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
278 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
279 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
280 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
281 |
+
argument.
|
282 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
283 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
284 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
285 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
286 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
287 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
288 |
+
input argument.
|
289 |
+
lora_scale (`float`, *optional*):
|
290 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
291 |
+
clip_skip (`int`, *optional*):
|
292 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
293 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
294 |
+
"""
|
295 |
+
device = device or self._execution_device
|
296 |
+
|
297 |
+
# set lora scale so that monkey patched LoRA
|
298 |
+
# function of text encoder can correctly access it
|
299 |
+
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
|
300 |
+
self._lora_scale = lora_scale
|
301 |
+
|
302 |
+
# dynamically adjust the LoRA scale
|
303 |
+
if self.text_encoder is not None:
|
304 |
+
if not USE_PEFT_BACKEND:
|
305 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
306 |
+
else:
|
307 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
308 |
+
|
309 |
+
if self.text_encoder_2 is not None:
|
310 |
+
if not USE_PEFT_BACKEND:
|
311 |
+
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
312 |
+
else:
|
313 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
314 |
+
|
315 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
316 |
+
|
317 |
+
if prompt is not None:
|
318 |
+
batch_size = len(prompt)
|
319 |
+
else:
|
320 |
+
batch_size = prompt_embeds.shape[0]
|
321 |
+
|
322 |
+
# Define tokenizers and text encoders
|
323 |
+
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
324 |
+
text_encoders = (
|
325 |
+
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
326 |
+
)
|
327 |
+
|
328 |
+
if prompt_embeds is None:
|
329 |
+
prompt_2 = prompt_2 or prompt
|
330 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
331 |
+
|
332 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
333 |
+
prompt_embeds_list = []
|
334 |
+
prompts = [prompt, prompt_2]
|
335 |
+
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
336 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
337 |
+
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
338 |
+
|
339 |
+
text_inputs = tokenizer(
|
340 |
+
prompt,
|
341 |
+
padding="max_length",
|
342 |
+
max_length=tokenizer.model_max_length,
|
343 |
+
truncation=True,
|
344 |
+
return_tensors="pt",
|
345 |
+
)
|
346 |
+
|
347 |
+
text_input_ids = text_inputs.input_ids
|
348 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
349 |
+
|
350 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
351 |
+
text_input_ids, untruncated_ids
|
352 |
+
):
|
353 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
354 |
+
logger.warning(
|
355 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
356 |
+
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
357 |
+
)
|
358 |
+
|
359 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
360 |
+
|
361 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
362 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
363 |
+
if clip_skip is None:
|
364 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
365 |
+
else:
|
366 |
+
# "2" because SDXL always indexes from the penultimate layer.
|
367 |
+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
368 |
+
|
369 |
+
prompt_embeds_list.append(prompt_embeds)
|
370 |
+
|
371 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
372 |
+
|
373 |
+
# get unconditional embeddings for classifier free guidance
|
374 |
+
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
|
375 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
376 |
+
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
377 |
+
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
378 |
+
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
379 |
+
negative_prompt = negative_prompt or ""
|
380 |
+
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
381 |
+
|
382 |
+
# normalize str to list
|
383 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
384 |
+
negative_prompt_2 = (
|
385 |
+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
386 |
+
)
|
387 |
+
|
388 |
+
uncond_tokens: List[str]
|
389 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
390 |
+
raise TypeError(
|
391 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
392 |
+
f" {type(prompt)}."
|
393 |
+
)
|
394 |
+
elif batch_size != len(negative_prompt):
|
395 |
+
raise ValueError(
|
396 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
397 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
398 |
+
" the batch size of `prompt`."
|
399 |
+
)
|
400 |
+
else:
|
401 |
+
uncond_tokens = [negative_prompt, negative_prompt_2]
|
402 |
+
|
403 |
+
negative_prompt_embeds_list = []
|
404 |
+
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
|
405 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
406 |
+
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
|
407 |
+
|
408 |
+
max_length = prompt_embeds.shape[1]
|
409 |
+
uncond_input = tokenizer(
|
410 |
+
negative_prompt,
|
411 |
+
padding="max_length",
|
412 |
+
max_length=max_length,
|
413 |
+
truncation=True,
|
414 |
+
return_tensors="pt",
|
415 |
+
)
|
416 |
+
|
417 |
+
negative_prompt_embeds = text_encoder(
|
418 |
+
uncond_input.input_ids.to(device),
|
419 |
+
output_hidden_states=True,
|
420 |
+
)
|
421 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
422 |
+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
423 |
+
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
424 |
+
|
425 |
+
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
426 |
+
|
427 |
+
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
428 |
+
|
429 |
+
if self.text_encoder_2 is not None:
|
430 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
431 |
+
else:
|
432 |
+
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
433 |
+
|
434 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
435 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
436 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
437 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
438 |
+
|
439 |
+
if do_classifier_free_guidance:
|
440 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
441 |
+
seq_len = negative_prompt_embeds.shape[1]
|
442 |
+
|
443 |
+
if self.text_encoder_2 is not None:
|
444 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
445 |
+
else:
|
446 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
447 |
+
|
448 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
449 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
450 |
+
|
451 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
452 |
+
bs_embed * num_images_per_prompt, -1
|
453 |
+
)
|
454 |
+
if do_classifier_free_guidance:
|
455 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
456 |
+
bs_embed * num_images_per_prompt, -1
|
457 |
+
)
|
458 |
+
|
459 |
+
|
460 |
+
if self.text_encoder is not None:
|
461 |
+
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
462 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
463 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
464 |
+
|
465 |
+
if self.text_encoder_2 is not None:
|
466 |
+
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
467 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
468 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
469 |
+
|
470 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
471 |
+
|
472 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
473 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
474 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
475 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
476 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
477 |
+
# and should be between [0, 1]
|
478 |
+
|
479 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
480 |
+
extra_step_kwargs = {}
|
481 |
+
if accepts_eta:
|
482 |
+
extra_step_kwargs["eta"] = eta
|
483 |
+
|
484 |
+
# check if the scheduler accepts generator
|
485 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
486 |
+
if accepts_generator:
|
487 |
+
extra_step_kwargs["generator"] = generator
|
488 |
+
return extra_step_kwargs
|
489 |
+
|
490 |
+
def check_inputs(
|
491 |
+
self,
|
492 |
+
prompt,
|
493 |
+
prompt_2,
|
494 |
+
height,
|
495 |
+
width,
|
496 |
+
callback_steps,
|
497 |
+
negative_prompt=None,
|
498 |
+
negative_prompt_2=None,
|
499 |
+
prompt_embeds=None,
|
500 |
+
negative_prompt_embeds=None,
|
501 |
+
pooled_prompt_embeds=None,
|
502 |
+
negative_pooled_prompt_embeds=None,
|
503 |
+
callback_on_step_end_tensor_inputs=None,
|
504 |
+
):
|
505 |
+
if height % 8 != 0 or width % 8 != 0:
|
506 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
507 |
+
|
508 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
509 |
+
raise ValueError(
|
510 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
511 |
+
f" {type(callback_steps)}."
|
512 |
+
)
|
513 |
+
|
514 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
515 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
516 |
+
):
|
517 |
+
raise ValueError(
|
518 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
519 |
+
)
|
520 |
+
|
521 |
+
if prompt is not None and prompt_embeds is not None:
|
522 |
+
raise ValueError(
|
523 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
524 |
+
" only forward one of the two."
|
525 |
+
)
|
526 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
527 |
+
raise ValueError(
|
528 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
529 |
+
" only forward one of the two."
|
530 |
+
)
|
531 |
+
elif prompt is None and prompt_embeds is None:
|
532 |
+
raise ValueError(
|
533 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
534 |
+
)
|
535 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
536 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
537 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
538 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
539 |
+
|
540 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
541 |
+
raise ValueError(
|
542 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
543 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
544 |
+
)
|
545 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
546 |
+
raise ValueError(
|
547 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
548 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
549 |
+
)
|
550 |
+
|
551 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
552 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
553 |
+
raise ValueError(
|
554 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
555 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
556 |
+
f" {negative_prompt_embeds.shape}."
|
557 |
+
)
|
558 |
+
|
559 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
560 |
+
raise ValueError(
|
561 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
562 |
+
)
|
563 |
+
|
564 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
565 |
+
raise ValueError(
|
566 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
567 |
+
)
|
568 |
+
def prepare_img_latents(
|
569 |
+
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
|
570 |
+
):
|
571 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
572 |
+
raise ValueError(
|
573 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
574 |
+
)
|
575 |
+
|
576 |
+
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
577 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
578 |
+
self.text_encoder_2.to("cpu")
|
579 |
+
torch.cuda.empty_cache()
|
580 |
+
|
581 |
+
image = image.to(device=device, dtype=dtype)
|
582 |
+
|
583 |
+
batch_size = batch_size * num_images_per_prompt
|
584 |
+
|
585 |
+
if image.shape[1] == 4:
|
586 |
+
init_latents = image
|
587 |
+
|
588 |
+
else:
|
589 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
590 |
+
if self.vae.config.force_upcast:
|
591 |
+
image = image.float()
|
592 |
+
self.vae.to(dtype=torch.float32)
|
593 |
+
|
594 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
595 |
+
raise ValueError(
|
596 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
597 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
598 |
+
)
|
599 |
+
|
600 |
+
elif isinstance(generator, list):
|
601 |
+
init_latents = [
|
602 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
603 |
+
for i in range(batch_size)
|
604 |
+
]
|
605 |
+
init_latents = torch.cat(init_latents, dim=0)
|
606 |
+
else:
|
607 |
+
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
608 |
+
|
609 |
+
if self.vae.config.force_upcast:
|
610 |
+
self.vae.to(dtype)
|
611 |
+
|
612 |
+
init_latents = init_latents.to(dtype)
|
613 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
614 |
+
|
615 |
+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
616 |
+
# expand init_latents for batch_size
|
617 |
+
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
618 |
+
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
619 |
+
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
620 |
+
raise ValueError(
|
621 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
622 |
+
)
|
623 |
+
else:
|
624 |
+
init_latents = torch.cat([init_latents], dim=0)
|
625 |
+
|
626 |
+
if add_noise:
|
627 |
+
shape = init_latents.shape
|
628 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
629 |
+
# get latents
|
630 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
631 |
+
|
632 |
+
latents = init_latents
|
633 |
+
|
634 |
+
return latents
|
635 |
+
|
636 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
637 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
638 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
639 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
640 |
+
raise ValueError(
|
641 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
642 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
643 |
+
)
|
644 |
+
|
645 |
+
if latents is None:
|
646 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
647 |
+
else:
|
648 |
+
latents = latents.to(device)
|
649 |
+
|
650 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
651 |
+
latents = latents * self.scheduler.init_noise_sigma
|
652 |
+
return latents
|
653 |
+
|
654 |
+
def _get_add_time_ids(
|
655 |
+
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
656 |
+
):
|
657 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
658 |
+
|
659 |
+
passed_add_embed_dim = (
|
660 |
+
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
661 |
+
)
|
662 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
663 |
+
|
664 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
665 |
+
raise ValueError(
|
666 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
667 |
+
)
|
668 |
+
|
669 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
670 |
+
return add_time_ids
|
671 |
+
|
672 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
673 |
+
def upcast_vae(self):
|
674 |
+
dtype = self.vae.dtype
|
675 |
+
self.vae.to(dtype=torch.float32)
|
676 |
+
use_torch_2_0_or_xformers = isinstance(
|
677 |
+
self.vae.decoder.mid_block.attentions[0].processor,
|
678 |
+
(
|
679 |
+
AttnProcessor2_0,
|
680 |
+
XFormersAttnProcessor,
|
681 |
+
LoRAXFormersAttnProcessor,
|
682 |
+
LoRAAttnProcessor2_0,
|
683 |
+
),
|
684 |
+
)
|
685 |
+
# if xformers or torch_2_0 is used attention block does not need
|
686 |
+
# to be in float32 which can save lots of memory
|
687 |
+
if use_torch_2_0_or_xformers:
|
688 |
+
self.vae.post_quant_conv.to(dtype)
|
689 |
+
self.vae.decoder.conv_in.to(dtype)
|
690 |
+
self.vae.decoder.mid_block.to(dtype)
|
691 |
+
|
692 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
|
693 |
+
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
694 |
+
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
695 |
+
|
696 |
+
The suffixes after the scaling factors represent the stages where they are being applied.
|
697 |
+
|
698 |
+
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
699 |
+
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
700 |
+
|
701 |
+
Args:
|
702 |
+
s1 (`float`):
|
703 |
+
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
704 |
+
mitigate "oversmoothing effect" in the enhanced denoising process.
|
705 |
+
s2 (`float`):
|
706 |
+
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
707 |
+
mitigate "oversmoothing effect" in the enhanced denoising process.
|
708 |
+
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
709 |
+
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
710 |
+
"""
|
711 |
+
if not hasattr(self, "unet"):
|
712 |
+
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
713 |
+
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
714 |
+
|
715 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
|
716 |
+
def disable_freeu(self):
|
717 |
+
"""Disables the FreeU mechanism if enabled."""
|
718 |
+
self.unet.disable_freeu()
|
719 |
+
|
720 |
+
@property
|
721 |
+
def guidance_scale(self):
|
722 |
+
return self._guidance_scale
|
723 |
+
|
724 |
+
@property
|
725 |
+
def guidance_rescale(self):
|
726 |
+
return self._guidance_rescale
|
727 |
+
|
728 |
+
@property
|
729 |
+
def clip_skip(self):
|
730 |
+
return self._clip_skip
|
731 |
+
|
732 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
733 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
734 |
+
# corresponds to doing no classifier free guidance.
|
735 |
+
@property
|
736 |
+
def do_classifier_free_guidance(self):
|
737 |
+
return self._guidance_scale > 1
|
738 |
+
|
739 |
+
@property
|
740 |
+
def cross_attention_kwargs(self):
|
741 |
+
return self._cross_attention_kwargs
|
742 |
+
|
743 |
+
@property
|
744 |
+
def denoising_end(self):
|
745 |
+
return self._denoising_end
|
746 |
+
|
747 |
+
@property
|
748 |
+
def num_timesteps(self):
|
749 |
+
return self._num_timesteps
|
750 |
+
|
751 |
+
@torch.no_grad()
|
752 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
753 |
+
def __call__(
|
754 |
+
self,
|
755 |
+
prompt: Union[str, List[str]] = None,
|
756 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
757 |
+
height: Optional[int] = None,
|
758 |
+
width: Optional[int] = None,
|
759 |
+
num_inference_steps: int = 50,
|
760 |
+
denoising_end: Optional[float] = None,
|
761 |
+
guidance_scale: float = 5.0,
|
762 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
763 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
764 |
+
num_images_per_prompt: Optional[int] = 1,
|
765 |
+
eta: float = 0.0,
|
766 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
767 |
+
latents: Optional[torch.FloatTensor] = None,
|
768 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
769 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
770 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
771 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
772 |
+
output_type: Optional[str] = "pil",
|
773 |
+
return_dict: bool = True,
|
774 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
775 |
+
guidance_rescale: float = 0.0,
|
776 |
+
original_size: Optional[Tuple[int, int]] = None,
|
777 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
778 |
+
target_size: Optional[Tuple[int, int]] = None,
|
779 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
780 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
781 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
782 |
+
clip_skip: Optional[int] = None,
|
783 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
784 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
785 |
+
use_prompt_as_null = False,
|
786 |
+
image = None,
|
787 |
+
**kwargs,
|
788 |
+
):
|
789 |
+
r"""
|
790 |
+
Function invoked when calling the pipeline for generation.
|
791 |
+
|
792 |
+
Args:
|
793 |
+
prompt (`str` or `List[str]`, *optional*):
|
794 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
795 |
+
instead.
|
796 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
797 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
798 |
+
used in both text-encoders
|
799 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
800 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
801 |
+
Anything below 512 pixels won't work well for
|
802 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
803 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
804 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
805 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
806 |
+
Anything below 512 pixels won't work well for
|
807 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
808 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
809 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
810 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
811 |
+
expense of slower inference.
|
812 |
+
denoising_end (`float`, *optional*):
|
813 |
+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
814 |
+
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
815 |
+
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
816 |
+
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
817 |
+
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
818 |
+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
819 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
820 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
821 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
822 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
823 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
824 |
+
usually at the expense of lower image quality.
|
825 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
826 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
827 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
828 |
+
less than `1`).
|
829 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
830 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
831 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
832 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
833 |
+
The number of images to generate per prompt.
|
834 |
+
eta (`float`, *optional*, defaults to 0.0):
|
835 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
836 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
837 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
838 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
839 |
+
to make generation deterministic.
|
840 |
+
latents (`torch.FloatTensor`, *optional*):
|
841 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
842 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
843 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
844 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
845 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
846 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
847 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
848 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
849 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
850 |
+
argument.
|
851 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
852 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
853 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
854 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
855 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
856 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
857 |
+
input argument.
|
858 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
859 |
+
The output format of the generate image. Choose between
|
860 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
861 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
862 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
863 |
+
of a plain tuple.
|
864 |
+
cross_attention_kwargs (`dict`, *optional*):
|
865 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
866 |
+
`self.processor` in
|
867 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
868 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
869 |
+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
870 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
871 |
+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
872 |
+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
873 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
874 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
875 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
876 |
+
explained in section 2.2 of
|
877 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
878 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
879 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
880 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
881 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
882 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
883 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
884 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
885 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
886 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
887 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
888 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
889 |
+
micro-conditioning as explained in section 2.2 of
|
890 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
891 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
892 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
893 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
894 |
+
micro-conditioning as explained in section 2.2 of
|
895 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
896 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
897 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
898 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
899 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
900 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
901 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
902 |
+
callback_on_step_end (`Callable`, *optional*):
|
903 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
904 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
905 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
906 |
+
`callback_on_step_end_tensor_inputs`.
|
907 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
908 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
909 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
910 |
+
`._callback_tensor_inputs` attribute of your pipeine class.
|
911 |
+
|
912 |
+
Examples:
|
913 |
+
|
914 |
+
Returns:
|
915 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
916 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
917 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
918 |
+
"""
|
919 |
+
|
920 |
+
|
921 |
+
|
922 |
+
|
923 |
+
callback = kwargs.pop("callback", None)
|
924 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
925 |
+
|
926 |
+
if callback is not None:
|
927 |
+
deprecate(
|
928 |
+
"callback",
|
929 |
+
"1.0.0",
|
930 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
931 |
+
)
|
932 |
+
if callback_steps is not None:
|
933 |
+
deprecate(
|
934 |
+
"callback_steps",
|
935 |
+
"1.0.0",
|
936 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
937 |
+
)
|
938 |
+
|
939 |
+
|
940 |
+
if image is not None:
|
941 |
+
z0 = self.image_processor.preprocess(image)
|
942 |
+
|
943 |
+
|
944 |
+
# 0. Default height and width to unet
|
945 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
946 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
947 |
+
|
948 |
+
original_size = original_size or (height, width)
|
949 |
+
target_size = target_size or (height, width)
|
950 |
+
|
951 |
+
# 1. Check inputs. Raise error if not correct
|
952 |
+
self.check_inputs(
|
953 |
+
prompt,
|
954 |
+
prompt_2,
|
955 |
+
height,
|
956 |
+
width,
|
957 |
+
callback_steps,
|
958 |
+
negative_prompt,
|
959 |
+
negative_prompt_2,
|
960 |
+
prompt_embeds,
|
961 |
+
negative_prompt_embeds,
|
962 |
+
pooled_prompt_embeds,
|
963 |
+
negative_pooled_prompt_embeds,
|
964 |
+
callback_on_step_end_tensor_inputs,
|
965 |
+
)
|
966 |
+
|
967 |
+
self._guidance_scale = guidance_scale
|
968 |
+
self._guidance_rescale = guidance_rescale
|
969 |
+
self._clip_skip = clip_skip
|
970 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
971 |
+
self._denoising_end = denoising_end
|
972 |
+
|
973 |
+
# 2. Define call parameters
|
974 |
+
if prompt is not None and isinstance(prompt, str):
|
975 |
+
batch_size = 1
|
976 |
+
elif prompt is not None and isinstance(prompt, list):
|
977 |
+
batch_size = len(prompt)
|
978 |
+
else:
|
979 |
+
batch_size = prompt_embeds.shape[0]
|
980 |
+
|
981 |
+
device = self._execution_device
|
982 |
+
|
983 |
+
# 3. Encode input prompt
|
984 |
+
lora_scale = (
|
985 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
986 |
+
)
|
987 |
+
|
988 |
+
|
989 |
+
(
|
990 |
+
prompt_embeds,
|
991 |
+
negative_prompt_embeds,
|
992 |
+
pooled_prompt_embeds,
|
993 |
+
negative_pooled_prompt_embeds,
|
994 |
+
) = self.encode_prompt(
|
995 |
+
prompt=prompt,
|
996 |
+
prompt_2=prompt_2,
|
997 |
+
device=device,
|
998 |
+
num_images_per_prompt=num_images_per_prompt,
|
999 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1000 |
+
negative_prompt=negative_prompt,
|
1001 |
+
negative_prompt_2=negative_prompt_2,
|
1002 |
+
prompt_embeds=prompt_embeds,
|
1003 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1004 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
1005 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
1006 |
+
lora_scale=lora_scale,
|
1007 |
+
clip_skip=self.clip_skip,
|
1008 |
+
)
|
1009 |
+
|
1010 |
+
if kwargs['target_prompt'] is not None:
|
1011 |
+
(
|
1012 |
+
prompt_embeds_,
|
1013 |
+
negative_prompt_embeds_,
|
1014 |
+
pooled_prompt_embeds_,
|
1015 |
+
negative_pooled_prompt_embeds_,
|
1016 |
+
) = self.encode_prompt(
|
1017 |
+
prompt=kwargs['target_prompt'],
|
1018 |
+
prompt_2=prompt_2,
|
1019 |
+
device=device,
|
1020 |
+
num_images_per_prompt=num_images_per_prompt,
|
1021 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1022 |
+
# negative_prompt=negative_prompt,
|
1023 |
+
negative_prompt=None, #if kwargs["target_neg"] is None else kwargs["target_neg"],
|
1024 |
+
# negative_prompt_2=negative_prompt_2,
|
1025 |
+
negative_prompt_2=None,
|
1026 |
+
prompt_embeds=None,
|
1027 |
+
negative_prompt_embeds=None,
|
1028 |
+
pooled_prompt_embeds=None,
|
1029 |
+
negative_pooled_prompt_embeds=None,
|
1030 |
+
lora_scale=lora_scale,
|
1031 |
+
clip_skip=self.clip_skip,
|
1032 |
+
)
|
1033 |
+
|
1034 |
+
prompt_embeds[1:] = prompt_embeds_[1:]
|
1035 |
+
pooled_prompt_embeds[1:] = pooled_prompt_embeds_[1:]
|
1036 |
+
if not kwargs['use_inf_negative_prompt']:
|
1037 |
+
negative_prompt_embeds[1:] = negative_prompt_embeds_[1:]
|
1038 |
+
negative_pooled_prompt_embeds[1:] = negative_pooled_prompt_embeds_[1:]
|
1039 |
+
|
1040 |
+
|
1041 |
+
# 4. Prepare timesteps
|
1042 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
1043 |
+
|
1044 |
+
timesteps = self.scheduler.timesteps
|
1045 |
+
|
1046 |
+
|
1047 |
+
# 5. Prepare latent variables
|
1048 |
+
num_channels_latents = self.unet.config.in_channels
|
1049 |
+
latents = self.prepare_latents(
|
1050 |
+
batch_size * num_images_per_prompt,
|
1051 |
+
num_channels_latents,
|
1052 |
+
height,
|
1053 |
+
width,
|
1054 |
+
prompt_embeds.dtype,
|
1055 |
+
device,
|
1056 |
+
generator,
|
1057 |
+
latents,
|
1058 |
+
)
|
1059 |
+
|
1060 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1061 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
1062 |
+
|
1063 |
+
# 7. Prepare added time ids & embeddings
|
1064 |
+
add_text_embeds = pooled_prompt_embeds
|
1065 |
+
if self.text_encoder_2 is None:
|
1066 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
1067 |
+
else:
|
1068 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
1069 |
+
|
1070 |
+
add_time_ids = self._get_add_time_ids(
|
1071 |
+
original_size,
|
1072 |
+
crops_coords_top_left,
|
1073 |
+
target_size,
|
1074 |
+
dtype=prompt_embeds.dtype,
|
1075 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1076 |
+
)
|
1077 |
+
if negative_original_size is not None and negative_target_size is not None:
|
1078 |
+
negative_add_time_ids = self._get_add_time_ids(
|
1079 |
+
negative_original_size,
|
1080 |
+
negative_crops_coords_top_left,
|
1081 |
+
negative_target_size,
|
1082 |
+
dtype=prompt_embeds.dtype,
|
1083 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1084 |
+
)
|
1085 |
+
else:
|
1086 |
+
negative_add_time_ids = add_time_ids
|
1087 |
+
|
1088 |
+
if self.do_classifier_free_guidance:
|
1089 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1090 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
1091 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
1092 |
+
|
1093 |
+
prompt_embeds = prompt_embeds.to(device)
|
1094 |
+
add_text_embeds = add_text_embeds.to(device)
|
1095 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
1096 |
+
|
1097 |
+
# 8. Denoising loop
|
1098 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
1099 |
+
|
1100 |
+
# 8.1 Apply denoising_end
|
1101 |
+
if (
|
1102 |
+
self.denoising_end is not None
|
1103 |
+
and isinstance(self.denoising_end, float)
|
1104 |
+
and self.denoising_end > 0
|
1105 |
+
and self.denoising_end < 1
|
1106 |
+
):
|
1107 |
+
discrete_timestep_cutoff = int(
|
1108 |
+
round(
|
1109 |
+
self.scheduler.config.num_train_timesteps
|
1110 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
1111 |
+
)
|
1112 |
+
)
|
1113 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
1114 |
+
timesteps = timesteps[:num_inference_steps]
|
1115 |
+
|
1116 |
+
self._num_timesteps = len(timesteps)
|
1117 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1118 |
+
for i, t in enumerate(timesteps):
|
1119 |
+
|
1120 |
+
|
1121 |
+
if image is not None:
|
1122 |
+
zt = self.prepare_img_latents(z0,t.repeat(1),1, num_images_per_prompt,prompt_embeds.dtype,device,generator,True)# add_noise/
|
1123 |
+
|
1124 |
+
latents[0] = zt[0]
|
1125 |
+
|
1126 |
+
# expand the latents if we are doing classifier free guidance
|
1127 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1128 |
+
|
1129 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1130 |
+
|
1131 |
+
# predict the noise residual
|
1132 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
1133 |
+
noise_pred = self.unet(
|
1134 |
+
latent_model_input,
|
1135 |
+
t,
|
1136 |
+
encoder_hidden_states=prompt_embeds,
|
1137 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
1138 |
+
added_cond_kwargs=added_cond_kwargs,
|
1139 |
+
return_dict=False,
|
1140 |
+
)[0]
|
1141 |
+
|
1142 |
+
|
1143 |
+
# perform guidance
|
1144 |
+
if self.do_classifier_free_guidance:
|
1145 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1146 |
+
|
1147 |
+
# tmp_noise_pred_text = noise_pred_text[0]######### reconstruction only
|
1148 |
+
# import pdb; pdb.set_trace()
|
1149 |
+
# if 1 < i < 3 and kwargs["use_advanced_sampling"]:
|
1150 |
+
if i < 3 and kwargs["use_advanced_sampling"]:
|
1151 |
+
noise_pred = noise_pred_uncond + 20.0 * (noise_pred_text - noise_pred_uncond)
|
1152 |
+
# noise_pred[0] = noise_pred_uncond[0] + self.guidance_scale * (noise_pred_text[0] - noise_pred_uncond[0])
|
1153 |
+
else:
|
1154 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1155 |
+
|
1156 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
1157 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1158 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
1159 |
+
|
1160 |
+
if use_prompt_as_null:
|
1161 |
+
noise_pred[0] = noise_pred_text[0]
|
1162 |
+
|
1163 |
+
|
1164 |
+
# noise_pred[0] = tmp_noise_pred_text######## reconstruction only
|
1165 |
+
|
1166 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1167 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1168 |
+
|
1169 |
+
if callback_on_step_end is not None:
|
1170 |
+
callback_kwargs = {}
|
1171 |
+
for k in callback_on_step_end_tensor_inputs:
|
1172 |
+
callback_kwargs[k] = locals()[k]
|
1173 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1174 |
+
|
1175 |
+
latents = callback_outputs.pop("latents", latents)
|
1176 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1177 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1178 |
+
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
1179 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
1180 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
1181 |
+
)
|
1182 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
1183 |
+
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
|
1184 |
+
|
1185 |
+
# call the callback, if provided
|
1186 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1187 |
+
progress_bar.update()
|
1188 |
+
if callback is not None and i % callback_steps == 0:
|
1189 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1190 |
+
callback(step_idx, t, latents)
|
1191 |
+
|
1192 |
+
if XLA_AVAILABLE:
|
1193 |
+
xm.mark_step()
|
1194 |
+
|
1195 |
+
if not output_type == "latent":
|
1196 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
1197 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
1198 |
+
|
1199 |
+
if needs_upcasting:
|
1200 |
+
self.upcast_vae()
|
1201 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
1202 |
+
self.enable_vae_slicing()
|
1203 |
+
|
1204 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
1205 |
+
|
1206 |
+
# cast back to fp16 if needed
|
1207 |
+
if needs_upcasting:
|
1208 |
+
self.vae.to(dtype=torch.float16)
|
1209 |
+
else:
|
1210 |
+
image = latents
|
1211 |
+
|
1212 |
+
if not output_type == "latent":
|
1213 |
+
# apply watermark if available
|
1214 |
+
if self.watermark is not None:
|
1215 |
+
image = self.watermark.apply_watermark(image)
|
1216 |
+
|
1217 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1218 |
+
|
1219 |
+
# Offload all models
|
1220 |
+
self.maybe_free_model_hooks()
|
1221 |
+
|
1222 |
+
if not return_dict:
|
1223 |
+
return (image,)
|
1224 |
+
|
1225 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
1226 |
+
|
1227 |
+
@torch.no_grad()
|
1228 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
1229 |
+
def inverted_ve_cross_frame_attn(
|
1230 |
+
self,
|
1231 |
+
prompt: Union[str, List[str]] = None,
|
1232 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
1233 |
+
height: Optional[int] = None,
|
1234 |
+
width: Optional[int] = None,
|
1235 |
+
num_inference_steps: int = 50,
|
1236 |
+
denoising_end: Optional[float] = None,
|
1237 |
+
guidance_scale: float = 5.0,
|
1238 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1239 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
1240 |
+
num_images_per_prompt: Optional[int] = 1,
|
1241 |
+
eta: float = 0.0,
|
1242 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
1243 |
+
latents: Optional[torch.FloatTensor] = None,
|
1244 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
1245 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1246 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1247 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1248 |
+
output_type: Optional[str] = "pil",
|
1249 |
+
return_dict: bool = True,
|
1250 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1251 |
+
guidance_rescale: float = 0.0,
|
1252 |
+
original_size: Optional[Tuple[int, int]] = None,
|
1253 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
1254 |
+
target_size: Optional[Tuple[int, int]] = None,
|
1255 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
1256 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
1257 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
1258 |
+
clip_skip: Optional[int] = None,
|
1259 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
1260 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
1261 |
+
**kwargs,
|
1262 |
+
):
|
1263 |
+
r"""
|
1264 |
+
Function invoked when calling the pipeline for generation.
|
1265 |
+
|
1266 |
+
Args:
|
1267 |
+
prompt (`str` or `List[str]`, *optional*):
|
1268 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
1269 |
+
instead.
|
1270 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
1271 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
1272 |
+
used in both text-encoders
|
1273 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
1274 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
1275 |
+
Anything below 512 pixels won't work well for
|
1276 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
1277 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
1278 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
1279 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
1280 |
+
Anything below 512 pixels won't work well for
|
1281 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
1282 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
1283 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1284 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1285 |
+
expense of slower inference.
|
1286 |
+
denoising_end (`float`, *optional*):
|
1287 |
+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
1288 |
+
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
1289 |
+
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
1290 |
+
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
1291 |
+
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
1292 |
+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
1293 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
1294 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1295 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1296 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1297 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1298 |
+
usually at the expense of lower image quality.
|
1299 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1300 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
1301 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
1302 |
+
less than `1`).
|
1303 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
1304 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
1305 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
1306 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1307 |
+
The number of images to generate per prompt.
|
1308 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1309 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1310 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1311 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
1312 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
1313 |
+
to make generation deterministic.
|
1314 |
+
latents (`torch.FloatTensor`, *optional*):
|
1315 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
1316 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
1317 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
1318 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
1319 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
1320 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
1321 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1322 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
1323 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
1324 |
+
argument.
|
1325 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1326 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
1327 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
1328 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1329 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
1330 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
1331 |
+
input argument.
|
1332 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1333 |
+
The output format of the generate image. Choose between
|
1334 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1335 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1336 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
1337 |
+
of a plain tuple.
|
1338 |
+
cross_attention_kwargs (`dict`, *optional*):
|
1339 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
1340 |
+
`self.processor` in
|
1341 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
1342 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
1343 |
+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
1344 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
1345 |
+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
1346 |
+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
1347 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
1348 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
1349 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
1350 |
+
explained in section 2.2 of
|
1351 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
1352 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
1353 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
1354 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
1355 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
1356 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
1357 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
1358 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
1359 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
1360 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
1361 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
1362 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
1363 |
+
micro-conditioning as explained in section 2.2 of
|
1364 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
1365 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
1366 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
1367 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
1368 |
+
micro-conditioning as explained in section 2.2 of
|
1369 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
1370 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
1371 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
1372 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
1373 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
1374 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
1375 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
1376 |
+
callback_on_step_end (`Callable`, *optional*):
|
1377 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
1378 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
1379 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
1380 |
+
`callback_on_step_end_tensor_inputs`.
|
1381 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
1382 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
1383 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
1384 |
+
`._callback_tensor_inputs` attribute of your pipeine class.
|
1385 |
+
|
1386 |
+
Examples:
|
1387 |
+
|
1388 |
+
Returns:
|
1389 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
1390 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
1391 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
1392 |
+
"""
|
1393 |
+
|
1394 |
+
|
1395 |
+
|
1396 |
+
callback = kwargs.pop("callback", None)
|
1397 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
1398 |
+
|
1399 |
+
if callback is not None:
|
1400 |
+
deprecate(
|
1401 |
+
"callback",
|
1402 |
+
"1.0.0",
|
1403 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
1404 |
+
)
|
1405 |
+
if callback_steps is not None:
|
1406 |
+
deprecate(
|
1407 |
+
"callback_steps",
|
1408 |
+
"1.0.0",
|
1409 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
1410 |
+
)
|
1411 |
+
|
1412 |
+
# 0. Default height and width to unet
|
1413 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
1414 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
1415 |
+
|
1416 |
+
original_size = original_size or (height, width)
|
1417 |
+
target_size = target_size or (height, width)
|
1418 |
+
|
1419 |
+
# 1. Check inputs. Raise error if not correct
|
1420 |
+
self.check_inputs(
|
1421 |
+
prompt,
|
1422 |
+
prompt_2,
|
1423 |
+
height,
|
1424 |
+
width,
|
1425 |
+
callback_steps,
|
1426 |
+
negative_prompt,
|
1427 |
+
negative_prompt_2,
|
1428 |
+
prompt_embeds,
|
1429 |
+
negative_prompt_embeds,
|
1430 |
+
pooled_prompt_embeds,
|
1431 |
+
negative_pooled_prompt_embeds,
|
1432 |
+
callback_on_step_end_tensor_inputs,
|
1433 |
+
)
|
1434 |
+
|
1435 |
+
self._guidance_scale = guidance_scale
|
1436 |
+
self._guidance_rescale = guidance_rescale
|
1437 |
+
self._clip_skip = clip_skip
|
1438 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
1439 |
+
self._denoising_end = denoising_end
|
1440 |
+
|
1441 |
+
# 2. Define call parameters
|
1442 |
+
if prompt is not None and isinstance(prompt, str):
|
1443 |
+
batch_size = 1
|
1444 |
+
elif prompt is not None and isinstance(prompt, list):
|
1445 |
+
batch_size = len(prompt)
|
1446 |
+
else:
|
1447 |
+
batch_size = prompt_embeds.shape[0]
|
1448 |
+
|
1449 |
+
device = self._execution_device
|
1450 |
+
|
1451 |
+
# 3. Encode input prompt
|
1452 |
+
lora_scale = (
|
1453 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
1454 |
+
)
|
1455 |
+
|
1456 |
+
(
|
1457 |
+
prompt_embeds,
|
1458 |
+
negative_prompt_embeds,
|
1459 |
+
pooled_prompt_embeds,
|
1460 |
+
negative_pooled_prompt_embeds,
|
1461 |
+
) = self.encode_prompt(
|
1462 |
+
prompt=prompt,
|
1463 |
+
prompt_2=prompt_2,
|
1464 |
+
device=device,
|
1465 |
+
num_images_per_prompt=num_images_per_prompt,
|
1466 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1467 |
+
negative_prompt=negative_prompt,
|
1468 |
+
negative_prompt_2=negative_prompt_2,
|
1469 |
+
prompt_embeds=prompt_embeds,
|
1470 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1471 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
1472 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
1473 |
+
lora_scale=lora_scale,
|
1474 |
+
clip_skip=self.clip_skip,
|
1475 |
+
)
|
1476 |
+
|
1477 |
+
if kwargs['target_prompt'] is not None:
|
1478 |
+
(
|
1479 |
+
prompt_embeds_,
|
1480 |
+
negative_prompt_embeds_,
|
1481 |
+
_,
|
1482 |
+
_,
|
1483 |
+
) = self.encode_prompt(
|
1484 |
+
prompt=kwargs['target_prompt'],
|
1485 |
+
prompt_2=prompt_2,
|
1486 |
+
device=device,
|
1487 |
+
num_images_per_prompt=num_images_per_prompt,
|
1488 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1489 |
+
negative_prompt=kwargs['target_negative_prompt'] if kwargs['target_negative_prompt'] is not None else None,
|
1490 |
+
# negative_prompt=None,
|
1491 |
+
# negative_prompt_2=negative_prompt_2,
|
1492 |
+
negative_prompt_2=None,
|
1493 |
+
prompt_embeds=None,
|
1494 |
+
negative_prompt_embeds=None,
|
1495 |
+
pooled_prompt_embeds=None,
|
1496 |
+
negative_pooled_prompt_embeds=None,
|
1497 |
+
lora_scale=lora_scale,
|
1498 |
+
clip_skip=self.clip_skip,
|
1499 |
+
)
|
1500 |
+
prompt_embeds[1:] = prompt_embeds_[1:]
|
1501 |
+
if negative_prompt_embeds_ is not None:
|
1502 |
+
negative_prompt_embeds[1:] = negative_prompt_embeds_[1:]
|
1503 |
+
|
1504 |
+
|
1505 |
+
# 4. Prepare timesteps
|
1506 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
1507 |
+
|
1508 |
+
timesteps = self.scheduler.timesteps
|
1509 |
+
|
1510 |
+
# 5. Prepare latent variables
|
1511 |
+
num_channels_latents = self.unet.config.in_channels
|
1512 |
+
latents = self.prepare_latents(
|
1513 |
+
batch_size * num_images_per_prompt,
|
1514 |
+
num_channels_latents,
|
1515 |
+
height,
|
1516 |
+
width,
|
1517 |
+
prompt_embeds.dtype,
|
1518 |
+
device,
|
1519 |
+
generator,
|
1520 |
+
latents,
|
1521 |
+
)
|
1522 |
+
|
1523 |
+
|
1524 |
+
latents_ = self.prepare_latents(
|
1525 |
+
batch_size * num_images_per_prompt,
|
1526 |
+
num_channels_latents,
|
1527 |
+
height,
|
1528 |
+
width,
|
1529 |
+
prompt_embeds.dtype,
|
1530 |
+
device,
|
1531 |
+
generator,
|
1532 |
+
# latents,
|
1533 |
+
)
|
1534 |
+
|
1535 |
+
# import pdb; pdb.set_trace()
|
1536 |
+
|
1537 |
+
# latents[1:] = latents_[1:]
|
1538 |
+
latents = torch.cat([latents.unsqueeze(0), latents_[1:]], dim=0)
|
1539 |
+
|
1540 |
+
|
1541 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1542 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
1543 |
+
|
1544 |
+
# 7. Prepare added time ids & embeddings
|
1545 |
+
add_text_embeds = pooled_prompt_embeds
|
1546 |
+
if self.text_encoder_2 is None:
|
1547 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
1548 |
+
else:
|
1549 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
1550 |
+
|
1551 |
+
add_time_ids = self._get_add_time_ids(
|
1552 |
+
original_size,
|
1553 |
+
crops_coords_top_left,
|
1554 |
+
target_size,
|
1555 |
+
dtype=prompt_embeds.dtype,
|
1556 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1557 |
+
)
|
1558 |
+
if negative_original_size is not None and negative_target_size is not None:
|
1559 |
+
negative_add_time_ids = self._get_add_time_ids(
|
1560 |
+
negative_original_size,
|
1561 |
+
negative_crops_coords_top_left,
|
1562 |
+
negative_target_size,
|
1563 |
+
dtype=prompt_embeds.dtype,
|
1564 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1565 |
+
)
|
1566 |
+
else:
|
1567 |
+
negative_add_time_ids = add_time_ids
|
1568 |
+
|
1569 |
+
if self.do_classifier_free_guidance:
|
1570 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1571 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
1572 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
1573 |
+
|
1574 |
+
prompt_embeds = prompt_embeds.to(device)
|
1575 |
+
add_text_embeds = add_text_embeds.to(device)
|
1576 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
1577 |
+
|
1578 |
+
# 8. Denoising loop
|
1579 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
1580 |
+
|
1581 |
+
# 8.1 Apply denoising_end
|
1582 |
+
if (
|
1583 |
+
self.denoising_end is not None
|
1584 |
+
and isinstance(self.denoising_end, float)
|
1585 |
+
and self.denoising_end > 0
|
1586 |
+
and self.denoising_end < 1
|
1587 |
+
):
|
1588 |
+
discrete_timestep_cutoff = int(
|
1589 |
+
round(
|
1590 |
+
self.scheduler.config.num_train_timesteps
|
1591 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
1592 |
+
)
|
1593 |
+
)
|
1594 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
1595 |
+
timesteps = timesteps[:num_inference_steps]
|
1596 |
+
|
1597 |
+
self._num_timesteps = len(timesteps)
|
1598 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1599 |
+
for i, t in enumerate(timesteps):
|
1600 |
+
# expand the latents if we are doing classifier free guidance
|
1601 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1602 |
+
|
1603 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1604 |
+
|
1605 |
+
# predict the noise residual
|
1606 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
1607 |
+
noise_pred = self.unet(
|
1608 |
+
latent_model_input,
|
1609 |
+
t,
|
1610 |
+
encoder_hidden_states=prompt_embeds,
|
1611 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
1612 |
+
added_cond_kwargs=added_cond_kwargs,
|
1613 |
+
return_dict=False,
|
1614 |
+
)[0]
|
1615 |
+
|
1616 |
+
# perform guidance
|
1617 |
+
if self.do_classifier_free_guidance:
|
1618 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1619 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1620 |
+
noise_pred[0] = noise_pred_uncond[0] #추가된것
|
1621 |
+
|
1622 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
1623 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1624 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
1625 |
+
noise_pred[0] = noise_pred_uncond[0] #추가된것
|
1626 |
+
|
1627 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1628 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1629 |
+
|
1630 |
+
if callback_on_step_end is not None:
|
1631 |
+
callback_kwargs = {}
|
1632 |
+
for k in callback_on_step_end_tensor_inputs:
|
1633 |
+
callback_kwargs[k] = locals()[k]
|
1634 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1635 |
+
|
1636 |
+
latents = callback_outputs.pop("latents", latents)
|
1637 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1638 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1639 |
+
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
1640 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
1641 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
1642 |
+
)
|
1643 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
1644 |
+
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
|
1645 |
+
|
1646 |
+
# call the callback, if provided
|
1647 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1648 |
+
progress_bar.update()
|
1649 |
+
if callback is not None and i % callback_steps == 0:
|
1650 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1651 |
+
callback(step_idx, t, latents)
|
1652 |
+
|
1653 |
+
if XLA_AVAILABLE:
|
1654 |
+
xm.mark_step()
|
1655 |
+
|
1656 |
+
if not output_type == "latent":
|
1657 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
1658 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
1659 |
+
|
1660 |
+
if needs_upcasting:
|
1661 |
+
self.upcast_vae()
|
1662 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
1663 |
+
|
1664 |
+
self.enable_vae_slicing()
|
1665 |
+
|
1666 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
1667 |
+
|
1668 |
+
# cast back to fp16 if needed
|
1669 |
+
if needs_upcasting:
|
1670 |
+
self.vae.to(dtype=torch.float16)
|
1671 |
+
else:
|
1672 |
+
image = latents
|
1673 |
+
|
1674 |
+
if not output_type == "latent":
|
1675 |
+
# apply watermark if available
|
1676 |
+
if self.watermark is not None:
|
1677 |
+
image = self.watermark.apply_watermark(image)
|
1678 |
+
|
1679 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1680 |
+
|
1681 |
+
# Offload all models
|
1682 |
+
self.maybe_free_model_hooks()
|
1683 |
+
|
1684 |
+
if not return_dict:
|
1685 |
+
return (image,)
|
1686 |
+
|
1687 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
1688 |
+
|
1689 |
+
|
1690 |
+
@torch.no_grad()
|
1691 |
+
def activate_layer(self,
|
1692 |
+
activate_layer_indices,
|
1693 |
+
attn_map_save_steps=[],
|
1694 |
+
activate_step_indices = None,
|
1695 |
+
use_shared_attention = False,
|
1696 |
+
adain_queries=True,
|
1697 |
+
adain_keys=True,
|
1698 |
+
adain_values=False,
|
1699 |
+
):
|
1700 |
+
|
1701 |
+
|
1702 |
+
attn_procs = {}
|
1703 |
+
activate_layer = []
|
1704 |
+
str_activate_layer = ""
|
1705 |
+
for activate_layer_index in activate_layer_indices:
|
1706 |
+
activate_layer += ACTIVATE_LAYER_CANDIDATE[activate_layer_index[0]:activate_layer_index[1]]
|
1707 |
+
str_activate_layer += str(activate_layer_index)
|
1708 |
+
|
1709 |
+
str_activate_step = ""
|
1710 |
+
for activate_step_index in activate_step_indices:
|
1711 |
+
str_activate_step += str(activate_step_index)
|
1712 |
+
|
1713 |
+
for name in self.unet.attn_processors.keys():
|
1714 |
+
if name in activate_layer:
|
1715 |
+
if not use_shared_attention:
|
1716 |
+
attn_procs[name] = CrossFrameAttnProcessor(unet_chunk_size=2,
|
1717 |
+
attn_map_save_steps=attn_map_save_steps,
|
1718 |
+
activate_step_indices=activate_step_indices)
|
1719 |
+
else:
|
1720 |
+
|
1721 |
+
activate_save_layer = [
|
1722 |
+
'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor',
|
1723 |
+
'up_blocks.0.attentions.2.transformer_blocks.0.attn1.processor',
|
1724 |
+
'up_blocks.0.attentions.1.transformer_blocks.0.attn1.processor',
|
1725 |
+
'up_blocks.0.attentions.0.transformer_blocks.0.attn1.processor',
|
1726 |
+
'mid_block.attentions.0.transformer_blocks.0.attn1.processor'
|
1727 |
+
]
|
1728 |
+
if name in activate_save_layer:
|
1729 |
+
attn_procs[name] = SharedAttentionProcessor_v2(
|
1730 |
+
adain_keys=adain_keys,
|
1731 |
+
adain_queries=adain_queries,
|
1732 |
+
adain_values=adain_values,
|
1733 |
+
attn_map_save_steps = attn_map_save_steps,
|
1734 |
+
keys_scale=1.0,
|
1735 |
+
)
|
1736 |
+
else:
|
1737 |
+
attn_procs[name] = SharedAttentionProcessor(
|
1738 |
+
# unet_chunk_size=2,
|
1739 |
+
# attn_map_save_steps=attn_map_save_steps,
|
1740 |
+
# activate_step_indices=activate_step_indices,
|
1741 |
+
adain_keys=adain_keys,
|
1742 |
+
adain_queries=adain_queries,
|
1743 |
+
adain_values=adain_values,
|
1744 |
+
keys_scale=1.0,
|
1745 |
+
)
|
1746 |
+
else :
|
1747 |
+
attn_procs[name] = AttnProcessor()
|
1748 |
+
|
1749 |
+
self.unet.set_attn_processor(attn_procs)
|
1750 |
+
|
1751 |
+
return str_activate_layer, str_activate_step
|
1752 |
+
|
1753 |
+
@torch.no_grad()
|
1754 |
+
def get_init_latent(self,
|
1755 |
+
precomputed_path,
|
1756 |
+
seed):
|
1757 |
+
|
1758 |
+
|
1759 |
+
if not os.path.exists(precomputed_path):
|
1760 |
+
os.makedirs(precomputed_path)
|
1761 |
+
|
1762 |
+
#search init latents in precomputed latents
|
1763 |
+
init_latent_name = f'init_latent_{seed}.pt'
|
1764 |
+
init_latent_path = os.path.join(precomputed_path, init_latent_name)
|
1765 |
+
|
1766 |
+
|
1767 |
+
|
1768 |
+
# 0. Default height and width to unet
|
1769 |
+
height = self.default_sample_size * self.vae_scale_factor
|
1770 |
+
width = self.default_sample_size * self.vae_scale_factor
|
1771 |
+
|
1772 |
+
num_channels_latents = self.unet.config.in_channels
|
1773 |
+
|
1774 |
+
|
1775 |
+
|
1776 |
+
if not os.path.exists(init_latent_path):
|
1777 |
+
print(f'init_latent_{seed}.pt is not exist')
|
1778 |
+
# device= self._execution_device
|
1779 |
+
device = torch.device("cpu")
|
1780 |
+
generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
|
1781 |
+
|
1782 |
+
|
1783 |
+
shape = (1, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
1784 |
+
|
1785 |
+
init_latent = randn_tensor(shape, generator=generator, dtype = self.dtype, device=device)
|
1786 |
+
|
1787 |
+
torch.save(init_latent, init_latent_path)
|
1788 |
+
else:
|
1789 |
+
print(f'init_latent_{seed}.pt is exist')
|
1790 |
+
init_latent = torch.load(init_latent_path)
|
1791 |
+
|
1792 |
+
return init_latent
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
diffusers
|
3 |
+
transformers
|
4 |
+
accelerate
|
5 |
+
einops
|
6 |
+
kornia
|
7 |
+
gradio
|
8 |
+
torchvision
|
9 |
+
opencv-python
|
10 |
+
xformers
|
utils.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers.utils.torch_utils import randn_tensor
|
3 |
+
|
4 |
+
import json, os, cv2
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def parse_config(config):
|
9 |
+
with open(config, 'r') as f:
|
10 |
+
config = json.load(f)
|
11 |
+
return config
|
12 |
+
|
13 |
+
def load_config(config):
|
14 |
+
activate_layer_indices_list = config['inference_info']['activate_layer_indices_list']
|
15 |
+
activate_step_indices_list = config['inference_info']['activate_step_indices_list']
|
16 |
+
ref_seeds = config['reference_info']['ref_seeds']
|
17 |
+
inf_seeds = config['inference_info']['inf_seeds']
|
18 |
+
|
19 |
+
attn_map_save_steps = config['inference_info']['attn_map_save_steps']
|
20 |
+
precomputed_path = config['precomputed_path']
|
21 |
+
guidance_scale = config['guidance_scale']
|
22 |
+
use_inf_negative_prompt = config['inference_info']['use_negative_prompt']
|
23 |
+
|
24 |
+
style_name_list = config["style_name_list"]
|
25 |
+
ref_object_list = config["reference_info"]["ref_object_list"]
|
26 |
+
inf_object_list = config["inference_info"]["inf_object_list"]
|
27 |
+
ref_with_style_description = config['reference_info']['with_style_description']
|
28 |
+
inf_with_style_description = config['inference_info']['with_style_description']
|
29 |
+
|
30 |
+
|
31 |
+
use_shared_attention = config['inference_info']['use_shared_attention']
|
32 |
+
adain_queries = config['inference_info']['adain_queries']
|
33 |
+
adain_keys = config['inference_info']['adain_keys']
|
34 |
+
adain_values = config['inference_info']['adain_values']
|
35 |
+
use_advanced_sampling = config['inference_info']['use_advanced_sampling']
|
36 |
+
|
37 |
+
out = [
|
38 |
+
activate_layer_indices_list, activate_step_indices_list,
|
39 |
+
ref_seeds, inf_seeds,
|
40 |
+
attn_map_save_steps, precomputed_path, guidance_scale, use_inf_negative_prompt,
|
41 |
+
style_name_list, ref_object_list, inf_object_list, ref_with_style_description, inf_with_style_description,
|
42 |
+
use_shared_attention, adain_queries, adain_keys, adain_values, use_advanced_sampling
|
43 |
+
|
44 |
+
]
|
45 |
+
return out
|
46 |
+
|
47 |
+
def memory_efficient(model, device):
|
48 |
+
try:
|
49 |
+
model.to(device)
|
50 |
+
except Exception as e:
|
51 |
+
print("Error moving model to device:", e)
|
52 |
+
|
53 |
+
try:
|
54 |
+
model.enable_model_cpu_offload()
|
55 |
+
except AttributeError:
|
56 |
+
print("enable_model_cpu_offload is not supported.")
|
57 |
+
try:
|
58 |
+
model.enable_vae_slicing()
|
59 |
+
except AttributeError:
|
60 |
+
print("enable_vae_slicing is not supported.")
|
61 |
+
|
62 |
+
try:
|
63 |
+
model.enable_vae_tiling()
|
64 |
+
except AttributeError:
|
65 |
+
print("enable_vae_tiling is not supported.")
|
66 |
+
|
67 |
+
try:
|
68 |
+
model.enable_xformers_memory_efficient_attention()
|
69 |
+
except AttributeError:
|
70 |
+
print("enable_xformers_memory_efficient_attention is not supported.")
|
71 |
+
|
72 |
+
def init_latent(model, device_name='cuda', dtype=torch.float16, seed=None):
|
73 |
+
scale_factor = model.vae_scale_factor
|
74 |
+
sample_size = model.default_sample_size
|
75 |
+
latent_dim = model.unet.config.in_channels
|
76 |
+
|
77 |
+
height = sample_size * scale_factor
|
78 |
+
width = sample_size * scale_factor
|
79 |
+
|
80 |
+
shape = (1, latent_dim, height // scale_factor, width // scale_factor)
|
81 |
+
|
82 |
+
device = torch.device(device_name)
|
83 |
+
generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
|
84 |
+
|
85 |
+
latent = randn_tensor(shape, generator=generator, dtype=dtype, device=device)
|
86 |
+
|
87 |
+
return latent
|
88 |
+
|
89 |
+
|
90 |
+
def get_canny_edge_array(canny_img_path, threshold1=100,threshold2=200):
|
91 |
+
canny_image_list = []
|
92 |
+
|
93 |
+
# check if canny_img_path is a directory
|
94 |
+
if os.path.isdir(canny_img_path):
|
95 |
+
canny_img_list = os.listdir(canny_img_path)
|
96 |
+
for canny_img in canny_img_list:
|
97 |
+
canny_image_tmp = Image.open(os.path.join(canny_img_path, canny_img))
|
98 |
+
#resize image into1024x1024
|
99 |
+
canny_image_tmp = canny_image_tmp.resize((1024,1024))
|
100 |
+
canny_image_tmp = np.array(canny_image_tmp)
|
101 |
+
canny_image_tmp = cv2.Canny(canny_image_tmp, threshold1, threshold2)
|
102 |
+
canny_image_tmp = canny_image_tmp[:, :, None]
|
103 |
+
canny_image_tmp = np.concatenate([canny_image_tmp, canny_image_tmp, canny_image_tmp], axis=2)
|
104 |
+
canny_image = Image.fromarray(canny_image_tmp)
|
105 |
+
canny_image_list.append(canny_image)
|
106 |
+
|
107 |
+
return canny_image_list
|
108 |
+
|
109 |
+
def get_depth_map(image, feature_extractor, depth_estimator, device='cuda'):
|
110 |
+
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
|
111 |
+
with torch.no_grad(), torch.autocast(device):
|
112 |
+
depth_map = depth_estimator(image).predicted_depth
|
113 |
+
|
114 |
+
depth_map = torch.nn.functional.interpolate(
|
115 |
+
depth_map.unsqueeze(1),
|
116 |
+
size=(1024, 1024),
|
117 |
+
mode="bicubic",
|
118 |
+
align_corners=False,
|
119 |
+
)
|
120 |
+
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
|
121 |
+
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
|
122 |
+
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
|
123 |
+
image = torch.cat([depth_map] * 3, dim=1)
|
124 |
+
|
125 |
+
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
|
126 |
+
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
|
127 |
+
|
128 |
+
return image
|
129 |
+
|
130 |
+
def get_depth_edge_array(depth_img_path, feature_extractor, depth_estimator, device='cuda'):
|
131 |
+
depth_image_list = []
|
132 |
+
|
133 |
+
# check if canny_img_path is a directory
|
134 |
+
if os.path.isdir(depth_img_path):
|
135 |
+
depth_img_list = os.listdir(depth_img_path)
|
136 |
+
for depth_img in depth_img_list:
|
137 |
+
depth_image_tmp = Image.open(os.path.join(depth_img_path, depth_img)).convert('RGB')
|
138 |
+
|
139 |
+
# get depth map
|
140 |
+
depth_map = get_depth_map(depth_image_tmp, feature_extractor, depth_estimator, device)
|
141 |
+
depth_image_list.append(depth_map)
|
142 |
+
|
143 |
+
return depth_image_list
|
visualize_attention_src/__init__.py
ADDED
File without changes
|
visualize_attention_src/pipeline_stable_diffusion_xl_attn.py
ADDED
@@ -0,0 +1,1573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
20 |
+
|
21 |
+
from diffusers.image_processor import VaeImageProcessor
|
22 |
+
from diffusers.loaders import (
|
23 |
+
FromSingleFileMixin,
|
24 |
+
StableDiffusionXLLoraLoaderMixin,
|
25 |
+
TextualInversionLoaderMixin,
|
26 |
+
)
|
27 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
28 |
+
from diffusers.models.attention_processor import (
|
29 |
+
AttnProcessor2_0,
|
30 |
+
LoRAAttnProcessor2_0,
|
31 |
+
LoRAXFormersAttnProcessor,
|
32 |
+
XFormersAttnProcessor,
|
33 |
+
)
|
34 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
35 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
36 |
+
from diffusers.utils import (
|
37 |
+
USE_PEFT_BACKEND,
|
38 |
+
deprecate,
|
39 |
+
is_invisible_watermark_available,
|
40 |
+
is_torch_xla_available,
|
41 |
+
logging,
|
42 |
+
replace_example_docstring,
|
43 |
+
scale_lora_layers,
|
44 |
+
unscale_lora_layers,
|
45 |
+
)
|
46 |
+
from diffusers.utils.torch_utils import randn_tensor
|
47 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
48 |
+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
49 |
+
|
50 |
+
|
51 |
+
if is_invisible_watermark_available():
|
52 |
+
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
53 |
+
|
54 |
+
if is_torch_xla_available():
|
55 |
+
import torch_xla.core.xla_model as xm
|
56 |
+
|
57 |
+
XLA_AVAILABLE = True
|
58 |
+
else:
|
59 |
+
XLA_AVAILABLE = False
|
60 |
+
|
61 |
+
|
62 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
63 |
+
|
64 |
+
EXAMPLE_DOC_STRING = """
|
65 |
+
Examples:
|
66 |
+
```py
|
67 |
+
>>> import torch
|
68 |
+
>>> from diffusers import StableDiffusionXLPipeline
|
69 |
+
|
70 |
+
>>> pipe = StableDiffusionXLPipeline.from_pretrained(
|
71 |
+
... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
72 |
+
... )
|
73 |
+
>>> pipe = pipe.to("cuda")
|
74 |
+
|
75 |
+
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
76 |
+
>>> image = pipe(prompt).images[0]
|
77 |
+
```
|
78 |
+
"""
|
79 |
+
|
80 |
+
|
81 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
82 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
83 |
+
"""
|
84 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
85 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
86 |
+
"""
|
87 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
88 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
89 |
+
# rescale the results from guidance (fixes overexposure)
|
90 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
91 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
92 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
93 |
+
return noise_cfg
|
94 |
+
|
95 |
+
|
96 |
+
class StableDiffusionXLPipeline(
|
97 |
+
DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
98 |
+
):
|
99 |
+
r"""
|
100 |
+
Pipeline for text-to-image generation using Stable Diffusion XL.
|
101 |
+
|
102 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
103 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
104 |
+
|
105 |
+
In addition the pipeline inherits the following loading methods:
|
106 |
+
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
|
107 |
+
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
108 |
+
|
109 |
+
as well as the following saving methods:
|
110 |
+
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
|
111 |
+
|
112 |
+
Args:
|
113 |
+
vae ([`AutoencoderKL`]):
|
114 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
115 |
+
text_encoder ([`CLIPTextModel`]):
|
116 |
+
Frozen text-encoder. Stable Diffusion XL uses the text portion of
|
117 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
118 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
119 |
+
text_encoder_2 ([` CLIPTextModelWithProjection`]):
|
120 |
+
Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
|
121 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
122 |
+
specifically the
|
123 |
+
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
124 |
+
variant.
|
125 |
+
tokenizer (`CLIPTokenizer`):
|
126 |
+
Tokenizer of class
|
127 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
128 |
+
tokenizer_2 (`CLIPTokenizer`):
|
129 |
+
Second Tokenizer of class
|
130 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
131 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
132 |
+
scheduler ([`SchedulerMixin`]):
|
133 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
134 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
135 |
+
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
|
136 |
+
Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
|
137 |
+
`stabilityai/stable-diffusion-xl-base-1-0`.
|
138 |
+
add_watermarker (`bool`, *optional*):
|
139 |
+
Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
|
140 |
+
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
|
141 |
+
watermarker will be used.
|
142 |
+
"""
|
143 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
144 |
+
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
145 |
+
_callback_tensor_inputs = [
|
146 |
+
"latents",
|
147 |
+
"prompt_embeds",
|
148 |
+
"negative_prompt_embeds",
|
149 |
+
"add_text_embeds",
|
150 |
+
"add_time_ids",
|
151 |
+
"negative_pooled_prompt_embeds",
|
152 |
+
"negative_add_time_ids",
|
153 |
+
]
|
154 |
+
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
vae: AutoencoderKL,
|
158 |
+
text_encoder: CLIPTextModel,
|
159 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
160 |
+
tokenizer: CLIPTokenizer,
|
161 |
+
tokenizer_2: CLIPTokenizer,
|
162 |
+
unet: UNet2DConditionModel,
|
163 |
+
scheduler: KarrasDiffusionSchedulers,
|
164 |
+
force_zeros_for_empty_prompt: bool = True,
|
165 |
+
add_watermarker: Optional[bool] = None,
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
|
169 |
+
self.register_modules(
|
170 |
+
vae=vae,
|
171 |
+
text_encoder=text_encoder,
|
172 |
+
text_encoder_2=text_encoder_2,
|
173 |
+
tokenizer=tokenizer,
|
174 |
+
tokenizer_2=tokenizer_2,
|
175 |
+
unet=unet,
|
176 |
+
scheduler=scheduler,
|
177 |
+
)
|
178 |
+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
179 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
180 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
181 |
+
|
182 |
+
self.default_sample_size = self.unet.config.sample_size
|
183 |
+
|
184 |
+
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
185 |
+
|
186 |
+
if add_watermarker:
|
187 |
+
self.watermark = StableDiffusionXLWatermarker()
|
188 |
+
else:
|
189 |
+
self.watermark = None
|
190 |
+
|
191 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
192 |
+
def enable_vae_slicing(self):
|
193 |
+
r"""
|
194 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
195 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
196 |
+
"""
|
197 |
+
self.vae.enable_slicing()
|
198 |
+
|
199 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
200 |
+
def disable_vae_slicing(self):
|
201 |
+
r"""
|
202 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
203 |
+
computing decoding in one step.
|
204 |
+
"""
|
205 |
+
self.vae.disable_slicing()
|
206 |
+
|
207 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
208 |
+
def enable_vae_tiling(self):
|
209 |
+
r"""
|
210 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
211 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
212 |
+
processing larger images.
|
213 |
+
"""
|
214 |
+
self.vae.enable_tiling()
|
215 |
+
|
216 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
217 |
+
def disable_vae_tiling(self):
|
218 |
+
r"""
|
219 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
220 |
+
computing decoding in one step.
|
221 |
+
"""
|
222 |
+
self.vae.disable_tiling()
|
223 |
+
|
224 |
+
def encode_prompt(
|
225 |
+
self,
|
226 |
+
prompt: str,
|
227 |
+
prompt_2: Optional[str] = None,
|
228 |
+
device: Optional[torch.device] = None,
|
229 |
+
num_images_per_prompt: int = 1,
|
230 |
+
do_classifier_free_guidance: bool = True,
|
231 |
+
negative_prompt: Optional[str] = None,
|
232 |
+
negative_prompt_2: Optional[str] = None,
|
233 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
234 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
235 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
236 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
237 |
+
lora_scale: Optional[float] = None,
|
238 |
+
clip_skip: Optional[int] = None,
|
239 |
+
):
|
240 |
+
r"""
|
241 |
+
Encodes the prompt into text encoder hidden states.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
prompt (`str` or `List[str]`, *optional*):
|
245 |
+
prompt to be encoded
|
246 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
247 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
248 |
+
used in both text-encoders
|
249 |
+
device: (`torch.device`):
|
250 |
+
torch device
|
251 |
+
num_images_per_prompt (`int`):
|
252 |
+
number of images that should be generated per prompt
|
253 |
+
do_classifier_free_guidance (`bool`):
|
254 |
+
whether to use classifier free guidance or not
|
255 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
256 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
257 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
258 |
+
less than `1`).
|
259 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
260 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
261 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
262 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
263 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
264 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
265 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
266 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
267 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
268 |
+
argument.
|
269 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
270 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
271 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
272 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
273 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
274 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
275 |
+
input argument.
|
276 |
+
lora_scale (`float`, *optional*):
|
277 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
278 |
+
clip_skip (`int`, *optional*):
|
279 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
280 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
281 |
+
"""
|
282 |
+
device = device or self._execution_device
|
283 |
+
|
284 |
+
# set lora scale so that monkey patched LoRA
|
285 |
+
# function of text encoder can correctly access it
|
286 |
+
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
|
287 |
+
self._lora_scale = lora_scale
|
288 |
+
|
289 |
+
# dynamically adjust the LoRA scale
|
290 |
+
if self.text_encoder is not None:
|
291 |
+
if not USE_PEFT_BACKEND:
|
292 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
293 |
+
else:
|
294 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
295 |
+
|
296 |
+
if self.text_encoder_2 is not None:
|
297 |
+
if not USE_PEFT_BACKEND:
|
298 |
+
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
299 |
+
else:
|
300 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
301 |
+
|
302 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
303 |
+
|
304 |
+
if prompt is not None:
|
305 |
+
batch_size = len(prompt)
|
306 |
+
else:
|
307 |
+
batch_size = prompt_embeds.shape[0]
|
308 |
+
|
309 |
+
# Define tokenizers and text encoders
|
310 |
+
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
311 |
+
text_encoders = (
|
312 |
+
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
313 |
+
)
|
314 |
+
|
315 |
+
if prompt_embeds is None:
|
316 |
+
prompt_2 = prompt_2 or prompt
|
317 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
318 |
+
|
319 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
320 |
+
prompt_embeds_list = []
|
321 |
+
prompts = [prompt, prompt_2]
|
322 |
+
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
323 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
324 |
+
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
325 |
+
|
326 |
+
text_inputs = tokenizer(
|
327 |
+
prompt,
|
328 |
+
padding="max_length",
|
329 |
+
max_length=tokenizer.model_max_length,
|
330 |
+
truncation=True,
|
331 |
+
return_tensors="pt",
|
332 |
+
)
|
333 |
+
|
334 |
+
text_input_ids = text_inputs.input_ids
|
335 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
336 |
+
|
337 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
338 |
+
text_input_ids, untruncated_ids
|
339 |
+
):
|
340 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
341 |
+
logger.warning(
|
342 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
343 |
+
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
344 |
+
)
|
345 |
+
|
346 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
347 |
+
|
348 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
349 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
350 |
+
if clip_skip is None:
|
351 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
352 |
+
else:
|
353 |
+
# "2" because SDXL always indexes from the penultimate layer.
|
354 |
+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
355 |
+
|
356 |
+
prompt_embeds_list.append(prompt_embeds)
|
357 |
+
|
358 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
359 |
+
|
360 |
+
# get unconditional embeddings for classifier free guidance
|
361 |
+
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
|
362 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
363 |
+
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
364 |
+
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
365 |
+
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
366 |
+
negative_prompt = negative_prompt or ""
|
367 |
+
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
368 |
+
|
369 |
+
# normalize str to list
|
370 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
371 |
+
negative_prompt_2 = (
|
372 |
+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
373 |
+
)
|
374 |
+
|
375 |
+
uncond_tokens: List[str]
|
376 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
377 |
+
raise TypeError(
|
378 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
379 |
+
f" {type(prompt)}."
|
380 |
+
)
|
381 |
+
elif batch_size != len(negative_prompt):
|
382 |
+
raise ValueError(
|
383 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
384 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
385 |
+
" the batch size of `prompt`."
|
386 |
+
)
|
387 |
+
else:
|
388 |
+
uncond_tokens = [negative_prompt, negative_prompt_2]
|
389 |
+
|
390 |
+
negative_prompt_embeds_list = []
|
391 |
+
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
|
392 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
393 |
+
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
|
394 |
+
|
395 |
+
max_length = prompt_embeds.shape[1]
|
396 |
+
uncond_input = tokenizer(
|
397 |
+
negative_prompt,
|
398 |
+
padding="max_length",
|
399 |
+
max_length=max_length,
|
400 |
+
truncation=True,
|
401 |
+
return_tensors="pt",
|
402 |
+
)
|
403 |
+
|
404 |
+
negative_prompt_embeds = text_encoder(
|
405 |
+
uncond_input.input_ids.to(device),
|
406 |
+
output_hidden_states=True,
|
407 |
+
)
|
408 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
409 |
+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
410 |
+
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
411 |
+
|
412 |
+
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
413 |
+
|
414 |
+
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
415 |
+
|
416 |
+
if self.text_encoder_2 is not None:
|
417 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
418 |
+
else:
|
419 |
+
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
420 |
+
|
421 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
422 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
423 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
424 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
425 |
+
|
426 |
+
if do_classifier_free_guidance:
|
427 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
428 |
+
seq_len = negative_prompt_embeds.shape[1]
|
429 |
+
|
430 |
+
if self.text_encoder_2 is not None:
|
431 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
432 |
+
else:
|
433 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
434 |
+
|
435 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
436 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
437 |
+
|
438 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
439 |
+
bs_embed * num_images_per_prompt, -1
|
440 |
+
)
|
441 |
+
if do_classifier_free_guidance:
|
442 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
443 |
+
bs_embed * num_images_per_prompt, -1
|
444 |
+
)
|
445 |
+
|
446 |
+
if self.text_encoder is not None:
|
447 |
+
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
448 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
449 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
450 |
+
|
451 |
+
if self.text_encoder_2 is not None:
|
452 |
+
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
453 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
454 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
455 |
+
|
456 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
457 |
+
|
458 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
459 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
460 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
461 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
462 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
463 |
+
# and should be between [0, 1]
|
464 |
+
|
465 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
466 |
+
extra_step_kwargs = {}
|
467 |
+
if accepts_eta:
|
468 |
+
extra_step_kwargs["eta"] = eta
|
469 |
+
|
470 |
+
# check if the scheduler accepts generator
|
471 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
472 |
+
if accepts_generator:
|
473 |
+
extra_step_kwargs["generator"] = generator
|
474 |
+
return extra_step_kwargs
|
475 |
+
|
476 |
+
def check_inputs(
|
477 |
+
self,
|
478 |
+
prompt,
|
479 |
+
prompt_2,
|
480 |
+
height,
|
481 |
+
width,
|
482 |
+
callback_steps,
|
483 |
+
negative_prompt=None,
|
484 |
+
negative_prompt_2=None,
|
485 |
+
prompt_embeds=None,
|
486 |
+
negative_prompt_embeds=None,
|
487 |
+
pooled_prompt_embeds=None,
|
488 |
+
negative_pooled_prompt_embeds=None,
|
489 |
+
callback_on_step_end_tensor_inputs=None,
|
490 |
+
):
|
491 |
+
if height % 8 != 0 or width % 8 != 0:
|
492 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
493 |
+
|
494 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
495 |
+
raise ValueError(
|
496 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
497 |
+
f" {type(callback_steps)}."
|
498 |
+
)
|
499 |
+
|
500 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
501 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
502 |
+
):
|
503 |
+
raise ValueError(
|
504 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
505 |
+
)
|
506 |
+
|
507 |
+
if prompt is not None and prompt_embeds is not None:
|
508 |
+
raise ValueError(
|
509 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
510 |
+
" only forward one of the two."
|
511 |
+
)
|
512 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
513 |
+
raise ValueError(
|
514 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
515 |
+
" only forward one of the two."
|
516 |
+
)
|
517 |
+
elif prompt is None and prompt_embeds is None:
|
518 |
+
raise ValueError(
|
519 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
520 |
+
)
|
521 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
522 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
523 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
524 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
525 |
+
|
526 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
527 |
+
raise ValueError(
|
528 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
529 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
530 |
+
)
|
531 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
532 |
+
raise ValueError(
|
533 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
534 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
535 |
+
)
|
536 |
+
|
537 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
538 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
539 |
+
raise ValueError(
|
540 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
541 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
542 |
+
f" {negative_prompt_embeds.shape}."
|
543 |
+
)
|
544 |
+
|
545 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
546 |
+
raise ValueError(
|
547 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
548 |
+
)
|
549 |
+
|
550 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
551 |
+
raise ValueError(
|
552 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
553 |
+
)
|
554 |
+
|
555 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
556 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
557 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
558 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
559 |
+
raise ValueError(
|
560 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
561 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
562 |
+
)
|
563 |
+
|
564 |
+
if latents is None:
|
565 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
566 |
+
else:
|
567 |
+
latents = latents.to(device)
|
568 |
+
|
569 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
570 |
+
latents = latents * self.scheduler.init_noise_sigma
|
571 |
+
return latents
|
572 |
+
|
573 |
+
def _get_add_time_ids(
|
574 |
+
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
575 |
+
):
|
576 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
577 |
+
|
578 |
+
passed_add_embed_dim = (
|
579 |
+
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
580 |
+
)
|
581 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
582 |
+
|
583 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
584 |
+
raise ValueError(
|
585 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
586 |
+
)
|
587 |
+
|
588 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
589 |
+
return add_time_ids
|
590 |
+
|
591 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
592 |
+
def upcast_vae(self):
|
593 |
+
dtype = self.vae.dtype
|
594 |
+
self.vae.to(dtype=torch.float32)
|
595 |
+
use_torch_2_0_or_xformers = isinstance(
|
596 |
+
self.vae.decoder.mid_block.attentions[0].processor,
|
597 |
+
(
|
598 |
+
AttnProcessor2_0,
|
599 |
+
XFormersAttnProcessor,
|
600 |
+
LoRAXFormersAttnProcessor,
|
601 |
+
LoRAAttnProcessor2_0,
|
602 |
+
),
|
603 |
+
)
|
604 |
+
# if xformers or torch_2_0 is used attention block does not need
|
605 |
+
# to be in float32 which can save lots of memory
|
606 |
+
if use_torch_2_0_or_xformers:
|
607 |
+
self.vae.post_quant_conv.to(dtype)
|
608 |
+
self.vae.decoder.conv_in.to(dtype)
|
609 |
+
self.vae.decoder.mid_block.to(dtype)
|
610 |
+
|
611 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
|
612 |
+
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
613 |
+
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
614 |
+
|
615 |
+
The suffixes after the scaling factors represent the stages where they are being applied.
|
616 |
+
|
617 |
+
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
618 |
+
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
619 |
+
|
620 |
+
Args:
|
621 |
+
s1 (`float`):
|
622 |
+
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
623 |
+
mitigate "oversmoothing effect" in the enhanced denoising process.
|
624 |
+
s2 (`float`):
|
625 |
+
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
626 |
+
mitigate "oversmoothing effect" in the enhanced denoising process.
|
627 |
+
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
628 |
+
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
629 |
+
"""
|
630 |
+
if not hasattr(self, "unet"):
|
631 |
+
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
632 |
+
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
633 |
+
|
634 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
|
635 |
+
def disable_freeu(self):
|
636 |
+
"""Disables the FreeU mechanism if enabled."""
|
637 |
+
self.unet.disable_freeu()
|
638 |
+
|
639 |
+
@property
|
640 |
+
def guidance_scale(self):
|
641 |
+
return self._guidance_scale
|
642 |
+
|
643 |
+
@property
|
644 |
+
def guidance_rescale(self):
|
645 |
+
return self._guidance_rescale
|
646 |
+
|
647 |
+
@property
|
648 |
+
def clip_skip(self):
|
649 |
+
return self._clip_skip
|
650 |
+
|
651 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
652 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
653 |
+
# corresponds to doing no classifier free guidance.
|
654 |
+
@property
|
655 |
+
def do_classifier_free_guidance(self):
|
656 |
+
return self._guidance_scale > 1
|
657 |
+
|
658 |
+
@property
|
659 |
+
def cross_attention_kwargs(self):
|
660 |
+
return self._cross_attention_kwargs
|
661 |
+
|
662 |
+
@property
|
663 |
+
def denoising_end(self):
|
664 |
+
return self._denoising_end
|
665 |
+
|
666 |
+
@property
|
667 |
+
def num_timesteps(self):
|
668 |
+
return self._num_timesteps
|
669 |
+
|
670 |
+
@torch.no_grad()
|
671 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
672 |
+
def __call__(
|
673 |
+
self,
|
674 |
+
prompt: Union[str, List[str]] = None,
|
675 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
676 |
+
height: Optional[int] = None,
|
677 |
+
width: Optional[int] = None,
|
678 |
+
num_inference_steps: int = 50,
|
679 |
+
denoising_end: Optional[float] = None,
|
680 |
+
guidance_scale: float = 5.0,
|
681 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
682 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
683 |
+
num_images_per_prompt: Optional[int] = 1,
|
684 |
+
eta: float = 0.0,
|
685 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
686 |
+
latents: Optional[torch.FloatTensor] = None,
|
687 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
688 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
689 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
690 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
691 |
+
output_type: Optional[str] = "pil",
|
692 |
+
return_dict: bool = True,
|
693 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
694 |
+
guidance_rescale: float = 0.0,
|
695 |
+
original_size: Optional[Tuple[int, int]] = None,
|
696 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
697 |
+
target_size: Optional[Tuple[int, int]] = None,
|
698 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
699 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
700 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
701 |
+
clip_skip: Optional[int] = None,
|
702 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
703 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
704 |
+
**kwargs,
|
705 |
+
):
|
706 |
+
r"""
|
707 |
+
Function invoked when calling the pipeline for generation.
|
708 |
+
|
709 |
+
Args:
|
710 |
+
prompt (`str` or `List[str]`, *optional*):
|
711 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
712 |
+
instead.
|
713 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
714 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
715 |
+
used in both text-encoders
|
716 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
717 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
718 |
+
Anything below 512 pixels won't work well for
|
719 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
720 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
721 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
722 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
723 |
+
Anything below 512 pixels won't work well for
|
724 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
725 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
726 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
727 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
728 |
+
expense of slower inference.
|
729 |
+
denoising_end (`float`, *optional*):
|
730 |
+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
731 |
+
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
732 |
+
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
733 |
+
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
734 |
+
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
735 |
+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
736 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
737 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
738 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
739 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
740 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
741 |
+
usually at the expense of lower image quality.
|
742 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
743 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
744 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
745 |
+
less than `1`).
|
746 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
747 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
748 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
749 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
750 |
+
The number of images to generate per prompt.
|
751 |
+
eta (`float`, *optional*, defaults to 0.0):
|
752 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
753 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
754 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
755 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
756 |
+
to make generation deterministic.
|
757 |
+
latents (`torch.FloatTensor`, *optional*):
|
758 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
759 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
760 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
761 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
762 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
763 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
764 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
765 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
766 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
767 |
+
argument.
|
768 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
769 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
770 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
771 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
772 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
773 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
774 |
+
input argument.
|
775 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
776 |
+
The output format of the generate image. Choose between
|
777 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
778 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
779 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
780 |
+
of a plain tuple.
|
781 |
+
cross_attention_kwargs (`dict`, *optional*):
|
782 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
783 |
+
`self.processor` in
|
784 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
785 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
786 |
+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
787 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
788 |
+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
789 |
+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
790 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
791 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
792 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
793 |
+
explained in section 2.2 of
|
794 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
795 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
796 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
797 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
798 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
799 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
800 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
801 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
802 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
803 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
804 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
805 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
806 |
+
micro-conditioning as explained in section 2.2 of
|
807 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
808 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
809 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
810 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
811 |
+
micro-conditioning as explained in section 2.2 of
|
812 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
813 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
814 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
815 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
816 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
817 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
818 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
819 |
+
callback_on_step_end (`Callable`, *optional*):
|
820 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
821 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
822 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
823 |
+
`callback_on_step_end_tensor_inputs`.
|
824 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
825 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
826 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
827 |
+
`._callback_tensor_inputs` attribute of your pipeine class.
|
828 |
+
|
829 |
+
Examples:
|
830 |
+
|
831 |
+
Returns:
|
832 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
833 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
834 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
835 |
+
"""
|
836 |
+
|
837 |
+
callback = kwargs.pop("callback", None)
|
838 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
839 |
+
|
840 |
+
if callback is not None:
|
841 |
+
deprecate(
|
842 |
+
"callback",
|
843 |
+
"1.0.0",
|
844 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
845 |
+
)
|
846 |
+
if callback_steps is not None:
|
847 |
+
deprecate(
|
848 |
+
"callback_steps",
|
849 |
+
"1.0.0",
|
850 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
851 |
+
)
|
852 |
+
|
853 |
+
# 0. Default height and width to unet
|
854 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
855 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
856 |
+
|
857 |
+
original_size = original_size or (height, width)
|
858 |
+
target_size = target_size or (height, width)
|
859 |
+
|
860 |
+
# 1. Check inputs. Raise error if not correct
|
861 |
+
self.check_inputs(
|
862 |
+
prompt,
|
863 |
+
prompt_2,
|
864 |
+
height,
|
865 |
+
width,
|
866 |
+
callback_steps,
|
867 |
+
negative_prompt,
|
868 |
+
negative_prompt_2,
|
869 |
+
prompt_embeds,
|
870 |
+
negative_prompt_embeds,
|
871 |
+
pooled_prompt_embeds,
|
872 |
+
negative_pooled_prompt_embeds,
|
873 |
+
callback_on_step_end_tensor_inputs,
|
874 |
+
)
|
875 |
+
|
876 |
+
self._guidance_scale = guidance_scale
|
877 |
+
self._guidance_rescale = guidance_rescale
|
878 |
+
self._clip_skip = clip_skip
|
879 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
880 |
+
self._denoising_end = denoising_end
|
881 |
+
|
882 |
+
# 2. Define call parameters
|
883 |
+
if prompt is not None and isinstance(prompt, str):
|
884 |
+
batch_size = 1
|
885 |
+
elif prompt is not None and isinstance(prompt, list):
|
886 |
+
batch_size = len(prompt)
|
887 |
+
else:
|
888 |
+
batch_size = prompt_embeds.shape[0]
|
889 |
+
|
890 |
+
device = self._execution_device
|
891 |
+
|
892 |
+
# 3. Encode input prompt
|
893 |
+
lora_scale = (
|
894 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
895 |
+
)
|
896 |
+
|
897 |
+
(
|
898 |
+
prompt_embeds,
|
899 |
+
negative_prompt_embeds,
|
900 |
+
pooled_prompt_embeds,
|
901 |
+
negative_pooled_prompt_embeds,
|
902 |
+
) = self.encode_prompt(
|
903 |
+
prompt=prompt,
|
904 |
+
prompt_2=prompt_2,
|
905 |
+
device=device,
|
906 |
+
num_images_per_prompt=num_images_per_prompt,
|
907 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
908 |
+
negative_prompt=negative_prompt,
|
909 |
+
negative_prompt_2=negative_prompt_2,
|
910 |
+
prompt_embeds=prompt_embeds,
|
911 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
912 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
913 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
914 |
+
lora_scale=lora_scale,
|
915 |
+
clip_skip=self.clip_skip,
|
916 |
+
)
|
917 |
+
|
918 |
+
if kwargs['target_prompt'] is not None:
|
919 |
+
(
|
920 |
+
prompt_embeds_,
|
921 |
+
negative_prompt_embeds_,
|
922 |
+
_,
|
923 |
+
_,
|
924 |
+
) = self.encode_prompt(
|
925 |
+
prompt=kwargs['target_prompt'],
|
926 |
+
prompt_2=prompt_2,
|
927 |
+
device=device,
|
928 |
+
num_images_per_prompt=num_images_per_prompt,
|
929 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
930 |
+
# negative_prompt=negative_prompt,
|
931 |
+
negative_prompt=None,
|
932 |
+
# negative_prompt_2=negative_prompt_2,
|
933 |
+
negative_prompt_2=None,
|
934 |
+
prompt_embeds=None,
|
935 |
+
negative_prompt_embeds=None,
|
936 |
+
pooled_prompt_embeds=None,
|
937 |
+
negative_pooled_prompt_embeds=None,
|
938 |
+
lora_scale=lora_scale,
|
939 |
+
clip_skip=self.clip_skip,
|
940 |
+
)
|
941 |
+
prompt_embeds[1:] = prompt_embeds_[1:]
|
942 |
+
negative_prompt_embeds[1:] = negative_prompt_embeds_[1:]
|
943 |
+
|
944 |
+
|
945 |
+
# 4. Prepare timesteps
|
946 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
947 |
+
|
948 |
+
timesteps = self.scheduler.timesteps
|
949 |
+
|
950 |
+
# 5. Prepare latent variables
|
951 |
+
num_channels_latents = self.unet.config.in_channels
|
952 |
+
latents = self.prepare_latents(
|
953 |
+
batch_size * num_images_per_prompt,
|
954 |
+
num_channels_latents,
|
955 |
+
height,
|
956 |
+
width,
|
957 |
+
prompt_embeds.dtype,
|
958 |
+
device,
|
959 |
+
generator,
|
960 |
+
latents,
|
961 |
+
)
|
962 |
+
|
963 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
964 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
965 |
+
|
966 |
+
# 7. Prepare added time ids & embeddings
|
967 |
+
add_text_embeds = pooled_prompt_embeds
|
968 |
+
if self.text_encoder_2 is None:
|
969 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
970 |
+
else:
|
971 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
972 |
+
|
973 |
+
add_time_ids = self._get_add_time_ids(
|
974 |
+
original_size,
|
975 |
+
crops_coords_top_left,
|
976 |
+
target_size,
|
977 |
+
dtype=prompt_embeds.dtype,
|
978 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
979 |
+
)
|
980 |
+
if negative_original_size is not None and negative_target_size is not None:
|
981 |
+
negative_add_time_ids = self._get_add_time_ids(
|
982 |
+
negative_original_size,
|
983 |
+
negative_crops_coords_top_left,
|
984 |
+
negative_target_size,
|
985 |
+
dtype=prompt_embeds.dtype,
|
986 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
987 |
+
)
|
988 |
+
else:
|
989 |
+
negative_add_time_ids = add_time_ids
|
990 |
+
|
991 |
+
if self.do_classifier_free_guidance:
|
992 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
993 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
994 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
995 |
+
|
996 |
+
prompt_embeds = prompt_embeds.to(device)
|
997 |
+
add_text_embeds = add_text_embeds.to(device)
|
998 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
999 |
+
|
1000 |
+
# 8. Denoising loop
|
1001 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
1002 |
+
|
1003 |
+
# 8.1 Apply denoising_end
|
1004 |
+
if (
|
1005 |
+
self.denoising_end is not None
|
1006 |
+
and isinstance(self.denoising_end, float)
|
1007 |
+
and self.denoising_end > 0
|
1008 |
+
and self.denoising_end < 1
|
1009 |
+
):
|
1010 |
+
discrete_timestep_cutoff = int(
|
1011 |
+
round(
|
1012 |
+
self.scheduler.config.num_train_timesteps
|
1013 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
1014 |
+
)
|
1015 |
+
)
|
1016 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
1017 |
+
timesteps = timesteps[:num_inference_steps]
|
1018 |
+
|
1019 |
+
self._num_timesteps = len(timesteps)
|
1020 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1021 |
+
for i, t in enumerate(timesteps):
|
1022 |
+
# expand the latents if we are doing classifier free guidance
|
1023 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1024 |
+
|
1025 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1026 |
+
|
1027 |
+
# predict the noise residual
|
1028 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
1029 |
+
noise_pred = self.unet(
|
1030 |
+
latent_model_input,
|
1031 |
+
t,
|
1032 |
+
encoder_hidden_states=prompt_embeds,
|
1033 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
1034 |
+
added_cond_kwargs=added_cond_kwargs,
|
1035 |
+
return_dict=False,
|
1036 |
+
)[0]
|
1037 |
+
|
1038 |
+
# perform guidance
|
1039 |
+
if self.do_classifier_free_guidance:
|
1040 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1041 |
+
|
1042 |
+
if i < 3:
|
1043 |
+
noise_pred = noise_pred_uncond + 15.0 * (noise_pred_text - noise_pred_uncond)
|
1044 |
+
else:
|
1045 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1046 |
+
|
1047 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
1048 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1049 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
1050 |
+
|
1051 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1052 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1053 |
+
|
1054 |
+
if callback_on_step_end is not None:
|
1055 |
+
callback_kwargs = {}
|
1056 |
+
for k in callback_on_step_end_tensor_inputs:
|
1057 |
+
callback_kwargs[k] = locals()[k]
|
1058 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1059 |
+
|
1060 |
+
latents = callback_outputs.pop("latents", latents)
|
1061 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1062 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1063 |
+
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
1064 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
1065 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
1066 |
+
)
|
1067 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
1068 |
+
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
|
1069 |
+
|
1070 |
+
# call the callback, if provided
|
1071 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1072 |
+
progress_bar.update()
|
1073 |
+
if callback is not None and i % callback_steps == 0:
|
1074 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1075 |
+
callback(step_idx, t, latents)
|
1076 |
+
|
1077 |
+
if XLA_AVAILABLE:
|
1078 |
+
xm.mark_step()
|
1079 |
+
|
1080 |
+
if not output_type == "latent":
|
1081 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
1082 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
1083 |
+
|
1084 |
+
if needs_upcasting:
|
1085 |
+
self.upcast_vae()
|
1086 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
1087 |
+
|
1088 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
1089 |
+
|
1090 |
+
# cast back to fp16 if needed
|
1091 |
+
if needs_upcasting:
|
1092 |
+
self.vae.to(dtype=torch.float16)
|
1093 |
+
else:
|
1094 |
+
image = latents
|
1095 |
+
|
1096 |
+
if not output_type == "latent":
|
1097 |
+
# apply watermark if available
|
1098 |
+
if self.watermark is not None:
|
1099 |
+
image = self.watermark.apply_watermark(image)
|
1100 |
+
|
1101 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1102 |
+
|
1103 |
+
# Offload all models
|
1104 |
+
self.maybe_free_model_hooks()
|
1105 |
+
|
1106 |
+
if not return_dict:
|
1107 |
+
return (image,)
|
1108 |
+
|
1109 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
1110 |
+
|
1111 |
+
@torch.no_grad()
|
1112 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
1113 |
+
def inverted_ve_cross_frame_attn(
|
1114 |
+
self,
|
1115 |
+
prompt: Union[str, List[str]] = None,
|
1116 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
1117 |
+
height: Optional[int] = None,
|
1118 |
+
width: Optional[int] = None,
|
1119 |
+
num_inference_steps: int = 50,
|
1120 |
+
denoising_end: Optional[float] = None,
|
1121 |
+
guidance_scale: float = 5.0,
|
1122 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1123 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
1124 |
+
num_images_per_prompt: Optional[int] = 1,
|
1125 |
+
eta: float = 0.0,
|
1126 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
1127 |
+
latents: Optional[torch.FloatTensor] = None,
|
1128 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
1129 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1130 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1131 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1132 |
+
output_type: Optional[str] = "pil",
|
1133 |
+
return_dict: bool = True,
|
1134 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1135 |
+
guidance_rescale: float = 0.0,
|
1136 |
+
original_size: Optional[Tuple[int, int]] = None,
|
1137 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
1138 |
+
target_size: Optional[Tuple[int, int]] = None,
|
1139 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
1140 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
1141 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
1142 |
+
clip_skip: Optional[int] = None,
|
1143 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
1144 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
1145 |
+
**kwargs,
|
1146 |
+
):
|
1147 |
+
r"""
|
1148 |
+
Function invoked when calling the pipeline for generation.
|
1149 |
+
|
1150 |
+
Args:
|
1151 |
+
prompt (`str` or `List[str]`, *optional*):
|
1152 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
1153 |
+
instead.
|
1154 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
1155 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
1156 |
+
used in both text-encoders
|
1157 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
1158 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
1159 |
+
Anything below 512 pixels won't work well for
|
1160 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
1161 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
1162 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
1163 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
1164 |
+
Anything below 512 pixels won't work well for
|
1165 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
1166 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
1167 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1168 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1169 |
+
expense of slower inference.
|
1170 |
+
denoising_end (`float`, *optional*):
|
1171 |
+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
1172 |
+
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
1173 |
+
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
1174 |
+
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
1175 |
+
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
1176 |
+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
1177 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
1178 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1179 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1180 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1181 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1182 |
+
usually at the expense of lower image quality.
|
1183 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1184 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
1185 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
1186 |
+
less than `1`).
|
1187 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
1188 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
1189 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
1190 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1191 |
+
The number of images to generate per prompt.
|
1192 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1193 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1194 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1195 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
1196 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
1197 |
+
to make generation deterministic.
|
1198 |
+
latents (`torch.FloatTensor`, *optional*):
|
1199 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
1200 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
1201 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
1202 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
1203 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
1204 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
1205 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1206 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
1207 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
1208 |
+
argument.
|
1209 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1210 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
1211 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
1212 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1213 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
1214 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
1215 |
+
input argument.
|
1216 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1217 |
+
The output format of the generate image. Choose between
|
1218 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1219 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1220 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
1221 |
+
of a plain tuple.
|
1222 |
+
cross_attention_kwargs (`dict`, *optional*):
|
1223 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
1224 |
+
`self.processor` in
|
1225 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
1226 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
1227 |
+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
1228 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
1229 |
+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
1230 |
+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
1231 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
1232 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
1233 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
1234 |
+
explained in section 2.2 of
|
1235 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
1236 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
1237 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
1238 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
1239 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
1240 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
1241 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
1242 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
1243 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
1244 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
1245 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
1246 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
1247 |
+
micro-conditioning as explained in section 2.2 of
|
1248 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
1249 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
1250 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
1251 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
1252 |
+
micro-conditioning as explained in section 2.2 of
|
1253 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
1254 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
1255 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
1256 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
1257 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
1258 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
1259 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
1260 |
+
callback_on_step_end (`Callable`, *optional*):
|
1261 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
1262 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
1263 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
1264 |
+
`callback_on_step_end_tensor_inputs`.
|
1265 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
1266 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
1267 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
1268 |
+
`._callback_tensor_inputs` attribute of your pipeine class.
|
1269 |
+
|
1270 |
+
Examples:
|
1271 |
+
|
1272 |
+
Returns:
|
1273 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
1274 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
1275 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
1276 |
+
"""
|
1277 |
+
|
1278 |
+
callback = kwargs.pop("callback", None)
|
1279 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
1280 |
+
|
1281 |
+
if callback is not None:
|
1282 |
+
deprecate(
|
1283 |
+
"callback",
|
1284 |
+
"1.0.0",
|
1285 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
1286 |
+
)
|
1287 |
+
if callback_steps is not None:
|
1288 |
+
deprecate(
|
1289 |
+
"callback_steps",
|
1290 |
+
"1.0.0",
|
1291 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
1292 |
+
)
|
1293 |
+
|
1294 |
+
# 0. Default height and width to unet
|
1295 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
1296 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
1297 |
+
|
1298 |
+
original_size = original_size or (height, width)
|
1299 |
+
target_size = target_size or (height, width)
|
1300 |
+
|
1301 |
+
# 1. Check inputs. Raise error if not correct
|
1302 |
+
self.check_inputs(
|
1303 |
+
prompt,
|
1304 |
+
prompt_2,
|
1305 |
+
height,
|
1306 |
+
width,
|
1307 |
+
callback_steps,
|
1308 |
+
negative_prompt,
|
1309 |
+
negative_prompt_2,
|
1310 |
+
prompt_embeds,
|
1311 |
+
negative_prompt_embeds,
|
1312 |
+
pooled_prompt_embeds,
|
1313 |
+
negative_pooled_prompt_embeds,
|
1314 |
+
callback_on_step_end_tensor_inputs,
|
1315 |
+
)
|
1316 |
+
|
1317 |
+
self._guidance_scale = guidance_scale
|
1318 |
+
self._guidance_rescale = guidance_rescale
|
1319 |
+
self._clip_skip = clip_skip
|
1320 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
1321 |
+
self._denoising_end = denoising_end
|
1322 |
+
|
1323 |
+
# 2. Define call parameters
|
1324 |
+
if prompt is not None and isinstance(prompt, str):
|
1325 |
+
batch_size = 1
|
1326 |
+
elif prompt is not None and isinstance(prompt, list):
|
1327 |
+
batch_size = len(prompt)
|
1328 |
+
else:
|
1329 |
+
batch_size = prompt_embeds.shape[0]
|
1330 |
+
|
1331 |
+
device = self._execution_device
|
1332 |
+
|
1333 |
+
# 3. Encode input prompt
|
1334 |
+
lora_scale = (
|
1335 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
1336 |
+
)
|
1337 |
+
|
1338 |
+
(
|
1339 |
+
prompt_embeds,
|
1340 |
+
negative_prompt_embeds,
|
1341 |
+
pooled_prompt_embeds,
|
1342 |
+
negative_pooled_prompt_embeds,
|
1343 |
+
) = self.encode_prompt(
|
1344 |
+
prompt=prompt,
|
1345 |
+
prompt_2=prompt_2,
|
1346 |
+
device=device,
|
1347 |
+
num_images_per_prompt=num_images_per_prompt,
|
1348 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1349 |
+
negative_prompt=negative_prompt,
|
1350 |
+
negative_prompt_2=negative_prompt_2,
|
1351 |
+
prompt_embeds=prompt_embeds,
|
1352 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1353 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
1354 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
1355 |
+
lora_scale=lora_scale,
|
1356 |
+
clip_skip=self.clip_skip,
|
1357 |
+
)
|
1358 |
+
|
1359 |
+
if kwargs['target_prompt'] is not None:
|
1360 |
+
(
|
1361 |
+
prompt_embeds_,
|
1362 |
+
negative_prompt_embeds_,
|
1363 |
+
_,
|
1364 |
+
_,
|
1365 |
+
) = self.encode_prompt(
|
1366 |
+
prompt=kwargs['target_prompt'],
|
1367 |
+
prompt_2=prompt_2,
|
1368 |
+
device=device,
|
1369 |
+
num_images_per_prompt=num_images_per_prompt,
|
1370 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1371 |
+
negative_prompt=kwargs['target_negative_prompt'] if kwargs['target_negative_prompt'] is not None else None,
|
1372 |
+
# negative_prompt=None,
|
1373 |
+
# negative_prompt_2=negative_prompt_2,
|
1374 |
+
negative_prompt_2=None,
|
1375 |
+
prompt_embeds=None,
|
1376 |
+
negative_prompt_embeds=None,
|
1377 |
+
pooled_prompt_embeds=None,
|
1378 |
+
negative_pooled_prompt_embeds=None,
|
1379 |
+
lora_scale=lora_scale,
|
1380 |
+
clip_skip=self.clip_skip,
|
1381 |
+
)
|
1382 |
+
prompt_embeds[1:] = prompt_embeds_[1:]
|
1383 |
+
if negative_prompt_embeds_ is not None:
|
1384 |
+
negative_prompt_embeds[1:] = negative_prompt_embeds_[1:]
|
1385 |
+
|
1386 |
+
|
1387 |
+
# 4. Prepare timesteps
|
1388 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
1389 |
+
|
1390 |
+
timesteps = self.scheduler.timesteps
|
1391 |
+
|
1392 |
+
# 5. Prepare latent variables
|
1393 |
+
num_channels_latents = self.unet.config.in_channels
|
1394 |
+
latents = self.prepare_latents(
|
1395 |
+
batch_size * num_images_per_prompt,
|
1396 |
+
num_channels_latents,
|
1397 |
+
height,
|
1398 |
+
width,
|
1399 |
+
prompt_embeds.dtype,
|
1400 |
+
device,
|
1401 |
+
generator,
|
1402 |
+
latents,
|
1403 |
+
)
|
1404 |
+
|
1405 |
+
|
1406 |
+
# import pdb; pdb.set_trace()
|
1407 |
+
|
1408 |
+
latents_ = self.prepare_latents(
|
1409 |
+
batch_size * num_images_per_prompt,
|
1410 |
+
num_channels_latents,
|
1411 |
+
height,
|
1412 |
+
width,
|
1413 |
+
prompt_embeds.dtype,
|
1414 |
+
device,
|
1415 |
+
generator,
|
1416 |
+
# latents,
|
1417 |
+
)
|
1418 |
+
|
1419 |
+
# import pdb; pdb.set_trace()
|
1420 |
+
|
1421 |
+
# latents[1:] = latents_[1:]
|
1422 |
+
latents = torch.cat([latents.unsqueeze(0), latents_[1:]], dim=0)
|
1423 |
+
|
1424 |
+
|
1425 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1426 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
1427 |
+
|
1428 |
+
# 7. Prepare added time ids & embeddings
|
1429 |
+
add_text_embeds = pooled_prompt_embeds
|
1430 |
+
if self.text_encoder_2 is None:
|
1431 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
1432 |
+
else:
|
1433 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
1434 |
+
|
1435 |
+
add_time_ids = self._get_add_time_ids(
|
1436 |
+
original_size,
|
1437 |
+
crops_coords_top_left,
|
1438 |
+
target_size,
|
1439 |
+
dtype=prompt_embeds.dtype,
|
1440 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1441 |
+
)
|
1442 |
+
if negative_original_size is not None and negative_target_size is not None:
|
1443 |
+
negative_add_time_ids = self._get_add_time_ids(
|
1444 |
+
negative_original_size,
|
1445 |
+
negative_crops_coords_top_left,
|
1446 |
+
negative_target_size,
|
1447 |
+
dtype=prompt_embeds.dtype,
|
1448 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1449 |
+
)
|
1450 |
+
else:
|
1451 |
+
negative_add_time_ids = add_time_ids
|
1452 |
+
|
1453 |
+
if self.do_classifier_free_guidance:
|
1454 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1455 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
1456 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
1457 |
+
|
1458 |
+
prompt_embeds = prompt_embeds.to(device)
|
1459 |
+
add_text_embeds = add_text_embeds.to(device)
|
1460 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
1461 |
+
|
1462 |
+
# 8. Denoising loop
|
1463 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
1464 |
+
|
1465 |
+
# 8.1 Apply denoising_end
|
1466 |
+
if (
|
1467 |
+
self.denoising_end is not None
|
1468 |
+
and isinstance(self.denoising_end, float)
|
1469 |
+
and self.denoising_end > 0
|
1470 |
+
and self.denoising_end < 1
|
1471 |
+
):
|
1472 |
+
discrete_timestep_cutoff = int(
|
1473 |
+
round(
|
1474 |
+
self.scheduler.config.num_train_timesteps
|
1475 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
1476 |
+
)
|
1477 |
+
)
|
1478 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
1479 |
+
timesteps = timesteps[:num_inference_steps]
|
1480 |
+
|
1481 |
+
self._num_timesteps = len(timesteps)
|
1482 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1483 |
+
for i, t in enumerate(timesteps):
|
1484 |
+
# expand the latents if we are doing classifier free guidance
|
1485 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1486 |
+
|
1487 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1488 |
+
|
1489 |
+
# predict the noise residual
|
1490 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
1491 |
+
noise_pred = self.unet(
|
1492 |
+
latent_model_input,
|
1493 |
+
t,
|
1494 |
+
encoder_hidden_states=prompt_embeds,
|
1495 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
1496 |
+
added_cond_kwargs=added_cond_kwargs,
|
1497 |
+
return_dict=False,
|
1498 |
+
)[0]
|
1499 |
+
|
1500 |
+
# perform guidance
|
1501 |
+
if self.do_classifier_free_guidance:
|
1502 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1503 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1504 |
+
noise_pred[0] = noise_pred_uncond[0] #추가된것
|
1505 |
+
|
1506 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
1507 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1508 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
1509 |
+
noise_pred[0] = noise_pred_uncond[0] #추가된것
|
1510 |
+
|
1511 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1512 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1513 |
+
|
1514 |
+
if callback_on_step_end is not None:
|
1515 |
+
callback_kwargs = {}
|
1516 |
+
for k in callback_on_step_end_tensor_inputs:
|
1517 |
+
callback_kwargs[k] = locals()[k]
|
1518 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1519 |
+
|
1520 |
+
latents = callback_outputs.pop("latents", latents)
|
1521 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1522 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1523 |
+
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
1524 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
1525 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
1526 |
+
)
|
1527 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
1528 |
+
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
|
1529 |
+
|
1530 |
+
# call the callback, if provided
|
1531 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1532 |
+
progress_bar.update()
|
1533 |
+
if callback is not None and i % callback_steps == 0:
|
1534 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1535 |
+
callback(step_idx, t, latents)
|
1536 |
+
|
1537 |
+
if XLA_AVAILABLE:
|
1538 |
+
xm.mark_step()
|
1539 |
+
|
1540 |
+
if not output_type == "latent":
|
1541 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
1542 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
1543 |
+
|
1544 |
+
if needs_upcasting:
|
1545 |
+
self.upcast_vae()
|
1546 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
1547 |
+
|
1548 |
+
self.enable_vae_slicing()
|
1549 |
+
|
1550 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
1551 |
+
|
1552 |
+
# cast back to fp16 if needed
|
1553 |
+
if needs_upcasting:
|
1554 |
+
self.vae.to(dtype=torch.float16)
|
1555 |
+
else:
|
1556 |
+
image = latents
|
1557 |
+
|
1558 |
+
if not output_type == "latent":
|
1559 |
+
# apply watermark if available
|
1560 |
+
if self.watermark is not None:
|
1561 |
+
image = self.watermark.apply_watermark(image)
|
1562 |
+
|
1563 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1564 |
+
|
1565 |
+
# Offload all models
|
1566 |
+
self.maybe_free_model_hooks()
|
1567 |
+
|
1568 |
+
if not return_dict:
|
1569 |
+
return (image,)
|
1570 |
+
|
1571 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
1572 |
+
|
1573 |
+
|
visualize_attention_src/save_attn_map_script.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pipelines.inverted_ve_pipeline import CrossFrameAttnProcessor, CrossFrameAttnProcessor_store, ACTIVATE_LAYER_CANDIDATE
|
3 |
+
from diffusers import DDIMScheduler, AutoencoderKL
|
4 |
+
import os
|
5 |
+
from PIL import Image
|
6 |
+
from utils import memory_efficient
|
7 |
+
from diffusers.models.attention_processor import AttnProcessor
|
8 |
+
from pipeline_stable_diffusion_xl_attn import StableDiffusionXLPipeline
|
9 |
+
|
10 |
+
|
11 |
+
def create_image_grid(image_list, rows, cols, padding=10):
|
12 |
+
# Ensure the number of rows and columns doesn't exceed the number of images
|
13 |
+
rows = min(rows, len(image_list))
|
14 |
+
cols = min(cols, len(image_list))
|
15 |
+
|
16 |
+
# Get the dimensions of a single image
|
17 |
+
image_width, image_height = image_list[0].size
|
18 |
+
|
19 |
+
# Calculate the size of the output image
|
20 |
+
grid_width = cols * (image_width + padding) - padding
|
21 |
+
grid_height = rows * (image_height + padding) - padding
|
22 |
+
|
23 |
+
# Create an empty grid image
|
24 |
+
grid_image = Image.new('RGB', (grid_width, grid_height), (255, 255, 255))
|
25 |
+
|
26 |
+
# Paste images into the grid
|
27 |
+
for i, img in enumerate(image_list[:rows * cols]):
|
28 |
+
row = i // cols
|
29 |
+
col = i % cols
|
30 |
+
x = col * (image_width + padding)
|
31 |
+
y = row * (image_height + padding)
|
32 |
+
grid_image.paste(img, (x, y))
|
33 |
+
|
34 |
+
return grid_image
|
35 |
+
|
36 |
+
def transform_variable_name(input_str, attn_map_save_step):
|
37 |
+
# Split the input string into parts using the dot as a separator
|
38 |
+
parts = input_str.split('.')
|
39 |
+
|
40 |
+
# Extract numerical indices from the parts
|
41 |
+
indices = [int(part) if part.isdigit() else part for part in parts]
|
42 |
+
|
43 |
+
# Build the desired output string
|
44 |
+
output_str = f'pipe.unet.{indices[0]}[{indices[1]}].{indices[2]}[{indices[3]}].{indices[4]}[{indices[5]}].{indices[6]}.attn_map[{attn_map_save_step}]'
|
45 |
+
|
46 |
+
return output_str
|
47 |
+
|
48 |
+
|
49 |
+
num_images_per_prompt = 4
|
50 |
+
seeds=[1] #craft_clay
|
51 |
+
|
52 |
+
|
53 |
+
activate_layer_indices_list = [
|
54 |
+
# ((0,28),(108,140)),
|
55 |
+
# ((0,48), (68,140)),
|
56 |
+
# ((0,48), (88,140)),
|
57 |
+
# ((0,48), (108,140)),
|
58 |
+
# ((0,48), (128,140)),
|
59 |
+
# ((0,48), (140,140)),
|
60 |
+
# ((0,28), (68,140)),
|
61 |
+
# ((0,28), (88,140)),
|
62 |
+
# ((0,28), (108,140)),
|
63 |
+
# ((0,28), (128,140)),
|
64 |
+
# ((0,28), (140,140)),
|
65 |
+
# ((0,8), (68,140)),
|
66 |
+
# ((0,8), (88,140)),
|
67 |
+
# ((0,8), (108,140)),
|
68 |
+
# ((0,8), (128,140)),
|
69 |
+
# ((0,8), (140,140)),
|
70 |
+
# ((0,0), (68,140)),
|
71 |
+
# ((0,0), (88,140)),
|
72 |
+
((0,0), (108,140)),
|
73 |
+
# ((0,0), (128,140)),
|
74 |
+
# ((0,0), (140,140))
|
75 |
+
]
|
76 |
+
|
77 |
+
save_layer_list = [
|
78 |
+
# 'up_blocks.0.attentions.1.transformer_blocks.0.attn1.processor', #68
|
79 |
+
# 'up_blocks.0.attentions.1.transformer_blocks.4.attn2.processor', #78
|
80 |
+
# 'up_blocks.0.attentions.2.transformer_blocks.0.attn1.processor', #88
|
81 |
+
# 'up_blocks.0.attentions.2.transformer_blocks.4.attn2.processor', #108
|
82 |
+
# 'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', #128
|
83 |
+
# 'up_blocks.1.attentions.2.transformer_blocks.1.attn1.processor', #138
|
84 |
+
|
85 |
+
'up_blocks.0.attentions.2.transformer_blocks.0.attn1.processor', #108
|
86 |
+
'up_blocks.0.attentions.2.transformer_blocks.0.attn2.processor',
|
87 |
+
'up_blocks.0.attentions.2.transformer_blocks.1.attn1.processor',
|
88 |
+
'up_blocks.0.attentions.2.transformer_blocks.1.attn2.processor',
|
89 |
+
'up_blocks.0.attentions.2.transformer_blocks.2.attn1.processor',
|
90 |
+
'up_blocks.0.attentions.2.transformer_blocks.2.attn2.processor',
|
91 |
+
'up_blocks.0.attentions.2.transformer_blocks.3.attn1.processor',
|
92 |
+
'up_blocks.0.attentions.2.transformer_blocks.3.attn2.processor',
|
93 |
+
'up_blocks.0.attentions.2.transformer_blocks.4.attn1.processor',
|
94 |
+
'up_blocks.0.attentions.2.transformer_blocks.4.attn2.processor',
|
95 |
+
'up_blocks.0.attentions.2.transformer_blocks.5.attn1.processor',
|
96 |
+
'up_blocks.0.attentions.2.transformer_blocks.5.attn2.processor',
|
97 |
+
'up_blocks.0.attentions.2.transformer_blocks.6.attn1.processor',
|
98 |
+
'up_blocks.0.attentions.2.transformer_blocks.6.attn2.processor',
|
99 |
+
'up_blocks.0.attentions.2.transformer_blocks.7.attn1.processor',
|
100 |
+
'up_blocks.0.attentions.2.transformer_blocks.7.attn2.processor',
|
101 |
+
'up_blocks.0.attentions.2.transformer_blocks.8.attn1.processor',
|
102 |
+
'up_blocks.0.attentions.2.transformer_blocks.8.attn2.processor',
|
103 |
+
'up_blocks.0.attentions.2.transformer_blocks.9.attn1.processor',
|
104 |
+
'up_blocks.0.attentions.2.transformer_blocks.9.attn2.processor',
|
105 |
+
|
106 |
+
'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor', #128
|
107 |
+
'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor',
|
108 |
+
'up_blocks.1.attentions.0.transformer_blocks.1.attn1.processor',
|
109 |
+
'up_blocks.1.attentions.0.transformer_blocks.1.attn2.processor',
|
110 |
+
'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor',
|
111 |
+
'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor',
|
112 |
+
'up_blocks.1.attentions.1.transformer_blocks.1.attn1.processor',
|
113 |
+
'up_blocks.1.attentions.1.transformer_blocks.1.attn2.processor',
|
114 |
+
'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor',
|
115 |
+
'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor',
|
116 |
+
'up_blocks.1.attentions.2.transformer_blocks.1.attn1.processor',
|
117 |
+
'up_blocks.1.attentions.2.transformer_blocks.1.attn2.processor',
|
118 |
+
]
|
119 |
+
|
120 |
+
attn_map_save_steps = [20]
|
121 |
+
# attn_map_save_steps = [10,20,30,40]
|
122 |
+
|
123 |
+
results_dir = 'saved_attention_map_results'
|
124 |
+
if not os.path.exists(results_dir):
|
125 |
+
os.makedirs(results_dir)
|
126 |
+
|
127 |
+
base_model_path = "runwayml/stable-diffusion-v1-5"
|
128 |
+
vae_model_path = "stabilityai/sd-vae-ft-mse"
|
129 |
+
image_encoder_path = "models/image_encoder/"
|
130 |
+
|
131 |
+
|
132 |
+
object_list = [
|
133 |
+
"cat",
|
134 |
+
# "woman",
|
135 |
+
# "dog",
|
136 |
+
# "horse",
|
137 |
+
# "motorcycle"
|
138 |
+
]
|
139 |
+
|
140 |
+
target_object_list = [
|
141 |
+
# "Null",
|
142 |
+
"dog",
|
143 |
+
# "clock",
|
144 |
+
# "car"
|
145 |
+
# "panda",
|
146 |
+
# "bridge",
|
147 |
+
# "flower"
|
148 |
+
]
|
149 |
+
|
150 |
+
prompt_neg_prompt_pair_dicts = {
|
151 |
+
|
152 |
+
# "line_art": ("line art drawing {prompt} . professional, sleek, modern, minimalist, graphic, line art, vector graphics",
|
153 |
+
# "anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic"
|
154 |
+
# ) ,
|
155 |
+
|
156 |
+
# "anime": ("anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
157 |
+
# "photo, deformed, black and white, realism, disfigured, low contrast"
|
158 |
+
# ),
|
159 |
+
|
160 |
+
# "Artstyle_Pop_Art" : ("pop Art style {prompt} . bright colors, bold outlines, popular culture themes, ironic or kitsch",
|
161 |
+
# "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, minimalist"
|
162 |
+
# ),
|
163 |
+
|
164 |
+
# "Artstyle_Pointillism": ("pointillism style {prompt} . composed entirely of small, distinct dots of color, vibrant, highly detailed",
|
165 |
+
# "line drawing, smooth shading, large color fields, simplistic"
|
166 |
+
# ),
|
167 |
+
|
168 |
+
# "origami": ("origami style {prompt} . paper art, pleated paper, folded, origami art, pleats, cut and fold, centered composition",
|
169 |
+
# "noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo"
|
170 |
+
# ),
|
171 |
+
|
172 |
+
"craft_clay": ("play-doh style {prompt} . sculpture, clay art, centered composition, Claymation",
|
173 |
+
"sloppy, messy, grainy, highly detailed, ultra textured, photo"
|
174 |
+
),
|
175 |
+
|
176 |
+
# "low_poly" : ("low-poly style {prompt} . low-poly game art, polygon mesh, jagged, blocky, wireframe edges, centered composition",
|
177 |
+
# "noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo"
|
178 |
+
# ),
|
179 |
+
|
180 |
+
# "Artstyle_watercolor": ("watercolor painting {prompt} . vibrant, beautiful, painterly, detailed, textural, artistic",
|
181 |
+
# "anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy"
|
182 |
+
# ),
|
183 |
+
|
184 |
+
# "Papercraft_Collage" : ("collage style {prompt} . mixed media, layered, textural, detailed, artistic",
|
185 |
+
# "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic"
|
186 |
+
# ),
|
187 |
+
|
188 |
+
# "Artstyle_Impressionist" : ("impressionist painting {prompt} . loose brushwork, vibrant color, light and shadow play, captures feeling over form",
|
189 |
+
# "anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy"
|
190 |
+
# )
|
191 |
+
|
192 |
+
}
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
noise_scheduler = DDIMScheduler(
|
197 |
+
num_train_timesteps=1000,
|
198 |
+
beta_start=0.00085,
|
199 |
+
beta_end=0.012,
|
200 |
+
beta_schedule="scaled_linear",
|
201 |
+
clip_sample=False,
|
202 |
+
set_alpha_to_one=False,
|
203 |
+
steps_offset=1,
|
204 |
+
)
|
205 |
+
|
206 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
207 |
+
if device == 'cpu':
|
208 |
+
torch_dtype = torch.float32
|
209 |
+
else:
|
210 |
+
torch_dtype = torch.float16
|
211 |
+
|
212 |
+
vae = AutoencoderKL.from_pretrained(vae_model_path, torch_dtype=torch_dtype)
|
213 |
+
pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch_dtype)
|
214 |
+
|
215 |
+
|
216 |
+
memory_efficient(vae, device)
|
217 |
+
memory_efficient(pipe, device)
|
218 |
+
|
219 |
+
for seed in seeds:
|
220 |
+
for activate_layer_indices in activate_layer_indices_list:
|
221 |
+
attn_procs = {}
|
222 |
+
activate_layers = []
|
223 |
+
str_activate_layer = ""
|
224 |
+
for activate_layer_index in activate_layer_indices:
|
225 |
+
activate_layers += ACTIVATE_LAYER_CANDIDATE[activate_layer_index[0]:activate_layer_index[1]]
|
226 |
+
str_activate_layer += str(activate_layer_index)
|
227 |
+
|
228 |
+
|
229 |
+
for name in pipe.unet.attn_processors.keys():
|
230 |
+
if name in activate_layers:
|
231 |
+
if name in save_layer_list:
|
232 |
+
print(f"layer:{name}")
|
233 |
+
attn_procs[name] = CrossFrameAttnProcessor_store(unet_chunk_size=2, attn_map_save_steps=attn_map_save_steps)
|
234 |
+
else:
|
235 |
+
print(f"layer:{name}")
|
236 |
+
attn_procs[name] = CrossFrameAttnProcessor(unet_chunk_size=2)
|
237 |
+
else :
|
238 |
+
attn_procs[name] = AttnProcessor()
|
239 |
+
pipe.unet.set_attn_processor(attn_procs)
|
240 |
+
|
241 |
+
|
242 |
+
for target_object in target_object_list:
|
243 |
+
target_prompt = f"A photo of a {target_object}"
|
244 |
+
|
245 |
+
for object in object_list:
|
246 |
+
for key in prompt_neg_prompt_pair_dicts.keys():
|
247 |
+
prompt, negative_prompt = prompt_neg_prompt_pair_dicts[key]
|
248 |
+
|
249 |
+
generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
|
250 |
+
|
251 |
+
images = pipe(
|
252 |
+
prompt=prompt.replace("{prompt}", object),
|
253 |
+
guidance_scale = 7.0,
|
254 |
+
num_images_per_prompt = num_images_per_prompt,
|
255 |
+
target_prompt = target_prompt,
|
256 |
+
generator=generator,
|
257 |
+
|
258 |
+
)[0]
|
259 |
+
|
260 |
+
|
261 |
+
#make grid
|
262 |
+
grid = create_image_grid(images, 1, num_images_per_prompt)
|
263 |
+
|
264 |
+
save_name = f"{key}_src_{object}_tgt_{target_object}_activate_layer_{str_activate_layer}_seed_{seed}.png"
|
265 |
+
save_path = os.path.join(results_dir, save_name)
|
266 |
+
|
267 |
+
grid.save(save_path)
|
268 |
+
|
269 |
+
print("Saved image to: ", save_path)
|
270 |
+
|
271 |
+
#save attn map
|
272 |
+
for attn_map_save_step in attn_map_save_steps:
|
273 |
+
attn_map_save_name = f"attn_map_raw_{key}_src_{object}_tgt_{target_object}_activate_layer_{str_activate_layer}_attn_map_step_{attn_map_save_step}_seed_{seed}.pt"
|
274 |
+
attn_map_dic = {}
|
275 |
+
# for activate_layer in activate_layers:
|
276 |
+
for activate_layer in save_layer_list:
|
277 |
+
attn_map_var_name = transform_variable_name(activate_layer, attn_map_save_step)
|
278 |
+
exec(f"attn_map_dic[\"{activate_layer}\"] = {attn_map_var_name}")
|
279 |
+
|
280 |
+
torch.save(attn_map_dic, os.path.join(results_dir, attn_map_save_name))
|
281 |
+
print("Saved attn map to: ", os.path.join(results_dir, attn_map_save_name))
|
282 |
+
|
283 |
+
|
visualize_attention_src/utils.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
def get_image(image_path, row, col, image_size=1024, grid_width=1):
|
5 |
+
|
6 |
+
left_point = (image_size + grid_width) * col
|
7 |
+
up_point = (image_size + grid_width) * row
|
8 |
+
right_point = left_point + image_size
|
9 |
+
down_point = up_point + image_size
|
10 |
+
|
11 |
+
if type(image_path) is str:
|
12 |
+
image = Image.open(image_path)
|
13 |
+
else:
|
14 |
+
image = image_path
|
15 |
+
croped_image = image.crop((left_point, up_point, right_point, down_point))
|
16 |
+
return croped_image
|
17 |
+
|
18 |
+
def get_image_v2(image_path, row, col, image_size=1024, grid_row_space=1, grid_col_space=1):
|
19 |
+
|
20 |
+
left_point = (image_size + grid_col_space) * col
|
21 |
+
up_point = (image_size + grid_row_space) * row
|
22 |
+
right_point = left_point + image_size
|
23 |
+
down_point = up_point + image_size
|
24 |
+
|
25 |
+
if type(image_path) is str:
|
26 |
+
image = Image.open(image_path)
|
27 |
+
else:
|
28 |
+
image = image_path
|
29 |
+
croped_image = image.crop((left_point, up_point, right_point, down_point))
|
30 |
+
return croped_image
|
31 |
+
|
32 |
+
def create_image(row, col, image_size=1024, grid_width=1, background_color=(255,255,255), top_padding = 0, bottom_padding = 0, left_padding = 0, right_padding = 0):
|
33 |
+
|
34 |
+
image = Image.new('RGB', (image_size * col + grid_width * (col - 1) + left_padding , image_size * row + grid_width * (row - 1)), background_color)
|
35 |
+
return image
|
36 |
+
|
37 |
+
def paste_image(grid, image, row, col, image_size=1024, grid_width=1, top_padding = 0, bottom_padding = 0, left_padding = 0, right_padding = 0):
|
38 |
+
left_point = (image_size + grid_width) * col + left_padding
|
39 |
+
up_point = (image_size + grid_width) * row + top_padding
|
40 |
+
right_point = left_point + image_size
|
41 |
+
down_point = up_point + image_size
|
42 |
+
grid.paste(image, (left_point, up_point, right_point, down_point))
|
43 |
+
|
44 |
+
return grid
|
45 |
+
|
46 |
+
def paste_image_v2(grid, image, row, col, grid_size=1024, grid_width=1, top_padding = 0, bottom_padding = 0, left_padding = 0, right_padding = 0):
|
47 |
+
left_point = (grid_size + grid_width) * col + left_padding
|
48 |
+
up_point = (grid_size + grid_width) * row + top_padding
|
49 |
+
|
50 |
+
image_width, image_height = image.size
|
51 |
+
|
52 |
+
right_point = left_point + image_width
|
53 |
+
down_point = up_point + image_height
|
54 |
+
|
55 |
+
grid.paste(image, (left_point, up_point, right_point, down_point))
|
56 |
+
|
57 |
+
return grid
|
58 |
+
|
59 |
+
|
60 |
+
def pivot_figure(file_path, image_size=1024, grid_width=1):
|
61 |
+
if type(file_path) is str:
|
62 |
+
image = Image.open(file_path)
|
63 |
+
else:
|
64 |
+
image = file_path
|
65 |
+
image_col = image.width // image_size
|
66 |
+
image_row = image.height // image_size
|
67 |
+
|
68 |
+
|
69 |
+
grid = create_image(image_col, image_row, image_size, grid_width)
|
70 |
+
|
71 |
+
for row in range(image_row):
|
72 |
+
for col in range(image_col):
|
73 |
+
croped_image = get_image(image, row, col, image_size, grid_width)
|
74 |
+
grid = paste_image(grid, croped_image, col, row, image_size, grid_width)
|
75 |
+
|
76 |
+
return grid
|
77 |
+
|
78 |
+
def horizontal_flip_figure(file_path, image_size=1024, grid_width=1):
|
79 |
+
if type(file_path) is str:
|
80 |
+
image = Image.open(file_path)
|
81 |
+
else:
|
82 |
+
image = file_path
|
83 |
+
image_col = image.width // image_size
|
84 |
+
image_row = image.height // image_size
|
85 |
+
|
86 |
+
grid = create_image(image_row, image_col, image_size, grid_width)
|
87 |
+
|
88 |
+
for row in range(image_row):
|
89 |
+
for col in range(image_col):
|
90 |
+
croped_image = get_image(image, row, image_col - col - 1, image_size, grid_width)
|
91 |
+
grid = paste_image(grid, croped_image, row, col, image_size, grid_width)
|
92 |
+
|
93 |
+
return grid
|
94 |
+
|
95 |
+
def vertical_flip_figure(file_path, image_size=1024, grid_width=1):
|
96 |
+
if type(file_path) is str:
|
97 |
+
image = Image.open(file_path)
|
98 |
+
else:
|
99 |
+
image = file_path
|
100 |
+
|
101 |
+
image_col = image.width // image_size
|
102 |
+
image_row = image.height // image_size
|
103 |
+
|
104 |
+
grid = create_image(image_row, image_col, image_size, grid_width)
|
105 |
+
|
106 |
+
for row in range(image_row):
|
107 |
+
for col in range(image_col):
|
108 |
+
croped_image = get_image(image, image_row - row - 1, col, image_size, grid_width)
|
109 |
+
grid = paste_image(grid, croped_image, row, col, image_size, grid_width)
|
110 |
+
|
111 |
+
return grid
|
visualize_attention_src/visualize_attn_map_script.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
from ipycanvas import Canvas
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
from visualize_attention_src.utils import get_image
|
9 |
+
|
10 |
+
exp_dir = "saved_attention_map_results"
|
11 |
+
|
12 |
+
style_name = "line_art"
|
13 |
+
src_name = "cat"
|
14 |
+
tgt_name = "dog"
|
15 |
+
|
16 |
+
steps = ["20"]
|
17 |
+
seed = "4"
|
18 |
+
saved_dtype = "tensor"
|
19 |
+
|
20 |
+
|
21 |
+
attn_map_raws = []
|
22 |
+
for step in steps:
|
23 |
+
attn_map_name_wo_ext = f"attn_map_raw_{style_name}_src_{src_name}_tgt_{tgt_name}_activate_layer_(0, 0)(108, 140)_attn_map_step_{step}_seed_{seed}" # new
|
24 |
+
|
25 |
+
if saved_dtype == 'uint8':
|
26 |
+
attn_map_name = attn_map_name_wo_ext + '_uint8.npy'
|
27 |
+
attn_map_path = os.path.join(exp_dir, attn_map_name)
|
28 |
+
attn_map_raws.append(np.load(attn_map_path, allow_pickle=True))
|
29 |
+
|
30 |
+
else:
|
31 |
+
attn_map_name = attn_map_name_wo_ext + '.pt'
|
32 |
+
attn_map_path = os.path.join(exp_dir, attn_map_name)
|
33 |
+
attn_map_raws.append(torch.load(attn_map_path))
|
34 |
+
print(attn_map_path)
|
35 |
+
|
36 |
+
attn_map_path = os.path.join(exp_dir, attn_map_name)
|
37 |
+
|
38 |
+
print(f"{step} is on memory")
|
39 |
+
|
40 |
+
keys = [key for key in attn_map_raws[0].keys()]
|
41 |
+
|
42 |
+
|
43 |
+
print(len(keys))
|
44 |
+
key = keys[0]
|
45 |
+
|
46 |
+
########################
|
47 |
+
tgt_idx = 3 # indicating the location of generated images.
|
48 |
+
|
49 |
+
attn_map_paired_rgb_grid_name = f"{style_name}_src_{src_name}_tgt_{tgt_name}_scale_1.0_activate_layer_(0, 0)(108, 140)_seed_{seed}.png"
|
50 |
+
|
51 |
+
attn_map_paired_rgb_grid_path = os.path.join(exp_dir, attn_map_paired_rgb_grid_name)
|
52 |
+
print(attn_map_paired_rgb_grid_path)
|
53 |
+
attn_map_paired_rgb_grid = Image.open(attn_map_paired_rgb_grid_path)
|
54 |
+
|
55 |
+
attn_map_src_img = get_image(attn_map_paired_rgb_grid, row = 0, col = 0, image_size = 1024, grid_width = 10)
|
56 |
+
attn_map_tgt_img = get_image(attn_map_paired_rgb_grid, row = 0, col = tgt_idx, image_size = 1024, grid_width = 10)
|
57 |
+
|
58 |
+
|
59 |
+
h, w = 256, 256
|
60 |
+
num_of_grid = 64
|
61 |
+
|
62 |
+
plus_50 = 0
|
63 |
+
|
64 |
+
# key_idx_list = [0,2,4,6,8,10]
|
65 |
+
key_idx_list = [6, 28]
|
66 |
+
# (108 -> 0, 109 -> 1, ... , 140 -> 32)
|
67 |
+
# if Swapping Attentio nin (108, 140) layer , use key_idx_list = [6, 28].
|
68 |
+
# 6==early upblock, 28==late upblock
|
69 |
+
|
70 |
+
saved_attention_map_idx = [0]
|
71 |
+
|
72 |
+
source_image = attn_map_src_img
|
73 |
+
target_image = attn_map_tgt_img
|
74 |
+
|
75 |
+
# resize
|
76 |
+
source_image = source_image.resize((h, w))
|
77 |
+
target_image = target_image.resize((h, w))
|
78 |
+
|
79 |
+
# convert to numpy array
|
80 |
+
source_image = np.array(source_image)
|
81 |
+
target_image = np.array(target_image)
|
82 |
+
|
83 |
+
canvas = Canvas(width=4 * w, height=h * len(key_idx_list), sync_image_data=True)
|
84 |
+
canvas.put_image_data(source_image, w * 3, 0)
|
85 |
+
canvas.put_image_data(target_image, 0, 0)
|
86 |
+
|
87 |
+
canvas.put_image_data(source_image, w * 3, h)
|
88 |
+
canvas.put_image_data(target_image, 0, h)
|
89 |
+
|
90 |
+
# Display the canvas
|
91 |
+
# display(canvas)
|
92 |
+
|
93 |
+
|
94 |
+
def save_to_file(*args, **kwargs):
|
95 |
+
canvas.to_file("my_file1.png")
|
96 |
+
|
97 |
+
|
98 |
+
# Listen to changes on the ``image_data`` trait and call ``save_to_file`` when it changes.
|
99 |
+
canvas.observe(save_to_file, "image_data")
|
100 |
+
|
101 |
+
|
102 |
+
def on_click(x, y):
|
103 |
+
cnt = 0
|
104 |
+
canvas.put_image_data(target_image, 0, 0)
|
105 |
+
|
106 |
+
print(x, y)
|
107 |
+
# draw a point
|
108 |
+
canvas.fill_style = 'red'
|
109 |
+
canvas.fill_circle(x, y, 4)
|
110 |
+
|
111 |
+
for step_i, step in enumerate(range(len(saved_attention_map_idx))):
|
112 |
+
|
113 |
+
attn_map_raw = attn_map_raws[step_i]
|
114 |
+
|
115 |
+
for key_i, key_idx in enumerate(key_idx_list):
|
116 |
+
key = keys[key_idx]
|
117 |
+
|
118 |
+
num_of_grid = int(attn_map_raw[key].shape[-1] ** (0.5))
|
119 |
+
|
120 |
+
# normalize x,y
|
121 |
+
grid_x_idx = int(x / (w / num_of_grid))
|
122 |
+
grid_y_idx = int(y / (h / num_of_grid))
|
123 |
+
|
124 |
+
print(grid_x_idx, grid_y_idx)
|
125 |
+
|
126 |
+
grid_idx = grid_x_idx + grid_y_idx * num_of_grid
|
127 |
+
|
128 |
+
attn_map = attn_map_raw[key][tgt_idx * 10:10 + tgt_idx * 10, grid_idx, :]
|
129 |
+
|
130 |
+
attn_map = attn_map.sum(dim=0)
|
131 |
+
|
132 |
+
attn_map = attn_map.reshape(num_of_grid, num_of_grid)
|
133 |
+
|
134 |
+
# process attn_map to pil
|
135 |
+
attn_map = attn_map.detach().cpu().numpy()
|
136 |
+
# attn_map = attn_map / attn_map.max()
|
137 |
+
# normalized_attn_map = attn_map
|
138 |
+
normalized_attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-8)
|
139 |
+
normalized_attn_map = 1.0 - normalized_attn_map
|
140 |
+
|
141 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * normalized_attn_map), cv2.COLORMAP_JET)
|
142 |
+
heatmap = cv2.resize(heatmap, (w, h))
|
143 |
+
|
144 |
+
attn_map = normalized_attn_map * 255
|
145 |
+
|
146 |
+
attn_map = attn_map.astype(np.uint8)
|
147 |
+
|
148 |
+
attn_map = cv2.cvtColor(attn_map, cv2.COLOR_GRAY2RGB)
|
149 |
+
# attn_map = cv2.cvtColor(attn_map, cv2.COLORMAP_JET)
|
150 |
+
attn_map = cv2.resize(attn_map, (w, h))
|
151 |
+
|
152 |
+
# draw attn_map
|
153 |
+
canvas.put_image_data(attn_map, w + step_i * 4 * w, h * key_i)
|
154 |
+
# canvas.put_image_data(attn_map, w , h*key_i)
|
155 |
+
|
156 |
+
# blend attn_map and target image
|
157 |
+
alpha = 0.85
|
158 |
+
blended_image = cv2.addWeighted(source_image, 1 - alpha, heatmap, alpha, 0)
|
159 |
+
|
160 |
+
# draw blended image
|
161 |
+
canvas.put_image_data(blended_image, w * 2 + step_i * 4 * w, h * key_i)
|
162 |
+
|
163 |
+
cnt += 1
|
164 |
+
|
165 |
+
# Attach the event handler to the canvas
|
166 |
+
|
167 |
+
|
168 |
+
canvas.on_mouse_down(on_click)
|