RockeyCoss
add code files”
51f6859
raw
history blame contribute delete
No virus
14.3 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmcv.runner import force_fp32
from ..builder import HEADS
from ..losses import smooth_l1_loss
from .ascend_anchor_head import AscendAnchorHead
from .ssd_head import SSDHead
@HEADS.register_module()
class AscendSSDHead(SSDHead, AscendAnchorHead):
"""Ascend SSD head used in https://arxiv.org/abs/1512.02325.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
stacked_convs (int): Number of conv layers in cls and reg tower.
Default: 0.
feat_channels (int): Number of hidden channels when stacked_convs
> 0. Default: 256.
use_depthwise (bool): Whether to use DepthwiseSeparableConv.
Default: False.
conv_cfg (dict): Dictionary to construct and config conv layer.
Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: None.
act_cfg (dict): Dictionary to construct and config activation layer.
Default: None.
anchor_generator (dict): Config dict for anchor generator
bbox_coder (dict): Config of bounding box coder.
reg_decoded_bbox (bool): If true, the regression loss would be
applied directly on decoded bounding boxes, converting both
the predicted boxes and regression targets to absolute
coordinates format. Default False. It should be `True` when
using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
train_cfg (dict): Training config of anchor head.
test_cfg (dict): Testing config of anchor head.
init_cfg (dict or list[dict], optional): Initialization config dict.
""" # noqa: W605
def __init__(self,
num_classes=80,
in_channels=(512, 1024, 512, 256, 256, 256),
stacked_convs=0,
feat_channels=256,
use_depthwise=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
anchor_generator=dict(
type='SSDAnchorGenerator',
scale_major=False,
input_size=300,
strides=[8, 16, 32, 64, 100, 300],
ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
basesize_ratio_range=(0.1, 0.9)),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
clip_border=True,
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
),
reg_decoded_bbox=False,
train_cfg=None,
test_cfg=None,
init_cfg=dict(
type='Xavier',
layer='Conv2d',
distribution='uniform',
bias=0)):
super(AscendSSDHead, self).__init__(
num_classes=num_classes,
in_channels=in_channels,
stacked_convs=stacked_convs,
feat_channels=feat_channels,
use_depthwise=use_depthwise,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
anchor_generator=anchor_generator,
bbox_coder=bbox_coder,
reg_decoded_bbox=reg_decoded_bbox,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
assert self.reg_decoded_bbox is False, \
'reg_decoded_bbox only support False now.'
def get_static_anchors(self, featmap_sizes, img_metas, device='cuda'):
"""Get static anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
device (torch.device | str): Device for returned tensors
Returns:
tuple:
anchor_list (list[Tensor]): Anchors of each image.
valid_flag_list (list[Tensor]): Valid flags of each image.
"""
if not hasattr(self, 'static_anchors') or \
not hasattr(self, 'static_valid_flags'):
static_anchors, static_valid_flags = self.get_anchors(
featmap_sizes, img_metas, device)
self.static_anchors = static_anchors
self.static_valid_flags = static_valid_flags
return self.static_anchors, self.static_valid_flags
def get_targets(self,
anchor_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
unmap_outputs=True,
return_sampling_results=False,
return_level=True):
"""Compute regression and classification targets for anchors in
multiple images.
Args:
anchor_list (list[list[Tensor]]): Multi level anchors of each
image. The outer list indicates images, and the inner list
corresponds to feature levels of the image. Each element of
the inner list is a tensor of shape (num_anchors, 4).
valid_flag_list (list[list[Tensor]]): Multi level valid flags of
each image. The outer list indicates images, and the inner list
corresponds to feature levels of the image. Each element of
the inner list is a tensor of shape (num_anchors, )
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
img_metas (list[dict]): Meta info of each image.
gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
ignored.
gt_labels_list (list[Tensor]): Ground truth labels of each box.
label_channels (int): Channel of label.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
return_sampling_results (bool): Whether to return the result of
sample.
return_level (bool): Whether to map outputs back to the levels
of feature map sizes.
Returns:
tuple: Usually returns a tuple containing learning targets.
- labels_list (list[Tensor]): Labels of each level.
- label_weights_list (list[Tensor]): Label weights of each
level.
- bbox_targets_list (list[Tensor]): BBox targets of each level.
- bbox_weights_list (list[Tensor]): BBox weights of each level.
- num_total_pos (int): Number of positive samples in all
images.
- num_total_neg (int): Number of negative samples in all
images.
additional_returns: This function enables user-defined returns from
`self._get_targets_single`. These returns are currently refined
to properties at each feature map (i.e. having HxW dimension).
The results will be concatenated after the end
"""
return AscendAnchorHead.get_targets(
self,
anchor_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list,
gt_labels_list,
label_channels,
unmap_outputs,
return_sampling_results,
return_level,
)
def batch_loss(self, batch_cls_score, batch_bbox_pred, batch_anchor,
batch_labels, batch_label_weights, batch_bbox_targets,
batch_bbox_weights, batch_pos_mask, batch_neg_mask,
num_total_samples):
"""Compute loss of all images.
Args:
batch_cls_score (Tensor): Box scores for all image
Has shape (num_imgs, num_total_anchors, num_classes).
batch_bbox_pred (Tensor): Box energies / deltas for all image
level with shape (num_imgs, num_total_anchors, 4).
batch_anchor (Tensor): Box reference for all image with shape
(num_imgs, num_total_anchors, 4).
batch_labels (Tensor): Labels of all anchors with shape
(num_imgs, num_total_anchors,).
batch_label_weights (Tensor): Label weights of all anchor with
shape (num_imgs, num_total_anchors,)
batch_bbox_targets (Tensor): BBox regression targets of all anchor
weight shape (num_imgs, num_total_anchors, 4).
batch_bbox_weights (Tensor): BBox regression loss weights of
all anchor with shape (num_imgs, num_total_anchors, 4).
batch_pos_mask (Tensor): Positive samples mask in all images.
batch_neg_mask (Tensor): negative samples mask in all images.
num_total_samples (int): If sampling, num total samples equal to
the number of total anchors; Otherwise, it is the number of
positive anchors.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_images, num_anchors, _ = batch_anchor.size()
batch_loss_cls_all = F.cross_entropy(
batch_cls_score.view((-1, self.cls_out_channels)),
batch_labels.view(-1),
reduction='none').view(
batch_label_weights.size()) * batch_label_weights
# # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
batch_num_pos_samples = torch.sum(batch_pos_mask, dim=1)
batch_num_neg_samples = \
self.train_cfg.neg_pos_ratio * batch_num_pos_samples
batch_num_neg_samples_max = torch.sum(batch_neg_mask, dim=1)
batch_num_neg_samples = torch.min(batch_num_neg_samples,
batch_num_neg_samples_max)
batch_topk_loss_cls_neg, _ = torch.topk(
batch_loss_cls_all * batch_neg_mask, k=num_anchors, dim=1)
batch_loss_cls_pos = torch.sum(
batch_loss_cls_all * batch_pos_mask, dim=1)
anchor_index = torch.arange(
end=num_anchors, dtype=torch.float,
device=batch_anchor.device).view((1, -1))
topk_loss_neg_mask = (anchor_index < batch_num_neg_samples.view(
-1, 1)).float()
batch_loss_cls_neg = torch.sum(
batch_topk_loss_cls_neg * topk_loss_neg_mask, dim=1)
loss_cls = \
(batch_loss_cls_pos + batch_loss_cls_neg) / num_total_samples
if self.reg_decoded_bbox:
# TODO: support self.reg_decoded_bbox is True
raise RuntimeError
loss_bbox_all = smooth_l1_loss(
batch_bbox_pred,
batch_bbox_targets,
batch_bbox_weights,
reduction='none',
beta=self.train_cfg.smoothl1_beta,
avg_factor=num_total_samples)
eps = torch.finfo(torch.float32).eps
sum_dim = (i for i in range(1, len(loss_bbox_all.size())))
loss_bbox = loss_bbox_all.sum(tuple(sum_dim)) / (
num_total_samples + eps)
return loss_cls[None], loss_bbox
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
gt_bboxes (list[Tensor]): each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.prior_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=1,
unmap_outputs=True,
return_level=False)
if cls_reg_targets is None:
return None
(batch_labels, batch_label_weights, batch_bbox_targets,
batch_bbox_weights, batch_pos_mask, batch_neg_mask, sampling_result,
num_total_pos, num_total_neg, batch_anchors) = cls_reg_targets
num_imgs = len(img_metas)
batch_cls_score = torch.cat([
s.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.cls_out_channels)
for s in cls_scores
], 1)
batch_bbox_pred = torch.cat([
b.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for b in bbox_preds
], -2)
batch_losses_cls, batch_losses_bbox = self.batch_loss(
batch_cls_score, batch_bbox_pred, batch_anchors, batch_labels,
batch_label_weights, batch_bbox_targets, batch_bbox_weights,
batch_pos_mask, batch_neg_mask, num_total_pos)
losses_cls = [
batch_losses_cls[:, index_imgs] for index_imgs in range(num_imgs)
]
losses_bbox = [losses_bbox for losses_bbox in batch_losses_bbox]
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)