本文约9700字,建议阅读10分钟
本文介绍了知识蒸馏方法探究。
大型语言模型 (Large Language Models, LLMs) 的发展日新月异。从最初的简单对话系统,到如今能够执行文本生成、语言翻译和代码编写等复杂任务的先进模型,LLM 技术实现了跨越式的进步。
然而这些模型的规模和计算需求也呈指数级增长。它们需要大量的计算资源、专用硬件设施以及可观的能源消耗。对于学术界和工业界中的大多数研究者和开发者而言,尤其是不在大型科技公司的从业者,LLM 模型的庞大规模构成了实际应用的重大挑战。
知识蒸馏 (Knowledge Distillation) 技术应运而生。其核心思想类似于专业技能的传承过程:不是要求学习者直接复制全部细节,而是着重于掌握关键技能和核心方法。在 LLM 领域,知识蒸馏的目标是将大型模型(教师模型)的知识和能力转移到更小、更易管理的模型(学生模型)中。传统知识蒸馏方法虽已存在多年,但在转移过程中往往会损失部分关键能力,导致精简后的模型在推理能力等方面表现欠佳。
Google Research 团队发表的论文《Distilling Step-by-Step!》提出了一种创新的知识蒸馏方法,不仅能有效减小模型规模,还能使学生模型在某些任务上超越其教师模型。这种方法引起了机器学习领域研究者的广泛关注,同时也引发了一些质疑:这种方法的效果是否可靠?我们是否真的能构建更小且更智能的模型?
"Step-by-Step Distillation" 方法的核心创新在于其对推理过程的重视。该方法不再将 LLM 视为简单的输入输出映射器,而是着重提取其解决问题的思维链 (Chain-of-Thought)。这就像在数学教学中,不仅要求学生得到正确答案,更要理解完整的解题步骤。通过提取这种推理过程,该方法为学生模型提供了更深层次的学习指导。
本文将深入剖析 "Step-by-Step Distillation" 方法的技术原理,通过数学推导理解其内在机制,并使用 Python 实现一个简化版本。我们将探讨这种方法的工作原理、成功要素以及潜在局限性。
大型模型的瓶颈:为什么需要知识蒸馏?
大型语言模型的规模是其强大能力的根本来源。庞大的参数数量使它们能够完成各种复杂的任务。这种规模也带来了一些瓶颈, 特别是在实际应用中。这就像拥有一辆一级方程式赛车,虽然性能出色,但并不适合日常通勤。
运行这些大型模型需要大量的计算能力, 通常需要专门的硬件、大量的 GPU 以及高昂的电费。这对基础设施提出了挑战, 并且成本高昂。对于资源有限的小公司、研究人员,或者需要在手机或嵌入式系统等边缘设备上运行这些模型的情况,计算需求是一个巨大的障碍。
除了前期成本之外,延迟(模型生成响应所需的时间)也是一个关键因素。大型模型虽然功能强大,但由于每次推理都涉及大量的计算,因此速度可能会较慢。对于速度至关重要的实时应用程序,这种延迟是不可接受的。
LLM 知识蒸馏旨在解决这些问题。知识蒸馏的本质是知识转移, 其目标是将大型、强大的 LLM(教师模型)的基本知识和能力提炼成更小、更高效的学生模型。这类似于创建一种浓缩提取物, 通过仔细的过程缩小尺寸,同时保留甚至增强其关键品质。
知识蒸馏背后的核心动机是创建可以与大型模型相媲美,但计算成本显著降低且推理时间更快的较小模型。这使得强大的 AI 更易于访问、更易于部署,并且更可持续。能够在手机、智能家居设备或资源受限的应用程序中运行复杂的语言模型将极大地扩展 AI 的应用范围。
传统的知识蒸馏技术已经存在一段时间。诸如知识蒸馏之类的方法,通常使用来自教师的“软标签”或试图模仿中间表示,已经显示出一些成功。这些方法通常旨在训练学生复制教师的输出行为。虽然这些方法确实可以缩小模型并提高效率,但它们通常会损失一些关键要素和智能。这类似于复印一件杰作,虽然得到了一份副本,但细微的差别、深度和原始的活力通常会在翻译中丢失。当涉及到捕捉 LLM 的复杂推理能力时,传统的知识蒸馏方法有时会失败。它们教学生模仿答案,而没有真正理解其背后的推理。
“Distilling Step-by-Step” 提供了一种潜在的解决方案。它不仅仅是使模型更小,而是通过专注于知识蒸馏推理过程,使它们更小更智能。
Distilling Step-by-Step 方法
“Distilling Step-by-Step” 核心创新在于视角上的根本转变。这种方法不仅仅将大型语言模型视为输出答案的黑匣子,而是认识并利用 LLM 的推理能力。
这就像从简单地向老师索要答案,转变为要求他们展示解题步骤,解释他们的思考过程。这些“思考过程”(通过思维链 (CoT) 等提示技术引出)在指导较小的模型更有效地学习方面非常有价值。如果你正在学习下棋,仅仅记住获胜的步骤可能效果有限,但理解这些步骤背后的战略原因,理解象棋的基本原则,会让你成为一个更适应性更强、更优秀的棋手。“Distilling Step-by-Step” 试图将这种更深刻的理解,这种战略推理,传授给较小的学生模型。
首先是推理过程提取阶段(Rationale Extraction Phase)。在此阶段,使用思维链推理提示技术“采访”大型 LLM。这种技术旨在引出 LLM 的一步一步的推理过程。我们不仅仅问“答案是什么?”,而是问“你是如何得出这个答案的?你能一步一步地解释你的想法吗?”。该研究论文利用了“少样本 (few-shot)” 思维链推理提示,这意味着向 LLM 提供一些输入-推理过程-标签三元组的示例来指导其生成。这类似于在要求 LLM 解决一个新问题并解释其方法之前,向 LLM 展示一些已解决的示例。
例如,考虑一个简单的问题:“如果一列火车以 60 英里/小时的速度行驶 2 小时,它行驶了多远?”。传统的 LLM 提示可能只会要求提供答案。但是通过思维链推理提示,鼓励它生成中间推理步骤:“速度 = 60 英里/小时,时间 = 2 小时。距离 = 速度 x 时间。距离 = 60 英里/小时 x 2 小时 = 120 英里。” 输出不仅仅是“120 英里”,还有推理过程:“距离 = 速度 x 时间”和计算步骤。这些推理过程是 LLM 推理的自然语言解释。
此阶段的输出是一个有价值的数据集。对于每个输入,不仅有 LLM 预测的标签(答案),而且还有证明该标签合理的自然语言推理过程。这类似于创建一个丰富的数据集,其中每个示例不仅仅是输入-输出,而是输入-推理-输出。
接下来是 多任务训练阶段(Multi-Task Training Phase)。在此阶段,使用刚刚创建的数据集来训练较小的学生模型。“多任务”体现在我们不仅仅训练学生预测最终标签,而是训练它同时做两件事:预测标签并生成推理过程。
这类似于教象棋学徒不仅要走获胜的步骤,还要解释为什么每个步骤在战略上都是合理的。我们不仅仅要求学生模仿老师的答案,而是要理解和复制老师的推理过程。该论文将此定义为一个多任务学习问题。训练学生模型以最小化两个损失函数:一个用于标签预测准确性,另一个用于推理过程生成。这两个任务都被赋予权重,促使学生学习“什么”(答案)和“为什么”(推理)。
研究人员使用了一种称为“任务前缀”的技巧。在训练期间,当目标是标签时,他们在输入前加上“[label]”,当目标是推理过程时,他们在输入前加上“[rationale]”。这明确地告诉学生:“对于这个输入,我希望你专注于预测标签”,然后,“对于这个输入,我希望你专注于生成推理”。这有助于模型解开这两个任务,并学习有效地执行它们。
这种方法可能比传统的知识蒸馏甚至标准的微调更有效。传统方法通常侧重于模仿教师的表面行为(输出)。“Distilling Step-by-Step” 旨在捕捉更深层次的东西:潜在的推理过程。通过强制较小的模型生成推理过程,我们实际上是在鼓励它学习对任务的更抽象、更可概括的理解。它不仅仅是记住输入-输出对, 而是学习将输入连接到输出的原则。
这种对推理的关注是关键的区别因素,也是潜在数据效率提升的来源。通过学习生成推理过程,较小的模型可能能够从更少的示例中更好地概括,并且可能在某些情况下超过教师 LLM 的性能。
Distilling Step-by-Step 背后的数学原理
本节将介绍 “Distilling Step-by-Step” 背后的数学机制。我们将逐一分解公式,并探讨它们的含义、优势和潜在局限性。
使用思维链 (CoT) 提示提取推理过程
思维链推理提示是一种指导大型语言模型生成过程的方法, 其核心是影响 LLM 输出的概率分布。
将 LLM 视为一个函数,给定一个输入提示 P,它会生成一个输出序列 Y。在标准提示中,我们的目标是最大化给定提示 P 的所需输出 y(标签)的概率。我们可以将其表示为:
CoT 提示改变了这一目标。希望鼓励 LLM 不仅生成最终答案 y,而且生成导致 y 的中间推理过程 r。通过制作展示这种输入-推理过程-标签结构的提示来实现这一点,正如在示例三元组中讨论的那样。
虽然没有 CoT 提示的直接公式,但可以将其视为将 LLM 的生成条件设置为特定风格的输出,即包含显式推理步骤的输出。这促使 LLM 的内部决策过程变得更加透明和循序渐进。
CoT 的有效性在于其涌现行为。大型语言模型在正确提示时,表现出模仿思维链过程的能力,即使它们没有经过明确的训练来以这种方式生成推理过程。
需要注意的是,CoT 提示仍然具有一定的技巧性。制作有效的提示需要直觉和实验。不能保证生成的推理过程总是完美的,甚至完全准确。它们是由一个模型生成的,虽然令人印象深刻,但并非万无一失。这些推理过程的质量会影响后续知识蒸馏过程的有效性。如果教师的推理有缺陷,学生也可能会学到有缺陷的推理。
尽管存在这些警告,CoT 提示提供了一种强大的方式来利用 LLM 的推理能力,并提取有价值的监督信号来训练较小的模型。此阶段的输出是一个输入-推理过程-标签三元组 (xi , r ^i , y ^ i ) 的数据集,可用于下一阶段:多任务学习。
多任务学习目标
本节将介绍 “Distilling Step-by-Step” 的数学核心。我们将从提取推理过程转变为在多任务学习框架中利用它们。核心思想是训练学生模型同时执行两项任务:标签预测和推理过程生成。这是通过组合损失函数来实现的。
第一个任务是标签预测。我们希望学生模型 f 能够准确地预测给定输入 x_i 的正确标签 y ^ i 。为了衡量其表现,我们使用标准的交叉熵损失,表示为 L_label 。对于 N 个示例的数据集,标签预测损失计算如下:
其中,ℓ( f(x_i) , y^i) 表示学生模型预测的标签概率分布 f(xi ) 与目标标签 y ^ i 之间的交叉熵损失。当模型预测的概率偏离正确标签时,此损失函数会惩罚模型。交叉熵损失是分类任务中一个完善的损失函数, 能够有效地引导模型做出准确的预测。然而,交叉熵损失主要侧重于标签准确性,并不直接鼓励模型学习推理或生成解释。
为了鼓励学生模型学习推理,我们引入了第二个任务:推理过程生成。我们希望模型不仅预测标签,而且为输入 x_i 生成一个合理的推理过程 r ^i 。同样使用交叉熵损失,但这次应用于推理过程中的标记序列:
其中,ℓ( f(xi ) , r ^i ) 是在标记序列上计算的交叉熵损失。学生模型不仅会因不正确的标签而受到惩罚,还会因生成与教师生成的基本原理 r ^i 不同的推理过程而受到惩罚。此损失函数直接鼓励模型学习为其预测生成人类可读的解释, 从而将教师 LLM 的“推理风格”注入到较小的学生中。此损失函数基于模仿,要求学生模仿教师的推理过程,这可能并不总是最佳或最有效的推理形式。学生模型可能存在其他替代方法,甚至更好的方法来得出正确的答案,但是此损失函数会使其偏向于复制教师的推理过程。
组合多任务损失:LL
为了训练学生模型同时执行这两项任务,我们将这两个损失函数组合成一个单一的联合损失函数:
这是一个简单而优雅的组合。我们将标签预测损失和推理过程生成损失相加,并乘以一个因子 λ。在研究论文中,通常将 λ 设置为 1,从而使这两项任务同等重要。加权因子 λ 使我们能够调整推理过程生成与标签预测的相对重要性。如果我们想优先考虑标签准确性,可以减小 λ。相反如果想强调学习推理和生成解释,可以增加 λ。在实践中,将 λ 设置为 1 通常可以达到很好的平衡。
此组合损失函数驱动 “Distilling Step-by-Step” 中的多任务学习过程。通过最小化此联合损失,可以激励学生模型精通预测标签和生成推理过程。这是一种巧妙的方式,可以利用教师生成的基本原理中包含的丰富信息来指导较小、更高效且可能更具洞察力的学生模型的训练。
在 Python 中构建 Distilling Step-by-Step
本节将介绍如何使用 Python 代码实现 “Distilling Step-by-Step” 方法。我们将参考 Google 研究人员发布的原始 GitHub 存储库。
首先需要一种加载数据的方法,将使用该方法来知识蒸馏我们的学生模型。以下是 data_utils.py 文件的功能:
-
从各种来源加载数据集:无论是使用 Hugging Face load_dataset 方法还是从自定义 JSON 文件加载,DatasetLoader 基类(及其子类)都会对其进行管理。
-
准备输入和目标:例如,CQADatasetLoader 从问题及其多项选择答案构建一个组合输入字符串。它还会删除不必要的列,以便下游训练仅看到重要的内容。
-
集成 LLM 输出:想要使用来自 LLM 的外部推理过程和标签吗?加载器可以从 JSON 文件(用于 PaLM 或 GPT 预测)中读取这些内容,并将它们解析为结构化列。
class DatasetLoader(object):
def __init__(self, dataset_name, source_dataset_name, dataset_version, has_valid, split_map,
batch_size, train_batch_idxs, test_batch_idxs, valid_batch_idxs=None):
self.data_root = DATASET_ROOT
self.dataset_name = dataset_name
self.source_dataset_name = source_dataset_name
self.dataset_version = dataset_version
self.has_valid = has_valid
self.split_map = split_map
# (Additional setup omitted for brevity…)
def load_from_source(self):
if self.dataset_version is None:
datasets = load_dataset(self.source_dataset_name)
else:
datasets = load_dataset(self.source_dataset_name, self.dataset_version)
return datasets
def to_json(self, datasets):
for k, v in self.split_map.items():
datasets[v].to_json(f'{self.data_root}/{self.dataset_name}/{self.dataset_name}_{k}.json')
# …plus methods for loading from JSON and parsing LLM/GPT outputs.
每个具体的加载器(例如 CQADatasetLoader、SVAMPDatasetLoader 等)都实现特定于数据集的逻辑。例如,CQA 加载器的 _post_process 方法通过将问题与其答案选项组合在一起来构建输入字符串:
def prepare_input(example):
question = example['question']
c_0 = example['choices'][0]
# …other choices...
input = f'{question}\nAnswer Choices:\n(a) {c_0}\n(b) {example["choices"][1]}\n(c) {example["choices"][2]}\n(d) {example["choices"][3]}\n(e) {example["choices"][4]}'
example['input'] = input
example['label'] = example['answer']
return example
这样管道就为多任务训练做好了准备, 模型既要学习预测答案,又要生成推理过程。
metrics.py 文件包含计算文本和方程预测准确性的函数。
-
-
方程准确性:评估字符串表达式(以受控方式使用 Python 的 eval)以查看计算出的答案是否匹配。
def compute_equation_acc(preds, labels):
preds = [eval_equation(pred) for pred in preds]
labels = [eval_equation(label) for label in labels]
return np.mean(np.array(preds) == np.array(labels))
eval_equation 函数尝试安全地计算方程的结果:
def eval_equation(equation):
try:
answer = eval(equation)
except:
answer = np.nan
return answer
当任务不仅仅是分类,而是涉及更复杂的推理(例如数学问题解决)时,此模块至关重要。
3、多任务模型和训练器设置:model_utils.py
model_utils.py 文件包含多任务模型和训练器的设置。该管道通过使用任务前缀来区分预测(回答)和解释(生成推理过程)来支持多任务训练。该代码使用一些 HuggingFace 类作为父类,并在其之上进行扩展。
TaskPrefixDataCollator 采用一批示例,并将其拆分为两个字典: