跳转至

Sharpness-Aware Pretraining Mitigates Catastrophic Forgetting

论文基本信息

  • 作者: Ishaan Watts, Catherine Li, Sachin Goyal, Jacob Mitchell Springer, Aditi Raghunathan
  • arXiv: https://arxiv.org/abs/2605.02105
  • 领域: cs.LG, cs.CL
  • 类别: 持续学习 → 灾难性遗忘缓解 → 预训练优化

摘要(翻译)

预训练优化器通常被调整为产生尽可能强大的基座模型,隐含假设是更强的起点在后续微调(如 post-training、量化)后仍会产生更强的模型。然而,这种思路忽视了基座模型的几何性质——它决定了有多少原有能力能在后续参数更新后存活下来。本文研究了三种偏向更平坦极小值的预训练优化策略:Sharpness-Aware Minimization(SAM)、大学习率和缩短的学习率退火周期。在 20M 到 150M 参数的模型规模上,这些干预措施在五个常见数据集的 post-training 后持续提升下游性能,遗忘减少高达 80%。规模化实验验证了这些原则的鲁棒性:在一个已有 OLMo-2-1B 检查点上应用短期 SAM 中期训练,在 MetaMath post-training 后遗忘减少 31%,在 4-bit 量化后减少 40%。

核心贡献

  1. 揭示了预训练优化器选择对后续能力保留的影响:证明了传统预训练目标(仅最大化基座性能)忽略了"遗忘敏感性"这一关键维度。
  2. 三种平坦化预训练策略的系统比较:SAM、大学习率、缩短退火周期,提供了可操作的实践指导。
  3. 规模化验证:在 20M-150M 参数范围验证了方法有效性,并给出了 OLMo-2-1B 检查点的具体收益(31%/40% 遗忘减少)。
  4. 即插即用的检查点提升:无需改变训练流程,仅在预训练中途插入短期 SAM 阶段即可显著改善下游适应性。
  5. 对量化的鲁棒性:方法同时减少了 post-training 和量化引入的遗忘,为端侧部署提供了更稳定的模型。

研究背景与问题

灾难性遗忘在 Post-training 中的表现

当前大模型的标准流程是:预训练 → Post-training(指令微调/RLHF)→ 量化压缩部署。在每个阶段,模型都会经历参数更新,而这种更新会损害前序阶段习得的能力: - Post-training 后,模型在预训练阶段学到的通用知识部分退化 - 量化(尤其是低比特如 4-bit)后,模型能力进一步衰减

传统方法的局限

现有方法主要关注 post-training 或量化本身的优化(如蒸馏、正则化),而忽视了基座模型本身对遗忘的敏感性。不同的预训练策略会产生对参数更新有不同"抗性"的模型几何——有的模型即使小幅度更新也会大幅损害原有能力,有的则对扰动更鲁棒。

核心洞察

平坦的损失曲面(flat minima)对应于对参数扰动更鲁棒的模型。通过在预训练阶段就偏向平坦极小值,可以产生"对遗忘天然抵抗"的基座模型。

核心方法

策略一:Sharpness-Aware Minimization (SAM)

SAM 通过在损失函数的邻域内最大化损失来显式寻找平坦极小值: $$\min_\theta \max_{\epsilon \in B} L(\theta + \epsilon)$$ 其中 $B$ 是半径为 $\rho$ 的球形邻域。这鼓励解位于损失平面上宽阔的极小值区域。

策略二:大学习率

直觉:更大的学习率倾向于探索更宽的损失盆地,产生更平坦的解。

策略三:缩短学习率退火周期

标准预训练在学习率衰减前有一个长的平坦期(保持高学习率)。作者发现跳过或大幅缩短这个退火阶段能产生更平坦的极小值——因为后期退火会将参数驱动到更尖锐的极小值。

实验设计

  • 模型:Pythia、OLMo 系列,20M-150M 参数
  • Post-training 数据集:5 个常见指令微调数据集
  • 评估指标:遗忘率(相对于预训练基座的性能衰减)

为什么重要

  1. 从源头解决遗忘问题:不是等 post-training 时再打补丁,而是在预训练阶段就塑造对遗忘鲁棒的模型
  2. 工程可行性强:无需改变训练框架,仅需调整超参数或在训练中途插入 SAM 阶段
  3. 对端侧部署有直接价值:减少量化后遗忘意味着可以用更低的比特数达到相同的下游性能

与移动端/端侧相关性

高度相关。这是对端侧部署最友好的研究之一:

  • 低比特量化友好:4-bit 量化后仍保持 40% 更少的遗忘,意味着可以在端侧使用更激进的量化(如 2-3 bit)而不损失太多能力
  • 预训练成本一次性承担:SAM 等优化在预训练时执行,部署时的计算成本不变——非常适合资源受限的端侧设备
  • 检查点复用价值:已有的开源基座模型(如 OLMo)只需短暂 SAM 训练即可获得更鲁棒的版本,降低了实际应用成本

关键词:灾难性遗忘、平坦极小值、Post-training 鲁棒性、低比特量化、端侧部署、模型编辑