Happy · 6 天前

89.77%!谷歌大脑Quoc V.Le团队提出CoAtNet:将卷积与自注意力纳入同一模块

首发:AIWalker
作者:HappyAIWalker

image.png

本文是谷歌研究院Quov V.Le团队在卷积与自注意力组合方面的探索,将深度卷积与自注意力集成统一到一个计算模块中,并从实验角度对卷积与自注意力的组合方式进行了论证,进而确定了CoAtNet的架构。所提方案在ImageNet数据集上取得了大幅超越其他ConvNet与Transformer的性能。比如,无需额外数据,CoAtNet在ImageNet上取得了86%的top1精度;额外引入JFT预训练后,模型进一步提升提升到89.77%,超越了之前最佳的EfficientNetV2与NFNet。

Abstract

Transformer在计算机视觉领域受到 了越多越多的关注,但他们的性能仍落后于优秀的CNN。在这篇文章中,我们将表明:尽管Transformer具有非常大的模型容量,但由于缺乏正确的归纳偏置导致其泛化性能不如CNN

为有效取两者之长,我们提出了CoAtNet,它基于以下两个关键点而构建的混合模型:

  • 深度卷积与自注意力可以通过简单的相对注意力进行统一;
  • 垂直叠加卷积与注意力层对于提升泛化性能、容量以及效率非常有效。

实验结果表明:在不同数据集、不同资源约束下,所提CoAtNet均取得了SOTA性能。比如,无需额外数据,CoAtNet在ImageNet上取得了86.0%top1精度;当引入额外的JFT后其性能可以进一步提升到89.77%。更值得注意的是,当在ImageNet21K上预训练后,CoAtNet可以达到88.56%top1精度,这与JFT上预训练的ViT-huge精度相当,而CoAtNet需要的训练数据比ViT-huge少23倍。

Method

接下来,我们将聚焦于“如何最优化组合卷积与Transformer?”粗略的讲,我们将该问题分为两部分:

  • 如何将卷积与自注意力组合到一个基础计算模块中?
  • 如何垂直堆叠不同类型的计算模块以构建一个完整网络?

Merging Convolution and Self-Attention

对于卷积而言,我们主要聚焦于MBConv,它采用深度卷积捕获空间交互关系。进行该选择的一个关键原因:Transformer中的FFN与MBConv均采用了“Inverted Bottleneck”的设计思想(先对输入通道数扩张,然后在收缩到原始通道数以支持残差连接)

除了“Inverted Bottleneck”的相似外,我们还注意到:深度卷积与自注意力均可以表达成预定义感受野范围内值的加权和。具体来说,卷积依赖于固定核在局部感受野内收集信息:
image.png
作为对比,自注意力使得感受野覆盖到整个空间位置并基于成对数据计算权值后加权:
image.png
在正式介绍如何对其进行最佳组合之前,我们先来看一下两者之间的相对优弱所在,这有助于我们理解需要保留的优异属性。

  • 首先,对于深度卷积来说,它的卷积核是输入无关的静态值;而对于自注意力而言,它的注意力权值是输入相关的动态值。因此,自注意力更易于捕获不同位置间复杂的相关性,而这个特性是我们在处理高级概念时所需要的。然后,自注意力的这种特性也带来的过拟合问题,尤其当数据非常有限时;
  • 其次,注意到:对于任意位置对,卷积的权值仅关注了相关偏移,而未关注特定的值。这种属性一般特指平移不变性,它有助于提升小数据集的泛化性。由于引入了绝对位置嵌入,Transformer并不具备这种特性。这就是为什么小数据集上ConvNet的性能通常优于Transformer的原因。
  • 最后,感受野的尺寸是卷积与自注意力的最关键区别。一般来讲,更大的感受野可以提供更多的上下文信息,进而导致更高的模型容量。因此,全局感受野是在视觉领域使用自注意的一个关键驱动力。然而,更大的感受野同时也带来了更多的计算量。

image.png

基于上述比较,一个理想的模型应当能组合上表中的三个属性。类似深度卷积与自注意力,一个最直接的方式:全局静态卷积核与自适应注意力矩阵组合,即:
image.png
有意思的是,尽管这个想法看起来过于简单,但是预归一化版本对应了相对自注意力的一个特定变种。这种情况下,注意力权值由平移不变权值与输入自适应联合确定。更重要的是,为引入全局卷积核且不会导致大量的参数量,我们重新加载为标量而非向量,进而只会引入非常少的计算量。接下来,我们将采用带预归一化相对注意力变种的Transformer作为CoAtNet的核心模块

Vertical Layout Design

前面找到了一种组合卷积与注意力的简单方法,接下来,我们将考虑如何进行堆叠以构建一个完整网络。

