0%

LSTM 的原理与推导

长短期记忆网络(Long Short-Term Memory,也称 LSTM)是一种时间递归神经网络。LSTM 适合基于时间序列数据进行分类,处理和预测。因为在时间序列中的重要事件之间,可能存在未知持续时间的滞后。LSTM 的提出是为了缓解 RNN 的梯度爆炸和梯度消失的问题。

1. LSTM 结构与原理

从图中可以看出,LSTM 内部是由遗忘门、输入门、输出门和单元状态组成。其中 \(\sigma\) 为 sigmoid 激活函数,\(*\) 为矩阵点乘。

  • 遗忘门

遗忘门决定上一个单元状态 \(C_{t-1}\) 中哪些信息进入下一个单元状态。遗忘门的计算公式: \[ \begin{aligned} f_t & = \sigma (W_f \cdot [h_{t-1},x_t] + b_f) \end{aligned} \]

  • 输入门

输入门确定哪些新信息能够被存放到 \(\tilde{C_t}\) 中 。输入门计算公式: \[ \begin{aligned} i_t &= \sigma (W_i \cdot [h_{t-1}, x_t] + b_i) \\ \tilde{C_t} &= \tanh (W_C \cdot [h_{t-1}, x_t] + b_C) \end{aligned} \]

  • 单元状态

单元状态是由遗忘门和输入门组成,它决定了 \(C_{t-1}\)\(\tilde{C_t}\) 中哪些信息能更新到 \(C_t\) 中。单元状态计算公式: \[ \begin{aligned} C_t = f_t * C_{t-1} + i_t * \tilde{C_t} \end{aligned} \]

  • 输出门

输出门决定单元状态 \(C_t\) 中哪些信息能够进入到隐藏层状态 \(h_t\) 中。输出门的计算公式: \[ \begin{aligned} o_t &= \sigma (W_o \cdot [h_{t-1}, x_t] + b_o) \\ h_t &= o_t * \tanh (C_t) \end{aligned} \]

2. LSTM 如何缓解 RNN 的梯度爆炸和梯度消失

  • 引用门控机制
    • 遗忘门:控制继续保持长期状态 \(C\)
    • 输入门:控制把即使状态输入到长期状态 \(C\)
    • 输出门:控制是否把长期状态 \(C\) 作为当前的 LSTM 的输出
  • 原理:门实际上就是一层全连接层,它的输入是一个向量,输出是一个 0 到 1 之间的实数向量。

3. 代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class LSTMTagger(nn.Module):

def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
super(LSTMTagger, self).__init__()
self.hidden_dim = hidden_dim

self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)

# The LSTM takes word embeddings as inputs, and outputs hidden states
# with dimensionality hidden_dim.
self.lstm = nn.LSTM(embedding_dim, hidden_dim)

# The linear layer that maps from hidden state space to tag space
self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

def forward(self, sentence):
embeds = self.word_embeddings(sentence)
lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
tag_scores = F.log_softmax(tag_space, dim=1)
return tag_scores