Spaces:
Runtime error
Runtime error
File size: 1,526 Bytes
3dd84f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
from torch import nn
class UnifyGenerator(nn.Module):
def __init__(
self,
backbone: nn.Module,
head: nn.Module,
vq: nn.Module | None = None,
):
super().__init__()
self.backbone = backbone
self.head = head
self.vq = vq
def forward(self, x: torch.Tensor, template=None) -> torch.Tensor:
x = self.backbone(x)
if self.vq is not None:
vq_result = self.vq(x)
x = vq_result.z
x = self.head(x, template=template)
if x.ndim == 2:
x = x[:, None, :]
if self.vq is not None:
return x, vq_result
return x
def encode(self, x: torch.Tensor) -> torch.Tensor:
if self.vq is None:
raise ValueError("VQ module is not present in the model.")
x = self.backbone(x)
vq_result = self.vq(x)
return vq_result.codes
def decode(self, codes: torch.Tensor, template=None) -> torch.Tensor:
if self.vq is None:
raise ValueError("VQ module is not present in the model.")
x = self.vq.from_codes(codes)[0]
x = self.head(x, template=template)
if x.ndim == 2:
x = x[:, None, :]
return x
def remove_parametrizations(self):
if hasattr(self.backbone, "remove_parametrizations"):
self.backbone.remove_parametrizations()
if hasattr(self.head, "remove_parametrizations"):
self.head.remove_parametrizations() |