CS224N笔记(三) Lecture 6~7——深入理解循环神经网络RNN模型

NLP连载系列:

  1. NLP入门:文本特征的表示方式
  2. 命名实体分类NER初识
  3. n-gram的原理
  4. CS224N笔记(一) Lecture1~2——word2vec超详细解析
  5. CS224N笔记(二) Lecture1~2 :深入理解Glove原理
  6. CS224N笔记(三) Lecture 6~7——深入理解循环神经网络RNN模型
  7. CS224N笔记(四) Lecture 7:循环神经网络RNN的进阶——LSTM与GRU
  8. CS224N笔记(五) Lecture8 机器翻译、Seq2Seq以及Attention注意力机制

本文将从语言模型的概念出发,引出循环神经网络RNN的概念,对RNN的结构进行描述,详细推导了梯度计算过程,并解释RNN容易出现梯度消失、梯度爆炸的原因。文章的最后对RNN的应用场景进行了简单的介绍。

一、背景知识

1. 语言模型

在讲解循环神经网络RNN之前,先来回顾一下什么是统计语言模型,之前说到统计语言模型就是用以计算一个句子的概率的模型,简而言之就是判断一句话“是不是正常人说的”,常常会挖空语料中的某些位置,要求预测该位置填什么词最合适。其模型可以归结为:

p(W)=p(w1T)=p(w1,w2,...,wT)=p(w1)p(w2w1)p(w3w12)p(wTw1T1)p(W)=p(w_1^T)=p(w_1,w_2,...,w_T)=p(w_1) \cdot p(w_2|w_1) \cdot p(w_3|w_1^2) \dots p(w_T|w_1^{T-1})

上面这条公式中,W表示一个句话,它是由单词w1,w2,...,wTw_1,w_2,...,w_T按顺序排列而构成的,再次强调见到w1Tw_1^T不要认为是T幂次,它就是表示首单词为w1w_1,长度为T,末尾单词为wTw_T的一句话。

2.n-gram

上面的公式其实可以表示一切语言模型,如果我们是以统计信息来构建模型,那么就叫做统计语言模型,比如n-gram,其表述形式为:

p(wkw1k1)count(w1k)count(w1k1)p(w_k|w_1^{k-1})\approx\frac{count(w_1^k)}{count(w_1^{k-1})}

在深度学习未应用与NLP时,n-gram模型是非常流行的,但它存在以下问题

  1. 稀疏性问题:随着n增大,语料库中出现特定的连续词语组合的可能性会越小,看你从词表中找不到特定的连续词语组合。
  2. 存储问题:需要存储一个非常大的共现矩阵,且还会随着n增大而增大。
  3. n-gram通常无法捕捉深层的语义信息。

3. 固定窗口神经语言模型

为了解决这一问题,提出了固定窗口神经语言模型,它是通过将一个窗口内的词语的词向量拼接起来,送入全连接神经网络,最后通过softmax函数预测概率。这里的词向量one-hot,但其实用word2vec的词向量应该也是可以的。

RNn/Untitled.png

这一模型的好处在于:

  1. 不存在稀疏性问题
  2. 不需要存储所有的n-gram

但是它也有缺点:

  1. 固定窗口不能太小也不能太大,太小捕捉不到上下文信息,太大拼接的向量会异常大,对机器要求很高。
  2. 虽说每个词向量拼接在了一起,但是它们分别对应参数矩阵W的一条向量,参数不共享。这样不一定合理,因为它们其实都在做一件事,我们希望参数具有泛化性,对任意位置的处理是一致。
  3. 窗口大小是固定的,一旦确定不能更改。

二、基本结构

上面说了固定窗口神经语言模型的不足之处,我们希望找到了一个更强大的模型,它应当具备以下特定:

  1. 不受窗口限制,能够处理不定长的输入
  2. 对于每一个位置上的词都可以用同一套参数进行处理

循环神经网络可以满足以上要求,它的模型结构如下:

RNn/Untitled%201.png

模型的真正输入是各个词的词向量,模型中有一个隐状态hh,输出是预测词的概率y^\hat{y}。模型的参数为输入层WeW_e、隐含层WhW_h、输出层UU,注意EE不算是可学习参数,它负责从将词的one-hot向量映射到词向量,本质上是个查表的工作,相当于tensorflow中tf.embedding_lookup操作。

