AI学习者 · 2021年11月23日

基于GNN的元学习模型直接预测神经网络参数

深度学习神经网络训练往往需要千万级次梯度迭代才能收敛,参数达到稳定,如何通过元模型预测深度学习模型参数,加速或替代梯度迭代收敛的过程。 近期NIPS2021的一篇论文为这一问题提供了一种基于GNN的建模思路并取得了不错的效果,作者来自加拿大圭尔夫和Facebook。

论文题目:

Parameter Prediction for Unseen Deep Architectures

方案

模型结构图数据

神经网络的计算图天然为图结构数据,图中的节点为算子如卷积、全连接和求和等,边为算子直接的连接,由于神经网络的计算有正向和反向,因此可以视为有向图。
image.png

图1 神经网络的计算图

模型结构空间

One-Shot神经网络结构搜索算法通过训练一张超网,用于评估子网的性能,从中获取符合搜索目标的子网进行从头训练。这种方法对于网络结构的扩展性差,不能预测任意网络的性能,并且往往关注不同网络结构之间的性能排序而不是预测网络的参数。本文基于DARTS扩展了模型结构空间,覆盖了non-separable 2D conv、SE、Transformer类型算子和LN等15种算子,可以涵盖VGG,RESNET和VIT等模型。通过每次均匀采样所有可能的算子即可构建一个计算单元,计算单元的取值范围由定义好的模型深度、宽度和节点数确定。
image.png

图2 模型结构

模型结构图超网

基于超网可以构建一张覆盖模型结构采样空间的大图,每次从图中采样的子图则对应一个神经网络模型结构。用于预测神经网络参数的元模型,包含Gated-GNN、参数生成模块Decoder和神经网络loss计算模块。本文每次采样一个batch的子图进行训练 ,Gated-GNN图神经网络模拟神经网络前向和后向计算进行消息传播更新子图中每个节点的特征,节点特征初始化为算子类型的one -hot编码和参数shape的编码特征,将图节点特征作为Decoder的输入,输出模型参数,Decoder采用了多层感知机输出3D算子参数,不同类型的算子从输出的最大shape中进行切分获取相应shape的参数,生成的参数用于在ImageNet等数据集上进行Loss计算,并将梯度反传回来。本文加强了Gated-GNN网络远距离消息传播并对不同类型的算子进行相应的正则化。

image.png

图3 元模型

效果

元模型GHN-2预测的Resnet50参数在ImageNet数据集top5精度可以达到60%,约为梯度迭代更新2500步的达到的精度,极大的提升了收敛的效率。
image.png

图4 DEEPNETS-1M ImageNet 结果

预测元模型GHN-2训练过程未覆盖的网络结构参数时,预测参数精度随着分布差异增大而衰减,例如预测神经网络信道数超过元模型训练网络覆盖范围时,预测效果随着信道数增加而精度下降。说明这一元模型存在局限。
image.png
图5 泛化性能

总结

正向

利用模型计算图拓扑结构之间的相似性,预测未见模型的参数,可以作为一种模型参数初始化方法替代随机初始化,适用于少样本学习。

1.基于GNN模型利用模型计算图拓扑结构相似性,实现不同模型之间的参数迁移
2.单次前向传播即可预测所有参数。
3.针对未见过的模型仍可预测出比随机参数精度更高的参数。

反向

1.预测参数能否加速全程收敛未在实验结果呈现,是否可能只是局部最优,而后续收敛过程更曲折
2.泛化性问题,针对元模型训练未能覆盖结构,效果有待提升。

文章转载于:知乎
作者: 于璠

推荐阅读

更多嵌入式AI技术相关内容请关注嵌入式AI专栏。
推荐阅读
关注数
18951
内容数
1459
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息