|
import functools |
|
import json |
|
import logging |
|
import operator |
|
import os |
|
from typing import Tuple |
|
|
|
import colossalai |
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
from colossalai.booster import Booster |
|
from colossalai.checkpoint_io import GeneralCheckpointIO |
|
from colossalai.cluster import DistCoordinator |
|
from torch.optim import Optimizer |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
from torchvision.datasets.utils import download_url |
|
|
|
pretrained_models = { |
|
"DiT-XL-2-512x512.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt", |
|
"DiT-XL-2-256x256.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt", |
|
"Latte-XL-2-256x256-ucf101.pt": "https://huggingface.co/maxin-cn/Latte/resolve/main/ucf101.pt", |
|
"PixArt-XL-2-256x256.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-256x256.pth", |
|
"PixArt-XL-2-SAM-256x256.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-SAM-256x256.pth", |
|
"PixArt-XL-2-512x512.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-512x512.pth", |
|
"PixArt-XL-2-1024-MS.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-1024-MS.pth", |
|
} |
|
|
|
|
|
def reparameter(ckpt, name=None): |
|
if "DiT" in name: |
|
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) |
|
del ckpt["pos_embed"] |
|
elif "Latte" in name: |
|
ckpt = ckpt["ema"] |
|
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) |
|
del ckpt["pos_embed"] |
|
del ckpt["temp_embed"] |
|
elif "PixArt" in name: |
|
ckpt = ckpt["state_dict"] |
|
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) |
|
del ckpt["pos_embed"] |
|
return ckpt |
|
|
|
|
|
def find_model(model_name): |
|
""" |
|
Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path. |
|
""" |
|
if model_name in pretrained_models: |
|
model = download_model(model_name) |
|
model = reparameter(model, model_name) |
|
return model |
|
else: |
|
assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}" |
|
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) |
|
if "pos_embed_temporal" in checkpoint: |
|
del checkpoint["pos_embed_temporal"] |
|
if "pos_embed" in checkpoint: |
|
del checkpoint["pos_embed"] |
|
if "ema" in checkpoint: |
|
checkpoint = checkpoint["ema"] |
|
return checkpoint |
|
|
|
|
|
def download_model(model_name): |
|
""" |
|
Downloads a pre-trained DiT model from the web. |
|
""" |
|
assert model_name in pretrained_models |
|
local_path = f"pretrained_models/{model_name}" |
|
if not os.path.isfile(local_path): |
|
os.makedirs("pretrained_models", exist_ok=True) |
|
web_path = pretrained_models[model_name] |
|
download_url(web_path, "pretrained_models", model_name) |
|
model = torch.load(local_path, map_location=lambda storage, loc: storage) |
|
return model |
|
|
|
|
|
def load_from_sharded_state_dict(model, ckpt_path): |
|
ckpt_io = GeneralCheckpointIO() |
|
ckpt_io.load_model(model, os.path.join(ckpt_path, "model")) |
|
|
|
def model_sharding(model: torch.nn.Module): |
|
global_rank = dist.get_rank() |
|
world_size = dist.get_world_size() |
|
for _, param in model.named_parameters(): |
|
padding_size = (world_size - param.numel() % world_size) % world_size |
|
if padding_size > 0: |
|
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) |
|
else: |
|
padding_param = param.data.view(-1) |
|
splited_params = padding_param.split(padding_param.numel() // world_size) |
|
splited_params = splited_params[global_rank] |
|
param.data = splited_params |
|
|
|
|
|
def load_json(file_path: str): |
|
with open(file_path, "r") as f: |
|
return json.load(f) |
|
|
|
|
|
def save_json(data, file_path: str): |
|
with open(file_path, "w") as f: |
|
json.dump(data, f, indent=4) |
|
|
|
|
|
def remove_padding(tensor: torch.Tensor, original_shape: Tuple) -> torch.Tensor: |
|
return tensor[: functools.reduce(operator.mul, original_shape)] |
|
|
|
|
|
def model_gathering(model: torch.nn.Module, model_shape_dict: dict): |
|
global_rank = dist.get_rank() |
|
global_size = dist.get_world_size() |
|
for name, param in model.named_parameters(): |
|
all_params = [torch.empty_like(param.data) for _ in range(global_size)] |
|
dist.all_gather(all_params, param.data, group=dist.group.WORLD) |
|
if int(global_rank) == 0: |
|
all_params = torch.cat(all_params) |
|
param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name]) |
|
dist.barrier() |
|
|
|
|
|
def record_model_param_shape(model: torch.nn.Module) -> dict: |
|
param_shape = {} |
|
for name, param in model.named_parameters(): |
|
param_shape[name] = param.shape |
|
return param_shape |
|
|
|
|
|
def save( |
|
booster: Booster, |
|
model: nn.Module, |
|
ema: nn.Module, |
|
optimizer: Optimizer, |
|
lr_scheduler: _LRScheduler, |
|
epoch: int, |
|
step: int, |
|
global_step: int, |
|
batch_size: int, |
|
coordinator: DistCoordinator, |
|
save_dir: str, |
|
shape_dict: dict, |
|
): |
|
save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}") |
|
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) |
|
|
|
booster.save_model(model, os.path.join(save_dir, "model"), shard=True) |
|
|
|
model_gathering(ema, shape_dict) |
|
global_rank = dist.get_rank() |
|
if int(global_rank) == 0: |
|
torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt")) |
|
model_sharding(ema) |
|
|
|
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) |
|
if lr_scheduler is not None: |
|
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) |
|
running_states = { |
|
"epoch": epoch, |
|
"step": step, |
|
"global_step": global_step, |
|
"sample_start_index": step * batch_size, |
|
} |
|
if coordinator.is_master(): |
|
save_json(running_states, os.path.join(save_dir, "running_states.json")) |
|
dist.barrier() |
|
|
|
|
|
def load( |
|
booster: Booster, model: nn.Module, ema: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str |
|
) -> Tuple[int, int, int]: |
|
booster.load_model(model, os.path.join(load_dir, "model")) |
|
|
|
|
|
ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu"))) |
|
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) |
|
if lr_scheduler is not None: |
|
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) |
|
running_states = load_json(os.path.join(load_dir, "running_states.json")) |
|
dist.barrier() |
|
return running_states["epoch"], running_states["step"], running_states["sample_start_index"] |
|
|
|
|
|
def create_logger(logging_dir): |
|
""" |
|
Create a logger that writes to a log file and stdout. |
|
""" |
|
if dist.get_rank() == 0: |
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="[\033[34m%(asctime)s\033[0m] %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")], |
|
) |
|
logger = logging.getLogger(__name__) |
|
else: |
|
logger = logging.getLogger(__name__) |
|
logger.addHandler(logging.NullHandler()) |
|
return logger |
|
|
|
|
|
def load_checkpoint(model, ckpt_path, save_as_pt=True): |
|
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"): |
|
state_dict = find_model(ckpt_path) |
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
|
print(f"Missing keys: {missing_keys}") |
|
print(f"Unexpected keys: {unexpected_keys}") |
|
elif os.path.isdir(ckpt_path): |
|
load_from_sharded_state_dict(model, ckpt_path) |
|
if save_as_pt: |
|
save_path = os.path.join(ckpt_path, "model_ckpt.pt") |
|
torch.save(model.state_dict(), save_path) |
|
print(f"Model checkpoint saved to {save_path}") |
|
else: |
|
raise ValueError(f"Invalid checkpoint path: {ckpt_path}") |
|
|