esper / utils.py
jiwan-chung's picture
demo init
0bf81ba
raw
history blame contribute delete
622 Bytes
import torch
NEGATIVE_INF = -100000.0
HALF_NEGATIVE_INF = -60000.0 # half precision
def get_first_sentence(txt, min_len=5):
eos = '<|endoftext|>'
eos_idx = txt.find(eos)
if eos_idx > 0:
txt = txt[eos_idx:]
txt = txt.replace('\n', ' ')
sents = txt.split('. ')
if len(sents[0]) >= min_len:
sent = f'{sents[0].strip()}.'
else:
sent = txt
return sent
def logits_to_entropy(logits):
distribution = torch.distributions.Categorical(logits=logits)
return distribution.entropy()
def mask_pad(value, mask):
return value * mask + NEGATIVE_INF * (1 - mask)