Clinical adapters / assertion classification
Model description
This is adapter for assertion classification trained on the i2b2 2010/ VA data. Clinical assertion classification is an NLP task that involves determining the level of certainty expressed in clinical text.
Assertions are attributes of medical problem concepts or entities that classify the presence, absence, possibility, conditionality, hypothetical nature, or lack of association of a medical problem with the patient. The task is to accurately classify identified medical problems based on the context of the text. This model classify entities into 3 classes : PRESENT, ABSENT and POSSIBLE.
How to use the model
from transformers import AutoTokenizer
from adapters import AutoAdapterModel
model_adapter = "clinical-adapters/n2c2-pfieffer-adapter-bert-ast"
pretrained_model_name_or_path = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, model_max_length=150)
special_tokens_dict = {"additional_special_tokens": ["[entity]"]}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict,False)
model = AutoAdapterModel.from_pretrained(pretrained_model_name_or_path = pretrained_model_name_or_path,
num_labels=3 , id2label = {0: 'PRESENT', 1: 'ABSENT', 2:'POSSIBLE'})
model.resize_token_embeddings(len(tokenizer))
ast = model.load_adapter(model_adapter,with_head=True)
model.active_adapters = ast
You need to pass the model a sentence after enclosing the medical problem with a special indicator token, "[entity]".
import torch
id2label = {0: 'PRESENT', 1: 'ABSENT', 2:'POSSIBLE'}
sentence = [ "Patient denies [entity] SOB [entity]",
"Patient do not have [entity] fever [entity]",
"had [entity] abnormal ett [entity] and referred for cath",
"The patient recovered during the night and now denies any [entity] shortness of breath [entity].",
"Patient with [entity] severe fever [entity].",
"Patient should abstain from [entity] painkillers [entity]"]
model.to('cpu')
for s in sentence :
tokenized_input = tokenizer(s, return_tensors="pt", padding=True)
outputs = model(**tokenized_input)
predicted_labels = torch.argmax(outputs.logits, dim=1)
print(id2label[predicted_labels.item()])
- Downloads last month
- 7