Happy · 3月26日

超越BatchNorm你不可不知的两种方案:FRN与MABN

BatchNorm已成为主流DCNN的标配,它不仅不可以加速模型训练还是提升模型效果。但是,众所周知:BN对于batch size极为敏感,越小性能越差。为解决该问题,已有各种方法被提出用于解决上述问题,比如LayerNorm, InstanceNor, GroupNorm等等。但是这些方法在大batch size下无法超越BN的性能,这无疑是这些方法的弊端所在。

谷歌与Face++的研究员在这方面取得了突破性的进展:全面超越BatchNorm。

首发知乎:https://zhuanlan.zhihu.com/p/106247003
文章作者: Happy

FRN

链接:https://arxiv.org/abs/1911.09737
v2-858e490c58e39fab6c31b2e756e98d72_720w.png

import torch
import torch.nn as nn

class FRN(nn.Module):
    def __ini__(self, num_features, eps=1e-6, learnable_eps=False):
        super().__init__()
        shape = (1, num_features, 1, 1) 
        self.eps = nn.Parameter(torch.ones(*shape) * eps)
        if not learnable_eps:
            self.eps.requires_grad_(Flase)
        self.gamma = nn.Parameter(torch.Tensor(*shape))
        self.beta = nn.Parameter(torch.Tensor(*shape))
        self.tau = nn.Parameter(torch.Tensor(*shape))
        self.reset_parameters()

    def forward(self, x):
        avg_dims = tuple(range(2, x.dim()))
        nu2 = torch.pow(x, 2).mean(dim=avg_dims, keepdim=True)
        x = x * torch.rsqrt(nu2 + torch.abs(self.eps))
        return torch.max(self.gamma * x + self.beta, self.tau)

    def reset_parameters():
        nn.init.ones_(self.gamma)
        nn.init.ones_(self.beta)
        nn.init.ones_(self.tau)

MABN

论文链接:https://openreview.net/forum?id=SkgGjRVKDS¬eId=BJeCWt3KiH

Image

import torch
import torch.nn as nn
import torch.nn.functional as F

class MABNFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, weight, bias,
                running_var, eps, momentum,
                sta_matrix, pre_x2, pre_gz, iters
               ):
        ctx.eps = eps
        current_iter = iters.item()
        ctx.iter = current_iter
        N, C, H, W = x.size()

        x = x.view(N//2, 2, C, H, W)
        x2 = (x * x).mean(dim=4).mean(dim=3).mean(dim=1)
        var = torch.cat([pre_x2, x2], dim=0)

        var = torch.mm(sta_matrix, var)
        var = var.view(N//2, 1, C, 1, 1)

        if current_iter == 1:
            var = x2.view(N//2, 1, C, 1, 1)

        z = x /(var + eps).sqrt()
        r = (var + eps).sqrt() / (running_var.view(1, 1, C, 1, 1) + eps).sqrt()
        if current_iter < 100:
            r = torch.clamp(r, 1, 1)
        else:
            r = torch.clamp(r, 1/5, 5)
        y = r * z
        ctx.save_for_backward(z, var, weight, sta_matrix, pre_gz, r)

        if current_iter == 1:
            running_var.copy_(var.mean(dim=0).view(-1,))
        running_var.copy_(momentum*running_var + (1-momentum)*var.mean(dim=0).view(-1,))
        pre_x2.copy_(x2)
        y = weight.view(1,C,1,1) * y.view(N, C, H, W) + bias.view(1,C,1,1)

        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps
        current_iter = ctx.iter
        N, C, H, W = grad_output.size()
        z, var, weight, sta_matrix, pre_gz, r  = ctx.saved_variables
        y = r * z
        g = grad_output * weight.view(1, C, 1, 1)
        g = g.view(N//2, 2, C, H, W) * r
        gz = (g * z).mean(dim=4).mean(dim=3).mean(dim=1)

        mean_gz = torch.cat([pre_gz, gz], dim=0)
        mean_gz = torch.mm(sta_matrix, mean_gz)
        mean_gz = mean_gz.view(N//2, 1, C, 1, 1)

        if current_iter == 1:
            mean_gz = gz.view(N//2, 1, C, 1, 1)
        gx = 1. / torch.sqrt(var + eps) * (g - z * mean_gz)
        gx = gx.view(N, C, H, W)
        pre_gz.copy_(gz)

        return gx, (grad_output * y.view(N, C, H, W)).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0),  None, None, None, None, None, None, None

class MABN2d(nn.Module):

    def __init__(self, channels, eps=1e-5, momentum=0.98, buffer_size=16):
        """
            buffer_size: Moving Average Batch Size / Normalization Batch Size
            running_var: EMA statistics of x^2
            buffer_x2: batch statistics of x^2 from last several iters
            buffer_gz: batch statistics of phi from last several iters
            iters: current iter
        """
        super(MABN2d, self).__init__()
        self.B = buffer_size
        self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
        self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
        self.register_buffer('running_var', torch.ones(channels))
        self.register_buffer('sta_matrix', torch.ones(self.B, 2 *self.B)/self.B)
        self.register_buffer('pre_x2', torch.ones(self.B, channels))
        self.register_buffer('pre_gz', torch.zeros(self.B, channels))
        self.register_buffer('iters', torch.zeros(1,))
        self.eps = eps
        self.momentum = momentum
        self.init()

    def init(self):
        for i in range(self.sta_matrix.size(0)):
            self.sta_matrix[i][:i+1] = 0
            self.sta_matrix[i][self.B+i+1:] = 0

    def forward(self, x):
        if self.training:
            self.iters.copy_(self.iters + 1)
            x = MABNFunction.apply(x, self.weight, self.bias,
                                   self.running_var, self.eps, 
                                   self.momentum, self.sta_matrix, 
                                   self.pre_x2, self.pre_gz, 
                                   self.iters)
            return x
        else:
            N, C, H, W = x.size()
            var = self.running_var.view(1, C, 1, 1)
            x = x / (var + self.eps).sqrt()

        return self.weight.view(1,C,1,1) * x + self.bias.view(1,C,1,1)

class CenConv2d(nn.Module):
    """Conv2d layer with Weight Centralization
    """
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1,
                 padding=0, dilation=1, groups=1, bias=False):
        super(CenConv2d, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.weight = nn.Parameter(torch.randn(out_planes, in_planes//groups, kernel_size, kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.randn(out_planes))
        else:
            self.register_parameter('bias', None)

    def forward(self, x):
        weight = self.weight
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
        weight = weight - weight_mean
        return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

分析

为什么将两者放到一起呢?这是因为两者具有异曲同工之妙。

Image

对比两者的实现方式,是不是发现两者都采用了了类似的计算?区别在于:FRN采用的实时计算,而MABN则采用在训练集中统计的方式。

FRN与MABN的另一点区别:FRN采用了新的激活函数,而MABN则采用了新的卷积计算方式(权值去中心化)。总体来看,MABN的通用性更佳,当然实际情况需要靠时间去检验,期待各位大神的更多尝试。



推荐阅读:

本文章著作权归作者所有,任何形式的转载都请注明出处。更多动态滤波,图像质量,超分辨相关请关注我的专栏深度学习从入门到精通
1 阅读 22
推荐阅读
0 条评论
关注数
6
文章数
16
夯实深度学习知识基础, 涵盖动态滤波,超分辨,轻量级框架等
目录
qrcode
关注微信服务号
实时接收回答提醒和评论通知