Branch-Merge 蒸馏:大语言模型压缩的革命性突破


大语言模型(LLMs)在自然语言处理领域取得了显著成就,但其庞大的参数规模给部署和应用带来了挑战。现有的模型蒸馏和迁移学习方法难以达到高精度要求,存在数据选择繁琐、梯度冲突等问题。 为解决这些难题,QiYuan Tech 联合北大提出了 Branch-Merge 蒸馏方法。
该方法将模型训练分为分支和合并两个阶段。在分支阶段,通过领域特定的监督微调,使学生模型在特定领域内具备专业能力。在合并阶段,利用 Arcee Fusion 技术,将多个专业学生模型进行深度融合,确保关键知识的保留和融合。

Branch-Merge 蒸馏方法显著提高了模型的准确性,接近教师模型的性能水平,同时大幅降低了计算成本和时间。

实验结果表明,通过将 DeepSeek-R1 作为教师模型,DeepSeek-R1-Distill-Qwen-32B 作为学生模型,由此产生的合并模型 TinyR1-32B-Preview 在多个基准测试中优于 DeepSeek-R1-Distill-Qwen-32B,包括数学(高 5.5 分)、编程(高 4.4 分)和科学(高 2.9 分),同时在 2024 年 AIME 测试中性能接近 DeepSeek-R1。

image.png

一、引言

1. 大语言模型压缩的挑战


随着大语言模型(LLMs)的快速发展,其在自然语言处理领域取得了显著成就。然而,这些模型通常具有庞大的参数规模,如数十亿甚至上百亿参数,这使得它们在部署和实际应用中面临诸多挑战。
  • 首先,庞大的模型尺寸需要大量的存储空间,这对于资源有限的设备和应用场景来说是一个巨大的负担。
  • 其次,大模型的推理过程需要消耗大量的计算资源和时间,导致推理成本高昂,难以满足实时性要求。此外,大模型的训练和优化也变得更加复杂和困难,需要更多的计算资源和时间来完成。

因此,如何在保持模型性能的前提下,有效减小模型的规模,降低其计算成本,成为当前大语言模型研究领域的一个关键挑战。

2. 现有方法的局限性

为了应对大语言模型压缩的挑战,现有的方法如模型蒸馏和迁移学习等被广泛研究和应用。然而,这些方法在实际应用中往往难以达到高精度的要求。

  • 一方面,模型蒸馏方法通常需要精心选择最相关的数据/领域,并调整它们在联合训练中的比例,这一过程不仅耗时费力,而且容易出错。
  • 另一方面,同时优化多个领域时,可能会导致梯度冲突,不同任务之间相互干扰,阻碍整体学习进度,从而限制了模型在多个领域的性能提升。

因此,现有的模型压缩方法在有效性和效率方面存在一定的局限性,难以满足实际应用中对高精度和高效压缩的需求。

3. Branch-Merge 蒸馏方法的提出

为了解决现有方法的局限性,我们提出了一种新的模型压缩方法——Branch-Merge 蒸馏方法。

该方法通过两个阶段来增强模型的压缩效果:

  1. 分支阶段(Branch Phase),将大型教师模型的知识选择性地蒸馏到多个专业学生模型中,通过领域特定的监督微调(SFT)实现;
  2. 合并阶段(Merge Phase),将这些学生模型合并,实现跨领域的知识转移,同时保留其原始的专业能力。

Branch-Merge 蒸馏方法通过解耦训练领域并在之后进行整合,直接解决了数据选择和梯度冲突的问题,为创建更小、性能更高且计算成本和时间更低的大语言模型提供了一种可扩展的解决方案。

Image

通过将 DeepSeek-R1 作为教师模型,DeepSeek-R1-Distill-Qwen-32B 作为学生模型,来验证我们的蒸馏方法。由此产生的合并模型 TinyR1-32B-Preview 在多个基准测试中优于 DeepSeek-R1-Distill-Qwen-32B,包括数学(高 5.5 分)、编程(高 4.4 分)和科学(高 2.9 分),同时在 2024 年 AIME 测试中性能接近 DeepSeek-R1。

