fffiloni commited on
Commit
2eac672
1 Parent(s): 2401398

Upload 3 files

Browse files
Files changed (3) hide show
  1. blora_utils.py +46 -0
  2. requirements.txt +11 -0
  3. train_dreambooth_b-lora_sdxl.py +2029 -0
blora_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ BLOCKS = {
4
+ 'content': ['unet.up_blocks.0.attentions.0'],
5
+ 'style': ['unet.up_blocks.0.attentions.1'],
6
+ }
7
+
8
+
9
+ def is_belong_to_blocks(key, blocks):
10
+ try:
11
+ for g in blocks:
12
+ if g in key:
13
+ return True
14
+ return False
15
+ except Exception as e:
16
+ raise type(e)(f'failed to is_belong_to_block, due to: {e}')
17
+
18
+
19
+ def filter_lora(state_dict, blocks_):
20
+ try:
21
+ return {k: v for k, v in state_dict.items() if is_belong_to_blocks(k, blocks_)}
22
+ except Exception as e:
23
+ raise type(e)(f'failed to filter_lora, due to: {e}')
24
+
25
+
26
+ def scale_lora(state_dict, alpha):
27
+ try:
28
+ return {k: v * alpha for k, v in state_dict.items()}
29
+ except Exception as e:
30
+ raise type(e)(f'failed to scale_lora, due to: {e}')
31
+
32
+
33
+ def get_target_modules(unet, blocks=None):
34
+ try:
35
+ if not blocks:
36
+ blocks = [('.').join(blk.split('.')[1:]) for blk in BLOCKS['content'] + BLOCKS['style']]
37
+
38
+ attns = [attn_processor_name.rsplit('.', 1)[0] for attn_processor_name, _ in unet.attn_processors.items() if
39
+ is_belong_to_blocks(attn_processor_name, blocks)]
40
+
41
+ target_modules = [f'{attn}.{mat}' for mat in ["to_k", "to_q", "to_v", "to_out.0"] for attn in attns]
42
+ return target_modules
43
+ except Exception as e:
44
+ raise type(e)(f'failed to get_target_modules, due to: {e}')
45
+
46
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ bitsandbytes==0.36.0.post2
3
+ datasets
4
+ diffusers==0.25.0
5
+ ftfy==6.1.1
6
+ huggingface-hub
7
+ Pillow==9.4.0
8
+ python-slugify==7.0.0
9
+ torch
10
+ torchvision
11
+ transformers
train_dreambooth_b-lora_sdxl.py ADDED
@@ -0,0 +1,2029 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import gc
18
+ import itertools
19
+ import logging
20
+ import math
21
+ import os
22
+ import shutil
23
+ import warnings
24
+ from pathlib import Path
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ import transformers
31
+ from accelerate import Accelerator
32
+ from accelerate.logging import get_logger
33
+ from accelerate.utils import (
34
+ DistributedDataParallelKwargs,
35
+ ProjectConfiguration,
36
+ set_seed,
37
+ )
38
+ from huggingface_hub import create_repo, upload_folder
39
+ from huggingface_hub.utils import insecure_hashlib
40
+ from packaging import version
41
+ from PIL import Image
42
+ from PIL.ImageOps import exif_transpose
43
+ from torch.utils.data import Dataset
44
+ from torchvision import transforms
45
+ from tqdm.auto import tqdm
46
+ from transformers import AutoTokenizer, PretrainedConfig
47
+
48
+ import diffusers
49
+ from diffusers import (
50
+ AutoencoderKL,
51
+ DDPMScheduler,
52
+ DPMSolverMultistepScheduler,
53
+ StableDiffusionXLPipeline,
54
+ UNet2DConditionModel,
55
+ )
56
+ from diffusers.loaders import LoraLoaderMixin
57
+ from diffusers.models.lora import LoRALinearLayer
58
+ from diffusers.optimization import get_scheduler
59
+ from diffusers.training_utils import compute_snr, unet_lora_state_dict
60
+ from diffusers.utils import check_min_version, is_wandb_available
61
+ from diffusers.utils.import_utils import is_xformers_available
62
+
63
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
64
+ check_min_version("0.24.0.dev0")
65
+
66
+ logger = get_logger(__name__)
67
+
68
+
69
+ # TODO: This function should be removed once training scripts are rewritten in PEFT
70
+ def text_encoder_lora_state_dict(text_encoder):
71
+ state_dict = {}
72
+
73
+ def text_encoder_attn_modules(text_encoder):
74
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection
75
+
76
+ attn_modules = []
77
+
78
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
79
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
80
+ name = f"text_model.encoder.layers.{i}.self_attn"
81
+ mod = layer.self_attn
82
+ attn_modules.append((name, mod))
83
+
84
+ return attn_modules
85
+
86
+ for name, module in text_encoder_attn_modules(text_encoder):
87
+ for k, v in module.q_proj.lora_linear_layer.state_dict().items():
88
+ state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
89
+
90
+ for k, v in module.k_proj.lora_linear_layer.state_dict().items():
91
+ state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
92
+
93
+ for k, v in module.v_proj.lora_linear_layer.state_dict().items():
94
+ state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
95
+
96
+ for k, v in module.out_proj.lora_linear_layer.state_dict().items():
97
+ state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
98
+
99
+ return state_dict
100
+
101
+
102
+ def save_model_card(
103
+ repo_id: str,
104
+ images=None,
105
+ base_model=str,
106
+ train_text_encoder=False,
107
+ instance_prompt=str,
108
+ validation_prompt=str,
109
+ repo_folder=None,
110
+ vae_path=None,
111
+ ):
112
+ img_str = "widget:\n" if images else ""
113
+ for i, image in enumerate(images):
114
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
115
+ img_str += f"""
116
+ - text: '{validation_prompt if validation_prompt else ' '}'
117
+ output:
118
+ url:
119
+ "image_{i}.png"
120
+ """
121
+
122
+ yaml = f"""
123
+ ---
124
+ tags:
125
+ - stable-diffusion-xl
126
+ - stable-diffusion-xl-diffusers
127
+ - text-to-image
128
+ - diffusers
129
+ - lora
130
+ - template:sd-lora
131
+ {img_str}
132
+ base_model: {base_model}
133
+ instance_prompt: {instance_prompt}
134
+ license: openrail++
135
+ ---
136
+ """
137
+
138
+ model_card = f"""
139
+ # SDXL LoRA DreamBooth - {repo_id}
140
+
141
+ <Gallery />
142
+
143
+ ## Model description
144
+
145
+ These are {repo_id} LoRA adaption weights for {base_model}.
146
+
147
+ The weights were trained using [DreamBooth](https://dreambooth.github.io/).
148
+
149
+ LoRA for the text encoder was enabled: {train_text_encoder}.
150
+
151
+ Special VAE used for training: {vae_path}.
152
+
153
+ ## Trigger words
154
+
155
+ You should use {instance_prompt} to trigger the image generation.
156
+
157
+ ## Download model
158
+
159
+ Weights for this model are available in Safetensors format.
160
+
161
+ [Download]({repo_id}/tree/main) them in the Files & versions tab.
162
+
163
+ """
164
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
165
+ f.write(yaml + model_card)
166
+
167
+
168
+ def import_model_class_from_model_name_or_path(
169
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
170
+ ):
171
+ text_encoder_config = PretrainedConfig.from_pretrained(
172
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
173
+ )
174
+ model_class = text_encoder_config.architectures[0]
175
+
176
+ if model_class == "CLIPTextModel":
177
+ from transformers import CLIPTextModel
178
+
179
+ return CLIPTextModel
180
+ elif model_class == "CLIPTextModelWithProjection":
181
+ from transformers import CLIPTextModelWithProjection
182
+
183
+ return CLIPTextModelWithProjection
184
+ else:
185
+ raise ValueError(f"{model_class} is not supported.")
186
+
187
+
188
+ def parse_args(input_args=None):
189
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
190
+ parser.add_argument(
191
+ "--pretrained_model_name_or_path",
192
+ type=str,
193
+ default=None,
194
+ required=True,
195
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
196
+ )
197
+ parser.add_argument(
198
+ "--pretrained_vae_model_name_or_path",
199
+ type=str,
200
+ default=None,
201
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
202
+ )
203
+ parser.add_argument(
204
+ "--revision",
205
+ type=str,
206
+ default=None,
207
+ required=False,
208
+ help="Revision of pretrained model identifier from huggingface.co/models.",
209
+ )
210
+ parser.add_argument(
211
+ "--dataset_name",
212
+ type=str,
213
+ default=None,
214
+ help=(
215
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
216
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
217
+ " or to a folder containing files that 🤗 Datasets can understand."
218
+ ),
219
+ )
220
+ parser.add_argument(
221
+ "--dataset_config_name",
222
+ type=str,
223
+ default=None,
224
+ help="The config of the Dataset, leave as None if there's only one config.",
225
+ )
226
+ parser.add_argument(
227
+ "--instance_data_dir",
228
+ type=str,
229
+ default=None,
230
+ help=("A folder containing the training data. "),
231
+ )
232
+
233
+ parser.add_argument(
234
+ "--cache_dir",
235
+ type=str,
236
+ default=None,
237
+ help="The directory where the downloaded models and datasets will be stored.",
238
+ )
239
+
240
+ parser.add_argument(
241
+ "--image_column",
242
+ type=str,
243
+ default="image",
244
+ help="The column of the dataset containing the target image. By "
245
+ "default, the standard Image Dataset maps out 'file_name' "
246
+ "to 'image'.",
247
+ )
248
+ parser.add_argument(
249
+ "--caption_column",
250
+ type=str,
251
+ default=None,
252
+ help="The column of the dataset containing the instance prompt for each image",
253
+ )
254
+
255
+ parser.add_argument(
256
+ "--repeats",
257
+ type=int,
258
+ default=1,
259
+ help="How many times to repeat the training data.",
260
+ )
261
+
262
+ parser.add_argument(
263
+ "--class_data_dir",
264
+ type=str,
265
+ default=None,
266
+ required=False,
267
+ help="A folder containing the training data of class images.",
268
+ )
269
+ parser.add_argument(
270
+ "--instance_prompt",
271
+ type=str,
272
+ default=None,
273
+ required=True,
274
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
275
+ )
276
+ parser.add_argument(
277
+ "--class_prompt",
278
+ type=str,
279
+ default=None,
280
+ help="The prompt to specify images in the same class as provided instance images.",
281
+ )
282
+ parser.add_argument(
283
+ "--validation_prompt",
284
+ type=str,
285
+ default=None,
286
+ help="A prompt that is used during validation to verify that the model is learning.",
287
+ )
288
+ parser.add_argument(
289
+ "--num_validation_images",
290
+ type=int,
291
+ default=4,
292
+ help="Number of images that should be generated during validation with `validation_prompt`.",
293
+ )
294
+ parser.add_argument(
295
+ "--validation_epochs",
296
+ type=int,
297
+ default=50,
298
+ help=(
299
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
300
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
301
+ ),
302
+ )
303
+ parser.add_argument(
304
+ "--with_prior_preservation",
305
+ default=False,
306
+ action="store_true",
307
+ help="Flag to add prior preservation loss.",
308
+ )
309
+ parser.add_argument(
310
+ "--prior_loss_weight",
311
+ type=float,
312
+ default=1.0,
313
+ help="The weight of prior preservation loss.",
314
+ )
315
+ parser.add_argument(
316
+ "--num_class_images",
317
+ type=int,
318
+ default=100,
319
+ help=(
320
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
321
+ " class_data_dir, additional images will be sampled with class_prompt."
322
+ ),
323
+ )
324
+ parser.add_argument(
325
+ "--output_dir",
326
+ type=str,
327
+ default="lora-dreambooth-model",
328
+ help="The output directory where the model predictions and checkpoints will be written.",
329
+ )
330
+ parser.add_argument(
331
+ "--seed", type=int, default=None, help="A seed for reproducible training."
332
+ )
333
+ parser.add_argument(
334
+ "--resolution",
335
+ type=int,
336
+ default=1024,
337
+ help=(
338
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
339
+ " resolution"
340
+ ),
341
+ )
342
+ parser.add_argument(
343
+ "--crops_coords_top_left_h",
344
+ type=int,
345
+ default=0,
346
+ help=(
347
+ "Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."
348
+ ),
349
+ )
350
+ parser.add_argument(
351
+ "--crops_coords_top_left_w",
352
+ type=int,
353
+ default=0,
354
+ help=(
355
+ "Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."
356
+ ),
357
+ )
358
+ parser.add_argument(
359
+ "--center_crop",
360
+ default=True,
361
+ action="store_true",
362
+ help=(
363
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
364
+ " cropped. The images will be resized to the resolution first before cropping."
365
+ ),
366
+ )
367
+ parser.add_argument(
368
+ "--train_text_encoder",
369
+ action="store_true",
370
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
371
+ )
372
+ parser.add_argument(
373
+ "--train_batch_size",
374
+ type=int,
375
+ default=4,
376
+ help="Batch size (per device) for the training dataloader.",
377
+ )
378
+ parser.add_argument(
379
+ "--sample_batch_size",
380
+ type=int,
381
+ default=4,
382
+ help="Batch size (per device) for sampling images.",
383
+ )
384
+ parser.add_argument("--num_train_epochs", type=int, default=1)
385
+ parser.add_argument(
386
+ "--max_train_steps",
387
+ type=int,
388
+ default=None,
389
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
390
+ )
391
+ parser.add_argument(
392
+ "--checkpointing_steps",
393
+ type=int,
394
+ default=500,
395
+ help=(
396
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
397
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
398
+ " training using `--resume_from_checkpoint`."
399
+ ),
400
+ )
401
+ parser.add_argument(
402
+ "--checkpoints_total_limit",
403
+ type=int,
404
+ default=None,
405
+ help=("Max number of checkpoints to store."),
406
+ )
407
+ parser.add_argument(
408
+ "--resume_from_checkpoint",
409
+ type=str,
410
+ default=None,
411
+ help=(
412
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
413
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
414
+ ),
415
+ )
416
+ parser.add_argument(
417
+ "--gradient_accumulation_steps",
418
+ type=int,
419
+ default=1,
420
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
421
+ )
422
+ parser.add_argument(
423
+ "--gradient_checkpointing",
424
+ action="store_true",
425
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
426
+ )
427
+ parser.add_argument(
428
+ "--learning_rate",
429
+ type=float,
430
+ default=1e-4,
431
+ help="Initial learning rate (after the potential warmup period) to use.",
432
+ )
433
+
434
+ parser.add_argument(
435
+ "--text_encoder_lr",
436
+ type=float,
437
+ default=5e-6,
438
+ help="Text encoder learning rate to use.",
439
+ )
440
+ parser.add_argument(
441
+ "--scale_lr",
442
+ action="store_true",
443
+ default=False,
444
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
445
+ )
446
+ parser.add_argument(
447
+ "--lr_scheduler",
448
+ type=str,
449
+ default="constant",
450
+ help=(
451
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
452
+ ' "constant", "constant_with_warmup"]'
453
+ ),
454
+ )
455
+
456
+ parser.add_argument(
457
+ "--snr_gamma",
458
+ type=float,
459
+ default=None,
460
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
461
+ "More details here: https://arxiv.org/abs/2303.09556.",
462
+ )
463
+ parser.add_argument(
464
+ "--lr_warmup_steps",
465
+ type=int,
466
+ default=500,
467
+ help="Number of steps for the warmup in the lr scheduler.",
468
+ )
469
+ parser.add_argument(
470
+ "--lr_num_cycles",
471
+ type=int,
472
+ default=1,
473
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
474
+ )
475
+ parser.add_argument(
476
+ "--lr_power",
477
+ type=float,
478
+ default=1.0,
479
+ help="Power factor of the polynomial scheduler.",
480
+ )
481
+ parser.add_argument(
482
+ "--dataloader_num_workers",
483
+ type=int,
484
+ default=0,
485
+ help=(
486
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
487
+ ),
488
+ )
489
+
490
+ parser.add_argument(
491
+ "--optimizer",
492
+ type=str,
493
+ default="AdamW",
494
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
495
+ )
496
+
497
+ parser.add_argument(
498
+ "--use_8bit_adam",
499
+ action="store_true",
500
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
501
+ )
502
+
503
+ parser.add_argument(
504
+ "--adam_beta1",
505
+ type=float,
506
+ default=0.9,
507
+ help="The beta1 parameter for the Adam and Prodigy optimizers.",
508
+ )
509
+ parser.add_argument(
510
+ "--adam_beta2",
511
+ type=float,
512
+ default=0.999,
513
+ help="The beta2 parameter for the Adam and Prodigy optimizers.",
514
+ )
515
+ parser.add_argument(
516
+ "--prodigy_beta3",
517
+ type=float,
518
+ default=None,
519
+ help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
520
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
521
+ )
522
+ parser.add_argument(
523
+ "--prodigy_decouple",
524
+ type=bool,
525
+ default=True,
526
+ help="Use AdamW style decoupled weight decay",
527
+ )
528
+ parser.add_argument(
529
+ "--adam_weight_decay",
530
+ type=float,
531
+ default=1e-04,
532
+ help="Weight decay to use for unet params",
533
+ )
534
+ parser.add_argument(
535
+ "--adam_weight_decay_text_encoder",
536
+ type=float,
537
+ default=1e-03,
538
+ help="Weight decay to use for text_encoder",
539
+ )
540
+
541
+ parser.add_argument(
542
+ "--adam_epsilon",
543
+ type=float,
544
+ default=1e-08,
545
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
546
+ )
547
+
548
+ parser.add_argument(
549
+ "--prodigy_use_bias_correction",
550
+ type=bool,
551
+ default=True,
552
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
553
+ )
554
+ parser.add_argument(
555
+ "--prodigy_safeguard_warmup",
556
+ type=bool,
557
+ default=True,
558
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
559
+ "Ignored if optimizer is adamW",
560
+ )
561
+ parser.add_argument(
562
+ "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
563
+ )
564
+ parser.add_argument(
565
+ "--push_to_hub",
566
+ action="store_true",
567
+ help="Whether or not to push the model to the Hub.",
568
+ )
569
+ parser.add_argument(
570
+ "--hub_token",
571
+ type=str,
572
+ default=None,
573
+ help="The token to use to push to the Model Hub.",
574
+ )
575
+ parser.add_argument(
576
+ "--hub_model_id",
577
+ type=str,
578
+ default=None,
579
+ help="The name of the repository to keep in sync with the local `output_dir`.",
580
+ )
581
+ parser.add_argument(
582
+ "--logging_dir",
583
+ type=str,
584
+ default="logs",
585
+ help=(
586
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
587
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
588
+ ),
589
+ )
590
+ parser.add_argument(
591
+ "--allow_tf32",
592
+ action="store_true",
593
+ help=(
594
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
595
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
596
+ ),
597
+ )
598
+ parser.add_argument(
599
+ "--report_to",
600
+ type=str,
601
+ default="tensorboard",
602
+ help=(
603
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
604
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
605
+ ),
606
+ )
607
+ parser.add_argument(
608
+ "--mixed_precision",
609
+ type=str,
610
+ default=None,
611
+ choices=["no", "fp16", "bf16"],
612
+ help=(
613
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
614
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
615
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
616
+ ),
617
+ )
618
+ parser.add_argument(
619
+ "--prior_generation_precision",
620
+ type=str,
621
+ default=None,
622
+ choices=["no", "fp32", "fp16", "bf16"],
623
+ help=(
624
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
625
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
626
+ ),
627
+ )
628
+ parser.add_argument(
629
+ "--local_rank",
630
+ type=int,
631
+ default=-1,
632
+ help="For distributed training: local_rank",
633
+ )
634
+ parser.add_argument(
635
+ "--enable_xformers_memory_efficient_attention",
636
+ action="store_true",
637
+ help="Whether or not to use xformers.",
638
+ )
639
+ parser.add_argument(
640
+ "--rank",
641
+ type=int,
642
+ default=4,
643
+ help=("The dimension of the LoRA update matrices."),
644
+ )
645
+
646
+ if input_args is not None:
647
+ args = parser.parse_args(input_args)
648
+ else:
649
+ args = parser.parse_args()
650
+
651
+ if args.dataset_name is None and args.instance_data_dir is None:
652
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
653
+
654
+ if args.dataset_name is not None and args.instance_data_dir is not None:
655
+ raise ValueError(
656
+ "Specify only one of `--dataset_name` or `--instance_data_dir`"
657
+ )
658
+
659
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
660
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
661
+ args.local_rank = env_local_rank
662
+
663
+ if args.with_prior_preservation:
664
+ if args.class_data_dir is None:
665
+ raise ValueError("You must specify a data directory for class images.")
666
+ if args.class_prompt is None:
667
+ raise ValueError("You must specify prompt for class images.")
668
+ else:
669
+ # logger is not available yet
670
+ if args.class_data_dir is not None:
671
+ warnings.warn(
672
+ "You need not use --class_data_dir without --with_prior_preservation."
673
+ )
674
+ if args.class_prompt is not None:
675
+ warnings.warn(
676
+ "You need not use --class_prompt without --with_prior_preservation."
677
+ )
678
+
679
+ return args
680
+
681
+
682
+ class DreamBoothDataset(Dataset):
683
+ """
684
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
685
+ It pre-processes the images.
686
+ """
687
+
688
+ def __init__(
689
+ self,
690
+ instance_data_root,
691
+ instance_prompt,
692
+ class_prompt,
693
+ class_data_root=None,
694
+ class_num=None,
695
+ size=1024,
696
+ repeats=1,
697
+ center_crop=False,
698
+ ):
699
+ self.size = size
700
+ self.center_crop = center_crop
701
+
702
+ self.instance_prompt = instance_prompt
703
+ self.custom_instance_prompts = None
704
+ self.class_prompt = class_prompt
705
+
706
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
707
+ # we load the training data using load_dataset
708
+ if args.dataset_name is not None:
709
+ try:
710
+ from datasets import load_dataset
711
+ except ImportError:
712
+ raise ImportError(
713
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
714
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
715
+ "local folder containing images only, specify --instance_data_dir instead."
716
+ )
717
+ # Downloading and loading a dataset from the hub.
718
+ # See more about loading custom images at
719
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
720
+ dataset = load_dataset(
721
+ args.dataset_name,
722
+ args.dataset_config_name,
723
+ cache_dir=args.cache_dir,
724
+ )
725
+ # Preprocessing the datasets.
726
+ column_names = dataset["train"].column_names
727
+
728
+ # 6. Get the column names for input/target.
729
+ if args.image_column is None:
730
+ image_column = column_names[0]
731
+ logger.info(f"image column defaulting to {image_column}")
732
+ else:
733
+ image_column = args.image_column
734
+ if image_column not in column_names:
735
+ raise ValueError(
736
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
737
+ )
738
+ instance_images = dataset["train"][image_column]
739
+
740
+ if args.caption_column is None:
741
+ logger.info(
742
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
743
+ "contains captions/prompts for the images, make sure to specify the "
744
+ "column as --caption_column"
745
+ )
746
+ self.custom_instance_prompts = None
747
+ else:
748
+ if args.caption_column not in column_names:
749
+ raise ValueError(
750
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
751
+ )
752
+ custom_instance_prompts = dataset["train"][args.caption_column]
753
+ # create final list of captions according to --repeats
754
+ self.custom_instance_prompts = []
755
+ for caption in custom_instance_prompts:
756
+ self.custom_instance_prompts.extend(
757
+ itertools.repeat(caption, repeats)
758
+ )
759
+ else:
760
+ self.instance_data_root = Path(instance_data_root)
761
+ if not self.instance_data_root.exists():
762
+ raise ValueError("Instance images root doesn't exists.")
763
+
764
+ instance_images = [
765
+ Image.open(path) for path in list(Path(instance_data_root).iterdir())
766
+ ]
767
+ self.custom_instance_prompts = None
768
+
769
+ self.instance_images = []
770
+ for img in instance_images:
771
+ self.instance_images.extend(itertools.repeat(img, repeats))
772
+ self.num_instance_images = len(self.instance_images)
773
+ self._length = self.num_instance_images
774
+
775
+ if class_data_root is not None:
776
+ self.class_data_root = Path(class_data_root)
777
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
778
+ self.class_images_path = list(self.class_data_root.iterdir())
779
+ if class_num is not None:
780
+ self.num_class_images = min(len(self.class_images_path), class_num)
781
+ else:
782
+ self.num_class_images = len(self.class_images_path)
783
+ self._length = max(self.num_class_images, self.num_instance_images)
784
+ else:
785
+ self.class_data_root = None
786
+
787
+ self.image_transforms = transforms.Compose(
788
+ [
789
+ transforms.Resize(
790
+ size, interpolation=transforms.InterpolationMode.BILINEAR
791
+ ),
792
+ # transforms.CenterCrop(size)
793
+ # if center_crop
794
+ # else transforms.RandomCrop(size),
795
+ transforms.ToTensor(),
796
+ transforms.Normalize([0.5], [0.5]),
797
+ ]
798
+ )
799
+
800
+ def __len__(self):
801
+ return self._length
802
+
803
+ def __getitem__(self, index):
804
+ example = {}
805
+ instance_image = self.instance_images[index % self.num_instance_images]
806
+ # instance_image = exif_transpose(instance_image)
807
+
808
+ if not instance_image.mode == "RGB":
809
+ instance_image = instance_image.convert("RGB")
810
+ example["instance_images"] = self.image_transforms(instance_image)
811
+
812
+ if self.custom_instance_prompts:
813
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
814
+ if caption:
815
+ example["instance_prompt"] = caption
816
+ else:
817
+ example["instance_prompt"] = self.instance_prompt
818
+
819
+ else: # costum prompts were provided, but length does not match size of image dataset
820
+ example["instance_prompt"] = self.instance_prompt
821
+
822
+ if self.class_data_root:
823
+ class_image = Image.open(
824
+ self.class_images_path[index % self.num_class_images]
825
+ )
826
+ class_image = exif_transpose(class_image)
827
+
828
+ if not class_image.mode == "RGB":
829
+ class_image = class_image.convert("RGB")
830
+ example["class_images"] = self.image_transforms(class_image)
831
+ example["class_prompt"] = self.class_prompt
832
+
833
+ return example
834
+
835
+
836
+ def collate_fn(examples, with_prior_preservation=False):
837
+ pixel_values = [example["instance_images"] for example in examples]
838
+ prompts = [example["instance_prompt"] for example in examples]
839
+
840
+ # Concat class and instance examples for prior preservation.
841
+ # We do this to avoid doing two forward passes.
842
+ if with_prior_preservation:
843
+ pixel_values += [example["class_images"] for example in examples]
844
+ prompts += [example["class_prompt"] for example in examples]
845
+
846
+ pixel_values = torch.stack(pixel_values)
847
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
848
+
849
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
850
+ return batch
851
+
852
+
853
+ class PromptDataset(Dataset):
854
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
855
+
856
+ def __init__(self, prompt, num_samples):
857
+ self.prompt = prompt
858
+ self.num_samples = num_samples
859
+
860
+ def __len__(self):
861
+ return self.num_samples
862
+
863
+ def __getitem__(self, index):
864
+ example = {}
865
+ example["prompt"] = self.prompt
866
+ example["index"] = index
867
+ return example
868
+
869
+
870
+ def tokenize_prompt(tokenizer, prompt):
871
+ text_inputs = tokenizer(
872
+ prompt,
873
+ padding="max_length",
874
+ max_length=tokenizer.model_max_length,
875
+ truncation=True,
876
+ return_tensors="pt",
877
+ )
878
+ text_input_ids = text_inputs.input_ids
879
+ return text_input_ids
880
+
881
+
882
+ # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
883
+ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
884
+ prompt_embeds_list = []
885
+
886
+ for i, text_encoder in enumerate(text_encoders):
887
+ if tokenizers is not None:
888
+ tokenizer = tokenizers[i]
889
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
890
+ else:
891
+ assert text_input_ids_list is not None
892
+ text_input_ids = text_input_ids_list[i]
893
+
894
+ prompt_embeds = text_encoder(
895
+ text_input_ids.to(text_encoder.device),
896
+ output_hidden_states=True,
897
+ )
898
+
899
+ # We are only ALWAYS interested in the pooled output of the final text encoder
900
+ pooled_prompt_embeds = prompt_embeds[0]
901
+ prompt_embeds = prompt_embeds.hidden_states[-2]
902
+ bs_embed, seq_len, _ = prompt_embeds.shape
903
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
904
+ prompt_embeds_list.append(prompt_embeds)
905
+
906
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
907
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
908
+ return prompt_embeds, pooled_prompt_embeds
909
+
910
+
911
+ def is_belong_to_groups(key: str, groups: list) -> bool:
912
+ try:
913
+ for g in groups:
914
+ if key.startswith(g):
915
+ return True
916
+ return False
917
+ except Exception as e:
918
+ raise type(e)(f'failed to is_belong_to_groups, due to: {e}')
919
+
920
+
921
+ def filter_lora_layers(lora_state_dict: dict, groups: list) -> dict:
922
+ try:
923
+ return {k: v for k, v in lora_state_dict.items() if is_belong_to_groups(k, groups)}
924
+ except Exception as e:
925
+ raise type(e)(f'failed to filter_lora_layers, due to: {e}')
926
+
927
+
928
+ def main(args):
929
+ logging_dir = Path(args.output_dir, args.logging_dir)
930
+
931
+ accelerator_project_config = ProjectConfiguration(
932
+ project_dir=args.output_dir, logging_dir=logging_dir
933
+ )
934
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
935
+ accelerator = Accelerator(
936
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
937
+ mixed_precision=args.mixed_precision,
938
+ log_with=args.report_to,
939
+ project_config=accelerator_project_config,
940
+ kwargs_handlers=[kwargs],
941
+ )
942
+
943
+ if args.report_to == "wandb":
944
+ if not is_wandb_available():
945
+ raise ImportError(
946
+ "Make sure to install wandb if you want to use it for logging during training."
947
+ )
948
+ import wandb
949
+
950
+ # Make one log on every process with the configuration for debugging.
951
+ logging.basicConfig(
952
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
953
+ datefmt="%m/%d/%Y %H:%M:%S",
954
+ level=logging.INFO,
955
+ )
956
+ logger.info(accelerator.state, main_process_only=False)
957
+ if accelerator.is_local_main_process:
958
+ transformers.utils.logging.set_verbosity_warning()
959
+ diffusers.utils.logging.set_verbosity_info()
960
+ else:
961
+ transformers.utils.logging.set_verbosity_error()
962
+ diffusers.utils.logging.set_verbosity_error()
963
+
964
+ # If passed along, set the training seed now.
965
+ if args.seed is not None:
966
+ set_seed(args.seed)
967
+
968
+ # Generate class images if prior preservation is enabled.
969
+ if args.with_prior_preservation:
970
+ class_images_dir = Path(args.class_data_dir)
971
+ if not class_images_dir.exists():
972
+ class_images_dir.mkdir(parents=True)
973
+ cur_class_images = len(list(class_images_dir.iterdir()))
974
+
975
+ if cur_class_images < args.num_class_images:
976
+ torch_dtype = (
977
+ torch.float16 if accelerator.device.type == "cuda" else torch.float32
978
+ )
979
+ if args.prior_generation_precision == "fp32":
980
+ torch_dtype = torch.float32
981
+ elif args.prior_generation_precision == "fp16":
982
+ torch_dtype = torch.float16
983
+ elif args.prior_generation_precision == "bf16":
984
+ torch_dtype = torch.bfloat16
985
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
986
+ args.pretrained_model_name_or_path,
987
+ torch_dtype=torch_dtype,
988
+ revision=args.revision,
989
+ )
990
+ pipeline.set_progress_bar_config(disable=True)
991
+
992
+ num_new_images = args.num_class_images - cur_class_images
993
+ logger.info(f"Number of class images to sample: {num_new_images}.")
994
+
995
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
996
+ sample_dataloader = torch.utils.data.DataLoader(
997
+ sample_dataset, batch_size=args.sample_batch_size
998
+ )
999
+
1000
+ sample_dataloader = accelerator.prepare(sample_dataloader)
1001
+ pipeline.to(accelerator.device)
1002
+
1003
+ for example in tqdm(
1004
+ sample_dataloader,
1005
+ desc="Generating class images",
1006
+ disable=not accelerator.is_local_main_process,
1007
+ ):
1008
+ images = pipeline(example["prompt"]).images
1009
+
1010
+ for i, image in enumerate(images):
1011
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
1012
+ image_filename = (
1013
+ class_images_dir
1014
+ / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
1015
+ )
1016
+ image.save(image_filename)
1017
+
1018
+ del pipeline
1019
+ if torch.cuda.is_available():
1020
+ torch.cuda.empty_cache()
1021
+
1022
+ # Handle the repository creation
1023
+ if accelerator.is_main_process:
1024
+ if args.output_dir is not None:
1025
+ os.makedirs(args.output_dir, exist_ok=True)
1026
+
1027
+ if args.push_to_hub:
1028
+ repo_id = create_repo(
1029
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
1030
+ exist_ok=True,
1031
+ token=args.hub_token,
1032
+ ).repo_id
1033
+
1034
+ # Load the tokenizers
1035
+ tokenizer_one = AutoTokenizer.from_pretrained(
1036
+ args.pretrained_model_name_or_path,
1037
+ subfolder="tokenizer",
1038
+ revision=args.revision,
1039
+ use_fast=False,
1040
+ )
1041
+ tokenizer_two = AutoTokenizer.from_pretrained(
1042
+ args.pretrained_model_name_or_path,
1043
+ subfolder="tokenizer_2",
1044
+ revision=args.revision,
1045
+ use_fast=False,
1046
+ )
1047
+
1048
+ # import correct text encoder classes
1049
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
1050
+ args.pretrained_model_name_or_path, args.revision
1051
+ )
1052
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
1053
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
1054
+ )
1055
+
1056
+ # Load scheduler and models
1057
+ noise_scheduler = DDPMScheduler.from_pretrained(
1058
+ args.pretrained_model_name_or_path, subfolder="scheduler"
1059
+ )
1060
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
1061
+ args.pretrained_model_name_or_path,
1062
+ subfolder="text_encoder",
1063
+ revision=args.revision,
1064
+ )
1065
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
1066
+ args.pretrained_model_name_or_path,
1067
+ subfolder="text_encoder_2",
1068
+ revision=args.revision,
1069
+ )
1070
+ vae_path = (
1071
+ args.pretrained_model_name_or_path
1072
+ if args.pretrained_vae_model_name_or_path is None
1073
+ else args.pretrained_vae_model_name_or_path
1074
+ )
1075
+ vae = AutoencoderKL.from_pretrained(
1076
+ vae_path,
1077
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1078
+ revision=args.revision,
1079
+ )
1080
+ unet = UNet2DConditionModel.from_pretrained(
1081
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
1082
+ )
1083
+
1084
+ # We only train the additional adapter LoRA layers
1085
+ vae.requires_grad_(False)
1086
+ text_encoder_one.requires_grad_(False)
1087
+ text_encoder_two.requires_grad_(False)
1088
+ unet.requires_grad_(False)
1089
+
1090
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
1091
+ # as these weights are only used for inference, keeping weights in full precision is not required.
1092
+ weight_dtype = torch.float32
1093
+ if accelerator.mixed_precision == "fp16":
1094
+ weight_dtype = torch.float16
1095
+ elif accelerator.mixed_precision == "bf16":
1096
+ weight_dtype = torch.bfloat16
1097
+
1098
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
1099
+ unet.to(accelerator.device, dtype=weight_dtype)
1100
+
1101
+ # The VAE is always in float32 to avoid NaN losses.
1102
+ vae.to(accelerator.device, dtype=torch.float32)
1103
+
1104
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
1105
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
1106
+
1107
+ if args.enable_xformers_memory_efficient_attention:
1108
+ if is_xformers_available():
1109
+ import xformers
1110
+
1111
+ xformers_version = version.parse(xformers.__version__)
1112
+ if xformers_version == version.parse("0.0.16"):
1113
+ logger.warn(
1114
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
1115
+ "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
1116
+ )
1117
+ unet.enable_xformers_memory_efficient_attention()
1118
+ else:
1119
+ raise ValueError(
1120
+ "xformers is not available. Make sure it is installed correctly"
1121
+ )
1122
+
1123
+ if args.gradient_checkpointing:
1124
+ unet.enable_gradient_checkpointing()
1125
+ if args.train_text_encoder:
1126
+ text_encoder_one.gradient_checkpointing_enable()
1127
+ text_encoder_two.gradient_checkpointing_enable()
1128
+
1129
+ BLORA_BLOCKS = ['up_blocks.0.attentions.0',
1130
+ 'up_blocks.0.attentions.1']
1131
+
1132
+ # now we will add new LoRA weights to the attention layers
1133
+ # Set correct lora layers
1134
+ unet_lora_parameters = []
1135
+ for attn_processor_name, attn_processor in unet.attn_processors.items():
1136
+ # Parse the attention module.
1137
+ if not is_belong_to_groups(attn_processor_name, BLORA_BLOCKS):
1138
+ continue
1139
+ attn_module = unet
1140
+ for n in attn_processor_name.split(".")[:-1]:
1141
+ attn_module = getattr(attn_module, n)
1142
+
1143
+ # Set the `lora_layer` attribute of the attention-related matrices.
1144
+ attn_module.to_q.set_lora_layer(
1145
+ LoRALinearLayer(
1146
+ in_features=attn_module.to_q.in_features,
1147
+ out_features=attn_module.to_q.out_features,
1148
+ rank=args.rank,
1149
+ )
1150
+ )
1151
+ attn_module.to_k.set_lora_layer(
1152
+ LoRALinearLayer(
1153
+ in_features=attn_module.to_k.in_features,
1154
+ out_features=attn_module.to_k.out_features,
1155
+ rank=args.rank,
1156
+ )
1157
+ )
1158
+ attn_module.to_v.set_lora_layer(
1159
+ LoRALinearLayer(
1160
+ in_features=attn_module.to_v.in_features,
1161
+ out_features=attn_module.to_v.out_features,
1162
+ rank=args.rank,
1163
+ )
1164
+ )
1165
+ attn_module.to_out[0].set_lora_layer(
1166
+ LoRALinearLayer(
1167
+ in_features=attn_module.to_out[0].in_features,
1168
+ out_features=attn_module.to_out[0].out_features,
1169
+ rank=args.rank,
1170
+ )
1171
+ )
1172
+
1173
+ # Accumulate the LoRA params to optimize.
1174
+ unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
1175
+ unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
1176
+ unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
1177
+ unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
1178
+
1179
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
1180
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
1181
+ if args.train_text_encoder:
1182
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1183
+ text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
1184
+ text_encoder_one, dtype=torch.float32, rank=args.rank
1185
+ )
1186
+ text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
1187
+ text_encoder_two, dtype=torch.float32, rank=args.rank
1188
+ )
1189
+
1190
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
1191
+ def save_model_hook(models, weights, output_dir):
1192
+ if accelerator.is_main_process:
1193
+ # there are only two options here. Either are just the unet attn processor layers
1194
+ # or there are the unet and text encoder atten layers
1195
+ unet_lora_layers_to_save = None
1196
+ text_encoder_one_lora_layers_to_save = None
1197
+ text_encoder_two_lora_layers_to_save = None
1198
+
1199
+ for model in models:
1200
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
1201
+ unet_lora_layers_to_save = unet_lora_state_dict(model)
1202
+ elif isinstance(
1203
+ model, type(accelerator.unwrap_model(text_encoder_one))
1204
+ ):
1205
+ text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(
1206
+ model
1207
+ )
1208
+ elif isinstance(
1209
+ model, type(accelerator.unwrap_model(text_encoder_two))
1210
+ ):
1211
+ text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(
1212
+ model
1213
+ )
1214
+ else:
1215
+ raise ValueError(f"unexpected save model: {model.__class__}")
1216
+
1217
+ # make sure to pop weight so that corresponding model is not saved again
1218
+ weights.pop()
1219
+
1220
+ StableDiffusionXLPipeline.save_lora_weights(
1221
+ output_dir,
1222
+ unet_lora_layers=unet_lora_layers_to_save,
1223
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
1224
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
1225
+ )
1226
+
1227
+ def load_model_hook(models, input_dir):
1228
+ unet_ = None
1229
+ text_encoder_one_ = None
1230
+ text_encoder_two_ = None
1231
+
1232
+ while len(models) > 0:
1233
+ model = models.pop()
1234
+
1235
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
1236
+ unet_ = model
1237
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
1238
+ text_encoder_one_ = model
1239
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
1240
+ text_encoder_two_ = model
1241
+ else:
1242
+ raise ValueError(f"unexpected save model: {model.__class__}")
1243
+
1244
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
1245
+ LoraLoaderMixin.load_lora_into_unet(
1246
+ lora_state_dict, network_alphas=network_alphas, unet=unet_
1247
+ )
1248
+
1249
+ text_encoder_state_dict = {
1250
+ k: v for k, v in lora_state_dict.items() if "text_encoder." in k
1251
+ }
1252
+ LoraLoaderMixin.load_lora_into_text_encoder(
1253
+ text_encoder_state_dict,
1254
+ network_alphas=network_alphas,
1255
+ text_encoder=text_encoder_one_,
1256
+ )
1257
+
1258
+ text_encoder_2_state_dict = {
1259
+ k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k
1260
+ }
1261
+ LoraLoaderMixin.load_lora_into_text_encoder(
1262
+ text_encoder_2_state_dict,
1263
+ network_alphas=network_alphas,
1264
+ text_encoder=text_encoder_two_,
1265
+ )
1266
+
1267
+ accelerator.register_save_state_pre_hook(save_model_hook)
1268
+ accelerator.register_load_state_pre_hook(load_model_hook)
1269
+
1270
+ # Enable TF32 for faster training on Ampere GPUs,
1271
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1272
+ if args.allow_tf32:
1273
+ torch.backends.cuda.matmul.allow_tf32 = True
1274
+
1275
+ if args.scale_lr:
1276
+ args.learning_rate = (
1277
+ args.learning_rate
1278
+ * args.gradient_accumulation_steps
1279
+ * args.train_batch_size
1280
+ * accelerator.num_processes
1281
+ )
1282
+
1283
+ # Optimization parameters
1284
+ unet_lora_parameters_with_lr = {
1285
+ "params": unet_lora_parameters,
1286
+ "lr": args.learning_rate,
1287
+ }
1288
+ if args.train_text_encoder:
1289
+ # different learning rate for text encoder and unet
1290
+ text_lora_parameters_one_with_lr = {
1291
+ "params": text_lora_parameters_one,
1292
+ "weight_decay": args.adam_weight_decay_text_encoder,
1293
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
1294
+ }
1295
+ text_lora_parameters_two_with_lr = {
1296
+ "params": text_lora_parameters_two,
1297
+ "weight_decay": args.adam_weight_decay_text_encoder,
1298
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
1299
+ }
1300
+ params_to_optimize = [
1301
+ unet_lora_parameters_with_lr,
1302
+ text_lora_parameters_one_with_lr,
1303
+ text_lora_parameters_two_with_lr,
1304
+ ]
1305
+ else:
1306
+ params_to_optimize = [unet_lora_parameters_with_lr]
1307
+
1308
+ # Optimizer creation
1309
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
1310
+ logger.warn(
1311
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
1312
+ "Defaulting to adamW"
1313
+ )
1314
+ args.optimizer = "adamw"
1315
+
1316
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
1317
+ logger.warn(
1318
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
1319
+ f"set to {args.optimizer.lower()}"
1320
+ )
1321
+
1322
+ if args.optimizer.lower() == "adamw":
1323
+ if args.use_8bit_adam:
1324
+ try:
1325
+ import bitsandbytes as bnb
1326
+ except ImportError:
1327
+ raise ImportError(
1328
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1329
+ )
1330
+
1331
+ optimizer_class = bnb.optim.AdamW8bit
1332
+ else:
1333
+ optimizer_class = torch.optim.AdamW
1334
+
1335
+ optimizer = optimizer_class(
1336
+ params_to_optimize,
1337
+ betas=(args.adam_beta1, args.adam_beta2),
1338
+ weight_decay=args.adam_weight_decay,
1339
+ eps=args.adam_epsilon,
1340
+ )
1341
+
1342
+ if args.optimizer.lower() == "prodigy":
1343
+ try:
1344
+ import prodigyopt
1345
+ except ImportError:
1346
+ raise ImportError(
1347
+ "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`"
1348
+ )
1349
+
1350
+ optimizer_class = prodigyopt.Prodigy
1351
+
1352
+ if args.learning_rate <= 0.1:
1353
+ logger.warn(
1354
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
1355
+ )
1356
+ if args.train_text_encoder and args.text_encoder_lr:
1357
+ logger.warn(
1358
+ f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
1359
+ f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
1360
+ f"When using prodigy only learning_rate is used as the initial learning rate."
1361
+ )
1362
+ # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
1363
+ # --learning_rate
1364
+ params_to_optimize[1]["lr"] = args.learning_rate
1365
+ params_to_optimize[2]["lr"] = args.learning_rate
1366
+
1367
+ optimizer = optimizer_class(
1368
+ params_to_optimize,
1369
+ lr=args.learning_rate,
1370
+ betas=(args.adam_beta1, args.adam_beta2),
1371
+ beta3=args.prodigy_beta3,
1372
+ weight_decay=args.adam_weight_decay,
1373
+ eps=args.adam_epsilon,
1374
+ decouple=args.prodigy_decouple,
1375
+ use_bias_correction=args.prodigy_use_bias_correction,
1376
+ safeguard_warmup=args.prodigy_safeguard_warmup,
1377
+ )
1378
+
1379
+ # Dataset and DataLoaders creation:
1380
+ train_dataset = DreamBoothDataset(
1381
+ instance_data_root=args.instance_data_dir,
1382
+ instance_prompt=args.instance_prompt,
1383
+ class_prompt=args.class_prompt,
1384
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
1385
+ class_num=args.num_class_images,
1386
+ size=args.resolution,
1387
+ repeats=args.repeats,
1388
+ center_crop=args.center_crop,
1389
+ )
1390
+
1391
+ train_dataloader = torch.utils.data.DataLoader(
1392
+ train_dataset,
1393
+ batch_size=args.train_batch_size,
1394
+ shuffle=True,
1395
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
1396
+ num_workers=args.dataloader_num_workers,
1397
+ )
1398
+
1399
+ # Computes additional embeddings/ids required by the SDXL UNet.
1400
+ # regular text embeddings (when `train_text_encoder` is not True)
1401
+ # pooled text embeddings
1402
+ # time ids
1403
+
1404
+ def compute_time_ids():
1405
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1406
+ original_size = (args.resolution, args.resolution)
1407
+ target_size = (args.resolution, args.resolution)
1408
+ crops_coords_top_left = (
1409
+ args.crops_coords_top_left_h,
1410
+ args.crops_coords_top_left_w,
1411
+ )
1412
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1413
+ add_time_ids = torch.tensor([add_time_ids])
1414
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1415
+ return add_time_ids
1416
+
1417
+ if not args.train_text_encoder:
1418
+ tokenizers = [tokenizer_one, tokenizer_two]
1419
+ text_encoders = [text_encoder_one, text_encoder_two]
1420
+
1421
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1422
+ with torch.no_grad():
1423
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1424
+ text_encoders, tokenizers, prompt
1425
+ )
1426
+ prompt_embeds = prompt_embeds.to(accelerator.device)
1427
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
1428
+ return prompt_embeds, pooled_prompt_embeds
1429
+
1430
+ # Handle instance prompt.
1431
+ instance_time_ids = compute_time_ids()
1432
+
1433
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
1434
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
1435
+ # the redundant encoding.
1436
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
1437
+ (
1438
+ instance_prompt_hidden_states,
1439
+ instance_pooled_prompt_embeds,
1440
+ ) = compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers)
1441
+
1442
+ # Handle class prompt for prior-preservation.
1443
+ if args.with_prior_preservation:
1444
+ class_time_ids = compute_time_ids()
1445
+ if not args.train_text_encoder:
1446
+ (
1447
+ class_prompt_hidden_states,
1448
+ class_pooled_prompt_embeds,
1449
+ ) = compute_text_embeddings(args.class_prompt, text_encoders, tokenizers)
1450
+
1451
+ # Clear the memory here
1452
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
1453
+ del tokenizers, text_encoders
1454
+ gc.collect()
1455
+ torch.cuda.empty_cache()
1456
+
1457
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
1458
+ # pack the statically computed variables appropriately here. This is so that we don't
1459
+ # have to pass them to the dataloader.
1460
+ add_time_ids = instance_time_ids
1461
+ if args.with_prior_preservation:
1462
+ add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
1463
+
1464
+ if not train_dataset.custom_instance_prompts:
1465
+ if not args.train_text_encoder:
1466
+ prompt_embeds = instance_prompt_hidden_states
1467
+ unet_add_text_embeds = instance_pooled_prompt_embeds
1468
+ if args.with_prior_preservation:
1469
+ prompt_embeds = torch.cat(
1470
+ [prompt_embeds, class_prompt_hidden_states], dim=0
1471
+ )
1472
+ unet_add_text_embeds = torch.cat(
1473
+ [unet_add_text_embeds, class_pooled_prompt_embeds], dim=0
1474
+ )
1475
+ # if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
1476
+ # batch prompts on all training steps
1477
+ else:
1478
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
1479
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
1480
+ if args.with_prior_preservation:
1481
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
1482
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
1483
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
1484
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
1485
+
1486
+ # Scheduler and math around the number of training steps.
1487
+ overrode_max_train_steps = False
1488
+ num_update_steps_per_epoch = math.ceil(
1489
+ len(train_dataloader) / args.gradient_accumulation_steps
1490
+ )
1491
+ if args.max_train_steps is None:
1492
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1493
+ overrode_max_train_steps = True
1494
+
1495
+ lr_scheduler = get_scheduler(
1496
+ args.lr_scheduler,
1497
+ optimizer=optimizer,
1498
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1499
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
1500
+ num_cycles=args.lr_num_cycles,
1501
+ power=args.lr_power,
1502
+ )
1503
+
1504
+ # Prepare everything with our `accelerator`.
1505
+ if args.train_text_encoder:
1506
+ (
1507
+ unet,
1508
+ text_encoder_one,
1509
+ text_encoder_two,
1510
+ optimizer,
1511
+ train_dataloader,
1512
+ lr_scheduler,
1513
+ ) = accelerator.prepare(
1514
+ unet,
1515
+ text_encoder_one,
1516
+ text_encoder_two,
1517
+ optimizer,
1518
+ train_dataloader,
1519
+ lr_scheduler,
1520
+ )
1521
+ else:
1522
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1523
+ unet, optimizer, train_dataloader, lr_scheduler
1524
+ )
1525
+
1526
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1527
+ num_update_steps_per_epoch = math.ceil(
1528
+ len(train_dataloader) / args.gradient_accumulation_steps
1529
+ )
1530
+ if overrode_max_train_steps:
1531
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1532
+ # Afterwards we recalculate our number of training epochs
1533
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1534
+
1535
+ # We need to initialize the trackers we use, and also store our configuration.
1536
+ # The trackers initializes automatically on the main process.
1537
+ if accelerator.is_main_process:
1538
+ accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
1539
+
1540
+ # Train!
1541
+ total_batch_size = (
1542
+ args.train_batch_size
1543
+ * accelerator.num_processes
1544
+ * args.gradient_accumulation_steps
1545
+ )
1546
+
1547
+ logger.info("***** Running training *****")
1548
+ logger.info(f" Num examples = {len(train_dataset)}")
1549
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1550
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1551
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1552
+ logger.info(
1553
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
1554
+ )
1555
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1556
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1557
+ global_step = 0
1558
+ first_epoch = 0
1559
+
1560
+ # Potentially load in the weights and states from a previous save
1561
+ if args.resume_from_checkpoint:
1562
+ if args.resume_from_checkpoint != "latest":
1563
+ path = os.path.basename(args.resume_from_checkpoint)
1564
+ else:
1565
+ # Get the mos recent checkpoint
1566
+ dirs = os.listdir(args.output_dir)
1567
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1568
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1569
+ path = dirs[-1] if len(dirs) > 0 else None
1570
+
1571
+ if path is None:
1572
+ accelerator.print(
1573
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1574
+ )
1575
+ args.resume_from_checkpoint = None
1576
+ initial_global_step = 0
1577
+ else:
1578
+ accelerator.print(f"Resuming from checkpoint {path}")
1579
+ accelerator.load_state(os.path.join(args.output_dir, path))
1580
+ global_step = int(path.split("-")[1])
1581
+
1582
+ initial_global_step = global_step
1583
+ first_epoch = global_step // num_update_steps_per_epoch
1584
+
1585
+ else:
1586
+ initial_global_step = 0
1587
+
1588
+ progress_bar = tqdm(
1589
+ range(0, args.max_train_steps),
1590
+ initial=initial_global_step,
1591
+ desc="Steps",
1592
+ # Only show the progress bar once on each machine.
1593
+ disable=not accelerator.is_local_main_process,
1594
+ )
1595
+
1596
+ for epoch in range(first_epoch, args.num_train_epochs):
1597
+ unet.train()
1598
+ if args.train_text_encoder:
1599
+ text_encoder_one.train()
1600
+ text_encoder_two.train()
1601
+
1602
+ # set top parameter requires_grad = True for gradient checkpointing works
1603
+ text_encoder_one.text_model.embeddings.requires_grad_(True)
1604
+ text_encoder_two.text_model.embeddings.requires_grad_(True)
1605
+
1606
+ for step, batch in enumerate(train_dataloader):
1607
+ with accelerator.accumulate(unet):
1608
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1609
+ prompts = batch["prompts"]
1610
+
1611
+ # encode batch prompts when custom prompts are provided for each image -
1612
+ if train_dataset.custom_instance_prompts:
1613
+ if not args.train_text_encoder:
1614
+ prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
1615
+ prompts, text_encoders, tokenizers
1616
+ )
1617
+ else:
1618
+ tokens_one = tokenize_prompt(tokenizer_one, prompts)
1619
+ tokens_two = tokenize_prompt(tokenizer_two, prompts)
1620
+
1621
+ # Convert images to latent space
1622
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1623
+ model_input = model_input * vae.config.scaling_factor
1624
+ if args.pretrained_vae_model_name_or_path is None:
1625
+ model_input = model_input.to(weight_dtype)
1626
+
1627
+ # Sample noise that we'll add to the latents
1628
+ noise = torch.randn_like(model_input)
1629
+ bsz = model_input.shape[0]
1630
+ # Sample a random timestep for each image
1631
+ timesteps = torch.randint(
1632
+ 0,
1633
+ noise_scheduler.config.num_train_timesteps,
1634
+ (bsz,),
1635
+ device=model_input.device,
1636
+ )
1637
+ timesteps = timesteps.long()
1638
+
1639
+ # Add noise to the model input according to the noise magnitude at each timestep
1640
+ # (this is the forward diffusion process)
1641
+ noisy_model_input = noise_scheduler.add_noise(
1642
+ model_input, noise, timesteps
1643
+ )
1644
+
1645
+ # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
1646
+ if not train_dataset.custom_instance_prompts:
1647
+ elems_to_repeat_text_embeds = (
1648
+ bsz // 2 if args.with_prior_preservation else bsz
1649
+ )
1650
+ elems_to_repeat_time_ids = (
1651
+ bsz // 2 if args.with_prior_preservation else bsz
1652
+ )
1653
+ else:
1654
+ elems_to_repeat_text_embeds = 1
1655
+ elems_to_repeat_time_ids = (
1656
+ bsz // 2 if args.with_prior_preservation else bsz
1657
+ )
1658
+
1659
+ # Predict the noise residual
1660
+ if not args.train_text_encoder:
1661
+ unet_added_conditions = {
1662
+ "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
1663
+ "text_embeds": unet_add_text_embeds.repeat(
1664
+ elems_to_repeat_text_embeds, 1
1665
+ ),
1666
+ }
1667
+ prompt_embeds_input = prompt_embeds.repeat(
1668
+ elems_to_repeat_text_embeds, 1, 1
1669
+ )
1670
+ model_pred = unet(
1671
+ noisy_model_input,
1672
+ timesteps,
1673
+ prompt_embeds_input,
1674
+ added_cond_kwargs=unet_added_conditions,
1675
+ ).sample
1676
+ else:
1677
+ unet_added_conditions = {
1678
+ "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)
1679
+ }
1680
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1681
+ text_encoders=[text_encoder_one, text_encoder_two],
1682
+ tokenizers=None,
1683
+ prompt=None,
1684
+ text_input_ids_list=[tokens_one, tokens_two],
1685
+ )
1686
+ unet_added_conditions.update(
1687
+ {
1688
+ "text_embeds": pooled_prompt_embeds.repeat(
1689
+ elems_to_repeat_text_embeds, 1
1690
+ )
1691
+ }
1692
+ )
1693
+ prompt_embeds_input = prompt_embeds.repeat(
1694
+ elems_to_repeat_text_embeds, 1, 1
1695
+ )
1696
+ model_pred = unet(
1697
+ noisy_model_input,
1698
+ timesteps,
1699
+ prompt_embeds_input,
1700
+ added_cond_kwargs=unet_added_conditions,
1701
+ ).sample
1702
+
1703
+ # Get the target for loss depending on the prediction type
1704
+ if noise_scheduler.config.prediction_type == "epsilon":
1705
+ target = noise
1706
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1707
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1708
+ else:
1709
+ raise ValueError(
1710
+ f"Unknown prediction type {noise_scheduler.config.prediction_type}"
1711
+ )
1712
+
1713
+ if args.with_prior_preservation:
1714
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1715
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1716
+ target, target_prior = torch.chunk(target, 2, dim=0)
1717
+
1718
+ # Compute prior loss
1719
+ prior_loss = F.mse_loss(
1720
+ model_pred_prior.float(), target_prior.float(), reduction="mean"
1721
+ )
1722
+
1723
+ if args.snr_gamma is None:
1724
+ loss = F.mse_loss(
1725
+ model_pred.float(), target.float(), reduction="mean"
1726
+ )
1727
+ else:
1728
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1729
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1730
+ # This is discussed in Section 4.2 of the same paper.
1731
+ snr = compute_snr(noise_scheduler, timesteps)
1732
+ base_weight = (
1733
+ torch.stack(
1734
+ [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1
1735
+ ).min(dim=1)[0]
1736
+ / snr
1737
+ )
1738
+
1739
+ if noise_scheduler.config.prediction_type == "v_prediction":
1740
+ # Velocity objective needs to be floored to an SNR weight of one.
1741
+ mse_loss_weights = base_weight + 1
1742
+ else:
1743
+ # Epsilon and sample both use the same loss weights.
1744
+ mse_loss_weights = base_weight
1745
+
1746
+ loss = F.mse_loss(
1747
+ model_pred.float(), target.float(), reduction="none"
1748
+ )
1749
+ loss = (
1750
+ loss.mean(dim=list(range(1, len(loss.shape))))
1751
+ * mse_loss_weights
1752
+ )
1753
+ loss = loss.mean()
1754
+
1755
+ if args.with_prior_preservation:
1756
+ # Add the prior loss to the instance loss.
1757
+ loss = loss + args.prior_loss_weight * prior_loss
1758
+
1759
+ accelerator.backward(loss)
1760
+ if accelerator.sync_gradients:
1761
+ params_to_clip = (
1762
+ itertools.chain(
1763
+ unet_lora_parameters,
1764
+ text_lora_parameters_one,
1765
+ text_lora_parameters_two,
1766
+ )
1767
+ if args.train_text_encoder
1768
+ else unet_lora_parameters
1769
+ )
1770
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1771
+ optimizer.step()
1772
+ lr_scheduler.step()
1773
+ optimizer.zero_grad()
1774
+
1775
+ # Checks if the accelerator has performed an optimization step behind the scenes
1776
+ if accelerator.sync_gradients:
1777
+ progress_bar.update(1)
1778
+ global_step += 1
1779
+
1780
+ if accelerator.is_main_process:
1781
+ if global_step % args.checkpointing_steps == 0:
1782
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1783
+ if args.checkpoints_total_limit is not None:
1784
+ checkpoints = os.listdir(args.output_dir)
1785
+ checkpoints = [
1786
+ d for d in checkpoints if d.startswith("checkpoint")
1787
+ ]
1788
+ checkpoints = sorted(
1789
+ checkpoints, key=lambda x: int(x.split("-")[1])
1790
+ )
1791
+
1792
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1793
+ if len(checkpoints) >= args.checkpoints_total_limit:
1794
+ num_to_remove = (
1795
+ len(checkpoints) - args.checkpoints_total_limit + 1
1796
+ )
1797
+ removing_checkpoints = checkpoints[0:num_to_remove]
1798
+
1799
+ logger.info(
1800
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1801
+ )
1802
+ logger.info(
1803
+ f"removing checkpoints: {', '.join(removing_checkpoints)}"
1804
+ )
1805
+
1806
+ for removing_checkpoint in removing_checkpoints:
1807
+ removing_checkpoint = os.path.join(
1808
+ args.output_dir, removing_checkpoint
1809
+ )
1810
+ shutil.rmtree(removing_checkpoint)
1811
+
1812
+ save_path = os.path.join(
1813
+ args.output_dir, f"checkpoint-{global_step}"
1814
+ )
1815
+ accelerator.save_state(save_path)
1816
+ logger.info(f"Saved state to {save_path}")
1817
+
1818
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1819
+ progress_bar.set_postfix(**logs)
1820
+ accelerator.log(logs, step=global_step)
1821
+
1822
+ if global_step >= args.max_train_steps:
1823
+ break
1824
+
1825
+ if accelerator.is_main_process:
1826
+ if (
1827
+ args.validation_prompt is not None
1828
+ and epoch % args.validation_epochs == 0
1829
+ ):
1830
+ logger.info(
1831
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1832
+ f" {args.validation_prompt}."
1833
+ )
1834
+ # create pipeline
1835
+ if not args.train_text_encoder:
1836
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
1837
+ args.pretrained_model_name_or_path,
1838
+ subfolder="text_encoder",
1839
+ revision=args.revision,
1840
+ )
1841
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
1842
+ args.pretrained_model_name_or_path,
1843
+ subfolder="text_encoder_2",
1844
+ revision=args.revision,
1845
+ )
1846
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
1847
+ args.pretrained_model_name_or_path,
1848
+ vae=vae,
1849
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
1850
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
1851
+ unet=accelerator.unwrap_model(unet),
1852
+ revision=args.revision,
1853
+ torch_dtype=weight_dtype,
1854
+ )
1855
+
1856
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1857
+ scheduler_args = {}
1858
+
1859
+ if "variance_type" in pipeline.scheduler.config:
1860
+ variance_type = pipeline.scheduler.config.variance_type
1861
+
1862
+ if variance_type in ["learned", "learned_range"]:
1863
+ variance_type = "fixed_small"
1864
+
1865
+ scheduler_args["variance_type"] = variance_type
1866
+
1867
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1868
+ pipeline.scheduler.config, **scheduler_args
1869
+ )
1870
+
1871
+ pipeline = pipeline.to(accelerator.device)
1872
+ pipeline.set_progress_bar_config(disable=True)
1873
+
1874
+ # run inference
1875
+ generator = (
1876
+ torch.Generator(device=accelerator.device).manual_seed(args.seed)
1877
+ if args.seed
1878
+ else None
1879
+ )
1880
+ pipeline_args = {"prompt": args.validation_prompt,
1881
+ 'num_images_per_prompt': args.num_validation_images}
1882
+
1883
+ images = pipeline(**pipeline_args, generator=generator).images
1884
+
1885
+ for tracker in accelerator.trackers:
1886
+ if tracker.name == "tensorboard":
1887
+ np_images = np.stack([np.asarray(img) for img in images])
1888
+ tracker.writer.add_images(
1889
+ "validation", np_images, epoch, dataformats="NHWC"
1890
+ )
1891
+ if tracker.name == "wandb":
1892
+ tracker.log(
1893
+ {
1894
+ "validation": [
1895
+ wandb.Image(
1896
+ image, caption=f"{i}: {args.validation_prompt}"
1897
+ )
1898
+ for i, image in enumerate(images)
1899
+ ]
1900
+ }
1901
+ )
1902
+
1903
+ del pipeline
1904
+ torch.cuda.empty_cache()
1905
+
1906
+ # Save the lora layers
1907
+ accelerator.wait_for_everyone()
1908
+ if accelerator.is_main_process:
1909
+ unet = accelerator.unwrap_model(unet)
1910
+ unet = unet.to(torch.float32)
1911
+ unet_lora_layers = unet_lora_state_dict(unet)
1912
+
1913
+ if args.train_text_encoder:
1914
+ text_encoder_one = accelerator.unwrap_model(text_encoder_one)
1915
+ text_encoder_lora_layers = text_encoder_lora_state_dict(
1916
+ text_encoder_one.to(torch.float32)
1917
+ )
1918
+ text_encoder_two = accelerator.unwrap_model(text_encoder_two)
1919
+ text_encoder_2_lora_layers = text_encoder_lora_state_dict(
1920
+ text_encoder_two.to(torch.float32)
1921
+ )
1922
+ else:
1923
+ text_encoder_lora_layers = None
1924
+ text_encoder_2_lora_layers = None
1925
+
1926
+ StableDiffusionXLPipeline.save_lora_weights(
1927
+ save_directory=args.output_dir,
1928
+ unet_lora_layers=unet_lora_layers,
1929
+ text_encoder_lora_layers=text_encoder_lora_layers,
1930
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
1931
+ )
1932
+
1933
+ # remove unuse models for save GPU memory
1934
+ unet = unet.cpu()
1935
+ text_encoder_one = text_encoder_one.cpu()
1936
+ text_encoder_two = text_encoder_two.cpu()
1937
+ del unet, text_encoder_one, text_encoder_two
1938
+ del optimizer
1939
+ if args.train_text_encoder:
1940
+ del text_encoder_lora_layers, text_encoder_2_lora_layers
1941
+
1942
+ # Final inference
1943
+ # Load previous pipeline
1944
+ vae = AutoencoderKL.from_pretrained(
1945
+ vae_path,
1946
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1947
+ revision=args.revision,
1948
+ torch_dtype=weight_dtype,
1949
+ )
1950
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
1951
+ args.pretrained_model_name_or_path,
1952
+ vae=vae,
1953
+ revision=args.revision,
1954
+ torch_dtype=weight_dtype,
1955
+ )
1956
+
1957
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1958
+ scheduler_args = {}
1959
+
1960
+ if "variance_type" in pipeline.scheduler.config:
1961
+ variance_type = pipeline.scheduler.config.variance_type
1962
+
1963
+ if variance_type in ["learned", "learned_range"]:
1964
+ variance_type = "fixed_small"
1965
+
1966
+ scheduler_args["variance_type"] = variance_type
1967
+
1968
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1969
+ pipeline.scheduler.config, **scheduler_args
1970
+ )
1971
+
1972
+ # load attention processors
1973
+ pipeline.load_lora_weights(args.output_dir)
1974
+
1975
+ # run inference
1976
+ images = []
1977
+ if args.validation_prompt and args.num_validation_images > 0:
1978
+ pipeline = pipeline.to(accelerator.device)
1979
+ generator = (
1980
+ torch.Generator(device=accelerator.device).manual_seed(args.seed)
1981
+ if args.seed
1982
+ else None
1983
+ )
1984
+ images = pipeline(args.validation_prompt, num_images_per_prompt=args.num_validation_images,
1985
+ generator=generator).images
1986
+
1987
+
1988
+ for tracker in accelerator.trackers:
1989
+ if tracker.name == "tensorboard":
1990
+ np_images = np.stack([np.asarray(img) for img in images])
1991
+ tracker.writer.add_images(
1992
+ "test", np_images, epoch, dataformats="NHWC"
1993
+ )
1994
+ if tracker.name == "wandb":
1995
+ tracker.log(
1996
+ {
1997
+ "test": [
1998
+ wandb.Image(
1999
+ image, caption=f"{i}: {args.validation_prompt}"
2000
+ )
2001
+ for i, image in enumerate(images)
2002
+ ]
2003
+ }
2004
+ )
2005
+
2006
+ if args.push_to_hub:
2007
+ save_model_card(
2008
+ repo_id,
2009
+ images=images,
2010
+ base_model=args.pretrained_model_name_or_path,
2011
+ train_text_encoder=args.train_text_encoder,
2012
+ instance_prompt=args.instance_prompt,
2013
+ validation_prompt=args.validation_prompt,
2014
+ repo_folder=args.output_dir,
2015
+ vae_path=args.pretrained_vae_model_name_or_path,
2016
+ )
2017
+ upload_folder(
2018
+ repo_id=repo_id,
2019
+ folder_path=args.output_dir,
2020
+ commit_message="End of training",
2021
+ ignore_patterns=["step_*", "epoch_*"],
2022
+ )
2023
+
2024
+ accelerator.end_training()
2025
+
2026
+
2027
+ if __name__ == "__main__":
2028
+ args = parse_args()
2029
+ main(args)