voicelead / nameder.py
zeimoto's picture
working
8786bb1
raw
history blame
No virus
3.6 kB
from typing import List
from resources import set_start, audit_elapsedtime, entities_list_to_dict
from transformers import BertTokenizer, BertForTokenClassification
import torch
from gliner import GLiNER
#Named-Entity Recognition model
def init_model_ner():
print("Initiating NER model...")
start = set_start()
model = GLiNER.from_pretrained("urchade/gliner_multi")
audit_elapsedtime(function="Initiating NER model", start=start)
return model
def get_entity_results(model: GLiNER, text: str, entities_list: List[str]): #-> Lead_labels:
print("Initiating entity recognition...")
start = set_start()
labels = entities_list
entities_result = model.predict_entities(text, labels)
entities_dict = entities_list_to_dict(entities_list)
for entity in entities_result:
print(entity["label"], "=>", entity["text"])
entities_dict[entity["label"]] = entity["text"]
audit_elapsedtime(function="Retreiving entity labels from text", start=start)
return entities_dict
def init_model_ner_v2():
print("Initiating NER model...")
start = set_start()
# Load pre-trained tokenizer and model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
audit_elapsedtime(function="Initiating NER model", start=start)
return tokenizer, model
def get_entity_results_v2(tokenizer, model, text: str, entities_list: List[str]): #-> Lead_labels:
print("Initiating entity recognition...")
start = set_start()
tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(text)))
labels = entities_list#["Apple Inc.", "American", "Cupertino", "California"]#entities_list
print("tokens line 24:",tokens)
# Convert tokens to IDs
input_ids = tokenizer.encode(text, return_tensors="pt")
print("input_ids line 27:",input_ids)
# Perform NER prediction
with torch.no_grad():
outputs = model(input_ids)
print("outputs line 31:",outputs)
# Get the predicted labels
predicted_labels = torch.argmax(outputs.logits, dim=2)[0]
print("predicted_labels line 35:",predicted_labels)
# Map predicted labels to actual entities
entities = []
current_entity = ""
for i, label_id in enumerate(predicted_labels):
label = model.config.id2label[label_id.item()]
print(f"i[{i}], label[{label}], label_id[{label_id}]")
token = tokens[i]
if label.startswith('B-'): # Beginning of a new entity
print(token)
if current_entity:
entities.append(current_entity.strip())
current_entity = token
elif label.startswith('I-'): # Inside of an entity
print(token)
current_entity += " " + token
else: # Outside of any entity
if current_entity:
entities.append(current_entity.strip())
current_entity = ""
# Filter out only the entities you are interested in
filtered_entities = [entity for entity in entities if entity in labels]
print("filtered_entities line 56:",filtered_entities)
# entities_result = model.predict_entities(text, labels)
# entities_dict = entities_list_to_dict(entities_list)
# for entity in entities_result:
# print(entity["text"], "=>", entity["label"])
# entities_dict[entity["label"]] = entity["text"]
audit_elapsedtime(function="Retreiving entity labels from text", start=start)
return filtered_entities