SeyedAli commited on
Commit
49c2537
1 Parent(s): 0b4638a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -6,7 +6,7 @@ from PIL import Image
6
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
- # English to Persian model
10
  fa_en_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1")
11
  fa_en_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1").to(device)
12
 
@@ -16,7 +16,7 @@ def run_fa_en_transaltion_model(input_string, **generator_args):
16
  output = fa_en_translation_tokenizer.batch_decode(res, skip_special_tokens=True)
17
  return output
18
 
19
- # Persian to English model
20
  en_fa_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1")
21
  en_fa_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1").to(device)
22
 
@@ -38,13 +38,13 @@ def VQA(image,text):
38
  # Load the image file using Pillow
39
  image = Image.open(temp_image_file.name)
40
  # prepare inputs
41
- encoding = processor(image, run_fa_en_transaltion_model(text), return_tensors="pt")
42
  # forward pass
43
- outputs = model(**encoding)
44
  logits = outputs.logits
45
  idx = logits.argmax(-1).item()
46
  output=[]
47
- for item in model.config.id2label[idx]:
48
  output.append(run_en_fa_transaltion_model(item))
49
  return output
50
 
 
6
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
+ # English to Persian Translation model
10
  fa_en_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1")
11
  fa_en_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1").to(device)
12
 
 
16
  output = fa_en_translation_tokenizer.batch_decode(res, skip_special_tokens=True)
17
  return output
18
 
19
+ # Persian to English Translation model
20
  en_fa_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1")
21
  en_fa_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1").to(device)
22
 
 
38
  # Load the image file using Pillow
39
  image = Image.open(temp_image_file.name)
40
  # prepare inputs
41
+ encoding = VQA_processor(image, run_fa_en_transaltion_model(text), return_tensors="pt")
42
  # forward pass
43
+ outputs = VQA_model(**encoding)
44
  logits = outputs.logits
45
  idx = logits.argmax(-1).item()
46
  output=[]
47
+ for item in VQA_model.config.id2label[idx]:
48
  output.append(run_en_fa_transaltion_model(item))
49
  return output
50