Spaces:
Runtime error
Runtime error
""" | |
References: | |
- VQGAN: https://github.com/CompVis/taming-transformers | |
- MAE: https://github.com/facebookresearch/mae | |
""" | |
import numpy as np | |
import math | |
import functools | |
from collections import namedtuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
from timm.models.vision_transformer import Mlp | |
from timm.layers.helpers import to_2tuple | |
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb | |
from dit import PatchEmbed | |
class DiagonalGaussianDistribution(object): | |
def __init__(self, parameters, deterministic=False, dim=1): | |
self.parameters = parameters | |
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim) | |
if dim == 1: | |
self.dims = [1, 2, 3] | |
elif dim == 2: | |
self.dims = [1, 2] | |
else: | |
raise NotImplementedError | |
self.logvar = torch.clamp(self.logvar, -30.0, 20.0) | |
self.deterministic = deterministic | |
self.std = torch.exp(0.5 * self.logvar) | |
self.var = torch.exp(self.logvar) | |
if self.deterministic: | |
self.var = self.std = torch.zeros_like(self.mean).to( | |
device=self.parameters.device | |
) | |
def sample(self): | |
x = self.mean + self.std * torch.randn(self.mean.shape).to( | |
device=self.parameters.device | |
) | |
return x | |
def mode(self): | |
return self.mean | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads, | |
frame_height, | |
frame_width, | |
qkv_bias=False, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
is_causal=False, | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.frame_height = frame_height | |
self.frame_width = frame_width | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = attn_drop | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
self.is_causal = is_causal | |
rotary_freqs = RotaryEmbedding( | |
dim=head_dim // 4, | |
freqs_for="pixel", | |
max_freq=frame_height*frame_width, | |
).get_axial_freqs(frame_height, frame_width) | |
self.register_buffer("rotary_freqs", rotary_freqs, persistent=False) | |
def forward(self, x): | |
B, N, C = x.shape | |
assert N == self.frame_height * self.frame_width | |
qkv = ( | |
self.qkv(x) | |
.reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
.permute(2, 0, 3, 1, 4) | |
) | |
q, k, v = ( | |
qkv[0], | |
qkv[1], | |
qkv[2], | |
) # make torchscript happy (cannot use tensor as tuple) | |
if self.rotary_freqs is not None: | |
q = rearrange(q, "b h (H W) d -> b h H W d", H=self.frame_height, W=self.frame_width) | |
k = rearrange(k, "b h (H W) d -> b h H W d", H=self.frame_height, W=self.frame_width) | |
q = apply_rotary_emb(self.rotary_freqs, q) | |
k = apply_rotary_emb(self.rotary_freqs, k) | |
q = rearrange(q, "b h H W d -> b h (H W) d") | |
k = rearrange(k, "b h H W d -> b h (H W) d") | |
attn = F.scaled_dot_product_attention( | |
q, | |
k, | |
v, | |
dropout_p=self.attn_drop, | |
is_causal=self.is_causal, | |
) | |
x = attn.transpose(1, 2).reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class AttentionBlock(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads, | |
frame_height, | |
frame_width, | |
mlp_ratio=4.0, | |
qkv_bias=False, | |
drop=0.0, | |
attn_drop=0.0, | |
attn_causal=False, | |
drop_path=0.0, | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm, | |
): | |
super().__init__() | |
self.norm1 = norm_layer(dim) | |
self.attn = Attention( | |
dim, | |
num_heads, | |
frame_height, | |
frame_width, | |
qkv_bias=qkv_bias, | |
attn_drop=attn_drop, | |
proj_drop=drop, | |
is_causal=attn_causal, | |
) | |
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp( | |
in_features=dim, | |
hidden_features=mlp_hidden_dim, | |
act_layer=act_layer, | |
drop=drop, | |
) | |
def forward(self, x): | |
x = x + self.drop_path(self.attn(self.norm1(x))) | |
x = x + self.drop_path(self.mlp(self.norm2(x))) | |
return x | |
class AutoencoderKL(nn.Module): | |
def __init__( | |
self, | |
latent_dim, | |
input_height=256, | |
input_width=256, | |
patch_size=16, | |
enc_dim=768, | |
enc_depth=6, | |
enc_heads=12, | |
dec_dim=768, | |
dec_depth=6, | |
dec_heads=12, | |
mlp_ratio=4.0, | |
norm_layer=functools.partial(nn.LayerNorm, eps=1e-6), | |
use_variational=True, | |
**kwargs, | |
): | |
super().__init__() | |
self.input_height = input_height | |
self.input_width = input_width | |
self.patch_size = patch_size | |
self.seq_h = input_height // patch_size | |
self.seq_w = input_width // patch_size | |
self.seq_len = self.seq_h * self.seq_w | |
self.patch_dim = 3 * patch_size**2 | |
self.latent_dim = latent_dim | |
self.enc_dim = enc_dim | |
self.dec_dim = dec_dim | |
# patch | |
self.patch_embed = PatchEmbed(input_height, input_width, patch_size, 3, enc_dim) | |
# encoder | |
self.encoder = nn.ModuleList( | |
[ | |
AttentionBlock( | |
enc_dim, | |
enc_heads, | |
self.seq_h, | |
self.seq_w, | |
mlp_ratio, | |
qkv_bias=True, | |
norm_layer=norm_layer, | |
) | |
for i in range(enc_depth) | |
] | |
) | |
self.enc_norm = norm_layer(enc_dim) | |
# bottleneck | |
self.use_variational = use_variational | |
mult = 2 if self.use_variational else 1 | |
self.quant_conv = nn.Linear(enc_dim, mult * latent_dim) | |
self.post_quant_conv = nn.Linear(latent_dim, dec_dim) | |
# decoder | |
self.decoder = nn.ModuleList( | |
[ | |
AttentionBlock( | |
dec_dim, | |
dec_heads, | |
self.seq_h, | |
self.seq_w, | |
mlp_ratio, | |
qkv_bias=True, | |
norm_layer=norm_layer, | |
) | |
for i in range(dec_depth) | |
] | |
) | |
self.dec_norm = norm_layer(dec_dim) | |
self.predictor = nn.Linear(dec_dim, self.patch_dim) # decoder to patch | |
# initialize this weight first | |
self.initialize_weights() | |
def initialize_weights(self): | |
# initialization | |
# initialize nn.Linear and nn.LayerNorm | |
self.apply(self._init_weights) | |
# initialize patch_embed like nn.Linear (instead of nn.Conv2d) | |
w = self.patch_embed.proj.weight.data | |
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
# we use xavier_uniform following official JAX ViT: | |
nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0.0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0.0) | |
nn.init.constant_(m.weight, 1.0) | |
def patchify(self, x): | |
# patchify | |
bsz, _, h, w = x.shape | |
x = x.reshape( | |
bsz, | |
3, | |
self.seq_h, | |
self.patch_size, | |
self.seq_w, | |
self.patch_size, | |
).permute( | |
[0, 1, 3, 5, 2, 4] | |
) # [b, c, h, p, w, p] --> [b, c, p, p, h, w] | |
x = x.reshape( | |
bsz, self.patch_dim, self.seq_h, self.seq_w | |
) # --> [b, cxpxp, h, w] | |
x = x.permute([0, 2, 3, 1]).reshape( | |
bsz, self.seq_len, self.patch_dim | |
) # --> [b, hxw, cxpxp] | |
return x | |
def unpatchify(self, x): | |
bsz = x.shape[0] | |
# unpatchify | |
x = x.reshape(bsz, self.seq_h, self.seq_w, self.patch_dim).permute( | |
[0, 3, 1, 2] | |
) # [b, h, w, cxpxp] --> [b, cxpxp, h, w] | |
x = x.reshape( | |
bsz, | |
3, | |
self.patch_size, | |
self.patch_size, | |
self.seq_h, | |
self.seq_w, | |
).permute( | |
[0, 1, 4, 2, 5, 3] | |
) # [b, c, p, p, h, w] --> [b, c, h, p, w, p] | |
x = x.reshape( | |
bsz, | |
3, | |
self.input_height, | |
self.input_width, | |
) # [b, c, hxp, wxp] | |
return x | |
def encode(self, x): | |
# patchify | |
x = self.patch_embed(x) | |
# encoder | |
for blk in self.encoder: | |
x = blk(x) | |
x = self.enc_norm(x) | |
# bottleneck | |
moments = self.quant_conv(x) | |
if not self.use_variational: | |
moments = torch.cat((moments, torch.zeros_like(moments)), 2) | |
posterior = DiagonalGaussianDistribution( | |
moments, deterministic=(not self.use_variational), dim=2 | |
) | |
return posterior | |
def decode(self, z): | |
# bottleneck | |
z = self.post_quant_conv(z) | |
# decoder | |
for blk in self.decoder: | |
z = blk(z) | |
z = self.dec_norm(z) | |
# predictor | |
z = self.predictor(z) | |
# unpatchify | |
dec = self.unpatchify(z) | |
return dec | |
def autoencode(self, input, sample_posterior=True): | |
posterior = self.encode(input) | |
if self.use_variational and sample_posterior: | |
z = posterior.sample() | |
else: | |
z = posterior.mode() | |
dec = self.decode(z) | |
return dec, posterior, z | |
def get_input(self, batch, k): | |
x = batch[k] | |
if len(x.shape) == 3: | |
x = x[..., None] | |
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() | |
return x | |
def forward(self, inputs, labels, split="train"): | |
rec, post, latent = self.autoencode(inputs) | |
return rec, post, latent | |
def get_last_layer(self): | |
return self.predictor.weight | |
def ViT_L_20_Shallow_Encoder(**kwargs): | |
if "latent_dim" in kwargs: | |
latent_dim = kwargs.pop("latent_dim") | |
else: | |
latent_dim = 16 | |
return AutoencoderKL( | |
latent_dim=latent_dim, | |
patch_size=20, | |
enc_dim=1024, | |
enc_depth=6, | |
enc_heads=16, | |
dec_dim=1024, | |
dec_depth=12, | |
dec_heads=16, | |
input_height=360, | |
input_width=640, | |
**kwargs, | |
) | |
VAE_models = { | |
"vit-l-20-shallow-encoder": ViT_L_20_Shallow_Encoder, | |
} | |