Spaces:
Running
on
Zero
Running
on
Zero
import json | |
from time import time | |
import argparse | |
import logging | |
import os | |
from pathlib import Path | |
import math | |
import numpy as np | |
from PIL import Image | |
from copy import deepcopy | |
import torch | |
import torch.distributed as dist | |
from torch.utils.data import Dataset, DataLoader | |
from torch.utils.data.distributed import DistributedSampler | |
from torchvision import transforms | |
from accelerate import Accelerator | |
from accelerate.utils import ProjectConfiguration, set_seed | |
from diffusers.optimization import get_scheduler | |
from accelerate.utils import DistributedType | |
from peft import LoraConfig, set_peft_model_state_dict, PeftModel, get_peft_model | |
from peft.utils import get_peft_model_state_dict | |
from huggingface_hub import snapshot_download | |
from safetensors.torch import save_file | |
from diffusers.models import AutoencoderKL | |
from OmniGen import OmniGen, OmniGenProcessor | |
from OmniGen.train_helper import DatasetFromJson, TrainDataCollator | |
from OmniGen.train_helper import training_losses | |
from OmniGen.utils import ( | |
create_logger, | |
update_ema, | |
requires_grad, | |
center_crop_arr, | |
crop_arr, | |
vae_encode, | |
vae_encode_list | |
) | |
def main(args): | |
# Setup accelerator: | |
from accelerate import DistributedDataParallelKwargs as DDPK | |
kwargs = DDPK(find_unused_parameters=False) | |
accelerator = Accelerator( | |
gradient_accumulation_steps=args.gradient_accumulation_steps, | |
mixed_precision=args.mixed_precision, | |
log_with=args.report_to, | |
project_dir=args.results_dir, | |
kwargs_handlers=[kwargs], | |
) | |
device = accelerator.device | |
accelerator.init_trackers("tensorboard_log", config=args.__dict__) | |
# Setup an experiment folder: | |
checkpoint_dir = f"{args.results_dir}/checkpoints" # Stores saved model checkpoints | |
logger = create_logger(args.results_dir) | |
if accelerator.is_main_process: | |
os.makedirs(checkpoint_dir, exist_ok=True) | |
logger.info(f"Experiment directory created at {args.results_dir}") | |
json.dump(args.__dict__, open(os.path.join(args.results_dir, 'train_args.json'), 'w')) | |
# Create model: | |
if not os.path.exists(args.model_name_or_path): | |
cache_folder = os.getenv('HF_HUB_CACHE') | |
args.model_name_or_path = snapshot_download(repo_id=args.model_name_or_path, | |
cache_dir=cache_folder, | |
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) | |
logger.info(f"Downloaded model to {args.model_name_or_path}") | |
model = OmniGen.from_pretrained(args.model_name_or_path) | |
model.llm.config.use_cache = False | |
model.llm.gradient_checkpointing_enable() | |
model = model.to(device) | |
if args.vae_path is None: | |
print(args.model_name_or_path) | |
vae_path = os.path.join(args.model_name_or_path, "vae") | |
if os.path.exists(vae_path): | |
vae = AutoencoderKL.from_pretrained(vae_path).to(device) | |
else: | |
logger.info("No VAE found in model, downloading stabilityai/sdxl-vae from HF") | |
logger.info("If you have VAE in local folder, please specify the path with --vae_path") | |
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device) | |
else: | |
vae = AutoencoderKL.from_pretrained(args.vae_path).to(device) | |
weight_dtype = torch.float32 | |
if accelerator.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif accelerator.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
vae.to(dtype=torch.float32) | |
model.to(weight_dtype) | |
processor = OmniGenProcessor.from_pretrained(args.model_name_or_path) | |
requires_grad(vae, False) | |
if args.use_lora: | |
if accelerator.distributed_type == DistributedType.FSDP: | |
raise NotImplementedError("FSDP does not support LoRA") | |
requires_grad(model, False) | |
transformer_lora_config = LoraConfig( | |
r=args.lora_rank, | |
lora_alpha=args.lora_rank, | |
init_lora_weights="gaussian", | |
target_modules=["qkv_proj", "o_proj"], | |
) | |
model.llm.enable_input_require_grads() | |
model = get_peft_model(model, transformer_lora_config) | |
model.to(weight_dtype) | |
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, model.parameters())) | |
for n,p in model.named_parameters(): | |
print(n, p.requires_grad) | |
opt = torch.optim.AdamW(transformer_lora_parameters, lr=args.lr, weight_decay=args.adam_weight_decay) | |
else: | |
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.adam_weight_decay) | |
ema = None | |
if args.use_ema: | |
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training | |
requires_grad(ema, False) | |
# Setup data: | |
crop_func = crop_arr | |
if not args.keep_raw_resolution: | |
crop_func = center_crop_arr | |
image_transform = transforms.Compose([ | |
transforms.Lambda(lambda pil_image: crop_func(pil_image, args.max_image_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) | |
]) | |
dataset = DatasetFromJson(json_file=args.json_file, | |
image_path=args.image_path, | |
processer=processor, | |
image_transform=image_transform, | |
max_input_length_limit=args.max_input_length_limit, | |
condition_dropout_prob=args.condition_dropout_prob, | |
keep_raw_resolution=args.keep_raw_resolution | |
) | |
collate_fn = TrainDataCollator(pad_token_id=processor.text_tokenizer.eos_token_id, hidden_size=model.llm.config.hidden_size, keep_raw_resolution=args.keep_raw_resolution) | |
loader = DataLoader( | |
dataset, | |
collate_fn=collate_fn, | |
batch_size=args.batch_size_per_device, | |
shuffle=True, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=True, | |
prefetch_factor=2, | |
) | |
if accelerator.is_main_process: | |
logger.info(f"Dataset contains {len(dataset):,}") | |
num_update_steps_per_epoch = math.ceil(len(loader) / args.gradient_accumulation_steps) | |
max_train_steps = args.epochs * num_update_steps_per_epoch | |
lr_scheduler = get_scheduler( | |
args.lr_scheduler, | |
optimizer=opt, | |
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | |
num_training_steps=max_train_steps * args.gradient_accumulation_steps, | |
) | |
# Prepare models for training: | |
model.train() # important! This enables embedding dropout for classifier-free guidance | |
if ema is not None: | |
update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights | |
ema.eval() # EMA model should always be in eval mode | |
if ema is not None: | |
model, ema = accelerator.prepare(model, ema) | |
else: | |
model = accelerator.prepare(model) | |
opt, loader, lr_scheduler = accelerator.prepare(opt, loader, lr_scheduler) | |
# Variables for monitoring/logging purposes: | |
train_steps, log_steps = 0, 0 | |
running_loss = 0 | |
start_time = time() | |
if accelerator.is_main_process: | |
logger.info(f"Training for {args.epochs} epochs...") | |
for epoch in range(args.epochs): | |
if accelerator.is_main_process: | |
logger.info(f"Beginning epoch {epoch}...") | |
for data in loader: | |
with accelerator.accumulate(model): | |
with torch.no_grad(): | |
output_images = data['output_images'] | |
input_pixel_values = data['input_pixel_values'] | |
if isinstance(output_images, list): | |
output_images = vae_encode_list(vae, output_images, weight_dtype) | |
if input_pixel_values is not None: | |
input_pixel_values = vae_encode_list(vae, input_pixel_values, weight_dtype) | |
else: | |
output_images = vae_encode(vae, output_images, weight_dtype) | |
if input_pixel_values is not None: | |
input_pixel_values = vae_encode(vae, input_pixel_values, weight_dtype) | |
model_kwargs = dict(input_ids=data['input_ids'], input_img_latents=input_pixel_values, input_image_sizes=data['input_image_sizes'], attention_mask=data['attention_mask'], position_ids=data['position_ids'], padding_latent=data['padding_images'], past_key_values=None, return_past_key_values=False) | |
loss_dict = training_losses(model, output_images, model_kwargs) | |
loss = loss_dict["loss"].mean() | |
running_loss += loss.item() | |
accelerator.backward(loss) | |
if args.max_grad_norm is not None and accelerator.sync_gradients: | |
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) | |
opt.step() | |
lr_scheduler.step() | |
opt.zero_grad() | |
log_steps += 1 | |
train_steps += 1 | |
accelerator.log({"training_loss": loss.item()}, step=train_steps) | |
if train_steps % args.gradient_accumulation_steps == 0: | |
if accelerator.sync_gradients and ema is not None: | |
update_ema(ema, model) | |
if train_steps % (args.log_every * args.gradient_accumulation_steps) == 0 and train_steps > 0: | |
torch.cuda.synchronize() | |
end_time = time() | |
steps_per_sec = log_steps / args.gradient_accumulation_steps / (end_time - start_time) | |
# Reduce loss history over all processes: | |
avg_loss = torch.tensor(running_loss / log_steps, device=device) | |
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) | |
avg_loss = avg_loss.item() / accelerator.num_processes | |
if accelerator.is_main_process: | |
cur_lr = opt.param_groups[0]["lr"] | |
logger.info(f"(step={int(train_steps/args.gradient_accumulation_steps):07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}, Epoch: {train_steps/len(loader)}, LR: {cur_lr}") | |
# Reset monitoring variables: | |
running_loss = 0 | |
log_steps = 0 | |
start_time = time() | |
if train_steps % (args.ckpt_every * args.gradient_accumulation_steps) == 0 and train_steps > 0: | |
if accelerator.distributed_type == DistributedType.FSDP: | |
state_dict = accelerator.get_state_dict(model) | |
ema_state_dict = accelerator.get_state_dict(ema) if ema is not None else None | |
else: | |
if not args.use_lora: | |
state_dict = model.module.state_dict() | |
ema_state_dict = accelerator.get_state_dict(ema) if ema is not None else None | |
if accelerator.is_main_process: | |
if args.use_lora: | |
checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}/" | |
os.makedirs(checkpoint_path, exist_ok=True) | |
model.module.save_pretrained(checkpoint_path) | |
else: | |
checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}/" | |
os.makedirs(checkpoint_path, exist_ok=True) | |
torch.save(state_dict, os.path.join(checkpoint_path, "model.pt")) | |
processor.text_tokenizer.save_pretrained(checkpoint_path) | |
model.llm.config.save_pretrained(checkpoint_path) | |
if ema_state_dict is not None: | |
checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}_ema" | |
os.makedirs(checkpoint_path, exist_ok=True) | |
torch.save(state_dict, os.path.join(checkpoint_path, "model.pt")) | |
processor.text_tokenizer.save_pretrained(checkpoint_path) | |
model.llm.config.save_pretrained(checkpoint_path) | |
logger.info(f"Saved checkpoint to {checkpoint_path}") | |
dist.barrier() | |
accelerator.end_training() | |
model.eval() | |
if accelerator.is_main_process: | |
logger.info("Done!") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--results_dir", type=str, default="results") | |
parser.add_argument("--model_name_or_path", type=str, default="OmniGen") | |
parser.add_argument("--json_file", type=str) | |
parser.add_argument("--image_path", type=str, default=None) | |
parser.add_argument("--epochs", type=int, default=1400) | |
parser.add_argument("--batch_size_per_device", type=int, default=1) | |
parser.add_argument("--vae_path", type=str, default=None) | |
parser.add_argument("--num_workers", type=int, default=4) | |
parser.add_argument("--log_every", type=int, default=100) | |
parser.add_argument("--ckpt_every", type=int, default=20000) | |
parser.add_argument("--max_grad_norm", type=float, default=1.0) | |
parser.add_argument("--lr", type=float, default=1e-4) | |
parser.add_argument("--max_input_length_limit", type=int, default=1024) | |
parser.add_argument("--condition_dropout_prob", type=float, default=0.1) | |
parser.add_argument("--adam_weight_decay", type=float, default=0.0) | |
parser.add_argument( | |
"--keep_raw_resolution", | |
action="store_true", | |
help="multiple_resolutions", | |
) | |
parser.add_argument("--max_image_size", type=int, default=1344) | |
parser.add_argument( | |
"--use_lora", | |
action="store_true", | |
) | |
parser.add_argument( | |
"--lora_rank", | |
type=int, | |
default=8 | |
) | |
parser.add_argument( | |
"--use_ema", | |
action="store_true", | |
help="Whether or not to use ema.", | |
) | |
parser.add_argument( | |
"--lr_scheduler", | |
type=str, | |
default="constant", | |
help=( | |
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | |
' "constant", "constant_with_warmup"]' | |
), | |
) | |
parser.add_argument( | |
"--lr_warmup_steps", type=int, default=1000, help="Number of steps for the warmup in the lr scheduler." | |
) | |
parser.add_argument( | |
"--report_to", | |
type=str, | |
default="tensorboard", | |
help=( | |
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' | |
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' | |
), | |
) | |
parser.add_argument( | |
"--mixed_precision", | |
type=str, | |
default="bf16", | |
choices=["no", "fp16", "bf16"], | |
help=( | |
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" | |
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" | |
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." | |
), | |
) | |
parser.add_argument( | |
"--gradient_accumulation_steps", | |
type=int, | |
default=1, | |
help="Number of updates steps to accumulate before performing a backward/update pass.", | |
) | |
args = parser.parse_args() | |
assert args.max_image_size % 16 == 0, "Image size must be divisible by 16." | |
main(args) | |