LSTM - 长短期记忆网络

Posted by RAIS on 2021-02-06

循环神经网络(RNN)

人们不是每一秒都从头开始思考,就像你阅读本文时,不会从头去重新学习一个文字,人类的思维是有持续性的。传统的卷积神经网络没有记忆,不能解决这一个问题,循环神经网络(Recurrent Neural Networks)可以解决这一个问题,在循环神经网络中,通过循环可以解决没有记忆的问题,如下图:

RNN-rolled

看到这里,你可能还是不理解为什循环神经网络就可以有记忆。我们把这个图展开:

RNN-unrolled

可以看出,我们输入 $X_0$ 后,首先警告训练,得到输出 $h_0$,同时会把这个输出传递给下一次训练 $X_1$,普通的神经网络是不会这样做的,这时对 $X_1$ 进行训练时,输入就包括了 $X_1$ 本身和 训练 $X_0$ 的输出,前面的训练对后面有印象,同样的道理,之后的每一次训练都收到了前面的输出的影响(对 $X_1$ 训练的输出传递给训练 $X_2$ 的过程,$X_0$ 对 $X_2$ 的影响是间接的)。

遇到的问题

循环神经网络很好用,但是还有一些问题,主要体现在没办法进行长期记忆。我们可以想象(也有论文证明),前期的某一次输入,在较长的链路上传递时,对后面的影响越来越小,相当于网络有一定的记忆能力,但是记忆力只有 7 秒,很快就忘记了,如下图 $X_0$ 和 $X_1$ 对 $h_{t+1}$ 的影响就比较小了(理论上通过调整参数避免这个问题,但是寻找这个参数太难了,实践中不好应用,因此可以近似认为不可行),LSTM 的提出就是为了解决这个问题的。

RNN-longtermdependencies

LSTM

LSTM(Long Short Term Memory)本质还是一种 RNN,只不过其中的那个循环,上图中的那个 A被重新设计了,目的就是为了解决记忆时间不够长的问题,其他神经网络努力调整参数为的是使记忆力更好一点,结果 LSTM 天生过目不忘,简直降维打击!

普通的 RNN 中的 A 如下图,前一次的输入和本次的输入,进行一次运算,图中用的是 tanh:

LSTM3-SimpleRNN

相比较起来,LSTM 中的 A 就显得复杂了好多,不是上图单一的神经网络层,而是有四层,如下图,并且似乎这么看还有点看不懂,这就是本文需要重点分析的内容,仔细认真读下去,定会有收获:

LSTM3-chain

定义一些图形的含义,黄色方框是简单的神经网络层;粉色的代表逐点操作,如加法乘法;还有合并和分开(拷贝)操作:

LSTM2-notation

核心思想

首先看下图高亮部分,前一次的输出,可以几乎没有阻碍的一直沿着这条高速公路流动,多么简单朴素的思想,既然希望前面的训练不被遗忘,那就一直传递下去:

LSTM3-C-line

当然,为了让这种传递更加有意义,需要加入一些门的控制,这种门具有选择性,可以完全通过,可以完全不通过,可以部分通过,S 函数(Sigmoid)可以达到这样的目的,下面这样就是一个简单的门:

LSTM3-gate

总结一下,我们构造 LSTM 网络,这个网络有能力让前面的数据传递到最后,网络具有长期记忆的能力,同时也有门的控制,及时舍弃那些无用的记忆。

详细分析

有了这样的核心思想,再看这个网络就简单了好多,从左到右第一层是“选择性忘记”。我们根据前一次的输出和本次的输入,通过 Sigmoid 判断出前一次哪些记忆需要保留和忘记:

LSTM3-focus-f

第二部分又分为了两个部分,一个部分是“输入门层”,用 Sigmoid 决定哪些信息需要进行更新,另一个部分是创建候选值向量,即本次输入和上次输出进行初步计算后的中间状态:

LSTM3-focus-i

经过前面的计算,我们可以更新单元格的状态了。第一步,前一个的单元格哪些数据需要传递,哪些数据需要忘记;第二步,本次的哪些数据需要更新,乘以本次计算的中间状态可以得到本次的更新数据;再把前两步的数据相加,就是新的单元格状态,可以继续向后传递。

LSTM3-focus-C

这一步需要决定我们的输出:第一步,我们用 Sigmoid 来判断我们需要输出的部分;第二步,把上面计算得到的单元格状态通过 tanh 计算将数据整理到 -1 到 1 的区间内;第三步,把第一步和第二步的数据相乘,就得到了最后的输出:

LSTM3-focus-o

总结一下我们刚刚做了什么:我们首先通过本次的输入和上次的输出,判断出上次单元格状态有哪些数据需要保留或舍弃,再根据本次的输入进行网络训练,进一步得到本次训练的单元格状态和输出,并将单元格状态和本次的输出继续往后传递。

这里有一个疑问,为什么需要舍弃?举个例子,翻译一篇文章,一篇文章前一段介绍某一个人的详细信息和背景,下一段介绍今天发生的某个故事,两者的关系是弱耦合的,需要及时舍弃前面对人背景信息的记忆,才会更好的翻译下面的故事。

其他一些基于 LSTM 修改版的网络,本质是一样的,只不过把某些地方打通了,有论文验证过,一般情况下对训练的结果影响很小,这里不展开介绍,大同小异,修内功而不是那些奇奇怪怪的招式:

LSTM3-var-peepholes

LSTM3-var-tied

LSTM3-var-GRU

总结

本文介绍了长短期记忆网络,在大多数情况下,若在某个领域用 RNN 取得了比较好的效果,其很可能就是使用的 LSTM。这是一篇好文,本文图片来自Understanding-LSTMs,值得一读。

  • 本文首发自: RAIS