Spaces:
Runtime error
Runtime error
File size: 4,698 Bytes
12aae2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
"""
Based on https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/attention.py
"""
from typing import Optional
from collections import namedtuple
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
from embeddings import TimestepEmbedding, Timesteps, Positions2d
class TemporalAxialAttention(nn.Module):
def __init__(
self,
dim: int,
heads: int = 4,
dim_head: int = 32,
is_causal: bool = True,
rotary_emb: Optional[RotaryEmbedding] = None,
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
self.to_out = nn.Linear(self.inner_dim, dim)
self.rotary_emb = rotary_emb
self.time_pos_embedding = (
nn.Sequential(
Timesteps(dim),
TimestepEmbedding(in_channels=dim, time_embed_dim=dim * 4, out_dim=dim),
)
if rotary_emb is None
else None
)
self.is_causal = is_causal
def forward(self, x: torch.Tensor):
B, T, H, W, D = x.shape
if self.time_pos_embedding is not None:
time_emb = self.time_pos_embedding(
torch.arange(T, device=x.device)
)
x = x + rearrange(time_emb, "t d -> 1 t 1 1 d")
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads)
k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads)
v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads)
if self.rotary_emb is not None:
q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs)
k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs)
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
x = F.scaled_dot_product_attention(
query=q, key=k, value=v, is_causal=self.is_causal
)
x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
x = x.to(q.dtype)
# linear proj
x = self.to_out(x)
return x
class SpatialAxialAttention(nn.Module):
def __init__(
self,
dim: int,
heads: int = 4,
dim_head: int = 32,
rotary_emb: Optional[RotaryEmbedding] = None,
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
self.to_out = nn.Linear(self.inner_dim, dim)
self.rotary_emb = rotary_emb
self.space_pos_embedding = (
nn.Sequential(
Positions2d(dim),
TimestepEmbedding(in_channels=dim, time_embed_dim=dim * 4, out_dim=dim),
)
if rotary_emb is None
else None
)
def forward(self, x: torch.Tensor):
B, T, H, W, D = x.shape
if self.space_pos_embedding is not None:
h_steps = torch.arange(H, device=x.device)
w_steps = torch.arange(W, device=x.device)
grid = torch.meshgrid(h_steps, w_steps, indexing="ij")
space_emb = self.space_pos_embedding(grid)
x = x + rearrange(space_emb, "h w d -> 1 1 h w d")
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q = rearrange(q, "B T H W (h d) -> (B T) h H W d", h=self.heads)
k = rearrange(k, "B T H W (h d) -> (B T) h H W d", h=self.heads)
v = rearrange(v, "B T H W (h d) -> (B T) h H W d", h=self.heads)
if self.rotary_emb is not None:
freqs = self.rotary_emb.get_axial_freqs(H, W)
q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)
# prepare for attn
q = rearrange(q, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
k = rearrange(k, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
v = rearrange(v, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
x = F.scaled_dot_product_attention(
query=q, key=k, value=v, is_causal=False
)
x = rearrange(x, "(B T) h (H W) d -> B T H W (h d)", B=B, H=H, W=W)
x = x.to(q.dtype)
# linear proj
x = self.to_out(x)
return x
|