AI学习者 · 3月27日

LSTM模型分析

1. LSTM模型是什么

LSTM是一种RNN模型。RNN和CNN可以是DL的两种重要模型。CNN主要处理空间结构数据,RNN主要处理时间序列数据。但也不是绝对,本文中用作profiling的例子(代码来源参考文献[1])就是LSTM处理图片,即空间结构数据。

RNN(Recurrent neural network,循环神经网络)是一系列能够处理序列数据的神经网络的总称。RNN特性是隐单元间的连接是循环的;如果输入是一个时间序列,可以将其展开。其中的每一个单元,除了处理当前时间点的输入数据外,还要处理前一个单元的输出,最终输出单一的预测。基本的RNN模型只处理前一个单元的输出,这样距离远的单元的输出,因为中间经过多次处理,影响就逐渐消失。这带来一个问题就是Long term dependency。

为了说明Long term dependency,下面看两个例子。第一个例子中,前文内容是“蓝天上朵朵”,希望预测的结果是“白云”。我们看到“朵朵”就自然会预测下一个词“白云”,因为这两个词中间没有间隔。

image.png

但如果是另外一种情况,比如下面的例子。“他出生在中国,去过很多国家,…,能说流利的”,这时候模型就不能准确预测“流利”后面应该接哪个词。因为“出生在中国”这个关键信息经过模型的多层处理,对输出预测的影响已经很小。

image.png

为了尽可能将有用的关键信息保留下来,前人提出了很多方法,LSTM就是其中一种。LSTM算法全称为Long short-term memory,是一种特定形式的RNN,模型结构如下图所示。LSTM的输出和RNN不同,除原有输出h外,还增加了一条,最上面贯穿整个network的通路,称为cell state,单元状态。注意到在cell state通路上没有非线性单元,只有乘法和加法。所以cell state才可能只经过比较少的改变就传给下一个单元。LSTM的巧妙之处在于通过增加各种门控,如,输入门,遗忘门和输出门,这使得自循环的权重是变化的,控制了上一单元有多少信息可以通过(sigmoid输出0-1),当前单元有哪些信息可以添加到cell state传递给下一个单元。

image.png

2. LSTM实现细节分析

以下是参照tensorflow的源码实现对LSTM实现细节的分析,包括分析其结构和各部分的作用。

image.png

Forget gate,输入是前一个单元输出h(t-1)和当前时间输入数据x(t)。注意这里的sigmoid模块不仅只有sigmoid激活函数,而实际是sigmoid layer,包含了和weight的matrix乘和add bias。输出f(t)输出的取值范围是 0-1,决定了允许什么样的信息通过,什么信息保留。如果f(t)输出是0,则上一单元的状态信息被全部遗忘,不会被输入到当前单元;如果f(t)输出是1,则上一单元的状态信息被当前单元全部保留。如果以上面的“能说流利的中文”为例,为了能正确预测下一个词是“中文”,“他出生在中国”这种关键信息就应该尽可能完整的保留在C(t-1)中,并逐级传递给预测单元。多多白云,会将相关的信息多保留,不相关的信息多forget。但不是不通过,有一定比例的保留。

input gate,决定哪些新输入的信息可以保留到cell state。其中第一个sigmoid layer称为input gate,输出i(t)是一个控制信号,控制哪些新的输入信息可以被更新到cell state。旁边的tanh layer产生的就是新信息。比如在前面语言预测的例子中,重要的关键词就多输出一些,不重要的“地得啊呀”这类不重要的信息经过i(t)的控制可以少输出到cell state。而对于关键信息,如“他出生在中国”,i(t)应该控制其尽可能保存到cell state中。i(t)和c(t) head相乘的结果就决定了什么样的信息可以更新到cell state中。

update gate,将新信息和需要保留信息的cell state拼接,得到新的cell state。

Output gate,决定输出什么信息。o(t)控制哪些信息需要输出给下一单元,输出信息由cell state决定。

image.png

Tensorflow中的LSTM源码实现和paper中的公式略有不同。Paper中各个门的输入数据和权值矩阵分别做矩阵乘法。源码中是将输入与self._kernel做矩阵乘,self._kernel是四个门的权值矩阵拼接得到。再通过split将矩阵乘结果分为四部分,分别对应四个门的输出(即代码中的I,j,f,o)。

