Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch import tensor as tt | |
from typing import Optional, Tuple, Type | |
import pyro | |
import pyro.distributions as dist | |
import warnings | |
from atoms_detection.vae_image_utils import imcoordgrid, to_onehot, transform_coordinates | |
warnings.filterwarnings("ignore", module="torchvision.datasets") | |
# VAE model set-up | |
# @title Load neural networks for VAE { form-width: "25%" } | |
def set_deterministic_mode(seed: int) -> None: | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def make_fc_layers(in_dim: int, | |
hidden_dim: int = 128, | |
num_layers: int = 2, | |
activation: str = "tanh" | |
) -> Type[nn.Module]: | |
""" | |
Generates a module with stacked fully-connected (aka dense) layers | |
""" | |
activations = {"tanh": nn.Tanh, "lrelu": nn.LeakyReLU, "softplus": nn.Softplus} | |
fc_layers = [] | |
for i in range(num_layers): | |
hidden_dim_ = in_dim if i == 0 else hidden_dim | |
fc_layers.extend( | |
[nn.Linear(hidden_dim_, hidden_dim), activations[activation]()]) | |
fc_layers = nn.Sequential(*fc_layers) | |
return fc_layers | |
class fcEncoderNet(nn.Module): | |
""" | |
Simple fully-connected inference (encoder) network | |
""" | |
def __init__(self, | |
in_dim: Tuple[int,int], | |
latent_dim: int = 2, | |
hidden_dim: int = 128, | |
num_layers: int = 2, | |
activation: str = 'tanh', | |
softplus_out: bool = False | |
) -> None: | |
""" | |
Initializes module parameters | |
""" | |
super(fcEncoderNet, self).__init__() | |
if len(in_dim) not in [1, 2, 3]: | |
raise ValueError("in_dim must be (h, w), (h, w, c), or (h*w*c,)") | |
self.in_dim = torch.prod(tt(in_dim)).item() | |
self.fc_layers = make_fc_layers( | |
self.in_dim, hidden_dim, num_layers, activation) | |
self.fc11 = nn.Linear(hidden_dim, latent_dim) | |
self.fc12 = nn.Linear(hidden_dim, latent_dim) | |
self.activation_out = nn.Softplus() if softplus_out else lambda x: x | |
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: | |
""" | |
Forward pass | |
""" | |
x = x.view(-1, self.in_dim) | |
x = self.fc_layers(x) | |
mu = self.fc11(x) | |
log_sigma = self.activation_out(self.fc12(x)) | |
return mu, log_sigma | |
class fcDecoderNet(nn.Module): | |
""" | |
Standard decoder for VAE | |
""" | |
def __init__(self, | |
out_dim: Tuple[int], | |
latent_dim: int, | |
hidden_dim: int = 128, | |
num_layers: int = 2, | |
activation: str = 'tanh', | |
sigmoid_out: str = True, | |
) -> None: | |
super(fcDecoderNet, self).__init__() | |
if len(out_dim) not in [1, 2, 3]: | |
raise ValueError("in_dim must be (h, w), (h, w, c), or (h*w*c,)") | |
self.reshape = out_dim | |
out_dim = torch.prod(tt(out_dim)).item() | |
self.fc_layers = make_fc_layers( | |
latent_dim, hidden_dim, num_layers, activation) | |
self.out = nn.Linear(hidden_dim, out_dim) | |
self.activation_out = nn.Sigmoid() if sigmoid_out else lambda x: x | |
def forward(self, z: torch.Tensor) -> torch.Tensor: | |
x = self.fc_layers(z) | |
x = self.activation_out(self.out(x)) | |
return x.view(-1, *self.reshape) | |
class rDecoderNet(nn.Module): | |
""" | |
Spatial generator (decoder) network with fully-connected layers | |
""" | |
def __init__(self, | |
out_dim: Tuple[int], | |
latent_dim: int, | |
hidden_dim: int = 128, | |
num_layers: int = 2, | |
activation: str = 'tanh', | |
sigmoid_out: str = True | |
) -> None: | |
""" | |
Initializes module parameters | |
""" | |
super(rDecoderNet, self).__init__() | |
if len(out_dim) not in [1, 2, 3]: | |
raise ValueError("in_dim must be (h, w), (h, w, c), or (h*w*c,)") | |
self.reshape = out_dim | |
out_dim = torch.prod(tt(out_dim)).item() | |
self.coord_latent = coord_latent(latent_dim, hidden_dim) | |
self.fc_layers = make_fc_layers( | |
hidden_dim, hidden_dim, num_layers, activation) | |
self.out = nn.Linear(hidden_dim, 1) # need to generalize to multi-channel (c > 1) | |
self.activation_out = nn.Sigmoid() if sigmoid_out else lambda x: x | |
def forward(self, x_coord: torch.Tensor, z: torch.Tensor) -> torch.Tensor: | |
""" | |
Forward pass | |
""" | |
x = self.coord_latent(x_coord, z) | |
x = self.fc_layers(x) | |
x = self.activation_out(self.out(x)) | |
return x.view(-1, *self.reshape) | |
class coord_latent(nn.Module): | |
""" | |
The "spatial" part of the rVAE's decoder that allows for translational | |
and rotational invariance (based on https://arxiv.org/abs/1909.11663) | |
""" | |
def __init__(self, | |
latent_dim: int, | |
out_dim: int, | |
activation_out: bool = True) -> None: | |
""" | |
Iniitalizes modules parameters | |
""" | |
super(coord_latent, self).__init__() | |
self.fc_coord = nn.Linear(2, out_dim) | |
self.fc_latent = nn.Linear(latent_dim, out_dim, bias=False) | |
self.activation = nn.Tanh() if activation_out else None | |
def forward(self, | |
x_coord: torch.Tensor, | |
z: torch.Tensor) -> torch.Tensor: | |
""" | |
Forward pass | |
""" | |
batch_dim, n = x_coord.size()[:2] | |
x_coord = x_coord.reshape(batch_dim * n, -1) | |
h_x = self.fc_coord(x_coord) | |
h_x = h_x.reshape(batch_dim, n, -1) | |
h_z = self.fc_latent(z) | |
h = h_x.add(h_z.unsqueeze(1)) | |
h = h.reshape(batch_dim * n, -1) | |
if self.activation is not None: | |
h = self.activation(h) | |
return h | |
class rVAE(nn.Module): | |
""" | |
Variational autoencoder with rotational and/or transaltional invariance | |
""" | |
def __init__(self, | |
in_dim: Tuple[int, int], | |
latent_dim: int = 2, | |
coord: int = 3, | |
num_classes: int = 0, | |
hidden_dim_e: int = 128, | |
hidden_dim_d: int = 128, | |
num_layers_e: int = 2, | |
num_layers_d: int = 2, | |
activation: str = "tanh", | |
softplus_sd: bool = True, | |
sigmoid_out: bool = True, | |
seed: int = 1, | |
**kwargs | |
) -> None: | |
""" | |
Initializes rVAE's modules and parameters | |
""" | |
super(rVAE, self).__init__() | |
pyro.clear_param_store() | |
set_deterministic_mode(seed) | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.encoder_net = fcEncoderNet( | |
in_dim, latent_dim+coord, hidden_dim_e, | |
num_layers_e, activation, softplus_sd) | |
if coord not in [0, 1, 2, 3]: | |
raise ValueError("'coord' argument must be 0, 1, 2 or 3") | |
dnet = rDecoderNet if coord in [1, 2, 3] else fcDecoderNet | |
self.decoder_net = dnet( | |
in_dim, latent_dim+num_classes, hidden_dim_d, | |
num_layers_d, activation, sigmoid_out) | |
self.z_dim = latent_dim + coord | |
self.coord = coord | |
self.num_classes = num_classes | |
self.grid = imcoordgrid(in_dim).to(self.device) | |
self.dx_prior = tt(kwargs.get("dx_prior", 0.1)).to(self.device) | |
self.to(self.device) | |
def model(self, | |
x: torch.Tensor, | |
y: Optional[torch.Tensor] = None, | |
**kwargs: float) -> torch.Tensor: | |
""" | |
Defines the model p(x|z)p(z) | |
""" | |
# register PyTorch module `decoder_net` with Pyro | |
pyro.module("decoder_net", self.decoder_net) | |
# KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl) | |
beta = kwargs.get("scale_factor", 1.) | |
reshape_ = torch.prod(tt(x.shape[1:])).item() | |
with pyro.plate("data", x.shape[0]): | |
# setup hyperparameters for prior p(z) | |
z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim))) | |
z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim))) | |
# sample from prior (value will be sampled by guide when computing the ELBO) | |
with pyro.poutine.scale(scale=beta): | |
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) | |
if self.coord > 0: # rotationally- and/or translationaly-invariant mode | |
# Split latent variable into parts for rotation | |
# and/or translation and image content | |
phi, dx, z = self.split_latent(z) | |
if torch.sum(dx) != 0: | |
dx = (dx * self.dx_prior).unsqueeze(1) | |
# transform coordinate grid | |
grid = self.grid.expand(x.shape[0], *self.grid.shape) | |
x_coord_prime = transform_coordinates(grid, phi, dx) | |
# Add class label (if any) | |
if y is not None: | |
y = to_onehot(y, self.num_classes) | |
z = torch.cat([z, y], dim=-1) | |
# decode the latent code z together with the transformed coordiantes (if any) | |
dec_args = (x_coord_prime, z) if self.coord else (z,) | |
loc_img = self.decoder_net(*dec_args) | |
# score against actual images ("binary cross-entropy loss") | |
pyro.sample( | |
"obs", dist.Bernoulli(loc_img.view(-1, reshape_), validate_args=False).to_event(1), | |
obs=x.view(-1, reshape_)) | |
def guide(self, | |
x: torch.Tensor, | |
y: Optional[torch.Tensor] = None, | |
**kwargs: float) -> torch.Tensor: | |
""" | |
Defines the guide q(z|x) | |
""" | |
# register PyTorch module `encoder_net` with Pyro | |
pyro.module("encoder_net", self.encoder_net) | |
# KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl) | |
beta = kwargs.get("scale_factor", 1.) | |
with pyro.plate("data", x.shape[0]): | |
# use the encoder to get the parameters used to define q(z|x) | |
z_loc, z_scale = self.encoder_net(x) | |
# sample the latent code z | |
with pyro.poutine.scale(scale=beta): | |
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) | |
def split_latent(self, z: torch.Tensor) -> Tuple[torch.Tensor]: | |
""" | |
Split latent variable into parts for rotation | |
and/or translation and image content | |
""" | |
phi, dx = tt(0), tt(0) | |
# rotation + translation | |
if self.coord == 3: | |
phi = z[:, 0] # encoded angle | |
dx = z[:, 1:3] # translation | |
z = z[:, 3:] # image content | |
# translation only | |
elif self.coord == 2: | |
dx = z[:, :2] | |
z = z[:, 2:] | |
# rotation only | |
elif self.coord == 1: | |
phi = z[:, 0] | |
z = z[:, 1:] | |
return phi, dx, z | |
def _encode(self, x_new: torch.Tensor, **kwargs: int) -> torch.Tensor: | |
""" | |
Encodes data using a trained inference (encoder) network | |
in a batch-by-batch fashion | |
""" | |
def inference() -> np.ndarray: | |
with torch.no_grad(): | |
encoded = self.encoder_net(x_i) | |
encoded = torch.cat(encoded, -1).cpu() | |
return encoded | |
x_new = x_new.to(self.device) | |
num_batches = kwargs.get("num_batches", 10) | |
batch_size = len(x_new) // num_batches | |
z_encoded = [] | |
for i in range(num_batches): | |
x_i = x_new[i*batch_size:(i+1)*batch_size] | |
z_encoded_i = inference() | |
z_encoded.append(z_encoded_i) | |
x_i = x_new[(i+1)*batch_size:] | |
if len(x_i) > 0: | |
z_encoded_i = inference() | |
z_encoded.append(z_encoded_i) | |
return torch.cat(z_encoded) | |
def encode(self, x_new: torch.Tensor, **kwargs: int) -> torch.Tensor: | |
""" | |
Encodes data using a trained inference (encoder) network | |
(this is baiscally a wrapper for self._encode) | |
""" | |
if isinstance(x_new, torch.utils.data.DataLoader): | |
x_new = train_loader.dataset.tensors[0] | |
z = self._encode(x_new) | |
z_loc = z[:, :self.z_dim] | |
z_scale = z[:, self.z_dim:] | |
return z_loc, z_scale | |