AI学习者 · 2022年07月12日

一种融合卷积的ViT模型

自从Vision Transformer网络面世以来,Transformer模型在CV领域的应用也逐渐开始崭露头角。然而如图1所示,原始的ViT网络模型在小数据集上的表现差强人意,相较于传统的CNN网络并没有明显的优势;并且ViT和相似尺寸的CNN网络相比所需要的训练数据集更大,种种这些现象无疑减弱了ViT模型在小数据集的推理应用。本周将与大家分享一篇来自ICCV的文章:CvT: Introducing Convolutions to Vision Transformers,该论文提出了一种结合卷积神经网络和Transformer的思路,可以有效地提升ViT模型在小数据集的预测精度。

图1 CNN_based model vs ViT_based model 的预测精度对比

原因分析

文章作者认为,ViT之所以在小数据集中的表现不如CNN的原因主要有以下两点:

  1. CNN的卷积操作具备诸如local receptive fields, shared weights, and spatial subsampling的特性,这样可以实现对图片局部关联特征的有效捕捉,并且也具备了一定程度上的平移,缩放和旋转不变性。然而基于Transformer的ViT并不具备这样的特性,ViT通过将图片分成不同的patch,再转换成一维的序列输入进行attention的计算,本质上计算了单个patch与全局不同位置处的图片关联性,因此在训练过程中往往需要更多的数据进行学习。
  2. 多层级的CNN网络结构有助于提取不同层次的信息。例如低层级的卷积核可以提取图片边缘和纹理信息,而高层级的卷积核可以获取更为丰富的语义信息;这些特性都可以使得CNN网络更适宜处理图像方面的任务。因此,一个自然而然的想法就是通过在ViT中引入卷积的相关操作,从而实现对ViT网络整体性能的改善。

融合策略

作者团队提出的CvT模型架构pipeline如图2(a)所示, 相比ViT网络架构,包含两个核心变化:

  1. 参考CNN的多层分级结构,将Transformer模型做类似的多层分级设计。每一个stage的第一层均采用Convolutional token embedding将2D的image或token map的特征进行提取同时进行空间下采样,以获取更多维度的特征图谱,如CNN的卷积操作。
  2. 将Transformer中的linear projection替换成convolutional projection,这样有利于模型进一步捕捉局部空间的关联信息,减少注意力机制中的语义歧义;同时在注意力的Q/K/V计算中,作者团队尝试了不同的stride的取值,以牺牲一点精度为代价,进一步压缩了模型的大小,提升了计算的效率。

文章指出这样的结构设计既能够充分利用CNN的特性,如local receptive fields, shared weights, and spatial subsampling,,也能够保留transformer模型的优良特性:dynamic attention, global context fusion, and better generalization。

图2(a)CvT网络结构图,(b)convolutional transformer block的细节图

效果分析

CvT模型与其他模型的在ImageNet, ImageNet Real和ImageNet V2的测试对比结果如下表所示,可以看出,相比基于Transformer 的模型,CvT模型通过更少的学习参数和计算消耗获得了更高的预测精度。相比于传统的ResNet,也具备更高的预测精度。

图3. 不同模型的效果对比结果表

总结

文章提供了一种CNN和ViT耦合的新模型CvT, 该模型既能利用CNN对图像特征的有效的学习处理,又可以保留Transformer的动态注意力机制和全局信息感知,使得当训练数据集较少时,能够明显地提升基于Transformer模型的图像预测精度。

Reference

Wu H, Xiao B, Codella N, et al. Cvt: Introducing convolutions to vision transformers[C]//Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021: 22-31.

作者:于璠
文章来源:知乎

推荐阅读

更多嵌入式AI相关技术干货请关注嵌入式AI专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。
推荐阅读
关注数
18808
内容数
1352
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息