科技行者報(bào)道
來源:WILDML
作者:Denny Britz
編譯:科技行者
前一部分中,我們介紹了如何在Python和Theano框架下實(shí)現(xiàn)RNN,但還未深入了解時(shí)序反向傳播算法(BPTT)是如何計(jì)算梯度的。
這周,我們將簡單介紹BPTT,并解釋其與傳統(tǒng)反向傳播的區(qū)別。我們還將了解梯度消失問題,這也是推動(dòng)LSTM(長短時(shí)記憶)和GRU(門控循環(huán)單元)(目前在NLP和其他領(lǐng)域最流行且有效模型)發(fā)展的原因。
1991年,梯度消失問題最早由Sepp Hochreiter發(fā)現(xiàn),又因深度框架的廣泛應(yīng)用再次受到關(guān)注。
以下是本系列教程的四個(gè)部分:
1.循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)的基本介紹
2.在Python和Theano框架下實(shí)現(xiàn)RNN
3.基于時(shí)間的反向傳播算法(BPTT)和梯度消失問題(本部分)
4.建立基于門控循環(huán)單元(GRU)或者長短時(shí)記憶(LSTM)的RNN模型
說明:為完全掌握本部分教程,建議您對(duì)偏微分(也稱偏導(dǎo)數(shù))和基本反向傳播的工作原理有所了解,以下是三篇關(guān)于反向傳播算法的教程供大家參考:
http://cs231n.github.io/optimization-2/
http://colah.github.io/posts/2015-08-Backprop/
http://neuralnetworksanddeeplearning.com/chap2.html
-1-
時(shí)序反向傳播算法 (BPTT)
先來快速回憶一下RNN的基本方程。注意,為了和要引用的文獻(xiàn)保持一致,這里我們把o改成了
。同樣,將損失函數(shù)定義為交叉熵?fù)p失函數(shù),如下所示:
在這里,y_t是表示的是時(shí)間步t上的正確標(biāo)簽,
是我們的預(yù)測(cè)。通常我們會(huì)將一個(gè)完整的句子序列視作一個(gè)訓(xùn)練樣本,因此總誤差即為各時(shí)間步(單詞)的誤差之和。▲RNN反向傳播
別忘了,我們的目的是要計(jì)算誤差對(duì)應(yīng)的參數(shù)U、V和W的梯度,然后借助SDG算法來更新參數(shù)。當(dāng)然,我們統(tǒng)計(jì)的不只是誤差,還包括訓(xùn)練樣本在每時(shí)間步的梯度:
▲RNN的結(jié)構(gòu)圖
我們借助導(dǎo)數(shù)的鏈?zhǔn)椒▌t來計(jì)算梯度。從最后一層將誤差向前傳播的思想,即為反向傳播。本文后續(xù)部分將以E3為例繼續(xù)介紹:
由上可知,z_3 =Vs_3,
為兩個(gè)矢量的外積。為了讓大家更好理解,這里我省略了幾個(gè)步驟,你可以試著自己計(jì)算這些導(dǎo)數(shù)。我想強(qiáng)調(diào)的是,的值僅取決于當(dāng)前時(shí)間步的值。有了這些值,計(jì)算參數(shù)V的梯度就是簡單的矩陣相乘了。▲鏈?zhǔn)角髮?dǎo)式子1
其中,s_3 = \tanh(Ux_t + Ws_2) 取決于s_2,而s_2則取決于W和s_1,以此類推。因此,如果要推導(dǎo)參數(shù)W,就不能簡單將s_2視作常量,需要再次應(yīng)用鏈?zhǔn)椒▌t,真正得到的是:
▲鏈?zhǔn)角髮?dǎo)式子2
上面的式子用到了復(fù)合函數(shù)的鏈?zhǔn)角髮?dǎo)法則,將每個(gè)時(shí)間步長對(duì)梯度的貢獻(xiàn)相加。換言之,由于參數(shù)W時(shí)間步長應(yīng)用于想要的輸出,因此需從t=3開始通過所有網(wǎng)絡(luò)路徑到t=0進(jìn)行反向傳播梯度:▲BPTT復(fù)合函數(shù)鏈?zhǔn)角髮?dǎo)
5個(gè)時(shí)間步梯度的遞歸神經(jīng)網(wǎng)絡(luò)展開圖
請(qǐng)注意,這與我們?cè)谏疃壬窠?jīng)網(wǎng)絡(luò)中應(yīng)用的標(biāo)準(zhǔn)反向傳播算法完全一致。主要區(qū)別在于我們對(duì)每時(shí)間步的參數(shù)W的梯度進(jìn)行了求和。傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)(RNN)中,我們不在層與層之間共享參數(shù),也就無需求和。但就我而言,BPTT不過是標(biāo)準(zhǔn)反向傳播在展開RNN上的別稱。好比在反向傳播算法中,可以定義一個(gè)反向傳播的delta矢量,例如:基于z_2 = Ux_2+ Ws_1的
直接實(shí)現(xiàn)BPTT的代碼如下:
▲Tip:點(diǎn)擊圖片可看大圖
該代碼解釋了難以訓(xùn)練RNN的原因:因?yàn)樾蛄校ň渥樱┖荛L,可能由20個(gè)或以上單詞組成,因此需反向傳播多層網(wǎng)絡(luò)。在實(shí)際操作時(shí),許多人會(huì)在反向傳播數(shù)步后進(jìn)行截?cái)喑杀容^長的步驟,正如上面代碼中的bptt_truncate參數(shù)定義的那樣。
-2-
梯度消失問題
本教程的前面章節(jié)提到過RNN中,相隔數(shù)步的單詞間難以形成長期依賴的問題。而英文句子的句意通常取決于相隔較遠(yuǎn)的單詞,例如“The man who wore a wig on his head went inside”的語意重心在于一個(gè)人走進(jìn)屋里,而非男人戴著假發(fā)。但普通的RNN難以捕獲此類信息。那么不妨通過分析上面計(jì)算出的梯度來一探究竟:
別忘了
然而,上述雅克比矩陣中2范數(shù)(可視為絕對(duì)值)的上限是1(此處不做證明)。直觀上,tanh激活函數(shù)將所有的值映射到-1到1這個(gè)區(qū)間,導(dǎo)數(shù)值也小于等于1(sigmoi函數(shù)的導(dǎo)數(shù)值小于等于1/4):
▲tanh及其導(dǎo)數(shù)。圖片源自:http://nn.readthedocs.org/en/rtd/transfer/
可以看到tanh和sigmoid函數(shù)在兩端的導(dǎo)數(shù)均為0,近乎呈直線狀(導(dǎo)數(shù)為0,函數(shù)圖像為直線),此種情況下可稱相應(yīng)的神經(jīng)元已經(jīng)飽和。兩函數(shù)的梯度為0,使前層的其它梯度也趨近于0。由于矩陣元素?cái)?shù)值較小,且矩陣相乘數(shù)次(t - k次)后,梯度值迅速以指數(shù)形式收縮(意思相近于,小數(shù)相乘,數(shù)值收縮,越來越小),最終在幾個(gè)時(shí)間步長后完全消失?!拜^遠(yuǎn)”的時(shí)間步長貢獻(xiàn)的梯度變?yōu)?,這些時(shí)間段的狀態(tài)不會(huì)對(duì)你的學(xué)習(xí)有所貢獻(xiàn):你最終還是無法學(xué)習(xí)長期依賴。梯度消失不僅存在于循環(huán)神經(jīng)網(wǎng)絡(luò),也出現(xiàn)在深度前饋神經(jīng)網(wǎng)絡(luò)中。區(qū)別在于,循環(huán)神經(jīng)網(wǎng)絡(luò)非常深(本例中,深度與句長相同),因此梯度消失問題更為常見。
不難想象,如果雅克比矩陣的值非常大,參照激活函數(shù)及網(wǎng)絡(luò)參數(shù)可能會(huì)出現(xiàn)梯度爆炸,即所謂的梯度爆炸問題。相較于梯度爆炸,梯度消失問題更受關(guān)注,主要有兩個(gè)原因:其一,梯度爆炸現(xiàn)象明顯,梯度會(huì)變成Nan(而并非數(shù)字),并出現(xiàn)程序崩潰;其二,在預(yù)定義閾值處將梯度截?cái)啵ㄔ斍檎?qǐng)見本文章)是一種解決梯度爆炸問題簡單有效的方法。而梯度消失問題更為復(fù)雜,因?yàn)槠洮F(xiàn)象不明顯,且解決方案尚不明確。
幸運(yùn)的是,目前有一些方法可解決梯度消失問題。合理初始化矩陣 W可緩解梯度消失現(xiàn)象。還可采用正則化方法。此外,更好的方法是使用 ReLU,而非tanh或sigmoid激活函數(shù)。ReLU函數(shù)的導(dǎo)數(shù)是個(gè)常量,0或1,因此不太可能出現(xiàn)梯度消失現(xiàn)象。
更常用的方法是借助LSTM或GRU架構(gòu)。1997年,首次提出LSTM ,目前該模型在NLP領(lǐng)域的應(yīng)用極其廣泛。GRU則于2014年問世,是LSTM的簡化版。這些循環(huán)神經(jīng)網(wǎng)絡(luò)旨在解決梯度消失和有效學(xué)習(xí)長期依賴問題。相關(guān)介紹請(qǐng)見本教程下一部分。
聯(lián)系客服