Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
class SinusoidalEmbedding(nn.Module): | |
def __init__(self, emb_min_freq=1.0, emb_max_freq=1000.0, embedding_dims=32): | |
super(SinusoidalEmbedding, self).__init__() | |
frequencies = torch.exp( | |
torch.linspace(np.log(emb_min_freq), np.log(emb_max_freq), embedding_dims // 2) | |
) | |
self.register_buffer("angular_speeds", 2.0 * torch.pi * frequencies) | |
def forward(self, x): | |
embeddings = torch.cat( | |
[torch.sin(self.angular_speeds * x), torch.cos(self.angular_speeds * x)], dim=-1 | |
) | |
return embeddings | |
class MHAttention(nn.Module): | |
def __init__(self, is_causal=False, dropout_level=0.0, n_heads=4): | |
super().__init__() | |
self.is_causal = is_causal | |
self.dropout_level = dropout_level | |
self.n_heads = n_heads | |
def forward(self, q, k, v, attn_mask=None): | |
assert q.size(-1) == k.size(-1) | |
assert k.size(-2) == v.size(-2) | |
q, k, v = [rearrange(x, "bs n (h d) -> bs h n d", h=self.n_heads) for x in [q, k, v]] | |
out = nn.functional.scaled_dot_product_attention( | |
q, | |
k, | |
v, | |
attn_mask=attn_mask, | |
is_causal=self.is_causal, | |
dropout_p=self.dropout_level if self.training else 0, | |
) | |
out = rearrange(out, "bs h n d -> bs n (h d)", h=self.n_heads) | |
return out | |
class SelfAttention(nn.Module): | |
def __init__(self, embed_dim, is_causal=False, dropout_level=0.0, n_heads=4): | |
super().__init__() | |
self.qkv_linear = nn.Linear(embed_dim, 3 * embed_dim, bias=False) | |
self.mha = MHAttention(is_causal, dropout_level, n_heads) | |
def forward(self, x): | |
q, k, v = self.qkv_linear(x).chunk(3, dim=2) | |
return self.mha(q, k, v) | |
class CrossAttention(nn.Module): | |
def __init__(self, embed_dim, is_causal=False, dropout_level=0, n_heads=4): | |
super().__init__() | |
self.kv_linear = nn.Linear(embed_dim, 2 * embed_dim, bias=False) | |
self.q_linear = nn.Linear(embed_dim, embed_dim, bias=False) | |
self.mha = MHAttention(is_causal, dropout_level, n_heads) | |
def forward(self, x, y): | |
q = self.q_linear(x) | |
k, v = self.kv_linear(y).chunk(2, dim=2) | |
return self.mha(q, k, v) | |
class MLP(nn.Module): | |
def __init__(self, embed_dim, mlp_multiplier, dropout_level): | |
super().__init__() | |
self.mlp = nn.Sequential( | |
nn.Linear(embed_dim, mlp_multiplier * embed_dim), | |
nn.GELU(), | |
nn.Linear(mlp_multiplier * embed_dim, embed_dim), | |
nn.Dropout(dropout_level), | |
) | |
def forward(self, x): | |
return self.mlp(x) | |
class MLPSepConv(nn.Module): | |
def __init__(self, embed_dim, mlp_multiplier, dropout_level): | |
"""see: https://github.com/ofsoundof/LocalViT""" | |
super().__init__() | |
self.mlp = nn.Sequential( | |
# this Conv with kernel size 1 is equivalent to the Linear layer in a "regular" transformer MLP | |
nn.Conv2d(embed_dim, mlp_multiplier * embed_dim, kernel_size=1, padding="same"), | |
nn.Conv2d( | |
mlp_multiplier * embed_dim, | |
mlp_multiplier * embed_dim, | |
kernel_size=3, | |
padding="same", | |
groups=mlp_multiplier * embed_dim, | |
), # <- depthwise conv | |
nn.GELU(), | |
nn.Conv2d(mlp_multiplier * embed_dim, embed_dim, kernel_size=1, padding="same"), | |
nn.Dropout(dropout_level), | |
) | |
def forward(self, x): | |
w = h = int(np.sqrt(x.size(1))) # only square images for now | |
x = rearrange(x, "bs (h w) d -> bs d h w", h=h, w=w) | |
x = self.mlp(x) | |
x = rearrange(x, "bs d h w -> bs (h w) d") | |
return x | |
class DecoderBlock(nn.Module): | |
def __init__( | |
self, | |
embed_dim: int, | |
is_causal: bool, | |
mlp_multiplier: int, | |
dropout_level: float, | |
mlp_class: type[MLP] | type[MLPSepConv], | |
): | |
super().__init__() | |
self.self_attention = SelfAttention(embed_dim, is_causal, dropout_level, n_heads=embed_dim // 64) | |
self.cross_attention = CrossAttention( | |
embed_dim, is_causal=False, dropout_level=0, n_heads=embed_dim // 64 | |
) | |
self.mlp = mlp_class(embed_dim, mlp_multiplier, dropout_level) | |
self.norm1 = nn.LayerNorm(embed_dim) | |
self.norm2 = nn.LayerNorm(embed_dim) | |
self.norm3 = nn.LayerNorm(embed_dim) | |
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
x = self.self_attention(self.norm1(x)) + x | |
x = self.cross_attention(self.norm2(x), y) + x | |
x = self.mlp(self.norm3(x)) + x | |
return x | |