分支合并蒸馏方法为创建更小、性能更高且计算成本和时间更低的大语言模型提供了一种可扩展的解决方案。

二、Branch-Merge 蒸馏方法概述

1. 方法的核心思想

Branch-Merge 蒸馏方法的核心思想是通过两个阶段来实现大语言模型的有效压缩,同时保持高性能。具体如下:

  • 分支阶段(Branch Phase):在这个阶段,首先为不同的领域(如数学、科学、编程等)分别构建数据集。然后,使用这些数据集对大型教师模型(如 DeepSeek-R1 671B)进行领域特定的监督微调(SFT),从而得到多个专业学生模型。每个学生模型在特定领域内具有较强的专业能力,能够有效地捕捉和学习该领域的知识和模式。
  • 合并阶段(Merge Phase):在这一阶段,使用 Arcee Fusion 方法将不同领域的专业学生模型进行合并。通过计算参数的重要性分数、动态选择阈值以及选择性集成等步骤,将多个模型的知识进行融合,形成一个统一的模型。合并后的模型不仅保留了各个学生模型在特定领域的专业能力,还实现了跨领域的知识转移,从而提高了模型的泛化能力和整体性能。

这种方法通过将训练过程分解为分支和合并两个阶段,巧妙地解决了传统模型蒸馏方法中数据选择和梯度冲突的问题。分支阶段允许模型在特定领域内进行深入学习,而合并阶段则通过有效的模型整合策略,将不同领域的知识融合到一个模型中,实现了模型的高效压缩和性能提升。

2. 方法的优势与特点

Branch-Merge 蒸馏方法具有以下显著的优势与特点:

  • 显著提升模型准确性:实验结果表明,Branch-Merge 蒸馏方法能够显著提高模型的准确性,使其接近教师模型的性能水平。以蒸馏后的 Qwen-32B 模型为例,其数学准确性接近原始的 R1 教师模型,整体性能相比传统蒸馏方法有了质的飞跃。在多个基准测试中,如数学、编程和科学等领域,该方法蒸馏得到的模型均表现出色,超越了相同大小的其他蒸馏模型,展现出其在不同领域的广泛适用性和强大的知识融合能力。
  • 简单且成本低廉:Branch-Merge 蒸馏方法在实现高性能的同时,还具有简单易行和成本低廉的特点。与传统方法相比,该方法在合并阶段显著减少了时间和计算成本。例如,在合并过程中,仅需使用 4 个 H800 GPU 在 0.5 小时内即可完成,而传统方法则需要使用 32 个 H800 GPU 耗费 23 小时进行合并数据的重新训练。这种高效的模型合并方式,不仅节省了大量的计算资源和时间,还降低了模型蒸馏的门槛,使得更多的研究者和开发者能够轻松地应用该方法来创建高性能的小模型。
  • 开放性:该方法的提出者秉持开源精神,致力于将研究成果回馈给开源社区。他们将公开发布蒸馏后的模型以及所有相关数据、训练代码、评估代码和日志等资源,以便任何人都可以方便地复现实验结果。这种开放性的做法,不仅有助于推动学术研究的进一步发展,还能够促进大语言模型技术在更广泛领域的应用和创新,为整个 AI 社区的发展做出积极贡献。

三、Branch 阶段:知识的分支化蒸馏

1. 数据集的构建

在 Branch 阶段,首先需要为不同的领域构建专门的数据集,这些数据集将用于训练专业学生模型,使其在特定领域内具有强大的专业能力。

(1)数学领域数据集

数学领域数据集的构建基于 NuminaMath1.5 数据集。从该数据集中筛选出 58k 个样本,筛选过程主要考虑三个方面:问题类型(question_type)、来源(source)以及数学验证的正确性(correctness_math_verify)。通过这些标准,确保所选样本在数学领域具有代表性和高质量,能够有效支持学生模型在数学领域的学习和训练。

(2)编程领域数据集

