peteli commited on
Commit
4f23115
1 Parent(s): 6228581

Upload 3 files

Browse files
Files changed (3) hide show
  1. train_dreambooth_lora.py +966 -0
  2. untitled.streamlit.py +32 -0
  3. utils.py +133 -0
train_dreambooth_lora.py ADDED
@@ -0,0 +1,966 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from paddlenlp.utils.log import logger
16
+ logger.set_level("WARNING")
17
+ import paddle
18
+ import argparse
19
+ import contextlib
20
+ import gc
21
+ import hashlib
22
+ import math
23
+ import os
24
+ import sys
25
+ import warnings
26
+ from pathlib import Path
27
+ from typing import Optional
28
+
29
+ import numpy as np
30
+ import paddle
31
+ import paddle.nn as nn
32
+ import paddle.nn.functional as F
33
+ import requests
34
+ from huggingface_hub import HfFolder, create_repo, upload_folder, whoami
35
+ from paddle.distributed.fleet.utils.hybrid_parallel_util import (
36
+ fused_allreduce_gradients,
37
+ )
38
+ from utils import context_nologging, _retry
39
+ from paddle.io import BatchSampler, DataLoader, Dataset, DistributedBatchSampler
40
+ from paddle.optimizer import AdamW
41
+ from paddle.vision import BaseTransform, transforms
42
+ from PIL import Image
43
+ from tqdm.auto import tqdm
44
+
45
+ from paddlenlp.trainer import set_seed
46
+ from paddlenlp.transformers import AutoTokenizer, PretrainedConfig
47
+ from ppdiffusers import (
48
+ AutoencoderKL,
49
+ DDPMScheduler,
50
+ DiffusionPipeline,
51
+ DPMSolverMultistepScheduler,
52
+ UNet2DConditionModel,
53
+ )
54
+ from ppdiffusers.loaders import AttnProcsLayers
55
+ from ppdiffusers.modeling_utils import freeze_params, unwrap_model
56
+ from ppdiffusers.models.cross_attention import LoRACrossAttnProcessor
57
+ from ppdiffusers.optimization import get_scheduler
58
+ from ppdiffusers.utils import image_grid
59
+
60
+ def str2bool(v):
61
+ if v.lower() in ("yes", "true", "t", "y", "1"):
62
+ return True
63
+ elif v.lower() in ("no", "false", "f", "n", "0"):
64
+ return False
65
+ else:
66
+ raise argparse.ArgumentTypeError("Unsupported value encountered.")
67
+
68
+ def url_or_path_join(*path_list):
69
+ return os.path.join(*path_list) if os.path.isdir(os.path.join(*path_list)) else "/".join(path_list)
70
+
71
+
72
+ def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None):
73
+ img_str = ""
74
+ for i, image in enumerate(images):
75
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
76
+ img_str += f"![img_{i}](./image_{i}.png)\n"
77
+
78
+ yaml = f"""
79
+ ---
80
+ license: creativeml-openrail-m
81
+ base_model: {base_model}
82
+ instance_prompt: {prompt}
83
+ tags:
84
+ - stable-diffusion
85
+ - stable-diffusion-ppdiffusers
86
+ - text-to-image
87
+ - ppdiffusers
88
+ - lora
89
+ inference: false
90
+ ---
91
+ """
92
+ model_card = f"""
93
+ # LoRA DreamBooth - {repo_name}
94
+ These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
95
+ {img_str}
96
+ """
97
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
98
+ f.write(yaml + model_card)
99
+
100
+
101
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
102
+ try:
103
+ text_encoder_config = PretrainedConfig.from_pretrained(
104
+ url_or_path_join(pretrained_model_name_or_path, "text_encoder")
105
+ )
106
+ model_class = text_encoder_config.architectures[0]
107
+ except Exception:
108
+ model_class = "LDMBertModel"
109
+ if model_class == "CLIPTextModel":
110
+ from paddlenlp.transformers import CLIPTextModel
111
+
112
+ return CLIPTextModel
113
+ elif model_class == "RobertaSeriesModelWithTransformation":
114
+ from ppdiffusers.pipelines.alt_diffusion.modeling_roberta_series import (
115
+ RobertaSeriesModelWithTransformation,
116
+ )
117
+
118
+ return RobertaSeriesModelWithTransformation
119
+ elif model_class == "BertModel":
120
+ from paddlenlp.transformers import BertModel
121
+
122
+ return BertModel
123
+ elif model_class == "LDMBertModel":
124
+ from ppdiffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
125
+ LDMBertModel,
126
+ )
127
+
128
+ return LDMBertModel
129
+ else:
130
+ raise ValueError(f"{model_class} is not supported.")
131
+
132
+
133
+ class Lambda(BaseTransform):
134
+ def __init__(self, fn, keys=None):
135
+ super().__init__(keys)
136
+ self.fn = fn
137
+
138
+ def _apply_image(self, img):
139
+ return self.fn(img)
140
+
141
+
142
+ def get_report_to(args):
143
+ if args.report_to == "visualdl":
144
+ from visualdl import LogWriter
145
+
146
+ writer = LogWriter(logdir=args.logging_dir)
147
+ elif args.report_to == "tensorboard":
148
+ from tensorboardX import SummaryWriter
149
+
150
+ writer = SummaryWriter(logdir=args.logging_dir)
151
+ else:
152
+ raise ValueError("report_to must be in ['visualdl', 'tensorboard']")
153
+ return writer
154
+
155
+
156
+ def parse_args(input_args=None):
157
+ parser = argparse.ArgumentParser(description="Simple example of a training dreambooth lora script.")
158
+ parser.add_argument(
159
+ "--pretrained_model_name_or_path",
160
+ type=str,
161
+ default=None,
162
+ required=True,
163
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
164
+ )
165
+ parser.add_argument(
166
+ "--tokenizer_name",
167
+ type=str,
168
+ default=None,
169
+ help="Pretrained tokenizer name or path if not the same as model_name",
170
+ )
171
+ parser.add_argument(
172
+ "--instance_data_dir",
173
+ type=str,
174
+ default=None,
175
+ required=True,
176
+ help="A folder containing the training data of instance images.",
177
+ )
178
+ parser.add_argument(
179
+ "--class_data_dir",
180
+ type=str,
181
+ default=None,
182
+ required=False,
183
+ help="A folder containing the training data of class images.",
184
+ )
185
+ parser.add_argument(
186
+ "--instance_prompt",
187
+ type=str,
188
+ default=None,
189
+ required=True,
190
+ help="The prompt with identifier specifying the instance",
191
+ )
192
+ parser.add_argument(
193
+ "--class_prompt",
194
+ type=str,
195
+ default=None,
196
+ help="The prompt to specify images in the same class as provided instance images.",
197
+ )
198
+ parser.add_argument(
199
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
200
+ )
201
+ parser.add_argument(
202
+ "--num_validation_images",
203
+ type=int,
204
+ default=4,
205
+ help="Number of images that should be generated during validation with `validation_prompt`.",
206
+ )
207
+ parser.add_argument(
208
+ "--validation_steps",
209
+ type=int,
210
+ default=50,
211
+ help=(
212
+ "Run dreambooth validation every X global steps. Dreambooth validation consists of running the prompt"
213
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
214
+ ),
215
+ )
216
+ parser.add_argument(
217
+ "--with_prior_preservation",
218
+ default=False,
219
+ action="store_true",
220
+ help="Flag to add prior preservation loss.",
221
+ )
222
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
223
+ parser.add_argument(
224
+ "--num_class_images",
225
+ type=int,
226
+ default=100,
227
+ help=(
228
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
229
+ " class_data_dir, additional images will be sampled with class_prompt."
230
+ ),
231
+ )
232
+ parser.add_argument(
233
+ "--lora_rank",
234
+ type=int,
235
+ default=4,
236
+ help=(
237
+ "lora_rank"
238
+ ),
239
+ )
240
+
241
+ parser.add_argument(
242
+ "--output_dir",
243
+ type=str,
244
+ default="lora-dreambooth-model",
245
+ help="The output directory where the model predictions and checkpoints will be written.",
246
+ )
247
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
248
+ parser.add_argument(
249
+ "--height",
250
+ type=int,
251
+ default=None,
252
+ help=(
253
+ "The height for input images, all the images in the train/validation dataset will be resized to this"
254
+ " height"
255
+ ),
256
+ )
257
+ parser.add_argument(
258
+ "--width",
259
+ type=int,
260
+ default=None,
261
+ help=(
262
+ "The width for input images, all the images in the train/validation dataset will be resized to this"
263
+ " width"
264
+ ),
265
+ )
266
+ parser.add_argument(
267
+ "--resolution",
268
+ type=int,
269
+ default=512,
270
+ help=(
271
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
272
+ " resolution"
273
+ ),
274
+ )
275
+ parser.add_argument(
276
+ "--center_crop",
277
+ default=False,
278
+ action="store_true",
279
+ help=(
280
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
281
+ " cropped. The images will be resized to the resolution first before cropping."
282
+ ),
283
+ )
284
+ parser.add_argument(
285
+ "--random_flip",
286
+ action="store_true",
287
+ help="whether to randomly flip images horizontally",
288
+ )
289
+ parser.add_argument(
290
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
291
+ )
292
+ parser.add_argument(
293
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
294
+ )
295
+ parser.add_argument("--num_train_epochs", type=int, default=1)
296
+ parser.add_argument(
297
+ "--max_train_steps",
298
+ type=int,
299
+ default=500,
300
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
301
+ )
302
+ parser.add_argument(
303
+ "--checkpointing_steps",
304
+ type=int,
305
+ default=100,
306
+ help=("Save a checkpoint of the training state every X updates."),
307
+ )
308
+ parser.add_argument(
309
+ "--gradient_accumulation_steps",
310
+ type=int,
311
+ default=1,
312
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
313
+ )
314
+ parser.add_argument(
315
+ "--gradient_checkpointing",
316
+ action="store_true",
317
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
318
+ )
319
+ parser.add_argument(
320
+ "--learning_rate",
321
+ type=float,
322
+ default=5e-4,
323
+ help="Initial learning rate (after the potential warmup period) to use.",
324
+ )
325
+ parser.add_argument(
326
+ "--scale_lr",
327
+ action="store_true",
328
+ default=False,
329
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
330
+ )
331
+ parser.add_argument(
332
+ "--lr_scheduler",
333
+ type=str,
334
+ default="constant",
335
+ help=(
336
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
337
+ ' "constant", "constant_with_warmup"]'
338
+ ),
339
+ )
340
+ parser.add_argument(
341
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
342
+ )
343
+ parser.add_argument(
344
+ "--lr_num_cycles",
345
+ type=int,
346
+ default=1,
347
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
348
+ )
349
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
350
+ parser.add_argument(
351
+ "--dataloader_num_workers",
352
+ type=int,
353
+ default=0,
354
+ help=(
355
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
356
+ ),
357
+ )
358
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
359
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
360
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
361
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
362
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
363
+ parser.add_argument("--push_to_hub", type=str2bool, nargs="?", const=False, help="Whether or not to push the model to the Hub.")
364
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
365
+ parser.add_argument(
366
+ "--hub_model_id",
367
+ type=str,
368
+ default=None,
369
+ help="The name of the repository to keep in sync with the local `output_dir`.",
370
+ )
371
+ parser.add_argument(
372
+ "--logging_dir",
373
+ type=str,
374
+ default="logs",
375
+ help=(
376
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) or [VisualDL](https://www.paddlepaddle.org.cn/paddle/visualdl) log directory. Will default to"
377
+ "*output_dir/logs"
378
+ ),
379
+ )
380
+ parser.add_argument(
381
+ "--report_to",
382
+ type=str,
383
+ default="visualdl",
384
+ choices=["tensorboard", "visualdl"],
385
+ help="Log writer type.",
386
+ )
387
+ if input_args is not None:
388
+ args = parser.parse_args(input_args)
389
+ else:
390
+ args = parser.parse_args()
391
+
392
+ if args.instance_data_dir is None:
393
+ raise ValueError("You must specify a train data directory.")
394
+
395
+ if args.with_prior_preservation:
396
+ if args.class_data_dir is None:
397
+ raise ValueError("You must specify a data directory for class images.")
398
+ if args.class_prompt is None:
399
+ raise ValueError("You must specify prompt for class images.")
400
+ else:
401
+ # logger is not available yet
402
+ if args.class_data_dir is not None:
403
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
404
+ if args.class_prompt is not None:
405
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
406
+
407
+ args.logging_dir = os.path.join(args.output_dir, args.logging_dir)
408
+ if args.height is None or args.width is None and args.resolution is not None:
409
+ args.height = args.width = args.resolution
410
+
411
+ return args
412
+
413
+
414
+ class DreamBoothDataset(Dataset):
415
+ """
416
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
417
+ It pre-processes the images and the tokenizes prompts.
418
+ """
419
+
420
+ def __init__(
421
+ self,
422
+ instance_data_root,
423
+ instance_prompt,
424
+ tokenizer,
425
+ class_data_root=None,
426
+ class_prompt=None,
427
+ height=512,
428
+ width=512,
429
+ center_crop=False,
430
+ interpolation="bilinear",
431
+ random_flip=False,
432
+ ):
433
+ self.height = height
434
+ self.width = width
435
+ self.center_crop = center_crop
436
+ self.tokenizer = tokenizer
437
+
438
+ self.instance_data_root = Path(instance_data_root)
439
+ if not self.instance_data_root.exists():
440
+ raise ValueError("Instance images root doesn't exists.")
441
+ ext = ["png", "jpg", "jpeg", "bmp", "PNG", "JPG", "JPEG", "BMP"]
442
+ self.instance_images_path = []
443
+ for p in Path(instance_data_root).iterdir():
444
+ if any(suffix in p.name for suffix in ext):
445
+ self.instance_images_path.append(p)
446
+ self.num_instance_images = len(self.instance_images_path)
447
+ self.instance_prompt = instance_prompt
448
+ self._length = self.num_instance_images
449
+
450
+ if class_data_root is not None:
451
+ self.class_data_root = Path(class_data_root)
452
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
453
+ self.class_images_path = []
454
+ for p in Path(class_data_root).iterdir():
455
+ if any(suffix in p.name for suffix in ext):
456
+ self.class_images_path.append(p)
457
+ self.num_class_images = len(self.class_images_path)
458
+ self._length = max(self.num_class_images, self.num_instance_images)
459
+ self.class_prompt = class_prompt
460
+ else:
461
+ self.class_data_root = None
462
+
463
+ self.image_transforms = transforms.Compose(
464
+ [
465
+ transforms.Resize((height, width), interpolation=interpolation),
466
+ transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
467
+ transforms.RandomHorizontalFlip() if random_flip else Lambda(lambda x: x),
468
+ transforms.ToTensor(),
469
+ transforms.Normalize([0.5], [0.5]),
470
+ ]
471
+ )
472
+
473
+ def __len__(self):
474
+ return self._length
475
+
476
+ def __getitem__(self, index):
477
+ example = {}
478
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
479
+ if not instance_image.mode == "RGB":
480
+ instance_image = instance_image.convert("RGB")
481
+ example["instance_images"] = self.image_transforms(instance_image)
482
+ example["instance_prompt_ids"] = self.tokenizer(
483
+ self.instance_prompt,
484
+ padding="do_not_pad",
485
+ truncation=True,
486
+ max_length=self.tokenizer.model_max_length,
487
+ return_attention_mask=False,
488
+ ).input_ids
489
+
490
+ if self.class_data_root:
491
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
492
+ if not class_image.mode == "RGB":
493
+ class_image = class_image.convert("RGB")
494
+ example["class_images"] = self.image_transforms(class_image)
495
+ example["class_prompt_ids"] = self.tokenizer(
496
+ self.class_prompt,
497
+ padding="do_not_pad",
498
+ truncation=True,
499
+ max_length=self.tokenizer.model_max_length,
500
+ return_attention_mask=False,
501
+ ).input_ids
502
+
503
+ return example
504
+
505
+
506
+ class PromptDataset(Dataset):
507
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
508
+
509
+ def __init__(self, prompt, num_samples):
510
+ self.prompt = prompt
511
+ self.num_samples = num_samples
512
+
513
+ def __len__(self):
514
+ return self.num_samples
515
+
516
+ def __getitem__(self, index):
517
+ example = {}
518
+ example["prompt"] = self.prompt
519
+ example["index"] = index
520
+ return example
521
+
522
+
523
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
524
+ if token is None:
525
+ token = HfFolder.get_token()
526
+ if organization is None:
527
+ username = whoami(token)["name"]
528
+ return f"{username}/{model_id}"
529
+ else:
530
+ return f"{organization}/{model_id}"
531
+
532
+
533
+ def main():
534
+ paddle.randn((1,))
535
+ args = parse_args()
536
+ rank = paddle.distributed.get_rank()
537
+ is_main_process = rank == 0
538
+ num_processes = paddle.distributed.get_world_size()
539
+ if num_processes > 1:
540
+ paddle.distributed.init_parallel_env()
541
+
542
+ # If passed along, set the training seed now.
543
+ if args.seed is not None:
544
+ set_seed(args.seed)
545
+
546
+ # Generate class images if prior preservation is enabled.
547
+ if args.with_prior_preservation:
548
+ class_images_dir = Path(args.class_data_dir)
549
+ if not class_images_dir.exists():
550
+ class_images_dir.mkdir(parents=True)
551
+ cur_class_images = len(list(class_images_dir.iterdir()))
552
+
553
+ if cur_class_images < args.num_class_images:
554
+ with context_nologging():
555
+ pipeline = DiffusionPipeline.from_pretrained(
556
+ args.pretrained_model_name_or_path,
557
+ safety_checker=None,
558
+ )
559
+ pipeline.set_progress_bar_config(disable=True)
560
+
561
+ num_new_images = args.num_class_images - cur_class_images
562
+ logger.info(f"Number of class images to sample: {num_new_images}.")
563
+
564
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
565
+ batch_sampler = (
566
+ DistributedBatchSampler(sample_dataset, batch_size=args.sample_batch_size, shuffle=False)
567
+ if num_processes > 1
568
+ else BatchSampler(sample_dataset, batch_size=args.sample_batch_size, shuffle=False)
569
+ )
570
+ sample_dataloader = DataLoader(
571
+ sample_dataset, batch_sampler=batch_sampler, num_workers=args.dataloader_num_workers
572
+ )
573
+
574
+ for example in tqdm(sample_dataloader, desc="Generating class images", disable=not is_main_process, ncols=100):
575
+ images = pipeline(example["prompt"]).images
576
+
577
+ for i, image in enumerate(images):
578
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
579
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
580
+ image.save(image_filename)
581
+ pipeline.to("cpu")
582
+ del pipeline
583
+ gc.collect()
584
+
585
+ if is_main_process:
586
+ if args.output_dir is not None:
587
+ os.makedirs(args.output_dir, exist_ok=True)
588
+
589
+ print("正在下载模型权重,请耐心等待。。。。。。。。。。")
590
+ with context_nologging():
591
+ # Load the tokenizer
592
+ if args.tokenizer_name:
593
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
594
+ elif args.pretrained_model_name_or_path:
595
+ tokenizer = AutoTokenizer.from_pretrained(url_or_path_join(args.pretrained_model_name_or_path, "tokenizer"))
596
+
597
+ # import correct text encoder class
598
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)
599
+
600
+ # Load scheduler and models
601
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
602
+ text_encoder = text_encoder_cls.from_pretrained(
603
+ url_or_path_join(args.pretrained_model_name_or_path, "text_encoder")
604
+ )
605
+ text_config = text_encoder.config if isinstance(text_encoder.config, dict) else text_encoder.config.to_dict()
606
+ if text_config.get("use_attention_mask", None) is not None and text_config["use_attention_mask"]:
607
+ use_attention_mask = True
608
+ else:
609
+ use_attention_mask = False
610
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
611
+ unet = UNet2DConditionModel.from_pretrained(
612
+ args.pretrained_model_name_or_path,
613
+ subfolder="unet",
614
+ )
615
+
616
+ # We only train the additional adapter LoRA layers
617
+ freeze_params(vae.parameters())
618
+ freeze_params(text_encoder.parameters())
619
+ freeze_params(unet.parameters())
620
+
621
+ # now we will add new LoRA weights to the attention layers
622
+ # It's important to realize here how many attention weights will be added and of which sizes
623
+ # The sizes of the attention layers consist only of two different variables:
624
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
625
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
626
+
627
+ # Let's first see how many attention processors we will have to set.
628
+ # For Stable Diffusion, it should be equal to:
629
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
630
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
631
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
632
+ # => 32 layers
633
+
634
+ # Set correct lora layers
635
+ lora_attn_procs = {}
636
+ for name in unet.attn_processors.keys():
637
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
638
+ if name.startswith("mid_block"):
639
+ hidden_size = unet.config.block_out_channels[-1]
640
+ elif name.startswith("up_blocks"):
641
+ block_id = int(name[len("up_blocks.")])
642
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
643
+ elif name.startswith("down_blocks"):
644
+ block_id = int(name[len("down_blocks.")])
645
+ hidden_size = unet.config.block_out_channels[block_id]
646
+
647
+ lora_attn_procs[name] = LoRACrossAttnProcessor(
648
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.lora_rank
649
+ )
650
+
651
+ unet.set_attn_processor(lora_attn_procs)
652
+ lora_layers = AttnProcsLayers(unet.attn_processors)
653
+
654
+ # Dataset and DataLoaders creation:
655
+ train_dataset = DreamBoothDataset(
656
+ instance_data_root=args.instance_data_dir,
657
+ instance_prompt=args.instance_prompt,
658
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
659
+ class_prompt=args.class_prompt,
660
+ tokenizer=tokenizer,
661
+ height=args.height,
662
+ width=args.width,
663
+ center_crop=args.center_crop,
664
+ interpolation="bilinear",
665
+ random_flip=args.random_flip,
666
+ )
667
+
668
+ def collate_fn(examples):
669
+ input_ids = [example["instance_prompt_ids"] for example in examples]
670
+ pixel_values = [example["instance_images"] for example in examples]
671
+
672
+ # Concat class and instance examples for prior preservation.
673
+ # We do this to avoid doing two forward passes.
674
+ if args.with_prior_preservation:
675
+ input_ids += [example["class_prompt_ids"] for example in examples]
676
+ pixel_values += [example["class_images"] for example in examples]
677
+
678
+ pixel_values = paddle.stack(pixel_values).astype("float32")
679
+
680
+ input_ids = tokenizer.pad(
681
+ {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pd"
682
+ ).input_ids
683
+
684
+ return {
685
+ "input_ids": input_ids,
686
+ "pixel_values": pixel_values,
687
+ }
688
+
689
+ train_sampler = (
690
+ DistributedBatchSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True)
691
+ if num_processes > 1
692
+ else BatchSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True)
693
+ )
694
+ train_dataloader = DataLoader(
695
+ train_dataset, batch_sampler=train_sampler, collate_fn=collate_fn, num_workers=args.dataloader_num_workers
696
+ )
697
+
698
+ # Scheduler and math around the number of training steps.
699
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
700
+ if args.max_train_steps is None:
701
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
702
+ # Afterwards we recalculate our number of training epochs
703
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
704
+
705
+ if args.scale_lr:
706
+ args.learning_rate = (
707
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * num_processes
708
+ )
709
+
710
+ lr_scheduler = get_scheduler(
711
+ args.lr_scheduler,
712
+ learning_rate=args.learning_rate,
713
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
714
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
715
+ num_cycles=args.lr_num_cycles,
716
+ power=args.lr_power,
717
+ )
718
+
719
+ # Optimizer creation
720
+ optimizer = AdamW(
721
+ learning_rate=lr_scheduler,
722
+ parameters=lora_layers.parameters(),
723
+ beta1=args.adam_beta1,
724
+ beta2=args.adam_beta2,
725
+ weight_decay=args.adam_weight_decay,
726
+ epsilon=args.adam_epsilon,
727
+ grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm) if args.max_grad_norm > 0 else None,
728
+ )
729
+
730
+ if num_processes > 1:
731
+ unet = paddle.DataParallel(unet)
732
+
733
+ if is_main_process:
734
+ logger.info("----------- Configuration Arguments -----------")
735
+ for arg, value in sorted(vars(args).items()):
736
+ logger.info("%s: %s" % (arg, value))
737
+ logger.info("------------------------------------------------")
738
+ writer = get_report_to(args)
739
+
740
+ # Train!
741
+ total_batch_size = args.train_batch_size * num_processes * args.gradient_accumulation_steps
742
+
743
+ logger.info("***** Running training *****")
744
+ logger.info(f" Num examples = {len(train_dataset)}")
745
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
746
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
747
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
748
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
749
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
750
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
751
+
752
+ # Only show the progress bar once on each machine.
753
+ progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process, ncols=100)
754
+ progress_bar.set_description("Train Steps")
755
+ global_step = 0
756
+ vae.eval()
757
+ text_encoder.eval()
758
+
759
+ for epoch in range(args.num_train_epochs):
760
+ unet.train()
761
+ for step, batch in enumerate(train_dataloader):
762
+ # Convert images to latent space
763
+ latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
764
+ latents = latents * 0.18215
765
+
766
+ # Sample noise that we'll add to the latents
767
+ noise = paddle.randn(latents.shape)
768
+ batch_size = latents.shape[0]
769
+ # Sample a random timestep for each image
770
+ timesteps = paddle.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,)).cast("int64")
771
+
772
+ # Add noise to the latents according to the noise magnitude at each timestep
773
+ # (this is the forward diffusion process)
774
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
775
+
776
+ if num_processes > 1 and (
777
+ args.gradient_checkpointing or ((step + 1) % args.gradient_accumulation_steps != 0)
778
+ ):
779
+ # grad acc, no_sync when (step + 1) % args.gradient_accumulation_steps != 0:
780
+ # gradient_checkpointing, no_sync every where
781
+ # gradient_checkpointing + grad_acc, no_sync every where
782
+ unet_ctx_manager = unet.no_sync()
783
+ else:
784
+ unet_ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()
785
+
786
+ if use_attention_mask:
787
+ attention_mask = (batch["input_ids"] != tokenizer.pad_token_id).cast("int64")
788
+ else:
789
+ attention_mask = None
790
+ encoder_hidden_states = text_encoder(batch["input_ids"], attention_mask=attention_mask)[0]
791
+
792
+ with unet_ctx_manager:
793
+ # Predict the noise residual / sample
794
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
795
+
796
+ # Get the target for loss depending on the prediction type
797
+ if noise_scheduler.config.prediction_type == "epsilon":
798
+ target = noise
799
+ elif noise_scheduler.config.prediction_type == "v_prediction":
800
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
801
+ else:
802
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
803
+
804
+ if args.with_prior_preservation:
805
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
806
+ model_pred, model_pred_prior = model_pred.chunk(2, axis=0)
807
+ target, target_prior = target.chunk(2, axis=0)
808
+
809
+ # Compute instance loss
810
+ loss = F.mse_loss(model_pred, target, reduction="mean")
811
+
812
+ # Compute prior loss
813
+ prior_loss = F.mse_loss(model_pred_prior, target_prior, reduction="mean")
814
+
815
+ # Add the prior loss to the instance loss.
816
+ loss = loss + args.prior_loss_weight * prior_loss
817
+ else:
818
+ loss = F.mse_loss(model_pred, target, reduction="mean")
819
+
820
+ if args.gradient_accumulation_steps > 1:
821
+ loss = loss / args.gradient_accumulation_steps
822
+ loss.backward()
823
+
824
+ if (step + 1) % args.gradient_accumulation_steps == 0:
825
+ if num_processes > 1 and args.gradient_checkpointing:
826
+ fused_allreduce_gradients(lora_layers.parameters(), None)
827
+ optimizer.step()
828
+ lr_scheduler.step()
829
+ optimizer.clear_grad()
830
+ progress_bar.update(1)
831
+ global_step += 1
832
+ step_loss = loss.item() * args.gradient_accumulation_steps
833
+ logs = {
834
+ "epoch": str(epoch).zfill(4),
835
+ "step_loss": round(step_loss, 10),
836
+ "lr": lr_scheduler.get_lr(),
837
+ }
838
+ progress_bar.set_postfix(**logs)
839
+
840
+ if is_main_process:
841
+ for name, val in logs.items():
842
+ if name == "epoch":
843
+ continue
844
+ writer.add_scalar(f"train/{name}", val, step=global_step)
845
+
846
+ if global_step % args.checkpointing_steps == 0:
847
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
848
+ with context_nologging():
849
+ unwrap_model(unet).save_attn_procs(save_path)
850
+ print(f"\n Saved lora weights to {save_path}")
851
+
852
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
853
+ with context_nologging():
854
+ logger.info(
855
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
856
+ f" {args.validation_prompt}."
857
+ )
858
+ # create pipeline
859
+ pipeline = DiffusionPipeline.from_pretrained(
860
+ args.pretrained_model_name_or_path,
861
+ unet=unwrap_model(unet),
862
+ safety_checker=None,
863
+ )
864
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
865
+ pipeline.set_progress_bar_config(disable=True)
866
+
867
+ # run inference
868
+ generator = paddle.Generator().manual_seed(args.seed) if args.seed else None
869
+ images = [
870
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
871
+ for _ in range(args.num_validation_images)
872
+ ]
873
+ png_save_path = os.path.join(args.output_dir, "validation_images")
874
+ os.makedirs(png_save_path, exist_ok=True)
875
+ if len(images) == 1:
876
+ gird_image = images[0]
877
+ elif len(images) == 2:
878
+ gird_image = image_grid(images, 1, 2)
879
+ else:
880
+ display_images = 2 * (len(images) // 2)
881
+ gird_image = image_grid(images[:display_images], 2, display_images // 2)
882
+ gird_image.save(os.path.join(png_save_path, f"{global_step}.png"))
883
+
884
+ np_images = np.stack([np.asarray(img) for img in images])
885
+
886
+ if args.report_to == "tensorboard":
887
+ writer.add_images("test", np_images, epoch, dataformats="NHWC")
888
+ else:
889
+ writer.add_image("test", np_images, epoch, dataformats="NHWC")
890
+ del pipeline
891
+ gc.collect()
892
+
893
+ if global_step >= args.max_train_steps:
894
+ break
895
+ # Save the lora layers
896
+ if is_main_process:
897
+ unet = unwrap_model(unet)
898
+ unet.save_attn_procs(args.output_dir)
899
+
900
+ # Final inference
901
+ # Load previous pipeline
902
+ with context_nologging():
903
+ pipeline = DiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, safety_checker=None)
904
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
905
+ pipeline.set_progress_bar_config(disable=True)
906
+ # load attention processors
907
+ pipeline.unet.load_attn_procs(args.output_dir)
908
+
909
+ # run inference
910
+ if args.validation_prompt and args.num_validation_images > 0:
911
+ generator = paddle.Generator().manual_seed(args.seed) if args.seed else None
912
+ images = [
913
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
914
+ for _ in range(args.num_validation_images)
915
+ ]
916
+ np_images = np.stack([np.asarray(img) for img in images])
917
+
918
+ if args.report_to == "tensorboard":
919
+ writer.add_images("test", np_images, epoch, dataformats="NHWC")
920
+ else:
921
+ writer.add_image("test", np_images, epoch, dataformats="NHWC")
922
+
923
+ writer.close()
924
+
925
+ # logic to push to HF Hub
926
+ if args.push_to_hub:
927
+ if args.hub_model_id is None:
928
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
929
+ else:
930
+ repo_name = args.hub_model_id
931
+
932
+ _retry(
933
+ create_repo,
934
+ func_kwargs={"repo_id": repo_name, "exist_ok": True, "token": args.hub_token},
935
+ base_wait_time=1.0,
936
+ max_retries=5,
937
+ max_wait_time=10.0,
938
+ )
939
+
940
+ save_model_card(
941
+ repo_name,
942
+ images=images,
943
+ base_model=args.pretrained_model_name_or_path,
944
+ prompt=args.instance_prompt,
945
+ repo_folder=args.output_dir,
946
+ )
947
+ # Upload model
948
+ logger.info(f"Pushing to {repo_name}")
949
+ _retry(
950
+ upload_folder,
951
+ func_kwargs={
952
+ "repo_id": repo_name,
953
+ "repo_type": "model",
954
+ "folder_path": args.output_dir,
955
+ "commit_message": "End of training",
956
+ "token": args.hub_token,
957
+ "ignore_patterns": ["checkpoint-*/*", "logs/*"],
958
+ },
959
+ base_wait_time=1.0,
960
+ max_retries=5,
961
+ max_wait_time=20.0,
962
+ )
963
+
964
+
965
+ if __name__ == "__main__":
966
+ main()
untitled.streamlit.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("pip install paddlenlp==2.5.2")
3
+ os.system("pip install ppdiffusers==0.11.1")
4
+
5
+ from ppdiffusers import DiffusionPipeline, DPMSolverMultistepScheduler
6
+ import paddle
7
+ import streamlit as st
8
+
9
+ st.header("用LoRA 和 DreamBooth画出你眼中的HomeTown")
10
+ st.image("lora_outputs/validation_images/700.png")
11
+
12
+ st.subheader("设置Prompt")
13
+
14
+ st.write("输入你眼中家乡的几个关键句吧!看看能不能绘制出你心中的家乡!")
15
+
16
+ pr1 = st.text_input('你心中的家乡:',"")
17
+
18
+ test = st.button("⚡⚡开始生成⚡⚡")
19
+ if test:
20
+ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
21
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
22
+
23
+ prompt = pr1
24
+ image = pipe(prompt).images[0]
25
+
26
+ image.save("demo.png")
27
+ st.image("demo.png")
28
+
29
+ st.success('推理完毕!!', icon="✅")
30
+ st.write("移动端长按图片可保存,PC端右键图片可另存到本地。")
31
+
32
+
utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from paddlenlp.utils.serialization import load_torch
2
+ import paddle
3
+ import safetensors.numpy
4
+ import os
5
+ import ppdiffusers
6
+ from contextlib import contextmanager
7
+
8
+ @contextmanager
9
+ def context_nologging():
10
+ ppdiffusers.utils.logging.set_verbosity_error()
11
+ try:
12
+ yield
13
+ finally:
14
+ ppdiffusers.utils.logging.set_verbosity_info()
15
+
16
+
17
+ __all__ = ['convert_paddle_lora_to_safetensor_lora', 'convert_pytorch_lora_to_paddle_lora']
18
+
19
+ def convert_paddle_lora_to_safetensor_lora(paddle_file, safe_file=None):
20
+ if not os.path.exists(paddle_file):
21
+ print(f"{paddle_file} 文件不存在!")
22
+ return
23
+ if safe_file is None:
24
+ safe_file = paddle_file.replace("paddle_lora_weights.pdparams", "pytorch_lora_weights.safetensors")
25
+
26
+ tensors = paddle.load(paddle_file)
27
+ new_tensors = {}
28
+ for k, v in tensors.items():
29
+ new_tensors[k] = v.cpu().numpy().T
30
+ safetensors.numpy.save_file(new_tensors, safe_file)
31
+ print(f"文件已经保存到{safe_file}!")
32
+
33
+ def convert_pytorch_lora_to_paddle_lora(pytorch_file, paddle_file=None):
34
+ if not os.path.exists(pytorch_file):
35
+ print(f"{pytorch_file} 文件不存在!")
36
+ return
37
+ if paddle_file is None:
38
+ paddle_file = pytorch_file.replace("pytorch_lora_weights.bin", "paddle_lora_weights.pdparams")
39
+
40
+ tensors = load_torch(pytorch_file)
41
+ new_tensors = {}
42
+ for k, v in tensors.items():
43
+ new_tensors[k] = v.T
44
+ paddle.save(new_tensors, paddle_file)
45
+ print(f"文件已经保存到{paddle_file}!")
46
+
47
+
48
+
49
+ import time
50
+ from typing import Optional, Type
51
+ import paddle
52
+ import requests
53
+ from huggingface_hub import create_repo, upload_folder, get_full_repo_name
54
+
55
+ # Since HF sometimes timeout, we need to retry uploads
56
+ # Credit: https://github.com/huggingface/datasets/blob/06ae3f678651bfbb3ca7dd3274ee2f38e0e0237e/src/datasets/utils/file_utils.py#L265
57
+ def _retry(
58
+ func,
59
+ func_args: Optional[tuple] = None,
60
+ func_kwargs: Optional[dict] = None,
61
+ exceptions: Type[requests.exceptions.RequestException] = requests.exceptions.RequestException,
62
+ max_retries: int = 0,
63
+ base_wait_time: float = 0.5,
64
+ max_wait_time: float = 2,
65
+ ):
66
+ func_args = func_args or ()
67
+ func_kwargs = func_kwargs or {}
68
+ retry = 0
69
+ while True:
70
+ try:
71
+ return func(*func_args, **func_kwargs)
72
+ except exceptions as err:
73
+ if retry >= max_retries:
74
+ raise err
75
+ else:
76
+ sleep_time = min(max_wait_time, base_wait_time * 2**retry) # Exponential backoff
77
+ print(f"{func} timed out, retrying in {sleep_time}s... [{retry/max_retries}]")
78
+ time.sleep(sleep_time)
79
+ retry += 1
80
+
81
+ def upload_lora_folder(upload_dir, repo_name, pretrained_model_name_or_path, prompt, hub_token=None):
82
+ repo_name = get_full_repo_name(repo_name, token=hub_token)
83
+ _retry(
84
+ create_repo,
85
+ func_kwargs={"repo_id": repo_name, "exist_ok": True, "token": hub_token},
86
+ base_wait_time=1.0,
87
+ max_retries=5,
88
+ max_wait_time=10.0,
89
+ )
90
+ save_model_card(
91
+ repo_name,
92
+ base_model=pretrained_model_name_or_path,
93
+ prompt=prompt,
94
+ repo_folder=upload_dir,
95
+ )
96
+ # Upload model
97
+ print(f"Pushing to {repo_name}")
98
+ _retry(
99
+ upload_folder,
100
+ func_kwargs={
101
+ "repo_id": repo_name,
102
+ "repo_type": "model",
103
+ "folder_path": upload_dir,
104
+ "commit_message": "submit best ckpt",
105
+ "token": hub_token,
106
+ "ignore_patterns": ["checkpoint-*/*", "logs/*", "validation_images/*"],
107
+ },
108
+ base_wait_time=1.0,
109
+ max_retries=5,
110
+ max_wait_time=20.0,
111
+ )
112
+
113
+ def save_model_card(repo_name, base_model=str, prompt=str, repo_folder=None):
114
+ yaml = f"""
115
+ ---
116
+ license: creativeml-openrail-m
117
+ base_model: {base_model}
118
+ instance_prompt: {prompt}
119
+ tags:
120
+ - stable-diffusion
121
+ - stable-diffusion-ppdiffusers
122
+ - text-to-image
123
+ - ppdiffusers
124
+ - lora
125
+ inference: false
126
+ ---
127
+ """
128
+ model_card = f"""
129
+ # LoRA DreamBooth - {repo_name}
130
+ 本仓库的 LoRA 权重是基于 {base_model} 训练而来的,我们采用[DreamBooth](https://dreambooth.github.io/)的技术并使用 {prompt} 文本进行了训练。
131
+ """
132
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
133
+ f.write(yaml + model_card)