Spaces:
Runtime error
Runtime error
import torch | |
import re | |
import gradio as gr | |
from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel | |
''' | |
device='cpu' | |
encoder_checkpoint = "Thibalte/captionning_project" | |
decoder_checkpoint = "Thibalte/captionning_project" | |
model_checkpoint = "Thibalte/captionning_project" | |
feature_extractor= ViTImageProcessor.from_pretrained(model_path) | |
feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint) | |
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint) | |
model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device) | |
''' | |
# Load the trained model | |
model_path = "Thibalte/captionning_project" | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
#Load ImageProcessor | |
feature_extractor= ViTImageProcessor.from_pretrained(model_path) | |
# Load model | |
model = VisionEncoderDecoderModel.from_pretrained(model_path) | |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
# use GPT2's eos_token as the pad as well as eos token | |
# generation | |
print(captions) | |
def predict(image,max_length=24, num_beams=4): | |
image = image.convert('RGB') | |
sequences = model.generate(pixel_values, num_beams=4, max_length=25) | |
sequences = model.generate(pixel_values, num_beams=4, max_length=25) | |
captions = tokenizer.batch_decode(sequences, skip_special_tokens=True) | |
return caption | |
# Gradio Interface | |
gradio_app = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(label="Select image for captioning", sources=['upload', 'webcam'], type="pil"), | |
outputs=[gr.Textbox(label="Image Caption")], | |
examples = [f"example{i}.jpg" for i in range(1,7)], | |
title="Image Captioning with our model", | |
) | |
if __name__ == "__main__": | |
gradio_app.launch() |