Spaces:
Runtime error
Runtime error
File size: 10,089 Bytes
854a552 2485275 854a552 2485275 854a552 2485275 854a552 2485275 854a552 2485275 854a552 2485275 854a552 281974e 854a552 2485275 854a552 2485275 854a552 2f4e364 854a552 2485275 854a552 2485275 854a552 23a2b0d 854a552 2485275 a1dc9c7 2485275 90a2a07 23a2b0d 854a552 23a2b0d 2485275 a1dc9c7 90a2a07 b842440 a1dc9c7 854a552 2485275 854a552 edf6d87 281974e 2485275 23a2b0d 96c85ad 6d54c78 98c32e8 6d54c78 981bdcf c79d9d8 2485275 854a552 edf6d87 23a2b0d 2485275 23a2b0d 96c85ad 98c32e8 348b1aa 6d54c78 981bdcf 3094195 2485275 854a552 2485275 854a552 2485275 562b4f7 90a2a07 2485275 90a2a07 2485275 854a552 2485275 1cc4362 44fbf6a 1cc4362 854a552 90a2a07 854a552 2485275 854a552 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 |
import time
import streamlit as st
import torch
import string
from annotated_text import annotated_text
from flair.data import Sentence
from flair.models import SequenceTagger
from transformers import BertTokenizer, BertForMaskedLM
import BatchInference as bd
import batched_main_NER as ner
import aggregate_server_json as aggr
import json
DEFAULT_TOP_K = 20
SPECIFIC_TAG=":__entity__"
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def POS_get_model(model_name):
val = SequenceTagger.load(model_name) # Load the model
return val
def getPos(s: Sentence):
texts = []
labels = []
for t in s.tokens:
for label in t.annotation_layers.keys():
texts.append(t.text)
labels.append(t.get_labels(label)[0].value)
return texts, labels
def getDictFromPOS(texts, labels):
return [["dummy",t,l,"dummy","dummy" ] for t, l in zip(texts, labels)]
def decode(tokenizer, pred_idx, top_clean):
ignore_tokens = string.punctuation + '[PAD]'
tokens = []
for w in pred_idx:
token = ''.join(tokenizer.decode(w).split())
if token not in ignore_tokens:
tokens.append(token.replace('##', ''))
return '\n'.join(tokens[:top_clean])
def encode(tokenizer, text_sentence, add_special_tokens=True):
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
# if <mask> is the last token, append a "." so that models dont predict punctuation.
if tokenizer.mask_token == text_sentence.split()[-1]:
text_sentence += ' .'
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
return input_ids, mask_idx
def get_all_predictions(text_sentence, top_clean=5):
# ========================= BERT =================================
input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
with torch.no_grad():
predict = bert_model(input_ids)[0]
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
return {'bert': bert}
def get_bert_prediction(input_text,top_k):
try:
input_text += ' <mask>'
res = get_all_predictions(input_text, top_clean=int(top_k))
return res
except Exception as error:
pass
def load_pos_model():
checkpoint = "flair/pos-english"
return POS_get_model(checkpoint)
def init_session_states():
if 'top_k' not in st.session_state:
st.session_state['top_k'] = 20
if 'pos_model' not in st.session_state:
st.session_state['pos_model'] = None
if 'phi_model' not in st.session_state:
st.session_state['phi_model'] = None
if 'ner_phi' not in st.session_state:
st.session_state['ner_phi'] = None
if 'aggr' not in st.session_state:
st.session_state['aggr'] = None
def get_pos_arr(input_text,display_area):
if (st.session_state['pos_model'] is None):
display_area.text("Loading model 2 of 2.Loading POS model...")
st.session_state['pos_model'] = load_pos_model()
s = Sentence(input_text)
st.session_state['pos_model'].predict(s)
texts, labels = getPos(s)
pos_results = getDictFromPOS(texts, labels)
return pos_results
def perform_inference(text,display_area):
if (st.session_state['phi_model'] is None):
display_area.text("Loading model 1 of 2. BERT model...")
st.session_state['phi_model'] = bd.BatchInference("bbc/desc_bbc_config.json",'bert-base-cased',False,False,DEFAULT_TOP_K,True,True, "bbc/","bbc/bbc_labels.txt",False)
#Load POS model if needed and gets POS tags
if (SPECIFIC_TAG not in text):
pos_arr = get_pos_arr(text,display_area)
else:
pos_arr = None
if (st.session_state['ner_phi'] is None):
display_area.text("Initializing BERT module...")
st.session_state['ner_phi'] = ner.UnsupNER("bbc/ner_bbc_config.json")
if (st.session_state['aggr'] is None):
display_area.text("Initializing Aggregation modeule...")
st.session_state['aggr'] = aggr.AggregateNER("./ensemble_config.json")
display_area.text("Getting predictions from BERT model...")
phi_results = st.session_state['phi_model'].get_descriptors(text,pos_arr)
display_area.text("Computing NER results...")
display_area.text("Consolidating responses...")
phi_ner = st.session_state['ner_phi'].tag_sentence_service(text,phi_results)
obj = json.loads(phi_ner)
combined_arr = [obj,obj]
aggregate_results = st.session_state['aggr'].fetch_all(text,combined_arr)
return aggregate_results
sent_arr = [
"Washington resigned from Washington and flew out of Washington",
"John Doe flew from New York to Rio De Janiro ",
"Stanford called",
"I met my girl friends at the pub ",
"I met my New York friends at the pub",
"I met my XCorp friends at the pub",
"I met my two friends at the pub",
"The sky turned dark in advance of the storm that was coming from the east ",
"She loves to watch Sunday afternoon football with her family ",
"The United States has the largest prison population in the world, and the highest per-capita incarceration rate",
"He went to a local theater and watched Jaws before the Covid-19 lockdown",
"He converted to Christianity towards the end of his life after being a Buddhist",
"Dr Fyodor Dostovetsky advocates wearing masks in public places to reduce the risk of contracting covid",
"He graduated from Stanford with a master's degree in Physics and Astronomy",
"The Seahawks had a tough year losing almost all the games",
"In 2020 , John participated in the Winter Olympics and came third in Ice hockey",
"Paul Erdos died at 83 "
]
sent_arr_masked = [
"Washington:__entity__ resigned from Washington:__entity__ and flew out of Washington:__entity__",
"John:__entity__ Doe:__entity__ flew from New:__entity__ York:__entity__ to Rio:__entity__ De:__entity__ Janiro:__entity__ ",
"Stanford:__entity__ called",
"I met my girl:__entity__ friends at the pub ",
"I met my New:__entity__ York:__entity__ friends at the pub",
"I met my XCorp:__entity__ friends at the pub",
"I met my two:__entity__ friends at the pub",
"The sky turned dark:__entity__ in advance of the storm that was coming from the east ",
"She loves to watch Sunday afternoon football:__entity__ with her family ",
"The United:__entity__ States:__entity__ has the largest prison population in the world, and the highest per-capita incarceration:__entity__ rate:__entity__",
"He went to a local theater and watched Jaws:__entity__ before the Covid-19 lockdown",
"He converted to christianity:__entity__ towards the end of his life after being a buddhist:__entity__",
"Dr:__entity__ Fyodor:__entity__ Dostovetsky:__entity__ advocates wearing masks:__entity__ in public places to reduce the risk of contracting covid",
"He graduated from Stanford:__entity__ with a master's degree in Physics:__entity__ and Astronomy:__entity__",
"The Seahawks:__entity__ had a tough year losing almost all the games",
"In 2020:__entity__ , John:__entity__ participated in the Winter:__entity__ Olympics:__entity__ and came third:__entity__ in Ice:__entity__ hockey:__entity__",
"Paul:__entity__ Erdos:__entity__ died at 83:__entity__ "
]
def init_selectbox():
return st.selectbox(
'Choose any of the sentences in pull-down below',
sent_arr,key='my_choice')
def on_text_change():
text = st.session_state.my_text
print("in callback: " + text)
perform_inference(text)
def main():
try:
init_session_states()
st.markdown("<h4 style='text-align: center;'>NER of PERSON,LOCATION,ORG etc.</h4>", unsafe_allow_html=True)
st.markdown("<h5 style='text-align: center;'>Using a pretrained BERT model with <a href='https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html'>no fine tuning</a><br/><br/></h5>", unsafe_allow_html=True)
st.write("This app uses 2 models. Bert-base-cased(**no fine tuning**) and a POS tagger")
with st.form('my_form'):
selected_sentence = init_selectbox()
text_input = st.text_area(label='Type any sentence below',value="")
submit_button = st.form_submit_button('Submit')
input_status_area = st.empty()
display_area = st.empty()
if submit_button:
start = time.time()
if (len(text_input) == 0):
text_input = sent_arr_masked[sent_arr.index(selected_sentence)]
input_status_area.text("Input sentence: " + text_input)
results = perform_inference(text_input,display_area)
display_area.empty()
with display_area.container():
st.text(f"prediction took {time.time() - start:.2f}s")
st.json(results)
st.markdown("""
<small style="font-size:16px; color: #8f8f8f; text-align: left"><i><b>Note:</b> The example sentences in the pull-down above largely tests PHI entities. Biomedical entities are not tested since this model does not perform well on biomedical entities. To see valid predictions for both biomedical and PHI entities <a href='https://huggingface.co/spaces/ajitrajasekharan/NER-Biomedical-PHI-Ensemble' target='_blank'>use this ensemble app</a></i></small>
""", unsafe_allow_html=True)
st.markdown("""
<small style="font-size:16px; color: #7f7f7f; text-align: left"><br/><br/>Models used: <br/>(1) Bert-base-cased (for PHI entities - Person/location/organization etc.)<br/>(2) Flair POS tagger</small>
#""", unsafe_allow_html=True)
st.markdown("""
<h3 style="font-size:16px; color: #9f9f9f; text-align: center"><b> <a href='https://huggingface.co/spaces/ajitrajasekharan/Qualitative-pretrained-model-evaluation' target='_blank'>App link to examine pretrained models</a> used to perform NER without fine tuning</b></h3>
""", unsafe_allow_html=True)
st.markdown("""
<h3 style="font-size:16px; color: #9f9f9f; text-align: center">Github <a href='http://github.com/ajitrajasekharan/unsupervised_NER' target='_blank'>link to same working code </a>(without UI) as separate microservices</h3>
""", unsafe_allow_html=True)
except Exception as e:
print("Some error occurred in main")
st.exception(e)
if __name__ == "__main__":
main()
|