通用數據集上蒸餾,和特定數據集上蒸餾,并且在特定數據集上做數據增加效果會(huì )更好
TinyBERT中蒸餾的整體過(guò)程:首先進(jìn)行通用蒸餾,然后用數據增強后的數據,在特定任務(wù)上進(jìn)行蒸餾,本文主要進(jìn)行了第二階段的蒸餾,模型是利用第一階段得到的通用小模型tinybert-6l-768d-v2進(jìn)行初始化。
知識的蒸餾通常是通過(guò)讓學(xué)生模型學(xué)習相關(guān)的蒸餾相損失函數實(shí)現,在本實(shí)驗中,蒸餾的學(xué)習目標由兩個(gè)部分組成,分別是中間層的蒸餾損失和預測層的蒸餾損失。其中,中間層的蒸餾包括對Embedding層的蒸餾、對每個(gè)Transformer layer輸出的蒸餾、以及對每個(gè)Transformer中attention矩陣(softmax之前的結果)的蒸餾,三者均采用的是均方誤差損失函數。而預測層蒸餾的學(xué)習目標則是學(xué)生模型輸出的logits和教師模型輸出的logits的交叉熵損失。
蒸餾層的映射由于教師模型是12層,學(xué)生模型的層數少于教師模型的層數,因此需要選擇一種layer mapping的方式。論文中采用了一種固定的映射方式,當學(xué)生模型的層數為教師模型的1/2時(shí),學(xué)生第i層的attention矩陣,需要學(xué)習教師的第2i+1層的attention矩陣,Transformer layer輸出同理。