3. LSTM profiling分析

LSTM profiling分析分为三部分。第一,data layout,第二Timeline分析,第三输入输出分析。

前文已经提到,这个例子是用LSTM做手写数字识别。每张数字图片大小是28 * 28pixel。首先用tf.transpose交换输入数据的第0维和第1维,维度由原来的[batch_size, 28, 28]变为[28,batch_size, 28]。也就是说,变换后的第一块数据的第一行数据来自第一张图片的第一行pixel,第一块数据的第二行数据来自第二张图片的第一行pixel,第一块数据的第三行数据来自第三张图片的第一行pixel,以此类推。

tf.reshape将三维数据拼接成二维数据,总数据量保持不变。

tf.split将大的二维数据再且分为28块,每块对应一个LSTM单元输入。

image.png

为什么要对数据做上述处理,而不是直接将手写数字图片输入给LSTM做预测?这里可以通过下图说明。

如果将数字图片直接输入给LSTM,预测精度只能到达10%左右,且不会提高。因为如果每个LSTM单元输入的是不同的数字图片,各个单元之间的cell state连接其实对预测结果没有助益,因为两次输入(如手写数字1和3)之间根本没有内在关联,单元之间的循环连接不起作用,当前单元无法利用前一单元的预测信息。即使有再多的输入数据,预测结果也等同于在10个数字中的随机猜测。

经过变换后的数据,是经过上述处理切分的28块数据。第一块数据的第一行数据来自第一张图片的第一行pixel,第二行数据来自第二张图片的第一行pixel,第三行数据来自第三张图片的第一行pixel,以此类推。第一块数据输入给第一个LSTM单元后,LSTM做预测得到输出并将cell state传给第二个单元。第二块数据输入给第二个LSTM单元后,同样做预测得到输出并将cell state传给第三个单元。以此类推,一直到第28个单元做预测后输出整体特征。这个特征经过全连接操作得到预测结果。

image.png

经过变换后的数据,经过大约10轮循环训练,精度就可达到90%以上。因为不同数据块包含相同数字图片的不同行的数据,相当于每次给不同的LSTM单元看的信息逐渐增加。模型在这个过程中通过cell state利用前一单元的预测结果不断学习。最后的LSTM单元可以利用之前单元的全部学习结果,得到准确的预测。

以上是输入数据的变换。下图是LSTM单元内部计算过程中,数据shape的变换。

image.png
image.png

输入输出和Timeline分析需要用到Chrome tracing,界面如下图所示。Tracing 可以直观地看到operation的执行时间和调用关系以及时序等。Operation起始时候和结束时候都记录下时间并保存到 json 文件中,调用栈等进行可视化表示。在Chrome的地址栏中敲入chrome://tracing/,通过load 按钮加载json文件。在tracing界面中选中关心的operation,可以看到operation的输入输出。

image.png

为了生成json文件,需要在python代码中增加如下RunOptions和timeline写操作。

from tensorflow.python.client import timeline

run_options = tf.RunOptions(trace_level= tf.RunOptions.FULL_TRACE)

run_metadata = tf.RunMetadata()

sess.run(xxx, options=run_options, run_metadata=run_metadata)))

tl = timeline.Timeline(run_metadata.step_stats)

ctf = tl.generate_chrome_trace_format()

with open('lstmtimeline.json', 'w') as f:

f.write(ctf)

下图是LSTM的operation pattern,从中可以看到完整的LSTM operation序列,如matmul, biasadd等。

image.png

下图是tracing中operation的细节,及其与代码的对应关系。

image.png
image.png
image.png
image.png

参考文献

【1】 https://colah.github.io/posts/2015-08-Understanding-LSTMs

【2】 https://github.com/tensorflow/tensorflow/blob/r1.7/tensorflow/python/ops/rnn_cell_impl.py

【3】 https://www.chromium.org/developers

作者:Frank Wang
文章来源:知乎

推荐阅读

更多芯擎AI开发板干货请关注芯擎AI开发板专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
3467
内容数
22
搭载安谋科技“周易”NPU的芯擎AI开发板,支持各种AI模型部署。
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息