RNN(Recurrent Neural Network)是用于处理序列数据的神经网络,它的网络结构如下图所示
该网络的计算过程可表示为,其中为权重矩阵,为网络的输出
权重矩阵的计算使用BPTT算法,它的本质还是BP算法,只不过要加上基于时间的反向传播,以下图一个简单的网络为例,其中表示输出,表示实际值,表示损失函数
根据链式求导法则:
(1)
(2)
(3)
由上述公式可以很容易看出时间步长间隔越多,在梯度计算中累乘的项数就越多,激活函数的导数相乘的次数就越多,越容易出现梯度消失现象,即当前时刻与多个时间步长之前的时刻之间的依赖关系在计算过程中被丢弃了
- 激活函数选取relu,右侧导数恒为1,可以较好地解决梯度消失问题;但是若W没有很好地初始化,容易产生梯度爆炸问题,需使用梯度裁剪(如果梯度的范数大于某个给定值,将梯度同比收缩)解决
- RNN网络的一些变种(例如LSTM、GRU)可以较好地解决梯度消失问题
长短时记忆网络LSTM(Long Short Term Memory network)是常规RNN网络的一个变体,可以学习长时间间隔的依赖关系,以下总结主要参考Understanding LSTM Networks
- 一个LSTM单元有相应的cell state()
-
遗忘门(forget layer)表示前一个单元的cell state有多少进入到当前单元
-
输入门(input gate)表示有多少新的信息进入当前单元
-
计算当前单元的cell state
-
输出门(output gate)表示当前单元的输出值