# Copyright (c) OpenMMLab. All rights reserved. import mmcv import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule from mmdet.core import InstanceData, mask_matrix_nms, multi_apply from mmdet.core.utils import center_of_mass, generate_coordinate from mmdet.models.builder import HEADS, build_loss from mmdet.utils.misc import floordiv from .base_mask_head import BaseMaskHead @HEADS.register_module() class SOLOHead(BaseMaskHead): """SOLO mask head used in `SOLO: Segmenting Objects by Locations. `_ Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. feat_channels (int): Number of hidden channels. Used in child classes. Default: 256. stacked_convs (int): Number of stacking convs of the head. Default: 4. strides (tuple): Downsample factor of each feature map. scale_ranges (tuple[tuple[int, int]]): Area range of multiple level masks, in the format [(min1, max1), (min2, max2), ...]. A range of (16, 64) means the area range between (16, 64). pos_scale (float): Constant scale factor to control the center region. num_grids (list[int]): Divided image into a uniform grids, each feature map has a different grid value. The number of output channels is grid ** 2. Default: [40, 36, 24, 16, 12]. cls_down_index (int): The index of downsample operation in classification branch. Default: 0. loss_mask (dict): Config of mask loss. loss_cls (dict): Config of classification loss. norm_cfg (dict): dictionary to construct and config norm layer. Default: norm_cfg=dict(type='GN', num_groups=32, requires_grad=True). train_cfg (dict): Training config of head. test_cfg (dict): Testing config of head. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__( self, num_classes, in_channels, feat_channels=256, stacked_convs=4, strides=(4, 8, 16, 32, 64), scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)), pos_scale=0.2, num_grids=[40, 36, 24, 16, 12], cls_down_index=0, loss_mask=None, loss_cls=None, norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), train_cfg=None, test_cfg=None, init_cfg=[ dict(type='Normal', layer='Conv2d', std=0.01), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_mask_list')), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_cls')) ], ): super(SOLOHead, self).__init__(init_cfg) self.num_classes = num_classes self.cls_out_channels = self.num_classes self.in_channels = in_channels self.feat_channels = feat_channels self.stacked_convs = stacked_convs self.strides = strides self.num_grids = num_grids # number of FPN feats self.num_levels = len(strides) assert self.num_levels == len(scale_ranges) == len(num_grids) self.scale_ranges = scale_ranges self.pos_scale = pos_scale self.cls_down_index = cls_down_index self.loss_cls = build_loss(loss_cls) self.loss_mask = build_loss(loss_mask) self.norm_cfg = norm_cfg self.init_cfg = init_cfg self.train_cfg = train_cfg self.test_cfg = test_cfg self._init_layers() def _init_layers(self): self.mask_convs = nn.ModuleList() self.cls_convs = nn.ModuleList() for i in range(self.stacked_convs): chn = self.in_channels + 2 if i == 0 else self.feat_channels self.mask_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, norm_cfg=self.norm_cfg)) chn = self.in_channels if i == 0 else self.feat_channels self.cls_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, norm_cfg=self.norm_cfg)) self.conv_mask_list = nn.ModuleList() for num_grid in self.num_grids: self.conv_mask_list.append( nn.Conv2d(self.feat_channels, num_grid**2, 1)) self.conv_cls = nn.Conv2d( self.feat_channels, self.cls_out_channels, 3, padding=1) def resize_feats(self, feats): """Downsample the first feat and upsample last feat in feats.""" out = [] for i in range(len(feats)): if i == 0: out.append( F.interpolate( feats[0], size=feats[i + 1].shape[-2:], mode='bilinear', align_corners=False)) elif i == len(feats) - 1: out.append( F.interpolate( feats[i], size=feats[i - 1].shape[-2:], mode='bilinear', align_corners=False)) else: out.append(feats[i]) return out def forward(self, feats): assert len(feats) == self.num_levels feats = self.resize_feats(feats) mlvl_mask_preds = [] mlvl_cls_preds = [] for i in range(self.num_levels): x = feats[i] mask_feat = x cls_feat = x # generate and concat the coordinate coord_feat = generate_coordinate(mask_feat.size(), mask_feat.device) mask_feat = torch.cat([mask_feat, coord_feat], 1) for mask_layer in (self.mask_convs): mask_feat = mask_layer(mask_feat) mask_feat = F.interpolate( mask_feat, scale_factor=2, mode='bilinear') mask_pred = self.conv_mask_list[i](mask_feat) # cls branch for j, cls_layer in enumerate(self.cls_convs): if j == self.cls_down_index: num_grid = self.num_grids[i] cls_feat = F.interpolate( cls_feat, size=num_grid, mode='bilinear') cls_feat = cls_layer(cls_feat) cls_pred = self.conv_cls(cls_feat) if not self.training: feat_wh = feats[0].size()[-2:] upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) mask_pred = F.interpolate( mask_pred.sigmoid(), size=upsampled_size, mode='bilinear') cls_pred = cls_pred.sigmoid() # get local maximum local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) keep_mask = local_max[:, :, :-1, :-1] == cls_pred cls_pred = cls_pred * keep_mask mlvl_mask_preds.append(mask_pred) mlvl_cls_preds.append(cls_pred) return mlvl_mask_preds, mlvl_cls_preds def loss(self, mlvl_mask_preds, mlvl_cls_preds, gt_labels, gt_masks, img_metas, gt_bboxes=None, **kwargs): """Calculate the loss of total batch. Args: mlvl_mask_preds (list[Tensor]): Multi-level mask prediction. Each element in the list has shape (batch_size, num_grids**2 ,h ,w). mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element in the list has shape (batch_size, num_classes, num_grids ,num_grids). gt_labels (list[Tensor]): Labels of multiple images. gt_masks (list[Tensor]): Ground truth masks of multiple images. Each has shape (num_instances, h, w). img_metas (list[dict]): Meta information of multiple images. gt_bboxes (list[Tensor]): Ground truth bboxes of multiple images. Default: None. Returns: dict[str, Tensor]: A dictionary of loss components. """ num_levels = self.num_levels num_imgs = len(gt_labels) featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds] # `BoolTensor` in `pos_masks` represent # whether the corresponding point is # positive pos_mask_targets, labels, pos_masks = multi_apply( self._get_targets_single, gt_bboxes, gt_labels, gt_masks, featmap_sizes=featmap_sizes) # change from the outside list meaning multi images # to the outside list meaning multi levels mlvl_pos_mask_targets = [[] for _ in range(num_levels)] mlvl_pos_mask_preds = [[] for _ in range(num_levels)] mlvl_pos_masks = [[] for _ in range(num_levels)] mlvl_labels = [[] for _ in range(num_levels)] for img_id in range(num_imgs): assert num_levels == len(pos_mask_targets[img_id]) for lvl in range(num_levels): mlvl_pos_mask_targets[lvl].append( pos_mask_targets[img_id][lvl]) mlvl_pos_mask_preds[lvl].append( mlvl_mask_preds[lvl][img_id, pos_masks[img_id][lvl], ...]) mlvl_pos_masks[lvl].append(pos_masks[img_id][lvl].flatten()) mlvl_labels[lvl].append(labels[img_id][lvl].flatten()) # cat multiple image temp_mlvl_cls_preds = [] for lvl in range(num_levels): mlvl_pos_mask_targets[lvl] = torch.cat( mlvl_pos_mask_targets[lvl], dim=0) mlvl_pos_mask_preds[lvl] = torch.cat( mlvl_pos_mask_preds[lvl], dim=0) mlvl_pos_masks[lvl] = torch.cat(mlvl_pos_masks[lvl], dim=0) mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0) temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute( 0, 2, 3, 1).reshape(-1, self.cls_out_channels)) num_pos = sum(item.sum() for item in mlvl_pos_masks) # dice loss loss_mask = [] for pred, target in zip(mlvl_pos_mask_preds, mlvl_pos_mask_targets): if pred.size()[0] == 0: loss_mask.append(pred.sum().unsqueeze(0)) continue loss_mask.append( self.loss_mask(pred, target, reduction_override='none')) if num_pos > 0: loss_mask = torch.cat(loss_mask).sum() / num_pos else: loss_mask = torch.cat(loss_mask).mean() flatten_labels = torch.cat(mlvl_labels) flatten_cls_preds = torch.cat(temp_mlvl_cls_preds) loss_cls = self.loss_cls( flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1) return dict(loss_mask=loss_mask, loss_cls=loss_cls) def _get_targets_single(self, gt_bboxes, gt_labels, gt_masks, featmap_sizes=None): """Compute targets for predictions of single image. Args: gt_bboxes (Tensor): Ground truth bbox of each instance, shape (num_gts, 4). gt_labels (Tensor): Ground truth label of each instance, shape (num_gts,). gt_masks (Tensor): Ground truth mask of each instance, shape (num_gts, h, w). featmap_sizes (list[:obj:`torch.size`]): Size of each feature map from feature pyramid, each element means (feat_h, feat_w). Default: None. Returns: Tuple: Usually returns a tuple containing targets for predictions. - mlvl_pos_mask_targets (list[Tensor]): Each element represent the binary mask targets for positive points in this level, has shape (num_pos, out_h, out_w). - mlvl_labels (list[Tensor]): Each element is classification labels for all points in this level, has shape (num_grid, num_grid). - mlvl_pos_masks (list[Tensor]): Each element is a `BoolTensor` to represent whether the corresponding point in single level is positive, has shape (num_grid **2). """ device = gt_labels.device gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (gt_bboxes[:, 3] - gt_bboxes[:, 1])) mlvl_pos_mask_targets = [] mlvl_labels = [] mlvl_pos_masks = [] for (lower_bound, upper_bound), stride, featmap_size, num_grid \ in zip(self.scale_ranges, self.strides, featmap_sizes, self.num_grids): mask_target = torch.zeros( [num_grid**2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes labels = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device) + self.num_classes pos_mask = torch.zeros([num_grid**2], dtype=torch.bool, device=device) gt_inds = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten() if len(gt_inds) == 0: mlvl_pos_mask_targets.append( mask_target.new_zeros(0, featmap_size[0], featmap_size[1])) mlvl_labels.append(labels) mlvl_pos_masks.append(pos_mask) continue hit_gt_bboxes = gt_bboxes[gt_inds] hit_gt_labels = gt_labels[gt_inds] hit_gt_masks = gt_masks[gt_inds, ...] pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] - hit_gt_bboxes[:, 0]) * self.pos_scale pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] - hit_gt_bboxes[:, 1]) * self.pos_scale # Make sure hit_gt_masks has a value valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0 output_stride = stride / 2 for gt_mask, gt_label, pos_h_range, pos_w_range, \ valid_mask_flag in \ zip(hit_gt_masks, hit_gt_labels, pos_h_ranges, pos_w_ranges, valid_mask_flags): if not valid_mask_flag: continue upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4) center_h, center_w = center_of_mass(gt_mask) coord_w = int( floordiv((center_w / upsampled_size[1]), (1. / num_grid), rounding_mode='trunc')) coord_h = int( floordiv((center_h / upsampled_size[0]), (1. / num_grid), rounding_mode='trunc')) # left, top, right, down top_box = max( 0, int( floordiv( (center_h - pos_h_range) / upsampled_size[0], (1. / num_grid), rounding_mode='trunc'))) down_box = min( num_grid - 1, int( floordiv( (center_h + pos_h_range) / upsampled_size[0], (1. / num_grid), rounding_mode='trunc'))) left_box = max( 0, int( floordiv( (center_w - pos_w_range) / upsampled_size[1], (1. / num_grid), rounding_mode='trunc'))) right_box = min( num_grid - 1, int( floordiv( (center_w + pos_w_range) / upsampled_size[1], (1. / num_grid), rounding_mode='trunc'))) top = max(top_box, coord_h - 1) down = min(down_box, coord_h + 1) left = max(coord_w - 1, left_box) right = min(right_box, coord_w + 1) labels[top:(down + 1), left:(right + 1)] = gt_label # ins gt_mask = np.uint8(gt_mask.cpu().numpy()) # Follow the original implementation, F.interpolate is # different from cv2 and opencv gt_mask = mmcv.imrescale(gt_mask, scale=1. / output_stride) gt_mask = torch.from_numpy(gt_mask).to(device=device) for i in range(top, down + 1): for j in range(left, right + 1): index = int(i * num_grid + j) mask_target[index, :gt_mask.shape[0], :gt_mask. shape[1]] = gt_mask pos_mask[index] = True mlvl_pos_mask_targets.append(mask_target[pos_mask]) mlvl_labels.append(labels) mlvl_pos_masks.append(pos_mask) return mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks def get_results(self, mlvl_mask_preds, mlvl_cls_scores, img_metas, **kwargs): """Get multi-image mask results. Args: mlvl_mask_preds (list[Tensor]): Multi-level mask prediction. Each element in the list has shape (batch_size, num_grids**2 ,h ,w). mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element in the list has shape (batch_size, num_classes, num_grids ,num_grids). img_metas (list[dict]): Meta information of all images. Returns: list[:obj:`InstanceData`]: Processed results of multiple images.Each :obj:`InstanceData` usually contains following keys. - scores (Tensor): Classification scores, has shape (num_instance,). - labels (Tensor): Has shape (num_instances,). - masks (Tensor): Processed mask results, has shape (num_instances, h, w). """ mlvl_cls_scores = [ item.permute(0, 2, 3, 1) for item in mlvl_cls_scores ] assert len(mlvl_mask_preds) == len(mlvl_cls_scores) num_levels = len(mlvl_cls_scores) results_list = [] for img_id in range(len(img_metas)): cls_pred_list = [ mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels) for lvl in range(num_levels) ] mask_pred_list = [ mlvl_mask_preds[lvl][img_id] for lvl in range(num_levels) ] cls_pred_list = torch.cat(cls_pred_list, dim=0) mask_pred_list = torch.cat(mask_pred_list, dim=0) results = self._get_results_single( cls_pred_list, mask_pred_list, img_meta=img_metas[img_id]) results_list.append(results) return results_list def _get_results_single(self, cls_scores, mask_preds, img_meta, cfg=None): """Get processed mask related results of single image. Args: cls_scores (Tensor): Classification score of all points in single image, has shape (num_points, num_classes). mask_preds (Tensor): Mask prediction of all points in single image, has shape (num_points, feat_h, feat_w). img_meta (dict): Meta information of corresponding image. cfg (dict, optional): Config used in test phase. Default: None. Returns: :obj:`InstanceData`: Processed results of single image. it usually contains following keys. - scores (Tensor): Classification scores, has shape (num_instance,). - labels (Tensor): Has shape (num_instances,). - masks (Tensor): Processed mask results, has shape (num_instances, h, w). """ def empty_results(results, cls_scores): """Generate a empty results.""" results.scores = cls_scores.new_ones(0) results.masks = cls_scores.new_zeros(0, *results.ori_shape[:2]) results.labels = cls_scores.new_ones(0) return results cfg = self.test_cfg if cfg is None else cfg assert len(cls_scores) == len(mask_preds) results = InstanceData(img_meta) featmap_size = mask_preds.size()[-2:] img_shape = results.img_shape ori_shape = results.ori_shape h, w, _ = img_shape upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4) score_mask = (cls_scores > cfg.score_thr) cls_scores = cls_scores[score_mask] if len(cls_scores) == 0: return empty_results(results, cls_scores) inds = score_mask.nonzero() cls_labels = inds[:, 1] # Filter the mask mask with an area is smaller than # stride of corresponding feature level lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0) strides = cls_scores.new_ones(lvl_interval[-1]) strides[:lvl_interval[0]] *= self.strides[0] for lvl in range(1, self.num_levels): strides[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= self.strides[lvl] strides = strides[inds[:, 0]] mask_preds = mask_preds[inds[:, 0]] masks = mask_preds > cfg.mask_thr sum_masks = masks.sum((1, 2)).float() keep = sum_masks > strides if keep.sum() == 0: return empty_results(results, cls_scores) masks = masks[keep] mask_preds = mask_preds[keep] sum_masks = sum_masks[keep] cls_scores = cls_scores[keep] cls_labels = cls_labels[keep] # maskness. mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks cls_scores *= mask_scores scores, labels, _, keep_inds = mask_matrix_nms( masks, cls_labels, cls_scores, mask_area=sum_masks, nms_pre=cfg.nms_pre, max_num=cfg.max_per_img, kernel=cfg.kernel, sigma=cfg.sigma, filter_thr=cfg.filter_thr) mask_preds = mask_preds[keep_inds] mask_preds = F.interpolate( mask_preds.unsqueeze(0), size=upsampled_size, mode='bilinear')[:, :, :h, :w] mask_preds = F.interpolate( mask_preds, size=ori_shape[:2], mode='bilinear').squeeze(0) masks = mask_preds > cfg.mask_thr results.masks = masks results.labels = labels results.scores = scores return results @HEADS.register_module() class DecoupledSOLOHead(SOLOHead): """Decoupled SOLO mask head used in `SOLO: Segmenting Objects by Locations. `_ Args: init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, *args, init_cfg=[ dict(type='Normal', layer='Conv2d', std=0.01), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_mask_list_x')), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_mask_list_y')), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_cls')) ], **kwargs): super(DecoupledSOLOHead, self).__init__( *args, init_cfg=init_cfg, **kwargs) def _init_layers(self): self.mask_convs_x = nn.ModuleList() self.mask_convs_y = nn.ModuleList() self.cls_convs = nn.ModuleList() for i in range(self.stacked_convs): chn = self.in_channels + 1 if i == 0 else self.feat_channels self.mask_convs_x.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, norm_cfg=self.norm_cfg)) self.mask_convs_y.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, norm_cfg=self.norm_cfg)) chn = self.in_channels if i == 0 else self.feat_channels self.cls_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, norm_cfg=self.norm_cfg)) self.conv_mask_list_x = nn.ModuleList() self.conv_mask_list_y = nn.ModuleList() for num_grid in self.num_grids: self.conv_mask_list_x.append( nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) self.conv_mask_list_y.append( nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) self.conv_cls = nn.Conv2d( self.feat_channels, self.cls_out_channels, 3, padding=1) def forward(self, feats): assert len(feats) == self.num_levels feats = self.resize_feats(feats) mask_preds_x = [] mask_preds_y = [] cls_preds = [] for i in range(self.num_levels): x = feats[i] mask_feat = x cls_feat = x # generate and concat the coordinate coord_feat = generate_coordinate(mask_feat.size(), mask_feat.device) mask_feat_x = torch.cat([mask_feat, coord_feat[:, 0:1, ...]], 1) mask_feat_y = torch.cat([mask_feat, coord_feat[:, 1:2, ...]], 1) for mask_layer_x, mask_layer_y in \ zip(self.mask_convs_x, self.mask_convs_y): mask_feat_x = mask_layer_x(mask_feat_x) mask_feat_y = mask_layer_y(mask_feat_y) mask_feat_x = F.interpolate( mask_feat_x, scale_factor=2, mode='bilinear') mask_feat_y = F.interpolate( mask_feat_y, scale_factor=2, mode='bilinear') mask_pred_x = self.conv_mask_list_x[i](mask_feat_x) mask_pred_y = self.conv_mask_list_y[i](mask_feat_y) # cls branch for j, cls_layer in enumerate(self.cls_convs): if j == self.cls_down_index: num_grid = self.num_grids[i] cls_feat = F.interpolate( cls_feat, size=num_grid, mode='bilinear') cls_feat = cls_layer(cls_feat) cls_pred = self.conv_cls(cls_feat) if not self.training: feat_wh = feats[0].size()[-2:] upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) mask_pred_x = F.interpolate( mask_pred_x.sigmoid(), size=upsampled_size, mode='bilinear') mask_pred_y = F.interpolate( mask_pred_y.sigmoid(), size=upsampled_size, mode='bilinear') cls_pred = cls_pred.sigmoid() # get local maximum local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) keep_mask = local_max[:, :, :-1, :-1] == cls_pred cls_pred = cls_pred * keep_mask mask_preds_x.append(mask_pred_x) mask_preds_y.append(mask_pred_y) cls_preds.append(cls_pred) return mask_preds_x, mask_preds_y, cls_preds def loss(self, mlvl_mask_preds_x, mlvl_mask_preds_y, mlvl_cls_preds, gt_labels, gt_masks, img_metas, gt_bboxes=None, **kwargs): """Calculate the loss of total batch. Args: mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction from x branch. Each element in the list has shape (batch_size, num_grids ,h ,w). mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction from y branch. Each element in the list has shape (batch_size, num_grids ,h ,w). mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element in the list has shape (batch_size, num_classes, num_grids ,num_grids). gt_labels (list[Tensor]): Labels of multiple images. gt_masks (list[Tensor]): Ground truth masks of multiple images. Each has shape (num_instances, h, w). img_metas (list[dict]): Meta information of multiple images. gt_bboxes (list[Tensor]): Ground truth bboxes of multiple images. Default: None. Returns: dict[str, Tensor]: A dictionary of loss components. """ num_levels = self.num_levels num_imgs = len(gt_labels) featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds_x] pos_mask_targets, labels, \ xy_pos_indexes = \ multi_apply(self._get_targets_single, gt_bboxes, gt_labels, gt_masks, featmap_sizes=featmap_sizes) # change from the outside list meaning multi images # to the outside list meaning multi levels mlvl_pos_mask_targets = [[] for _ in range(num_levels)] mlvl_pos_mask_preds_x = [[] for _ in range(num_levels)] mlvl_pos_mask_preds_y = [[] for _ in range(num_levels)] mlvl_labels = [[] for _ in range(num_levels)] for img_id in range(num_imgs): for lvl in range(num_levels): mlvl_pos_mask_targets[lvl].append( pos_mask_targets[img_id][lvl]) mlvl_pos_mask_preds_x[lvl].append( mlvl_mask_preds_x[lvl][img_id, xy_pos_indexes[img_id][lvl][:, 1]]) mlvl_pos_mask_preds_y[lvl].append( mlvl_mask_preds_y[lvl][img_id, xy_pos_indexes[img_id][lvl][:, 0]]) mlvl_labels[lvl].append(labels[img_id][lvl].flatten()) # cat multiple image temp_mlvl_cls_preds = [] for lvl in range(num_levels): mlvl_pos_mask_targets[lvl] = torch.cat( mlvl_pos_mask_targets[lvl], dim=0) mlvl_pos_mask_preds_x[lvl] = torch.cat( mlvl_pos_mask_preds_x[lvl], dim=0) mlvl_pos_mask_preds_y[lvl] = torch.cat( mlvl_pos_mask_preds_y[lvl], dim=0) mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0) temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute( 0, 2, 3, 1).reshape(-1, self.cls_out_channels)) num_pos = 0. # dice loss loss_mask = [] for pred_x, pred_y, target in \ zip(mlvl_pos_mask_preds_x, mlvl_pos_mask_preds_y, mlvl_pos_mask_targets): num_masks = pred_x.size(0) if num_masks == 0: # make sure can get grad loss_mask.append((pred_x.sum() + pred_y.sum()).unsqueeze(0)) continue num_pos += num_masks pred_mask = pred_y.sigmoid() * pred_x.sigmoid() loss_mask.append( self.loss_mask(pred_mask, target, reduction_override='none')) if num_pos > 0: loss_mask = torch.cat(loss_mask).sum() / num_pos else: loss_mask = torch.cat(loss_mask).mean() # cate flatten_labels = torch.cat(mlvl_labels) flatten_cls_preds = torch.cat(temp_mlvl_cls_preds) loss_cls = self.loss_cls( flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1) return dict(loss_mask=loss_mask, loss_cls=loss_cls) def _get_targets_single(self, gt_bboxes, gt_labels, gt_masks, featmap_sizes=None): """Compute targets for predictions of single image. Args: gt_bboxes (Tensor): Ground truth bbox of each instance, shape (num_gts, 4). gt_labels (Tensor): Ground truth label of each instance, shape (num_gts,). gt_masks (Tensor): Ground truth mask of each instance, shape (num_gts, h, w). featmap_sizes (list[:obj:`torch.size`]): Size of each feature map from feature pyramid, each element means (feat_h, feat_w). Default: None. Returns: Tuple: Usually returns a tuple containing targets for predictions. - mlvl_pos_mask_targets (list[Tensor]): Each element represent the binary mask targets for positive points in this level, has shape (num_pos, out_h, out_w). - mlvl_labels (list[Tensor]): Each element is classification labels for all points in this level, has shape (num_grid, num_grid). - mlvl_xy_pos_indexes (list[Tensor]): Each element in the list contains the index of positive samples in corresponding level, has shape (num_pos, 2), last dimension 2 present (index_x, index_y). """ mlvl_pos_mask_targets, mlvl_labels, \ mlvl_pos_masks = \ super()._get_targets_single(gt_bboxes, gt_labels, gt_masks, featmap_sizes=featmap_sizes) mlvl_xy_pos_indexes = [(item - self.num_classes).nonzero() for item in mlvl_labels] return mlvl_pos_mask_targets, mlvl_labels, mlvl_xy_pos_indexes def get_results(self, mlvl_mask_preds_x, mlvl_mask_preds_y, mlvl_cls_scores, img_metas, rescale=None, **kwargs): """Get multi-image mask results. Args: mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction from x branch. Each element in the list has shape (batch_size, num_grids ,h ,w). mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction from y branch. Each element in the list has shape (batch_size, num_grids ,h ,w). mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element in the list has shape (batch_size, num_classes ,num_grids ,num_grids). img_metas (list[dict]): Meta information of all images. Returns: list[:obj:`InstanceData`]: Processed results of multiple images.Each :obj:`InstanceData` usually contains following keys. - scores (Tensor): Classification scores, has shape (num_instance,). - labels (Tensor): Has shape (num_instances,). - masks (Tensor): Processed mask results, has shape (num_instances, h, w). """ mlvl_cls_scores = [ item.permute(0, 2, 3, 1) for item in mlvl_cls_scores ] assert len(mlvl_mask_preds_x) == len(mlvl_cls_scores) num_levels = len(mlvl_cls_scores) results_list = [] for img_id in range(len(img_metas)): cls_pred_list = [ mlvl_cls_scores[i][img_id].view( -1, self.cls_out_channels).detach() for i in range(num_levels) ] mask_pred_list_x = [ mlvl_mask_preds_x[i][img_id] for i in range(num_levels) ] mask_pred_list_y = [ mlvl_mask_preds_y[i][img_id] for i in range(num_levels) ] cls_pred_list = torch.cat(cls_pred_list, dim=0) mask_pred_list_x = torch.cat(mask_pred_list_x, dim=0) mask_pred_list_y = torch.cat(mask_pred_list_y, dim=0) results = self._get_results_single( cls_pred_list, mask_pred_list_x, mask_pred_list_y, img_meta=img_metas[img_id], cfg=self.test_cfg) results_list.append(results) return results_list def _get_results_single(self, cls_scores, mask_preds_x, mask_preds_y, img_meta, cfg): """Get processed mask related results of single image. Args: cls_scores (Tensor): Classification score of all points in single image, has shape (num_points, num_classes). mask_preds_x (Tensor): Mask prediction of x branch of all points in single image, has shape (sum_num_grids, feat_h, feat_w). mask_preds_y (Tensor): Mask prediction of y branch of all points in single image, has shape (sum_num_grids, feat_h, feat_w). img_meta (dict): Meta information of corresponding image. cfg (dict): Config used in test phase. Returns: :obj:`InstanceData`: Processed results of single image. it usually contains following keys. - scores (Tensor): Classification scores, has shape (num_instance,). - labels (Tensor): Has shape (num_instances,). - masks (Tensor): Processed mask results, has shape (num_instances, h, w). """ def empty_results(results, cls_scores): """Generate a empty results.""" results.scores = cls_scores.new_ones(0) results.masks = cls_scores.new_zeros(0, *results.ori_shape[:2]) results.labels = cls_scores.new_ones(0) return results cfg = self.test_cfg if cfg is None else cfg results = InstanceData(img_meta) img_shape = results.img_shape ori_shape = results.ori_shape h, w, _ = img_shape featmap_size = mask_preds_x.size()[-2:] upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4) score_mask = (cls_scores > cfg.score_thr) cls_scores = cls_scores[score_mask] inds = score_mask.nonzero() lvl_interval = inds.new_tensor(self.num_grids).pow(2).cumsum(0) num_all_points = lvl_interval[-1] lvl_start_index = inds.new_ones(num_all_points) num_grids = inds.new_ones(num_all_points) seg_size = inds.new_tensor(self.num_grids).cumsum(0) mask_lvl_start_index = inds.new_ones(num_all_points) strides = inds.new_ones(num_all_points) lvl_start_index[:lvl_interval[0]] *= 0 mask_lvl_start_index[:lvl_interval[0]] *= 0 num_grids[:lvl_interval[0]] *= self.num_grids[0] strides[:lvl_interval[0]] *= self.strides[0] for lvl in range(1, self.num_levels): lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ lvl_interval[lvl - 1] mask_lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ seg_size[lvl - 1] num_grids[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ self.num_grids[lvl] strides[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ self.strides[lvl] lvl_start_index = lvl_start_index[inds[:, 0]] mask_lvl_start_index = mask_lvl_start_index[inds[:, 0]] num_grids = num_grids[inds[:, 0]] strides = strides[inds[:, 0]] y_lvl_offset = (inds[:, 0] - lvl_start_index) // num_grids x_lvl_offset = (inds[:, 0] - lvl_start_index) % num_grids y_inds = mask_lvl_start_index + y_lvl_offset x_inds = mask_lvl_start_index + x_lvl_offset cls_labels = inds[:, 1] mask_preds = mask_preds_x[x_inds, ...] * mask_preds_y[y_inds, ...] masks = mask_preds > cfg.mask_thr sum_masks = masks.sum((1, 2)).float() keep = sum_masks > strides if keep.sum() == 0: return empty_results(results, cls_scores) masks = masks[keep] mask_preds = mask_preds[keep] sum_masks = sum_masks[keep] cls_scores = cls_scores[keep] cls_labels = cls_labels[keep] # maskness. mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks cls_scores *= mask_scores scores, labels, _, keep_inds = mask_matrix_nms( masks, cls_labels, cls_scores, mask_area=sum_masks, nms_pre=cfg.nms_pre, max_num=cfg.max_per_img, kernel=cfg.kernel, sigma=cfg.sigma, filter_thr=cfg.filter_thr) mask_preds = mask_preds[keep_inds] mask_preds = F.interpolate( mask_preds.unsqueeze(0), size=upsampled_size, mode='bilinear')[:, :, :h, :w] mask_preds = F.interpolate( mask_preds, size=ori_shape[:2], mode='bilinear').squeeze(0) masks = mask_preds > cfg.mask_thr results.masks = masks results.labels = labels results.scores = scores return results @HEADS.register_module() class DecoupledSOLOLightHead(DecoupledSOLOHead): """Decoupled Light SOLO mask head used in `SOLO: Segmenting Objects by Locations `_ Args: with_dcn (bool): Whether use dcn in mask_convs and cls_convs, default: False. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, *args, dcn_cfg=None, init_cfg=[ dict(type='Normal', layer='Conv2d', std=0.01), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_mask_list_x')), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_mask_list_y')), dict( type='Normal', std=0.01, bias_prob=0.01, override=dict(name='conv_cls')) ], **kwargs): assert dcn_cfg is None or isinstance(dcn_cfg, dict) self.dcn_cfg = dcn_cfg super(DecoupledSOLOLightHead, self).__init__( *args, init_cfg=init_cfg, **kwargs) def _init_layers(self): self.mask_convs = nn.ModuleList() self.cls_convs = nn.ModuleList() for i in range(self.stacked_convs): if self.dcn_cfg is not None\ and i == self.stacked_convs - 1: conv_cfg = self.dcn_cfg else: conv_cfg = None chn = self.in_channels + 2 if i == 0 else self.feat_channels self.mask_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=self.norm_cfg)) chn = self.in_channels if i == 0 else self.feat_channels self.cls_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=self.norm_cfg)) self.conv_mask_list_x = nn.ModuleList() self.conv_mask_list_y = nn.ModuleList() for num_grid in self.num_grids: self.conv_mask_list_x.append( nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) self.conv_mask_list_y.append( nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) self.conv_cls = nn.Conv2d( self.feat_channels, self.cls_out_channels, 3, padding=1) def forward(self, feats): assert len(feats) == self.num_levels feats = self.resize_feats(feats) mask_preds_x = [] mask_preds_y = [] cls_preds = [] for i in range(self.num_levels): x = feats[i] mask_feat = x cls_feat = x # generate and concat the coordinate coord_feat = generate_coordinate(mask_feat.size(), mask_feat.device) mask_feat = torch.cat([mask_feat, coord_feat], 1) for mask_layer in self.mask_convs: mask_feat = mask_layer(mask_feat) mask_feat = F.interpolate( mask_feat, scale_factor=2, mode='bilinear') mask_pred_x = self.conv_mask_list_x[i](mask_feat) mask_pred_y = self.conv_mask_list_y[i](mask_feat) # cls branch for j, cls_layer in enumerate(self.cls_convs): if j == self.cls_down_index: num_grid = self.num_grids[i] cls_feat = F.interpolate( cls_feat, size=num_grid, mode='bilinear') cls_feat = cls_layer(cls_feat) cls_pred = self.conv_cls(cls_feat) if not self.training: feat_wh = feats[0].size()[-2:] upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) mask_pred_x = F.interpolate( mask_pred_x.sigmoid(), size=upsampled_size, mode='bilinear') mask_pred_y = F.interpolate( mask_pred_y.sigmoid(), size=upsampled_size, mode='bilinear') cls_pred = cls_pred.sigmoid() # get local maximum local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) keep_mask = local_max[:, :, :-1, :-1] == cls_pred cls_pred = cls_pred * keep_mask mask_preds_x.append(mask_pred_x) mask_preds_y.append(mask_pred_y) cls_preds.append(cls_pred) return mask_preds_x, mask_preds_y, cls_preds