BatchNorm已成为主流DCNN的标配,它不仅不可以加速模型训练还是提升模型效果。但是,众所周知:BN对于batch size极为敏感,越小性能越差。为解决该问题,已有各种方法被提出用于解决上述问题,比如LayerNorm, InstanceNor, GroupNorm等等。但是这些方法在大batch size下无法超越BN的性能,这无疑是这些方法的弊端所在。
谷歌与Face++的研究员在这方面取得了突破性的进展:全面超越BatchNorm。
首发知乎:https://zhuanlan.zhihu.com/p/106247003
文章作者: Happy
FRN
链接:https://arxiv.org/abs/1911.09737
import torch
import torch.nn as nn
class FRN(nn.Module):
def __ini__(self, num_features, eps=1e-6, learnable_eps=False):
super().__init__()
shape = (1, num_features, 1, 1)
self.eps = nn.Parameter(torch.ones(*shape) * eps)
if not learnable_eps:
self.eps.requires_grad_(Flase)
self.gamma = nn.Parameter(torch.Tensor(*shape))
self.beta = nn.Parameter(torch.Tensor(*shape))
self.tau = nn.Parameter(torch.Tensor(*shape))
self.reset_parameters()
def forward(self, x):
avg_dims = tuple(range(2, x.dim()))
nu2 = torch.pow(x, 2).mean(dim=avg_dims, keepdim=True)
x = x * torch.rsqrt(nu2 + torch.abs(self.eps))
return torch.max(self.gamma * x + self.beta, self.tau)
def reset_parameters():
nn.init.ones_(self.gamma)
nn.init.ones_(self.beta)
nn.init.ones_(self.tau)
MABN
论文链接:https://openreview.net/forum?id=SkgGjRVKDS¬eId=BJeCWt3KiH
import torch
import torch.nn as nn
import torch.nn.functional as F
class MABNFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias,
running_var, eps, momentum,
sta_matrix, pre_x2, pre_gz, iters
):
ctx.eps = eps
current_iter = iters.item()
ctx.iter = current_iter
N, C, H, W = x.size()
x = x.view(N//2, 2, C, H, W)
x2 = (x * x).mean(dim=4).mean(dim=3).mean(dim=1)
var = torch.cat([pre_x2, x2], dim=0)
var = torch.mm(sta_matrix, var)
var = var.view(N//2, 1, C, 1, 1)
if current_iter == 1:
var = x2.view(N//2, 1, C, 1, 1)
z = x /(var + eps).sqrt()
r = (var + eps).sqrt() / (running_var.view(1, 1, C, 1, 1) + eps).sqrt()
if current_iter < 100:
r = torch.clamp(r, 1, 1)
else:
r = torch.clamp(r, 1/5, 5)
y = r * z
ctx.save_for_backward(z, var, weight, sta_matrix, pre_gz, r)
if current_iter == 1:
running_var.copy_(var.mean(dim=0).view(-1,))
running_var.copy_(momentum*running_var + (1-momentum)*var.mean(dim=0).view(-1,))
pre_x2.copy_(x2)
y = weight.view(1,C,1,1) * y.view(N, C, H, W) + bias.view(1,C,1,1)
return y
@staticmethod
def backward(ctx, grad_output):
eps = ctx.eps
current_iter = ctx.iter
N, C, H, W = grad_output.size()
z, var, weight, sta_matrix, pre_gz, r = ctx.saved_variables
y = r * z
g = grad_output * weight.view(1, C, 1, 1)
g = g.view(N//2, 2, C, H, W) * r
gz = (g * z).mean(dim=4).mean(dim=3).mean(dim=1)
mean_gz = torch.cat([pre_gz, gz], dim=0)
mean_gz = torch.mm(sta_matrix, mean_gz)
mean_gz = mean_gz.view(N//2, 1, C, 1, 1)
if current_iter == 1:
mean_gz = gz.view(N//2, 1, C, 1, 1)
gx = 1. / torch.sqrt(var + eps) * (g - z * mean_gz)
gx = gx.view(N, C, H, W)
pre_gz.copy_(gz)
return gx, (grad_output * y.view(N, C, H, W)).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), None, None, None, None, None, None, None
class MABN2d(nn.Module):
def __init__(self, channels, eps=1e-5, momentum=0.98, buffer_size=16):
"""
buffer_size: Moving Average Batch Size / Normalization Batch Size
running_var: EMA statistics of x^2
buffer_x2: batch statistics of x^2 from last several iters
buffer_gz: batch statistics of phi from last several iters
iters: current iter
"""
super(MABN2d, self).__init__()
self.B = buffer_size
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
self.register_buffer('running_var', torch.ones(channels))
self.register_buffer('sta_matrix', torch.ones(self.B, 2 *self.B)/self.B)
self.register_buffer('pre_x2', torch.ones(self.B, channels))
self.register_buffer('pre_gz', torch.zeros(self.B, channels))
self.register_buffer('iters', torch.zeros(1,))
self.eps = eps
self.momentum = momentum
self.init()
def init(self):
for i in range(self.sta_matrix.size(0)):
self.sta_matrix[i][:i+1] = 0
self.sta_matrix[i][self.B+i+1:] = 0
def forward(self, x):
if self.training:
self.iters.copy_(self.iters + 1)
x = MABNFunction.apply(x, self.weight, self.bias,
self.running_var, self.eps,
self.momentum, self.sta_matrix,
self.pre_x2, self.pre_gz,
self.iters)
return x
else:
N, C, H, W = x.size()
var = self.running_var.view(1, C, 1, 1)
x = x / (var + self.eps).sqrt()
return self.weight.view(1,C,1,1) * x + self.bias.view(1,C,1,1)
class CenConv2d(nn.Module):
"""Conv2d layer with Weight Centralization
"""
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1,
padding=0, dilation=1, groups=1, bias=False):
super(CenConv2d, self).__init__()
self.in_planes = in_planes
self.out_planes = out_planes
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.weight = nn.Parameter(torch.randn(out_planes, in_planes//groups, kernel_size, kernel_size))
if bias:
self.bias = nn.Parameter(torch.randn(out_planes))
else:
self.register_parameter('bias', None)
def forward(self, x):
weight = self.weight
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
weight = weight - weight_mean
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
分析
为什么将两者放到一起呢?这是因为两者具有异曲同工之妙。
对比两者的实现方式,是不是发现两者都采用了了类似的计算?区别在于:FRN采用的实时计算,而MABN则采用在训练集中统计的方式。
FRN与MABN的另一点区别:FRN采用了新的激活函数,而MABN则采用了新的卷积计算方式(权值去中心化)。总体来看,MABN的通用性更佳,当然实际情况需要靠时间去检验,期待各位大神的更多尝试。
推荐阅读:
本文章著作权归作者所有,任何形式的转载都请注明出处。更多动态滤波,图像质量,超分辨相关请关注我的专栏深度学习从入门到精通。