SeyedAli's picture
Update app.py
a5cd4ab
raw
history blame contribute delete
No virus
2.69 kB
import gradio as gr
import tempfile
from transformers import MT5ForConditionalGeneration, MT5Tokenizer,ViltProcessor, ViltForQuestionAnswering, AutoTokenizer
import torch
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# English to Persian Translation model
fa_en_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1")
fa_en_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1").to(device)
def run_fa_en_transaltion_model(input_string, **generator_args):
input_ids = fa_en_translation_tokenizer.encode(input_string, return_tensors="pt")
res = fa_en_translation_model.generate(input_ids, **generator_args)
output = fa_en_translation_tokenizer.batch_decode(res, skip_special_tokens=True)
return output
# Persian to English Translation model
en_fa_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1")
en_fa_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1").to(device)
def run_en_fa_transaltion_model(input_string, **generator_args):
input_ids = en_fa_translation_tokenizer.encode(input_string, return_tensors="pt")
res = en_fa_translation_model.generate(input_ids, **generator_args)
output = en_fa_translation_tokenizer.batch_decode(res, skip_special_tokens=True)
return output
# Visual Question Answering model
VQA_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
VQA_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device)
image_input =gr.Image(label="عکس ورودی")
text_input = gr.TextArea(label="سوال فارسی",text_align="right",rtl=True,type="text")
text_output = gr.TextArea(label="پاسخ",text_align="right",rtl=True,type="text")
def VQA(image,text):
with tempfile.NamedTemporaryFile(suffix=".png") as temp_image_file:
# Copy the contents of the uploaded image file to the temporary file
Image.fromarray(image).save(temp_image_file.name)
# Load the image file using Pillow
image = Image.open(temp_image_file.name)
# prepare inputs
encoding = VQA_processor(image, run_fa_en_transaltion_model(text), return_tensors="pt")
# forward pass
outputs = VQA_model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
return run_en_fa_transaltion_model(VQA_model.config.id2label[idx])[0]
iface = gr.Interface(fn=VQA, inputs=[image_input,text_input], outputs=text_output)
iface.launch(share=False)