|
import numpy as np |
|
import gradio as gr |
|
|
|
import torch |
|
from transformers import Dinov2Config, Dinov2Model, Dinov2ForImageClassification, AutoImageProcessor |
|
import torch.nn as nn |
|
import os |
|
|
|
|
|
def create_head(num_features , number_classes ,dropout_prob=0.5 ,activation_func =nn.ReLU): |
|
features_lst = [num_features , num_features//2 , num_features//4] |
|
layers = [] |
|
for in_f ,out_f in zip(features_lst[:-1] , features_lst[1:]): |
|
layers.append(nn.Linear(in_f , out_f)) |
|
layers.append(activation_func()) |
|
layers.append(nn.BatchNorm1d(out_f)) |
|
if dropout_prob !=0 : layers.append(nn.Dropout(dropout_prob)) |
|
layers.append(nn.Linear(features_lst[-1] , number_classes)) |
|
return nn.Sequential(*layers) |
|
|
|
class NewheadDinov2ForImageClassification(Dinov2ForImageClassification): |
|
def __init__(self, config: Dinov2Config) -> None: |
|
super().__init__(config) |
|
|
|
self.num_labels = config.num_labels |
|
self.dinov2 = Dinov2Model(config) |
|
|
|
|
|
self.classifier = create_head(config.hidden_size * 2, config.num_labels) |
|
|
|
checkpoint_name = "lombardata/dino-base-2023_11_27-with_custom_head" |
|
|
|
classes_names = ["Acropore_branched", "Acropore_digitised", "Acropore_tabular", "Algae_assembly", |
|
"Algae_limestone", "Algae_sodding", "Dead_coral", "Fish", "Human_object", |
|
"Living_coral", "Millepore", "No_acropore_encrusting", "No_acropore_massive", |
|
"No_acropore_sub_massive", "Rock", "Sand", |
|
"Scrap", "Sea_cucumber", "Syringodium_isoetifolium", |
|
"Thalassodendron_ciliatum", "Useless"] |
|
|
|
classes_nb = list(np.arange(len(classes_names))) |
|
id2label = {int(classes_nb[i]): classes_names[i] for i in range(len(classes_nb))} |
|
label2id = {v: k for k, v in id2label.items()} |
|
|
|
model = NewheadDinov2ForImageClassification.from_pretrained(checkpoint_name) |
|
|
|
def sigmoid(_outputs): |
|
return 1.0 / (1.0 + np.exp(-_outputs)) |
|
|
|
def predict(input_image): |
|
image_processor = AutoImageProcessor.from_pretrained(checkpoint_name) |
|
|
|
inputs = image_processor(input_image, return_tensors="pt") |
|
inputs = inputs |
|
with torch.no_grad(): |
|
model_outputs = model(**inputs) |
|
outputs = model_outputs["logits"][0] |
|
scores = sigmoid(outputs) |
|
result = {} |
|
i = 0 |
|
for score in scores: |
|
label = id2label[i] |
|
result[label] = float(score) |
|
i += 1 |
|
result = {key: result[key] for key in result if result[key] > 0.5} |
|
return result |
|
|
|
gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(shape=(224, 224)), |
|
|
|
outputs="label", |
|
examples=[ |
|
"Dalbergia oliveri.JPG", |
|
"Eucalyptus.JPG", |
|
"Khaya senegalensis.JPG", |
|
"Syzygium nervosum.JPG", |
|
] |
|
).launch() |
|
|