请选择 进入手机版 | 继续访问电脑版
MSIPO技术圈 首页 IT技术 查看内容

手搓GPT系列之 - 通过理解LSTM的反向传播过程,理解LSTM解决梯度消失的原理 - 逐条解释LSTM创始论文全部推导公式,配超多图帮助理解(下篇)

2023-07-13

本文承接上篇上篇在此和中篇中篇在此,继续就Sepp Hochreiter 1997年的开山大作 Long Short-term Memory 中APPENDIX A.1和A.2所载的数学推导过程进行详细解读。希望可以帮助大家理解了这个推导过程,进而能顺利理解为什么那几个门的设置可以解决RNN里的梯度消失和梯度爆炸的问题。中篇介绍了各个权重的误差更新算法。本篇将继续说明梯度信息在LSTM的记忆单元中经过一定的时间步之后如何变化,并由此证明LSTM可实现CEC(Constant Error Carousel)。本篇为整个文章的终章,也是最关键的一篇,因为此篇正是理解LSTM实现CEC的关键。一家之言,若有任何错漏欢迎大家评论区指正。好了,Dig in!

6. 误差流

我们将计算误差值在记忆单元上流过 q q q时间步之后(也称误差流error flow)的变化情况。

6.1 记忆单元输出点的误差值计算

已知记忆单元的计算公式:
s c j ( t ) = s c j ( t − 1 ) + g ( n e t c j ( t ) ) y i n j ( t ) s_{c_j}(t) = s_{c_j}(t-1) + g(net_{c_j}(t)) y^{in_j}(t) scj(t)=scj(t1)+g(netcj(t))yinj(t)
我们使用截断求导规则来计算误差在时间步 t − k t-k tk t − k − 1 t-k-1 tk1之间的变化情况:
∂ s c j ( t − k ) ∂ s c j ( t − k − 1 ) = 1 + ∂ g ( n e t c j ( t − k ) ) y i n j ( t − k ) ∂ s c j ( t − k − 1 ) = 1 + ∂ y i n j ( t − k ) ∂ s c j ( t − k − 1 ) g ( n e t c j ( t − k ) ) + ∂ g ( n e t c j ( t − k ) ) ∂ s c j ( t − k − 1 ) y i n j ( t − k ) = 1 + ∑ u [ ∂ y i n j ( t − k ) ∂ y u ( t − k − 1 ) ∂ y u ( t − k − 1 ) ∂ s c j ( t − k − 1 ) ] g ( n e t c j ( t − k ) ) + y i n j ( t − k ) g ′ ( n e t c j ( t − k ) ) ∑ u [ ∂ n e t c j ( t − k ) ∂ y u ( t − k − 1 ) ∂ y u ( t − k − 1 ) ∂ s c j ( t − k − 1 ) ] ≈ t r 1. (30) \begin{aligned} \frac{\partial s_{c_j}(t-k)}{\partial s_{c_j}(t-k-1)} &= 1 + \frac{\partial g(net_{c_j}(t-k))y^{in_j}(t-k)}{\partial s_{c_j}(t-k-1)}\\ &=1+ \frac{\partial y^{in_j}(t-k)}{\partial s_{c_j}(t-k-1)}g(net_{c_j}(t-k)) + \frac{\partial g(net_{c_j}(t-k))}{\partial s_{c_j}(t-k-1)}y^{in_j}(t-k)\\ &=1 + \sum_u[\frac{\partial y^{in_j}(t-k)}{\partial y^u(t-k-1)}\frac{\partial y^u(t-k-1)}{\partial s_{c_j}(t-k-1)}]g(net_{c_j}(t-k)) \\ &\quad + y^{in_j}(t-k)g'(net_{c_j}(t-k))\sum_u [\frac{\partial net_{c_j}(t-k)}{\partial y^u(t-k-1)}\frac{\partial y^u(t-k-1)}{\partial s_{c_j}(t-k-1)}]\\ &\approx_{tr} 1.\tag{30} \end{aligned} scj(tk1)scj(tk)=1+scj(tk1)g(netcj(tk))yinj(tk)=1+scj(tk1)yinj(tk)g(netcj(tk))+scj(tk1)g(netcj(tk))yinj(tk)=1+u[yu(tk1)yinj(tk)scj(tk1)yu(tk1)]g(netcj(tk))<

相关阅读

手机版|MSIPO技术圈 皖ICP备19022944号-2

Copyright © 2023, msipo.com

返回顶部