梁德澎 · 2023年02月16日

一文理解“上下文学习”----大语言模型突现能力

前言

=====

最近几年大语言模型(LLM) 获得了越来越多的关注,其中最知名的当属 GPT-3[6] 模型。GPT-3 模型展现了一些大模型才具备的突现能力(就是模型规模必须得增大到一定程度才会显现的能力,比如至少百亿级),其中一项能力就是上下文学习(In-Context Learning)[1,2,3,4,5]

该能力简单来说就是,对于一个预训练好的大语言模型,迁移到新任务上的时候,只需要给模型输入几个示例(示例输入和示例输出对),模型就能为新输入生成正确输出而不需要对模型做 fine-tuning。

这也引发了研究人员对该能力产生原因的思考和探索。本文会首先给读者介绍什么是上下文学习,接着解读一篇最近由微软研究院发布的探索 LLM 上下文学习能力来源的文章[7]

什么是上下文学习(In-Context Learning)?

GPT-n 系列的模型都属于自回归类的语言模型,所谓自回归模型就是根据当前输入预测下一个词,然后将预测结果和输入拼接再当做模型的输入再预测下一个词,这样循环往复。

而自回归模型的训练目标也很简单,就是从超大规模语料库中采样训练样本,模型根据输入输出一个概率向量(概率向量包含所有词的预测概率,对于GPT-3 模型来说,维度约1千多万),而因为文本数据自带标注所以我们是知道真实的下一个词,所以损失函数就采用得交叉熵。

然后研究人员发现,预训练好的 GPT-3 模型拥有一项神奇的能力,后来被称为:上下文学习(In-Context Learning)。

这项能力简单来说就是,预训练好的 GPT-3 模型在迁移到新任务上的时候并不需要重新训练,而只需要提供任务描述(这个任务描述是可选项)接着提供几个示例(任务查询和对应答案,以一对对的形式组织),最后加上要模型回答的查询。将以上内容打包一起作为模型的输入,则模型就能正确输出最后一个查询对应的答案。

举个例子:

比如现在想用 GPT-3 来做个翻译任务,翻译英文为法文。输入的格式如下:

image.png

首先第一行是对任务描述,告诉模型要做翻译,接下来三行就是示例,英文单词和对应的法文单词对,最后一行就是待翻译的英文单词。将以上内容整体作为 GPT-3 的输入,让模型去补全输出就能得到 cheese 对应的法文单词。

上下文学习非常的灵活,除了上面展示的翻译任务,还可以做语法修饰甚至写代码。而神奇的地方就在于,在 GPT-3 的训练过程中是并没有显式的提供,类似测试阶段任务描述加示例这样的训练数据。

当然 GPT-3 的训练数据量非常巨大(比如包含了 wiki, 书本期刊,reddit 上的讨论等等),或许里面就已经就包含了各种任务类似结构的数据,GPT-3 模型容量足够大能够将所有训练数据都记了下来。

对于上下文学习能力的成因,目前还是一个开放性的问题。为什么只有大规模的语言模型才会具备该能力?或许只有模型参数量大还不够,还必须要训练数据量也足够大,模型才能显现出该能力?

深入探究上下文学习

简单的拷贝输出

首先来看一个很简单的任务,就是让模型直接复制输入的内容。

首先示例个数设置为 5 个,每个示例输入包含 5 个不同的小写单词(从字母表前 8 个小写字母中随机选5个得到),这些单词用逗号分隔,输出直接拷贝的输入,比如:

Input: g, c, b, h, d  
Output: g, c, b, h, d  
Input: b, g, d, h, a  
Output: b, g, d, h, a  
Input: f, c, d, e, h  
Output: f, c, d, e, h  
Input: c, f, g, h, d  
Output: c, f, g, h, d  
Input: e, f, b, g, d  
Output: e, f, b, g, d  
Input: a, b, c, d, e  
Output:  

期待模型的输出是:

a, b, c, d, e  

接着对于5个字母顺序的所有可能情况 (8!/3!=6720,从8个样本中选5个总的组合数)也就是最后 input 的位置将 6720 个情况都测试了,GPT-3 模型的准确率是 100%

接着用 GPT-3 系列最小的模型 text-ada-001 来做这个任务,获得了 6705/6720 = 99.78% 的准确率,一定程度上证明了模型规模的重要性。

