pheme / modules /masking_logic.py
taras-sereda's picture
minimal set of files to run inference; pheme-small checkpoint
96ee597
"""Masking and sampling logic adapted from MaskGIT original paper:
https://github.com/google-research/maskgit
Copyright PolyAI Limited.
"""
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn.functional as F
@dataclass
class State:
"""Holds decoding state data."""
# The position of the decoding loop in the length dimension.
cur_index: None
# The active sequence log probabilities and finished sequence scores.
cur_seqs: None
final_seqs: None
def state_init(init_indices, num_iter, start_iter=0):
"""Initializes the decoding state data structure."""
cur_index_0 = start_iter
cur_seqs_0 = init_indices
final_seqs_0 = torch.unsqueeze(init_indices, 1)
final_seqs_0 = torch.tile(final_seqs_0, (1, num_iter, 1))
return State(
cur_index=cur_index_0, cur_seqs=cur_seqs_0, final_seqs=final_seqs_0)
def schedule(ratio, method="cosine"):
if method == "uniform":
mask_ratio = 1. - ratio
elif "pow" in method:
exponent = float(method.replace("pow", ""))
mask_ratio = 1. - ratio**exponent
elif method == "cosine":
mask_ratio = np.cos(ratio * (np.pi/2))
mask_ratio = np.clip(mask_ratio, 1e-6, 1.)
return mask_ratio
def mask_by_random_topk(mask_len, probs, temperature=1.0):
noise = gumbel_noise_like(probs)
confidence = torch.log(probs) + temperature * noise
sorted_confidence, _ = torch.sort(confidence, dim=-1)
# Obtains cut off threshold given the mask lengths.
cut_off = torch.take_along_dim(sorted_confidence, mask_len.long(), dim=-1)
# Masks tokens with lower confidence.
masking = (confidence < cut_off)
return masking
def gumbel_noise_like(t):
noise = torch.zeros_like(t).uniform_(1e-20, 1)
return -torch.log(-torch.log(noise))
def sample_from_logits(
logits,
sample: bool = True,
temperature: float = 1.0,
top_k: int = None,
top_p: float = None,
return_probs: bool = False
):
shp = logits.shape[:-1]
# Apply top_k sampling
if top_k is not None:
v, _ = logits.topk(top_k)
logits[logits < v[..., [-1]]] = -float("inf")
# Apply top_p (nucleus) sampling
if top_p is not None and top_p < 1.0:
v, sorted_indices = logits.sort(descending=True)
cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
# Right shift indices_to_remove to keep 1st token over threshold
sorted_indices_to_remove = F.pad(
sorted_indices_to_remove, (1, 0), value=False)[..., :-1]
# Compute indices_to_remove in unsorted array
indices_to_remove = sorted_indices_to_remove.scatter(
-1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = -float("inf")
# Perform multinomial sampling after normalizing logits
probs = (
F.softmax(logits / temperature, dim=-1)
if temperature > 0
else logits.softmax(dim=-1)
)
token = (
probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
if sample
else logits.argmax(-1)
)
if return_probs:
token_probs = probs.take_along_dim(
token.unsqueeze(-1), dim=-1).squeeze(-1)
return token, token_probs
else:
return token