ICLR 2021 的一篇文章提出了基于 KNN 方法的機器翻譯(kNN-MT),可以將 kNN 方法添加到現有的神經機器翻譯模型(NMT)上,從而進一步提升推理表現。該方法幫助當時的 SOTA 德語-英語翻譯模型提升了 1.5 BLEU 分數,并且還可以適應跨領域及零樣本傳輸。
本次要分享的論文則是針對 kNN-MT 推理速度過慢的不足,提出了蒸餾方法(kNN-KD)。從而在保持 kNN-MT 表現的情況下,將推理速度提升到了與一般 NMT 模型推理速度相當的水平。Nearest Neighbor Knowledge Distillation for Neural Machine Translation收錄會議:
論文鏈接:
https://arxiv.org/abs/2205.00479Methods
2.1 kNN-MT
KNN-MT 方法有兩個步驟:
1. Datastore creation:
根據訓練集每一條樣本離線構建的鍵值對組合,如下公式所示。其中 表示樣本的源語言句和目標語言句, 為翻譯過程中第 步時已經推理出來的文本, 表示第 步需要推理的目標語言 token。 表示 經過模型 decoder 編碼得到的高維向量。2. Generation:
推理階段的每一步時,首先根據 NMT 模型給出下一個 token 的輸出概率 ,然后根據 kNN 方法給出下一個 token 的輸出概率 ,最終的輸出概率為 。kNN 輸出概率如下:按照構造 Datastore 的方式,根據當前的測試樣本先構建當前步驟的 key,然后遍歷 Datastore 找到 距離最近的 個結果,將其距離進行一系列操作后,轉化為對應 value 的輸出概率,如下圖所示:在一般訓練 NMT 模型時,通常使用 模型預測結果 和 grount-truth 的交叉熵(CE)進行訓練。但在自然語言中,一個句子通常有多種表達,如果模型預測出一個合理但偏離 grount-truth 的詞,CE損失也會將其視為錯誤并懲罰模型,導致模型泛化性變差,這就是所謂的 overcorrection。 而在 KNN-MT 中,在解碼階段綜合考慮了其他可能的合理解釋,在一定程度上緩解了該問題,所以表現有了明顯提升。2.2 kNN-KD
針對 kNN-MT 推理速度很慢的劣勢,本文作者提出了 kNN-KD 方法,步驟如下:
1. Datastore creation:與 kNN-MT 相同
2. Distillation:
對于教師模型,在訓練前針對每一條訓練樣本的每一步驟,都按照類似 kNN-MT 中的方法輸出下一 token 的生成概率 。對于學生模型,針對每一條訓練樣本的每一步驟,都正常輸出下一 token 的生成概率 。訓練過程中,蒸餾損失為教師模型和學生模型表現的交叉熵:3. Generation:在最終的推理階段,就不需要再進行 kNN 搜索了,只要按照正常的 NMT 模型進行翻譯即可。
本文使用 IWSLT'14 德語-英語(De-En,160k 訓練樣本)、IWSLT'15 英語-越南語(En-Vi,113k 訓練樣本)和多域翻譯數據集(De-En,733k 訓練樣本)進行了實驗。使用 tst2012 作為驗證集,使用 tst2013 作為測試集,分別包含 和 個句子。本文所提出的 kNN-KD 是一種無架構方法,可應用于任意 Seq2Seq 模型,可以與其他提升性能的工作同時應用。因此,作者主要將 kNN-KD 與 kNN-MT 以及一些典型的 KD 方法進行比較,包括但不限于 Word-KD、Seq-KD、BERT-KD 和 Selective-KD 等。
實驗中所有算法都利用 pytorch 中的 fairseq 工具包實現,在 個 NVIDIA GTX 1080Ti GPU 上進行。實驗模型選取 層 Transformer。對于 IWSLT'14 和 IWSLT'15 模型,配置 embedding size 為 ,feed-forward size 為 ,attention heads 為 。針對跨領域數據集,配置 embedding size 為 ,feed-forward size 為 ,attention heads 為 。作者提前對 和 (歸一化溫度)進行了網格搜索,并選取了驗證集上的最佳 BLEU 分數對應的超參數 ,如下表所示,其中 表示 Datastore 中數據個數:3.2 Results
在 IWSLT 數據集上的實驗結果如下表所示,KNN-KD 超過了所有其它強 baseline,比 Transformer 取得了 和 的 BLEU 分數提升。在跨領域數據集上,kNN-KD 同樣超過了其他 baseline,如下表所示。在各領域中,kNN-KD 均可以超過 kNN-MT 的表現,且推理速度顯著提升。同樣,作者也進一步研究了 kNN-KD 的泛化性:在特定領域訓練了一個 NMT 模型,并在 out-of-domain 的測試集上進行了測試,實驗結果如下表所示,kNN-KD 的泛化性明顯優(yōu)于僅靠標準 CE 訓練的 Transformer。Conclusion
在本文中,作者提出了 kNN-KD,它提取通過 kNN 檢索得到的知識,以緩解基礎 NMT 模型過度校正的問題。實驗表明,kNN-KD 可以改進普通 kNN-MT 和其他baseline,而無需任何額外的訓練和解碼成本。
本站僅提供存儲服務,所有內容均由用戶發(fā)布,如發(fā)現有害或侵權內容,請
點擊舉報。