批归一化(Batch Normalization)因其可以加速神经网络训练、使网络训练更稳定,而且还有一定的正则化效果,所以得到了非常广泛的应用。但是,在推理阶段,BN层一般是可以完全融合到前面的卷积层的,而且丝毫不影响性能。
作者:元峰
Batch Normalization是谷歌研究员于2015年提出的一种归一化方法,其思想非常简单,一句话概括就是,对一个神经元(或者一个卷积核)的输出减去统计得到的均值,除以标准差,然后乘以一个可学习的系数,再加上一个偏置,这个过程就完成了。
下面我们简单介绍一下BN训练时怎么做,推理的时候为什么可以融合,以及怎么样融合。
一. BN训练时如何做
二. BN推理时怎么做
三. 在框架中如何融合
下面是来自博文[1]中的一个PyTorch例子,将ResNet18中一个卷积+BN层融合后,融合前后输出的差值为-6.10425390790148e-11,也就是误差在百亿分之一,基本就是0了。
import torch
import torchvision
def fuse(conv, bn):
fused = torch.nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=True
)
# setting weights
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
fused.weight.copy_( torch.mm(w_bn, w_conv).view(fused.weight.size()) )
# setting bias
if conv.bias is not None:
b_conv = conv.bias
else:
b_conv = torch.zeros( conv.weight.size(0) )
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
torch.sqrt(bn.running_var + bn.eps)
)
fused.bias.copy_( b_conv + b_bn )
return fused
# Testing
# we need to turn off gradient calculation because we didn't write it
torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
resnet18 = torchvision.models.resnet18(pretrained=True)
# removing all learning variables, etc
resnet18.eval()
model = torch.nn.Sequential(
resnet18.conv1,
resnet18.bn1
)
f1 = model.forward(x)
fused = fuse(model[0], model[1])
f2 = fused.forward(x)
d = (f1 - f2).mean().item()
print("error:",d)
因为这么一个线性操作的转换,如果有误差,那才真是见鬼了呢。
关于其他框架,如Keras、Caffe、TensorFlow的操作,与PyTorch基本一个原理,大家可以自己试验一下。
笔者在测试时候,发现融合掉BN后,会有大概5%的提速,而且还可以减小显存消耗,又丝毫不影响误差,何乐而不为呢。
但是,融合BN仅限于Conv+BN或者是BN+Conv结构,中间不能加非线性层,例如Conv+ReLu+BN那就不行了。当然,一般结构都是Conv+BN+ReLu结构。
推荐阅读
关于作者
我是元峰,互联网+AI领域的创业者,欢迎在微信搜索“AIZOO”关注我们的公众号AIZOO。
如果您是有算法需求,例如目标检测、人脸识别、缺陷检测、行人检测的算法需求,欢迎添加我们的微信号AIZOOTech与我们交流,我们团队是一群算法工程师的创业团队,会以高效、稳定、高性价比的产品满足您的需求。
如果您是算法或者开发工程师,也可以添加我们的微信号AIXZOOTech,请备注学校or公司名称-研究方向-昵称,例如“西电-图像算法-元峰”,元峰会拉您进我们的算法交流群,一起交流算法和开发的知识,以及对接项目。
更多算法模型相关请关注AIZOO专栏