神經(jīng)機(jī)器翻譯(NMT)基于上下文預(yù)測(cè)下一個(gè)詞,依次生成目標(biāo)語(yǔ)句。在訓(xùn)練時(shí),模型以真實(shí)值作為上下文(context)進(jìn)行預(yù)測(cè),而在推理時(shí),模型必須從頭生成整個(gè)序列。這種輸入上下文的差異會(huì)導(dǎo)致錯(cuò)誤累積。此外,單詞級(jí)別(word-level)的訓(xùn)練要求生成的序列與真實(shí)序列嚴(yán)格匹配,這會(huì)導(dǎo)致模型對(duì)不同但合理的翻譯產(chǎn)生過度矯正。為了解決這一問題,研究人員提出不僅從真實(shí)值序列中采樣得到上下文詞(context word),也從模型的預(yù)測(cè)序列中采樣得到上下文詞。實(shí)驗(yàn)結(jié)果表明該方法在多個(gè)數(shù)據(jù)集上取得了顯著的改進(jìn)。
本論文斬獲 ACL 2019 最佳長(zhǎng)論文獎(jiǎng),獲獎(jiǎng)理由如下:?
該論文解決了 seq2seq 中長(zhǎng)期存在的暴露偏差問題
論文所提出的解決方案是:在“基于來(lái)自參考語(yǔ)句的詞”和“基于解碼器輸出的預(yù)選擇詞”之間切換
這個(gè)方法適用于當(dāng)前的 teacher-forcing 訓(xùn)練范式,比 scheduled sampling 有所提升
論文的實(shí)驗(yàn)非常完善,結(jié)果令人信服,該方法可能影響機(jī)器翻譯的未來(lái)
該方法也適用于其他 seq2seq 任務(wù)
大多數(shù) NMT 模型都基于編碼器 - 解碼器框架,這些模型基于之前的文本來(lái)預(yù)測(cè)下一個(gè)詞,得到目標(biāo)詞的語(yǔ)言模型。在訓(xùn)練階段,將真實(shí)詞(ground truth word)用作上下文(context)輸入,而在推理時(shí),由于整個(gè)序列由得到的模型自行生成,所以將模型生成的前一個(gè)詞用作上下文輸入。因此,訓(xùn)練和推理時(shí)的預(yù)測(cè)詞是從不同的分布中提取出來(lái)的:訓(xùn)練時(shí)的預(yù)測(cè)詞是從數(shù)據(jù)分布中提取的,而推理時(shí)的預(yù)測(cè)詞是從模型分布中提取的。這種差異稱為 暴露偏差,導(dǎo)致了訓(xùn)練和推理之間的差距。隨著目標(biāo)序列的增長(zhǎng),誤差會(huì)隨之累積,模型必須在訓(xùn)練時(shí)從未遇到的情況下進(jìn)行預(yù)測(cè)。
為了解決這個(gè)問題,模型的訓(xùn)練和推理應(yīng)該在相同的條件下進(jìn)行。受 Data As Demonstrator 方法的啟發(fā),可以在訓(xùn)練過程中將真實(shí)詞和預(yù)測(cè)詞作為上下文一同輸入網(wǎng)絡(luò)。NMT 模型通常采用交叉熵?fù)p失(cross-entropy loss)作為優(yōu)化目標(biāo),這就要求在預(yù)測(cè)序列和真實(shí)序列在單詞級(jí)別上嚴(yán)格的成對(duì)匹配。一旦模型生成一個(gè)偏離真實(shí)序列的單詞,交叉熵?fù)p失將立即糾正錯(cuò)誤,并將下一次生成拉回真實(shí)序列。然而,這導(dǎo)致了一個(gè)新的問題:一個(gè)句子通常有多個(gè)合理的翻譯,不能因?yàn)槟P彤a(chǎn)生了和真實(shí)值不同的單詞,就說(shuō)這個(gè)模型出錯(cuò)了。
參考語(yǔ)句:We should comply with the rule(我們應(yīng)該遵守規(guī)則)。候選 1:We should abide with the rule(我們應(yīng)該與規(guī)則住在一起)。候選 2:We should abide by the law(我們應(yīng)該遵守法律)。候選 3:We should abide by the rule(我們應(yīng)該尊重規(guī)律)。
一旦模型生成第三個(gè)目標(biāo)詞“abide”,交叉熵?fù)p失會(huì)迫使模型生成第四個(gè)詞“with”(如候選 1),從而具有更大的句子級(jí)別的相似性,并與參考語(yǔ)句一致,但是“by”才是正確的用法。然后,以“with”作為上下文生成“the rule”,從而模型生成的是“abide with the rule(與規(guī)則住在一起)”,這實(shí)際上是錯(cuò)誤的。候選 1 就是一種過度矯正現(xiàn)象。另一個(gè)潛在的錯(cuò)誤是,即使模型在”abide”之后預(yù)測(cè)正確的單詞“by”,在生成后續(xù)翻譯時(shí),它也可能通過輸入“by”而產(chǎn)生“the law”,這也是不恰當(dāng)?shù)模ㄈ绾蜻x 2)。假設(shè)參考語(yǔ)句和訓(xùn)練標(biāo)準(zhǔn)讓模型記住了 “the rule”始終跟在單詞“with”后面的模式。為了幫助模型從這兩種錯(cuò)誤中恢復(fù)并給出正確的翻譯(候選 3),應(yīng)該輸入“with”作為上下文詞,而不是“by”,即使之前預(yù)測(cè)的短語(yǔ)是“abide by”。此解決方案稱為過度矯正恢復(fù)(Overcorrection Recovery, OR)。
這篇論文提出了一種方法彌合訓(xùn)練與推理之間的差距,提高 NMT 過度矯正的恢復(fù)能力。該方法首先從預(yù)測(cè)詞中選擇 oracle 詞,然后從 oracle 詞和真實(shí)詞中采樣得到上下文。作者不僅采用逐詞貪婪搜索(word-by-word greedy search),而且還采用了語(yǔ)句級(jí)別(sentence-level)優(yōu)化來(lái)選擇 oracle 詞。在訓(xùn)練開始時(shí),模型大概率選擇真實(shí)詞作為上下文。隨著模型的逐漸收斂,模型更多選擇 oracle 詞作為上下文。通過這種方式,訓(xùn)練過程從完全指導(dǎo)的方案轉(zhuǎn)變?yōu)檩^少指導(dǎo)的方案。在這種機(jī)制下,模型有機(jī)會(huì)學(xué)習(xí)如何處理推理時(shí)所犯的錯(cuò)誤,也能從替換翻譯(alternative translation)的過度矯正中恢復(fù)過來(lái)。作者使用 RNNSearch 模型和 Transformer 模型進(jìn)行了驗(yàn)證。結(jié)果表明,該方法能顯著提高兩種模型的性能。
作者以基于 RNN 的 NMT 為例介紹該方法。假設(shè)源序列和觀察到的翻譯分別為 x={x1,x2,...}和 y={y1, y2, ...}。
編碼器。 采用雙向門控循環(huán)單元來(lái)獲取兩個(gè)序列的隱狀態(tài)。exi 代表單詞 xi 的嵌入矢量表。
?
注意力。 注意力機(jī)制用于提取源信息(源上下文矢量,source context vector)。在第 j 步,目標(biāo)單詞 yj* 和第 i 個(gè)源單詞之間的相關(guān)性通過源序列進(jìn)行評(píng)估:
解碼器。 解碼器應(yīng)用 GRU 的一個(gè)變體來(lái)解碼目標(biāo)信息。在第 j 步,目標(biāo)隱狀態(tài) sj 由下式得到:
目標(biāo)詞典中所有詞的概率 Pj 即可基于上一個(gè)真實(shí)詞、源上下文矢量和隱狀態(tài)得到:
圖 1 方法框架圖
該方法的主要框架(如圖 1 所示)是以一定的概率將真實(shí)詞或之前預(yù)測(cè)的詞(即 oracle 詞)作為上下文。通過訓(xùn)練模型來(lái)處理測(cè)試期間出現(xiàn)的情況,也許可以減少訓(xùn)練和推理之間的差距。在這里,作者介紹了兩種選擇 oracle 單詞的方法。一種方法是用貪婪搜索算法,在單詞級(jí)別選擇 oracle 單詞,另一種方法是在語(yǔ)句級(jí)別選擇最優(yōu)的 oracle 序列。預(yù)測(cè)第 j 個(gè)目標(biāo)單詞 yj 包括以下步驟:?
在第 j-1 步選擇 oracle 單詞。
從真實(shí)詞 y*(j-1) 中以概率 p 采樣,或從 oracle 詞 yoracle(j-1) 中以概率 1-p 采樣。
使用采樣的單詞作為 y(j-1),并用 y(j-1) 代替公式 6 和 7 中的 y*(j-1),然后繼續(xù)使用基于注意力的 NMT 進(jìn)行后續(xù)的預(yù)測(cè)。
一般情況下,在第 j 步,NMT 模型需要用真實(shí)值 y*(j-1) 作為上下文詞(context word)來(lái)預(yù)測(cè) yj,所以我們可以選擇一個(gè) oracle 詞 yoracle(j-1) 來(lái)近似上下文詞。oracle 詞應(yīng)該與真實(shí)值相似,或者是真實(shí)值的近義詞。選擇 oracle 詞的一個(gè)方法是單詞級(jí)別的貪婪搜索,輸出每一步的 oracle 單詞(word-level oracle,WO)。此外,也可以通過擴(kuò)大搜索空間,對(duì)候選翻譯按語(yǔ)句級(jí)別的衡量標(biāo)準(zhǔn)進(jìn)行排序,例如 BLEU、GLEU、ROUGE 等指標(biāo)。選擇的翻譯即為 oracle 語(yǔ)句,該翻譯中的單詞即為語(yǔ)句級(jí)別的 oracle(sentence-level oracle,SO)。
圖 2 單詞級(jí)別 oracle(不含噪聲)
圖 3 單詞級(jí)別 oracle 加入 Gumbel 噪聲
作者將 Gumbel 噪聲以正則項(xiàng)的形式,加入公式 8 中的 o(j-1),如圖 3 所示,然后經(jīng)過 softmax 函數(shù),y(j-1) 的詞分布可以近似為:
當(dāng)τ趨近于 0 時(shí),softmax 函數(shù)近似為 argmax 函數(shù),當(dāng)τ接近無(wú)窮大時(shí),逐漸變成均勻分布。最佳的單詞級(jí)別 oracle 可由下式得到:
語(yǔ)句級(jí)別的 oracle 能夠通過 n-gram 匹配得到更靈活的翻譯。在這篇文章中,作者采用 BLEU 作為衡量指標(biāo)。為了選擇語(yǔ)句級(jí)別的 oracle,作者首先對(duì)一個(gè) batch 的所有句子進(jìn)行束搜索,假設(shè)束大小為 k,則得到 k 個(gè)最佳的候選翻譯。然后計(jì)算每個(gè)候選翻譯與真實(shí)值之間的 BLEU 分?jǐn)?shù),分?jǐn)?shù)最高的則作為 oracle 語(yǔ)句。將其表示為:
那么在解碼的第 j 步,語(yǔ)句級(jí)別 oracle 詞即可表示為:
但是語(yǔ)句級(jí)別的 oracle 存在一個(gè)問題。當(dāng)模型從真實(shí)詞和語(yǔ)句級(jí)別 oracle 詞中采樣時(shí),兩個(gè)序列應(yīng)該具有同樣數(shù)量的單詞。然而簡(jiǎn)單的束搜索解碼算法不能保證這一點(diǎn)。因此作者引入了強(qiáng)制解碼(force decoding)來(lái)確保兩個(gè)序列的長(zhǎng)度相同。
假設(shè)真實(shí)序列的長(zhǎng)度為|y|,強(qiáng)制解碼的目的是生成一個(gè)長(zhǎng)度為|y|的序列,后面跟著一個(gè)終止語(yǔ)句符號(hào)(EOS)。這樣在束搜索中,當(dāng)一個(gè)候選翻譯的長(zhǎng)度不等于|y*|,卻以 EOS 終結(jié)語(yǔ)句時(shí),強(qiáng)制解碼會(huì)強(qiáng)制它生成|y|個(gè)單詞:?
當(dāng)?shù)?j-1 步,候選翻譯的長(zhǎng)度還沒達(dá)到|y|,但是 EOS 已經(jīng)是第 j 步的首選詞時(shí),則從詞分布 Pj 中選擇第二個(gè)候選詞作為該翻譯的第 j 個(gè)詞。
當(dāng)?shù)趞y|+1 步時(shí),如果 EOS 不是詞分布的首選詞,則讓它成為候選翻譯第|y|+1 個(gè)詞。
這樣,就可確保所有的 k 個(gè)候選翻譯的長(zhǎng)度都為|y|,然后再根據(jù) BLEU 分?jǐn)?shù)對(duì) k 個(gè)候選翻譯進(jìn)行排序,然后選擇第一個(gè)作為 oracle 語(yǔ)句。
作者采用衰減采樣機(jī)制從真實(shí)詞 y(j-1) 和 oracle 詞 yoracle(j-1) 中采樣得到上下文詞 y(j-1)。在訓(xùn)練開始時(shí),由于模型沒有經(jīng)過良好的訓(xùn)練,使用 yoracle(j-1) 作為 y(j-1) 過于頻繁會(huì)導(dǎo)致收斂非常緩慢,甚至陷入局部最優(yōu)。另一方面,在訓(xùn)練結(jié)束時(shí),如果上下文詞 y(j-1) 在很大概率上仍然是從真實(shí)詞 y*(j-1) 中選擇的,則模型不會(huì)完全接觸到推理時(shí)會(huì)遇到的情況,從而不知道如何在推理時(shí)采取行動(dòng)。因此,從真實(shí)詞中選擇的概率 p 是不固定的,但隨著訓(xùn)練的進(jìn)行,它必須逐漸降低。在開始時(shí),p=1,即模型完全基于真實(shí)詞進(jìn)行訓(xùn)練。隨著模型逐漸收斂,模型將更多的從 oracle 詞中選擇上下文詞。
根據(jù)訓(xùn)練 epoch 逐漸衰減采樣概率 p:
用上述方法選擇 y(j-1) 后,可根據(jù)公式(6)、(7)、(8)、(9)得到 yj 的詞分布。目標(biāo)是最大化真實(shí)值序列的概率。因此,通過最小化以下?lián)p失函數(shù)訓(xùn)練模型:
對(duì)于 NIST 中譯英(Zh->EN)任務(wù),作者采用了兩個(gè)基線模型進(jìn)行驗(yàn)證。
表 1 中譯英翻譯任務(wù)實(shí)驗(yàn)結(jié)果
作者對(duì)比了三種對(duì)基于 RNN 的 NMT 模型進(jìn)行增強(qiáng)的方法:Coverage、MRT 和 Distortion。與這三種方法對(duì)比,作者提出的基線系統(tǒng) RNNsearch 的表現(xiàn) 1)超越了 Coverage,2)達(dá)到了與 MRT 和 Distortion 一樣的表現(xiàn)。
作者與其他兩個(gè)解決暴露偏差的方法進(jìn)行了對(duì)比:SS-NMT 和 MIXER。從表 1 中可以看出,SS-NMT 和 MIXER 都能取得一定的提升,但是作者提出的 OR-NMT 不僅超越了 RNNSearch 的基線,并且取得了更大的提升。與其他兩個(gè)方法相比,OR-NMT 在四個(gè)測(cè)試數(shù)據(jù)集上將 BLEU 分?jǐn)?shù)提升了 2.36 分。
作者在 Transformer 模型上測(cè)試了提出的方法。從表 1 可以看出,單詞級(jí)別的 oracle 可以取得 +0.54 BLEU 分的提升,語(yǔ)句級(jí)別的方法可以進(jìn)一步帶來(lái) +1.0 BLEU 分的提升。
作者提出了單詞級(jí)別 oracle、語(yǔ)句級(jí)別 oracle 和在 oracle 選擇中結(jié)合 Gumbel 噪聲這三種方法來(lái)解決過度矯正的問題。表 2 給出了這三種因素的影響。
表 2 中譯英翻譯任務(wù)因素分析實(shí)驗(yàn)
在只采用單詞級(jí)別 oracle 時(shí),模型表現(xiàn)提升了 1.21 BLEU 分?jǐn)?shù)點(diǎn),說(shuō)明輸入之前預(yù)測(cè)的詞作為上下文可以減輕暴露誤差。采用語(yǔ)句級(jí)別 oracle 時(shí),可以進(jìn)一步提升 0.62 BLEU 分?jǐn)?shù)點(diǎn)。說(shuō)明語(yǔ)句級(jí)別 oracle 的表現(xiàn)優(yōu)于單詞級(jí)別 oracle。作者認(rèn)為,這種優(yōu)勢(shì)可能來(lái)自于單詞生成的更大的靈活性,它可以緩解過度矯正的問題。通過在單詞級(jí)別 oracle 和語(yǔ)句級(jí)別 oracle 的生成過程中加入 Gumbel 噪聲,模型的 BLEU 得分分別提高了 0.56 和 0.53。這表明 Gumbel 噪聲可以幫助選擇每個(gè) oracle 詞,證明了 Gumbel-Max 提供了一種從分類分布中進(jìn)行采樣的有效和可靠的方法。
作者研究了不同因素對(duì)收斂性的影響。圖 4 給出了 RNNsearch 以及不同變體的訓(xùn)練損失曲線。圖 5 給出了不同因素的 BLEU 分?jǐn)?shù)值對(duì)比??梢钥闯?,RNNsearch 收斂較快,并且在第 7 個(gè) epoch 達(dá)到最佳結(jié)果,但是第 7 個(gè) epoch 后訓(xùn)練損失依然持續(xù)下降,所以 RNNsearch 的訓(xùn)練可能會(huì)過擬合。圖 4 和圖 5 也顯示出,加入 Gumbel 噪聲會(huì)稍微拖慢收斂速度,但是模型達(dá)到最佳表現(xiàn)后訓(xùn)練損失不會(huì)再繼續(xù)下降。這表明 oracle 采樣和 Gumbel 噪聲能避免過擬合。
圖 4 中譯英翻譯任務(wù)不同因素的訓(xùn)練損失曲線
圖 5 驗(yàn)證集上中譯英翻譯任務(wù)不同因素的 BLEU 分?jǐn)?shù)變化趨勢(shì)
圖 6 MT03 測(cè)試集上中譯英翻譯任務(wù)不同因素的 BLEU 分?jǐn)?shù)變化趨勢(shì)
圖 6 給出了 MT03 數(shù)據(jù)集上的 BLEU 分?jǐn)?shù)曲線。在語(yǔ)句級(jí)別 oracle 加入噪聲時(shí),可以得到最佳模型。沒有噪聲時(shí),模型收斂后的 BLEU 分?jǐn)?shù)較低。這也很好理解,在訓(xùn)練過程中如果沒有正則項(xiàng),只是一直重復(fù)使用模型自己的結(jié)果,容易導(dǎo)致過擬合。
圖 7 給出了在 MT03 測(cè)試集上從不同長(zhǎng)度的源語(yǔ)句中生成翻譯的 BLEU 分?jǐn)?shù)值。從圖中可以看出,論文的方法在所有的區(qū)間都對(duì) baseline 有較大的提升,尤其是 (10,20]、(40,50] 和 (70,80] 區(qū)間。交叉熵?fù)p失需要預(yù)測(cè)序列與真實(shí)值序列完全相同,這對(duì)于較長(zhǎng)的語(yǔ)句來(lái)說(shuō)更難做到,而語(yǔ)句級(jí)別 oracle 可以減輕這種過度矯正。
圖 7 MT03 測(cè)試集不同程度源語(yǔ)句模型表現(xiàn)對(duì)比
為了證明該方法帶來(lái)的提升是由于解決了暴露偏差問題,作者從中譯英訓(xùn)練數(shù)據(jù)中隨機(jī)選擇了 1000 對(duì)句子,然后用預(yù)訓(xùn)練的 RNNSearch 模型和提出的模型對(duì)源語(yǔ)句進(jìn)行解碼。RNNSearch 模型的 BLEU 分?jǐn)?shù)為 24.87,而論文模型提升了 2.18 分。然后作者統(tǒng)計(jì)了論文模型預(yù)測(cè)分布中真實(shí)詞的概率高于基線模型的數(shù)量,記為 N。在參考語(yǔ)句中共有 28266 個(gè)詞,N=18391,比例為 18391/28266=65.06%,證明了該方法帶來(lái)的提升是由于解決了暴露偏差問題。
表 3 英譯德翻譯任務(wù)實(shí)驗(yàn)結(jié)果
作者在 WMT’14 上也驗(yàn)證了所提方法。從表 3 中可以看出,論文提出的方法大大提升了基線模型的表現(xiàn),并且優(yōu)于其他相關(guān)方法。該實(shí)驗(yàn)說(shuō)明論文模型對(duì)不同語(yǔ)言之間的翻譯均有效。
端到端的 NMT 模型訓(xùn)練時(shí)將真實(shí)值單詞作為上下文,而模型推理時(shí)則由模型生成的前一個(gè)單詞作為上下文。為了減少訓(xùn)練和推理之間的差異,在預(yù)測(cè)一個(gè)詞時(shí),作者從真實(shí)值單詞或預(yù)測(cè)詞中抽樣得到一個(gè)詞作為上下文輸入。預(yù)測(cè)詞,即 oracle 詞,可以通過單詞級(jí)別或語(yǔ)句級(jí)別優(yōu)化生成。與單詞級(jí)別 oracle 相比,語(yǔ)句級(jí)別 oracle 可以進(jìn)一步增強(qiáng)模型的過度矯正恢復(fù)能力。為了使模型充分地暴露在推理時(shí)的環(huán)境中,作者采用衰減采樣,從真實(shí)值單詞采樣得到上下文詞。作者用兩個(gè)基線模型和相關(guān)工作在真實(shí)翻譯任務(wù)上進(jìn)行了驗(yàn)證,該方法在所有數(shù)據(jù)集上都有顯著提升。這篇論文很好地解決了 seq2seq 中存在的暴露偏差問題,用充分的實(shí)驗(yàn)證明了方法的有效性。
查看論文原文:
Bridging the Gap between Training and Inference for Neural Machine Translation
https://arxiv.org/abs/1906.02448
你也「在看」嗎???
聯(lián)系客服