词向量依次送入到模型中,每个词向量ee在模型中首先和输入层WeW_e相乘,再加上另一条支路上隐状态hh与隐含层WhW_h的相乘结果,然后再加上一个偏置bb,最终的和送入非线性激活函数σ\sigma;输出的结果作为新的隐状态传递下去,参与下一轮计算,在下一轮的计算是以下一个词向量作为输入,以此不断循环,最终将隐状态hh与输出矩阵相乘,加上偏置后进行softmax计算,输出概率值。

这里每一轮计算的参数WeW_eWhW_hUU是一样的,中间的隐状态h(1)h^{(1)}h(2)h^{(2)}h(3)h^{(3)}如有需要也是可以输出,每一轮都输出一个结果,且这个循环可以一直进行下去,尽管实际操作中不会让它一直循环下去,因为会遇到别的问题——计算量巨大、梯度消失。

三、训练与优化

1. 损失函数

RNn/Untitled%202.png

在RNN的训练过程中,每一轮都会计算损失函数J(1)(θ)J^{(1)}(\theta)J(2)(θ)J^{(2)}(\theta)J(3)(θ)J^{(3)}(\theta)J(4)(θ)J^{(4)}(\theta),最后将他们相加得到最终的损失函数:

J(θ)=1Tt=1TJ(t)(θ)J(\theta) = \frac{1}{T}\sum_{t=1}^TJ^{(t)}(\theta)

每一轮的损失函数J(t)(θ)J^{(t)}(\theta)均采用交叉熵函数,即:

J(t)(θ)=CE(y(t),y^(t))=wV yw(t)logy^w(t)=log y^w(t)J^{(t)}(\theta) = CE(y^{(t)},\hat{y}^{(t)}) = -\sum_{w \isin V}\ y_w^{(t)}log\hat{y}_w^{(t)} = -log\ \hat{y}_w^{(t)}

其中y(t)y^{(t)}表示真实值,它是x(t+1)x^{(t+1)}所对应的one-hot向量,我们希望输入tt时刻前的词,能预测得到t+1t+1时刻的词。

2. 梯度计算

接下来思考怎么计算梯度进行反向传播,这是RNN中的一个难点中的难点,这里课堂上讲得不是很清晰,补充材料notes里的符号又和课件slides里的不一致,很容易让人迷惑,因此下面我会根据课件和补充材料中内容,重新组织语言进行讲解,不完全遵循课件和补充材料中的顺序和符号。

根据前面的损失函数:

J(θ)=1Tt=1TJ(t)(θ)(1)J(\theta) = \frac{1}{T}\sum_{t=1}^TJ^{(t)}(\theta) \tag{1}

很容易可以得到它关于W_h的梯度:

JWh=1Tt=1TJ(t)Wh(2)\frac{\partial J}{\partial W_h} = \frac{1}{T}\sum_{t=1}^T \frac{\partial J^{(t)}}{\partial W_h} \tag{2}

现在问题来了,那么tt时刻的梯度J(t)Wh\frac{\partial J^{(t)}}{\partial W_h}应该怎么计算?结论很简单,它等于J(t)J^{(t)}在1~tt时刻关于WhW_h的梯度之和,即:

J(t)Wh=i=1tJ(t)Whi(3)\frac{\partial J^{(t)}}{\partial W_h} = \sum_{i=1}^t \frac{\partial J^{(t)}}{\partial W_h} \bigg| _{i} \tag{3}

右边每个求和项的下标ii指的是第ii个时刻或者第ii轮,即J(t)J^{(t)}对第ii个时刻的WhW_h的梯度。这个公式是怎么来的呢?由于J(t)J^{(t)}tt时刻前每个时刻的参数矩阵Wh1Wh2...WhtW_h|_1 、 W_h|_2 、 ... 、 W_h|_t都有关,根据多元函数的链式法则,可以得到下面式子:

J(t)Wh=i=1tJ(t)WhiWhiWh=i=1tJ(t)Whi×1(4)\frac{\partial J^{(t)}}{\partial W_h} = \sum_{i=1}^t \frac{\partial J^{(t)}}{\partial W_h} \bigg| _{i} \frac{\partial W_h|_i}{\partial W_h} = \sum_{i=1}^t \frac{\partial J^{(t)}}{\partial W_h} \bigg| _{i} \times 1 \tag{4}

