File size: 20,322 Bytes
8520a55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 |
"""
Discrete multinomial diffusion code adapted from https://github.com/RF5/transfusion-asr,
which in turn is adapted from https://github.com/ehoogeboom/multinomial_diffusion.
Please see the original repo (https://github.com/ehoogeboom/multinomial_diffusion) and paper for full
details on how multinomial diffusion works -- thanks to the original authors!
"""
import torch
from torch import Tensor
from torch.functional import F
import numpy as np
from dataclasses import dataclass
from typing import Union
# -------------- Multinomial utility functions -----------
MIN_LOG_ARG = 1e-7 # originally was 1e-40
def log_1_min_a(a): return torch.log((1 - a.exp()).clamp_(min=1e-30))
def log_add_exp(a, b):
maximum = torch.max(a, b)
return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))
def extract(a: Tensor, t, x_shape):
""" Given 1D vector of alpha/alpha_cum/betas, get index at `t` of shape (bs,), and then
broadcast it to number of dims in `x_shape`.
"""
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def index_to_log_onehot(x, num_classes, dim=-1, dtype=torch.float32):
""" Convert indices `x` (bs, ...) to approx one-hot log-probs of shape (bs, ..., num_classes) """
assert x.max().item() < num_classes, \
f'Error: {x.max().item()} >= {num_classes}'
x_onehot = F.one_hot(x, num_classes)
if dim == 1:
permute_order = (0, -1) + tuple(range(1, len(x.size())))
x_onehot = x_onehot.permute(permute_order)
else:
pass
log_x = torch.log(x_onehot.to(dtype).clamp(min=MIN_LOG_ARG)) # so min(log_x) will be -30
return log_x
def sum_except_batch(x: Tensor, num_dims=1) -> Tensor:
'''
Sums all dimensions except the first.
Args:
x: Tensor, shape (batch_size, ...)
num_dims: int, number of batch dims (default=1)
Returns:
x_sum: Tensor, shape (batch_size,)
'''
return x.reshape(*x.shape[:num_dims], -1).sum(-1)
# -------------- Multinomial diffusion class -------------
class MultinomialDiffusion():
def __init__(self, num_classes, timesteps=100, diffusion_s=0.008,
loss_type='vb_stochastic', parametrization='x0',
dtype=torch.float32,
device='cpu'):
super(MultinomialDiffusion, self).__init__()
assert loss_type in ('vb_stochastic',)
assert parametrization in ('x0', 'direct')
self.num_classes = num_classes
self.loss_type = loss_type
self.num_timesteps = timesteps
self.parametrization = parametrization
alphas = self.cosine_beta_schedule(timesteps, diffusion_s)
alphas = alphas.to(torch.float64)
log_alpha = alphas.log()
log_cumprod_alpha = torch.cumsum(log_alpha, dim=-1)
log_1_min_alpha = log_1_min_a(log_alpha) # = log(betas)
log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha) # = log(1- \bar{a})
a = log_add_exp(log_alpha, log_1_min_alpha) # log(1-beta + beta) = log(1) = 0
assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5
assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5
assert (torch.cumsum(log_alpha, dim=-1) - log_cumprod_alpha).abs().sum().item() < 1.e-5
# Convert to float32 and register buffers.
self.log_alpha = log_alpha.to(dtype).to(device)
self.log_1_min_alpha = log_1_min_alpha.to(dtype).to(device)
self.log_cumprod_alpha = log_cumprod_alpha.to(dtype).to(device)
self.log_1_min_cumprod_alpha = log_1_min_cumprod_alpha.to(dtype).to(device)
@staticmethod
def cosine_beta_schedule(timesteps, s=0.008) -> Tensor:
"""
cosine schedule as proposed in https://arxiv.org/abs/2102.09672 .
Returns alpha parameters, NOT Beta
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])
alphas = torch.clamp(alphas, 0.001, 1.0)
return torch.sqrt(alphas)
def multinomial_kl(self, log_prob1: Tensor, log_prob2: Tensor, dim=-1) -> Tensor:
""" Get KL divergence between two categorical distributions specified with `log_prob1` and `log_prob2`.
Assumed probability dim is `dim` (i.e. log_prob1.exp().sum(dim=`dim`) should be tensor of ones)
"""
kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=dim)
return kl
def q_pred_one_timestep(self, log_x_t: Tensor, t: Tensor) -> Tensor:
""" Compute q(x_t | x_{t-1}) = C(x_t | alpha_t * x_{t-1} + (1-alpha_t)/K in the log-domain
given `log_x_t` as log one-hot encoding of x_t.
Recall due to symmetry property we can compute
this value using x_t instead of x_{t-1} (se appendix A of https://arxiv.org/pdf/2102.05379.pdf)
"""
dt = log_x_t.dtype
log_alpha_t = extract(self.log_alpha, t, log_x_t.shape).to(dt)
log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape).to(dt)
# alpha_t * E[xt] + (1 - alpha_t) 1 / K
log_probs = log_add_exp(
log_x_t + log_alpha_t,
log_1_min_alpha_t - np.log(self.num_classes)
)
return log_probs
def q_pred_one_timestep_scaled(self, log_x_t: Tensor, t: Tensor, c: int, jump_len: int) -> Tensor:
""" Compute q(x_t | x_{t-1}) = C(x_t | alpha_t * x_{t-1} + (1-alpha_t)/K in the log-domain
given `log_x_t` as log one-hot encoding of x_t.
Recall due to symmetry property we can compute
this value using x_t instead of x_{t-1} (se appendix A of https://arxiv.org/pdf/2102.05379.pdf)
"""
dt = log_x_t.dtype
log_alpha_t = extract(self.log_alpha, t, log_x_t.shape).to(dt)
log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape).to(dt)
# Magic
xax = torch.arange(0,log_x_t.shape[1],1).to(log_x_t.device)
aa=log_x_t.shape[1]*(c/jump_len)
sig = 1/(1+torch.exp(-(xax-aa+20)/8))
log_alpha_t = (torch.log(1/sig)[None,:,None] + log_alpha_t).clamp(-torch.inf, 0)
log_1_min_alpha_t = torch.log(sig)[None,:,None] + log_1_min_alpha_t
# alpha_t * E[xt] + (1 - alpha_t) 1 / K
log_probs = log_add_exp(
log_x_t + log_alpha_t,
log_1_min_alpha_t - np.log(self.num_classes)
)
return log_probs
def q_pred(self, log_x_start: Tensor, t) -> Tensor:
""" Compute q(x_t | x_0) = C(x_t | bar{alpha}_t * x_0 + (1 - bar{alpha}_t)/K ) in log domain,
given `log_x_start` of log probs of x_0.
"""
dt = log_x_start.dtype
log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape).to(dt)
log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape).to(dt)
log_probs = log_add_exp(
log_x_start + log_cumprod_alpha_t,
log_1_min_cumprod_alpha - np.log(self.num_classes)
)
return log_probs
def q_posterior(self, log_x_start, log_x_t, t):
""" Compute `q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)`
where q(xt | xt-1, x0) = q(xt | xt-1).
"""
# q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)
# where q(xt | xt-1, x0) = q(xt | xt-1).
t_minus_1 = t - 1
# Remove negative values, will not be used anyway for final decoder
t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1) # log( q(x_{t-1} | x_0) )
# if t == 0, then log( q(x_0 | x_0) ) = log( one_hot(x_0) ), not even random at that point.
# so, where t == 0
num_axes = (1,) * (len(log_x_start.size()) - 1)
t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start) # broadcast to non-batch axes
log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)
# where it is zero, replace
# with log one-hot encoding of x0.
# Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
# Not very easy to see why this is true. But it is :)
# log_EV_qxtmin_x0 ~ q(x_{t-1} | x_0)
# q_pred_one_timestep(log_x_t, t) ~ q(x_t | x_{t-1}) (which due to symmetry can be computed using x_t)
unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t) # numerator of bayes
# approximate denominator with just a normalizing sum.
log_EV_xtmin_given_xt_given_xstart = \
unnormed_logprobs \
- torch.logsumexp(unnormed_logprobs, dim=-1, keepdim=True)
return log_EV_xtmin_given_xt_given_xstart
def p_pred(self, log_x_t, t, log_x0_pred):
""" Predict `p(x_{t-1} | x_t)` using `q(xt-1 | xt, hat{x0})`, where `hat{x0}` is given by
log probabilities from model as `log_x0_pred` (bs, ...., K) and x_t is given by
`log_x_t` of shape `(bs, ..., K)`
"""
# log_x_recon = self.predict_start(log_x, t=t) # model itself predicts x_0
# log_x0_pred
log_model_pred = self.q_posterior(
log_x_start=log_x0_pred, log_x_t=log_x_t, t=t)
return log_model_pred
def log_sample_categorical(self, logprobs: Tensor, dim=-1) -> Tensor:
""" Sample from categorical `logprobs` (bs, ..., probs), where position of probs is specified
by `dim`.
Returns sampled long indices of shape `(bs, ...)`
"""
uniform = torch.rand_like(logprobs)
gumbel_noise = -torch.log( (-torch.log(uniform.clamp_(min=MIN_LOG_ARG)) ).clamp_(min=MIN_LOG_ARG))
sample = (gumbel_noise + logprobs).argmax(dim=dim)
return sample
def q_sample(self, log_x_start, t):
""" Draw `x_t` ~ q(x_t | x_0) . `log_x_start` is of shape `(bs, ..., K)`, returns result of same shape """
log_EV_qxt_x0 = self.q_pred(log_x_start, t)
sample = self.log_sample_categorical(log_EV_qxt_x0)
# log_sample = index_to_log_onehot(sample, self.num_classes)
return sample #log_sample
def compute_Lt(self, log_x_start: Tensor, log_x_t: Tensor, log_x0_pred: Tensor, t,
detach_mean=False, include_kl_prior=True):
""" Get loss given one-hot log x_0, one-hot log x_t, t, and model prediction `log_x0_pred`.
Parameters:
- `log_x_start`: ground-truth input x0, converted to log one-hot (bs, ..., K)
- `log_x_t`: sampled noisy input at `x_t`, converted to log one-hot (bs, ..., K)
- `t`: diffusion timestep (bs,)
- `log_x0_pred`: model prediction of log probabilities of x0, i.e. hat{x0}.
- `include_kl_prior`: add last two terms to model loss (does not change optimization problem).
"""
dtype = log_x_start.dtype
log_true_prob = self.q_posterior(
log_x_start=log_x_start, log_x_t=log_x_t, t=t)
log_model_prob = self.p_pred(log_x_t=log_x_t, t=t, log_x0_pred=log_x0_pred)
if detach_mean:
log_model_prob = log_model_prob.detach()
kl = self.multinomial_kl(log_true_prob, log_model_prob)
kl = sum_except_batch(kl)
# Add L_0, -log(p(x_0 | x_1))
decoder_nll = - (log_x_start.exp() * log_model_prob).sum(dim=-1)
decoder_nll = sum_except_batch(decoder_nll)
mask = (t == torch.zeros_like(t)).to(dtype)
loss = mask * decoder_nll + (1. - mask) * kl # only add L0 if t == 0.
if include_kl_prior:
pt = torch.ones_like(t, dtype=dtype)
kl_prior = self.kl_prior(log_x_start)
loss = (kl) + kl_prior
return loss
def kl_prior(self, log_x_start: Tensor) -> Tensor:
""" This function computes -H_{q}(x_T | x_0)+H_{p}(x_T), which
by some math (see wiki for KL div relation to conditional entropy).
So KL(q(x_T | x_0) || 1/K) = -H_{q}(x_T | x_0)+H_{p}(x_T) for categorical distribution.
Given `log_x_start` (bs, ..., probs), return KL prior of shape (bs,)
"""
b = log_x_start.size(0)
device = log_x_start.device
ones = torch.ones(b, device=device, dtype=torch.long)
log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones) # q(x_T | x_0)
log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob)) # log(1/K), broadcast to q(x_T|x_0) shape
kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)
return sum_except_batch(kl_prior)
def index2logit(x: Tensor, vocab_size: int, dtype=torch.float32):
x = F.one_hot(x, num_classes=vocab_size).to(dtype)
x = x * (vocab_size/(vocab_size - 1)) - 1/(vocab_size - 1)
return x
# ------------------------------
# Functions adapted from the full
@dataclass
class DSH():
# Diffusion Sampling Hyperparameters [DSH] (Section 4)
jump_len: int = 1 # j in RePaint paper [default 10] (Section 4.1)
jump_n_sample: int = 1 # r in RePaint paper [default 10] (Section 4.1)
last_greedy: bool = False # whether to not sample at t=0, but take argmax prediction. [default False]
x_0_temp: float = 1.0 # reweight temp for model prediction of x0
guidance_w: float = 1.0 # classifier free guidance weight [default 1.5] (Section 4.3)
enable_kevin_scaled_inference: bool = True # sequentially progressive diffusion [default True] (Section 4.2)
T_override: Union[None, int] = None # allow variable transcription sizes during inference (Section 4.4)
deep_clone: bool = False # whether to do deep clone.
q0_override_steps: int = 0 # number of steps that we allow overriding the input quant level 0 inputs.
progress: bool = False # whether to show progress bar
def get_schedule(t_T, jump_len=10, jump_n_sample=10):
jumps = {}
for j in range(0, t_T - jump_len, jump_len):
jumps[j] = jump_n_sample - 1
t = t_T
ts = []
while t >= 1:
t = t-1
ts.append(t)
if jumps.get(t, 0) > 0:
jumps[t] = jumps[t] - 1
for _ in range(jump_len):
t = t + 1
ts.append(t)
ts.append(-1)
return ts
def forward_diffusion(diff: MultinomialDiffusion, dtype, x, t, c=None, dsh=DSH):
"""Simple forward diffusion process p"""
log_x_t = index_to_log_onehot(x, diff.num_classes, dtype=dtype)
if c is not None: x = diff.q_pred_one_timestep_scaled(log_x_t, t, c, dsh.jump_len)
else: x = diff.q_pred_one_timestep(log_x_t, t)
x = diff.log_sample_categorical(x)
return x
def reverse_diffusion(diff: MultinomialDiffusion, model, batch, x_known=None, m=None,
last_greedy=False, temperature=1.0, alphas=None, ensemble_size=1, dsh=DSH):
"""Reverse diffusion process q: predict x_{t-1} given x, t, x_known, m. Optionally do not sample model output
for t=0, but rather use the greedy argmax with `last_greedy`.
"""
x = batch[4]
t = batch[-1]
if x_known is None: x_known = torch.zeros_like(x)
if m is None: m = torch.zeros_like(x)
# Equation 8b
# for b in batch:
# print(f"{b.shape}: {b}")
x_0_pred = model(*batch) # (bs, seq_len, logit_dim, n_quant)
x_0_pred = x_0_pred.permute(0, 1, 3, 2) # (bs, seq_len, n_quant, dim)
if dsh.guidance_w != 1:
uncond_x_0_pred = model(*(c.clone() if c is not None else None for c in batch), drop_cond=True)
uncond_x_0_pred = uncond_x_0_pred.permute(0, 1, 3, 2)
x_0_pred = dsh.guidance_w*x_0_pred + (1-dsh.guidance_w)*uncond_x_0_pred
x_0_pred = x_0_pred / temperature
log_x_0_pred = F.log_softmax(x_0_pred, dim=-1)
log_x_t = index_to_log_onehot(x, diff.num_classes, dtype=x_0_pred.dtype)
# print("PRE: ", log_x_t.shape, t.shape, log_x_0_pred.shape)
log_model_pred = diff.p_pred(log_x_t, t, log_x_0_pred) # p(x_{t-1} | x_{t})
a_t = alphas[t[0]] if alphas is not None else 0
mat = torch.eye(ensemble_size, device=x.device)*(1-a_t)
mat += 1/ensemble_size * a_t
mat = torch.block_diag(*([mat]*(x.shape[0]//ensemble_size)))
log_model_pred = ( (mat[..., None, None] ).log().to(x.dtype) + log_model_pred[None])
log_model_pred = torch.logsumexp(log_model_pred, dim=1)
if (t==0).all() and last_greedy: # Do not sample at t=0
x_tm1_unknown = log_model_pred.argmax(dim=-1)
else:
x_tm1_unknown = diff.log_sample_categorical(log_model_pred)
# Equation 8a
x_known_log = index_to_log_onehot(x_known, diff.num_classes, dtype=x_0_pred.dtype)
if (t==0).all(): # Do not sample at t=0
x_tm1_known = x_known
else:
x_tm1_known = diff.q_sample(x_known_log, t)
# Equation 8c
x_tm1 = x_tm1_known * m.long() + x_tm1_unknown * (1 - m.long())
return x_tm1, x_0_pred
@torch.inference_mode()
def perform_simple_inference(model: torch.nn.Module, batch: tuple, diff: MultinomialDiffusion, T, dtype=torch.float16,
retain_quant0: bool = True, dsh=DSH):
""" If `retain_quant0`, then do not sample quant0 in each forward or reverse diffusion step. """
# (bs=1, N), (bs, seq_len2, 8), (bs,)
c_text, c_codes, c_text_lengths, c_codes_lengths, x, x_padding_mask = batch
device = c_text.device
bs = c_text.shape[0]
x_quant0 = x[..., 0].clone() # (bs, seq_len) 0th quant level
x = torch.randint(0, diff.num_classes, x.shape, dtype=x.dtype, device=device)
# CRITICAL LINE: override quantization level 0 with provided quant0 level.
x[..., 0] = x_quant0
# RePaint paper resample scheduling
times = get_schedule(T, jump_n_sample=dsh.jump_n_sample, jump_len=dsh.jump_len)
x_known = torch.zeros_like(x)
x_known[..., 0] = x[..., 0] # override L0 codes
m = torch.zeros_like(x).bool()
# (bs, seq_len, 8)
m[..., 0] = True
offset = 0
if dsh.deep_clone:
print(f"Note: using deep clone. Assuming input `c_phones` is concatenated prompt and output phones.",
"Also assuming no padded indices in `c_codes`.")
prompt = c_codes
x = torch.cat((prompt, x), dim=1) # (bs=1, sl1 + sl2, 8)
x_known = torch.cat((prompt, x_known), dim=1)
x_padding_mask = torch.cat((
torch.zeros(x_padding_mask.shape[0], c_codes_lengths[0], dtype=torch.bool, device=x_padding_mask.device),
x_padding_mask), dim=-1
)
# (bs=1, :up to prompt duration, all 8 codebooks) = True/masked.
m = torch.cat((torch.ones_like(prompt), m), dim=1)
x_quant0 = torch.cat((prompt[..., 0], x_quant0), dim=-1)
offset = c_codes_lengths[0]
print(f"New x: {x.shape} | new x_known: {x_known.shape} . Base prompt: {prompt.shape}. New padding mask: {x_padding_mask.shape} | m shape: {m.shape}")
c = 0 # sequentially progressive diffusion offset (Section 4.2)
# ensemble bs (not in paper)
alphas = torch.linspace(1, 0, T).to(device)
pb = zip(times[:-1], times[1:])
if dsh.progress:
from fastprogress import progress_bar
pb = progress_bar(pb, total=len(times)-1)
# See RePaint paper algorithm
for t_last, t_cur in pb:
t = torch.ones((bs,), dtype=torch.long, device=x.device) * (t_last)
if t_cur < t_last:
if c > dsh.jump_n_sample:
c = 0
c += 1/dsh.jump_len
# Reverse diffusion: q
cbatch = (c_text, c_codes, c_text_lengths, c_codes_lengths, x, x_padding_mask, t)
x, x_0_pred = reverse_diffusion(diff, model, cbatch, x_known, m, temperature=dsh.x_0_temp, alphas=alphas, ensemble_size=1, dsh=dsh)
else:
# Forward diffusion: p
if dsh.enable_kevin_scaled_inference: x = forward_diffusion(diff, dtype, x, t, c=c, dsh=dsh)
else: x = forward_diffusion(diff, dtype, x, t, c=None, dsh=dsh)
if retain_quant0 and dsh.q0_override_steps < t_last:
x[..., 0] = x_quant0
# crop offset:
x = x[:, offset:]
return x
|