|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" |
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
def fma(a, b, c): |
|
return _FusedMultiplyAdd.apply(a, b, c) |
|
|
|
|
|
|
|
|
|
|
|
class _FusedMultiplyAdd(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, a, b, c): |
|
out = torch.addcmul(c, a, b) |
|
ctx.save_for_backward(a, b) |
|
ctx.c_shape = c.shape |
|
return out |
|
|
|
@staticmethod |
|
def backward(ctx, dout): |
|
a, b = ctx.saved_tensors |
|
c_shape = ctx.c_shape |
|
da = None |
|
db = None |
|
dc = None |
|
|
|
if ctx.needs_input_grad[0]: |
|
da = _unbroadcast(dout * b, a.shape) |
|
|
|
if ctx.needs_input_grad[1]: |
|
db = _unbroadcast(dout * a, b.shape) |
|
|
|
if ctx.needs_input_grad[2]: |
|
dc = _unbroadcast(dout, c_shape) |
|
|
|
return da, db, dc |
|
|
|
|
|
|
|
|
|
|
|
def _unbroadcast(x, shape): |
|
extra_dims = x.ndim - len(shape) |
|
assert extra_dims >= 0 |
|
dim = [ |
|
i for i in range(x.ndim) |
|
if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1) |
|
] |
|
if len(dim): |
|
x = x.sum(dim=dim, keepdim=True) |
|
if extra_dims: |
|
x = x.reshape(-1, *x.shape[extra_dims + 1:]) |
|
assert x.shape == shape |
|
return x |
|
|
|
|
|
|
|
|