> 技术文档 > 【论文阅读】ARM: Adaptive Reasoning Model

【论文阅读】ARM: Adaptive Reasoning Model


ARM: Adaptive Reasoning Model

    • 方法
      • 第一阶段:SFT for Reasoning Formats Understanding
      • 第二阶段:RL for Encouraging Efficient Format Selection

ARM: Adaptive Reasoning Model 这篇文章介绍了自适应推理模型(Adaptive Reasoning Model, ARM),该模型能够根据任务难度自适应地选择推理格式,从而在保持性能的同时提高计算效率。ARM支持四种推理格式:三种高效的格式——直接回答(Direct Answer)、短链思考(Short CoT)和代码(Code),以及一种详细的格式——长链思考(Long CoT)

【论文阅读】ARM: Adaptive Reasoning Model

本篇博客仅介绍方法上的创新, 项目地址:https://team-arm.github.io/arm/

方法

第一阶段:SFT for Reasoning Formats Understanding

在这一阶段,文章利用SFT作为冷启动,将模型引入可以用于解决问题的各种推理格式。文章使用特殊标记(例如,)来包含思考逻辑:

  1. 直接回答(Direct Answer):这种格式直接给出答案,不包含推理过程。
  2. 短链推理(Short CoT):这种格式提供简短的推理过程,通常用于简单的任务。
  3. 代码(Code):这种格式使用代码来解决问题,适用于需要编程的场景。
  4. 长链推理(Long CoT):这种格式提供详细的推理过程,适用于复杂的任务。

为了确保生成的推理逻辑的质量,文章过滤掉那些导致错误答案的推理逻辑,最终生成的训练集包含3.0K个多项选择题和7.8K个开放形式问题,每个问题都有四种推理格式。文章在这一阶段使用AQuA-Rat数据集因为它可以自然地转化为四种不同的推理形式。除了数据集中提供的Direct AnswerShort CoT推理外,文章还利用GPT-4o 和DeepSeek-R1 分别补充了CodeLong CoT推理。

第二阶段:RL for Encouraging Efficient Format Selection

经过SFT后,模型学会了使用各种推理格式进行响应,但缺乏根据任务自适应切换格式的能力。为了解决这一问题,文章在第二阶段使用RL来鼓励模型选择更高效的推理格式,同时保持准确性。在这一阶段,文章使用了三个额外的数据集,这些数据集涵盖了从相对简单的常识推理任务到更复杂的数学问题的范围。这些数据集包括:

  • CSQA:常识推理任务
  • GSM8K:数学问题
  • MATH:数学问题

文章主要技术上的创新在强化学习训练的Reward设置上:

首先,作者定义了一组重塑后的奖励 r ′ ={ r 1 ′ , r 2 ′ ,⋯   , r G ′ } r\' = \\{r\'_1, r\'_2, \\cdots, r\'_G\\} r={r1,r2,,rG},这些奖励用于评估模型生成的响应。具体来说,每个响应 o i o_i oi 的奖励 r i ′ r\'_i ri 通过以下公式计算:

r i ′ = α i ( t ) ⋅ r i r\'_i = \\alpha_i(t) \\cdot r_i ri=αi(t)ri

其中, r i r_i ri 是原始奖励, α i (t) \\alpha_i(t) αi(t) 是一个格式多样性缩放因子,用于放大较少采样的推理格式的奖励,防止这些格式在训练过程中消失。格式多样性缩放因子 α i (t) \\alpha_i(t) αi(t) 的计算公式如下:

α i ( t ) = G F ( o i )⋅ decay i ( t ) \\alpha_i(t) = \\frac{G}{F(o_i)} \\cdot \\text{decay}_i(t) αi(t)=F(oi)Gdecayi(t)

其中, F( o i ) F(o_i) F(oi) 表示在组 O O O 中与 o i o_i oi 对应的推理格式出现的次数, t t t 表示训练步数。衰减因子 decay i (t) \\text{decay}_i(t) decayi(t) 的计算公式为:

decay i ( t ) = F ( o i ) G + 0.5 ⋅ ( 1 − F ( o i ) G ) ⋅ ( 1 + cos ⁡ ( π ⋅ t T ) ) \\text{decay}_i(t) = \\frac{F(o_i)}{G} + 0.5 \\cdot \\left(1 - \\frac{F(o_i)}{G}\\right) \\cdot \\left(1 + \\cos\\left(\\frac{\\pi \\cdot t}{T}\\right)\\right) decayi(t)=GF(oi)+0.5(1GF(oi))(1+cos(Tπt))

为了将GRPO扩展为Ada-GRPO,文章引入了格式多样性缩放因子 α i (t) \\alpha_i(t) αi(t),使模型能够自适应地选择推理格式。具体来 α i (t) \\alpha_i(t) αi(t) 由两个组件组成:

  1. Format Diversity Scaling Factor G F ( o i ) \\frac{G}{F(o_i)} F(oi)G

    • 为了防止模型过早收敛到最高准确率的格式(即格式坍缩到长链推理 Long CoT),文章通过增加较少采样格式的奖励来鼓励探索。具体来说,如果某个推理格式出现次数较少,其奖励会被放大,从而促使模型更多地尝试这些格式。
  2. Decay Factor decay i ( t ) \\text{decay}_i(t) decayi(t)

    • 为了避免因过度奖励稀有格式而导致的长期不一致,这一项逐渐减少多样性的影响。例如,Format Diversity Scaling Factor G F ( o i ) \\frac{G}{F(o_i)} F(oi)G 可能会使模型在训练初期更倾向于选择低准确率的格式(如短链推理 Short CoT),仅仅因为这些格式出现次数较少,从而获得更高的奖励。虽然这种探索在训练初期是有益的,但后期可能会阻碍模型的收敛。衰减机制通过在训练初期促进多样性,然后随着训练的进行逐渐将重点转移到准确性上来,从而缓解这一问题。