from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from vector_quantize_pytorch import GroupedResidualFSQ from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet @dataclass class FSQResult: z: torch.Tensor codes: torch.Tensor latents: torch.Tensor class DownsampleFiniteScalarQuantize(nn.Module): def __init__( self, input_dim: int = 512, n_codebooks: int = 9, n_groups: int = 1, levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10 downsample_factor: tuple[int] = (2, 2), downsample_dims: tuple[int] | None = None, ): super().__init__() if downsample_dims is None: downsample_dims = [input_dim for _ in range(len(downsample_factor))] all_dims = (input_dim,) + tuple(downsample_dims) self.residual_fsq = GroupedResidualFSQ( dim=all_dims[-1], levels=levels, num_quantizers=n_codebooks, groups=n_groups, ) self.downsample_factor = downsample_factor self.downsample_dims = downsample_dims self.downsample = nn.Sequential( *[ nn.Sequential( FishConvNet( all_dims[idx], all_dims[idx + 1], kernel_size=factor, stride=factor, ), ConvNeXtBlock(dim=all_dims[idx + 1]), ) for idx, factor in enumerate(downsample_factor) ] ) self.upsample = nn.Sequential( *[ nn.Sequential( FishTransConvNet( all_dims[idx + 1], all_dims[idx], kernel_size=factor, stride=factor, ), ConvNeXtBlock(dim=all_dims[idx]), ) for idx, factor in reversed(list(enumerate(downsample_factor))) ] ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv1d, nn.Linear)): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) def forward(self, z) -> FSQResult: original_shape = z.shape z = self.downsample(z) quantized, indices = self.residual_fsq(z.mT) result = FSQResult( z=quantized.mT, codes=indices.mT, latents=z, ) result.z = self.upsample(result.z) # Pad or crop z to match original shape diff = original_shape[-1] - result.z.shape[-1] left = diff // 2 right = diff - left if diff > 0: result.z = F.pad(result.z, (left, right)) elif diff < 0: result.z = result.z[..., left:-right] return result def encode(self, z): z = self.downsample(z) _, indices = self.residual_fsq(z.mT) indices = rearrange(indices, "g b l r -> b (g r) l") return indices def decode(self, indices: torch.Tensor): indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups) z_q = self.residual_fsq.get_output_from_indices(indices) z_q = self.upsample(z_q.mT) return z_q