编程领域数据集来源于 OpenThoughts 数据集。经过筛选,形成了包含 20k 个编程解决方案轨迹的数据集。在筛选过程中,对原始数据集中的“<|begin_of_thought|>”替换为“”,并将“<|end_of_solution|>”替换为“”,以适应模型训练的需求,使学生模型能够更好地理解和生成编程相关的思考和解决方案。

(3)科学领域数据集

科学领域数据集的构建涉及多个数据源。首先,从 data_ablation_full59k 的科学和健康科学子集中选取 2.7k 个种子示例;其次,从 S1k 数据集中选取 1.0k 个示例;最后,从 OpenThoughts 的科学子集中选取 4.9k 个示例。通过 DeepSeek-R1 模型为每个种子示例生成 1 个 CoT(Chain of Thought,思维链)轨迹,最终得到 8.6k 个 CoT 轨迹,构成科学领域数据集。这些数据涵盖了广泛的科学知识和思维过程,有助于学生模型在科学领域进行深入学习。

2. 领域特定的监督微调(SFT)

在构建好各领域的数据集后,使用领域特定的监督微调(SFT)方法,将大型教师模型的知识选择性地蒸馏到专业学生模型中。SFT 是一种针对特定领域进行微调的技术,通过在特定领域的数据上对预训练模型进行进一步训练,使其更好地适应该领域的任务和需求。在 Branch-Merge 蒸馏方法中,SFT 用于将教师模型在不同领域的知识分别传递给对应的学生模型,使学生模型在特定领域内具备专业的知识和能力。

3. 专家模型的训练细节

为了确保专业学生模型在各自领域的高性能,训练过程采用了不同的训练细节和参数设置,以适应各领域的特点和需求。

  • 数学专家模型:数学专家模型的训练采用了 5 个训练轮次(epochs),批次大小(batch size)设置为 96,学习率(learning rate)保持恒定为 1e-5。这种设置有助于模型在数学领域数据上充分学习,掌握数学问题的解题思路和方法,同时避免过拟合,确保模型具有良好的泛化能力。
  • 科学专家模型:科学专家模型的训练轮次同样为 5 次,但批次大小减小至 32,并采用了 neat packing 机制。学习率采用余弦退火策略,初始值为 1e-5。neat packing 机制能够有效地利用内存,提高训练效率,而余弦退火学习率策略则有助于模型在训练过程中更好地收敛,提升模型在科学领域的性能。
  • 编程专家模型:编程专家模型的训练轮次增加到 15 次,批次大小为 96,并使用 neat packing 机制。学习率同样保持恒定为 1e-5。较多的训练轮次和较大的批次大小,使得模型能够更深入地学习编程领域的知识和模式,掌握编程问题的解决方法和技巧,从而在编程任务中表现出色。

通过以上训练细节的精心设计,各个专业学生模型在各自领域内得到了充分的训练,具备了较强的专业能力和性能,为后续的合并阶段奠定了坚实的基础。

四、Merge 阶段:模型的合并与知识融合

image.png

合并过程分为三个步骤:

1. Arcee Fusion 方法的原理

Arcee Fusion 方法是 Merge 阶段的核心技术,通过以下步骤实现模型参数的融合:

(1)计算重要性分数

image.png

(2)计算动态选择

image.png

(3)选择性集成

image.png

2. 多领域模型的合并流程

在实际应用中,我们涉及三个模型(数学、编程、科学)的合并。具体的合并流程如下:

  1. 首先,将数学和编程领域的模型进行合并,得到一个中间合并模型。
  2. 然后,将该中间合并模型与科学领域的模型进行合并,最终得到一个包含三个领域知识的统一模型。
  3. 在每次合并过程中,都按照上述 Arcee Fusion 方法的步骤,计算重要性分数、动态选择阈值,并进行选择性集成,确保每次合并都能有效地融合两个模型的知识。

