_
File size: 7,831 Bytes
da3eeba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import platform
from functools import partial

import torch

from fast_sam import FastSamAutomaticMaskGenerator, fast_sam_model_registry
from ia_check_versions import ia_check_versions
from ia_config import IAConfig
from ia_devices import devices
from ia_logging import ia_logging
from mobile_sam import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorMobile
from mobile_sam import SamPredictor as SamPredictorMobile
from mobile_sam import sam_model_registry as sam_model_registry_mobile
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.build_sam import build_sam2
from segment_anything_fb import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from segment_anything_hq import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorHQ
from segment_anything_hq import SamPredictor as SamPredictorHQ
from segment_anything_hq import sam_model_registry as sam_model_registry_hq


def check_bfloat16_support() -> bool:
    if torch.cuda.is_available():
        compute_capability = torch.cuda.get_device_capability(torch.cuda.current_device())
        if compute_capability[0] >= 8:
            ia_logging.debug("The CUDA device supports bfloat16")
            return True
        else:
            ia_logging.debug("The CUDA device does not support bfloat16")
            return False
    else:
        ia_logging.debug("CUDA is not available")
        return False


def partial_from_end(func, /, *fixed_args, **fixed_kwargs):
    def wrapper(*args, **kwargs):
        updated_kwargs = {**fixed_kwargs, **kwargs}
        return func(*args, *fixed_args, **updated_kwargs)
    return wrapper


def rename_args(func, arg_map):
    def wrapper(*args, **kwargs):
        new_kwargs = {arg_map.get(k, k): v for k, v in kwargs.items()}
        return func(*args, **new_kwargs)
    return wrapper


arg_map = {"checkpoint": "ckpt_path"}
rename_build_sam2 = rename_args(build_sam2, arg_map)
end_kwargs = dict(device="cpu", mode="eval", hydra_overrides_extra=[], apply_postprocessing=False)
sam2_model_registry = {
    "sam2_hiera_large": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_l.yaml"),
    "sam2_hiera_base_plus": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_b+.yaml"),
    "sam2_hiera_small": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_s.yaml"),
    "sam2_hiera_tiny": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_t.yaml"),
}


def get_sam_mask_generator(sam_checkpoint, anime_style_chk=False):
    """Get SAM mask generator.



    Args:

        sam_checkpoint (str): SAM checkpoint path



    Returns:

        SamAutomaticMaskGenerator or None: SAM mask generator

    """
    points_per_batch = 64
    if "_hq_" in os.path.basename(sam_checkpoint):
        model_type = os.path.basename(sam_checkpoint)[7:12]
        sam_model_registry_local = sam_model_registry_hq
        SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorHQ
        points_per_batch = 32
    elif "FastSAM" in os.path.basename(sam_checkpoint):
        model_type = os.path.splitext(os.path.basename(sam_checkpoint))[0]
        sam_model_registry_local = fast_sam_model_registry
        SamAutomaticMaskGeneratorLocal = FastSamAutomaticMaskGenerator
        points_per_batch = None
    elif "mobile_sam" in os.path.basename(sam_checkpoint):
        model_type = "vit_t"
        sam_model_registry_local = sam_model_registry_mobile
        SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorMobile
        points_per_batch = 64
    elif "sam2_" in os.path.basename(sam_checkpoint):
        model_type = os.path.splitext(os.path.basename(sam_checkpoint))[0]
        sam_model_registry_local = sam2_model_registry
        SamAutomaticMaskGeneratorLocal = SAM2AutomaticMaskGenerator
        points_per_batch = 128
    else:
        model_type = os.path.basename(sam_checkpoint)[4:9]
        sam_model_registry_local = sam_model_registry
        SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGenerator
        points_per_batch = 64

    pred_iou_thresh = 0.88 if not anime_style_chk else 0.83
    stability_score_thresh = 0.95 if not anime_style_chk else 0.9

    if "sam2_" in model_type:
        pred_iou_thresh = round(pred_iou_thresh - 0.18, 2)
        stability_score_thresh = round(stability_score_thresh - 0.03, 2)
        sam2_gen_kwargs = dict(
            points_per_side=64,
            points_per_batch=points_per_batch,
            pred_iou_thresh=pred_iou_thresh,
            stability_score_thresh=stability_score_thresh,
            stability_score_offset=0.7,
            crop_n_layers=1,
            box_nms_thresh=0.7,
            crop_n_points_downscale_factor=2)
        if platform.system() == "Darwin":
            sam2_gen_kwargs.update(dict(points_per_side=32, points_per_batch=64, crop_n_points_downscale_factor=1))

    if os.path.isfile(sam_checkpoint):
        sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
        if platform.system() == "Darwin":
            if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
                sam.to(device=torch.device("cpu"))
            else:
                sam.to(device=torch.device("mps"))
        else:
            if IAConfig.global_args.get("sam_cpu", False):
                ia_logging.info("SAM is running on CPU... (the option has been selected)")
                sam.to(device=devices.cpu)
            else:
                sam.to(device=devices.device)
        sam_gen_kwargs = dict(
            model=sam, points_per_batch=points_per_batch, pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
        if "sam2_" in model_type:
            sam_gen_kwargs.update(sam2_gen_kwargs)
        sam_mask_generator = SamAutomaticMaskGeneratorLocal(**sam_gen_kwargs)
    else:
        sam_mask_generator = None

    return sam_mask_generator


def get_sam_predictor(sam_checkpoint):
    """Get SAM predictor.



    Args:

        sam_checkpoint (str): SAM checkpoint path



    Returns:

        SamPredictor or None: SAM predictor

    """
    # model_type = "vit_h"
    if "_hq_" in os.path.basename(sam_checkpoint):
        model_type = os.path.basename(sam_checkpoint)[7:12]
        sam_model_registry_local = sam_model_registry_hq
        SamPredictorLocal = SamPredictorHQ
    elif "FastSAM" in os.path.basename(sam_checkpoint):
        raise NotImplementedError("FastSAM predictor is not implemented yet.")
    elif "mobile_sam" in os.path.basename(sam_checkpoint):
        model_type = "vit_t"
        sam_model_registry_local = sam_model_registry_mobile
        SamPredictorLocal = SamPredictorMobile
    else:
        model_type = os.path.basename(sam_checkpoint)[4:9]
        sam_model_registry_local = sam_model_registry
        SamPredictorLocal = SamPredictor

    if os.path.isfile(sam_checkpoint):
        sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
        if platform.system() == "Darwin":
            if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
                sam.to(device=torch.device("cpu"))
            else:
                sam.to(device=torch.device("mps"))
        else:
            if IAConfig.global_args.get("sam_cpu", False):
                ia_logging.info("SAM is running on CPU... (the option has been selected)")
                sam.to(device=devices.cpu)
            else:
                sam.to(device=devices.device)
        sam_predictor = SamPredictorLocal(sam)
    else:
        sam_predictor = None

    return sam_predictor