Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
""" | |
Created on Tue Sep 17 19:03:17 2024 | |
@author: SABARI | |
""" | |
import os | |
import torch | |
from transformers import AutoConfig | |
from transformers.models.roberta.modeling_roberta import RobertaForTokenClassification | |
from datasets import Dataset | |
from torch.utils.data import DataLoader | |
from transformers import AutoTokenizer | |
import spacy | |
from spacy.tokens import Doc, Span | |
from spacy import displacy | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class JapaneseNER(): | |
def __init__(self, model_path, model_name="xlm-roberta-base"): | |
self._index_to_tag = {0: 'O', | |
1: 'PER', | |
2: 'ORG', | |
3: 'ORG-P', | |
4: 'ORG-O', | |
5: 'LOC', | |
6: 'INS', | |
7: 'PRD', | |
8: 'EVT'} | |
self._tag_to_index = {v: k for k, v in self._index_to_tag.items()} | |
self._tag_feature_num_classes = len(self._index_to_tag) | |
self._model_name = model_name | |
self._model_path = model_path | |
xlmr_config = AutoConfig.from_pretrained( | |
self._model_name, | |
num_labels=self._tag_feature_num_classes, | |
id2label=self._index_to_tag, | |
label2id=self._tag_to_index | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained(self._model_name) | |
self.model = (RobertaForTokenClassification | |
.from_pretrained(self._model_path, config=xlmr_config) | |
.to(device)) | |
def prepare(self): | |
# Create dataset for prediction | |
sample_encoding = self.tokenizer([ | |
"鈴木は4月の陽気の良い日に、鈴をつけて熊本県の阿蘇山に登った", | |
"中国では、中国共産党による一党統治が続く", | |
], truncation=True, padding=True, # Ensure all sequences are of the same length | |
max_length=512, return_tensors="pt") | |
sample_encoding = {k: v.to(device) for k, v in sample_encoding.items()} | |
# Perform prediction | |
with torch.no_grad(): | |
output = self.model(**sample_encoding) | |
predicted_label_id = torch.argmax(output.logits, axis=-1).cpu().numpy()[0] | |
print("Predicted labels:", predicted_label_id) | |
def predict(self, text): | |
encoding = self.tokenizer([text], truncation=True, padding=True, max_length=512, return_tensors="pt") | |
encoding = {k: v.to(device) for k, v in encoding.items()} | |
# Perform prediction | |
with torch.no_grad(): | |
output = self.model(**encoding) | |
# Get the predicted label ids | |
predicted_label_id = torch.argmax(output.logits, axis=-1).cpu().numpy()[0] | |
tokens = self.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0]) | |
# Map the predicted labels to their corresponding tag | |
predictions = [self._index_to_tag[label_id] for label_id in predicted_label_id] | |
return tokens, predictions | |
# Instantiate the NER model | |
model_path = "./trained_ner_classifier_jp/" | |
ner_model = JapaneseNER(model_path) | |
ner_model.prepare() | |
# Function to integrate with spaCy displacy for visualization | |
def ner_inference(text): | |
# Get tokens and predictions | |
tokens, predictions = ner_model.predict(text) | |
# Create a spaCy document to visualize with displacy | |
nlp = spacy.blank("ja") # Initialize a blank Japanese model in spaCy | |
doc = Doc(nlp.vocab, words=tokens) # Create a spaCy Doc object with tokens | |
# Create entity spans from predictions and add them to the Doc object | |
ents = [] | |
start_idx = 0 | |
for i, label in enumerate(predictions): | |
if label != 'O': # Skip non-entity tokens | |
span = Span(doc, start_idx, start_idx + 1, label=label) # Create Span for the token | |
ents.append(span) | |
start_idx += 1 | |
doc.ents = ents # Set the entities in the Doc | |
# Render using spaCy displacy | |
html = displacy.render(doc, style="ent", jupyter=False) # Generate HTML for entities | |
return html | |
# Sample text for demonstration | |
sample_text = "鈴木一朗は2020年に引退した。女優の石原さとみは多くの映画で主演している。" | |
# Create Gradio interface | |
import gradio as gr | |
iface = gr.Interface( | |
fn=ner_inference, # The function to call for prediction | |
inputs=gr.Textbox(lines=5, placeholder="Enter Japanese text for NER...", value=sample_text), # Input widget with sample text | |
outputs="html", # Output will be in HTML format using displacy | |
title="Japanese Named Entity Recognition (NER)", | |
description="Enter Japanese text and see the named entities highlighted in the output." | |
) | |
# Launch the interface | |
iface.launch() | |