3. 合并过程中的关键技术细节

  • 参数更新的稳定性:通过选择性集成,Arcee Fusion 方法避免了过度更新参数,保持了模型的稳定性。这种方法确保了在合并过程中,模型不会因为参数的剧烈变化而性能下降。
  • 两两合并策略:尽管 Arcee Fusion 方法每次只能合并两个模型,但通过合理的合并顺序,可以有效地将多个模型的知识融合到一个模型中。在我们的工作中,通过先合并数学和编程模型,再与科学模型合并,取得了较好的效果。
  • 高效性:与传统的数据混合蒸馏方法相比,Arcee Fusion 方法在合并阶段显著减少了计算成本和时间。例如,TinyR1-32B-Preview 的模型合并时间仅为 4 小时,而传统的数据混合方法需要 740 小时的 GPU 时间。这使得 Branch-Merge 蒸馏方法在实际应用中更加高效和可行。

五、实验与结果分析

1. 实验设置

(1)训练细节

在实验中,我们使用 DeepSeek-R1-Distill-Qwen-32B 作为基础模型,并基于 360-LlamaFactory 训练框架开发了三个领域的专家模型。以下是各领域专家模型的训练细节:

  • 数学专家模型:训练 5 轮,每轮的批次大小为 96,学习率设置为 1e-5(固定)。
  • 科学专家模型:训练 5 轮,每轮的批次大小为 32,并采用 neat packing 机制以提高训练效率,学习率采用余弦退火策略,初始值为 1e-5。
  • 编程专家模型:训练 15 轮,每轮的批次大小为 96,同样采用 neat packing 机制,学习率固定为 1e-5。

在合并阶段,我们使用 Arcee Fusion 方法,设置 θ=1.5 和阈值 THR=0.5,将三个领域的专家模型合并为一个统一的模型。

(2)评估细节

为了评估模型的性能,我们选择了三个基准数据集:AIME 2024(数学)、LiveCodeBench(编程)和 GPQA-Diamond(科学)。模型的准确率是通过计算在这些基准测试中通过的题目比例(pass@1)来衡量的。为了确保结果的可靠性,我们在每个基准测试中分别进行了 16 次、4 次和 4 次独立试验,并计算了平均值作为最终的准确率。

在评估过程中,我们没有使用贪婪解码方式,而是将最大输出标记数设置为 32768,并以 Temperature=0.6 和 Top-p=0.95 的参数进行模型评估,这与 DeepSeek-R1 的推荐设置一致。此外,我们尝试了多种开源评估框架,并最终选择了 FuseAI 提供的评估代码,该代码利用 vLLM 实现,能够复现 DeepSeek-R1 及其蒸馏模型的效果。

2. 主要实验结果

(1)与基线模型的对比

表 1 展示了 TinyR1-32B-Preview 模型与其他基线模型的性能对比。

Image

从结果可以看出,TinyR1-32B-Preview 在数学、编程和科学领域的准确率分别比其基础模型 DeepSeek-R1-Distill-Qwen-32B 高出 5.5、4.4 和 2.9 个百分点。此外,它在数学和编程领域的表现也优于 DeepSeek-R1-Distill-Llama-70B,分别高出 8.1 和 4.1 个百分点,尽管在科学领域的表现略低 0.2 个百分点。

总体而言,TinyR1-32B-Preview 的性能接近 DeepSeek-R1,仅在数学、编程和科学领域分别低 1.7、4.3 和 6.5 个百分点。

(2)不同领域任务的表现

在不同领域的任务中,TinyR1-32B-Preview 表现出色。

  • 在数学领域,它达到了 78.1%的准确率,接近 DeepSeek-R1 的 79.8%;
  • 在编程领域,其准确率为 61.6%,比 DeepSeek-R1 的 65.9%略低,但仍然优于其他蒸馏模型;
  • 在科学领域,准确率为 65.0%,与 DeepSeek-R1 的 71.5%相比有一定差距,但仍然优于其他蒸馏模型。

这些结果表明,Branch-Merge 蒸馏方法能够有效地提升模型在多个领域的性能,尤其是在数学和编程领域。

3. 消融研究

(1)与领域专家模型的对比

表 2 展示了 TinyR1-32B-Preview 与各领域专家模型的性能对比。

Image

