Spaces:
Running
on
L40S
Running
on
L40S
from dataclasses import dataclass | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
from vector_quantize_pytorch import GroupedResidualFSQ | |
from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet | |
class FSQResult: | |
z: torch.Tensor | |
codes: torch.Tensor | |
latents: torch.Tensor | |
class DownsampleFiniteScalarQuantize(nn.Module): | |
def __init__( | |
self, | |
input_dim: int = 512, | |
n_codebooks: int = 9, | |
n_groups: int = 1, | |
levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10 | |
downsample_factor: tuple[int] = (2, 2), | |
downsample_dims: tuple[int] | None = None, | |
): | |
super().__init__() | |
if downsample_dims is None: | |
downsample_dims = [input_dim for _ in range(len(downsample_factor))] | |
all_dims = (input_dim,) + tuple(downsample_dims) | |
self.residual_fsq = GroupedResidualFSQ( | |
dim=all_dims[-1], | |
levels=levels, | |
num_quantizers=n_codebooks, | |
groups=n_groups, | |
) | |
self.downsample_factor = downsample_factor | |
self.downsample_dims = downsample_dims | |
self.downsample = nn.Sequential( | |
*[ | |
nn.Sequential( | |
FishConvNet( | |
all_dims[idx], | |
all_dims[idx + 1], | |
kernel_size=factor, | |
stride=factor, | |
), | |
ConvNeXtBlock(dim=all_dims[idx + 1]), | |
) | |
for idx, factor in enumerate(downsample_factor) | |
] | |
) | |
self.upsample = nn.Sequential( | |
*[ | |
nn.Sequential( | |
FishTransConvNet( | |
all_dims[idx + 1], | |
all_dims[idx], | |
kernel_size=factor, | |
stride=factor, | |
), | |
ConvNeXtBlock(dim=all_dims[idx]), | |
) | |
for idx, factor in reversed(list(enumerate(downsample_factor))) | |
] | |
) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, (nn.Conv1d, nn.Linear)): | |
nn.init.trunc_normal_(m.weight, std=0.02) | |
nn.init.constant_(m.bias, 0) | |
def forward(self, z) -> FSQResult: | |
original_shape = z.shape | |
z = self.downsample(z) | |
quantized, indices = self.residual_fsq(z.mT) | |
result = FSQResult( | |
z=quantized.mT, | |
codes=indices.mT, | |
latents=z, | |
) | |
result.z = self.upsample(result.z) | |
# Pad or crop z to match original shape | |
diff = original_shape[-1] - result.z.shape[-1] | |
left = diff // 2 | |
right = diff - left | |
if diff > 0: | |
result.z = F.pad(result.z, (left, right)) | |
elif diff < 0: | |
result.z = result.z[..., left:-right] | |
return result | |
def encode(self, z): | |
z = self.downsample(z) | |
_, indices = self.residual_fsq(z.mT) | |
indices = rearrange(indices, "g b l r -> b (g r) l") | |
return indices | |
def decode(self, indices: torch.Tensor): | |
indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups) | |
z_q = self.residual_fsq.get_output_from_indices(indices) | |
z_q = self.upsample(z_q.mT) | |
return z_q | |