TahaRasouli commited on
Commit
2427679
1 Parent(s): 9f1df33

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #importing the necessary libraries
2
+ import gradio as gr
3
+ import numpy as np
4
+ import pandas as pd
5
+ import re
6
+ import torch
7
+
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
+ from topic_labels import labels
10
+
11
+ #Defining the models and tokenuzer
12
+ model_name = "valurank/distilroberta-topic-classification"
13
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
14
+ #model.to(device)
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+
17
+
18
+ def clean_text(raw_text):
19
+ text = raw_text.encode("ascii", errors="ignore").decode(
20
+ "ascii"
21
+ ) # remove non-ascii, Chinese characters
22
+
23
+ text = re.sub(r"\n", " ", text)
24
+ text = re.sub(r"\n\n", " ", text)
25
+ text = re.sub(r"\t", " ", text)
26
+ text = text.strip(" ")
27
+ text = re.sub(
28
+ " +", " ", text
29
+ ).strip() # get rid of multiple spaces and replace with a single
30
+
31
+ text = re.sub(r"Date\s\d{1,2}\/\d{1,2}\/\d{4}", "", text) #remove date
32
+ text = re.sub(r"\d{1,2}:\d{2}\s[A-Z]+\s[A-Z]+", "", text) #remove time
33
+
34
+ return text
35
+
36
+
37
+ def find_two_highest_indices(arr):
38
+ if len(arr) < 2:
39
+ raise ValueError("Array must have at least two elements")
40
+
41
+ # Initialize the indices of the two highest values
42
+ max_idx = second_max_idx = None
43
+
44
+ for i, value in enumerate(arr):
45
+ if max_idx is None or value > arr[max_idx]:
46
+ second_max_idx = max_idx
47
+ max_idx = i
48
+ elif second_max_idx is None or value > arr[second_max_idx]:
49
+ second_max_idx = i
50
+
51
+ return max_idx, second_max_idx
52
+
53
+
54
+ def predict_topic(text):
55
+ text = clean_text(text)
56
+ dict_topic = {}
57
+
58
+ input_tensor = tokenizer.encode(text, return_tensors="pt", truncation=True)
59
+ logits = model(input_tensor).logits
60
+
61
+ softmax = torch.nn.Softmax(dim=1)
62
+ probs = softmax(logits)[0]
63
+ probs = probs.cpu().detach().numpy()
64
+
65
+ max_index = find_two_highest_indices(probs)
66
+ emotion_1, emotion_2 = labels[max_index[0]], labels[max_index[1]]
67
+ probs_1, probs_2 = probs[max_index[0]], probs[max_index[1]]
68
+ dict_topic[emotion_1] = round((probs_1), 2)
69
+
70
+ #if probs_2 > 0.01:
71
+ dict_topic[emotion_2] = round((probs_2), 2)
72
+
73
+ return dict_topic
74
+
75
+
76
+ #Creating the interface for the radio appdemo = gr.Interface(multi_label_emotions, inputs=gr.Textbox(),
77
+ demo = gr.Interface(predict_topic, inputs=gr.Textbox(),
78
+ outputs = gr.Label(num_top_classes=2),
79
+ title="Topic Classification")
80
+
81
+ if __name__ == "__main__":
82
+ demo.launch(debug=True)