Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
import copy | |
from dataclasses import asdict, dataclass | |
import numpy as np | |
import torch | |
import torchvision | |
import torchvision.utils as vutils | |
import wandb | |
from accelerate import Accelerator | |
from diffusers import AutoencoderKL | |
from PIL.Image import Image | |
from torch import Tensor, nn | |
from torch.utils.data import DataLoader, TensorDataset | |
from tqdm import tqdm | |
from tld.denoiser import Denoiser | |
from tld.diffusion import DiffusionGenerator | |
def eval_gen(diffuser: DiffusionGenerator, labels: Tensor) -> Image: | |
class_guidance = 4.5 | |
seed = 10 | |
out, _ = diffuser.generate( | |
labels=torch.repeat_interleave(labels, 8, dim=0), | |
num_imgs=64, | |
class_guidance=class_guidance, | |
seed=seed, | |
n_iter=40, | |
exponent=1, | |
sharp_f=0.1, | |
) | |
out = to_pil((vutils.make_grid((out + 1) / 2, nrow=8, padding=4)).float().clip(0, 1)) | |
out.save(f"emb_val_cfg:{class_guidance}_seed:{seed}.png") | |
return out | |
def count_parameters(model: nn.Module): | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
def count_parameters_per_layer(model: nn.Module): | |
for name, param in model.named_parameters(): | |
print(f"{name}: {param.numel()} parameters") | |
to_pil = torchvision.transforms.ToPILImage() | |
def update_ema(ema_model: nn.Module, model: nn.Module, alpha: float = 0.999): | |
with torch.no_grad(): | |
for ema_param, model_param in zip(ema_model.parameters(), model.parameters()): | |
ema_param.data.mul_(alpha).add_(model_param.data, alpha=1 - alpha) | |
class ModelConfig: | |
embed_dim: int = 512 | |
n_layers: int = 6 | |
clip_embed_size: int = 768 | |
scaling_factor: int = 8 | |
patch_size: int = 2 | |
image_size: int = 32 | |
n_channels: int = 4 | |
dropout: float = 0 | |
mlp_multiplier: int = 4 | |
batch_size: int = 128 | |
class_guidance: int = 3 | |
lr: float = 3e-4 | |
n_epoch: int = 100 | |
alpha: float = 0.999 | |
noise_embed_dims: int = 128 | |
diffusion_n_iter: int = 35 | |
from_scratch: bool = True | |
run_id: str = "" | |
model_name: str = "" | |
beta_a: float = 0.75 | |
beta_b: float = 0.75 | |
save_and_eval_every_iters: int = 1000 | |
class DataConfig: | |
latent_path: str # path to a numpy file containing latents | |
text_emb_path: str | |
val_path: str | |
def main(config: ModelConfig, dataconfig: DataConfig) -> None: | |
"""main train loop to be used with accelerate""" | |
accelerator = Accelerator(mixed_precision="fp16", log_with="wandb") | |
accelerator.print("Loading Data:") | |
latent_train_data = torch.tensor(np.load(dataconfig.latent_path), dtype=torch.float32) | |
train_label_embeddings = torch.tensor(np.load(dataconfig.text_emb_path), dtype=torch.float32) | |
emb_val = torch.tensor(np.load(dataconfig.val_path), dtype=torch.float32) | |
dataset = TensorDataset(latent_train_data, train_label_embeddings) | |
train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True) | |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
if accelerator.is_main_process: | |
vae = vae.to(accelerator.device) | |
model = Denoiser( | |
image_size=config.image_size, | |
noise_embed_dims=config.noise_embed_dims, | |
patch_size=config.patch_size, | |
embed_dim=config.embed_dim, | |
dropout=config.dropout, | |
n_layers=config.n_layers, | |
) | |
loss_fn = nn.MSELoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) | |
accelerator.print("Compiling model:") | |
model = torch.compile(model) | |
if not config.from_scratch: | |
accelerator.print("Loading Model:") | |
wandb.restore( | |
config.model_name, run_path=f"apapiu/cifar_diffusion/runs/{config.run_id}", replace=True | |
) | |
full_state_dict = torch.load(config.model_name) | |
model.load_state_dict(full_state_dict["model_ema"]) | |
optimizer.load_state_dict(full_state_dict["opt_state"]) | |
global_step = full_state_dict["global_step"] | |
else: | |
global_step = 0 | |
if accelerator.is_local_main_process: | |
ema_model = copy.deepcopy(model).to(accelerator.device) | |
diffuser = DiffusionGenerator(ema_model, vae, accelerator.device, torch.float32) | |
accelerator.print("model prep") | |
model, train_loader, optimizer = accelerator.prepare(model, train_loader, optimizer) | |
accelerator.init_trackers(project_name="cifar_diffusion", config=asdict(config)) | |
accelerator.print(count_parameters(model)) | |
accelerator.print(count_parameters_per_layer(model)) | |
### Train: | |
for i in range(1, config.n_epoch + 1): | |
accelerator.print(f"epoch: {i}") | |
for x, y in tqdm(train_loader): | |
x = x / config.scaling_factor | |
noise_level = torch.tensor( | |
np.random.beta(config.beta_a, config.beta_b, len(x)), device=accelerator.device | |
) | |
signal_level = 1 - noise_level | |
noise = torch.randn_like(x) | |
x_noisy = noise_level.view(-1, 1, 1, 1) * noise + signal_level.view(-1, 1, 1, 1) * x | |
x_noisy = x_noisy.float() | |
noise_level = noise_level.float() | |
label = y | |
prob = 0.15 | |
mask = torch.rand(y.size(0), device=accelerator.device) < prob | |
label[mask] = 0 # OR replacement_vector | |
if global_step % config.save_and_eval_every_iters == 0: | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
##eval and saving: | |
out = eval_gen(diffuser=diffuser, labels=emb_val) | |
out.save("img.jpg") | |
accelerator.log({f"step: {global_step}": wandb.Image("img.jpg")}) | |
opt_unwrapped = accelerator.unwrap_model(optimizer) | |
full_state_dict = { | |
"model_ema": ema_model.state_dict(), | |
"opt_state": opt_unwrapped.state_dict(), | |
"global_step": global_step, | |
} | |
accelerator.save(full_state_dict, config.model_name) | |
wandb.save(config.model_name) | |
model.train() | |
with accelerator.accumulate(): | |
###train loop: | |
optimizer.zero_grad() | |
pred = model(x_noisy, noise_level.view(-1, 1), label) | |
loss = loss_fn(pred, x) | |
accelerator.log({"train_loss": loss.item()}, step=global_step) | |
accelerator.backward(loss) | |
optimizer.step() | |
if accelerator.is_main_process: | |
update_ema(ema_model, model, alpha=config.alpha) | |
global_step += 1 | |
accelerator.end_training() | |
# args = (config, data_path, val_path) | |
# notebook_launcher(training_loop) | |