File size: 3,717 Bytes
863aaff
 
9e66188
7092e6d
 
 
d73ed72
863aaff
28f1ee4
c7179a0
28f1ee4
 
fae65ac
 
 
 
 
 
 
 
 
 
28f1ee4
863aaff
fae65ac
 
 
 
 
 
 
 
 
4f4630f
 
 
 
 
 
 
 
 
 
fae65ac
 
 
 
 
 
 
 
 
 
 
4f4630f
863aaff
95ddf62
 
 
7989a0a
95ddf62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb85cd6
 
 
fc6fbc6
95ddf62
863aaff
 
 
eb85cd6
 
1acb5c4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import gradio as gr
import torch
from transformers import Dinov2Config, Dinov2Model, Dinov2ForImageClassification, AutoImageProcessor
import torch.nn as nn
import os
from huggingface_hub import hf_hub_download

model_name = "dinov2-large-2024_01_24-with_data_aug_batch-size32_epochs93_freeze"
checkpoint_name = "lombardata/" + model_name

# CREATE CUSTOM MODEL
def create_head(num_features , number_classes ,dropout_prob=0.5 ,activation_func =nn.ReLU):
    features_lst = [num_features , num_features//2 , num_features//4]
    layers = []
    for in_f ,out_f in zip(features_lst[:-1] , features_lst[1:]):
        layers.append(nn.Linear(in_f , out_f))
        layers.append(activation_func())
        layers.append(nn.BatchNorm1d(out_f))
        if dropout_prob !=0 : layers.append(nn.Dropout(dropout_prob))
    layers.append(nn.Linear(features_lst[-1] , number_classes))
    return nn.Sequential(*layers)
from transformers import Dinov2Config, Dinov2Model

class NewheadDinov2ForImageClassification(Dinov2ForImageClassification):
    def __init__(self, config: Dinov2Config) -> None:
        super().__init__(config)

        self.num_labels = config.num_labels
        self.dinov2 = Dinov2Model(config)

        # Classifier head
        self.classifier = create_head(config.hidden_size * 2, config.num_labels)

model = NewheadDinov2ForImageClassification.from_pretrained(checkpoint_name)

# IMPORT MODEL CONFIG PARAMETERS
hf_hub_download(repo_id=checkpoint_name, filename="config.json")
id2label = config["id2label"]
label2id = config["label2id"]
image_size = config["image_size"]
classes_names = list(label2id.keys())
'''
# import labels
classes_names = ["Acropore_branched", "Acropore_digitised", "Acropore_tabular", "Algae_assembly", 
               "Algae_limestone", "Algae_sodding", "Dead_coral", "Fish", "Human_object",
               "Living_coral", "Millepore", "No_acropore_encrusting", "No_acropore_massive",
              "No_acropore_sub_massive",  "Rock", "Sand",
               "Scrap", "Sea_cucumber", "Syringodium_isoetifolium",
               "Thalassodendron_ciliatum",  "Useless"]

classes_nb = list(np.arange(len(classes_names)))
id2label = {int(classes_nb[i]): classes_names[i] for i in range(len(classes_nb))}
label2id = {v: k for k, v in id2label.items()}
'''

def sigmoid(_outputs):
    return 1.0 / (1.0 + np.exp(-_outputs))
    
def predict(input_image):
    image_processor = AutoImageProcessor.from_pretrained(checkpoint_name)
    # predict
    inputs = image_processor(input_image, return_tensors="pt")
    inputs = inputs
    with torch.no_grad():
        model_outputs = model(**inputs)
    outputs = model_outputs["logits"][0]
    scores = sigmoid(outputs)
    result = {}
    i = 0
    for score in scores:
        label = id2label[i]
        result[label] = float(score)
        i += 1
    result = {key: result[key] for key in result if result[key] > 0.5}
    return result
    
# Define style
title = "DinoVd'eau image classification"
description = f"This is a prototype application that demonstrates how artificial intelligence-based systems can recognize what object(s) is present in an underwater image. To use it, simply upload your image, or click one of the example images to load them. For predictions, we use the open-source model {checkpoint_name}"

gr.Interface(
    fn=predict,
    inputs=gr.Image(shape=(224, 224)),
    outputs="label",
    title=title, 
    description=description,
    examples=["GOPR0106.JPG", 
              "session_2021_08_30_Mayotte_10_image_00066.jpg", 
              "session_2018_11_17_kite_Le_Morne_Manawa_G0065777.JPG", 
              "session_2023_06_28_caplahoussaye_plancha_body_v1B_00_GP1_3_1327.jpeg"]).launch()