LN3Diff / dit /dit_decoder.py
NIRVANALAN
release file
87c126b
import torch
import torch.nn as nn
import numpy as np
import math
from einops import rearrange
from pdb import set_trace as st
# from .dit_models import DiT, DiTBlock, DiT_models, get_2d_sincos_pos_embed, modulate, FinalLayer
from .dit_models_xformers import DiT, DiTBlock, DiT_models, get_2d_sincos_pos_embed, modulate, FinalLayer
# from .dit_models import DiT, DiTBlock, DiT_models, get_2d_sincos_pos_embed, modulate, FinalLayer
def modulate2(x, shift, scale):
return x * (1 + scale) + shift
class DiTBlock2(DiTBlock):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4, **block_kwargs):
super().__init__(hidden_size, num_heads, mlp_ratio, **block_kwargs)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
c).chunk(6, dim=-1)
# st()
x = x + gate_msa * self.attn(
modulate2(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp * self.mlp(
modulate2(self.norm2(x), shift_mlp, scale_mlp))
return x
class FinalLayer2(FinalLayer):
"""
The final layer of DiT, basically the decoder_pred in MAE with adaLN.
"""
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__(hidden_size, patch_size, out_channels)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate2(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DiT2(DiT):
# a conditional ViT
def __init__(self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
roll_out=False,
plane_n=3,
return_all_layers=False,
vit_blk=...):
super().__init__(input_size,
patch_size,
in_channels,
hidden_size,
depth,
num_heads,
mlp_ratio,
class_dropout_prob,
num_classes,
learn_sigma,
mixing_logit_init,
mixed_prediction,
context_dim,
roll_out,
vit_blk=DiTBlock2,
final_layer_blk=FinalLayer2)
# no t and x embedder
del self.x_embedder
del self.t_embedder
del self.final_layer
torch.cuda.empty_cache()
self.clip_text_proj = None
self.plane_n = plane_n
self.return_all_layers = return_all_layers
def forward(self, c, *args, **kwargs):
# return super().forward(x, timesteps, context, y, get_attr, **kwargs)
"""
Forward pass of DiT.
c: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
"""
x = self.pos_embed.repeat(
c.shape[0], 1, 1) # (N, T, D), where T = H * W / patch_size ** 2
if self.return_all_layers:
all_layers = []
# if context is not None:
# c = context # B 3HW C
for blk_idx, block in enumerate(self.blocks):
if self.roll_out:
if blk_idx % 2 == 0: # with-in plane self attention
x = rearrange(x, 'b (n l) c -> (b n) l c ', n=self.plane_n)
x = block(x,
rearrange(c,
'b (n l) c -> (b n) l c ',
n=self.plane_n)) # (N, T, D)
# st()
if self.return_all_layers:
all_layers.append(x)
else: # global attention
x = rearrange(x, '(b n) l c -> b (n l) c ', n=self.plane_n)
x = block(x, c) # (N, T, D)
# st()
if self.return_all_layers:
# all merged into B dim
all_layers.append(
rearrange(x,
'b (n l) c -> (b n) l c',
n=self.plane_n))
else:
x = block(x, c) # (N, T, D)
# x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
# if self.roll_out: # move n from L to B axis
# x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)
# x = self.unpatchify(x) # (N, out_channels, H, W)
# if self.roll_out: # move n from L to B axis
# x = rearrange(x, '(b n) c h w -> b (n c) h w', n=3)
if self.return_all_layers:
return all_layers
else:
return x
# class DiT2_DPT(DiT2):
# def __init__(self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4, class_dropout_prob=0.1, num_classes=1000, learn_sigma=True, mixing_logit_init=-3, mixed_prediction=True, context_dim=False, roll_out=False, plane_n=3, vit_blk=...):
# super().__init__(input_size, patch_size, in_channels, hidden_size, depth, num_heads, mlp_ratio, class_dropout_prob, num_classes, learn_sigma, mixing_logit_init, mixed_prediction, context_dim, roll_out, plane_n, vit_blk)
# self.return_all_layers = True
#################################################################################
# DiT2 Configs #
#################################################################################
def DiT2_XL_2(**kwargs):
return DiT2(depth=28,
hidden_size=1152,
patch_size=2,
num_heads=16,
**kwargs)
def DiT2_XL_2_half(**kwargs):
return DiT2(depth=28 // 2,
hidden_size=1152,
patch_size=2,
num_heads=16,
**kwargs)
def DiT2_XL_4(**kwargs):
return DiT2(depth=28,
hidden_size=1152,
patch_size=4,
num_heads=16,
**kwargs)
def DiT2_XL_8(**kwargs):
return DiT2(depth=28,
hidden_size=1152,
patch_size=8,
num_heads=16,
**kwargs)
def DiT2_L_2(**kwargs):
return DiT2(depth=24,
hidden_size=1024,
patch_size=2,
num_heads=16,
**kwargs)
def DiT2_L_2_half(**kwargs):
return DiT2(depth=24 // 2,
hidden_size=1024,
patch_size=2,
num_heads=16,
**kwargs)
def DiT2_L_4(**kwargs):
return DiT2(depth=24,
hidden_size=1024,
patch_size=4,
num_heads=16,
**kwargs)
def DiT2_L_8(**kwargs):
return DiT2(depth=24,
hidden_size=1024,
patch_size=8,
num_heads=16,
**kwargs)
def DiT2_B_2(**kwargs):
return DiT2(depth=12,
hidden_size=768,
patch_size=2,
num_heads=12,
**kwargs)
def DiT2_B_4(**kwargs):
return DiT2(depth=12,
hidden_size=768,
patch_size=4,
num_heads=12,
**kwargs)
def DiT2_B_8(**kwargs):
return DiT2(depth=12,
hidden_size=768,
patch_size=8,
num_heads=12,
**kwargs)
def DiT2_B_16(**kwargs): # ours cfg
return DiT2(depth=12,
hidden_size=768,
patch_size=16,
num_heads=12,
**kwargs)
def DiT2_S_2(**kwargs):
return DiT2(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
def DiT2_S_4(**kwargs):
return DiT2(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
def DiT2_S_8(**kwargs):
return DiT2(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
DiT2_models = {
'DiT2-XL/2': DiT2_XL_2,
'DiT2-XL/2/half': DiT2_XL_2_half,
'DiT2-XL/4': DiT2_XL_4,
'DiT2-XL/8': DiT2_XL_8,
'DiT2-L/2': DiT2_L_2,
'DiT2-L/2/half': DiT2_L_2_half,
'DiT2-L/4': DiT2_L_4,
'DiT2-L/8': DiT2_L_8,
'DiT2-B/2': DiT2_B_2,
'DiT2-B/4': DiT2_B_4,
'DiT2-B/8': DiT2_B_8,
'DiT2-B/16': DiT2_B_16,
'DiT2-S/2': DiT2_S_2,
'DiT2-S/4': DiT2_S_4,
'DiT2-S/8': DiT2_S_8,
}