格式化日期

接着来看 GPT-3 在更复杂一些的任务上的表现。

这个任务是对日期做格式化,将 年-月-日 格式的输入格式转化成 !月!日!年!,其中年份四位数,月份和日子是两位数,比如:

image.png

上面这个例子中,示例个数是3,最后是待测试的日期 2005-07-23

为什么选择日期格式化这个任务呢?

首先足够简单,日期包含三个随机变量(年月日),它们长度都是固定的,而且设定的输出格式也不是正常的格式,所以训练数据中不太可能包含类似的样本,也排除了模型可能只是将训练数据都记忆了下来。

接下来看看测试结果,我们测试了 GPT-3 全系列的模型 [8],包括text-ada-001,text-babbage-001,text-curie-001text-davinci-003,模型参数量依次从小到大排列。

并通过设置不同的上下文示例个数(对于每个示例个数的设置,都有2000个测试样本),记录各个模型的预测准确率,测试结果如下:

image.png

从图表展示的结果来看,固定横坐标示例个数,则模型越大准确率也越高,模型越大准确率曲线也就更加的陡峭。而对于每个模型来说,增加上下文的示例个数也能有效提升准确率。

不过仔细观察图表可以发现,即使增大示例个数和模型,模型的精确度也只是无限接近 100% 但还是达不到。

接下来我们分析一下,GPT-3 预测错误的样本都包含哪些类型。

image.png

这里我们选取了前10个最常见的错误类型,其中图标中的 DD 表示两位数的日子,MM 表示两位数字的月份,mm 一位数的月份,YYYY 则是四位数的年份,YY 是两位数的年份,** 则是其他的两位数。

从实验结果上看,随着上下文示例个数的增加,预测错误的样本个数也在下降。

而模型预测错误最多的格式是,将日期放在月份前面,这也能理解,因为训练数据中常见的日期格式都是先日期,再月份,最后年份。

继续分析模型预测错误的样本,发现一个有趣的结果:

image.png

就是对于 2019 年份的输入,模型是最容易预测错误的,这也能理解因为训练数据中 2019 年份的数据不多。

标签重映射

这个测试任务就是将实体做一个不正常的重新分类,比如:


`volleyball: animal  
onions: sport  
broccoli: sport  
hockey: animal  
kale: sport  
beet: sport  
golf: animal  
horse: plant/vegetable  
corn: sport  
football: animal  
luge: animal  
bowling: animal  
beans: sport  
archery: animal  
sheep: plant/vegetable  
zucchini: sport  
goldfish: plant/vegetable  
duck: plant/vegetable  
leopard: plant/vegetable  
lacrosse: animal  
badminton: animal  
lion: plant/vegetable  
celery: sport  
porcupine: plant/vegetable  
wolf: plant/vegetable  
lettuce: sport  
camel: plant/vegetable  
billiards: animal  
zebra: plant/vegetable  
radish: sport  
`

输入示例中包含了 [animal(动物), plant/vegetable(植物/蔬菜), sport(运动)] 三种类型标签。现在将它们原来的标签映射打乱,将动物映射为植物(duck: plant/vegetable),将运动映射为动物(golf: animal),将植物映射为运动(beans: sport)。

接着测试 GPT-3 能否根据仅有的示例学会预测新的映射,下面是测试结果:

`llama: plant/vegetable ✓  
cat: plant/vegetable ✓  
elephant: plant/vegetable ✓  
monkey: plant/vegetable ✓  
panda: plant/vegetable ✓  
cucumber: sport ✓  
peas: sport ✓  
tomato: sport ✓  
spinach: sport ✓  
carrots: sport ✓  
rugby: animal ✓  
cycling: animal ✓  
baseball: animal ✓  
tennis: animal ✓  
judo: animal ✓  
`

