Spaces:
Sleeping
Sleeping
"""BERT NER Inference.""" | |
from __future__ import absolute_import, division, print_function | |
import json | |
import os | |
import torch | |
import torch.nn.functional as F | |
from torch.nn import CrossEntropyLoss | |
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset | |
from torch.utils.data.distributed import DistributedSampler | |
from tqdm import tqdm, trange | |
import nltk | |
nltk.download('punkt') | |
from nltk import word_tokenize | |
# from transformers import (BertConfig, BertForTokenClassification, | |
# BertTokenizer) | |
from pytorch_transformers import (BertForTokenClassification, BertTokenizer) | |
class BertNer(BertForTokenClassification): | |
def forward(self, input_ids, token_type_ids=None, attention_mask=None, valid_ids=None): | |
sequence_output = self.bert(input_ids, token_type_ids, attention_mask, head_mask=None)[0] | |
batch_size,max_len,feat_dim = sequence_output.shape | |
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cpu') | |
# valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda' if torch.cuda.is_available() else 'cpu') | |
for i in range(batch_size): | |
jj = -1 | |
for j in range(max_len): | |
if valid_ids[i][j].item() == 1: | |
jj += 1 | |
valid_output[i][jj] = sequence_output[i][j] | |
sequence_output = self.dropout(valid_output) | |
logits = self.classifier(sequence_output) | |
return logits | |
class Ner: | |
def __init__(self,model_dir: str): | |
self.model , self.tokenizer, self.model_config = self.load_model(model_dir) | |
self.label_map = self.model_config["label_map"] | |
self.max_seq_length = self.model_config["max_seq_length"] | |
self.label_map = {int(k):v for k,v in self.label_map.items()} | |
self.device = "cpu" | |
# self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = self.model.to(self.device) | |
self.model.eval() | |
def load_model(self, model_dir: str, model_config: str = "model_config.json"): | |
model_config = os.path.join(model_dir,model_config) | |
model_config = json.load(open(model_config)) | |
model = BertNer.from_pretrained(model_dir) | |
tokenizer = BertTokenizer.from_pretrained(model_dir, do_lower_case=model_config["do_lower"]) | |
return model, tokenizer, model_config | |
def tokenize(self, text: str): | |
""" tokenize input""" | |
words = word_tokenize(text) | |
tokens = [] | |
valid_positions = [] | |
for i,word in enumerate(words): | |
token = self.tokenizer.tokenize(word) | |
tokens.extend(token) | |
for i in range(len(token)): | |
if i == 0: | |
valid_positions.append(1) | |
else: | |
valid_positions.append(0) | |
# print("valid positions from text o/p:=>", valid_positions) | |
return tokens, valid_positions | |
def preprocess(self, text: str): | |
""" preprocess """ | |
tokens, valid_positions = self.tokenize(text) | |
## insert "[CLS]" | |
tokens.insert(0,"[CLS]") | |
valid_positions.insert(0,1) | |
## insert "[SEP]" | |
tokens.append("[SEP]") | |
valid_positions.append(1) | |
segment_ids = [] | |
for i in range(len(tokens)): | |
segment_ids.append(0) | |
input_ids = self.tokenizer.convert_tokens_to_ids(tokens) | |
# print("input ids with berttokenizer:=>", input_ids) | |
input_mask = [1] * len(input_ids) | |
while len(input_ids) < self.max_seq_length: | |
input_ids.append(0) | |
input_mask.append(0) | |
segment_ids.append(0) | |
valid_positions.append(0) | |
return input_ids,input_mask,segment_ids,valid_positions | |
def predict_entity(self, B_lab, I_lab, words, labels, entity_list): | |
temp=[] | |
entity=[] | |
for word, (label, confidence), B_l, I_l in zip(words, labels, B_lab, I_lab): | |
if ((label==B_l) or (label==I_l)) and label!='O': | |
if label==B_l: | |
entity.append(temp) | |
temp=[] | |
temp.append(label) | |
temp.append(word) | |
entity.append(temp) | |
# print(entity) | |
entity_name_label = [] | |
for entity_name in entity[1:]: | |
for ent_key, ent_value in entity_list.items(): | |
if (ent_key==entity_name[0]): | |
# entity_name_label.append(' '.join(entity_name[1:]) + ": " + ent_value) | |
entity_name_label.append([' '.join(entity_name[1:]), ent_value]) | |
return entity_name_label | |
def predict(self, text: str): | |
input_ids,input_mask,segment_ids,valid_ids = self.preprocess(text) | |
# print("valid ids:=>", segment_ids) | |
input_ids = torch.tensor([input_ids],dtype=torch.long,device=self.device) | |
input_mask = torch.tensor([input_mask],dtype=torch.long,device=self.device) | |
segment_ids = torch.tensor([segment_ids],dtype=torch.long,device=self.device) | |
valid_ids = torch.tensor([valid_ids],dtype=torch.long,device=self.device) | |
with torch.no_grad(): | |
logits = self.model(input_ids, segment_ids, input_mask,valid_ids) | |
# print("logit values:=>", logits) | |
logits = F.softmax(logits,dim=2) | |
# print("logit values:=>", logits[0]) | |
logits_label = torch.argmax(logits,dim=2) | |
logits_label = logits_label.detach().cpu().numpy().tolist()[0] | |
# print("logits label value list:=>", logits_label) | |
logits_confidence = [values[label].item() for values,label in zip(logits[0],logits_label)] | |
logits = [] | |
pos = 0 | |
for index,mask in enumerate(valid_ids[0]): | |
if index == 0: | |
continue | |
if mask == 1: | |
logits.append((logits_label[index-pos],logits_confidence[index-pos])) | |
else: | |
pos += 1 | |
logits.pop() | |
labels = [(self.label_map[label],confidence) for label,confidence in logits] | |
words = word_tokenize(text) | |
entity_list = {'B-PER':'Person', | |
'B-FAC':'Facility', | |
'B-LOC':'Location', | |
'B-ORG':'Organization', | |
'B-ART':'Work Of Art', | |
'B-EVENT':'Event', | |
'B-DATE':'Date-Time Entity', | |
'B-TIME':'Date-Time Entity', | |
'B-LAW':'Law Terms', | |
'B-PRODUCT':'Product', | |
'B-PERCENT':'Percentage', | |
'B-MONEY':'Currency', | |
'B-LANGUAGE':'Langauge', | |
'B-NORP':'Nationality / Religion / Political group', | |
'B-QUANTITY':'Quantity', | |
'B-ORDINAL':'Ordinal Number', | |
'B-CARDINAL':'Cardinal Number'} | |
B_labels=[] | |
I_labels=[] | |
for label, confidence in labels: | |
if (label[:1]=='B'): | |
B_labels.append(label) | |
I_labels.append('O') | |
elif (label[:1]=='I'): | |
I_labels.append(label) | |
B_labels.append('O') | |
else: | |
B_labels.append('O') | |
I_labels.append('O') | |
assert len(labels) == len(words) == len(I_labels) == len(B_labels) | |
output = self.predict_entity(B_labels, I_labels, words, labels, entity_list) | |
print(output) | |
# output = [{"word":word,"tag":label,"confidence":confidence} for word,(label,confidence) in zip(words,labels)] | |
return output | |