Aymene's picture
Update app.py
6b16397
raw
history blame
No virus
1.51 kB
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model.load_state_dict(torch.load('/content/Fake-news-detection-bert-based-uncased/model_after_train.pt', map_location=torch.device('cpu')), strict=False)
model.eval()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def preprocess_text(text):
parts = []
text_len = len(text.split(' '))
delta = 300
max_parts = 5
nb_cuts = int(text_len / delta)
nb_cuts = min(nb_cuts, max_parts)
for i in range(nb_cuts + 1):
text_part = ' '.join(text.split(' ')[i * delta: (i + 1) * delta])
parts.append(tokenizer.encode(text_part, return_tensors="pt", max_length=500).to(device))
return parts
def test(text):
text_parts = preprocess_text(text)
overall_output = torch.zeros((1,2)).to(device)
try:
for part in text_parts:
if len(part) > 0:
overall_output += model(part.reshape(1, -1))[0]
except RuntimeError:
print("GPU out of memory, skipping this entry.")
overall_output = F.softmax(overall_output[0], dim=-1)
value, result = overall_output.max(0)
term = "fake"
if result.item() == 0:
term = "real"
return term + " at " + str(int(value.item()*100)) + " %"
iface = gr.Interface(fn=test, inputs="text", outputs="text")
iface.launch()