Skip to content

LSTM 长短期记忆网络

Long Short-Term Memory

RNN 想把所有的信息记住,不管是有用的信息还是无用的信息。

LSTM 通过引入记忆单元门控机制来控制信息的流动,从而有效地捕捉长期依赖关系。LSTM 的核心思想是通过三个门(输入门、遗忘门、输出门)来决定哪些信息需要保留,哪些信息需要丢弃。

略去每层都有的ot,RNN 的结构可以简化为:

RNN

LSTM 对隐藏结构进行了改进:

LSTM

结构

LSTM

记忆细胞

LSTM

遗忘门

LSTM

遗忘门(forget gate)顾名思义,是控制是否遗忘的,在 LSTM 中即以一定的概率控制是否遗忘上一层的隐藏细胞状态。

ft=σ(Wfht1+Ufxt+bf)

其中,σsigmoid激活函数

输入门

LSTM

输入门(input gate)负责处理当前序列位置的输入。

it=σ(Wiht1+Uixt+bi)C~t=tanh(Wcht1+Ucxt+bc)

这里:C~t 被称为候选记忆元(candidate memory cell)

状态更新

LSTM

Ct=Ct1ft+itC~t

其中,为 Hadamard 积。

输出门

LSTM

ot=σ(Woht1+Uoxt+bo)ht=ottanh(Ct)

当输出门接近 0 时,只保留记忆元内的所有信息,而不需要更新隐状态。

TIP

遗忘门决定了我要抛弃哪些旧知识

输入门决定了我要记住哪些新知识

输出门决定了我要用到哪些知识

前向传播

  1. 更新遗忘门输出:

    f(t)=σ(Wfh(t1)+Ufx(t)+bf)
  2. 更新输入门两部分输出:

    i(t)=σ(Wih(t1)+Uix(t)+bi)C~(t)=tanh(Wch(t1)+Ucx(t)+bc)
  3. 更新细胞状态:

    C(t)=C(t1)f(t)+i(t)C~(t)
  4. 更新输出门输出:

    o(t)=σ(Woh(t1)+Uox(t)+bo)h(t)=o(t)tanh(C(t))
  5. 更新当前序列索引预测输出:

y^(t)=σ(Vh(t)+c)

反向传播

推导

假设损失函数为交叉熵损失,定义:

δh(t)=Lh(t)δC(t)=LC(t)

对于最后的序列索引位置 τ

δh(τ)=L(τ)h(τ)=L(τ)y^(τ)y^(τ)h(τ)=V(y^(τ)y(τ))δC(τ)=L(τ)C(τ)=δh(τ)o(τ)(1tanh2(C(τ)))

交叉熵损失

L(t)=y(t)logy^(t)(1y(t))log(1y^(t))L(t)y^(t)=y^(t)y(t)y^(t)(1y^(t))

对于:

y^(t)=σ(z(t))z(t)=Vh(t)+c

有:

y^(t)z(t)=y^(t)(1y^(t))z(t)h(t)=V

对于 t 时刻 δh(t) 的计算 :

  • 直接梯度

    δh(t,direct)=V(y^(t)y(t))
  • 间接梯度(来自时间步 t+1):

    • 通过输出门:δh(t,o)=Wo[δh(t+1)tanh(C(t+1))o(t+1)(1o(t+1))]
    • 通过遗忘门:δh(t,f)=Wf[δC(t+1)C(t)f(t+1)(1f(t+1))]
    • 通过输入门:δh(t,i)=Wi[δC(t+1)C~(t+1)i(t+1)(1i(t+1))]
    • 通过候选细胞状态:δh(t,C~)=Wc[δC(t+1)i(t+1)(1(C~(t+1))2)]

对于 δC(t) :

δC(t)=δC(t+1)f(t+1)+δh(t)o(t)(1tanh2(C(t)))

下面计算:

LWf=t=1τLf(t)f(t)Wf

有:

Lf(t)=LC(t)C(t)f(t)=δC(t)C(t)f(t)=δC(t)C(t1)

进一步的,令:

zf(t)=Wfh(t1)+Ufx(t)+bf

有:

f(t)Wf=f(t)zf(t)zf(t)Wf

而:

f(t)zf(t)=σ(zf(t))(1σ(zf(t)))=f(t)(1f(t))zf(t)Wf=h(t1)

f(t)的求导

f(t)=σ(z)=11+ezdσ(z)dz=1(1+ez)2(ez)=σ(z)(1σ(z))

可以得出时间 tWf 的梯度贡献为:

LWf|t=[δC(t)C(t1)f(t)(1f(t))](h(t1))

总梯度为:

LWf=t=1τ[δC(t)C(t1)f(t)(1f(t))](h(t1))