Spaces:
Sleeping
Sleeping
from typing import Optional, Type | |
import torch | |
import torch.nn as nn | |
import pyro | |
import pyro.infer as infer | |
import pyro.optim as optim | |
import warnings | |
#from vae_model import set_deterministic_mode as set_deterministic_mode | |
from atoms_detection.vae_model import set_deterministic_mode as set_deterministic_mode | |
warnings.filterwarnings("ignore", module="torchvision.datasets") | |
class SVItrainer: | |
""" | |
Stochastic variational inference (SVI) trainer for | |
unsupervised and class-conditioned variational models | |
""" | |
def __init__(self, | |
model: Type[nn.Module], | |
optimizer: Type[optim.PyroOptim] = None, | |
loss: Type[infer.ELBO] = None, | |
seed: int = 1 | |
) -> None: | |
""" | |
Initializes the trainer's parameters | |
""" | |
pyro.clear_param_store() | |
set_deterministic_mode(seed) | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
if optimizer is None: | |
optimizer = optim.Adam({"lr": 1.0e-3}) | |
if loss is None: | |
loss = infer.Trace_ELBO() | |
self.svi = infer.SVI(model.model, model.guide, optimizer, loss=loss) | |
self.loss_history = {"training_loss": [], "test_loss": []} | |
self.current_epoch = 0 | |
def train(self, | |
train_loader: Type[torch.utils.data.DataLoader], | |
**kwargs: float) -> float: | |
""" | |
Trains a single epoch | |
""" | |
# initialize loss accumulator | |
epoch_loss = 0. | |
# do a training epoch over each mini-batch returned by the data loader | |
for data in train_loader: | |
if len(data) == 1: # VAE mode | |
x = data[0] | |
loss = self.svi.step(x.to(self.device), **kwargs) | |
else: # VED or cVAE mode | |
x, y = data | |
loss = self.svi.step( | |
x.to(self.device), y.to(self.device), **kwargs) | |
# do ELBO gradient and accumulate loss | |
epoch_loss += loss | |
return epoch_loss / len(train_loader.dataset) | |
def evaluate(self, | |
test_loader: Type[torch.utils.data.DataLoader], | |
**kwargs: float) -> float: | |
""" | |
Evaluates current models state on a single epoch | |
""" | |
# initialize loss accumulator | |
test_loss = 0. | |
# compute the loss over the entire test set | |
with torch.no_grad(): | |
for data in test_loader: | |
if len(data) == 1: # VAE mode | |
x = data[0] | |
loss = self.svi.step(x.to(self.device), **kwargs) | |
else: # VED or cVAE mode | |
x, y = data | |
loss = self.svi.step( | |
x.to(self.device), y.to(self.device), **kwargs) | |
test_loss += loss | |
return test_loss / len(test_loader.dataset) | |
def step(self, | |
train_loader: Type[torch.utils.data.DataLoader], | |
test_loader: Optional[Type[torch.utils.data.DataLoader]] = None, | |
**kwargs: float) -> None: | |
""" | |
Single training and (optionally) evaluation step | |
""" | |
self.loss_history["training_loss"].append(self.train(train_loader, **kwargs)) | |
if test_loader is not None: | |
self.loss_history["test_loss"].append(self.evaluate(test_loader, **kwargs)) | |
self.current_epoch += 1 | |
def print_statistics(self) -> None: | |
""" | |
Prints training and test (if any) losses for current epoch | |
""" | |
e = self.current_epoch | |
if len(self.loss_history["test_loss"]) > 0: | |
template = 'Epoch: {} Training loss: {:.4f}, Test loss: {:.4f}' | |
print(template.format(e, self.loss_history["training_loss"][-1], | |
self.loss_history["test_loss"][-1])) | |
else: | |
template = 'Epoch: {} Training loss: {:.4f}' | |
print(template.format(e, self.loss_history["training_loss"][-1])) | |
def init_dataloader(*args: torch.Tensor, **kwargs: int | |
) -> Type[torch.utils.data.DataLoader]: | |
batch_size = kwargs.get("batch_size", 100) | |
tensor_set = torch.utils.data.dataset.TensorDataset(*args) | |
data_loader = torch.utils.data.DataLoader( | |
dataset=tensor_set, batch_size=batch_size, shuffle=True) | |
return data_loader | |