|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import math |
|
|
|
from einops import rearrange |
|
from pdb import set_trace as st |
|
|
|
|
|
|
|
from .dit_models_xformers import DiT, DiTBlock, DiT_models, get_2d_sincos_pos_embed, modulate, FinalLayer |
|
|
|
|
|
|
|
def modulate2(x, shift, scale): |
|
return x * (1 + scale) + shift |
|
|
|
|
|
class DiTBlock2(DiTBlock): |
|
""" |
|
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. |
|
""" |
|
|
|
def __init__(self, hidden_size, num_heads, mlp_ratio=4, **block_kwargs): |
|
super().__init__(hidden_size, num_heads, mlp_ratio, **block_kwargs) |
|
|
|
def forward(self, x, c): |
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation( |
|
c).chunk(6, dim=-1) |
|
|
|
x = x + gate_msa * self.attn( |
|
modulate2(self.norm1(x), shift_msa, scale_msa)) |
|
x = x + gate_mlp * self.mlp( |
|
modulate2(self.norm2(x), shift_mlp, scale_mlp)) |
|
return x |
|
|
|
|
|
class FinalLayer2(FinalLayer): |
|
""" |
|
The final layer of DiT, basically the decoder_pred in MAE with adaLN. |
|
""" |
|
|
|
def __init__(self, hidden_size, patch_size, out_channels): |
|
super().__init__(hidden_size, patch_size, out_channels) |
|
|
|
def forward(self, x, c): |
|
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) |
|
x = modulate2(self.norm_final(x), shift, scale) |
|
x = self.linear(x) |
|
return x |
|
|
|
|
|
class DiT2(DiT): |
|
|
|
def __init__(self, |
|
input_size=32, |
|
patch_size=2, |
|
in_channels=4, |
|
hidden_size=1152, |
|
depth=28, |
|
num_heads=16, |
|
mlp_ratio=4, |
|
class_dropout_prob=0.1, |
|
num_classes=1000, |
|
learn_sigma=True, |
|
mixing_logit_init=-3, |
|
mixed_prediction=True, |
|
context_dim=False, |
|
roll_out=False, |
|
plane_n=3, |
|
return_all_layers=False, |
|
vit_blk=...): |
|
super().__init__(input_size, |
|
patch_size, |
|
in_channels, |
|
hidden_size, |
|
depth, |
|
num_heads, |
|
mlp_ratio, |
|
class_dropout_prob, |
|
num_classes, |
|
learn_sigma, |
|
mixing_logit_init, |
|
mixed_prediction, |
|
context_dim, |
|
roll_out, |
|
vit_blk=DiTBlock2, |
|
final_layer_blk=FinalLayer2) |
|
|
|
|
|
del self.x_embedder |
|
del self.t_embedder |
|
del self.final_layer |
|
torch.cuda.empty_cache() |
|
self.clip_text_proj = None |
|
self.plane_n = plane_n |
|
self.return_all_layers = return_all_layers |
|
|
|
def forward(self, c, *args, **kwargs): |
|
|
|
""" |
|
Forward pass of DiT. |
|
c: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) |
|
""" |
|
x = self.pos_embed.repeat( |
|
c.shape[0], 1, 1) |
|
|
|
if self.return_all_layers: |
|
all_layers = [] |
|
|
|
|
|
|
|
|
|
for blk_idx, block in enumerate(self.blocks): |
|
if self.roll_out: |
|
if blk_idx % 2 == 0: |
|
x = rearrange(x, 'b (n l) c -> (b n) l c ', n=self.plane_n) |
|
x = block(x, |
|
rearrange(c, |
|
'b (n l) c -> (b n) l c ', |
|
n=self.plane_n)) |
|
|
|
if self.return_all_layers: |
|
all_layers.append(x) |
|
else: |
|
x = rearrange(x, '(b n) l c -> b (n l) c ', n=self.plane_n) |
|
x = block(x, c) |
|
|
|
if self.return_all_layers: |
|
|
|
all_layers.append( |
|
rearrange(x, |
|
'b (n l) c -> (b n) l c', |
|
n=self.plane_n)) |
|
else: |
|
x = block(x, c) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.return_all_layers: |
|
return all_layers |
|
else: |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def DiT2_XL_2(**kwargs): |
|
return DiT2(depth=28, |
|
hidden_size=1152, |
|
patch_size=2, |
|
num_heads=16, |
|
**kwargs) |
|
|
|
|
|
def DiT2_XL_2_half(**kwargs): |
|
return DiT2(depth=28 // 2, |
|
hidden_size=1152, |
|
patch_size=2, |
|
num_heads=16, |
|
**kwargs) |
|
|
|
|
|
def DiT2_XL_4(**kwargs): |
|
return DiT2(depth=28, |
|
hidden_size=1152, |
|
patch_size=4, |
|
num_heads=16, |
|
**kwargs) |
|
|
|
|
|
def DiT2_XL_8(**kwargs): |
|
return DiT2(depth=28, |
|
hidden_size=1152, |
|
patch_size=8, |
|
num_heads=16, |
|
**kwargs) |
|
|
|
|
|
def DiT2_L_2(**kwargs): |
|
return DiT2(depth=24, |
|
hidden_size=1024, |
|
patch_size=2, |
|
num_heads=16, |
|
**kwargs) |
|
|
|
|
|
def DiT2_L_2_half(**kwargs): |
|
return DiT2(depth=24 // 2, |
|
hidden_size=1024, |
|
patch_size=2, |
|
num_heads=16, |
|
**kwargs) |
|
|
|
|
|
def DiT2_L_4(**kwargs): |
|
return DiT2(depth=24, |
|
hidden_size=1024, |
|
patch_size=4, |
|
num_heads=16, |
|
**kwargs) |
|
|
|
|
|
def DiT2_L_8(**kwargs): |
|
return DiT2(depth=24, |
|
hidden_size=1024, |
|
patch_size=8, |
|
num_heads=16, |
|
**kwargs) |
|
|
|
|
|
def DiT2_B_2(**kwargs): |
|
return DiT2(depth=12, |
|
hidden_size=768, |
|
patch_size=2, |
|
num_heads=12, |
|
**kwargs) |
|
|
|
|
|
def DiT2_B_4(**kwargs): |
|
return DiT2(depth=12, |
|
hidden_size=768, |
|
patch_size=4, |
|
num_heads=12, |
|
**kwargs) |
|
|
|
|
|
def DiT2_B_8(**kwargs): |
|
return DiT2(depth=12, |
|
hidden_size=768, |
|
patch_size=8, |
|
num_heads=12, |
|
**kwargs) |
|
|
|
|
|
def DiT2_B_16(**kwargs): |
|
return DiT2(depth=12, |
|
hidden_size=768, |
|
patch_size=16, |
|
num_heads=12, |
|
**kwargs) |
|
|
|
|
|
def DiT2_S_2(**kwargs): |
|
return DiT2(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) |
|
|
|
|
|
def DiT2_S_4(**kwargs): |
|
return DiT2(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) |
|
|
|
|
|
def DiT2_S_8(**kwargs): |
|
return DiT2(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) |
|
|
|
|
|
DiT2_models = { |
|
'DiT2-XL/2': DiT2_XL_2, |
|
'DiT2-XL/2/half': DiT2_XL_2_half, |
|
'DiT2-XL/4': DiT2_XL_4, |
|
'DiT2-XL/8': DiT2_XL_8, |
|
'DiT2-L/2': DiT2_L_2, |
|
'DiT2-L/2/half': DiT2_L_2_half, |
|
'DiT2-L/4': DiT2_L_4, |
|
'DiT2-L/8': DiT2_L_8, |
|
'DiT2-B/2': DiT2_B_2, |
|
'DiT2-B/4': DiT2_B_4, |
|
'DiT2-B/8': DiT2_B_8, |
|
'DiT2-B/16': DiT2_B_16, |
|
'DiT2-S/2': DiT2_S_2, |
|
'DiT2-S/4': DiT2_S_4, |
|
'DiT2-S/8': DiT2_S_8, |
|
} |
|
|