Spaces:
Runtime error
Runtime error
File size: 622 Bytes
0bf81ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
NEGATIVE_INF = -100000.0
HALF_NEGATIVE_INF = -60000.0 # half precision
def get_first_sentence(txt, min_len=5):
eos = '<|endoftext|>'
eos_idx = txt.find(eos)
if eos_idx > 0:
txt = txt[eos_idx:]
txt = txt.replace('\n', ' ')
sents = txt.split('. ')
if len(sents[0]) >= min_len:
sent = f'{sents[0].strip()}.'
else:
sent = txt
return sent
def logits_to_entropy(logits):
distribution = torch.distributions.Categorical(logits=logits)
return distribution.entropy()
def mask_pad(value, mask):
return value * mask + NEGATIVE_INF * (1 - mask)
|