File size: 2,685 Bytes
c636c78
 
 
 
 
 
 
 
49c2537
0b4638a
c636c78
 
 
 
 
 
 
 
49c2537
0b4638a
c636c78
 
 
 
 
 
 
 
 
0b4638a
c636c78
 
83c5511
 
 
c636c78
 
 
 
 
 
 
 
49c2537
c636c78
49c2537
c636c78
 
1148ad6
c636c78
a5cd4ab
c636c78
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gradio as gr
import tempfile
from transformers import  MT5ForConditionalGeneration, MT5Tokenizer,ViltProcessor, ViltForQuestionAnswering, AutoTokenizer
import torch
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# English to Persian Translation model
fa_en_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1")
fa_en_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1").to(device)

def run_fa_en_transaltion_model(input_string, **generator_args):
    input_ids = fa_en_translation_tokenizer.encode(input_string, return_tensors="pt")
    res = fa_en_translation_model.generate(input_ids, **generator_args)
    output = fa_en_translation_tokenizer.batch_decode(res, skip_special_tokens=True)
    return output

# Persian to English Translation model
en_fa_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1")
en_fa_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1").to(device)

def run_en_fa_transaltion_model(input_string, **generator_args):
    input_ids = en_fa_translation_tokenizer.encode(input_string, return_tensors="pt")
    res = en_fa_translation_model.generate(input_ids, **generator_args)
    output = en_fa_translation_tokenizer.batch_decode(res, skip_special_tokens=True)
    return output

# Visual Question Answering model
VQA_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
VQA_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device)

image_input =gr.Image(label="عکس ورودی")
text_input = gr.TextArea(label="سوال فارسی",text_align="right",rtl=True,type="text")
text_output = gr.TextArea(label="پاسخ",text_align="right",rtl=True,type="text")

def VQA(image,text):
    with tempfile.NamedTemporaryFile(suffix=".png") as temp_image_file:
        # Copy the contents of the uploaded image file to the temporary file
        Image.fromarray(image).save(temp_image_file.name)
        # Load the image file using Pillow
        image = Image.open(temp_image_file.name)
        # prepare inputs
        encoding = VQA_processor(image, run_fa_en_transaltion_model(text), return_tensors="pt")
        # forward pass
        outputs = VQA_model(**encoding)
        logits = outputs.logits
        idx = logits.argmax(-1).item()
        return run_en_fa_transaltion_model(VQA_model.config.id2label[idx])[0]

iface = gr.Interface(fn=VQA, inputs=[image_input,text_input], outputs=text_output)
iface.launch(share=False)