image-gen / tld /denoiser.py
BeveledCube's picture
Trying sum
3b3a783
raw
history blame
No virus
3.88 kB
"""transformer based denoiser"""
import torch
from einops.layers.torch import Rearrange
from torch import nn
from tld.transformer_blocks import DecoderBlock, MLPSepConv, SinusoidalEmbedding
class DenoiserTransBlock(nn.Module):
def __init__(
self,
patch_size: int,
img_size: int,
embed_dim: int,
dropout: float,
n_layers: int,
mlp_multiplier: int = 4,
n_channels: int = 4,
):
super().__init__()
self.patch_size = patch_size
self.img_size = img_size
self.n_channels = n_channels
self.embed_dim = embed_dim
self.dropout = dropout
self.n_layers = n_layers
self.mlp_multiplier = mlp_multiplier
seq_len = int((self.img_size / self.patch_size) * (self.img_size / self.patch_size))
patch_dim = self.n_channels * self.patch_size * self.patch_size
self.patchify_and_embed = nn.Sequential(
nn.Conv2d(
self.n_channels,
patch_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
),
Rearrange("bs d h w -> bs (h w) d"),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, self.embed_dim),
nn.LayerNorm(self.embed_dim),
)
self.rearrange2 = Rearrange(
"b (h w) (c p1 p2) -> b c (h p1) (w p2)",
h=int(self.img_size / self.patch_size),
p1=self.patch_size,
p2=self.patch_size,
)
self.pos_embed = nn.Embedding(seq_len, self.embed_dim)
self.register_buffer("precomputed_pos_enc", torch.arange(0, seq_len).long())
self.decoder_blocks = nn.ModuleList(
[
DecoderBlock(
embed_dim=self.embed_dim,
mlp_multiplier=self.mlp_multiplier,
# note that this is a non-causal block since we are
# denoising the entire image no need for masking
is_causal=False,
dropout_level=self.dropout,
mlp_class=MLPSepConv,
)
for _ in range(self.n_layers)
]
)
self.out_proj = nn.Sequential(nn.Linear(self.embed_dim, patch_dim), self.rearrange2)
def forward(self, x, cond):
x = self.patchify_and_embed(x)
pos_enc = self.precomputed_pos_enc[: x.size(1)].expand(x.size(0), -1)
x = x + self.pos_embed(pos_enc)
for block in self.decoder_blocks:
x = block(x, cond)
return self.out_proj(x)
class Denoiser(nn.Module):
def __init__(
self,
image_size: int,
noise_embed_dims: int,
patch_size: int,
embed_dim: int,
dropout: float,
n_layers: int,
text_emb_size: int = 768,
):
super().__init__()
self.image_size = image_size
self.noise_embed_dims = noise_embed_dims
self.embed_dim = embed_dim
self.fourier_feats = nn.Sequential(
SinusoidalEmbedding(embedding_dims=noise_embed_dims),
nn.Linear(noise_embed_dims, self.embed_dim),
nn.GELU(),
nn.Linear(self.embed_dim, self.embed_dim),
)
self.denoiser_trans_block = DenoiserTransBlock(patch_size, image_size, embed_dim, dropout, n_layers)
self.norm = nn.LayerNorm(self.embed_dim)
self.label_proj = nn.Linear(text_emb_size, self.embed_dim)
def forward(self, x, noise_level, label):
noise_level = self.fourier_feats(noise_level).unsqueeze(1)
label = self.label_proj(label).unsqueeze(1)
noise_label_emb = torch.cat([noise_level, label], dim=1) # bs, 2, d
noise_label_emb = self.norm(noise_label_emb)
x = self.denoiser_trans_block(x, noise_label_emb)
return x