# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" import torch #---------------------------------------------------------------------------- def fma(a, b, c): # => a * b + c return _FusedMultiplyAdd.apply(a, b, c) #---------------------------------------------------------------------------- class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c @staticmethod def forward(ctx, a, b, c): # pylint: disable=arguments-differ out = torch.addcmul(c, a, b) ctx.save_for_backward(a, b) ctx.c_shape = c.shape return out @staticmethod def backward(ctx, dout): # pylint: disable=arguments-differ a, b = ctx.saved_tensors c_shape = ctx.c_shape da = None db = None dc = None if ctx.needs_input_grad[0]: da = _unbroadcast(dout * b, a.shape) if ctx.needs_input_grad[1]: db = _unbroadcast(dout * a, b.shape) if ctx.needs_input_grad[2]: dc = _unbroadcast(dout, c_shape) return da, db, dc #---------------------------------------------------------------------------- def _unbroadcast(x, shape): extra_dims = x.ndim - len(shape) assert extra_dims >= 0 dim = [ i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1) ] if len(dim): x = x.sum(dim=dim, keepdim=True) if extra_dims: x = x.reshape(-1, *x.shape[extra_dims + 1:]) assert x.shape == shape return x #----------------------------------------------------------------------------