# https://raw.githubusercontent.com/CompVis/latent-diffusion/e66308c7f2e64cb581c6d27ab6fbeb846828253b/ldm/modules/distributions/distributions.py import torch import numpy as np from pdb import set_trace as st class AbstractDistribution: def sample(self): raise NotImplementedError() def mode(self): raise NotImplementedError() class DiracDistribution(AbstractDistribution): def __init__(self, value): self.value = value def sample(self): return self.value def mode(self): return self.value @torch.jit.script def soft_clamp20(x: torch.Tensor): # return x.div(5.).tanh_().mul(5.) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5] # return x.div(5.).tanh().mul(5.) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5] # return x.div(15.).tanh().mul(15.) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5] return x.div(20.).tanh().mul( 20. ) # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5] # @torch.jit.script # def soft_clamp(x: torch.Tensor, a: torch.Tensor): # return x.div(a).tanh_().mul(a) class DiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False, soft_clamp=False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) if soft_clamp: # self.mean, self.logvar = soft_clamp5(self.mean), soft_clamp5(self.logvar) # as in LSGM, bound the range. needs re-training? self.logvar = soft_clamp20( self.logvar) # as in LSGM, bound the range. [-20, 20] else: self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like( self.mean).to(device=self.parameters.device) def sample(self): x = self.mean + self.std * torch.randn( self.mean.shape).to(device=self.parameters.device) return x # https://github.dev/NVlabs/LSGM/util/distributions.py def log_p(self, samples): # for calculating the negative encoder entropy term normalized_samples = (samples - self.mean) / self.var log_p = -0.5 * normalized_samples * normalized_samples - 0.5 * np.log( 2 * np.pi) - self.logvar # return log_p # ! TODO def normal_entropy(self): # for calculating normal entropy. Motivation: supervise logvar directly. # normalized_samples = (samples - self.mean) / self.var # log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - self.logvar # # entropy = torch.sum(self.logvar + 0.5 * (np.log(2 * np.pi) + 1), # dim=[1, 2, 3]).mean(0) # entropy = torch.mean(self.logvar + 0.5 * (np.log(2 * np.pi) + 1)) # follow eps loss tradition here, average overall dims. entropy = self.logvar + 0.5 * (np.log(2 * np.pi) + 1) # follow eps loss tradition here, average overall dims. return entropy # ! TODO def kl(self, other=None): if self.deterministic: return torch.Tensor([0.]) else: if other is None: return 0.5 * torch.sum( torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3]) def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: return torch.Tensor([0.]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): return self.mean def normal_kl(mean1, logvar1, mean2, logvar2): """ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 Compute the KL divergence between two gaussians. Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases. """ tensor = None for obj in (mean1, logvar1, mean2, logvar2): if isinstance(obj, torch.Tensor): tensor = obj break assert tensor is not None, "at least one argument must be a Tensor" # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). logvar1, logvar2 = [ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2) ] return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2)**2) * torch.exp(-logvar2))