import numpy as np import gradio as gr # teo from transformers import Dinov2Config, Dinov2Model, Dinov2ForImageClassification, AutoImageProcessor import torch.nn as nn import os # Load PyTorch model def create_head(num_features , number_classes ,dropout_prob=0.5 ,activation_func =nn.ReLU): features_lst = [num_features , num_features//2 , num_features//4] layers = [] for in_f ,out_f in zip(features_lst[:-1] , features_lst[1:]): layers.append(nn.Linear(in_f , out_f)) layers.append(activation_func()) layers.append(nn.BatchNorm1d(out_f)) if dropout_prob !=0 : layers.append(nn.Dropout(dropout_prob)) layers.append(nn.Linear(features_lst[-1] , number_classes)) return nn.Sequential(*layers) class NewheadDinov2ForImageClassification(Dinov2ForImageClassification): def __init__(self, config: Dinov2Config) -> None: super().__init__(config) self.num_labels = config.num_labels self.dinov2 = Dinov2Model(config) # Classifier head self.classifier = create_head(config.hidden_size * 2, config.num_labels) # IMPORT CLASSIFICATION MODEL checkpoint_name = "lombardata/dino-base-2023_11_27-with_custom_head" # import labels classes_names = ["Acropore_branched", "Acropore_digitised", "Acropore_tabular", "Algae_assembly", "Algae_limestone", "Algae_sodding", "Dead_coral", "Fish", "Human_object", "Living_coral", "Millepore", "No_acropore_encrusting", "No_acropore_massive", "No_acropore_sub_massive", "Rock", "Sand", "Scrap", "Sea_cucumber", "Syringodium_isoetifolium", "Thalassodendron_ciliatum", "Useless"] classes_nb = list(np.arange(len(classes_names))) id2label = {int(classes_nb[i]): classes_names[i] for i in range(len(classes_nb))} label2id = {v: k for k, v in id2label.items()} model = NewheadDinov2ForImageClassification.from_pretrained(checkpoint_name) def predict(path): image = path.reshape((224, 224, 3)) image = tf.keras.utils.img_to_array(image) image = np.expand_dims(image, axis=0) pred = model.predict(image, verbose=0) pred = pred[0] confidences = {classes_names[i]: round(float(pred[i]), 2) for i in range(50)} return confidences gr.Interface( fn=predict, inputs=gr.Image(shape=(224, 224)), outputs=gr.Label(num_top_classes=5), examples=[ "Dalbergia oliveri.JPG", "Eucalyptus.JPG", "Khaya senegalensis.JPG", "Syzygium nervosum.JPG", ] ).launch()