import gradio as gr import tempfile from transformers import MT5ForConditionalGeneration, MT5Tokenizer,ViltProcessor, ViltForQuestionAnswering, AutoTokenizer import torch from PIL import Image device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # English to Persian Translation model fa_en_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1") fa_en_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1").to(device) def run_fa_en_transaltion_model(input_string, **generator_args): input_ids = fa_en_translation_tokenizer.encode(input_string, return_tensors="pt") res = fa_en_translation_model.generate(input_ids, **generator_args) output = fa_en_translation_tokenizer.batch_decode(res, skip_special_tokens=True) return output # Persian to English Translation model en_fa_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1") en_fa_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1").to(device) def run_en_fa_transaltion_model(input_string, **generator_args): input_ids = en_fa_translation_tokenizer.encode(input_string, return_tensors="pt") res = en_fa_translation_model.generate(input_ids, **generator_args) output = en_fa_translation_tokenizer.batch_decode(res, skip_special_tokens=True) return output # Visual Question Answering model VQA_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") VQA_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device) def VQA(image,text): with tempfile.NamedTemporaryFile(suffix=".png") as temp_image_file: # Copy the contents of the uploaded image file to the temporary file Image.fromarray(image).save(temp_image_file.name) # Load the image file using Pillow image = Image.open(temp_image_file.name) # prepare inputs encoding = VQA_processor(image, run_fa_en_transaltion_model(text), return_tensors="pt") # forward pass outputs = VQA_model(**encoding) logits = outputs.logits idx = logits.argmax(-1).item() output=[] for item in VQA_model.config.id2label[idx]: output.append(run_en_fa_transaltion_model(item)) return output iface = gr.Interface(fn=VQA, inputs=["image","text"], outputs="text") iface.launch(share=False)