File size: 3,486 Bytes
9e66188
9d241a6
 
 
d73ed72
9d241a6
 
 
863aaff
62fb3c5
710cbf4
9d241a6
 
 
 
 
 
 
 
 
 
 
 
28f1ee4
62fb3c5
fae65ac
 
 
 
 
 
 
 
 
 
863aaff
fae65ac
 
 
 
 
 
efb3b24
4f4630f
9d241a6
95ddf62
 
9d241a6
 
 
 
 
95ddf62
9d241a6
95ddf62
 
9d241a6
 
 
 
 
 
 
 
 
 
 
eb85cd6
2aa5481
dcc2cb9
028f50a
95ddf62
9d241a6
863aaff
9d241a6
 
eb85cd6
254cad7
 
 
2640d4d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from transformers import AutoImageProcessor, Dinov2ForImageClassification, Dinov2Config, Dinov2Model
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import json
import torch.nn as nn
import numpy as np

# DEFINE MODEL NAME
model_name = "DinoVdeau-large-2024_04_03-with_data_aug_batch-size32_epochs150_freeze"
checkpoint_name = "lombardata/" + model_name

# Load the model configuration and create the model
config_path = hf_hub_download(repo_id=checkpoint_name, filename="config.json")
with open(config_path, 'r') as config_file:
    config = json.load(config_file)
id2label = config["id2label"]
label2id = config["label2id"]
image_size = config["image_size"]
num_labels = len(id2label)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# IMPORT CLASSIFICATION MODEL
def create_head(num_features , number_classes ,dropout_prob=0.5 ,activation_func =nn.ReLU):
    features_lst = [num_features , num_features//2 , num_features//4]
    layers = []
    for in_f ,out_f in zip(features_lst[:-1] , features_lst[1:]):
        layers.append(nn.Linear(in_f , out_f))
        layers.append(activation_func())
        layers.append(nn.BatchNorm1d(out_f))
        if dropout_prob !=0 : layers.append(nn.Dropout(dropout_prob))
    layers.append(nn.Linear(features_lst[-1] , number_classes))
    return nn.Sequential(*layers)

class NewheadDinov2ForImageClassification(Dinov2ForImageClassification):
    def __init__(self, config: Dinov2Config) -> None:
        super().__init__(config)

        # Classifier head
        self.classifier = create_head(config.hidden_size * 2, config.num_labels)
        
model = NewheadDinov2ForImageClassification.from_pretrained(checkpoint_name)
model.to(device)
def sigmoid(_outputs):
    return 1.0 / (1.0 + np.exp(-_outputs))

def predict(image, threshold):
    # Preprocess the image
    processor = AutoImageProcessor.from_pretrained(checkpoint_name)
    inputs = processor(images=image, return_tensors="pt").to(device)
    
    # Get model predictions
    with torch.no_grad():
        model_outputs = model(**inputs)
    logits = model_outputs.logits[0]
    probabilities = torch.sigmoid(logits).cpu().numpy()  # Convert to probabilities

    # Create a dictionary of label scores
    results = {id2label[str(i)]: float(prob) for i, prob in enumerate(probabilities)}

    # Filter out predictions below a certain threshold (e.g., 0.5)
    filtered_results = {label: prob for label, prob in results.items() if prob > threshold}

    return filtered_results 

# Define style
title = "Victor - DinoVd'eau image classification"
model_link = "https://huggingface.co/" + checkpoint_name
description = f"This application showcases the capability of artificial intelligence-based systems to identify objects within underwater images. To utilize it, you can either upload your own image or select one of the provided examples for analysis.\nFor predictions, we use this [open-source model]({model_link})"

iface = gr.Interface(
    fn=predict,
    inputs=[gr.components.Image(type="pil"), gr.components.Slider(minimum=0, maximum=1, value=0.5, label="Threshold")],
    outputs=gr.components.Label(),
    title=title, 
    examples=[["session_GOPR0106.JPG"], 
              ["session_2021_08_30_Mayotte_10_image_00066.jpg"], 
              ["session_2018_11_17_kite_Le_Morne_Manawa_G0065777.JPG"], 
              ["session_2023_06_28_caplahoussaye_plancha_body_v1B_00_GP1_3_1327.jpeg"]]).launch()