论文阅读_BERT知识蒸馏
英文题目:Distilling Task-Specific Knowledge from BERT into Simple Neural Networks
中文题目:从 BERT 中蒸馏指定任务知识到简单网络
论文地址:https://arxiv.org/pdf/1903.12136.pdf
领域:自然语言,深度学习
发表时间:2019
作者:Raphael Tang, 滑铁卢大学
被引量:226
代码和数据:https://github.com/qiangsiwei/bert_distill
阅读时间:2022.09.11
读后感
第一次对大型自然语言模型的蒸馏:将 BERT 模型蒸馏成 BiLSTM 模型。
介绍
在自然语言处理方面,随着 BERT,GPT 等大规模预训练模型的发展,浅层的深度学习模型似乎已经过时了。但由于资源的限制,又需要使用小而快的模型。
文章的动机是讨论:浅层模型是否真的不具备对文本的表示能力?并展示了针对于具体的任务,将 BERT 蒸馏成单层 BiLSTM 模型的方法和效果。也通过大模型(起初训练的复杂的模型,后称 Teacher/T)和小模型(蒸馏后的模型,后称 Student/S)完全不同的模型结构展示了蒸馏与模型结构无关。另外,之前蒸馏模型主要应用于图片建模,论文讨论了它在自然语言领域的使用方法。
方法
核心方法包含两部分:增加了 logit 回归目标;重建蒸馏训练数据集使训练更为有效。
模型结构
将 BERT 作为教师模型,使用单层的 BiLSTM 作为学习模型的非线性分类器,针对每一种下游任务使用不同模型。如图 -1 是对单句分类任务设计的学生模型。
图 -2 展示了用于预测句子匹配度的模型,它们的编码层共享同一 BiLSTM 模型。
为了更好地对比效果,在学生模型中,未使用注意力归一化等更多技巧。
蒸馏目标
学生模型的目标是在所有数据上,模拟老师模型的行为。除了最终的标签,老师模型预测出的概率也很重要。比如在情绪分类问题中,一些实例有很强的正面情绪,有一些情绪可能比较中性,所以除了是否,也需要预测程度。
一般预测标签的方法是:
文中使用了 logit 的优化方法,构造了蒸馏目标:用 MSE 来惩罚师生模型间的差异:
其中 z(B) 指的是老师模型 BERT,z(S) 指学生模型,在初步实验中,MSE 比软目标效果更好。
在实际训练时,也使用了传统的交叉熵(对真正目标的预测)和蒸馏损失相结合的方式,最终损失函数如下:
当使用有标签数据训练时,t 是实例的标签;使用无标签数据训练时,使用老师模型打标签。
蒸馏的数据增强
在蒸馏过程中,使用小的数据集不足以让老师模型展示出其所有知识,因此,使用了无标签数据扩充训练数据集,用老师模型对其打标签。
增强 NLP 数据比增强图像数据难度大,没办法使用扭曲等方法,做出的句子可能不够流畅。文中提出了几种数据增强方法:
- 遮蔽:使用类似 BERT 的方法,这种方法能反应句中每个词对标签的贡献。
- 基于词性的词替换:在词袋里找同一词性的词作替换,以保持原始数据的分布。
- n-gram 采样:根据概率,随机采样 n 个连续的词,它是遮蔽方法的增强版。
实验
使用的是 BERT_LARGE 作为老师模型,针对特定任务精调,预测时获取预测的 logit 值,学生模型使用 300 维的 word2vec 作为词嵌入。主实验效果如表 -1 所示:
可以看到同样是使用 BiLSTM 方法,文中方法相较于其它方法有显著提升。
从表 -2 可以看到预测速度也有很大提升: