ronghuaiyang · 2021年03月09日

Deepmid的新SOTA图像分类模型:NFNets,不需要BN也能超越EfficientNets

首发:AI公园公众号
作者:Mostafa Ibrahim
编译:ronghuaiyang

导读

不需要bn,通过自适应的梯度裁剪,除了能优化到loss最小,还能优化到loss的锐度最小,从而提高泛化能力。

image.png

我们较小的模型测试精度在ImageNet上可以匹配EfficientNet-B7,同时训练时达到8.7×更快的加速,我们最大的模型达到了新的最先进的最高精度86.5%。

训练一个模型最烦人的事情之一是训练它所花费的时间和容纳数据和模型所需的内存量。由于图像分类是最常见的机器学习任务之一,Deepmind发布了一种新模型,该模型具有与最先进水平(SOTA)相匹配的性能,且具有明显更小的尺寸、更高的训练速度和更少的简化优化技术。

在他们的工作中,他们检查了目前的SOTA模型,如EfficientNets和ResNets。在他们的分析中,他们确定了一些使用大量内存却没有产生显著性能价值的优化技术。他们证明这些网络在没有优化技术的情况下也能达到同样的性能。

虽然提出的模型可能是最有趣的一点,但我还是觉得前期工作的分析非常有趣。仅仅因为这是大部分学习发生的地方,我们开始理解什么可以做得更好,以及为什么新提出的方法/技术比旧的方法/技术是一种改进。

先决条件:Batch Normalisation

论文首先分析了Batch Normalisation,为什么?因为尽管它已经取得了很好的结果,并在大量的SOTA模型中得到了大量的应用,但它也有一些论文概述了它的缺点,例如:

  1. 非常昂贵的计算成本
  2. 引入了许多额外的需要进一步微调的超参数
  3. 在分布式训练中造成了很多实现错误
  4. 在小batch的情况下表现不佳,小batch通常用于训练较大的模型

但首先,在删除Batch Normalisation之前,我们必须理解它给模型带来了什么好处。因为我们想要找到一种更聪明的方式来保持这些好处,但减少缺点。这些好处是:

  1. 它缩小了深ResNets中的残差分支的尺寸。ResNets是应用最广泛的图像分类网络之一。Batch Normalisation通常扩展到数千个层,Batch Normalisation减少了“隐藏的激活”的规模,这些激活常常导致梯度以一种有趣的方式表现(梯度爆炸问题)
  2. 消除了常见的激活函数(如ReLU和GeLU)的均值偏移。在大型网络中,这些激活函数的输出通常趋向于非常大的平均值。这导致网络在某些情况下(如初始化)对所有样本预测相同的标签,从而降低了其性能。Batch Normalisation解决了这个均值偏移问题。

还有一些其他的好处,但是我想你已经明白了主要是关于规范化和平滑的训练过程。

NFNets — Normaliser Free Networks:

image.png
虽然之前有很多论文尝试去移除batch normalization (BN),但结果与SOTA的性能或训练延迟不匹配,并且在大batch的时候似乎失败了,这是本文的主要卖点。他们在不影响性能的情况下成功地删除了(BN),并极大地提高了训练延迟。

为此,他们提出了一种梯度裁剪技术,称为自适应梯度裁剪(AGC)。本质上,梯度裁剪是为了稳定训练模型,不允许梯度超过一定的阈值。这允许使用更大的学习率,从而更快的收敛,而没有爆炸梯度问题。

然而,主要的问题是设置阈值超参数,这是一个相当困难的手工任务。AGC的主要好处是可以删除这个超参数。要做到这一点,我们必须研究梯度规范和参数规范。

虽然我对每个机器学习模型背后的数学很感兴趣,但我知道很多ML爱好者不喜欢阅读一堆长微分方程,这就是为什么我将从理论/直观的角度来解释AGC,而不是从数学严谨的角度。

范数仅仅是向量大小的量度。AGC建立的前提是:

梯度的范数与层的权值的范数的单位比率提供了一个简单的度量方法,可以衡量单个梯度下降步骤对原始权值的改变程度。

但是为什么这个前提有效呢?让我们倒回去一点。一个非常高的梯度会使我们的学习不稳定,如果是这样的话,那么权值矩阵的梯度与权值矩阵的比率就会非常高。

这个权重比例相当于:

学习率 x 梯度与权值矩阵的比值(这是我们的前提)。

所以本质上,这个前提所提出的比率是一个有效的指标来决定我们是否应该剪切梯度。还有一个小调整。他们通过多次实验发现,使用单位级的梯度范数的比例比使用层级的比例要好得多(因为每个层可以有不止一个梯度)。

除了AGC,他们还使用dropout来替代Batch normalization所提供的规范化效果。

他们还使用了一种优化技术,称为锐度感知最小化(SAM)

由于损失的几何形状和泛化之间的联系 —— 包括我们在这里证明的泛化边界 —— 我们引入了一种新颖的、有效的方法来同时最小化损失值和损失的锐度。特别的,我们的处理,锐度感知最小化(SAM),寻找的参数存在于具有一致低损失的区域,这个公式产生了一个最小-最大优化问题,在这个问题上梯度下降可以有效地执行。我们提出的经验结果表明,SAM改善了跨各种基准数据集(例如CIFAR-{10,100}、ImageNet、微调任务)和模型的模型泛化,为若干模型提供了最新的性能。

loss的锐度的概念似乎很有趣,为了简洁起见,我可能会在另一篇文章中探讨它。这里要注意的最后一点是,他们对SAM做了一个小的修改,以减少20-40%的计算成本!而且他们只在两款最大的型号上使用它。看到对这些技术的补充,而不仅仅是开箱即用,总是很棒的。我认为这表明他们在使用它之前已经对它进行了大量的分析(因此能够对它进行一些优化)。

最后的想法和要点

谁会想到,替换一个比如batch normalization的小的优化技术,会在训练延迟方面提高9倍。我认为这传达了一种信息,那就是对到处都在使用的流行优化技术持怀疑态度。平心而论,我这个的受害者,我曾经把所有流行的优化技术在我机器学习项目没有完全检查其利弊。我猜这是阅读机器学习论文的主要好处之一,分析以前的SOTAs !

—END—

英文原文:https://towardsdatascience.co...

推荐阅读

关注图像处理,自然语言处理,机器学习等人工智能领域,请点击关注AI公园专栏
欢迎关注微信公众号
AI公园 公众号二维码.jfif
推荐阅读
关注数
8257
内容数
210
关注图像处理,NLP,机器学习等人工智能领域
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息