Spaces:
Runtime error
Runtime error
File size: 1,067 Bytes
8af7698 6a79179 8af7698 420c089 8af7698 420c089 8af7698 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
import numpy as np
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
repo_name = 'juliensimon/autonlp-song-lyrics-18753417'
tokenizer = AutoTokenizer.from_pretrained(repo_name)
model = AutoModelForSequenceClassification.from_pretrained(repo_name)
labels = model.config.id2label
print(labels)
def predict(lyrics):
inputs = tokenizer(lyrics, padding=True, truncation=True, return_tensors="pt")
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predictions = predictions.detach().numpy()[0]
predictions = predictions*100
print(predictions)
sorted_indexes = np.argsort(predictions)
return "These lyrics are {:.2f}% {}, {:.2f}% {} and {:.2f}% {}.".format(
predictions[sorted_indexes[-1]], labels[sorted_indexes[-1]],
predictions[sorted_indexes[-2]], labels[sorted_indexes[-2]],
predictions[sorted_indexes[-3]], labels[sorted_indexes[-3]])
input = gr.inputs.Textbox(lines=20)
iface = gr.Interface(fn=predict, inputs=input, outputs="text")
iface.launch()
|