tree3po's picture
Upload 21 files
12aae2e verified
"""
References:
- VQGAN: https://github.com/CompVis/taming-transformers
- MAE: https://github.com/facebookresearch/mae
"""
import numpy as np
import math
import functools
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.vision_transformer import Mlp
from timm.layers.helpers import to_2tuple
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
from dit import PatchEmbed
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False, dim=1):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
if dim == 1:
self.dims = [1, 2, 3]
elif dim == 2:
self.dims = [1, 2]
else:
raise NotImplementedError
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(
device=self.parameters.device
)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(
device=self.parameters.device
)
return x
def mode(self):
return self.mean
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads,
frame_height,
frame_width,
qkv_bias=False,
attn_drop=0.0,
proj_drop=0.0,
is_causal=False,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.frame_height = frame_height
self.frame_width = frame_width
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = attn_drop
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.is_causal = is_causal
rotary_freqs = RotaryEmbedding(
dim=head_dim // 4,
freqs_for="pixel",
max_freq=frame_height*frame_width,
).get_axial_freqs(frame_height, frame_width)
self.register_buffer("rotary_freqs", rotary_freqs, persistent=False)
def forward(self, x):
B, N, C = x.shape
assert N == self.frame_height * self.frame_width
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
if self.rotary_freqs is not None:
q = rearrange(q, "b h (H W) d -> b h H W d", H=self.frame_height, W=self.frame_width)
k = rearrange(k, "b h (H W) d -> b h H W d", H=self.frame_height, W=self.frame_width)
q = apply_rotary_emb(self.rotary_freqs, q)
k = apply_rotary_emb(self.rotary_freqs, k)
q = rearrange(q, "b h H W d -> b h (H W) d")
k = rearrange(k, "b h H W d -> b h (H W) d")
attn = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop,
is_causal=self.is_causal,
)
x = attn.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentionBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
frame_height,
frame_width,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
attn_causal=False,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads,
frame_height,
frame_width,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
is_causal=attn_causal,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class AutoencoderKL(nn.Module):
def __init__(
self,
latent_dim,
input_height=256,
input_width=256,
patch_size=16,
enc_dim=768,
enc_depth=6,
enc_heads=12,
dec_dim=768,
dec_depth=6,
dec_heads=12,
mlp_ratio=4.0,
norm_layer=functools.partial(nn.LayerNorm, eps=1e-6),
use_variational=True,
**kwargs,
):
super().__init__()
self.input_height = input_height
self.input_width = input_width
self.patch_size = patch_size
self.seq_h = input_height // patch_size
self.seq_w = input_width // patch_size
self.seq_len = self.seq_h * self.seq_w
self.patch_dim = 3 * patch_size**2
self.latent_dim = latent_dim
self.enc_dim = enc_dim
self.dec_dim = dec_dim
# patch
self.patch_embed = PatchEmbed(input_height, input_width, patch_size, 3, enc_dim)
# encoder
self.encoder = nn.ModuleList(
[
AttentionBlock(
enc_dim,
enc_heads,
self.seq_h,
self.seq_w,
mlp_ratio,
qkv_bias=True,
norm_layer=norm_layer,
)
for i in range(enc_depth)
]
)
self.enc_norm = norm_layer(enc_dim)
# bottleneck
self.use_variational = use_variational
mult = 2 if self.use_variational else 1
self.quant_conv = nn.Linear(enc_dim, mult * latent_dim)
self.post_quant_conv = nn.Linear(latent_dim, dec_dim)
# decoder
self.decoder = nn.ModuleList(
[
AttentionBlock(
dec_dim,
dec_heads,
self.seq_h,
self.seq_w,
mlp_ratio,
qkv_bias=True,
norm_layer=norm_layer,
)
for i in range(dec_depth)
]
)
self.dec_norm = norm_layer(dec_dim)
self.predictor = nn.Linear(dec_dim, self.patch_dim) # decoder to patch
# initialize this weight first
self.initialize_weights()
def initialize_weights(self):
# initialization
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0.0)
nn.init.constant_(m.weight, 1.0)
def patchify(self, x):
# patchify
bsz, _, h, w = x.shape
x = x.reshape(
bsz,
3,
self.seq_h,
self.patch_size,
self.seq_w,
self.patch_size,
).permute(
[0, 1, 3, 5, 2, 4]
) # [b, c, h, p, w, p] --> [b, c, p, p, h, w]
x = x.reshape(
bsz, self.patch_dim, self.seq_h, self.seq_w
) # --> [b, cxpxp, h, w]
x = x.permute([0, 2, 3, 1]).reshape(
bsz, self.seq_len, self.patch_dim
) # --> [b, hxw, cxpxp]
return x
def unpatchify(self, x):
bsz = x.shape[0]
# unpatchify
x = x.reshape(bsz, self.seq_h, self.seq_w, self.patch_dim).permute(
[0, 3, 1, 2]
) # [b, h, w, cxpxp] --> [b, cxpxp, h, w]
x = x.reshape(
bsz,
3,
self.patch_size,
self.patch_size,
self.seq_h,
self.seq_w,
).permute(
[0, 1, 4, 2, 5, 3]
) # [b, c, p, p, h, w] --> [b, c, h, p, w, p]
x = x.reshape(
bsz,
3,
self.input_height,
self.input_width,
) # [b, c, hxp, wxp]
return x
def encode(self, x):
# patchify
x = self.patch_embed(x)
# encoder
for blk in self.encoder:
x = blk(x)
x = self.enc_norm(x)
# bottleneck
moments = self.quant_conv(x)
if not self.use_variational:
moments = torch.cat((moments, torch.zeros_like(moments)), 2)
posterior = DiagonalGaussianDistribution(
moments, deterministic=(not self.use_variational), dim=2
)
return posterior
def decode(self, z):
# bottleneck
z = self.post_quant_conv(z)
# decoder
for blk in self.decoder:
z = blk(z)
z = self.dec_norm(z)
# predictor
z = self.predictor(z)
# unpatchify
dec = self.unpatchify(z)
return dec
def autoencode(self, input, sample_posterior=True):
posterior = self.encode(input)
if self.use_variational and sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior, z
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def forward(self, inputs, labels, split="train"):
rec, post, latent = self.autoencode(inputs)
return rec, post, latent
def get_last_layer(self):
return self.predictor.weight
def ViT_L_20_Shallow_Encoder(**kwargs):
if "latent_dim" in kwargs:
latent_dim = kwargs.pop("latent_dim")
else:
latent_dim = 16
return AutoencoderKL(
latent_dim=latent_dim,
patch_size=20,
enc_dim=1024,
enc_depth=6,
enc_heads=16,
dec_dim=1024,
dec_depth=12,
dec_heads=16,
input_height=360,
input_width=640,
**kwargs,
)
VAE_models = {
"vit-l-20-shallow-encoder": ViT_L_20_Shallow_Encoder,
}