这条式子怎么来的?首先要明确,RNN在训练时对一段语料进行前向传播,如果语料长度为TT就会经历了TT个时刻,之后再将每个时刻tt的损失叠加起来求梯度反向传播,t=1Tt=1~T时刻内,参数矩阵是一直没有更新的,也即是同一个矩阵WhW_h。那么i=1ti=1~t时刻内,因为它们是11TT内的一个子时段,参数矩阵WhiW_h|_i肯定一直不变的,也即Wh1=Wh2=...=Wht=WhW_h|_1 = W_h|_2 = ... = W_h|_t = W_h,因此WhiWh=1\frac{\partial W_h|_i}{\partial W_h}=1,第二项可以被忽略掉。

现在问题转化为J(t)Whi\frac{\partial J^{(t)}}{\partial W_h} \bigg| _{i}要怎么计算?这个就比较复杂了,这里需要用到链式法则:

J(t)Whi=J(t)y^(t)y^(t)h(t)h(t)h(i)h(i)Whi\frac{\partial J^{(t)}}{\partial W_h} \bigg| _{i} = \frac{\partial J^{(t)}}{\partial \hat{y}^{(t)}} \frac{\partial \hat{y}^{(t)}}{\partial h^{(t)}} \frac{\partial h^{(t)}}{\partial h^{(i)}} \frac{\partial h^{(i)}}{\partial W_h|_i}

上面提到说Whi=WhW_h|_i = W_h,因此下面就都简写成WhW_h,上面的式子可以写作:

J(t)Whi=J(t)y^(t)y^(t)h(t)h(t)h(i)h(i)Wh(5)\frac{\partial J^{(t)}}{\partial W_h} \bigg| _{i} = \frac{\partial J^{(t)}}{\partial \hat{y}^{(t)}} \frac{\partial \hat{y}^{(t)}}{\partial h^{(t)}} \frac{\partial h^{(t)}}{\partial h^{(i)}} \frac{\partial h^{(i)}}{\partial W_h} \tag{5}

这个公式中第1、2、4项都很容易求得,关键是第三项h(t)h(i)\frac{\partial h^{(t)}}{\partial h^{(i)}}应该怎么求?我们还是可以用链式法则拆开它,但可以注意到它是与时刻ii有关的,可以想象当i=t1i=t-1,那就向前追溯1个时刻,当i=t2i=t-2时,要向前追溯2个时刻,以此类推,i=1i=1的话要向前追溯两个时刻,也即h(t)h(i)\frac{\partial h^{(t)}}{\partial h^{(i)}}需要向前追溯(ti)(t-i)个时刻,写成公式的话可以表示成:

h(t)h(i)=j=i+1th(j)h(j1)(6)\frac{\partial h^{(t)}}{\partial h^{(i)}} = \prod_{j=i+1}^t \frac{\partial h^{(j)}}{\partial h^{(j-1)}} \tag{6}

至此,损失JJWhW_h的梯度可以写成:

JWh=1Tt=1TJ(t)Wh=1Tt=1Ti=1tJ(t)Whi=1Tt=1Ti=1tJ(t)y^(t)y^(t)h(t)h(t)h(i)h(i)Wh=1Tt=1Ti=1t(J(t)y^(t)y^(t)h(t)(j=i+1th(j)h(j1))h(i)Wh)(7)\begin{aligned} \frac{\partial J}{\partial W_h} &= \frac{1}{T}\sum_{t=1}^T \frac{\partial J^{(t)}}{\partial W_h} \\ &= \frac{1}{T}\sum_{t=1}^T\sum_{i=1}^t \frac{\partial J^{(t)}}{\partial W_h} \bigg|_{i}\\ &=\frac{1}{T}\sum_{t=1}^T\sum_{i=1}^t \frac{\partial J^{(t)}}{\partial \hat{y}^{(t)}} \frac{\partial \hat{y}^{(t)}}{\partial h^{(t)}} \frac{\partial h^{(t)}}{\partial h^{(i)}} \frac{\partial h^{(i)}}{\partial W_h} \\ &=\frac{1}{T}\sum_{t=1}^T\sum_{i=1}^t (\frac{\partial J^{(t)}}{\partial \hat{y}^{(t)}} \frac{\partial \hat{y}^{(t)}}{\partial h^{(t)}} (\prod_{j=i+1}^t \frac{\partial h^{(j)}}{\partial h^{(j-1)}} )\frac{\partial h^{(i)}}{\partial W_h}) \end{aligned} \tag{7}

接下来其实可以继续追问h(j)h(j1)\frac{\partial h^{(j)}}{\partial h^{(j-1)}}怎么求解,我们回顾之前的图示:

RNn/Untitled%203.png

