import streamlit as st import numpy as np from numpy import ndarray import pandas as pd import torch as T from torch import Tensor, device from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoConfig, AutoModel from nltk.corpus import stopwords from nltk.stem.porter import * import json import nltk from nltk import FreqDist from nltk.corpus import gutenberg import urllib.request from string import punctuation from math import log,exp,sqrt import random from time import sleep nltk.download('stopwords') nltk.download('gutenberg') cos = T.nn.CosineSimilarity(dim=0) urllib.request.urlretrieve("https://github.com/ondovb/nCloze/raw/1b57ab719c367c070aeba8a53e71a536ce105091/dict-info.txt", 'dict-info.txt') sleep(1) urllib.request.urlretrieve("https://github.com/ondovb/nCloze/raw/1b57ab719c367c070aeba8a53e71a536ce105091/dict-unix.txt", 'dict-unix.txt') sleep(1) urllib.request.urlretrieve("https://github.com/ondovb/nCloze/raw/1b57ab719c367c070aeba8a53e71a536ce105091/profanity.json", 'profanity.json') #gdown.download('https://drive.google.com/uc?id=16j6oQbqIUfdY1kMFOonXVDdG7A0C6CXD&confirm=t',use_cookies=True) #gdown.download(id='13-3DyP4Df1GzrdQ_W4fLhPYAA1Gscg1j',use_cookies=True) #gdown.download(id='180X6ztER2lKVP_dKinSJNE0XRtmnixAM',use_cookies=True) CONTEXTUAL_EMBEDDING_LAYERS = [12] EXTEND_SUBWORDS=True MAX_SUBWORDS=1 DEBUG_OUTPUT=True DISTRACTORS_FROM_TEXT=False MIN_SENT_WORDS = 7 # Frequencies are used to decide if a distractor candidate might be a subword stemmer = PorterStemmer() freq = FreqDist(i.lower() for i in gutenberg.words()) print(freq.most_common()[:5]) words_unix = set(line.strip() for line in open('dict-unix.txt')) words_info = set(line.strip() for line in open('dict-info.txt')) words_small = words_unix.intersection(words_info) words_large = words_unix.union(words_info) f = open('profanity.json') profanity = json.load(f) import stanza nlp = stanza.Pipeline(lang='en', processors='tokenize')#, model_dir='/data/ondovbd/stanza_resources') nltk.download('punkt') nltk_sent_toker = nltk.data.load('tokenizers/punkt/english.pickle') def is_word(str): '''Check if word exists in dictionary''' splt = str.lower().split("'") if len(splt) > 2: return False elif len(splt) == 2: return is_word(splt[0]) and (splt[1] in ['t','nt','s','ll']) elif '-' in str: for word in str.split('-'): if not is_word(word): return False return True else: return str.lower() in words_unix or str.lower() in words_info def get_emb(snt_toks, tgt_toks, layers=None): '''Embeds a group of subword tokens in place of a mask, using the entire sentence for context. Returns the average of the target token embeddings, which are summed over the hidden layers. snt_toks: the tokenized sentence, including the mask token tgt_toks: the tokens (subwords) to replace the mask token layers (optional): which hidden layers to sum (list of indices)''' mask_idx = snt_toks.index(toker.mask_token_id) snt_toks = snt_toks.copy() while mask_idx + len(tgt_toks)-1 >= 512: # Shift text by 100 words snt_toks = snt_toks[100:] mask_idx -= 100 snt_toks[mask_idx:mask_idx+1] = tgt_toks snt_toks = snt_toks[:512] with T.no_grad(): if T.cuda.is_available(): T.tensor([snt_toks]).cuda() T.tensor([[1]*len(snt_toks)]).cuda() output = model(T.tensor([snt_toks]), T.tensor([[1]*len(snt_toks)]), output_hidden_states=True) layers = CONTEXTUAL_EMBEDDING_LAYERS if layers is None else layers output = T.stack([output.hidden_states[i] for i in layers]).sum(0).squeeze() # Only select the tokens that constitute the requested word return output[mask_idx:mask_idx+len(tgt_toks)].mean(dim=0) def energy(ctx, scaled_dists, scaled_sims, choices, words, ans): #Calculate and add cosine similarity scores '''Cost function to help choose best distractors''' #e = [embs[i] for i in choices] #+ [sem_emb_ans] #w = [words[i] for i in choices] #+ [ans] hm_sim = 0 e_ctx = 0 for i in choices: hm_sim += 1./scaled_sims[i] e_ctx += ctx[i] e_sim = float(len(choices))/hm_sim hm_emb = 0 count = 0 c = choices + [len(ctx)] for i in range(len(c)): for j in range(i): d = scaled_dists['%s-%s'%(max(c[i],c[j]), min(c[i], c[j]))] #print(c[i], c[j], d) hm_emb += 1./d count += 1 e_emb = float(count)/hm_emb return float(e_emb), e_ctx, float(e_sim) def anneal(probs_sent_context, probs_para_context, embs, emb_ans, words, k, ans): '''find k distractor indices that are optimally high probability and distant in embedding space''' # probs_sent_context = T.as_tensor(probs_sent_context) / sum(probs_sent_context) m = len(probs_sent_context) # probs_para_context = T.as_tensor(probs_para_context) / sum(probs_para_context) its = 1000 n = len(probs_para_context) choices = list(range(k)) dists = {} embsa = embs + [emb_ans] for i in range(len(embsa)): for j in range(i): dists['%s-%s'%(i,j)] = 1-cos(embsa[i], embsa[j]) # cosine "distance" #print(words[i], words[j], 1-cos(embs[i], embs[j])) dist_min = T.min(T.tensor(list(dists.values()))) dist_max = T.max(T.tensor(list(dists.values()))) for key, dist in dists.items(): dists[key] = (dist - dist_min)/(dist_max-dist_min) sims = T.tensor([cos(emb_ans, emb) for emb in embs]) scaled_sims = (sims - T.min(sims))/(T.max(sims)-T.min(sims)) ctx = T.tensor(probs_sent_context).log()-ALPHA*T.tensor(probs_para_context).log() ctx = (ctx-T.min(ctx))/(T.max(ctx)-T.min(ctx)) e_emb, e_ctx, e_sim = energy(ctx, dists, scaled_sims, choices, words, ans) e = e_ctx + BETA * e_emb #e = SIM_ANNEAL_EMB_WEIGHT * e_emb + e_prob for i in range(its): t = 1.-(i)/its mut_idx = random.randrange(k) # which choice to mutate orig = choices[mut_idx] new = orig while (new in choices): # mutate choice until not in current list new = random.randrange(m) choices[mut_idx] = new e_emb, e_ctx, e_sim = energy(ctx, dists, scaled_sims, choices, words, ans) e_new = e_ctx + BETA * e_emb delta = e_new - e exponent = delta/t if exponent < -50: exponent = -50 # avoid underflow if delta > 0 or exp(exponent) > random.random(): e = e_new # accept new state else: choices[mut_idx] = orig if DEBUG_OUTPUT: print([words[j] for j in choices] + [ans], "e: %f"%(e)) return choices def get_softmax_logits(toks, n_masks = 1, sub_ids = []): # Tokenize text - Keep length of inpts at or below 512 (including answer token length artifically added at end) msk_idx = toks.index(toker.mask_token_id) toks = toks.copy() toks[msk_idx:msk_idx+1] = [toker.mask_token_id] * n_masks + sub_ids # If the masked_token is over 512 (excluding answer token length artifically added at end) tokens away while msk_idx >= 512: # Shift text by 100 words toks = toks[100:] msk_idx -= 100 toks = toks[:512] # Find the predicted words for the fill-in-the-blank mask term based on sentence-context alone with T.no_grad(): t=T.tensor([toks]) m=T.tensor([[1]*len(toks)]) if T.cuda.is_available(): t.cuda() m.cuda() output = model(t, m) sm = T.softmax(output.logits[0, msk_idx:msk_idx+n_masks, :], dim=1) return sm e=1e-10 def candidates(text, answer): '''Create list of unique distractors that does not include the actual answer''' if DEBUG_OUTPUT: print(text) # Get only sentence with blanked text to tokenize doc = nlp(text) #sents = [sentence.text for sentence in doc.sentences] sents = nltk_sent_toker.tokenize(text) msk_snt_idx = [i for i in range(len(sents)) if toker.mask_token in sents[i]][0] just_masked_sentence = sents[msk_snt_idx] prv_snts = sents[:msk_snt_idx] nxt_snts = sents[msk_snt_idx+1:] if len(just_masked_sentence.split(' ')) < MIN_SENT_WORDS and len(prv_snts): just_masked_sentence = ' '.join([prv_snts.pop(), just_masked_sentence]) while len(just_masked_sentence.split(' ')) < MIN_SENT_WORDS and (len(prv_snts) or len(nxt_snts)): if T.rand(1) < 0.5 and len(prv_snts): just_masked_sentence = ' '.join([prv_snts.pop(), just_masked_sentence]) elif len(nxt_snts): just_masked_sentence = ' '.join([just_masked_sentence, nxt_snts.pop(0)]) ctx = just_masked_sentence while len(ctx.split(' ')) < 3 * len(just_masked_sentence.split(' ')) and (len(prv_snts) or len(nxt_snts)): if len(prv_snts): ctx = ' '.join([prv_snts.pop(), ctx]) if len(nxt_snts): ctx = ' '.join([ctx, nxt_snts.pop(0)]) # just_masked_sentence = ' '.join([just_masked_sentence.replace('', 'banana'), # just_masked_sentence.replace('', 'banana'), ## just_masked_sentence, # just_masked_sentence.replace('', 'banana'), # just_masked_sentence.replace('', 'banana')]) #just_masked_sentence = ' '.join([just_masked_sentence, just_masked_sentence, just_masked_sentence, just_masked_sentence, just_masked_sentence]) tiled = just_masked_sentence while len(tiled) < len(text): tiled += ' ' + just_masked_sentence just_masked_sentence = tiled if DEBUG_OUTPUT: print(ctx) print(text) print(just_masked_sentence) toks_para = toker.encode(text) toks_sent = toker.encode(just_masked_sentence) # Get softmaxed logits from sentence alone and full-text # sent_sm, sent_pos, sent_ids = get_span_logits(just_masked_sentence, answer) # para_sm, para_pos, para_ids = get_span_logits(text, answer) sent_sms_all = [] para_sms_all = [] para_sms_right = [] for i in range(MAX_SUBWORDS): para_sms = get_softmax_logits(toks_para, i + 1) para_sms_all.append(para_sms) sent_sms = get_softmax_logits(toks_sent, i + 1) sent_sms_all.append(sent_sms) para_sms_right.append(T.exp((sent_sms[i].log()+para_sms[i].log())/2) * (suffix_mask_inv if i == 0 else suffix_mask)) # Create 2 lists: (1) notes highest probability for each token across n-mask lists if token is suffix and (2) notes number of mask terms to add para_sm_best, para_pos_best = T.max(T.vstack(para_sms_right), 0) distractors = [] stems = [] embs = [] sent_probs = [] para_probs = [] ans_stem = stemmer.stem(answer.lower()) emb_ans = get_emb(toks_para, toker(answer)['input_ids'][1:-1]) para_words = text.lower().split(' ') blank_word_idx = [idx for idx, word in enumerate(text.split(' ')) if toker.mask_token in word][0] # Need to remove punctuation if (blank_word_idx - 1) < 0: prev_word = 'beforeanytext' else: prev_word = para_words[blank_word_idx-1] if (blank_word_idx + 1) >= len(para_words): next_word = 'afteralltext' else: next_word = para_words[blank_word_idx+1] # Need to check if the token is outside of the tokenizer based on predictions being made at all if len(para_sms_all[0]) > 0: top_ctx = T.topk((sent_sms_all[0][0]*word_mask+e).log() - ALPHA * (para_sms_all[0][0]*word_mask+e).log(), len(para_sms_all[0][0]), dim=0) para_top_ids = top_ctx.indices.tolist() para_top_probs = top_ctx.values.tolist() for i, id in enumerate(para_top_ids): sub_ids = [int(id)] # cumulative list of subword token ids dec = toker.decode(sub_ids).strip() if DEBUG_OUTPUT: print('Trying:', dec) #print(para_pos[id]) #if para_pos_best[id] > 0: # continue if dec.isupper() != answer.isupper(): continue if EXTEND_SUBWORDS and para_pos_best[id] > 0: if DEBUG_OUTPUT: print("Extending %s with %d masks..."%(dec, para_pos_best[id])) ext_ids, _ = extend(toks_sent, toks_para, [id], para_pos_best[id], para_words) sub_ids = ext_ids + sub_ids dec_ext = toker.decode(sub_ids).strip() if DEBUG_OUTPUT: print("Extended %s to %s"%(dec, dec_ext)) if is_word(dec_ext) or (dec_ext != '' and dec_ext in para_words): dec = dec_ext # choose new word else: sub_ids = [int(id)] # reset if len(toker.decode(sub_ids).lower().strip()) < 2: continue if dec[0].isupper() != answer[0].isupper(): continue # Only add distractor if it does not contain punctuation #if any(p in dec for p in punctuation): # pass #continue if dec.lower() in profanity: continue # make sure is a word, either in dict or somewhere else in text if not is_word(dec) and dec.lower() not in para_words: continue # make sure is not the same as an adjacent word if dec.lower() == prev_word or dec.lower() == next_word: continue # Don't add the distractor if stem matches another stem = stemmer.stem(dec).lower() if stem in stems or stem == ans_stem: continue # Only add distractor if it does not contain a number if any(char.isdigit() for char in toker.decode([id])): continue # Only add distractor if the distractor exists in the text already if DISTRACTORS_FROM_TEXT and dec.lower() not in para_words: continue #if answer[0].isupper(): # dec = dec.capitalize() # PASSED ALL TESTS; finally add distractor and computations distractors.append(dec) stems.append(stem) sent_logprob = 0 para_logprob = 0 nsubs = len(sub_ids) for j in range(nsubs): sub_id = sub_ids[j] sent_logprob_j = log(sent_sms_all[nsubs-1][j][sub_id]) para_logprob_j = log(para_sms_all[nsubs-1][j][sub_id]) #if j == 0 or sent_logprob_j > sent_logprob: # sent_logprob = sent_logprob_j #if j == 0 or para_logprob_j > para_logprob: # para_logprob = para_logprob_j sent_logprob += sent_logprob_j para_logprob += para_logprob_j sent_logprob /= nsubs para_logprob /= nsubs if DEBUG_OUTPUT: print("%s (p_sent=%f, p_para=%f)"%(dec,sent_logprob,para_logprob)) sent_probs.append(exp(sent_logprob)) para_probs.append(exp(para_logprob)) # sent_probs.append(sent_sms_all[nsubs-1][nsubs-1][sub_id]) # para_probs.append(para_sms_all[nsubs-1][nsubs-1][sub_id]) embs.append(get_emb(toks_para, sub_ids)) if len(distractors) >= K: break if DEBUG_OUTPUT: print('Corresponding Text: ', text) print('Correct Answer: ', answer) print('Distractors created before annealing: ', distractors) #indices = anneal(sent_probs, para_probs, embs, emb_ans, number_of_distractors, distractors, answer) #distractors = [distractors[i] for i in indices] #distractors += [''] * (number_of_distractors - len(distractors)) return sent_probs, para_probs, embs, emb_ans, distractors def create_distractors(text, answer): sent_probs, para_probs, embs, emb_ans, distractors = candidates(text, answer) #print(distractors) indices = anneal(sent_probs, para_probs, embs, emb_ans, distractors, 3, answer) return [distractors[x] for x in indices] st.title("nCloze") st.subheader("Create a multiple-choice cloze test from a passage") st.markdown("Note: this is a free, CPU-only space and will be slow. For better performance, clone the space with a GPU-enabled environment.") def blank(tok): if tok == 'a(n)': strp = tok else: strp = tok.strip(punctuation) print(strp, tok.replace(strp, toker.mask_token)) return strp, tok.replace(strp, toker.mask_token) test = """In contrast to necrosis, which is a form of traumatic cell death that results from acute cellular injury, apoptosis is a highly regulated and controlled process that confers advantages during an organism's life cycle. For example, the separation of fingers and toes in a developing human embryo occurs because cells between the digits undergo apoptosis. Unlike necrosis, apoptosis produces cell fragments called apoptotic bodies that phagocytes are able to engulf and remove before the contents of the cell can spill out onto surrounding cells and cause damage to them.""" st.header("Basic options") SPACING = int(st.text_input('Blank spacing', value="7")) OFFSET = int(st.text_input('First word to blank (0 to use spacing)', value="4")) st.header("Advanced options") ALPHA = float(st.text_input('Incorrectness weight', value="0.75")) BETA = float(st.text_input('Distinctness weight', value="0.25")) MODEL_TYPE = st.text_input('Masked Language Model (from HuggingFace)', value="roberta-large") K = 16 model = AutoModelForMaskedLM.from_pretrained(MODEL_TYPE)#, cache_dir=CACHE_DIR) if T.cuda.is_available(): model.cuda() toker = AutoTokenizer.from_pretrained(MODEL_TYPE, add_prefix_space=True) sorted_toker_vocab_dict = sorted(toker.vocab.items(), key=lambda x:x[1]) if toker.mask_token == '[MASK]': # BERT style suffix_mask = T.FloatTensor([1 if (('##' == x[0][:2]) and (re.match("^[A-Za-z0-9']*$", x[0]) is not None)) else 0 for x in sorted_toker_vocab_dict]) # 1 means is-suffix and 0 mean not-suffix else: # RoBERTa style suffix_mask = T.FloatTensor([1 if (('Ġ' != x[0][0]) and (re.match("^[A-Za-z0-9']*$", x[0]) is not None)) else 0 for x in sorted_toker_vocab_dict]) # 1 means is-suffix and 0 mean not-suffix suffix_mask_inv = suffix_mask * -1 + 1 word_mask = suffix_mask_inv*T.FloatTensor([1 if is_word(x[0][1:]) and x[0][1:].lower() not in profanity else 0 for x in sorted_toker_vocab_dict]) if T.cuda.is_available(): suffix_mask=suffix_mask.cuda() suffix_mask_inv=suffix_mask_inv.cuda() word_mask = word_mask.cuda() st.subheader("Passage") st.text_area('Passage to create a cloze test from:',value=test,key="text", max_chars=1024, height=275) def generate(): ws = st.session_state.text.split() wb = st.session_state.text.split() qs = [] i = OFFSET - 1 if OFFSET > 0 else SPACING - 1 j = 0 while i < len(ws): a, b = blank(ws[i]) while b == '' and i < len(ws)-1: i += 1 a, b = blank(ws[i]) if b != '': w = ws[i] ws[i] = b wb[i] = b while j