Spaces:
Runtime error
Runtime error
import gradio as gr | |
import tempfile | |
from transformers import MT5ForConditionalGeneration, MT5Tokenizer,ViltProcessor, ViltForQuestionAnswering, AutoTokenizer | |
import torch | |
from PIL import Image | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# English to Persian Translation model | |
fa_en_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1") | |
fa_en_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1").to(device) | |
def run_fa_en_transaltion_model(input_string, **generator_args): | |
input_ids = fa_en_translation_tokenizer.encode(input_string, return_tensors="pt") | |
res = fa_en_translation_model.generate(input_ids, **generator_args) | |
output = fa_en_translation_tokenizer.batch_decode(res, skip_special_tokens=True) | |
return output | |
# Persian to English Translation model | |
en_fa_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1") | |
en_fa_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1").to(device) | |
def run_en_fa_transaltion_model(input_string, **generator_args): | |
input_ids = en_fa_translation_tokenizer.encode(input_string, return_tensors="pt") | |
res = en_fa_translation_model.generate(input_ids, **generator_args) | |
output = en_fa_translation_tokenizer.batch_decode(res, skip_special_tokens=True) | |
return output | |
# Visual Question Answering model | |
VQA_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
VQA_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device) | |
def VQA(image,text): | |
with tempfile.NamedTemporaryFile(suffix=".png") as temp_image_file: | |
# Copy the contents of the uploaded image file to the temporary file | |
Image.fromarray(image).save(temp_image_file.name) | |
# Load the image file using Pillow | |
image = Image.open(temp_image_file.name) | |
# prepare inputs | |
encoding = VQA_processor(image, run_fa_en_transaltion_model(text), return_tensors="pt") | |
# forward pass | |
outputs = VQA_model(**encoding) | |
logits = outputs.logits | |
idx = logits.argmax(-1).item() | |
output=[] | |
for item in VQA_model.config.id2label[idx]: | |
output.append(run_en_fa_transaltion_model(item)) | |
return output | |
iface = gr.Interface(fn=VQA, inputs=["image","text"], outputs="text") | |
iface.launch(share=False) |