从结果可以看出,经过合并的 TinyR1-32B-Preview 在数学和科学领域的性能优于单独的领域专家模型,同时在编程领域的性能也接近专家模型。这表明,Branch-Merge 蒸馏方法能够有效地整合不同领域的知识,提升模型的综合性能。

(2)与数据混合模型的对比

数据混合模型是通过将数学、编程和科学领域的数据混合在一起进行训练得到的。与数据混合模型相比,TinyR1-32B-Preview 在数学和科学领域的性能更高,尽管在编程领域的性能略有下降。这说明,传统的数据混合方法在处理多领域任务时存在一定的局限性,而 Branch-Merge 蒸馏方法能够更好地解决领域之间的冲突,提升模型的泛化能力。

(3)不同合并顺序的影响

我们还研究了不同合并顺序对模型性能的影响。实验结果表明,先合并数学和编程领域的模型,再与科学领域的模型合并(即“数学&编程&科学”顺序)的性能略低于先合并数学和科学领域的模型,再与编程领域的模型合并(即“数学&科学&编程”顺序)的性能。这表明,合并顺序可能会对模型的最终性能产生一定的影响,但总体来说,两种顺序的性能差异较小。

此外,从计算成本的角度来看,TinyR1-32B-Preview 的合并时间仅为 4 小时,而传统的数据混合模型需要 740 小时的 GPU 时间。这表明,Branch-Merge 蒸馏方法不仅在性能上优于传统的数据混合方法,而且在计算成本上也具有显著的优势。

六、相关工作

1. 模型蒸馏技术的发展

模型蒸馏(Knowledge Distillation, KD)是一种将大型预训练模型(教师模型)的知识迁移到小型模型(学生模型)的技术,旨在提高学生模型的性能,同时降低计算和存储成本。自 Hinton 等人在 2015 年首次提出以来,模型蒸馏技术得到了广泛的研究和应用。

  • Logits-based 方法:这是最早被研究的蒸馏方法之一,主要通过将教师模型的输出(logits)作为软目标(soft targets)来指导学生模型的训练。Hinton 等人在 2015 年的工作中展示了这种方法的有效性,证明了通过模仿教师模型的输出,学生模型可以学习到更丰富的知识,从而在较小的模型规模下达到接近教师模型的性能。
  • Feature-based 方法:与 Logits-based 方法不同,Feature-based 方法通过在中间层传递知识来实现蒸馏。Romero 等人在 2015 年提出了 FitNets,这种方法不仅关注输出层,还通过匹配教师和学生模型的中间层特征来提高学生模型的性能。
  • 数据增强与自蒸馏:近年来,数据增强技术被引入到模型蒸馏中,通过生成与特定领域相关的数据来提升学生模型的性能。Taori 等人在 2023 年的工作中展示了如何利用少量的种子知识来引导大型语言模型生成更多领域的数据,从而实现更有效的知识迁移。此外,自蒸馏(self-distillation)技术也逐渐受到关注,其中开源的大型语言模型通过 API 调用的方式,利用自身生成的数据进行自我改进。
  • API-based 蒸馏:这种方法通过 API 调用获取教师模型的输出,从而实现知识的迁移。Yuan 等人在 2024 年的工作中展示了 API-based 蒸馏在缩小开源和闭源模型性能差距方面的潜力。这种方法包括 In-Context Learning(上下文学习)、Chain-of-Thought(思维链)和 Instruction Following(指令跟随)等多种策略。

2. 模型合并方法的研究进展

