本文约3300字,建议阅读10分钟
本文回顾了在四个分类任务(文本分类、意图分类、关系抽取和命名实体识别)下的表现,并在两种最流行的增量学习设置(类别增量和任务增量)中进行实验。
论文题目:
Learn or Recall? Revisiting Incremental Learning with Pre-trained Language Models
收录会议:
ACL 2024, Long Paper, Oral
论文链接:
https://aclanthology.org/2024.acl-long.794/
增量学习(IL)一直是计算机视觉和自然语言处理(NLP)领域长期存在的问题。近年来,随着大语言模型(Large Language Model, LLM)在各种 NLP 下游任务中取得了显著进展,将 LLMs 作为骨干网络在 NLP 领域的增量学习研究中已成为一种常见做法。大多数研究假设灾难性遗忘是实现优越增量学习性能的最大障碍,并提出了各种技术来克服这一问题。然而,我们发现这一假设存在问题。具体而言,我们回顾了在四个分类任务(文本分类、意图分类、关系抽取和命名实体识别)下的表现,并在两种最流行的增量学习设置(类别增量和任务增量)中进行实验,结果揭示大多数方法严重低估了 LLMs 固有的抗遗忘能力。这些发现促使我们重新审视基于 LLMs 的增量学习,并鼓励未来的研究更加深入地理解 LLMs 中的灾难性遗忘问题。我们利用探测技术 probing 评估模型 backbone 对目标任务的表示能力,实现如图 1 所示。我们在实验中使用生成模型进行类别增量意图分类的观察和探测性能。图 2(a)显示,随着更多新任务的学习,观察到的性能显著下降,从约 98% 降至 10%,这一结果符合我们对灾难性遗忘的理解。然而,图 2(b)描述了一个完全不同的现象。LLMs 在学习第一个任务后就达到了很高的探测性能,并且从第二个任务开始,线性探测性能几乎没有下降。换句话说,即使 LLMs 仅按顺序适应新任务(Sequential fine-tuning,SEQ),它们依然保留了分类所有 15 个任务的知识。这个现象与我们对灾难性遗忘和 SEQ 的理解相矛盾。实际上,探测性能之所以很高,是因为在训练探测分类器时,所有任务的数据都是可用的,而观察到的性能较差,是因为原始分类器仅在当前任务的数据上进行训练。因此,经过探测的实验结果表明大模型在连续学习过程中并没有丢失其知识。新发现2:Probing 性能:Linear > Cosine Linear ≈ Cosine Prototype > Prototype我们发现四个探测指标的排序如下:Linear > Cosine Linear ≈ Cosine Prototype > Prototype。如图 3 所示:首先,我们需要分别理解 LLMs 的特征(即最后的隐藏状态)、词向量和探测分类器中的类别嵌入“是什么样的”。特征、词向量和类别嵌入的 L2 范数和余弦相似度的直方图如图 4。图4 Pythia-410m 的特征和不同嵌入的直方图图 4a 显示,特征在向量空间中占据一个狭窄的圆锥形区域,而不是在所有方向上均匀分布。更令人惊讶的是,图 4b 显示,学习到的(输出)词向量与特征几乎是正交的。我们推测,交叉熵损失函数鼓励除了真实标签外的所有词向量在预训练过程中远离特征。换句话说,交叉熵损失鼓励 logits 之间有较大的差异,并且词向量与特征正交,以便更好地区分 logits。因此,考虑到词向量层本质上是一个线性层,线性探测有最佳表现也就不足为奇。从这个角度来看,原型探测表现较差也就不奇怪,因为原型(类别特征中心)也落在这个狭窄的圆锥空间内,而这对于区分 logits 并不是一个最优的解决方案。那么,为什么余弦归一化会降低线性探测的性能,但能改善原型探测的性能呢?图 4c 和图 4d 展示了特征和词向量的 L2 范数。我们发现,词向量的范数与特征相比存在较大的差异。这表明,词向量的范数包含了来自预训练阶段的先验知识。因此,余弦线性探测忽略了特征范数的差异,因此相比于线性探测,其性能较差。对于原型探测,原型位于一个狭窄的圆锥空间中,原型和特征之间的相似度较大,且接近彼此。在这种情况下,余弦归一化可以消除范数的干扰,从而建立 logits 和特征之间余弦相似度的关系。新发现3:LLMs 抵抗遗忘的关键在于 Transformer 的结构和预训练获取的知识我们评估了在不同预训练步数的检查点上的线性探测性能:{0, 16, 128, 1k, 10k, 143k(最终)}。我们加载预训练的检查点(或在步数为 0 时随机初始化的检查点),并在使用 SEQ 进行增量学习前后评估它们的线性探测性能。图 5 展示了预训练中的两个主要阶段:过拟合和泛化。在第一个阶段(步数0 - 步数 128),模型开始记忆预训练语料库,线性探测性能下降。在第二个阶段(步数 1k - 步数 143k),模型逐渐学习预训练知识,线性探测性能上升。然而,当模型进一步泛化到预训练语料库时(步数 10k - 步数 143k),小型骨干网络(如 Pythia-70m 和 160m)的线性探测性能再次下降,原因是预训练和下游任务之间存在差距。这个差距可以通过适应下游任务来消除。对于较大的骨干网络(如 Pythia-410m、1b 和 1.4b),模型能够直接适应新任务,而不会受到这种差距的影响。此外,我们还有以下有趣的发现:1. 预训练确实改善了增量学习中的线性探测性能(见图 5b 和图 5d)。2. 除了预训练之外,Transformer 的架构也是在 SEQ 过程中获得高线性探测准确率的关键因素。当下游任务相对简单时,例如意图分类,即使是随机初始化的模型也能获得较高的线性探测性能(见图 5b)。而当下游任务较为复杂时,例如关系抽取(见图 5d),预训练则带来了显著的性能提升。3. 更令人惊讶的是,SEQ 提高了几乎所有预训练步骤的模型的线性探测性能(见图 5a 与 5b;图5c 与 5d)。这表明,Transformer 的架构即使仅在新任务上进行顺序微调,也能够逐步吸收新知识。我们观察到,在 SEQ 模型中,新类别的 logits 远大于旧类别的 logits。由于特征和类别嵌入决定了 logits 的大小,而特征占据一个狭窄的圆锥空间,其范数相对接近,因此我们可以推测,遗忘现象的发生是由以下原因之一引起的:(1)类别嵌入的范数,或(2)特征与类别嵌入之间的余弦相似度。对于第一种原因(即类别范数),我们在图 6a 和图 6b 中比较了学习的线性分类器和线性探测分类器之间的类别嵌入范数。令人惊讶的是,在 SEQ 的观察分类器中,新任务的类别嵌入范数并不大于旧任务的类别嵌入范数。这表明,类别范数不是 SEQ 中遗忘现象的主要原因。对于第二个原因(即余弦相似度),我们在图6c和图6d中比较了观察分类器和探测分类器之间类别嵌入的移动距离。任务t的类别嵌入在任务时的移动距离计算如下:1. 当模型完成任务 的训练后,我们计算任务 t 的所有类别嵌入与所有任务的类别特征中心之间的余弦距离,并得到一个余弦相似度矩阵 。2. 当模型完成任务 t+k 的训练后,我们计算任务 t 的所有类别嵌入与所有任务的类别特征中心之间的余弦距离,并得到一个余弦相似度矩阵 。3. 然后,任务 t 的类别嵌入的移动距离计算为余弦相似度矩阵 和 之间的平均绝对差异。移动距离衡量了自学习以来,类别嵌入相对于所有类别特征中心的移动情况。图6 在 SEQ 过程中观察到的线性分类器与线性探测分类器的比较如果分类器没有遗忘某个类别,那么它的类别嵌入到所有类别特征中心的距离应该保持恒定。换句话说,如果分类器没有遗忘如何使用 LLMs 提取的特征来分类该类别,则其移动距离将为零。图 6c 和 6d 显示,观察分类器的类别嵌入相对于探测分类器发生了显著变化。这表明,遗忘现象的发生是因为旧的类别嵌入被推离了其初始和最优位置。最后,我们根据实验发现设计了 SEQ,提出了以下策略来缩小 SEQ 中探测和观察性能之间的差距:(S1)Warm-up 后冻结 LLMs;(S2)在学习新任务时冻结旧分类器;(S3)只有在 CIL 场景中没有旧数据可用的情况下才使用余弦线性分类器。否则,请使用线性分类器;(S4,可选)预先分配未来的分类器。我们将使用上述策略的方法称为 SEQ,如图 7 所示。实验结果如图 8 所示。具体实验情况详见论文:https://aclanthology.org/2024.acl-long.794/图8 在句子级分类任务上 SOTA 方法和 SEQ* 的比较
数据派THU作为数据科学类公众号,背靠清华大学大数据研究中心,分享前沿数据科学与大数据技术创新研究动态、持续传播数据科学知识,努力建设数据人才聚集平台、打造中国大数据最强集团军。
新浪微博:@数据派THU
微信视频号:数据派THU
今日头条:数据派THU