import numpy as np
import gradio as gr
from PIL import Image
import torch
from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation
model_checkpoint = "apple/deeplabv3-mobilevit-small"
feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_checkpoint)
model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()
palette = np.array(
[
[ 0, 0, 0], [192, 0, 0], [ 0, 192, 0], [192, 192, 0],
[ 0, 0, 192], [192, 0, 192], [ 0, 192, 192], [192, 192, 192],
[128, 0, 0], [255, 0, 0], [128, 192, 0], [255, 192, 0],
[128, 0, 192], [255, 0, 192], [128, 192, 192], [255, 192, 192],
[ 0, 128, 0], [192, 128, 0], [ 0, 255, 0], [192, 255, 0],
[ 0, 128, 192]
],
dtype=np.uint8)
labels = [
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
# Draw the labels. Light colors use black text, dark colors use white text.
inverted = [ 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20 ]
labels_colored = []
for i in range(len(labels)):
r, g, b = palette[i]
label = labels[i]
color = "white" if i in inverted else "black"
text = "%s" % (r, g, b, color, label)
labels_colored.append(text)
labels_text = " ".join(labels_colored)
title = "Semantic Segmentation with MobileViT and DeepLabV3"
description = """
The input image is resized and center cropped to 512×512 pixels. The segmentation output is 32×32 pixels.
This model has been trained on Pascal VOC.
The classes are:
""" + labels_text + "
Sources:
📜 MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer
🏋️ Original pretrained weights from this GitHub repo
🏙 Example images from this dataset