模型合并(Model Merging)是另一种提升模型性能的技术,通过将多个模型的参数进行融合,以实现更好的性能和泛化能力。近年来,模型合并方法得到了广泛的研究和探索。

  • 线性插值方法:早期的模型合并方法主要基于线性插值技术,例如权重平均(weight averaging)。这种方法通过计算多个模型参数的算术平均值来生成一个新的模型。虽然这种方法计算效率高,但在模型优化轨迹差异较大时,可能会导致性能下降。
  • 基于不确定性的合并方法:Daheim 等人在 2023 年提出了一种基于不确定性的模型合并方法,通过匹配梯度来减少模型间的干扰。这种方法考虑了模型参数的不确定性,从而在合并过程中更好地保留了每个模型的优势。
  • 稀疏掩码合并方法:Davari 和 Belilovsky 在 2024 年提出了一种使用稀疏掩码的模型合并方法,通过选择性地保留重要参数来减少合并过程中的信息损失。这种方法在保留关键知识的同时,提高了合并模型的性能。
  • 模型融合工具:Goddard 等人在 2024 年提出了 Arcee Fusion 工具,这是一种专门用于合并大型语言模型的工具。Arcee Fusion 通过计算参数的重要性分数,并根据动态阈值选择性地更新参数,从而在合并过程中保持模型的稳定性和性能。
  • 理论基础与实验验证:模型合并的理论基础主要来自于对损失函数几何形状的研究。研究表明,通过平均权重得到的模型通常位于损失函数的平坦区域,这与更好的泛化能力相关。此外,模型合并的实验验证也表明,通过聚合权重得到的模型在分布偏移下通常优于单个模型。

综上所述,模型蒸馏和模型合并技术在提升模型性能、降低计算成本和提高泛化能力方面都取得了显著进展。这些技术的发展为 Branch-Merge 蒸馏方法的提出提供了理论基础和实践指导。

七、结论与未来展望

1. 方法的总结与贡献

本文提出了 Branch-Merge 蒸馏方法,通过两个阶段——分支阶段(Branch Phase)和合并阶段(Merge Phase)——有效地提升了大语言模型的压缩效率和性能。

  • 在分支阶段,我们将大型教师模型的知识选择性地蒸馏到多个专业学生模型中,通过领域特定的监督微调(SFT)使其在特定领域内具备专业能力。
  • 在合并阶段,利用 Arcee Fusion 方法将这些专业模型合并,实现了跨领域的知识转移和性能提升。

实验结果表明,Branch-Merge 蒸馏方法显著提高了模型的准确性,接近教师模型的性能水平,同时大幅降低了计算成本和时间。此外,我们还计划开源模型及相关代码,以促进社区的进一步研究和应用。

Image

2. 未来可能的研究方向

尽管 Branch-Merge 蒸馏方法已经取得了显著的成果,但仍有许多值得进一步探索的方向:

(1)探索其他主干模型

目前,我们使用 DeepSeek-R1 及其蒸馏模型作为基础架构。未来,我们可以探索其他高性能的主干模型,例如 Qwen-Instruct 系列模型。初步实验表明,使用 Qwen-14B-Instruct 和 Qwen-32B-Instruct 作为主干模型进行 SFT,也能在特定任务上取得类似的优异结果。这表明 Branch-Merge 蒸馏方法具有广泛的适用性,可以与多种主干模型结合,进一步提升模型性能。

(2)发布多种尺寸的模型

为了满足不同应用场景的需求,我们计划扩展模型阵容,发布多种尺寸的模型。这些模型将涵盖从小型到大型的不同参数规模,使用户能够根据具体需求选择合适的模型。例如,小型模型适合在资源受限的设备上运行,而大型模型则可以在需要更高性能的场景中使用。通过提供多种尺寸的模型,我们希望能够推动大语言模型在更多领域的应用和普及。

(3)深入研究实验细节的影响

在当前的研究中,我们已经发现实验设置对最终性能有显著影响。未来,我们将进一步深入分析各种实验细节,例如学习率调度、训练轮次、数据集选择等,以优化模型性能。此外,我们还将研究如何更好地平衡模型的准确性和计算效率,以及如何进一步减少模型的参数规模,而不损失性能。通过这些研究,我们希望能够为大语言模型的压缩和优化提供更全面的解决方案。

总之,Branch-Merge 蒸馏方法为大语言模型的压缩和性能提升提供了一种新的思路和方法。我们相信,通过进一步的研究和探索,这一方法将在未来的大语言模型发展中发挥更大的作用。

相关文章

END

作者:三月三
来源:NeuralTalk

推荐阅读

欢迎大家点赞留言,更多 Arm 技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18920
内容数
1431
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息