# Copyright (c) OpenMMLab. All rights reserved. from torch.autograd import Function from torch.nn import functional as F class SigmoidGeometricMean(Function): """Forward and backward function of geometric mean of two sigmoid functions. This implementation with analytical gradient function substitutes the autograd function of (x.sigmoid() * y.sigmoid()).sqrt(). The original implementation incurs none during gradient backprapagation if both x and y are very small values. """ @staticmethod def forward(ctx, x, y): x_sigmoid = x.sigmoid() y_sigmoid = y.sigmoid() z = (x_sigmoid * y_sigmoid).sqrt() ctx.save_for_backward(x_sigmoid, y_sigmoid, z) return z @staticmethod def backward(ctx, grad_output): x_sigmoid, y_sigmoid, z = ctx.saved_tensors grad_x = grad_output * z * (1 - x_sigmoid) / 2 grad_y = grad_output * z * (1 - y_sigmoid) / 2 return grad_x, grad_y sigmoid_geometric_mean = SigmoidGeometricMean.apply def interpolate_as(source, target, mode='bilinear', align_corners=False): """Interpolate the `source` to the shape of the `target`. The `source` must be a Tensor, but the `target` can be a Tensor or a np.ndarray with the shape (..., target_h, target_w). Args: source (Tensor): A 3D/4D Tensor with the shape (N, H, W) or (N, C, H, W). target (Tensor | np.ndarray): The interpolation target with the shape (..., target_h, target_w). mode (str): Algorithm used for interpolation. The options are the same as those in F.interpolate(). Default: ``'bilinear'``. align_corners (bool): The same as the argument in F.interpolate(). Returns: Tensor: The interpolated source Tensor. """ assert len(target.shape) >= 2 def _interpolate_as(source, target, mode='bilinear', align_corners=False): """Interpolate the `source` (4D) to the shape of the `target`.""" target_h, target_w = target.shape[-2:] source_h, source_w = source.shape[-2:] if target_h != source_h or target_w != source_w: source = F.interpolate( source, size=(target_h, target_w), mode=mode, align_corners=align_corners) return source if len(source.shape) == 3: source = source[:, None, :, :] source = _interpolate_as(source, target, mode, align_corners) return source[:, 0, :, :] else: return _interpolate_as(source, target, mode, align_corners)