|
import torch |
|
from torch import nn |
|
|
|
|
|
def fuse_conv_and_bn(conv, bn): |
|
|
|
fusedconv = ( |
|
nn.Conv2d( |
|
conv.in_channels, |
|
conv.out_channels, |
|
kernel_size=conv.kernel_size, |
|
stride=conv.stride, |
|
padding=conv.padding, |
|
groups=conv.groups, |
|
bias=True, |
|
) |
|
.requires_grad_(False) |
|
.to(conv.weight.device) |
|
) |
|
|
|
|
|
w_conv = conv.weight.clone().view(conv.out_channels, -1) |
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) |
|
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) |
|
|
|
|
|
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias |
|
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) |
|
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) |
|
|
|
return fusedconv |
|
|
|
|
|
def copy_attr(a, b, include=(), exclude=()): |
|
|
|
for k, v in b.__dict__.items(): |
|
if (include and k not in include) or k.startswith("_") or k in exclude: |
|
continue |
|
|
|
setattr(a, k, v) |
|
|