Spaces:
Runtime error
Runtime error
"""Masking and sampling logic adapted from MaskGIT original paper: | |
https://github.com/google-research/maskgit | |
Copyright PolyAI Limited. | |
""" | |
from dataclasses import dataclass | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
class State: | |
"""Holds decoding state data.""" | |
# The position of the decoding loop in the length dimension. | |
cur_index: None | |
# The active sequence log probabilities and finished sequence scores. | |
cur_seqs: None | |
final_seqs: None | |
def state_init(init_indices, num_iter, start_iter=0): | |
"""Initializes the decoding state data structure.""" | |
cur_index_0 = start_iter | |
cur_seqs_0 = init_indices | |
final_seqs_0 = torch.unsqueeze(init_indices, 1) | |
final_seqs_0 = torch.tile(final_seqs_0, (1, num_iter, 1)) | |
return State( | |
cur_index=cur_index_0, cur_seqs=cur_seqs_0, final_seqs=final_seqs_0) | |
def schedule(ratio, method="cosine"): | |
if method == "uniform": | |
mask_ratio = 1. - ratio | |
elif "pow" in method: | |
exponent = float(method.replace("pow", "")) | |
mask_ratio = 1. - ratio**exponent | |
elif method == "cosine": | |
mask_ratio = np.cos(ratio * (np.pi/2)) | |
mask_ratio = np.clip(mask_ratio, 1e-6, 1.) | |
return mask_ratio | |
def mask_by_random_topk(mask_len, probs, temperature=1.0): | |
noise = gumbel_noise_like(probs) | |
confidence = torch.log(probs) + temperature * noise | |
sorted_confidence, _ = torch.sort(confidence, dim=-1) | |
# Obtains cut off threshold given the mask lengths. | |
cut_off = torch.take_along_dim(sorted_confidence, mask_len.long(), dim=-1) | |
# Masks tokens with lower confidence. | |
masking = (confidence < cut_off) | |
return masking | |
def gumbel_noise_like(t): | |
noise = torch.zeros_like(t).uniform_(1e-20, 1) | |
return -torch.log(-torch.log(noise)) | |
def sample_from_logits( | |
logits, | |
sample: bool = True, | |
temperature: float = 1.0, | |
top_k: int = None, | |
top_p: float = None, | |
return_probs: bool = False | |
): | |
shp = logits.shape[:-1] | |
# Apply top_k sampling | |
if top_k is not None: | |
v, _ = logits.topk(top_k) | |
logits[logits < v[..., [-1]]] = -float("inf") | |
# Apply top_p (nucleus) sampling | |
if top_p is not None and top_p < 1.0: | |
v, sorted_indices = logits.sort(descending=True) | |
cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1) | |
sorted_indices_to_remove = cumulative_probs > top_p | |
# Right shift indices_to_remove to keep 1st token over threshold | |
sorted_indices_to_remove = F.pad( | |
sorted_indices_to_remove, (1, 0), value=False)[..., :-1] | |
# Compute indices_to_remove in unsorted array | |
indices_to_remove = sorted_indices_to_remove.scatter( | |
-1, sorted_indices, sorted_indices_to_remove | |
) | |
logits[indices_to_remove] = -float("inf") | |
# Perform multinomial sampling after normalizing logits | |
probs = ( | |
F.softmax(logits / temperature, dim=-1) | |
if temperature > 0 | |
else logits.softmax(dim=-1) | |
) | |
token = ( | |
probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp) | |
if sample | |
else logits.argmax(-1) | |
) | |
if return_probs: | |
token_probs = probs.take_along_dim( | |
token.unsqueeze(-1), dim=-1).squeeze(-1) | |
return token, token_probs | |
else: | |
return token | |