KdaiP's picture
Upload 80 files
3dd84f8 verified
raw
history blame
1.53 kB
import torch
from torch import nn
class UnifyGenerator(nn.Module):
def __init__(
self,
backbone: nn.Module,
head: nn.Module,
vq: nn.Module | None = None,
):
super().__init__()
self.backbone = backbone
self.head = head
self.vq = vq
def forward(self, x: torch.Tensor, template=None) -> torch.Tensor:
x = self.backbone(x)
if self.vq is not None:
vq_result = self.vq(x)
x = vq_result.z
x = self.head(x, template=template)
if x.ndim == 2:
x = x[:, None, :]
if self.vq is not None:
return x, vq_result
return x
def encode(self, x: torch.Tensor) -> torch.Tensor:
if self.vq is None:
raise ValueError("VQ module is not present in the model.")
x = self.backbone(x)
vq_result = self.vq(x)
return vq_result.codes
def decode(self, codes: torch.Tensor, template=None) -> torch.Tensor:
if self.vq is None:
raise ValueError("VQ module is not present in the model.")
x = self.vq.from_codes(codes)[0]
x = self.head(x, template=template)
if x.ndim == 2:
x = x[:, None, :]
return x
def remove_parametrizations(self):
if hasattr(self.backbone, "remove_parametrizations"):
self.backbone.remove_parametrizations()
if hasattr(self.head, "remove_parametrizations"):
self.head.remove_parametrizations()