Skip to content

RNN 循环神经网络

Recurrent Neural Network

是一种用于处理序列数据的神经网络架构。与传统神经网络不同,RNN 具有记忆功能,能够捕捉序列中的时间依赖关系。

RNN 的核心在于循环结构,这里以最简单的 Elman RNN 为例说明。对于一个序列输入{x1,x2,...,xT} ,在每个时间步 t,RNN 会学习到一个隐藏状态 ht

ht=f(Whht1+Uxt+b)

其中:

  • ht 是当前时刻的隐藏状态
  • ht1 是前一时刻的隐藏状态,xtht1共同决定
  • xt 是当前时刻的输入
  • WhU 是权重矩阵
  • b 是偏置项
  • f 是激活函数(通常为 tanh

结构

RNN

  • ot是时刻的输出,例如我们希望预测一个句子的下一个单词,则输出希望是我们字典中所有词的概率组成的向量ot=softmax(Vht)

TIP

传统的神经网络在每一层采用不同的参数,而 RNN 在所有步中采用共同的参数(U,V,W),这表示我们在每一步执行相同的任务,仅仅是输入不同而已。这样会缩减需要学习的参数数量

RNN 的优缺点

优点

  • 处理序列数据:RNN 能够处理任意长度的序列数据
  • 记忆能力:RNN 能够记住之前的信息,从而捕捉时间依赖关系

缺点

  • 梯度消失/爆炸问题:RNN 在训练过程中容易出现梯度消失或爆炸的问题,导致难以训练长序列
  • 计算效率低:RNN 的计算是逐步进行的,无法并行化处理

前向传播

  1. 隐藏状态:(f通常为 tanh

    ht=f(Whht1+Uxt+b)
  2. 输出:

    ot=Vht+c
  3. 预测输出:(σ通常为 softmax

    y^t=σ(ot)

反向传播

推导

对任意t时刻:

LtU=LtotoththtU=Ltotothththt1ht1U==k=1tLtototht(j=k1thjhj1)hkU

假设损失函数为交叉熵损失:

Ltot=y^tyt

交叉熵损失

Lt=iyt,ilogy^t,iy^t,i=eot,ijeot,iy^t,iot,k=y^t,i(δiky^t,k)

其中,δik 为 kronecker delta 函数

δij={1如果 i=j0如果 ij

有:

Ltot,k=iyt,ilogy^t,iot,k=iyt,i1y^t,iy^t,i(δiky^t,k)=iyt,i(δiky^t,k)

由于iyt,i=1

Ltot,k=(yt,ky^t,kiyt,i)=yt,kyt,k

所以:

Ltot=y^tyt

设:ak=Whhk1+Uxk+b

Ltat=Lththtat=Lthtf(at)

例子

h(t)=tanh(at)

假设:

Ltht=[δ1δ2]at=[a1a2]

则:

Ltat=[δ1(1tanh2(a1))δ2(1tanh2(a2))]
atU=xt

定义序列索引 t 位置的隐藏状态的梯度为:

δt=Ltht

定义最后的序列索引位置为 τ

δτ=Lhτ=V(Ltot)=V(y^τyτ)

时间 tU 的梯度贡献为:

LU|t=δtf(at)xt

U 的总梯度为:

LU=t=1τδtf(at)xt

同样的,W 的总梯度为:

LW=t=1τδtf(at)ht1

V 的总梯度为:

LV=t=1τ(y^tyt)ht

b 的总梯度为:

Lb=t=1τδtf(at)

例子

h(t)=tanh(at)

htat=diag(1ht2)Ltat=δtdiag(1ht2)LU=t=1τδtdiag(1ht2)xt(LU=t=1τδt(1ht2)xt)