正如前面所提到:全局上下文会带来大量的计算量,与空间尺寸成二次关系。如果我们直接采用上述相对注意力到原始输入图像,计算效率会非常低。因此,为构建一个实际可行的网络,我们主要以下三种候选方案:

  • A: 执行下采样以降低空间尺寸,在达到可接受水平后采用全局相对注意力;
  • B: 采用局部注意力以全局,类似卷积约束自注意力的感受野;
  • C: 采用特定的线性注意力替换二次Softmax注意力,进而将计算复杂度降低到与空间尺寸成线性关系。

我们对方案C进行了简单的实验,但并未得到一个好的结果。对于方案B,局部注意力的实现会涉及大量的形变变换操作,进而需要大量的内存访问。在TPU上,这种操作的计算效率非常低,这不仅与全局注意力的加速相悖,同样会影响模型容量。因此,接下来,我们主要聚焦于方案A。

对于方案,下采样可以通过以下方式得到:

  • 类似ViT,采用卷积下采样到stride=16;
  • 类似ConvNet,采用多阶段逐渐池化的网络.

基于上述选择,我们设计了一个包含5个变种的搜索空间并通过可控实验进行了对比。

  • 当采用ViT Steam时,我们直接堆叠L个Transformer,表示为
  • 当采用多阶段ConvNet,我们模拟ConvNet构建一个包含5个阶段(S0,S1,S2,S3&S4)的网络,空间分辨率逐渐从S0下降到S4。S0为简单的两层卷积,S1采用MBConv与SE。从S2到S4,我们同时考虑MBConv与Transformer ,此时有这样一个约束:卷积要位于Transformer之前。这就引出了4个变种:C-C-C-C,C-C-C-T, C-C-T-T, C-T-T-T。

为系统研究这些设计选择,我们主要考虑了泛化性与模型容量两个引入:

  • 泛化性:我们主要感兴趣训练损失与验证精度之间的差异。如果两个模型具有相同的训练损失,具有更高验证精度的模型具有更好的泛化性。当训练数据比较小时,泛化性对于提高数据效率非常重要。
  • 模型容量:我们用来评价模型对于大数据的拟合能力。当训练数据非常大时(几乎不可能存在过拟合),具有更高精度的模型通常具有更大的容量。

为比较模型泛化性与容量,我们在ImageNet与JFT上分别训练300与3epoch不同变种的混合模型,均未添加任何正则与增广,训练损失与验证精度对比见下图。

image.png

  • 对于ImageNet上的结果,一个重要发现:就泛化性能而言,有以下结论:
    image.png

也就是说,的泛化性最差。我们认为:这与ViT的一次性下采样缺乏底层信息有关。

  • 在模型容量方面,其排名如下:
    image.png

这意味着:具有更多的Transformer模块并不意味着更高的容量。为在前两者之间做一个决策,我们考虑另外一个“迁移能力”:在JFT上预训练,然后在ImageNet微调。结果见下表,我们最终选择了配置

image.png
image.png

考虑到泛化性能、模型容量、迁移性能以及计算效率,我们采用了构建CoAtNet,其架构见上图。

Experiments

接下来,我们在相对公平的设置下对所提CoAtNet与其他方案进行了对比。下表给出了不同配置CoAtNet的参数信息。

image.png

Main Results

image.png

ImageNet-1K 上表给出了仅在ImageNet上训练的模型性能对比,从中可以看到:在相似前提下,CoAtNet不仅优于ViT变种,同时与最佳卷积方案(比如EfficientNetV2、NFNet)的性能相媲美。下图Figure2给出了仅在ImageNet上训练时模型结果的可视化图,可以看到CoAtNet具有比其他注意力模型更佳的性能

image.png

ImageNet-21K 如上述Table4与Fig3所示,当采用ImageNet21K预训练后,CoAtNet的优势更明显,大幅优于其他所有模型。比如,CoAtNet取得了88.56%的精度,与ViT-H/14的88.55%相媲美,同时所需额外的预训练数据更少。

image.png

JFT 上表对比了采用JFT预训练时的模型性能对比,可以看到:CoAtNet-4取得了与之前NFNet-F4相当的性能、相当的参数量,同时具有更快的TPU训练速度;CoAtNet-5则进一步达到了89.77%的top1精度,优于其他同等配置的模型。

全文到此结束,更多消融实验与分析建议查看原文。

推荐阅读

本文章著作权归作者所有,任何形式的转载都请注明出处。更多动态滤波,图像质量,超分辨相关请关注我的专栏深度学习从入门到精通
3 阅读 416
推荐阅读
0 条评论
关注数
2052
内容数
92
夯实深度学习知识基础, 涵盖动态滤波,超分辨,轻量级框架等
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
Arm中国学堂公众号
关注Arm中国学堂
实时获取免费 Arm 教学资源信息
Arm中国招聘公众号
关注Arm中国招聘
实时获取 Arm 中国职位信息