从图中我们可以得知h(j)h^{(j)}h(j1)h^{(j-1)}具有直接关系:

h(j)=σ(Whh(j1)+Wee(t)+b1)(8)h^{(j)} = \sigma(W_hh^{(j-1)}+W_ee^{(t)} + b_1) \tag{8}

则它们间的梯度也很容易求得:

h(j)h(j1)=diag(σ(Whh(j1)+Wee(t)+b1))×Wh(9)\frac{\partial h^{(j)}}{\partial h^{(j-1)}} = diag(\sigma'(W_hh^{(j-1)}+W_ee^{(t)} + b_1)) \times W_h \tag{9}

其中diag()diag(*)表示对角矩阵,对角线中的值即为*,这条式子在补充材料notes中有些许不同,但是本质上是一致的。

至此,RNN关于W_h的梯度计算就完成了,关于$$W_e$$的梯度计算也是类似,这里就不再赘述。

3. 梯度消失和梯度爆炸

可以看到(7)(7)式相当复杂,最关键的地方是两个不同时间隐状态间的梯度h(t)h(i)\frac{\partial h^{(t)}}{\partial h^{(i)}},需要从t时刻一直地追溯到i时刻,而且要进行多次这样的追溯。我们结合(6)(6)(8)(8)可将h(t)h(i)\frac{\partial h^{(t)}}{\partial h^{(i)}}写做:

h(t)h(i)=j=i+1th(j)h(j1)=j=i+1tdiag(σ(Whh(j1)+Wee(t)+b1))×Wh(10)\frac{\partial h^{(t)}}{\partial h^{(i)}} = \prod_{j=i+1}^t \frac{\partial h^{(j)}}{\partial h^{(j-1)}} = \prod_{j=i+1}^t diag(\sigma'(W_hh^{(j-1)}+W_ee^{(t)} + b_1)) \times W_h \tag{10}

接下来将结合该式子讲述了为什么会RNN特别容易梯度消失和梯度爆炸,CS224N的课件slides和补充材料notes是从两个角度来进行解释的,下面将分别对两者的思路进行讲解,在推导过程中为了保持本文符号的一致性,可能与课件材料有些出入。

课件中的解释:

为了简化问题,我们假设激活函数σ\sigma为恒等映射即σ(x)=x\sigma(x)=x,则σ=1\sigma'=1,公式(10)(10)可以改写成:

h(t)h(i)=j=i+1tI×Wh=j=i+1tWh=Whti(11)\frac{\partial h^{(t)}}{\partial h^{(i)}} = \prod_{j=i+1}^t I \times W_h = \prod_{j=i+1}^t W_h = W_h^{t-i} \tag{11}

也即是矩阵WhW_h连续自乘了(ti)(t-i)次,由于输入的词向量ee和隐状态hh一般都是保持同样的维度,因此WhW_h一定是方阵,不用担心自乘时维度对不上。假设矩阵WhW_h的特征值和特征向量分别为:

特征值:   λ1,λ2,...,λn特征向量:   q1,q2,...,qn\begin{aligned}\text{特征值:} \ \ \ \lambda_1, \lambda_2,...,\lambda_n \\ \text{特征向量:}\ \ \ q_1,q_2,...,q_n\end{aligned}

根据线性代数的知识,有:

Wh=PΛP1P是特征向量组成的矩阵,Λ是特征值组成的对角矩阵W_h = P\Lambda P^{-1},P是特征向量组成的矩阵,\Lambda是特征值组成的对角矩阵

进一步可以推出:

h(t)h(i)=Whti=PΛtiP1(12)\frac{\partial h^{(t)}}{\partial h^{(i)}}= W_h^{t-i} = P\Lambda^{t-i} P^{-1} \tag{12}

假设WhW_h的特征值全都小于1,那么一旦两个词相隔越远,或说两个时刻(ti)(t-i)相隔越长,对角矩阵Λti\Lambda^{t-i}上的元素会越乘越小,接近于0,那h(t)h(i)\frac{\partial h^{(t)}}{\partial h^{(i)}}自然也接近于零矩阵,再跟其他的矩阵或向量相乘也会存在大量的零,也就是梯度几乎都为零,这就是梯度消失的原因。反过来,如果说WhW_h的特征值全都大于1,对角矩阵Λti\Lambda^{t-i}上的元素会越乘越大,之后其他矩阵内的元素也会变得越来越大,甚至出现NaN值,这就是梯度爆炸的原因。上面在推导前是假设激活函数为恒等映射,但其实换成别的激活函数也一样会出现梯度消失或梯度爆炸,因为公式12中参数矩阵WhW_h的指数形式依然存在。

补充材料中的讲解

这次从公式(6)(6)出发,如果我们考虑矩阵的模,那么从公式(10)(10)可以得知:

h(j)h(j1)diag(σ(Whh(j1)+Wee(t)+b1))×WhβWβh(13)||\frac{\partial h^{(j)}}{\partial h^{(j-1)}}|| \le ||diag(\sigma'(W_hh^{(j-1)}+W_ee^{(t)} + b_1))|| \times ||W_h|| \le \beta_W\beta_h\tag{13}

其中的β\beta只是对diag(σ(Whh(j1)+Wee(t)+b1))||diag(\sigma'(W_hh^{(j-1)}+W_ee^{(t)} + b_1))||Wh||W_h||分别进行简写。在这之后,结合公式(10)(10),可以得到

h(t)h(i)=j=i+1th(j)h(j1)(βWβh)ti(14)||\frac{\partial h^{(t)}}{\partial h^{(i)}}|| = ||\prod_{j=i+1}^t \frac{\partial h^{(j)}}{\partial h^{(j-1)}}|| \le (\beta_W\beta_h)^{t-i} \tag{14}

如果βW\beta_Wβh\beta_h小于1,即WhW_h的模与diag(σ(Whh(j1)+Wee(t)+b1))||diag(\sigma'(W_hh^{(j-1)}+W_ee^{(t)} + b_1))||的乘积小于1,那么由于指数项的作用,h(t)h(i)||\frac{\partial h^{(t)}}{\partial h^{(i)}}||同样也会变得很小,容易出现梯度消失,相反如果它们大于1,那么h(t)h(i)||\frac{\partial h^{(t)}}{\partial h^{(i)}}||会变得相当大,容易出现梯度爆炸。

这里可以注意到两个细节,一方面,根据线性代数额知识,WhW_h的模和它的特征值有密切关系,如果像上面课件中说的那样WhW_h特征值都小于1,那么它的模肯定也很小;另一方面,另外一项βh\beta_h是和激活函数相关的,上面说激活函数不采用恒等映射同样可能出现梯度消失或梯度爆炸,此话没错,但更准确地说,梯度会不会出现问题和激活函数是会存在关系的。为了缓解梯度消失,我们可以选择ReLU激活函数,尽管它相对容易引起梯度爆炸,但是对于梯度爆炸我们好歹有梯度截断方法可以缓解,比起梯度消失更加可控。

还有最后一点需要注意的是,RNN的梯度消失主要是长距离的梯度消失,公式(12)(14)(12)、(14)的指数项(ti)(t-i)需要足够大才有明显地梯度消失效应,短距离的梯度还是正常的,而总梯度是包含了长距离和短距离的梯度,所以总梯度并不是完全为0,只是模型参数的更新方向不受长时约束,这样RNN就失去了捕捉更大范围上下文信息的能力。

四、优缺点

相比于统计语言模型以及固定窗口神经语言模型,RNN的优点在于:

  1. 可以处理不定长序列
  2. 对不同位置的词向量采用同样的参数进行计算,缩减模型参数且增强参数了泛化性
  3. 模型的大小与序列长度无关
  4. 可以更好地捕捉长时信息

但是其缺点也很明显:

  1. 很容易梯度消失,且是致命弱点,普通的RNN对此没有解决办法。梯度消失会造成两个问题:

    • 长时约束的作用减弱,模型几乎只受短时约束
    • 我们无法分辨是真的没有长距离间的两个词是真的没有联系,还是我们没能捕捉到它们间的联系
  2. 很容易梯度爆炸,可以通过梯度截断缓解

    • 模型发散,无法学习到有效信息
    • 梯度截断的实现

    RNn/Untitled%204.png

五、应用与展望

  1. 词性标注,对序列中每一个词预测其词性;类似的还有NER(name entity recognition),即命名实体识别

RNn/Untitled%205.png

  1. 文本分类,比如情感分类

RNn/Untitled%206.png

  1. 作为编码模块,可以应用到QA问答系统、机器翻译等

RNn/Untitled%207.png

  1. 可以用于文本生成,如语音识别、文本翻译、文本梗概,此时的RNN是一个条件语言模型

RNn/Untitled%208.png

六、参考文献

  1. CS224N Lecture 6 slides && notes
打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!

请我喝杯咖啡吧~

支付宝
微信