import torch import tensorflow as tf from tf_keras import models, layers from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, TFAutoModelForQuestionAnswering import gradio as gr import re # Check if GPU is available and use it if possible device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load the models and tokenizers qa_model_name = 'salsarra/ConfliBERT-QA' qa_model = TFAutoModelForQuestionAnswering.from_pretrained(qa_model_name) qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name) ner_model_name = 'eventdata-utd/conflibert-named-entity-recognition' ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name).to(device) ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name) clf_model_name = 'eventdata-utd/conflibert-binary-classification' clf_model = AutoModelForSequenceClassification.from_pretrained(clf_model_name).to(device) clf_tokenizer = AutoTokenizer.from_pretrained(clf_model_name) multi_clf_model_name = 'eventdata-utd/conflibert-satp-relevant-multilabel' multi_clf_model = AutoModelForSequenceClassification.from_pretrained(multi_clf_model_name).to(device) multi_clf_tokenizer = AutoTokenizer.from_pretrained(multi_clf_model_name) # Define the class names for text classification class_names = ['Negative', 'Positive'] multi_class_names = ["Armed Assault", "Bombing or Explosion", "Kidnapping", "Other"] # Updated labels # Define the NER labels and colors ner_labels = { 'Organisation': 'blue', 'Person': 'red', 'Location': 'green', 'Quantity': 'orange', 'Weapon': 'purple', 'Nationality': 'cyan', 'Temporal': 'magenta', 'DocumentReference': 'brown', 'MilitaryPlatform': 'yellow', 'Money': 'pink' } def handle_error_message(e, default_limit=512): error_message = str(e) pattern = re.compile(r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)") match = pattern.search(error_message) if match: number_1, number_2 = match.groups() return f"Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}" pattern_qa = re.compile(r"indices\[0,(\d+)\] = \d+ is not in \[0, (\d+)\)") match_qa = pattern_qa.search(error_message) if match_qa: number_1, number_2 = match_qa.groups() return f"Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}" return f"Error: Text Input is over limit where inserted text size is larger than model limits of {default_limit}" # Define the functions for each task def question_answering(context, question): try: inputs = qa_tokenizer(question, context, return_tensors='tf', truncation=True) outputs = qa_model(inputs) answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0] answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1 answer = qa_tokenizer.convert_tokens_to_string(qa_tokenizer.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end])) return f"{answer}" except Exception as e: return handle_error_message(e) def replace_unk(tokens): return [token.replace('[UNK]', "'") for token in tokens] def named_entity_recognition(text): try: inputs = ner_tokenizer(text, return_tensors='pt', truncation=True) with torch.no_grad(): outputs = ner_model(**inputs) ner_results = outputs.logits.argmax(dim=2).squeeze().tolist() tokens = ner_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze().tolist()) tokens = replace_unk(tokens) entities = [] seen_labels = set() for i in range(len(tokens)): token = tokens[i] label = ner_model.config.id2label[ner_results[i]].split('-')[-1] if token.startswith('##'): if entities: entities[-1][0] += token[2:] else: entities.append([token, label]) if label != 'O': seen_labels.add(label) highlighted_text = "" for token, label in entities: color = ner_labels.get(label, 'black') if label != 'O': highlighted_text += f"{token} " else: highlighted_text += f"{token} " legend = "
NER Tags Found:
" return f"
{highlighted_text}
{legend}" except Exception as e: return handle_error_message(e) def text_classification(text): try: inputs = clf_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device) with torch.no_grad(): outputs = clf_model(**inputs) logits = outputs.logits.squeeze().tolist() predicted_class = torch.argmax(outputs.logits, dim=1).item() confidence = torch.softmax(outputs.logits, dim=1).max().item() * 100 if predicted_class == 1: # Positive class result = f"Positive: The text is related to conflict, violence, or politics. (Confidence: {confidence:.2f}%)" else: # Negative class result = f"Negative: The text is not related to conflict, violence, or politics. (Confidence: {confidence:.2f}%)" return result except Exception as e: return handle_error_message(e) def multilabel_classification(text): try: inputs = multi_clf_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device) with torch.no_grad(): outputs = multi_clf_model(**inputs) predicted_classes = torch.sigmoid(outputs.logits).squeeze().tolist() if len(predicted_classes) != len(multi_class_names): return f"Error: Number of predicted classes ({len(predicted_classes)}) does not match number of class names ({len(multi_class_names)})." results = [] for i in range(len(predicted_classes)): confidence = predicted_classes[i] * 100 if predicted_classes[i] >= 0.5: results.append(f"{multi_class_names[i]} (Confidence: {confidence:.2f}%)") else: results.append(f"{multi_class_names[i]} (Confidence: {confidence:.2f}%)") return " / ".join(results) except Exception as e: return handle_error_message(e) # Define the Gradio interface def chatbot(task, text=None, context=None, question=None): if task == "Question Answering": if context and question: return question_answering(context, question) else: return "Please provide both context and question for the Question Answering task." elif task == "Named Entity Recognition": if text: return named_entity_recognition(text) else: return "Please provide text for the Named Entity Recognition task." elif task == "Text Classification": if text: return text_classification(text) else: return "Please provide text for the Text Classification task." elif task == "Multilabel Classification": if text: return multilabel_classification(text) else: return "Please provide text for the Multilabel Classification task." else: return "Please select a valid task." css = """ body { background-color: #f0f8ff; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; } h1 { color: #2e8b57; text-align: center; font-size: 2em; } h2 { color: #ff8c00; text-align: center; font-size: 1.5em; } .gradio-container { max-width: 100%; margin: 10px auto; padding: 10px; background-color: #ffffff; border-radius: 10px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); } .gr-input, .gr-output { background-color: #ffffff; border: 1px solid #ddd; border-radius: 5px; padding: 10px; font-size: 1em; } .gr-title { font-size: 1.5em; font-weight: bold; color: #2e8b57; margin-bottom: 10px; text-align: center; } .gr-description { font-size: 1.2em; color: #ff8c00; margin-bottom: 10px; text-align: center; } .header { display: flex; justify-content: center; align-items: center; padding: 10px; flex-wrap: wrap; } .header-title-center a { font-size: 4em; /* Increased font size */ font-weight: bold; /* Made text bold */ color: darkorange; /* Darker orange color */ text-align: center; display: block; } .gr-button { background-color: #ff8c00; color: white; border: none; padding: 10px 20px; font-size: 1em; border-radius: 5px; cursor: pointer; } .gr-button:hover { background-color: #ff4500; } .footer { text-align: center; margin-top: 10px; font-size: 0.9em; /* Updated font size */ color: #666; width: 100%; } .footer a { color: #2e8b57; font-weight: bold; text-decoration: none; } .footer a:hover { text-decoration: underline; } """ with gr.Blocks(css=css) as demo: with gr.Row(elem_id="header"): gr.Markdown("
ConfliBERT
", elem_id="header-title-center") gr.Markdown("Select a task and provide the necessary inputs:") task = gr.Dropdown(choices=["Question Answering", "Named Entity Recognition", "Text Classification", "Multilabel Classification"], label="Select Task") with gr.Row(): text_input = gr.Textbox(lines=5, placeholder="Enter the text here...", label="Text") context_input = gr.Textbox(lines=5, placeholder="Enter the context here...", label="Context", visible=False) question_input = gr.Textbox(lines=2, placeholder="Enter your question here...", label="Question", visible=False) output = gr.HTML(label="Output") def update_inputs(task): if task == "Question Answering": return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) else: return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) task.change(fn=update_inputs, inputs=task, outputs=[text_input, context_input, question_input]) def chatbot_interface(task, text, context, question): result = chatbot(task, text, context, question) return result submit_button = gr.Button("Submit", elem_id="gr-button") submit_button.click(fn=chatbot_interface, inputs=[task, text_input, context_input, question_input], outputs=output) gr.Markdown("") gr.Markdown("") demo.launch(share=True)