Nathanwit commited on
Commit
ba44e57
β€’
1 Parent(s): bd38d64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -7
app.py CHANGED
@@ -2,23 +2,45 @@ import torch
2
  import re
3
  import gradio as gr
4
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
-
6
  device='cpu'
7
  encoder_checkpoint = "Thibalte/captionning_project"
8
  decoder_checkpoint = "Thibalte/captionning_project"
9
  model_checkpoint = "Thibalte/captionning_project"
 
10
  feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
11
  tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
12
  model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  def predict(image,max_length=24, num_beams=4):
16
- image = image.convert('RGB')
17
- image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
18
- clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
19
- caption_ids = model.generate(image, max_length = max_length)[0]
20
- caption_text = clean_text(tokenizer.decode(caption_ids))
21
- return caption_text
22
 
23
 
24
  # Gradio Interface
 
2
  import re
3
  import gradio as gr
4
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
+ '''
6
  device='cpu'
7
  encoder_checkpoint = "Thibalte/captionning_project"
8
  decoder_checkpoint = "Thibalte/captionning_project"
9
  model_checkpoint = "Thibalte/captionning_project"
10
+ feature_extractor= ViTImageProcessor.from_pretrained(model_path)
11
  feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
12
  tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
13
  model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
14
+ '''
15
+ # Load the trained model
16
+ model_path = "./image-captioning-output"
17
+
18
+ # Load tokenizer
19
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
20
+
21
+ #Load ImageProcessor
22
+ feature_extractor= ViTImageProcessor.from_pretrained(model_path)
23
+
24
+ # Load model
25
+ model = VisionEncoderDecoderModel.from_pretrained(model_path)
26
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
27
+ # use GPT2's eos_token as the pad as well as eos token
28
+
29
+
30
+ # generation
31
+
32
+
33
+
34
+
35
+ print(captions)
36
 
37
 
38
  def predict(image,max_length=24, num_beams=4):
39
+ image = image.convert('RGB')
40
+ sequences = model.generate(pixel_values, num_beams=4, max_length=25)
41
+ sequences = model.generate(pixel_values, num_beams=4, max_length=25)
42
+ captions = tokenizer.batch_decode(sequences, skip_special_tokens=True)
43
+ return caption
 
44
 
45
 
46
  # Gradio Interface