ConfliBERT-Demo / app.py
salsarra's picture
Update app.py
3ee3346 verified
raw
history blame contribute delete
No virus
11.8 kB
import torch
import tensorflow as tf
from tf_keras import models, layers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, TFAutoModelForQuestionAnswering
import gradio as gr
import re
# Check if GPU is available and use it if possible
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the models and tokenizers
qa_model_name = 'salsarra/ConfliBERT-QA'
qa_model = TFAutoModelForQuestionAnswering.from_pretrained(qa_model_name)
qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
ner_model_name = 'eventdata-utd/conflibert-named-entity-recognition'
ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name).to(device)
ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
clf_model_name = 'eventdata-utd/conflibert-binary-classification'
clf_model = AutoModelForSequenceClassification.from_pretrained(clf_model_name).to(device)
clf_tokenizer = AutoTokenizer.from_pretrained(clf_model_name)
multi_clf_model_name = 'eventdata-utd/conflibert-satp-relevant-multilabel'
multi_clf_model = AutoModelForSequenceClassification.from_pretrained(multi_clf_model_name).to(device)
multi_clf_tokenizer = AutoTokenizer.from_pretrained(multi_clf_model_name)
# Define the class names for text classification
class_names = ['Negative', 'Positive']
multi_class_names = ["Armed Assault", "Bombing or Explosion", "Kidnapping", "Other"] # Updated labels
# Define the NER labels and colors
ner_labels = {
'Organisation': 'blue',
'Person': 'red',
'Location': 'green',
'Quantity': 'orange',
'Weapon': 'purple',
'Nationality': 'cyan',
'Temporal': 'magenta',
'DocumentReference': 'brown',
'MilitaryPlatform': 'yellow',
'Money': 'pink'
}
def handle_error_message(e, default_limit=512):
error_message = str(e)
pattern = re.compile(r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)")
match = pattern.search(error_message)
if match:
number_1, number_2 = match.groups()
return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>"
pattern_qa = re.compile(r"indices\[0,(\d+)\] = \d+ is not in \[0, (\d+)\)")
match_qa = pattern_qa.search(error_message)
if match_qa:
number_1, number_2 = match_qa.groups()
return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>"
return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size is larger than model limits of {default_limit}</span>"
# Define the functions for each task
def question_answering(context, question):
try:
inputs = qa_tokenizer(question, context, return_tensors='tf', truncation=True)
outputs = qa_model(inputs)
answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0]
answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1
answer = qa_tokenizer.convert_tokens_to_string(qa_tokenizer.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]))
return f"<span style='color: green; font-weight: bold;'>{answer}</span>"
except Exception as e:
return handle_error_message(e)
def replace_unk(tokens):
return [token.replace('[UNK]', "'") for token in tokens]
def named_entity_recognition(text):
try:
inputs = ner_tokenizer(text, return_tensors='pt', truncation=True)
with torch.no_grad():
outputs = ner_model(**inputs)
ner_results = outputs.logits.argmax(dim=2).squeeze().tolist()
tokens = ner_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze().tolist())
tokens = replace_unk(tokens)
entities = []
seen_labels = set()
for i in range(len(tokens)):
token = tokens[i]
label = ner_model.config.id2label[ner_results[i]].split('-')[-1]
if token.startswith('##'):
if entities:
entities[-1][0] += token[2:]
else:
entities.append([token, label])
if label != 'O':
seen_labels.add(label)
highlighted_text = ""
for token, label in entities:
color = ner_labels.get(label, 'black')
if label != 'O':
highlighted_text += f"<span style='color: {color}; font-weight: bold;'>{token}</span> "
else:
highlighted_text += f"{token} "
legend = "<div><strong>NER Tags Found:</strong><ul style='list-style-type: disc; padding-left: 20px;'>"
for label in seen_labels:
color = ner_labels.get(label, 'black')
legend += f"<li style='color: {color}; font-weight: bold;'>{label}</li>"
legend += "</ul></div>"
return f"<div>{highlighted_text}</div>{legend}"
except Exception as e:
return handle_error_message(e)
def text_classification(text):
try:
inputs = clf_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = clf_model(**inputs)
logits = outputs.logits.squeeze().tolist()
predicted_class = torch.argmax(outputs.logits, dim=1).item()
confidence = torch.softmax(outputs.logits, dim=1).max().item() * 100
if predicted_class == 1: # Positive class
result = f"<span style='color: green; font-weight: bold;'>Positive: The text is related to conflict, violence, or politics. (Confidence: {confidence:.2f}%)</span>"
else: # Negative class
result = f"<span style='color: red; font-weight: bold;'>Negative: The text is not related to conflict, violence, or politics. (Confidence: {confidence:.2f}%)</span>"
return result
except Exception as e:
return handle_error_message(e)
def multilabel_classification(text):
try:
inputs = multi_clf_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = multi_clf_model(**inputs)
predicted_classes = torch.sigmoid(outputs.logits).squeeze().tolist()
if len(predicted_classes) != len(multi_class_names):
return f"Error: Number of predicted classes ({len(predicted_classes)}) does not match number of class names ({len(multi_class_names)})."
results = []
for i in range(len(predicted_classes)):
confidence = predicted_classes[i] * 100
if predicted_classes[i] >= 0.5:
results.append(f"<span style='color: green; font-weight: bold;'>{multi_class_names[i]} (Confidence: {confidence:.2f}%)</span>")
else:
results.append(f"<span style='color: red; font-weight: bold;'>{multi_class_names[i]} (Confidence: {confidence:.2f}%)</span>")
return " / ".join(results)
except Exception as e:
return handle_error_message(e)
# Define the Gradio interface
def chatbot(task, text=None, context=None, question=None):
if task == "Question Answering":
if context and question:
return question_answering(context, question)
else:
return "Please provide both context and question for the Question Answering task."
elif task == "Named Entity Recognition":
if text:
return named_entity_recognition(text)
else:
return "Please provide text for the Named Entity Recognition task."
elif task == "Text Classification":
if text:
return text_classification(text)
else:
return "Please provide text for the Text Classification task."
elif task == "Multilabel Classification":
if text:
return multilabel_classification(text)
else:
return "Please provide text for the Multilabel Classification task."
else:
return "Please select a valid task."
css = """
body {
background-color: #f0f8ff;
font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
}
h1 {
color: #2e8b57;
text-align: center;
font-size: 2em;
}
h2 {
color: #ff8c00;
text-align: center;
font-size: 1.5em;
}
.gradio-container {
max-width: 100%;
margin: 10px auto;
padding: 10px;
background-color: #ffffff;
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}
.gr-input, .gr-output {
background-color: #ffffff;
border: 1px solid #ddd;
border-radius: 5px;
padding: 10px;
font-size: 1em;
}
.gr-title {
font-size: 1.5em;
font-weight: bold;
color: #2e8b57;
margin-bottom: 10px;
text-align: center;
}
.gr-description {
font-size: 1.2em;
color: #ff8c00;
margin-bottom: 10px;
text-align: center;
}
.header {
display: flex;
justify-content: center;
align-items: center;
padding: 10px;
flex-wrap: wrap;
}
.header-title-center a {
font-size: 4em; /* Increased font size */
font-weight: bold; /* Made text bold */
color: darkorange; /* Darker orange color */
text-align: center;
display: block;
}
.gr-button {
background-color: #ff8c00;
color: white;
border: none;
padding: 10px 20px;
font-size: 1em;
border-radius: 5px;
cursor: pointer;
}
.gr-button:hover {
background-color: #ff4500;
}
.footer {
text-align: center;
margin-top: 10px;
font-size: 0.9em; /* Updated font size */
color: #666;
width: 100%;
}
.footer a {
color: #2e8b57;
font-weight: bold;
text-decoration: none;
}
.footer a:hover {
text-decoration: underline;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Row(elem_id="header"):
gr.Markdown("<div class='header-title-center'><a href='https://eventdata.utdallas.edu/conflibert/'>ConfliBERT</a></div>", elem_id="header-title-center")
gr.Markdown("<span style='color: black;'>Select a task and provide the necessary inputs:</span>")
task = gr.Dropdown(choices=["Question Answering", "Named Entity Recognition", "Text Classification", "Multilabel Classification"], label="Select Task")
with gr.Row():
text_input = gr.Textbox(lines=5, placeholder="Enter the text here...", label="Text")
context_input = gr.Textbox(lines=5, placeholder="Enter the context here...", label="Context", visible=False)
question_input = gr.Textbox(lines=2, placeholder="Enter your question here...", label="Question", visible=False)
output = gr.HTML(label="Output")
def update_inputs(task):
if task == "Question Answering":
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
else:
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
task.change(fn=update_inputs, inputs=task, outputs=[text_input, context_input, question_input])
def chatbot_interface(task, text, context, question):
result = chatbot(task, text, context, question)
return result
submit_button = gr.Button("Submit", elem_id="gr-button")
submit_button.click(fn=chatbot_interface, inputs=[task, text_input, context_input, question_input], outputs=output)
gr.Markdown("<div class='footer'><a href='https://eventdata.utdallas.edu/'>UTD Event Data</a> | <a href='https://www.utdallas.edu/'>University of Texas at Dallas</a></div>")
gr.Markdown("<div class='footer'>Developed By: <a href='https://www.linkedin.com/in/sultan-alsarra-phd-56977a63/' target='_blank'>Sultan Alsarra</a></div>")
demo.launch(share=True)