|
""" |
|
Discrete multinomial diffusion code adapted from https://github.com/RF5/transfusion-asr, |
|
which in turn is adapted from https://github.com/ehoogeboom/multinomial_diffusion. |
|
|
|
Please see the original repo (https://github.com/ehoogeboom/multinomial_diffusion) and paper for full |
|
details on how multinomial diffusion works -- thanks to the original authors! |
|
""" |
|
|
|
import torch |
|
from torch import Tensor |
|
from torch.functional import F |
|
import numpy as np |
|
from dataclasses import dataclass |
|
from typing import Union |
|
|
|
|
|
|
|
MIN_LOG_ARG = 1e-7 |
|
|
|
def log_1_min_a(a): return torch.log((1 - a.exp()).clamp_(min=1e-30)) |
|
|
|
def log_add_exp(a, b): |
|
maximum = torch.max(a, b) |
|
return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum)) |
|
|
|
def extract(a: Tensor, t, x_shape): |
|
""" Given 1D vector of alpha/alpha_cum/betas, get index at `t` of shape (bs,), and then |
|
broadcast it to number of dims in `x_shape`. |
|
""" |
|
b, *_ = t.shape |
|
out = a.gather(-1, t) |
|
return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
|
def index_to_log_onehot(x, num_classes, dim=-1, dtype=torch.float32): |
|
""" Convert indices `x` (bs, ...) to approx one-hot log-probs of shape (bs, ..., num_classes) """ |
|
assert x.max().item() < num_classes, \ |
|
f'Error: {x.max().item()} >= {num_classes}' |
|
x_onehot = F.one_hot(x, num_classes) |
|
if dim == 1: |
|
permute_order = (0, -1) + tuple(range(1, len(x.size()))) |
|
x_onehot = x_onehot.permute(permute_order) |
|
else: |
|
pass |
|
|
|
log_x = torch.log(x_onehot.to(dtype).clamp(min=MIN_LOG_ARG)) |
|
|
|
return log_x |
|
|
|
def sum_except_batch(x: Tensor, num_dims=1) -> Tensor: |
|
''' |
|
Sums all dimensions except the first. |
|
Args: |
|
x: Tensor, shape (batch_size, ...) |
|
num_dims: int, number of batch dims (default=1) |
|
Returns: |
|
x_sum: Tensor, shape (batch_size,) |
|
''' |
|
return x.reshape(*x.shape[:num_dims], -1).sum(-1) |
|
|
|
|
|
|
|
class MultinomialDiffusion(): |
|
def __init__(self, num_classes, timesteps=100, diffusion_s=0.008, |
|
loss_type='vb_stochastic', parametrization='x0', |
|
dtype=torch.float32, |
|
device='cpu'): |
|
super(MultinomialDiffusion, self).__init__() |
|
assert loss_type in ('vb_stochastic',) |
|
assert parametrization in ('x0', 'direct') |
|
|
|
self.num_classes = num_classes |
|
self.loss_type = loss_type |
|
self.num_timesteps = timesteps |
|
self.parametrization = parametrization |
|
|
|
alphas = self.cosine_beta_schedule(timesteps, diffusion_s) |
|
|
|
alphas = alphas.to(torch.float64) |
|
log_alpha = alphas.log() |
|
log_cumprod_alpha = torch.cumsum(log_alpha, dim=-1) |
|
|
|
log_1_min_alpha = log_1_min_a(log_alpha) |
|
|
|
log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha) |
|
a = log_add_exp(log_alpha, log_1_min_alpha) |
|
|
|
assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5 |
|
assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5 |
|
assert (torch.cumsum(log_alpha, dim=-1) - log_cumprod_alpha).abs().sum().item() < 1.e-5 |
|
|
|
|
|
self.log_alpha = log_alpha.to(dtype).to(device) |
|
self.log_1_min_alpha = log_1_min_alpha.to(dtype).to(device) |
|
self.log_cumprod_alpha = log_cumprod_alpha.to(dtype).to(device) |
|
self.log_1_min_cumprod_alpha = log_1_min_cumprod_alpha.to(dtype).to(device) |
|
|
|
@staticmethod |
|
def cosine_beta_schedule(timesteps, s=0.008) -> Tensor: |
|
""" |
|
cosine schedule as proposed in https://arxiv.org/abs/2102.09672 . |
|
Returns alpha parameters, NOT Beta |
|
""" |
|
steps = timesteps + 1 |
|
x = torch.linspace(0, timesteps, steps) |
|
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 |
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
|
alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
|
alphas = torch.clamp(alphas, 0.001, 1.0) |
|
return torch.sqrt(alphas) |
|
|
|
def multinomial_kl(self, log_prob1: Tensor, log_prob2: Tensor, dim=-1) -> Tensor: |
|
""" Get KL divergence between two categorical distributions specified with `log_prob1` and `log_prob2`. |
|
Assumed probability dim is `dim` (i.e. log_prob1.exp().sum(dim=`dim`) should be tensor of ones) |
|
""" |
|
kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=dim) |
|
return kl |
|
|
|
def q_pred_one_timestep(self, log_x_t: Tensor, t: Tensor) -> Tensor: |
|
""" Compute q(x_t | x_{t-1}) = C(x_t | alpha_t * x_{t-1} + (1-alpha_t)/K in the log-domain |
|
given `log_x_t` as log one-hot encoding of x_t. |
|
|
|
Recall due to symmetry property we can compute |
|
this value using x_t instead of x_{t-1} (se appendix A of https://arxiv.org/pdf/2102.05379.pdf) |
|
""" |
|
dt = log_x_t.dtype |
|
log_alpha_t = extract(self.log_alpha, t, log_x_t.shape).to(dt) |
|
log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape).to(dt) |
|
|
|
|
|
log_probs = log_add_exp( |
|
log_x_t + log_alpha_t, |
|
log_1_min_alpha_t - np.log(self.num_classes) |
|
) |
|
return log_probs |
|
|
|
def q_pred_one_timestep_scaled(self, log_x_t: Tensor, t: Tensor, c: int, jump_len: int) -> Tensor: |
|
""" Compute q(x_t | x_{t-1}) = C(x_t | alpha_t * x_{t-1} + (1-alpha_t)/K in the log-domain |
|
given `log_x_t` as log one-hot encoding of x_t. |
|
|
|
Recall due to symmetry property we can compute |
|
this value using x_t instead of x_{t-1} (se appendix A of https://arxiv.org/pdf/2102.05379.pdf) |
|
""" |
|
dt = log_x_t.dtype |
|
log_alpha_t = extract(self.log_alpha, t, log_x_t.shape).to(dt) |
|
log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape).to(dt) |
|
|
|
|
|
xax = torch.arange(0,log_x_t.shape[1],1).to(log_x_t.device) |
|
aa=log_x_t.shape[1]*(c/jump_len) |
|
sig = 1/(1+torch.exp(-(xax-aa+20)/8)) |
|
log_alpha_t = (torch.log(1/sig)[None,:,None] + log_alpha_t).clamp(-torch.inf, 0) |
|
log_1_min_alpha_t = torch.log(sig)[None,:,None] + log_1_min_alpha_t |
|
|
|
|
|
log_probs = log_add_exp( |
|
log_x_t + log_alpha_t, |
|
log_1_min_alpha_t - np.log(self.num_classes) |
|
) |
|
return log_probs |
|
|
|
def q_pred(self, log_x_start: Tensor, t) -> Tensor: |
|
""" Compute q(x_t | x_0) = C(x_t | bar{alpha}_t * x_0 + (1 - bar{alpha}_t)/K ) in log domain, |
|
given `log_x_start` of log probs of x_0. |
|
""" |
|
dt = log_x_start.dtype |
|
log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape).to(dt) |
|
log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape).to(dt) |
|
|
|
log_probs = log_add_exp( |
|
log_x_start + log_cumprod_alpha_t, |
|
log_1_min_cumprod_alpha - np.log(self.num_classes) |
|
) |
|
|
|
return log_probs |
|
|
|
def q_posterior(self, log_x_start, log_x_t, t): |
|
""" Compute `q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)` |
|
where q(xt | xt-1, x0) = q(xt | xt-1). |
|
""" |
|
|
|
|
|
|
|
t_minus_1 = t - 1 |
|
|
|
t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) |
|
log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1) |
|
|
|
|
|
num_axes = (1,) * (len(log_x_start.size()) - 1) |
|
t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start) |
|
log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t) |
|
|
|
|
|
log_EV_xtmin_given_xt_given_xstart = \ |
|
unnormed_logprobs \ |
|
- torch.logsumexp(unnormed_logprobs, dim=-1, keepdim=True) |
|
|
|
return log_EV_xtmin_given_xt_given_xstart |
|
|
|
def p_pred(self, log_x_t, t, log_x0_pred): |
|
""" Predict `p(x_{t-1} | x_t)` using `q(xt-1 | xt, hat{x0})`, where `hat{x0}` is given by |
|
log probabilities from model as `log_x0_pred` (bs, ...., K) and x_t is given by |
|
`log_x_t` of shape `(bs, ..., K)` |
|
""" |
|
|
|
|
|
log_model_pred = self.q_posterior( |
|
log_x_start=log_x0_pred, log_x_t=log_x_t, t=t) |
|
return log_model_pred |
|
|
|
def log_sample_categorical(self, logprobs: Tensor, dim=-1) -> Tensor: |
|
""" Sample from categorical `logprobs` (bs, ..., probs), where position of probs is specified |
|
by `dim`. |
|
|
|
Returns sampled long indices of shape `(bs, ...)` |
|
""" |
|
uniform = torch.rand_like(logprobs) |
|
gumbel_noise = -torch.log( (-torch.log(uniform.clamp_(min=MIN_LOG_ARG)) ).clamp_(min=MIN_LOG_ARG)) |
|
sample = (gumbel_noise + logprobs).argmax(dim=dim) |
|
return sample |
|
|
|
def q_sample(self, log_x_start, t): |
|
""" Draw `x_t` ~ q(x_t | x_0) . `log_x_start` is of shape `(bs, ..., K)`, returns result of same shape """ |
|
log_EV_qxt_x0 = self.q_pred(log_x_start, t) |
|
sample = self.log_sample_categorical(log_EV_qxt_x0) |
|
|
|
|
|
return sample |
|
|
|
def compute_Lt(self, log_x_start: Tensor, log_x_t: Tensor, log_x0_pred: Tensor, t, |
|
detach_mean=False, include_kl_prior=True): |
|
""" Get loss given one-hot log x_0, one-hot log x_t, t, and model prediction `log_x0_pred`. |
|
Parameters: |
|
- `log_x_start`: ground-truth input x0, converted to log one-hot (bs, ..., K) |
|
- `log_x_t`: sampled noisy input at `x_t`, converted to log one-hot (bs, ..., K) |
|
- `t`: diffusion timestep (bs,) |
|
- `log_x0_pred`: model prediction of log probabilities of x0, i.e. hat{x0}. |
|
- `include_kl_prior`: add last two terms to model loss (does not change optimization problem). |
|
""" |
|
dtype = log_x_start.dtype |
|
log_true_prob = self.q_posterior( |
|
log_x_start=log_x_start, log_x_t=log_x_t, t=t) |
|
|
|
log_model_prob = self.p_pred(log_x_t=log_x_t, t=t, log_x0_pred=log_x0_pred) |
|
|
|
if detach_mean: |
|
log_model_prob = log_model_prob.detach() |
|
|
|
kl = self.multinomial_kl(log_true_prob, log_model_prob) |
|
kl = sum_except_batch(kl) |
|
|
|
|
|
decoder_nll = - (log_x_start.exp() * log_model_prob).sum(dim=-1) |
|
decoder_nll = sum_except_batch(decoder_nll) |
|
|
|
mask = (t == torch.zeros_like(t)).to(dtype) |
|
loss = mask * decoder_nll + (1. - mask) * kl |
|
|
|
if include_kl_prior: |
|
pt = torch.ones_like(t, dtype=dtype) |
|
kl_prior = self.kl_prior(log_x_start) |
|
loss = (kl) + kl_prior |
|
|
|
return loss |
|
|
|
def kl_prior(self, log_x_start: Tensor) -> Tensor: |
|
""" This function computes -H_{q}(x_T | x_0)+H_{p}(x_T), which |
|
by some math (see wiki for KL div relation to conditional entropy). |
|
So KL(q(x_T | x_0) || 1/K) = -H_{q}(x_T | x_0)+H_{p}(x_T) for categorical distribution. |
|
|
|
Given `log_x_start` (bs, ..., probs), return KL prior of shape (bs,) |
|
""" |
|
b = log_x_start.size(0) |
|
device = log_x_start.device |
|
ones = torch.ones(b, device=device, dtype=torch.long) |
|
|
|
log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones) |
|
log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob)) |
|
|
|
kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob) |
|
return sum_except_batch(kl_prior) |
|
|
|
|
|
def index2logit(x: Tensor, vocab_size: int, dtype=torch.float32): |
|
x = F.one_hot(x, num_classes=vocab_size).to(dtype) |
|
x = x * (vocab_size/(vocab_size - 1)) - 1/(vocab_size - 1) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
class DSH(): |
|
|
|
jump_len: int = 1 |
|
jump_n_sample: int = 1 |
|
last_greedy: bool = False |
|
x_0_temp: float = 1.0 |
|
guidance_w: float = 1.0 |
|
enable_kevin_scaled_inference: bool = True |
|
T_override: Union[None, int] = None |
|
|
|
deep_clone: bool = False |
|
q0_override_steps: int = 0 |
|
progress: bool = False |
|
|
|
|
|
def get_schedule(t_T, jump_len=10, jump_n_sample=10): |
|
jumps = {} |
|
for j in range(0, t_T - jump_len, jump_len): |
|
jumps[j] = jump_n_sample - 1 |
|
t = t_T |
|
ts = [] |
|
while t >= 1: |
|
t = t-1 |
|
ts.append(t) |
|
if jumps.get(t, 0) > 0: |
|
jumps[t] = jumps[t] - 1 |
|
for _ in range(jump_len): |
|
t = t + 1 |
|
ts.append(t) |
|
ts.append(-1) |
|
return ts |
|
|
|
|
|
def forward_diffusion(diff: MultinomialDiffusion, dtype, x, t, c=None, dsh=DSH): |
|
"""Simple forward diffusion process p""" |
|
log_x_t = index_to_log_onehot(x, diff.num_classes, dtype=dtype) |
|
if c is not None: x = diff.q_pred_one_timestep_scaled(log_x_t, t, c, dsh.jump_len) |
|
else: x = diff.q_pred_one_timestep(log_x_t, t) |
|
x = diff.log_sample_categorical(x) |
|
return x |
|
|
|
|
|
def reverse_diffusion(diff: MultinomialDiffusion, model, batch, x_known=None, m=None, |
|
last_greedy=False, temperature=1.0, alphas=None, ensemble_size=1, dsh=DSH): |
|
"""Reverse diffusion process q: predict x_{t-1} given x, t, x_known, m. Optionally do not sample model output |
|
for t=0, but rather use the greedy argmax with `last_greedy`. |
|
""" |
|
x = batch[4] |
|
t = batch[-1] |
|
if x_known is None: x_known = torch.zeros_like(x) |
|
if m is None: m = torch.zeros_like(x) |
|
|
|
|
|
|
|
|
|
x_0_pred = model(*batch) |
|
x_0_pred = x_0_pred.permute(0, 1, 3, 2) |
|
|
|
if dsh.guidance_w != 1: |
|
uncond_x_0_pred = model(*(c.clone() if c is not None else None for c in batch), drop_cond=True) |
|
uncond_x_0_pred = uncond_x_0_pred.permute(0, 1, 3, 2) |
|
x_0_pred = dsh.guidance_w*x_0_pred + (1-dsh.guidance_w)*uncond_x_0_pred |
|
|
|
x_0_pred = x_0_pred / temperature |
|
log_x_0_pred = F.log_softmax(x_0_pred, dim=-1) |
|
log_x_t = index_to_log_onehot(x, diff.num_classes, dtype=x_0_pred.dtype) |
|
|
|
|
|
log_model_pred = diff.p_pred(log_x_t, t, log_x_0_pred) |
|
|
|
a_t = alphas[t[0]] if alphas is not None else 0 |
|
mat = torch.eye(ensemble_size, device=x.device)*(1-a_t) |
|
mat += 1/ensemble_size * a_t |
|
mat = torch.block_diag(*([mat]*(x.shape[0]//ensemble_size))) |
|
log_model_pred = ( (mat[..., None, None] ).log().to(x.dtype) + log_model_pred[None]) |
|
log_model_pred = torch.logsumexp(log_model_pred, dim=1) |
|
|
|
if (t==0).all() and last_greedy: |
|
x_tm1_unknown = log_model_pred.argmax(dim=-1) |
|
else: |
|
x_tm1_unknown = diff.log_sample_categorical(log_model_pred) |
|
|
|
|
|
x_known_log = index_to_log_onehot(x_known, diff.num_classes, dtype=x_0_pred.dtype) |
|
if (t==0).all(): |
|
x_tm1_known = x_known |
|
else: |
|
x_tm1_known = diff.q_sample(x_known_log, t) |
|
|
|
|
|
x_tm1 = x_tm1_known * m.long() + x_tm1_unknown * (1 - m.long()) |
|
return x_tm1, x_0_pred |
|
|
|
|
|
|
|
@torch.inference_mode() |
|
def perform_simple_inference(model: torch.nn.Module, batch: tuple, diff: MultinomialDiffusion, T, dtype=torch.float16, |
|
retain_quant0: bool = True, dsh=DSH): |
|
""" If `retain_quant0`, then do not sample quant0 in each forward or reverse diffusion step. """ |
|
|
|
|
|
c_text, c_codes, c_text_lengths, c_codes_lengths, x, x_padding_mask = batch |
|
|
|
device = c_text.device |
|
bs = c_text.shape[0] |
|
x_quant0 = x[..., 0].clone() |
|
x = torch.randint(0, diff.num_classes, x.shape, dtype=x.dtype, device=device) |
|
|
|
x[..., 0] = x_quant0 |
|
|
|
|
|
times = get_schedule(T, jump_n_sample=dsh.jump_n_sample, jump_len=dsh.jump_len) |
|
|
|
x_known = torch.zeros_like(x) |
|
x_known[..., 0] = x[..., 0] |
|
m = torch.zeros_like(x).bool() |
|
|
|
m[..., 0] = True |
|
|
|
offset = 0 |
|
if dsh.deep_clone: |
|
print(f"Note: using deep clone. Assuming input `c_phones` is concatenated prompt and output phones.", |
|
"Also assuming no padded indices in `c_codes`.") |
|
prompt = c_codes |
|
x = torch.cat((prompt, x), dim=1) |
|
x_known = torch.cat((prompt, x_known), dim=1) |
|
x_padding_mask = torch.cat(( |
|
torch.zeros(x_padding_mask.shape[0], c_codes_lengths[0], dtype=torch.bool, device=x_padding_mask.device), |
|
x_padding_mask), dim=-1 |
|
) |
|
|
|
m = torch.cat((torch.ones_like(prompt), m), dim=1) |
|
x_quant0 = torch.cat((prompt[..., 0], x_quant0), dim=-1) |
|
offset = c_codes_lengths[0] |
|
|
|
print(f"New x: {x.shape} | new x_known: {x_known.shape} . Base prompt: {prompt.shape}. New padding mask: {x_padding_mask.shape} | m shape: {m.shape}") |
|
|
|
c = 0 |
|
|
|
|
|
alphas = torch.linspace(1, 0, T).to(device) |
|
|
|
pb = zip(times[:-1], times[1:]) |
|
if dsh.progress: |
|
from fastprogress import progress_bar |
|
pb = progress_bar(pb, total=len(times)-1) |
|
|
|
|
|
for t_last, t_cur in pb: |
|
|
|
t = torch.ones((bs,), dtype=torch.long, device=x.device) * (t_last) |
|
if t_cur < t_last: |
|
if c > dsh.jump_n_sample: |
|
c = 0 |
|
c += 1/dsh.jump_len |
|
|
|
|
|
cbatch = (c_text, c_codes, c_text_lengths, c_codes_lengths, x, x_padding_mask, t) |
|
x, x_0_pred = reverse_diffusion(diff, model, cbatch, x_known, m, temperature=dsh.x_0_temp, alphas=alphas, ensemble_size=1, dsh=dsh) |
|
else: |
|
|
|
if dsh.enable_kevin_scaled_inference: x = forward_diffusion(diff, dtype, x, t, c=c, dsh=dsh) |
|
else: x = forward_diffusion(diff, dtype, x, t, c=None, dsh=dsh) |
|
|
|
if retain_quant0 and dsh.q0_override_steps < t_last: |
|
x[..., 0] = x_quant0 |
|
|
|
|
|
x = x[:, offset:] |
|
return x |
|
|