可以看到 GPT-3 能正确输出映射关系。而即使将标签改成无意义的符号比如 [^*, #@#, !!~],模型同样可以输出正确的预测。

上下文学习能力的成因

经过上面对上下文学习的介绍,相信读者也能体会到其神奇之处。

为什么 LLM 能够具备该能力?上下文学习的原理究竟是怎样的呢?

接下来解读一篇最近微软研究院发布的文章[7],对于上下文学习能力来源的探究。

文章中提出,关键在于 LLM 中的注意力层(attention layers),在推理过程实现了一个隐式的参数优化过程,这和 fine-tuning 的时候通过梯度下降法显式优化参数的过程是类似的。

基于梯度下降法的优化过程和注意力层的联系

文章[7]中提出,一个线性的注意力层其实和基于梯度下降法优化的全连接层是互为对偶的形式,具体怎么理解呢?

image.png

首先文章中定义,全连接层的初始参数矩阵为 W0,参数的梯度矩阵为△W,维度为 dout × din。还有当前输入向量 x ,维度为 din。则经过一次梯度下降法优化的全连接层可以表示为:

image.png

其中 △W 由上一次的输入 x' 和上一次全连接层的输出梯度 e 计算得到:

image.png

怎么理解这个梯度的计算公式呢,我们画个图:

image.png
接下来看基于梯度下降法优化的全连接和线性注意力层是怎么联系起来的,

image.png

我们关注红框部分,参数梯度矩阵 △W 是上一次输入和上一次输出梯度的外积求和,这部分可以等价变换为,首先让上一次输入xi'T 和当前输入 x 做内积,接着再和 ei 做内积最后再求和。

接着如果我们将

  • 上一次输出梯度 ei 看做是一个 value 向量,Evalue 矩阵
  • 上一次输入 xi'T 看做是一个 key 向量,X'key 矩阵
  • 当前输入 x 看做是一个 query 向量

其实就等价于是一个线性的注意力层。

上一次输入xi'Tx 先做内积,就是相当于 key 矩阵和当前 query 向量做乘法,得到每个 value 向量的权值,然后每个 ei 和权值相乘再相加,就是所有 value 向量加权求和。

上下文学习怎么实现隐式 finetuning

文章中定义,将上下文学习输入的最后一个词表示定义为 query token ,维度是 d

image.png

则输入到注意力层之后的 query 向量计算公式如下:
image.png

则对于最后一个 token 来说,经过一个注意力头操作的输出公式如下:

image.png

其中 WvWKWQ 都是变换矩阵,维度是 d' × dX' 是输入中示例部分的 token 向量表示,而 X 则表示输入中示例部分之后又在最后一个词之前的所有的 token 的向量表示。[X';X] 表示矩阵拼接。

image.png

然后论文中简化了下公式,将注意力计算中的 softmax 操作去掉了,就得到了上面新的公式。

我们关注上公式的第二到第三行的变换,上图解释变换过程:

image.png

接着文章中将,输入中示例部分之后又在最后一个词之前的所有的 token 的 valuekey 相乘部分定义为 Wzsl(zsl 表示 Zero-shot Learning,0样本学习)当做是初始的权值:

image.png

Wzsl * q 就相当于是一个0样本学习的 attention 结果,因为没有加上前面示例部分的 attention 结果。接着就是根据前面全连接层和 attention 互转的公式可得:

image.png

我们看右边红框部分的变换,我们将示例部分的 token attention 操作中的

  • Wv*X' 看做是对应前面全连接上一次计算的输出梯度
  • Wk*X' 看作是对应前面全连接上一次计算的输入
  • q 看作是当前的输入

然后就可以把推理得示例部分的 token attention 操作部分看做是对应 Wzsl 初始权值的更新梯度 △Wicl(icl 表示 In-Context Learning)。

这就是为什么说 LLM 中的注意力层在推理过程中实现了隐式的参数优化过程。所以这也是上下文学习能 work 的原因。

但是有个疑问就是 attention 机制不管模型规模大小都是一样的操作,为什么模型规模得增加到一定程度上下文学习才能显现呢?

我感觉还是回到模型规模和训练数据上,首先 LLM 中的key, query, value变换矩阵的维度 d ' x d 足够大,其次预训练的数据量也大,所以初始权值 Wzsl 足够好只需要少量的示例梯度 △Wicl 更新参数之后就能 work 了,其实感觉就和 Few-Shot Learning 没什么区别。

作者: 梁德澎
文章来源:GiantPandaCV

推荐阅读

更多嵌入式AI干货请关注嵌入式AI专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。
推荐阅读
关注数
18835
内容数
1370
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息