0%

RNN 的原理与推导

循环神经网络(RNN,下文简称 RNN)是一种专门处理序列的神经网络。RNN 的特点是利用时序的信息,它被称为循环的(recurrent)原因就是它会对一个序列的每一个元素执行同样的操作,并且之后的输出依赖于之前的计算。我们可以认为 RNN 有一些“记忆”能力,它能捕获之前计算过的一些信息。理论上RNN能够利用任意长序列的信息,但是实际中它能记忆的长度是有限的。

1. RNN 结构与原理

上图展示了 RNN 的基本结构。RNN 是由输入层、隐藏层和输出层构成。\(t\) 时刻的隐藏层 \(s\) 由输入 \(x\) 和上一层隐藏层 \(s_{t-1}\) 决定。而输出 \(o\) 是由隐藏层 \(s\) 乘以权重矩阵得出的。RNN 隐藏层是这样计算的: \[ \begin{aligned} s_t & = g_1(U \cdot x_t + W \cdot s_{t-1} + b_1) \end{aligned} \] 其中 \(g_1\)为隐藏层激活函数,一般使用 \(tanh\) 函数或者 \(ReLU\) 函数。而输出层的计算公式是这样的: \[ \begin{aligned} o_t & = g_2(V \cdot s_t + b_2) \end{aligned} \] 其中 \(g_2\) 为输出层激活函数,一般用 \(softmax\) 函数。

2. RNN 存在的问题

RNN 存在梯度消失和梯度爆炸的问题。首先 RNN 的损失函数为: \[ \begin{aligned} L & = \sum_{t=0}^{T} L_t \end{aligned} \]

\(t\) 时刻对 \(W\) 的偏导: \[ \begin{aligned} \frac{\partial L_t}{\partial W} & = \frac{\partial L_t}{\partial o_t} \frac{\partial o_t}{\partial s_t} \frac{\partial s_t}{\partial W} + \frac{\partial L_{t}}{\partial o_{t}} \frac{\partial o_t}{\partial s_t} \frac{\partial s_t}{\partial s_{t-1}} \frac{\partial s_{t-1}}{\partial W} + \cdots + \frac{\partial L_{t}}{\partial o_{t}} \frac{\partial o_t}{\partial s_t} \frac{\partial s_t}{\partial s_{t-1}} \frac{\partial s_{t-1}}{\partial s_{t-2}} \cdots \frac{\partial s_2}{\partial s_1} \frac{\partial s_1}{\partial W} \\ & = \sum_{k=0}^{t} \frac{\partial L_t}{\partial o_t} \frac{\partial o_t}{\partial s_t} (\prod_{j=k+1}^{t} \frac{\partial s_j}{\partial s_{j-1}}) \frac{\partial s_k}{\partial W} \end{aligned} \]

同理, \(t\) 时刻对 \(U\) 的偏导: \[ \begin{aligned} \frac{\partial L_t}{\partial U} & = \frac{\partial L_t}{\partial o_t} \frac{\partial o_t}{\partial s_t} \frac{\partial s_t}{\partial U} + \frac{\partial L_{t}}{\partial o_{t}} \frac{\partial o_t}{\partial s_t} \frac{\partial s_t}{\partial s_{t-1}} \frac{\partial s_{t-1}}{\partial U} + \cdots + \frac{\partial L_{t}}{\partial o_{t}} \frac{\partial o_t}{\partial s_t} \frac{\partial s_t}{\partial s_{t-1}} \frac{\partial s_{t-1}}{\partial s_{t-2}} \cdots \frac{\partial s_2}{\partial s_1} \frac{\partial s_1}{\partial U} \\ & = \sum_{k=0}^{t} \frac{\partial L_t}{\partial o_t} \frac{\partial o_t}{\partial s_t} (\prod_{j=k+1}^{t} \frac{\partial s_j}{\partial s_{j-1}}) \frac{\partial s_k}{\partial U} \end{aligned} \]

由此可见梯度爆炸和梯度消失的原因在于这部分: \[ \prod_{j=k+1}^{t} \frac{\partial s_j}{\partial s_{j-1}} \]

假设激活函数为 \(tanh\),加上 \(tanh\) 后的表达式为: \[ \begin{aligned} s_t & = tanh(U \cdot x_t + W \cdot s_{t-1} + b_1) \end{aligned} \]

那么则有: \[ \begin{aligned} \prod_{j=k+1}^{t} \frac{\partial s_j}{\partial s_{j-1}} & = \prod_{j=k+1}^{t} {tanh}' W \end{aligned} \]

由于 \({tanh}'\) 的值总是小于 1,所以

  • \(W > 1\) 时,随着 \(t\) 增大, \(\prod_{j=k+1}^{t} {tanh}' W\) 越大,所以导致梯度爆炸
  • \(0 < W < 1\) 时,随着 \(t\) 增大,\(\prod_{j=k+1}^{t} {tanh}' W\) 越小,所以导致梯度消失

3. 梯度爆炸和梯度消失的解决方法

梯度消失解决方法:

  • 合理初始化权重值
  • 使用 \(relu\) 替代 \(sigmoid\)\(tanh\) 作为激活函数
  • 使用其他结构的 RNNs (例如 LSTM)

梯度爆炸解决方法:

  • 梯度截断(设置一个梯度阈值,当梯度超过这个阈值时,可以直接截取)

4. 代码示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()

self.hidden_size = hidden_size

self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)

def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined)
output = self.i2o(combined)
output = self.softmax(output)
return output, hidden

参考