ashhadahsan commited on
Commit
e1d0c26
β€’
1 Parent(s): ca3e8cf

Create pages/1_πŸ“ˆ_predict.py

Browse files
Files changed (1) hide show
  1. pages/1_πŸ“ˆ_predict.py +561 -0
pages/1_πŸ“ˆ_predict.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from transformers import pipeline
4
+ from stqdm import stqdm
5
+ from simplet5 import SimpleT5
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
+ from transformers import BertTokenizer, TFBertForSequenceClassification
8
+ from datetime import datetime
9
+ import logging
10
+ from transformers import TextClassificationPipeline
11
+ import gc
12
+ from datasets import load_dataset
13
+ from utils.openllmapi.api import ChatBot
14
+ from utils.openllmapi.exceptions import *
15
+ import time
16
+ from typing import List
17
+ from collections import OrderedDict
18
+
19
+ tokenizer_kwargs = dict(
20
+ max_length=128,
21
+ truncation=True,
22
+ padding=True,
23
+ )
24
+ SLEEP = 2
25
+
26
+
27
+ def cleanMemory(obj: TextClassificationPipeline):
28
+ del obj
29
+ gc.collect()
30
+
31
+
32
+ @st.cache_data
33
+ def getAllCats():
34
+ data = load_dataset("ashhadahsan/amazon_theme")
35
+ data = data["train"].to_pandas()
36
+ labels = [x for x in list(set(data.iloc[:, 1].values.tolist())) if x != "Unknown"]
37
+ del data
38
+ return labels
39
+
40
+
41
+ @st.cache_data
42
+ def getAllSubCats():
43
+ data = load_dataset("ashhadahsan/amazon_theme")
44
+ data = data["train"].to_pandas()
45
+ labels = [x for x in list(set(data.iloc[:, 1].values.tolist())) if x != "Unknown"]
46
+ del data
47
+ return labels
48
+
49
+
50
+ def assignHF(bot, what: str, to: str, old: List):
51
+ try:
52
+ old = ", ".join(old)
53
+ message_content = bot.chat(
54
+ f"""'Assign a one-line {what} to this summary of the text of a review
55
+ {to}
56
+ already assigned themes are , {old}
57
+ theme""",
58
+ )
59
+ try:
60
+ return message_content.split(":")[1].strip()
61
+ except:
62
+ return message_content.strip()
63
+ except ChatError:
64
+ return ""
65
+
66
+
67
+ @st.cache_resource
68
+ def loadZeroShotClassification():
69
+ classifierzero = pipeline(
70
+ "zero-shot-classification", model="facebook/bart-large-mnli"
71
+ )
72
+ return classifierzero
73
+
74
+
75
+ def assignZeroShot(zero, to: str, old: List):
76
+ assigned = zero(to, old)
77
+ assigneddict = dict(zip(assigned["labels"], assigned["scores"]))
78
+ od = OrderedDict(sorted(assigneddict.items(), key=lambda x: x[1], reverse=True))
79
+ print(list(od.keys())[0])
80
+ print(type(list(od.keys())[0]))
81
+
82
+ return list(od.keys())[0]
83
+
84
+
85
+ date = datetime.now().strftime(r"%Y-%m-%d")
86
+
87
+
88
+ @st.cache_resource
89
+ def load_t5() -> (AutoModelForSeq2SeqLM, AutoTokenizer):
90
+ model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
91
+
92
+ tokenizer = AutoTokenizer.from_pretrained("t5-base")
93
+ return model, tokenizer
94
+
95
+
96
+ @st.cache_resource
97
+ def summarizationModel():
98
+ return pipeline("summarization", model="my_awesome_sum/")
99
+
100
+
101
+ @st.cache_resource
102
+ def convert_df(df: pd.DataFrame):
103
+ # IMPORTANT: Cache the conversion to prevent computation on every rerun
104
+ return df.to_csv(index=False).encode("utf-8")
105
+
106
+
107
+ def load_one_line_summarizer(model):
108
+ return model.load_model("t5", "snrspeaks/t5-one-line-summary")
109
+
110
+
111
+ @st.cache_resource
112
+ def classify_theme() -> TextClassificationPipeline:
113
+ tokenizer = BertTokenizer.from_pretrained(
114
+ "ashhadahsan/amazon-theme-bert-base-finetuned"
115
+ )
116
+ model = TFBertForSequenceClassification.from_pretrained(
117
+ "ashhadahsan/amazon-theme-bert-base-finetuned"
118
+ )
119
+ pipeline = TextClassificationPipeline(
120
+ model=model, tokenizer=tokenizer, top_k=1, **tokenizer_kwargs
121
+ )
122
+ return pipeline
123
+
124
+
125
+ @st.cache_resource
126
+ def classify_sub_theme() -> TextClassificationPipeline:
127
+ tokenizer = BertTokenizer.from_pretrained(
128
+ "ashhadahsan/amazon-subtheme-bert-base-finetuned"
129
+ )
130
+ model = TFBertForSequenceClassification.from_pretrained(
131
+ "ashhadahsan/amazon-subtheme-bert-base-finetuned"
132
+ )
133
+ pipeline = TextClassificationPipeline(
134
+ model=model, tokenizer=tokenizer, top_k=1, **tokenizer_kwargs
135
+ )
136
+ return pipeline
137
+
138
+
139
+ st.set_page_config(layout="wide", page_title="Amazon Review | Summarizer")
140
+ st.title("Amazon Review Summarizer")
141
+
142
+ uploaded_file = st.file_uploader("Choose a file", type=["xlsx", "xls", "csv"])
143
+
144
+ try:
145
+ bot = ChatBot(
146
+ cookies={
147
+ "hf-chat": st.secrets["hf-chat"],
148
+ "token": st.secrets["token"],
149
+ }
150
+ )
151
+ except ChatBotInitError as e:
152
+ print(e)
153
+
154
+ summarizer_option = st.selectbox(
155
+ "Select Summarizer",
156
+ ("Custom trained on the dataset", "t5-base", "t5-one-line-summary"),
157
+ )
158
+ col1, col2, col3 = st.columns([1, 1, 1])
159
+
160
+ with col1:
161
+ summary_yes = st.checkbox("Summrization", value=False)
162
+
163
+ with col2:
164
+ classification = st.checkbox("Classify Category", value=True)
165
+
166
+ with col3:
167
+ sub_theme = st.checkbox("Sub theme classification", value=True)
168
+
169
+ treshold = st.slider(
170
+ label="Model Confidence value",
171
+ min_value=0.1,
172
+ max_value=0.8,
173
+ step=0.1,
174
+ value=0.6,
175
+ help="If the model has a confidence score below this number , then a new label is assigned (0.6) means 60 percent and so on",
176
+ )
177
+
178
+ ps = st.empty()
179
+
180
+ if st.button("Process", type="primary"):
181
+ themes = getAllCats()
182
+ subthemes = getAllSubCats()
183
+ # st.write(themes)
184
+
185
+ oneline = SimpleT5()
186
+ load_one_line_summarizer(model=oneline)
187
+ zeroline = loadZeroShotClassification()
188
+
189
+ cancel_button = st.empty()
190
+ cancel_button2 = st.empty()
191
+ cancel_button3 = st.empty()
192
+ if uploaded_file is not None:
193
+ if uploaded_file.name.split(".")[-1] in ["xls", "xlsx"]:
194
+ df = pd.read_excel(uploaded_file, engine="openpyxl")
195
+ if uploaded_file.name.split(".")[-1] in [".csv"]:
196
+ df = pd.read_csv(uploaded_file)
197
+ columns = df.columns.values.tolist()
198
+ columns = [x.lower() for x in columns]
199
+ df.columns = columns
200
+ print(summarizer_option)
201
+ outputdf = pd.DataFrame()
202
+ try:
203
+ text = df["text"].values.tolist()[0:100]
204
+ outputdf["text"] = text
205
+ if summarizer_option == "Custom trained on the dataset":
206
+ if summary_yes:
207
+ model = summarizationModel()
208
+
209
+ progress_text = "Summarization in progress. Please wait."
210
+ summary = []
211
+
212
+ for x in stqdm(range(len(text))):
213
+ if cancel_button.button("Cancel", key=x):
214
+ del model
215
+ break
216
+ try:
217
+ summary.append(
218
+ model(
219
+ f"summarize: {text[x]}",
220
+ max_length=50,
221
+ early_stopping=True,
222
+ )[0]["summary_text"]
223
+ )
224
+ except:
225
+ pass
226
+ outputdf["summary"] = summary
227
+ del model
228
+ if classification:
229
+ themePipe = classify_theme()
230
+ classes = []
231
+ classesUnlabel = []
232
+ classesUnlabelZero = []
233
+ for x in stqdm(
234
+ text,
235
+ desc="Assigning Themes ...",
236
+ total=len(text),
237
+ colour="#BF1A1A",
238
+ ):
239
+ output = themePipe(x)[0][0]["label"]
240
+ classes.append(output)
241
+ score = round(themePipe(x)[0][0]["score"], 2)
242
+ if score <= treshold:
243
+ onelineoutput=oneline.predict(x)[0]
244
+ time.sleep(SLEEP)
245
+ print("hit")
246
+ classesUnlabel.append(
247
+ assignHF(
248
+ bot=bot,
249
+ what="theme",
250
+ to=onelineoutput,
251
+ old=themes,
252
+ )
253
+ )
254
+ classesUnlabelZero.append(
255
+ assignZeroShot(
256
+ zero=zeroline, to=onelineoutput, old=themes
257
+ )
258
+ )
259
+
260
+ else:
261
+ classesUnlabel.append("")
262
+ classesUnlabelZero.append("")
263
+
264
+ outputdf["Review Theme"] = classes
265
+ outputdf["Review Theme-issue-new"] = classesUnlabel
266
+ outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
267
+ cleanMemory(themePipe)
268
+ if sub_theme:
269
+ subThemePipe = classify_sub_theme()
270
+ classes = []
271
+ classesUnlabel = []
272
+ classesUnlabelZero = []
273
+ for x in stqdm(
274
+ text,
275
+ desc="Assigning Subthemes ...",
276
+ total=len(text),
277
+ colour="green",
278
+ ):
279
+ output = subThemePipe(x)[0][0]["label"]
280
+ classes.append(output)
281
+ score = round(subThemePipe(x)[0][0]["score"], 2)
282
+ if score <= treshold:
283
+ onelineoutput=oneline.predict(x)[0]
284
+
285
+ time.sleep(SLEEP)
286
+
287
+ print("hit")
288
+ classesUnlabel.append(
289
+ assignHF(
290
+ bot=bot,
291
+ what="subtheme",
292
+ to=onelineoutput,
293
+ old=subthemes,
294
+ )
295
+ )
296
+ classesUnlabelZero.append(
297
+ assignZeroShot(
298
+ zero=zeroline,
299
+ to=onelineoutput,
300
+ old=subthemes,
301
+ )
302
+ )
303
+
304
+ else:
305
+ classesUnlabel.append("")
306
+ classesUnlabelZero.append("")
307
+
308
+ outputdf["Review SubTheme"] = classes
309
+ outputdf["Review SubTheme-issue-new"] = classesUnlabel
310
+ outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
311
+
312
+ cleanMemory(subThemePipe)
313
+
314
+ csv = convert_df(outputdf)
315
+ st.download_button(
316
+ label="Download output as CSV",
317
+ data=csv,
318
+ file_name=f"{summarizer_option}_{date}_df.csv",
319
+ mime="text/csv",
320
+ use_container_width=True,
321
+ )
322
+ if summarizer_option == "t5-base":
323
+ if summary_yes:
324
+ model, tokenizer = load_t5()
325
+ summary = []
326
+ for x in stqdm(range(len(text))):
327
+ if cancel_button2.button("Cancel", key=x):
328
+ del model, tokenizer
329
+ break
330
+ tokens_input = tokenizer.encode(
331
+ "summarize: " + text[x],
332
+ return_tensors="pt",
333
+ max_length=tokenizer.model_max_length,
334
+ truncation=True,
335
+ )
336
+ summary_ids = model.generate(
337
+ tokens_input,
338
+ min_length=80,
339
+ max_length=150,
340
+ length_penalty=20,
341
+ num_beams=2,
342
+ )
343
+ summary_gen = tokenizer.decode(
344
+ summary_ids[0], skip_special_tokens=True
345
+ )
346
+ summary.append(summary_gen)
347
+ del model, tokenizer
348
+ outputdf["summary"] = summary
349
+
350
+ if classification:
351
+ themePipe = classify_theme()
352
+ classes = []
353
+ classesUnlabel = []
354
+ classesUnlabelZero = []
355
+ for x in stqdm(
356
+ text, desc="Assigning Themes ...", total=len(text), colour="red"
357
+ ):
358
+ output = themePipe(x)[0][0]["label"]
359
+ classes.append(output)
360
+ score = round(themePipe(x)[0][0]["score"], 2)
361
+ if score <= treshold:
362
+ onelineoutput=oneline.predict(x)[0]
363
+
364
+ print("hit")
365
+ time.sleep(SLEEP)
366
+
367
+ classesUnlabel.append(
368
+ assignHF(
369
+ bot=bot,
370
+ what="theme",
371
+ to=onelineoutput,
372
+ old=themes,
373
+ )
374
+ )
375
+ classesUnlabelZero.append(
376
+ assignZeroShot(
377
+ zero=zeroline, to=onelineoutput, old=themes
378
+ )
379
+ )
380
+
381
+ else:
382
+ classesUnlabel.append("")
383
+ classesUnlabelZero.append("")
384
+ outputdf["Review Theme"] = classes
385
+ outputdf["Review Theme-issue-new"] = classesUnlabel
386
+ outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
387
+ cleanMemory(themePipe)
388
+
389
+ if sub_theme:
390
+ subThemePipe = classify_sub_theme()
391
+ classes = []
392
+ classesUnlabelZero = []
393
+
394
+ for x in stqdm(
395
+ text,
396
+ desc="Assigning Subthemes ...",
397
+ total=len(text),
398
+ colour="green",
399
+ ):
400
+ output = subThemePipe(x)[0][0]["label"]
401
+ classes.append(output)
402
+ score = round(subThemePipe(x)[0][0]["score"], 2)
403
+ if score <= treshold:
404
+ onelineoutput=oneline.predict(x)[0]
405
+
406
+ time.sleep(SLEEP)
407
+ print("hit")
408
+ classesUnlabel.append(
409
+ assignHF(
410
+ bot=bot,
411
+ what="subtheme",
412
+ to=onelineoutput,
413
+ old=subthemes,
414
+ )
415
+ )
416
+ classesUnlabelZero.append(
417
+ assignZeroShot(
418
+ zero=zeroline,
419
+ to=onelineoutput,
420
+ old=subthemes,
421
+ )
422
+ )
423
+
424
+ else:
425
+ classesUnlabel.append("")
426
+ classesUnlabelZero.append("")
427
+
428
+ outputdf["Review SubTheme"] = classes
429
+ outputdf["Review SubTheme-issue-new"] = classesUnlabel
430
+ outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
431
+
432
+ cleanMemory(subThemePipe)
433
+
434
+ csv = convert_df(outputdf)
435
+ st.download_button(
436
+ label="Download output as CSV",
437
+ data=csv,
438
+ file_name=f"{summarizer_option}_{date}_df.csv",
439
+ mime="text/csv",
440
+ use_container_width=True,
441
+ )
442
+
443
+ if summarizer_option == "t5-one-line-summary":
444
+ if summary_yes:
445
+ model = SimpleT5()
446
+ load_one_line_summarizer(model=model)
447
+
448
+ summary = []
449
+ for x in stqdm(range(len(text))):
450
+ if cancel_button3.button("Cancel", key=x):
451
+ del model
452
+ break
453
+ try:
454
+ summary.append(model.predict(text[x])[0])
455
+ except:
456
+ pass
457
+ outputdf["summary"] = summary
458
+ del model
459
+
460
+ if classification:
461
+ themePipe = classify_theme()
462
+ classes = []
463
+ classesUnlabel = []
464
+ classesUnlabelZero = []
465
+ for x in stqdm(
466
+ text, desc="Assigning Themes ...", total=len(text), colour="red"
467
+ ):
468
+ output = themePipe(x)[0][0]["label"]
469
+ classes.append(output)
470
+ score = round(themePipe(x)[0][0]["score"], 2)
471
+ if score <= treshold:
472
+ onelineoutput=oneline.predict(x)[0]
473
+
474
+ time.sleep(SLEEP)
475
+
476
+ print("hit")
477
+ classesUnlabel.append(
478
+ assignHF(
479
+ bot=bot,
480
+ what="theme",
481
+ to=onelineoutput,
482
+ old=themes,
483
+ )
484
+ )
485
+ classesUnlabelZero.append(
486
+ assignZeroShot(
487
+ zero=zeroline, to=onelineoutput, old=themes
488
+ )
489
+ )
490
+
491
+ else:
492
+ classesUnlabel.append("")
493
+ classesUnlabelZero.append("")
494
+ outputdf["Review Theme"] = classes
495
+ outputdf["Review Theme-issue-new"] = classesUnlabel
496
+ outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
497
+
498
+ if sub_theme:
499
+ subThemePipe = classify_sub_theme()
500
+ classes = []
501
+ classesUnlabelZero = []
502
+
503
+ for x in stqdm(
504
+ text,
505
+ desc="Assigning Subthemes ...",
506
+ total=len(text),
507
+ colour="green",
508
+ ):
509
+ output = subThemePipe(x)[0][0]["label"]
510
+ classes.append(output)
511
+ score = round(subThemePipe(x)[0][0]["score"], 2)
512
+ if score <= treshold:
513
+ print("hit")
514
+ onelineoutput=oneline.predict(x)[0]
515
+
516
+ time.sleep(SLEEP)
517
+ classesUnlabel.append(
518
+ assignHF(
519
+ bot=bot,
520
+ what="subtheme",
521
+ to=onelineoutput,
522
+ old=subthemes,
523
+ )
524
+ )
525
+ classesUnlabelZero.append(
526
+ assignZeroShot(
527
+ zero=zeroline,
528
+ to=onelineoutput,
529
+ old=subthemes,
530
+ )
531
+ )
532
+
533
+ else:
534
+ classesUnlabel.append("")
535
+ classesUnlabelZero.append("")
536
+
537
+ outputdf["Review SubTheme"] = classes
538
+ outputdf["Review SubTheme-issue-new"] = classesUnlabel
539
+ outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
540
+
541
+ cleanMemory(subThemePipe)
542
+
543
+ csv = convert_df(outputdf)
544
+ st.download_button(
545
+ label="Download output as CSV",
546
+ data=csv,
547
+ file_name=f"{summarizer_option}_{date}_df.csv",
548
+ mime="text/csv",
549
+ use_container_width=True,
550
+ )
551
+
552
+ except KeyError as e:
553
+ st.error(
554
+ "Please Make sure that your data must have a column named text",
555
+ icon="🚨",
556
+ )
557
+ st.info("Text column must have amazon reviews", icon="ℹ️")
558
+ # st.exception(e)
559
+
560
+ except BaseException as e:
561
+ logging.exception("An exception was occurred")