SeyedAli commited on
Commit
c636c78
1 Parent(s): d2bfb89

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tempfile
3
+ from transformers import MT5ForConditionalGeneration, MT5Tokenizer,ViltProcessor, ViltForQuestionAnswering, AutoTokenizer
4
+ import torch
5
+ from PIL import Image
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ # English to Persian model
10
+ fa_en_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1").to(device)
11
+ fa_en_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1").to(device)
12
+
13
+ def run_fa_en_transaltion_model(input_string, **generator_args):
14
+ input_ids = fa_en_translation_tokenizer.encode(input_string, return_tensors="pt")
15
+ res = fa_en_translation_model.generate(input_ids, **generator_args)
16
+ output = fa_en_translation_tokenizer.batch_decode(res, skip_special_tokens=True)
17
+ return output
18
+
19
+ # Persian to English model
20
+ en_fa_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1").to(device)
21
+ en_fa_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1").to(device)
22
+
23
+ def run_en_fa_transaltion_model(input_string, **generator_args):
24
+ input_ids = en_fa_translation_tokenizer.encode(input_string, return_tensors="pt")
25
+ res = en_fa_translation_model.generate(input_ids, **generator_args)
26
+ output = en_fa_translation_tokenizer.batch_decode(res, skip_special_tokens=True)
27
+ return output
28
+
29
+ # Visual Question Answering model
30
+ VQA_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device)
31
+ VQA_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device)
32
+
33
+
34
+ def VQA(image,text):
35
+ with tempfile.NamedTemporaryFile(suffix=".png") as temp_image_file:
36
+ # Copy the contents of the uploaded image file to the temporary file
37
+ Image.fromarray(image).save(temp_image_file.name)
38
+ # Load the image file using Pillow
39
+ image = Image.open(temp_image_file.name)
40
+ # prepare inputs
41
+ encoding = processor(image, run_fa_en_transaltion_model(text), return_tensors="pt")
42
+ # forward pass
43
+ outputs = model(**encoding)
44
+ logits = outputs.logits
45
+ idx = logits.argmax(-1).item()
46
+ output=[]
47
+ for item in model.config.id2label[idx]
48
+ output.append(run_en_fa_transaltion_model(item))
49
+ return output
50
+
51
+ iface = gr.Interface(fn=VQA, inputs=["image","text"], outputs="text")
52
+ iface.launch(share=False)