sabari commited on
Commit
51de811
1 Parent(s): 8c3cc44

Add application file

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Tue Sep 17 19:03:17 2024
4
+
5
+ @author: SABARI
6
+ """
7
+ import os
8
+ import torch
9
+ from transformers import AutoConfig
10
+ from transformers.models.roberta.modeling_roberta import RobertaForTokenClassification
11
+ from datasets import Dataset
12
+ from torch.utils.data import DataLoader
13
+ from transformers import AutoTokenizer
14
+ import spacy
15
+ from spacy.tokens import Doc, Span
16
+ from spacy import displacy
17
+
18
+ # Set device
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ class JapaneseNER():
22
+ def __init__(self, model_path, model_name="xlm-roberta-base"):
23
+ self._index_to_tag = {0: 'O',
24
+ 1: 'PER',
25
+ 2: 'ORG',
26
+ 3: 'ORG-P',
27
+ 4: 'ORG-O',
28
+ 5: 'LOC',
29
+ 6: 'INS',
30
+ 7: 'PRD',
31
+ 8: 'EVT'}
32
+
33
+ self._tag_to_index = {v: k for k, v in self._index_to_tag.items()}
34
+ self._tag_feature_num_classes = len(self._index_to_tag)
35
+ self._model_name = model_name
36
+ self._model_path = model_path
37
+
38
+ xlmr_config = AutoConfig.from_pretrained(
39
+ self._model_name,
40
+ num_labels=self._tag_feature_num_classes,
41
+ id2label=self._index_to_tag,
42
+ label2id=self._tag_to_index
43
+ )
44
+
45
+ self.tokenizer = AutoTokenizer.from_pretrained(self._model_name)
46
+ self.model = (RobertaForTokenClassification
47
+ .from_pretrained(self._model_path, config=xlmr_config)
48
+ .to(device))
49
+ def prepare(self):
50
+ # create dataset for prediction
51
+ sample_encoding = self.tokenizer([
52
+ "鈴木は4月の陽気の良い日に、鈴をつけて熊本県の阿蘇山に登った",
53
+ "中国では、中国共産党による一党統治が続く",
54
+ ], truncation=True, max_length=512)
55
+
56
+ sample_encoding = {k: v.to(device) for k, v in sample_encoding.items()}
57
+
58
+ # Perform prediction
59
+ with torch.no_grad():
60
+ output = self.model(**sample_encoding)
61
+
62
+ predicted_label_id = torch.argmax(output.logits, axis=-1).cpu().numpy()[0]
63
+ print("predictedl label",predicted_label_id)
64
+
65
+ def predict(self, text):
66
+ encoding = self.tokenizer([text], truncation=True, max_length=512, return_tensors="pt")
67
+ encoding = {k: v.to(device) for k, v in encoding.items()}
68
+
69
+ # Perform prediction
70
+ with torch.no_grad():
71
+ output = self.model(**encoding)
72
+
73
+ # Get the predicted label ids
74
+ predicted_label_id = torch.argmax(output.logits, axis=-1).cpu().numpy()[0]
75
+ tokens = self.tokenizer.tokenize(self.tokenizer.decode(encoding["input_ids"][0]))
76
+
77
+ # Map the predicted labels to their corresponding tag
78
+ predictions = [self._index_to_tag[label_id] for label_id in predicted_label_id]
79
+
80
+ return tokens, predictions
81
+
82
+ # Instantiate the NER model
83
+ model_path = "path_to_your_saved_model"
84
+ ner_model = JapaneseNER(model_path)
85
+ ner_model.prepare()
86
+ # Function to integrate with spaCy displacy for visualization
87
+ def ner_inference(text):
88
+ # Get tokens and predictions
89
+ tokens, predictions = ner_model.predict(text)
90
+
91
+ # Create a spaCy document to visualize with displacy
92
+ nlp = spacy.blank("ja") # Initialize a blank Japanese model in spaCy
93
+ doc = Doc(nlp.vocab, words=tokens) # Create a spaCy Doc object with tokens
94
+
95
+ # Create entity spans from predictions and add them to the Doc object
96
+ ents = []
97
+ for i, label in enumerate(predictions):
98
+ if label != 'O': # Skip non-entity tokens
99
+ span = Span(doc, i, i+1, label=label) # Create Span for the token
100
+ ents.append(span)
101
+ doc.ents = ents # Set the entities in the Doc
102
+
103
+ # Render using spacy displacy
104
+ html = displacy.render(doc, style="ent", jupyter=False) # Generate HTML for entities
105
+ return html
106
+
107
+ # Create Gradio interface
108
+ import gradio as gr
109
+
110
+ iface = gr.Interface(
111
+ fn=ner_inference, # The function to call for prediction
112
+ inputs=gr.Textbox(lines=5, placeholder="Enter Japanese text for NER..."), # Input widget
113
+ outputs="html", # Output will be in HTML format using displacy
114
+ title="Japanese Named Entity Recognition (NER)",
115
+ description="Enter Japanese text and see the named entities highlighted in the output."
116
+ )
117
+
118
+ # Launch the interface
119
+ iface.launch()