英文题目:Meta-KD: A Meta Knowledge Distillation Framework for Language Model Compression across Domains

中文题目:Meta-KD: 跨领域语言模型压缩的元知识蒸馏框架

论文地址:http://export.arxiv.org/pdf/2012.01266v1.pdf

领域:自然语言处理, 知识蒸馏

发表时间:2020.12

作者:Haojie Pan,阿里团队

出处:ACL

被引量:1

代码和数据:https://github.com/alibaba/EasyNLP(集成于 EasyNLP)

阅读时间:2022-09-17

读后感

结合元学习和蒸馏学习:元学习使得模型获取调整超参数的能力,使其可以在已有知识的基础上快速学习新任务。

介绍

预训练的自然语言模型虽然效果好,但占空间大,预测时间长,使模型不能应用于实时预测任务。典型的方法是使用基于老师/学生模型的知识蒸馏。而模型一般面向单一领域,忽略了不同领域知识的知识转移。本文提出元蒸馏算法,致力于基于元学习的理论,让老师模型具有更大的转移能力,尤其对 few-shot 和 zero-shot 任务效果更好。

如图 -1 所示,一个学物理的学生如果跟数学老师学习了数学方程知识,可能有助于他更好地理解物理方程。相近领域的数据可能提升模型的能力,但其它领域模型也可能转移一些无关的知识,从而影响性能。另外,当前研究证明:使用多任务精调也未必能提升所有任务的性能。由此,文中提出需要让老师模型消化不同领域的知识,并可针对具体领域,将知识转移到学生模型。在图 -1(c) 中,如果有万能的科学老师(元学习),它既会数学也会物理,则可以更好地教导学生。

如图 -2 所示,模型包含两部分:元老师和元蒸馏:

首先利用多领域数据集训练元老师,通过引入破坏域损失来获取跨域知识,然后针对具体领域,用领域相关数据集引导元老师,以提升学生的蒸馏能力。

文章贡献

  • 第一次提出基于元学习的预训练自然语言模型压缩算法。
  • 提出 Meta-KD 框架训练跨领域的老师模型,包含元老师和元蒸馏两部分
  • 实验证明模型的有效性

方法

概览

定义:设有 K 个领域的 K 个数据集参与训练,D 为数据集,M 为大模型,S 为蒸馏后的学习模型。

模型训练分为两个场景:

  • 训练一个学习了 K 个领域知识的元老师模型 M,模型消化了各领域知识且有针对不同领域很好的泛化能力。
  • 在元蒸馏过程中,利用领域数据集 DK 和元模型 M,训练学生模型 SK。

如果某一个领域的实例很少,如 few-shot 或 zero-shot 问题,通过知识转移训练该领域模型。

元老师学习

将 BERT 模型作为基础模型。

基于原型实例加权

学习过程中对每个实例 X 计算原型得分 t,假设处理分类问题,共 m 个类别,计算所有第 K 领域中实例属于每个类别的概率均值(请参考图 -3 左侧的实心多边形):

计算原型得分如下:

此处 cos 用于计算相似度,α是超参数,公式的前半部分计算了该实体与它所在的领域的关系(在嵌入空间与同类实体的一致性),后半部分计算了与其它领域的关系。这样模型就同时学习了同一领域的知识和其它领域的知识。

域破坏

除了交叉熵损失,还加入了域破坏损失以提升元老师转移学习的能力。对于每个实例,学习一个与 h 维度相同的域嵌入,记作 ED(epsilon D)。

在 BERT 以外,又加入了一个子网络,对网络输出进一步处理:

针对域破坏的损失函数定义为:

其中σ(sigma) 表示域类别,它是一个指示函数,只有 0/1 两个取值,这里最大化元教师对域标签做出错误预测的可能性。

我理解,这里的损失函数是让实例最终能识别它所在的域类别 k。

损失函数

最终的损失定义为:使用得分 t 加权针对所有领域的交叉损失;同时,加入了域破坏损失作为辅助,以训练模型转移知识的能力。

这里的γ1(gamma) 是超参数,用于设定域破坏损失的贡献。

元蒸馏

使用小型的 BERT 作为学生模型,蒸馏网络结构如图 -3 所示:

目标由五个部分组成:输入嵌入 Lembd,隐藏层状态 Lhidn,注意力矩阵 Lattn,输出 ligit 和知识转移。其中 Lembd,Lhidn,Lattn 的蒸馏方法与 TinyBERT 一样。又加入了 Lpred 对输出层使用软交叉熵损失。

另外,考虑到特定领域的知识转移,下面公式又加入了域相关的损失:

以此鼓励学生模型学习更多的该领域相关知识。我理解这里的 hM 是指对该领域的老师模型获得的编码。

又引入λk 参数,它是领域相关的权重:

其中 y^是预测的类别标签,当预测准确,或者 t 比较大时,λ值也相应变大,它反应的是老师在特定任务上监督学生的能力。

整体蒸馏损失计算方法如下:

实验

使用自然语言推理(MNLI)和情绪分析(Amazon Reviews)两个任务评价模型。

表 -2 和 3 展示了主实验结果:

得出三个结论:

  • Meta-KD 模型优于之前模型,它比基线模型小 7.5 倍,效果仅差 0.5%
  • Meta-teacher 模型效果很好,这表明元老师有能力学习更多可转移的知识来帮助学生。
  • 一般情况下,Meta-KD 对小数据集数据效果更明显。

图 -4 也说明在 few-shot 情况下,实例越少,Meta-KD 效果越明显: