""" References: - VQGAN: https://github.com/CompVis/taming-transformers - MAE: https://github.com/facebookresearch/mae """ import numpy as np import math import functools from collections import namedtuple import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from timm.models.vision_transformer import Mlp from timm.layers.helpers import to_2tuple from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb from dit import PatchEmbed class DiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False, dim=1): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim) if dim == 1: self.dims = [1, 2, 3] elif dim == 2: self.dims = [1, 2] else: raise NotImplementedError self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like(self.mean).to( device=self.parameters.device ) def sample(self): x = self.mean + self.std * torch.randn(self.mean.shape).to( device=self.parameters.device ) return x def mode(self): return self.mean class Attention(nn.Module): def __init__( self, dim, num_heads, frame_height, frame_width, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, is_causal=False, ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.frame_height = frame_height self.frame_width = frame_width self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = attn_drop self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.is_causal = is_causal rotary_freqs = RotaryEmbedding( dim=head_dim // 4, freqs_for="pixel", max_freq=frame_height*frame_width, ).get_axial_freqs(frame_height, frame_width) self.register_buffer("rotary_freqs", rotary_freqs, persistent=False) def forward(self, x): B, N, C = x.shape assert N == self.frame_height * self.frame_width qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = ( qkv[0], qkv[1], qkv[2], ) # make torchscript happy (cannot use tensor as tuple) if self.rotary_freqs is not None: q = rearrange(q, "b h (H W) d -> b h H W d", H=self.frame_height, W=self.frame_width) k = rearrange(k, "b h (H W) d -> b h H W d", H=self.frame_height, W=self.frame_width) q = apply_rotary_emb(self.rotary_freqs, q) k = apply_rotary_emb(self.rotary_freqs, k) q = rearrange(q, "b h H W d -> b h (H W) d") k = rearrange(k, "b h H W d -> b h (H W) d") attn = F.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_drop, is_causal=self.is_causal, ) x = attn.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class AttentionBlock(nn.Module): def __init__( self, dim, num_heads, frame_height, frame_width, mlp_ratio=4.0, qkv_bias=False, drop=0.0, attn_drop=0.0, attn_causal=False, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads, frame_height, frame_width, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, is_causal=attn_causal, ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class AutoencoderKL(nn.Module): def __init__( self, latent_dim, input_height=256, input_width=256, patch_size=16, enc_dim=768, enc_depth=6, enc_heads=12, dec_dim=768, dec_depth=6, dec_heads=12, mlp_ratio=4.0, norm_layer=functools.partial(nn.LayerNorm, eps=1e-6), use_variational=True, **kwargs, ): super().__init__() self.input_height = input_height self.input_width = input_width self.patch_size = patch_size self.seq_h = input_height // patch_size self.seq_w = input_width // patch_size self.seq_len = self.seq_h * self.seq_w self.patch_dim = 3 * patch_size**2 self.latent_dim = latent_dim self.enc_dim = enc_dim self.dec_dim = dec_dim # patch self.patch_embed = PatchEmbed(input_height, input_width, patch_size, 3, enc_dim) # encoder self.encoder = nn.ModuleList( [ AttentionBlock( enc_dim, enc_heads, self.seq_h, self.seq_w, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, ) for i in range(enc_depth) ] ) self.enc_norm = norm_layer(enc_dim) # bottleneck self.use_variational = use_variational mult = 2 if self.use_variational else 1 self.quant_conv = nn.Linear(enc_dim, mult * latent_dim) self.post_quant_conv = nn.Linear(latent_dim, dec_dim) # decoder self.decoder = nn.ModuleList( [ AttentionBlock( dec_dim, dec_heads, self.seq_h, self.seq_w, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, ) for i in range(dec_depth) ] ) self.dec_norm = norm_layer(dec_dim) self.predictor = nn.Linear(dec_dim, self.patch_dim) # decoder to patch # initialize this weight first self.initialize_weights() def initialize_weights(self): # initialization # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) # initialize patch_embed like nn.Linear (instead of nn.Conv2d) w = self.patch_embed.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0.0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0.0) nn.init.constant_(m.weight, 1.0) def patchify(self, x): # patchify bsz, _, h, w = x.shape x = x.reshape( bsz, 3, self.seq_h, self.patch_size, self.seq_w, self.patch_size, ).permute( [0, 1, 3, 5, 2, 4] ) # [b, c, h, p, w, p] --> [b, c, p, p, h, w] x = x.reshape( bsz, self.patch_dim, self.seq_h, self.seq_w ) # --> [b, cxpxp, h, w] x = x.permute([0, 2, 3, 1]).reshape( bsz, self.seq_len, self.patch_dim ) # --> [b, hxw, cxpxp] return x def unpatchify(self, x): bsz = x.shape[0] # unpatchify x = x.reshape(bsz, self.seq_h, self.seq_w, self.patch_dim).permute( [0, 3, 1, 2] ) # [b, h, w, cxpxp] --> [b, cxpxp, h, w] x = x.reshape( bsz, 3, self.patch_size, self.patch_size, self.seq_h, self.seq_w, ).permute( [0, 1, 4, 2, 5, 3] ) # [b, c, p, p, h, w] --> [b, c, h, p, w, p] x = x.reshape( bsz, 3, self.input_height, self.input_width, ) # [b, c, hxp, wxp] return x def encode(self, x): # patchify x = self.patch_embed(x) # encoder for blk in self.encoder: x = blk(x) x = self.enc_norm(x) # bottleneck moments = self.quant_conv(x) if not self.use_variational: moments = torch.cat((moments, torch.zeros_like(moments)), 2) posterior = DiagonalGaussianDistribution( moments, deterministic=(not self.use_variational), dim=2 ) return posterior def decode(self, z): # bottleneck z = self.post_quant_conv(z) # decoder for blk in self.decoder: z = blk(z) z = self.dec_norm(z) # predictor z = self.predictor(z) # unpatchify dec = self.unpatchify(z) return dec def autoencode(self, input, sample_posterior=True): posterior = self.encode(input) if self.use_variational and sample_posterior: z = posterior.sample() else: z = posterior.mode() dec = self.decode(z) return dec, posterior, z def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() return x def forward(self, inputs, labels, split="train"): rec, post, latent = self.autoencode(inputs) return rec, post, latent def get_last_layer(self): return self.predictor.weight def ViT_L_20_Shallow_Encoder(**kwargs): if "latent_dim" in kwargs: latent_dim = kwargs.pop("latent_dim") else: latent_dim = 16 return AutoencoderKL( latent_dim=latent_dim, patch_size=20, enc_dim=1024, enc_depth=6, enc_heads=16, dec_dim=1024, dec_depth=12, dec_heads=16, input_height=360, input_width=640, **kwargs, ) VAE_models = { "vit-l-20-shallow-encoder": ViT_L_20_Shallow_Encoder, }