import numpy as np import torch import torch.nn as nn from einops import rearrange class SinusoidalEmbedding(nn.Module): def __init__(self, emb_min_freq=1.0, emb_max_freq=1000.0, embedding_dims=32): super(SinusoidalEmbedding, self).__init__() frequencies = torch.exp( torch.linspace(np.log(emb_min_freq), np.log(emb_max_freq), embedding_dims // 2) ) self.register_buffer("angular_speeds", 2.0 * torch.pi * frequencies) def forward(self, x): embeddings = torch.cat( [torch.sin(self.angular_speeds * x), torch.cos(self.angular_speeds * x)], dim=-1 ) return embeddings class MHAttention(nn.Module): def __init__(self, is_causal=False, dropout_level=0.0, n_heads=4): super().__init__() self.is_causal = is_causal self.dropout_level = dropout_level self.n_heads = n_heads def forward(self, q, k, v, attn_mask=None): assert q.size(-1) == k.size(-1) assert k.size(-2) == v.size(-2) q, k, v = [rearrange(x, "bs n (h d) -> bs h n d", h=self.n_heads) for x in [q, k, v]] out = nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, is_causal=self.is_causal, dropout_p=self.dropout_level if self.training else 0, ) out = rearrange(out, "bs h n d -> bs n (h d)", h=self.n_heads) return out class SelfAttention(nn.Module): def __init__(self, embed_dim, is_causal=False, dropout_level=0.0, n_heads=4): super().__init__() self.qkv_linear = nn.Linear(embed_dim, 3 * embed_dim, bias=False) self.mha = MHAttention(is_causal, dropout_level, n_heads) def forward(self, x): q, k, v = self.qkv_linear(x).chunk(3, dim=2) return self.mha(q, k, v) class CrossAttention(nn.Module): def __init__(self, embed_dim, is_causal=False, dropout_level=0, n_heads=4): super().__init__() self.kv_linear = nn.Linear(embed_dim, 2 * embed_dim, bias=False) self.q_linear = nn.Linear(embed_dim, embed_dim, bias=False) self.mha = MHAttention(is_causal, dropout_level, n_heads) def forward(self, x, y): q = self.q_linear(x) k, v = self.kv_linear(y).chunk(2, dim=2) return self.mha(q, k, v) class MLP(nn.Module): def __init__(self, embed_dim, mlp_multiplier, dropout_level): super().__init__() self.mlp = nn.Sequential( nn.Linear(embed_dim, mlp_multiplier * embed_dim), nn.GELU(), nn.Linear(mlp_multiplier * embed_dim, embed_dim), nn.Dropout(dropout_level), ) def forward(self, x): return self.mlp(x) class MLPSepConv(nn.Module): def __init__(self, embed_dim, mlp_multiplier, dropout_level): """see: https://github.com/ofsoundof/LocalViT""" super().__init__() self.mlp = nn.Sequential( # this Conv with kernel size 1 is equivalent to the Linear layer in a "regular" transformer MLP nn.Conv2d(embed_dim, mlp_multiplier * embed_dim, kernel_size=1, padding="same"), nn.Conv2d( mlp_multiplier * embed_dim, mlp_multiplier * embed_dim, kernel_size=3, padding="same", groups=mlp_multiplier * embed_dim, ), # <- depthwise conv nn.GELU(), nn.Conv2d(mlp_multiplier * embed_dim, embed_dim, kernel_size=1, padding="same"), nn.Dropout(dropout_level), ) def forward(self, x): w = h = int(np.sqrt(x.size(1))) # only square images for now x = rearrange(x, "bs (h w) d -> bs d h w", h=h, w=w) x = self.mlp(x) x = rearrange(x, "bs d h w -> bs (h w) d") return x class DecoderBlock(nn.Module): def __init__( self, embed_dim: int, is_causal: bool, mlp_multiplier: int, dropout_level: float, mlp_class: type[MLP] | type[MLPSepConv], ): super().__init__() self.self_attention = SelfAttention(embed_dim, is_causal, dropout_level, n_heads=embed_dim // 64) self.cross_attention = CrossAttention( embed_dim, is_causal=False, dropout_level=0, n_heads=embed_dim // 64 ) self.mlp = mlp_class(embed_dim, mlp_multiplier, dropout_level) self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) self.norm3 = nn.LayerNorm(embed_dim) def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: x = self.self_attention(self.norm1(x)) + x x = self.cross_attention(self.norm2(x), y) + x x = self.mlp(self.norm3(x)) + x return x