Nathanwit commited on
Commit
4878f59
β€’
1 Parent(s): 7254bd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -1,13 +1,15 @@
1
  import torch
2
  import re
3
  import gradio as gr
4
- from transformers import AutoTokenizer, ViTFeatureExtractor, AutoModel
5
 
6
  device='cpu'
7
  encoder_checkpoint = "Thibalte/captionning_project"
 
 
8
  feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
9
- tokenizer = AutoTokenizer.from_pretrained("Thibalte/captionning_project")
10
- model = AutoModel.from_pretrained("Thibalte/captionning_project")
11
 
12
 
13
  def predict(image,max_length=24, num_beams=4):
 
1
  import torch
2
  import re
3
  import gradio as gr
4
+ from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
 
6
  device='cpu'
7
  encoder_checkpoint = "Thibalte/captionning_project"
8
+ decoder_checkpoint = "Thibalte/captionning_project"
9
+ model_checkpoint = "Thibalte/captionning_project"
10
  feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
11
+ tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
12
+ model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
13
 
14
 
15
  def predict(image,max_length=24, num_beams=4):