Spaces:
Runtime error
Runtime error
File size: 7,831 Bytes
da3eeba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import os
import platform
from functools import partial
import torch
from fast_sam import FastSamAutomaticMaskGenerator, fast_sam_model_registry
from ia_check_versions import ia_check_versions
from ia_config import IAConfig
from ia_devices import devices
from ia_logging import ia_logging
from mobile_sam import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorMobile
from mobile_sam import SamPredictor as SamPredictorMobile
from mobile_sam import sam_model_registry as sam_model_registry_mobile
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.build_sam import build_sam2
from segment_anything_fb import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from segment_anything_hq import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorHQ
from segment_anything_hq import SamPredictor as SamPredictorHQ
from segment_anything_hq import sam_model_registry as sam_model_registry_hq
def check_bfloat16_support() -> bool:
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability(torch.cuda.current_device())
if compute_capability[0] >= 8:
ia_logging.debug("The CUDA device supports bfloat16")
return True
else:
ia_logging.debug("The CUDA device does not support bfloat16")
return False
else:
ia_logging.debug("CUDA is not available")
return False
def partial_from_end(func, /, *fixed_args, **fixed_kwargs):
def wrapper(*args, **kwargs):
updated_kwargs = {**fixed_kwargs, **kwargs}
return func(*args, *fixed_args, **updated_kwargs)
return wrapper
def rename_args(func, arg_map):
def wrapper(*args, **kwargs):
new_kwargs = {arg_map.get(k, k): v for k, v in kwargs.items()}
return func(*args, **new_kwargs)
return wrapper
arg_map = {"checkpoint": "ckpt_path"}
rename_build_sam2 = rename_args(build_sam2, arg_map)
end_kwargs = dict(device="cpu", mode="eval", hydra_overrides_extra=[], apply_postprocessing=False)
sam2_model_registry = {
"sam2_hiera_large": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_l.yaml"),
"sam2_hiera_base_plus": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_b+.yaml"),
"sam2_hiera_small": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_s.yaml"),
"sam2_hiera_tiny": partial(partial_from_end(rename_build_sam2, **end_kwargs), "sam2_hiera_t.yaml"),
}
def get_sam_mask_generator(sam_checkpoint, anime_style_chk=False):
"""Get SAM mask generator.
Args:
sam_checkpoint (str): SAM checkpoint path
Returns:
SamAutomaticMaskGenerator or None: SAM mask generator
"""
points_per_batch = 64
if "_hq_" in os.path.basename(sam_checkpoint):
model_type = os.path.basename(sam_checkpoint)[7:12]
sam_model_registry_local = sam_model_registry_hq
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorHQ
points_per_batch = 32
elif "FastSAM" in os.path.basename(sam_checkpoint):
model_type = os.path.splitext(os.path.basename(sam_checkpoint))[0]
sam_model_registry_local = fast_sam_model_registry
SamAutomaticMaskGeneratorLocal = FastSamAutomaticMaskGenerator
points_per_batch = None
elif "mobile_sam" in os.path.basename(sam_checkpoint):
model_type = "vit_t"
sam_model_registry_local = sam_model_registry_mobile
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorMobile
points_per_batch = 64
elif "sam2_" in os.path.basename(sam_checkpoint):
model_type = os.path.splitext(os.path.basename(sam_checkpoint))[0]
sam_model_registry_local = sam2_model_registry
SamAutomaticMaskGeneratorLocal = SAM2AutomaticMaskGenerator
points_per_batch = 128
else:
model_type = os.path.basename(sam_checkpoint)[4:9]
sam_model_registry_local = sam_model_registry
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGenerator
points_per_batch = 64
pred_iou_thresh = 0.88 if not anime_style_chk else 0.83
stability_score_thresh = 0.95 if not anime_style_chk else 0.9
if "sam2_" in model_type:
pred_iou_thresh = round(pred_iou_thresh - 0.18, 2)
stability_score_thresh = round(stability_score_thresh - 0.03, 2)
sam2_gen_kwargs = dict(
points_per_side=64,
points_per_batch=points_per_batch,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
stability_score_offset=0.7,
crop_n_layers=1,
box_nms_thresh=0.7,
crop_n_points_downscale_factor=2)
if platform.system() == "Darwin":
sam2_gen_kwargs.update(dict(points_per_side=32, points_per_batch=64, crop_n_points_downscale_factor=1))
if os.path.isfile(sam_checkpoint):
sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
if platform.system() == "Darwin":
if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
sam.to(device=torch.device("cpu"))
else:
sam.to(device=torch.device("mps"))
else:
if IAConfig.global_args.get("sam_cpu", False):
ia_logging.info("SAM is running on CPU... (the option has been selected)")
sam.to(device=devices.cpu)
else:
sam.to(device=devices.device)
sam_gen_kwargs = dict(
model=sam, points_per_batch=points_per_batch, pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
if "sam2_" in model_type:
sam_gen_kwargs.update(sam2_gen_kwargs)
sam_mask_generator = SamAutomaticMaskGeneratorLocal(**sam_gen_kwargs)
else:
sam_mask_generator = None
return sam_mask_generator
def get_sam_predictor(sam_checkpoint):
"""Get SAM predictor.
Args:
sam_checkpoint (str): SAM checkpoint path
Returns:
SamPredictor or None: SAM predictor
"""
# model_type = "vit_h"
if "_hq_" in os.path.basename(sam_checkpoint):
model_type = os.path.basename(sam_checkpoint)[7:12]
sam_model_registry_local = sam_model_registry_hq
SamPredictorLocal = SamPredictorHQ
elif "FastSAM" in os.path.basename(sam_checkpoint):
raise NotImplementedError("FastSAM predictor is not implemented yet.")
elif "mobile_sam" in os.path.basename(sam_checkpoint):
model_type = "vit_t"
sam_model_registry_local = sam_model_registry_mobile
SamPredictorLocal = SamPredictorMobile
else:
model_type = os.path.basename(sam_checkpoint)[4:9]
sam_model_registry_local = sam_model_registry
SamPredictorLocal = SamPredictor
if os.path.isfile(sam_checkpoint):
sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
if platform.system() == "Darwin":
if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
sam.to(device=torch.device("cpu"))
else:
sam.to(device=torch.device("mps"))
else:
if IAConfig.global_args.get("sam_cpu", False):
ia_logging.info("SAM is running on CPU... (the option has been selected)")
sam.to(device=devices.cpu)
else:
sam.to(device=devices.device)
sam_predictor = SamPredictorLocal(sam)
else:
sam_predictor = None
return sam_predictor
|