File size: 1,357 Bytes
cd254a4
1a410aa
d933416
cd254a4
23e6b97
05c382a
cd254a4
05c382a
c95621e
34a7ce6
c95621e
 
34a7ce6
cd254a4
ab4c3a1
2417ee3
e35723c
4d335eb
 
c95621e
 
 
 
 
 
 
cd254a4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import gradio as gr
from  PIL import Image
import tempfile
import torch
from torchvision.io import read_image
from transformers import ViTForImageClassification, ViTFeatureExtractor,ViTImageProcessor

# With ViTImageProcessor we have error so i comment it
# model = ViTImageProcessor.from_pretrained('SeyedAli/Food-Image-Classification-VIT')

model = ViTForImageClassification.from_pretrained('SeyedAli/Food-Image-Classification-VIT')
feature_extractor = ViTFeatureExtractor.from_pretrained('SeyedAli/Food-Image-Classification-VIT')

def FoodClassification(image):
    with tempfile.NamedTemporaryFile(suffix=".png") as temp_image_file:
        # Copy the contents of the uploaded image file to the temporary file
        Image.fromarray(image).save(temp_image_file.name)
        # Load the image file using torchvision
        image = read_image(temp_image_file.name)
        # Preprocess the image using the ViT feature extractor
        inputs = feature_extractor(images=image, return_tensors="pt")
        # Use the ViT model for image classification
        outputs = model(**inputs)
        predicted_class_idx = torch.argmax(outputs.logits)
        predicted_class = model.config.id2label[predicted_class_idx.item()]
        return predicted_class

iface = gr.Interface(fn=FoodClassification, inputs="image", outputs="text")
iface.launch(share=False)