Spaces:
Sleeping
Sleeping
# Ultralytics YOLO π, AGPL-3.0 license | |
""" | |
Generate predictions using the Segment Anything Model (SAM). | |
SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance. | |
This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation | |
using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image | |
segmentation tasks. | |
""" | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torchvision | |
from ultralytics.data.augment import LetterBox | |
from ultralytics.engine.predictor import BasePredictor | |
from ultralytics.engine.results import Results | |
from ultralytics.utils import DEFAULT_CFG, ops | |
from ultralytics.utils.torch_utils import select_device | |
from .amg import ( | |
batch_iterator, | |
batched_mask_to_box, | |
build_all_layer_point_grids, | |
calculate_stability_score, | |
generate_crop_boxes, | |
is_box_near_crop_edge, | |
remove_small_regions, | |
uncrop_boxes_xyxy, | |
uncrop_masks, | |
) | |
from .build import build_sam | |
class Predictor(BasePredictor): | |
""" | |
Predictor class for the Segment Anything Model (SAM), extending BasePredictor. | |
The class provides an interface for model inference tailored to image segmentation tasks. | |
With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time | |
mask generation. The class is capable of working with various types of prompts such as bounding boxes, | |
points, and low-resolution masks. | |
Attributes: | |
cfg (dict): Configuration dictionary specifying model and task-related parameters. | |
overrides (dict): Dictionary containing values that override the default configuration. | |
_callbacks (dict): Dictionary of user-defined callback functions to augment behavior. | |
args (namespace): Namespace to hold command-line arguments or other operational variables. | |
im (torch.Tensor): Preprocessed input image tensor. | |
features (torch.Tensor): Extracted image features used for inference. | |
prompts (dict): Collection of various prompt types, such as bounding boxes and points. | |
segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones. | |
""" | |
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): | |
""" | |
Initialize the Predictor with configuration, overrides, and callbacks. | |
The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It | |
initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results. | |
Args: | |
cfg (dict): Configuration dictionary. | |
overrides (dict, optional): Dictionary of values to override default configuration. | |
_callbacks (dict, optional): Dictionary of callback functions to customize behavior. | |
""" | |
if overrides is None: | |
overrides = {} | |
overrides.update(dict(task="segment", mode="predict", imgsz=1024)) | |
super().__init__(cfg, overrides, _callbacks) | |
self.args.retina_masks = True | |
self.im = None | |
self.features = None | |
self.prompts = {} | |
self.segment_all = False | |
def preprocess(self, im): | |
""" | |
Preprocess the input image for model inference. | |
The method prepares the input image by applying transformations and normalization. | |
It supports both torch.Tensor and list of np.ndarray as input formats. | |
Args: | |
im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays. | |
Returns: | |
(torch.Tensor): The preprocessed image tensor. | |
""" | |
if self.im is not None: | |
return self.im | |
not_tensor = not isinstance(im, torch.Tensor) | |
if not_tensor: | |
im = np.stack(self.pre_transform(im)) | |
im = im[..., ::-1].transpose((0, 3, 1, 2)) | |
im = np.ascontiguousarray(im) | |
im = torch.from_numpy(im) | |
im = im.to(self.device) | |
im = im.half() if self.model.fp16 else im.float() | |
if not_tensor: | |
im = (im - self.mean) / self.std | |
return im | |
def pre_transform(self, im): | |
""" | |
Perform initial transformations on the input image for preprocessing. | |
The method applies transformations such as resizing to prepare the image for further preprocessing. | |
Currently, batched inference is not supported; hence the list length should be 1. | |
Args: | |
im (List[np.ndarray]): List containing images in HWC numpy array format. | |
Returns: | |
(List[np.ndarray]): List of transformed images. | |
""" | |
assert len(im) == 1, "SAM model does not currently support batched inference" | |
letterbox = LetterBox(self.args.imgsz, auto=False, center=False) | |
return [letterbox(image=x) for x in im] | |
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs): | |
""" | |
Perform image segmentation inference based on the given input cues, using the currently loaded image. This | |
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and | |
mask decoder for real-time and promptable segmentation tasks. | |
Args: | |
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). | |
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. | |
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates. | |
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background. | |
masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256. | |
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False. | |
Returns: | |
(tuple): Contains the following three elements. | |
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks. | |
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask. | |
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256. | |
""" | |
# Override prompts if any stored in self.prompts | |
bboxes = self.prompts.pop("bboxes", bboxes) | |
points = self.prompts.pop("points", points) | |
masks = self.prompts.pop("masks", masks) | |
if all(i is None for i in [bboxes, points, masks]): | |
return self.generate(im, *args, **kwargs) | |
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) | |
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False): | |
""" | |
Internal function for image segmentation inference based on cues like bounding boxes, points, and masks. | |
Leverages SAM's specialized architecture for prompt-based, real-time segmentation. | |
Args: | |
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). | |
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. | |
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixel coordinates. | |
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 for foreground and 0 for background. | |
masks (np.ndarray, optional): Low-resolution masks from previous predictions. Shape should be (N, H, W). For SAM, H=W=256. | |
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. Defaults to False. | |
Returns: | |
(tuple): Contains the following three elements. | |
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks. | |
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask. | |
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256. | |
""" | |
features = self.model.image_encoder(im) if self.features is None else self.features | |
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:] | |
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) | |
# Transform input prompts | |
if points is not None: | |
points = torch.as_tensor(points, dtype=torch.float32, device=self.device) | |
points = points[None] if points.ndim == 1 else points | |
# Assuming labels are all positive if users don't pass labels. | |
if labels is None: | |
labels = np.ones(points.shape[0]) | |
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) | |
points *= r | |
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1) | |
points, labels = points[:, None, :], labels[:, None] | |
if bboxes is not None: | |
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) | |
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes | |
bboxes *= r | |
if masks is not None: | |
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) | |
points = (points, labels) if points is not None else None | |
# Embed prompts | |
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks) | |
# Predict masks | |
pred_masks, pred_scores = self.model.mask_decoder( | |
image_embeddings=features, | |
image_pe=self.model.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=multimask_output, | |
) | |
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) | |
# `d` could be 1 or 3 depends on `multimask_output`. | |
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) | |
def generate( | |
self, | |
im, | |
crop_n_layers=0, | |
crop_overlap_ratio=512 / 1500, | |
crop_downscale_factor=1, | |
point_grids=None, | |
points_stride=32, | |
points_batch_size=64, | |
conf_thres=0.88, | |
stability_score_thresh=0.95, | |
stability_score_offset=0.95, | |
crop_nms_thresh=0.7, | |
): | |
""" | |
Perform image segmentation using the Segment Anything Model (SAM). | |
This function segments an entire image into constituent parts by leveraging SAM's advanced architecture | |
and real-time performance capabilities. It can optionally work on image crops for finer segmentation. | |
Args: | |
im (torch.Tensor): Input tensor representing the preprocessed image with dimensions (N, C, H, W). | |
crop_n_layers (int): Specifies the number of layers for additional mask predictions on image crops. | |
Each layer produces 2**i_layer number of image crops. | |
crop_overlap_ratio (float): Determines the extent of overlap between crops. Scaled down in subsequent layers. | |
crop_downscale_factor (int): Scaling factor for the number of sampled points-per-side in each layer. | |
point_grids (list[np.ndarray], optional): Custom grids for point sampling normalized to [0,1]. | |
Used in the nth crop layer. | |
points_stride (int, optional): Number of points to sample along each side of the image. | |
Exclusive with 'point_grids'. | |
points_batch_size (int): Batch size for the number of points processed simultaneously. | |
conf_thres (float): Confidence threshold [0,1] for filtering based on the model's mask quality prediction. | |
stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on mask stability. | |
stability_score_offset (float): Offset value for calculating stability score. | |
crop_nms_thresh (float): IoU cutoff for Non-Maximum Suppression (NMS) to remove duplicate masks between crops. | |
Returns: | |
(tuple): A tuple containing segmented masks, confidence scores, and bounding boxes. | |
""" | |
self.segment_all = True | |
ih, iw = im.shape[2:] | |
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio) | |
if point_grids is None: | |
point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor) | |
pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], [] | |
for crop_region, layer_idx in zip(crop_regions, layer_idxs): | |
x1, y1, x2, y2 = crop_region | |
w, h = x2 - x1, y2 - y1 | |
area = torch.tensor(w * h, device=im.device) | |
points_scale = np.array([[w, h]]) # w, h | |
# Crop image and interpolate to input size | |
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False) | |
# (num_points, 2) | |
points_for_image = point_grids[layer_idx] * points_scale | |
crop_masks, crop_scores, crop_bboxes = [], [], [] | |
for (points,) in batch_iterator(points_batch_size, points_for_image): | |
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True) | |
# Interpolate predicted masks to input size | |
pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0] | |
idx = pred_score > conf_thres | |
pred_mask, pred_score = pred_mask[idx], pred_score[idx] | |
stability_score = calculate_stability_score( | |
pred_mask, self.model.mask_threshold, stability_score_offset | |
) | |
idx = stability_score > stability_score_thresh | |
pred_mask, pred_score = pred_mask[idx], pred_score[idx] | |
# Bool type is much more memory-efficient. | |
pred_mask = pred_mask > self.model.mask_threshold | |
# (N, 4) | |
pred_bbox = batched_mask_to_box(pred_mask).float() | |
keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih]) | |
if not torch.all(keep_mask): | |
pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask] | |
crop_masks.append(pred_mask) | |
crop_bboxes.append(pred_bbox) | |
crop_scores.append(pred_score) | |
# Do nms within this crop | |
crop_masks = torch.cat(crop_masks) | |
crop_bboxes = torch.cat(crop_bboxes) | |
crop_scores = torch.cat(crop_scores) | |
keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS | |
crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region) | |
crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw) | |
crop_scores = crop_scores[keep] | |
pred_masks.append(crop_masks) | |
pred_bboxes.append(crop_bboxes) | |
pred_scores.append(crop_scores) | |
region_areas.append(area.expand(len(crop_masks))) | |
pred_masks = torch.cat(pred_masks) | |
pred_bboxes = torch.cat(pred_bboxes) | |
pred_scores = torch.cat(pred_scores) | |
region_areas = torch.cat(region_areas) | |
# Remove duplicate masks between crops | |
if len(crop_regions) > 1: | |
scores = 1 / region_areas | |
keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh) | |
pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep] | |
return pred_masks, pred_scores, pred_bboxes | |
def setup_model(self, model, verbose=True): | |
""" | |
Initializes the Segment Anything Model (SAM) for inference. | |
This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary | |
parameters for image normalization and other Ultralytics compatibility settings. | |
Args: | |
model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration. | |
verbose (bool): If True, prints selected device information. | |
Attributes: | |
model (torch.nn.Module): The SAM model allocated to the chosen device for inference. | |
device (torch.device): The device to which the model and tensors are allocated. | |
mean (torch.Tensor): The mean values for image normalization. | |
std (torch.Tensor): The standard deviation values for image normalization. | |
""" | |
device = select_device(self.args.device, verbose=verbose) | |
if model is None: | |
model = build_sam(self.args.model) | |
model.eval() | |
self.model = model.to(device) | |
self.device = device | |
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device) | |
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device) | |
# Ultralytics compatibility settings | |
self.model.pt = False | |
self.model.triton = False | |
self.model.stride = 32 | |
self.model.fp16 = False | |
self.done_warmup = True | |
def postprocess(self, preds, img, orig_imgs): | |
""" | |
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes. | |
The method scales masks and boxes to the original image size and applies a threshold to the mask predictions. The | |
SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance. | |
Args: | |
preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes. | |
img (torch.Tensor): The processed input image tensor. | |
orig_imgs (list | torch.Tensor): The original, unprocessed images. | |
Returns: | |
(list): List of Results objects containing detection masks, bounding boxes, and other metadata. | |
""" | |
# (N, 1, H, W), (N, 1) | |
pred_masks, pred_scores = preds[:2] | |
pred_bboxes = preds[2] if self.segment_all else None | |
names = dict(enumerate(str(i) for i in range(len(pred_masks)))) | |
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list | |
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) | |
results = [] | |
for i, masks in enumerate([pred_masks]): | |
orig_img = orig_imgs[i] | |
if pred_bboxes is not None: | |
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False) | |
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) | |
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) | |
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] | |
masks = masks > self.model.mask_threshold # to bool | |
img_path = self.batch[0][i] | |
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes)) | |
# Reset segment-all mode. | |
self.segment_all = False | |
return results | |
def setup_source(self, source): | |
""" | |
Sets up the data source for inference. | |
This method configures the data source from which images will be fetched for inference. The source could be a | |
directory, a video file, or other types of image data sources. | |
Args: | |
source (str | Path): The path to the image data source for inference. | |
""" | |
if source is not None: | |
super().setup_source(source) | |
def set_image(self, image): | |
""" | |
Preprocesses and sets a single image for inference. | |
This function sets up the model if not already initialized, configures the data source to the specified image, | |
and preprocesses the image for feature extraction. Only one image can be set at a time. | |
Args: | |
image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2. | |
Raises: | |
AssertionError: If more than one image is set. | |
""" | |
if self.model is None: | |
model = build_sam(self.args.model) | |
self.setup_model(model) | |
self.setup_source(image) | |
assert len(self.dataset) == 1, "`set_image` only supports setting one image!" | |
for batch in self.dataset: | |
im = self.preprocess(batch[1]) | |
self.features = self.model.image_encoder(im) | |
self.im = im | |
break | |
def set_prompts(self, prompts): | |
"""Set prompts in advance.""" | |
self.prompts = prompts | |
def reset_image(self): | |
"""Resets the image and its features to None.""" | |
self.im = None | |
self.features = None | |
def remove_small_regions(masks, min_area=0, nms_thresh=0.7): | |
""" | |
Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this | |
function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum | |
Suppression (NMS) to eliminate any newly created duplicate boxes. | |
Args: | |
masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is | |
the number of masks, H is height, and W is width. | |
min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0. | |
nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7. | |
Returns: | |
(tuple([torch.Tensor, List[int]])): | |
- new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W). | |
- keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes. | |
""" | |
if len(masks) == 0: | |
return masks | |
# Filter small disconnected regions and holes | |
new_masks = [] | |
scores = [] | |
for mask in masks: | |
mask = mask.cpu().numpy().astype(np.uint8) | |
mask, changed = remove_small_regions(mask, min_area, mode="holes") | |
unchanged = not changed | |
mask, changed = remove_small_regions(mask, min_area, mode="islands") | |
unchanged = unchanged and not changed | |
new_masks.append(torch.as_tensor(mask).unsqueeze(0)) | |
# Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing | |
scores.append(float(unchanged)) | |
# Recalculate boxes and remove any new duplicates | |
new_masks = torch.cat(new_masks, dim=0) | |
boxes = batched_mask_to_box(new_masks) | |
keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh) | |
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep | |