|
import ast |
|
import gc |
|
import random |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from diffusers.models.attention_processor import AttnProcessor2_0 |
|
from diffusers.models.attention import BasicTransformerBlock |
|
from decord import VideoReader |
|
import wandb |
|
|
|
|
|
def extract_into_tensor(a, t, x_shape): |
|
b, *_ = t.shape |
|
out = a.gather(-1, t) |
|
return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
|
|
|
def is_attn(name): |
|
return "attn1" or "attn2" == name.split(".")[-1] |
|
|
|
|
|
def set_processors(attentions): |
|
for attn in attentions: |
|
attn.set_processor(AttnProcessor2_0()) |
|
|
|
|
|
def set_torch_2_attn(unet): |
|
optim_count = 0 |
|
|
|
for name, module in unet.named_modules(): |
|
if is_attn(name): |
|
if isinstance(module, torch.nn.ModuleList): |
|
for m in module: |
|
if isinstance(m, BasicTransformerBlock): |
|
set_processors([m.attn1, m.attn2]) |
|
optim_count += 1 |
|
if optim_count > 0: |
|
print(f"{optim_count} Attention layers using Scaled Dot Product Attention.") |
|
|
|
|
|
|
|
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): |
|
""" |
|
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 |
|
|
|
Args: |
|
timesteps (`torch.Tensor`): |
|
generate embedding vectors at these timesteps |
|
embedding_dim (`int`, *optional*, defaults to 512): |
|
dimension of the embeddings to generate |
|
dtype: |
|
data type of the generated embeddings |
|
|
|
Returns: |
|
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` |
|
""" |
|
assert len(w.shape) == 1 |
|
w = w * 1000.0 |
|
|
|
half_dim = embedding_dim // 2 |
|
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) |
|
emb = w.to(dtype)[:, None] * emb[None, :] |
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
|
if embedding_dim % 2 == 1: |
|
emb = torch.nn.functional.pad(emb, (0, 1)) |
|
assert emb.shape == (w.shape[0], embedding_dim) |
|
return emb |
|
|
|
|
|
def append_dims(x, target_dims): |
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
|
dims_to_append = target_dims - x.ndim |
|
if dims_to_append < 0: |
|
raise ValueError( |
|
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" |
|
) |
|
return x[(...,) + (None,) * dims_to_append] |
|
|
|
|
|
|
|
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): |
|
scaled_timestep = timestep_scaling * timestep |
|
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) |
|
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 |
|
return c_skip, c_out |
|
|
|
|
|
|
|
def get_predicted_original_sample( |
|
model_output, timesteps, sample, prediction_type, alphas, sigmas |
|
): |
|
alphas = extract_into_tensor(alphas, timesteps, sample.shape) |
|
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) |
|
if prediction_type == "epsilon": |
|
pred_x_0 = (sample - sigmas * model_output) / alphas |
|
elif prediction_type == "sample": |
|
pred_x_0 = model_output |
|
elif prediction_type == "v_prediction": |
|
pred_x_0 = alphas * sample - sigmas * model_output |
|
else: |
|
raise ValueError( |
|
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" |
|
f" are supported." |
|
) |
|
|
|
return pred_x_0 |
|
|
|
|
|
|
|
def get_predicted_noise( |
|
model_output, timesteps, sample, prediction_type, alphas, sigmas |
|
): |
|
alphas = extract_into_tensor(alphas, timesteps, sample.shape) |
|
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) |
|
if prediction_type == "epsilon": |
|
pred_epsilon = model_output |
|
elif prediction_type == "sample": |
|
pred_epsilon = (sample - alphas * model_output) / sigmas |
|
elif prediction_type == "v_prediction": |
|
pred_epsilon = alphas * model_output + sigmas * sample |
|
else: |
|
raise ValueError( |
|
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" |
|
f" are supported." |
|
) |
|
|
|
return pred_epsilon |
|
|
|
|
|
|
|
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): |
|
""" |
|
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 |
|
|
|
Args: |
|
timesteps (`torch.Tensor`): |
|
generate embedding vectors at these timesteps |
|
embedding_dim (`int`, *optional*, defaults to 512): |
|
dimension of the embeddings to generate |
|
dtype: |
|
data type of the generated embeddings |
|
|
|
Returns: |
|
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` |
|
""" |
|
assert len(w.shape) == 1 |
|
w = w * 1000.0 |
|
|
|
half_dim = embedding_dim // 2 |
|
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) |
|
emb = w.to(dtype)[:, None] * emb[None, :] |
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
|
if embedding_dim % 2 == 1: |
|
emb = torch.nn.functional.pad(emb, (0, 1)) |
|
assert emb.shape == (w.shape[0], embedding_dim) |
|
return emb |
|
|
|
|
|
def append_dims(x, target_dims): |
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
|
dims_to_append = target_dims - x.ndim |
|
if dims_to_append < 0: |
|
raise ValueError( |
|
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" |
|
) |
|
return x[(...,) + (None,) * dims_to_append] |
|
|
|
|
|
|
|
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): |
|
scaled_timestep = timestep_scaling * timestep |
|
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) |
|
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 |
|
return c_skip, c_out |
|
|
|
|
|
|
|
def get_predicted_original_sample( |
|
model_output, timesteps, sample, prediction_type, alphas, sigmas |
|
): |
|
alphas = extract_into_tensor(alphas, timesteps, sample.shape) |
|
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) |
|
if prediction_type == "epsilon": |
|
pred_x_0 = (sample - sigmas * model_output) / alphas |
|
elif prediction_type == "sample": |
|
pred_x_0 = model_output |
|
elif prediction_type == "v_prediction": |
|
pred_x_0 = alphas * sample - sigmas * model_output |
|
else: |
|
raise ValueError( |
|
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" |
|
f" are supported." |
|
) |
|
|
|
return pred_x_0 |
|
|
|
|
|
|
|
def get_predicted_noise( |
|
model_output, timesteps, sample, prediction_type, alphas, sigmas |
|
): |
|
alphas = extract_into_tensor(alphas, timesteps, sample.shape) |
|
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) |
|
if prediction_type == "epsilon": |
|
pred_epsilon = model_output |
|
elif prediction_type == "sample": |
|
pred_epsilon = (sample - alphas * model_output) / sigmas |
|
elif prediction_type == "v_prediction": |
|
pred_epsilon = alphas * model_output + sigmas * sample |
|
else: |
|
raise ValueError( |
|
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" |
|
f" are supported." |
|
) |
|
|
|
return pred_epsilon |
|
|
|
|
|
def param_optim(model, condition, extra_params=None, is_lora=False, negation=None): |
|
extra_params = extra_params if len(extra_params.keys()) > 0 else None |
|
return { |
|
"model": model, |
|
"condition": condition, |
|
"extra_params": extra_params, |
|
"is_lora": is_lora, |
|
"negation": negation, |
|
} |
|
|
|
|
|
def create_optim_params(name="param", params=None, lr=5e-6, extra_params=None): |
|
params = {"name": name, "params": params, "lr": lr} |
|
if extra_params is not None: |
|
for k, v in extra_params.items(): |
|
params[k] = v |
|
|
|
return params |
|
|
|
|
|
def create_optimizer_params(model_list, lr): |
|
import itertools |
|
|
|
optimizer_params = [] |
|
|
|
for optim in model_list: |
|
model, condition, extra_params, is_lora, negation = optim.values() |
|
|
|
if is_lora and condition and isinstance(model, list): |
|
params = create_optim_params( |
|
params=itertools.chain(*model), extra_params=extra_params |
|
) |
|
optimizer_params.append(params) |
|
continue |
|
|
|
if is_lora and condition and not isinstance(model, list): |
|
for n, p in model.named_parameters(): |
|
if "lora" in n: |
|
params = create_optim_params(n, p, lr, extra_params) |
|
optimizer_params.append(params) |
|
continue |
|
|
|
|
|
if condition: |
|
for n, p in model.named_parameters(): |
|
should_negate = "lora" in n and not is_lora |
|
if should_negate: |
|
continue |
|
|
|
params = create_optim_params(n, p, lr, extra_params) |
|
optimizer_params.append(params) |
|
|
|
return optimizer_params |
|
|
|
|
|
def handle_trainable_modules( |
|
model, trainable_modules=None, is_enabled=True, negation=None |
|
): |
|
acc = [] |
|
unfrozen_params = 0 |
|
|
|
if trainable_modules is not None: |
|
unlock_all = any([name == "all" for name in trainable_modules]) |
|
if unlock_all: |
|
model.requires_grad_(True) |
|
unfrozen_params = len(list(model.parameters())) |
|
else: |
|
model.requires_grad_(False) |
|
for name, param in model.named_parameters(): |
|
for tm in trainable_modules: |
|
if all([tm in name, name not in acc, "lora" not in name]): |
|
param.requires_grad_(is_enabled) |
|
acc.append(name) |
|
unfrozen_params += 1 |
|
|
|
|
|
def huber_loss(pred, target, huber_c=0.001): |
|
loss = torch.sqrt((pred.float() - target.float()) ** 2 + huber_c**2) - huber_c |
|
return loss.mean() |
|
|
|
|
|
@torch.no_grad() |
|
def update_ema(target_params, source_params, rate=0.99): |
|
""" |
|
Update target parameters to be closer to those of source parameters using |
|
an exponential moving average. |
|
|
|
:param target_params: the target parameter sequence. |
|
:param source_params: the source parameter sequence. |
|
:param rate: the EMA rate (closer to 1 means slower). |
|
""" |
|
for targ, src in zip(target_params, source_params): |
|
src_to_dtype = src.to(targ.dtype) |
|
targ.detach().mul_(rate).add_(src_to_dtype, alpha=1 - rate) |
|
|
|
|
|
def log_validation_video(pipeline, args, accelerator, save_fps): |
|
if args.seed is None: |
|
generator = None |
|
else: |
|
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) |
|
|
|
validation_prompts = [ |
|
"An astronaut riding a horse.", |
|
"Darth vader surfing in waves.", |
|
"Robot dancing in times square.", |
|
"Clown fish swimming through the coral reef.", |
|
"A child excitedly swings on a rusty swing set, laughter filling the air.", |
|
"With the style of van gogh, A young couple dances under the moonlight by the lake.", |
|
"A young woman with glasses is jogging in the park wearing a pink headband.", |
|
"Impressionist style, a yellow rubber duck floating on the wave on the sunset", |
|
"Wolf, turns its head, in the wild", |
|
"Iron man, walks, on the moon, 8k, high detailed, best quality", |
|
"With the style of low-poly game art, A majestic, white horse gallops gracefully", |
|
"a rabbit, low-poly game art style", |
|
] |
|
|
|
video_logs = [] |
|
|
|
if getattr(args, "use_motion_cond", False): |
|
use_motion_cond = True |
|
else: |
|
use_motion_cond = False |
|
|
|
for _, prompt in enumerate(validation_prompts): |
|
if use_motion_cond: |
|
motin_gs_unit = (args.motion_gs_max - args.motion_gs_min) / 2 |
|
for i in range(3): |
|
with torch.autocast("cuda"): |
|
videos = pipeline( |
|
prompt=prompt, |
|
frames=args.n_frames, |
|
num_inference_steps=8, |
|
num_videos_per_prompt=1, |
|
fps=args.fps, |
|
use_motion_cond=True, |
|
motion_gs=motin_gs_unit * i, |
|
lcm_origin_steps=args.num_ddim_timesteps, |
|
generator=generator, |
|
) |
|
videos = (videos.clamp(-1.0, 1.0) + 1.0) / 2.0 |
|
videos = ( |
|
(videos * 255) |
|
.to(torch.uint8) |
|
.permute(0, 2, 1, 3, 4) |
|
.cpu() |
|
.numpy() |
|
) |
|
video_logs.append( |
|
{ |
|
"validation_prompt": f"GS={i * motin_gs_unit}, {prompt}", |
|
"videos": videos, |
|
} |
|
) |
|
else: |
|
for i in range(2): |
|
with torch.autocast("cuda"): |
|
videos = pipeline( |
|
prompt=prompt, |
|
frames=args.n_frames, |
|
num_inference_steps=4 * (i + 1), |
|
num_videos_per_prompt=1, |
|
fps=args.fps, |
|
use_motion_cond=False, |
|
lcm_origin_steps=args.num_ddim_timesteps, |
|
generator=generator, |
|
) |
|
videos = (videos.clamp(-1.0, 1.0) + 1.0) / 2.0 |
|
videos = ( |
|
(videos * 255) |
|
.to(torch.uint8) |
|
.permute(0, 2, 1, 3, 4) |
|
.cpu() |
|
.numpy() |
|
) |
|
video_logs.append( |
|
{ |
|
"validation_prompt": f"Steps={4 * (i + 1)}, {prompt}", |
|
"videos": videos, |
|
} |
|
) |
|
|
|
for tracker in accelerator.trackers: |
|
if tracker.name == "wandb": |
|
formatted_videos = [] |
|
for log in video_logs: |
|
videos = log["videos"] |
|
validation_prompt = log["validation_prompt"] |
|
for video in videos: |
|
video = wandb.Video(video, caption=validation_prompt, fps=save_fps) |
|
formatted_videos.append(video) |
|
|
|
tracker.log({f"validation": formatted_videos}) |
|
|
|
del pipeline |
|
gc.collect() |
|
|
|
|
|
def tuple_type(s): |
|
if isinstance(s, tuple): |
|
return s |
|
value = ast.literal_eval(s) |
|
if isinstance(value, tuple): |
|
return value |
|
raise TypeError("Argument must be a tuple") |
|
|
|
|
|
def load_model_checkpoint(model, ckpt): |
|
def load_checkpoint(model, ckpt, full_strict): |
|
state_dict = torch.load(ckpt, map_location="cpu", weights_only=True) |
|
if "state_dict" in list(state_dict.keys()): |
|
state_dict = state_dict["state_dict"] |
|
model.load_state_dict(state_dict, strict=full_strict) |
|
del state_dict |
|
gc.collect() |
|
return model |
|
|
|
load_checkpoint(model, ckpt, full_strict=True) |
|
print(">>> model checkpoint loaded.") |
|
return model |
|
|
|
|
|
def read_video_to_tensor( |
|
path_to_video, sample_fps, sample_frames, uniform_sampling=False |
|
): |
|
video_reader = VideoReader(path_to_video) |
|
video_fps = video_reader.get_avg_fps() |
|
video_frames = video_reader._num_frame |
|
video_duration = video_frames / video_fps |
|
sample_duration = sample_frames / sample_fps |
|
stride = video_fps / sample_fps |
|
|
|
if uniform_sampling or video_duration <= sample_duration: |
|
index_range = np.linspace(0, video_frames - 1, sample_frames).astype(np.int32) |
|
else: |
|
max_start_frame = video_frames - np.ceil(sample_frames * stride).astype( |
|
np.int32 |
|
) |
|
if max_start_frame > 0: |
|
start_frame = random.randint(0, max_start_frame) |
|
else: |
|
start_frame = 0 |
|
|
|
index_range = start_frame + np.arange(sample_frames) * stride |
|
index_range = np.round(index_range).astype(np.int32) |
|
|
|
sampled_frames = video_reader.get_batch(index_range).asnumpy() |
|
pixel_values = torch.from_numpy(sampled_frames).permute(0, 3, 1, 2).contiguous() |
|
pixel_values = pixel_values / 255.0 |
|
del video_reader |
|
|
|
return pixel_values |
|
|
|
|
|
def calculate_motion_rank_new(tensor_ref, tensor_gen, rank_k=1): |
|
if rank_k == 0: |
|
loss = torch.tensor(0.0, device=tensor_ref.device) |
|
elif rank_k > tensor_ref.shape[-1]: |
|
raise ValueError( |
|
"The value of rank_k cannot be larger than the number of frames" |
|
) |
|
else: |
|
|
|
_, sorted_indices = torch.sort(tensor_ref, dim=-1) |
|
|
|
mask = torch.zeros_like(tensor_ref, dtype=torch.bool) |
|
mask.scatter_(-1, sorted_indices[..., -rank_k:], True) |
|
|
|
loss = F.mse_loss(tensor_ref[mask].detach(), tensor_gen[mask]) |
|
return loss |
|
|
|
|
|
def compute_temp_loss(attention_prob, attention_prob_example): |
|
temp_attn_prob_loss = [] |
|
|
|
for name in attention_prob.keys(): |
|
attn_prob_example = attention_prob_example[name] |
|
attn_prob = attention_prob[name] |
|
|
|
module_attn_loss = calculate_motion_rank_new( |
|
attn_prob_example.detach(), attn_prob, rank_k=1 |
|
) |
|
temp_attn_prob_loss.append(module_attn_loss) |
|
|
|
loss_temp = torch.stack(temp_attn_prob_loss) * 100 |
|
loss = loss_temp.mean() |
|
return loss |
|
|