atom-detection / atoms_detection /vae_svi_train.py
Romain Graux
Initial commit with ml code and webapp
b2ffc9b
raw
history blame
4.38 kB
from typing import Optional, Type
import torch
import torch.nn as nn
import pyro
import pyro.infer as infer
import pyro.optim as optim
import warnings
#from vae_model import set_deterministic_mode as set_deterministic_mode
from atoms_detection.vae_model import set_deterministic_mode as set_deterministic_mode
warnings.filterwarnings("ignore", module="torchvision.datasets")
class SVItrainer:
"""
Stochastic variational inference (SVI) trainer for
unsupervised and class-conditioned variational models
"""
def __init__(self,
model: Type[nn.Module],
optimizer: Type[optim.PyroOptim] = None,
loss: Type[infer.ELBO] = None,
seed: int = 1
) -> None:
"""
Initializes the trainer's parameters
"""
pyro.clear_param_store()
set_deterministic_mode(seed)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if optimizer is None:
optimizer = optim.Adam({"lr": 1.0e-3})
if loss is None:
loss = infer.Trace_ELBO()
self.svi = infer.SVI(model.model, model.guide, optimizer, loss=loss)
self.loss_history = {"training_loss": [], "test_loss": []}
self.current_epoch = 0
def train(self,
train_loader: Type[torch.utils.data.DataLoader],
**kwargs: float) -> float:
"""
Trains a single epoch
"""
# initialize loss accumulator
epoch_loss = 0.
# do a training epoch over each mini-batch returned by the data loader
for data in train_loader:
if len(data) == 1: # VAE mode
x = data[0]
loss = self.svi.step(x.to(self.device), **kwargs)
else: # VED or cVAE mode
x, y = data
loss = self.svi.step(
x.to(self.device), y.to(self.device), **kwargs)
# do ELBO gradient and accumulate loss
epoch_loss += loss
return epoch_loss / len(train_loader.dataset)
def evaluate(self,
test_loader: Type[torch.utils.data.DataLoader],
**kwargs: float) -> float:
"""
Evaluates current models state on a single epoch
"""
# initialize loss accumulator
test_loss = 0.
# compute the loss over the entire test set
with torch.no_grad():
for data in test_loader:
if len(data) == 1: # VAE mode
x = data[0]
loss = self.svi.step(x.to(self.device), **kwargs)
else: # VED or cVAE mode
x, y = data
loss = self.svi.step(
x.to(self.device), y.to(self.device), **kwargs)
test_loss += loss
return test_loss / len(test_loader.dataset)
def step(self,
train_loader: Type[torch.utils.data.DataLoader],
test_loader: Optional[Type[torch.utils.data.DataLoader]] = None,
**kwargs: float) -> None:
"""
Single training and (optionally) evaluation step
"""
self.loss_history["training_loss"].append(self.train(train_loader, **kwargs))
if test_loader is not None:
self.loss_history["test_loss"].append(self.evaluate(test_loader, **kwargs))
self.current_epoch += 1
def print_statistics(self) -> None:
"""
Prints training and test (if any) losses for current epoch
"""
e = self.current_epoch
if len(self.loss_history["test_loss"]) > 0:
template = 'Epoch: {} Training loss: {:.4f}, Test loss: {:.4f}'
print(template.format(e, self.loss_history["training_loss"][-1],
self.loss_history["test_loss"][-1]))
else:
template = 'Epoch: {} Training loss: {:.4f}'
print(template.format(e, self.loss_history["training_loss"][-1]))
def init_dataloader(*args: torch.Tensor, **kwargs: int
) -> Type[torch.utils.data.DataLoader]:
batch_size = kwargs.get("batch_size", 100)
tensor_set = torch.utils.data.dataset.TensorDataset(*args)
data_loader = torch.utils.data.DataLoader(
dataset=tensor_set, batch_size=batch_size, shuffle=True)
return data_loader