跳转至

RL: Training Inference Mismatch

导言

  • 25年,RL训练崩溃归因于训推不一致;
  • 为此提出了很多方法,TIS,Router Replay,FP16训推,batch一致性...
  • 如何判断 模型当前训推不一致,并找到不一致实现处,是实践的要点。

基本概念

KL 散度

Reverse KL Divergence

GRPO 使用的是 Reverse KL(采样自当前策略 π_θ,参考策略为 π_ref):

\[D_{KL}(\pi_\theta \| \pi_{ref}) = \mathbb{E}_{x \sim \pi_\theta}\left[\log\frac{\pi_\theta(x)}{\pi_{ref}(x)}\right]\]

在 LLM 场景下,按 token 级别计算:

KL = Σ_t π_θ(o_t | q, o_<t) · log[π_θ(o_t | q, o_<t) / π_ref(o_t | q, o_<t)]

三种 KL 估计方法

由于直接计算 KL 的期望成本高,实践中采用蒙特卡洛近似。常见有三种估计器 [[49]][[50]]:

估计器 公式 特性 适用场景
k1 -log(r),其中 r = π_ref/π_θ 无偏但方差极大,梯度不含 π_ref PPO 中作为 reward shaping,不适合作为独立 KL loss
k2 0.5 * (log(r))² 有偏但方差低,梯度等价于 Reverse KL ✅ GRPO 推荐(VeRL/TRL 默认)
k3 (r - 1) - log(r) 无偏、方差低,但梯度等价于 Forward KL 需注意:采样分布不匹配时可能不稳定

关键代码逻辑(VeRL/TRL 实现):

# 假设已获取 per-token logps
log_ratio = ref_logps - actor_logps  # log(π_ref/π_θ)

# k1: 原始 KL (高方差)
kl_k1 = -log_ratio

# k2: 平方近似 (低方差,推荐)
kl_k2 = 0.5 * (log_ratio ** 2)

# k3: Bregman 形式 (无偏,但梯度对应 Forward KL)
kl_k3 = (log_ratio.exp() - 1) - log_ratio

# 最终 loss 加入
loss = policy_loss + beta * kl_loss_type(actor_logps, ref_logps)

🔍 为什么 k2 更推荐?
- k2 的梯度:∇_θ [0.5*(log r)²] = (log r) · ∇_θ log π_θ,恰好匹配 Reverse KL 的实用梯度形式 [[49]]
- k3 虽然无偏,但其梯度对应 Forward KL,在 π_θ 与 π_ref 差距较大时,重要性采样权重 π_ref/π_θ 可能爆炸,导致训练不稳定

监控指标的计算流程(每步 RL)

1️⃣ 采样阶段:
   - 对每个 query,用当前策略 π_θ 生成 G 个 completions
   - 同时用参考策略 π_ref 计算相同 tokens 的 log-probs

2️⃣ 计算 per-token KL:
   log_ratio = log(π_ref) - log(π_θ)
   kl_token = kl_loss_type(log_ratio)  # k1/k2/k3 选一

3️⃣ 聚合为标量指标:
   - 按 completion 平均:kl_per_seq = mean(kl_token over tokens)
   - 按 batch 平均:kl_monitor = mean(kl_per_seq over all completions)

4️⃣ 用于:
   - 📊 监控:TensorBoard/W&B 记录 kl 曲线,判断策略漂移
   - ⚖️ 正则:loss += β * kl_monitor(β=0.001~0.04,依任务调整)

四、实践建议

  1. 配置选择(以 VeRL 为例)[[13]][[41]]:

    actor_rollout_ref:
      actor:
        use_kl_loss: true          # 启用 KL 正则(GRPO 必须)
        kl_loss_coef: 0.001        # β 系数,数学任务可增至 0.04
        kl_loss_type: "k2"         # 推荐 k2;若用 k3+ 可开启 straight-through 梯度修正
    

  2. 监控阈值参考

  3. KL < 0.01:策略变化过小,可能学习缓慢
  4. KL ∈ [0.01, 0.1]:健康更新区间
  5. KL > 0.2:策略漂移过大,需检查 β 或奖励设计

  6. 调试技巧

  7. 同时记录 kl_k2kl_k3,若二者差异显著,说明 π_θ 与 π_ref 已偏离较大
  8. 若 KL 持续上升且 reward 不增,考虑定期重置 reference model(DeepSeek-R1 实践)[[50]]

log p

log_p 是模型对每个位置实际生成的 token 计算出的对数概率(Log Probability),属于标量值,而 hidden_state 是 Transformer 层输出的高维稠密向量。两者在计算链路、数据形态和用途上完全不同。


🔍 log_p 的完整计算链路

Input Tokens 
   ↓ [Embedding]
Hidden States (layer 0) 
   ↓ [Transformer Blocks × N]
Final Hidden States: h ∈ ℝ^{B×L×D}  ← 这才是你问的 hidden_state
   ↓ [LM Head: Linear + Bias]
Logits: z ∈ ℝ^{B×L×V}            ← 词表大小 V 的未归一化得分
   ↓ [Log Softmax]
Log Probabilities: log_p_all ∈ ℝ^{B×L×V}
   ↓ [Gather 实际 token ID]
log_p ∈ ℝ^{B×L}                  ← 你问的 log_p(每个位置一个标量)

📐 维度与形态对比

概念 形状 数据类型 物理含义
hidden_state [B, L, D] float32/16 Transformer 输出的上下文表征向量
logits [B, L, V] float32/16 词表每个 token 的原始得分
log_p [B, L] float32/16 当前策略下,每个位置真实 token 的对数概率

💡 D 通常为 4096/7680 等,V 为词表大小(如 32k/128k),而 log_p 已坍缩到 [B, L]只保留实际生成 token 的概率信息


💻 代码级直观实现(PyTorch)

import torch
import torch.nn.functional as F

# 假设 model 已加载,input_ids 为 [B, L]
outputs = model(input_ids=input_ids, return_dict=True)
logits = outputs.logits  # [B, L, V]

# 1. 计算全词表 log softmax
log_probs_all = F.log_softmax(logits, dim=-1)  # [B, L, V]

# 2. 提取实际 token 对应的 log_p
# input_ids.unsqueeze(-1) -> [B, L, 1]
token_log_p = log_probs_all.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)  # [B, L]

# 3. 通常只保留 response 部分(prompt 和 padding 用 mask 过滤)
response_mask = (input_ids >= tokenizer.vocab_size)  # 示例:假设 response token ID 较大
valid_log_p = token_log_p * response_mask  # 后续计算 KL/Loss 时会 mask 掉无效位置

⚠️ 常见误区澄清

误区 正确理解
log_p 就是 hidden_state” hidden_state 是向量表征;log_p 是经过 LM Head + LogSoftmax + Gather 后的标量概率
“推理时不需要 log_p 推理(生成)时模型会隐式计算它用于采样;RL 训练时需显式保存用于 loss
log_p 越大越好” 仅表示模型对该 token 更自信;RL 中需与 reward/advantage 结合,盲目最大化会导致 mode collapse
“KL 直接用 hidden_state 算” KL 是概率分布距离,必须基于 log_p;hidden_state 是特征空间,无法直接算分布散度

📌 总结

  • log_p = 每个位置真实 token 的对数概率,形状 [B, L]
  • hidden_state → logits → log_softmax → gather 得到,不是 hidden_state 本身
  • 在 GRPO 中用于:KL 正则、importance ratio、策略梯度、loss mask
  • 训练时需同时保存 actor_logpref_logp,推理时通常不显式输出但底层会计算

训推一致性

LogP Diff vs KL 散度:本质区别与阈值含义

维度 🔹 LogP Diff (训推一致性) 🔹 KL 散度 (RL 正则/监控)
比较对象 同一模型,train mode vs infer mode 两个策略,actor π_θ vs reference π_ref
核心目标 验证数值计算一致性(精度对齐) 控制策略更新幅度(防止分布崩溃)
数学形式 δ = \|log_p^train - log_p^infer\| KL = 𝔼[log(π_θ/π_ref)] 或其近似
是否取绝对值 ✅ 是,关注偏差大小 ❌ 否,保留方向信息(谁更自信)
是否加权 ❌ 所有 token 平等对待 ✅ 高概率 token 贡献更大(p·log(p/q))
量纲/单位 log 空间的相对误差(无单位比值) 信息论距离(nat,无单位但数值意义不同)
典型阈值 rel_diff < 0.01 (1%) KL < 0.01~0.1 (依任务/β系数调整)

# 在 trainer 的 logging 阶段
def compute_consistency_and_kl_metrics(batch):
    metrics = {}

    # 🔹 训推一致性(仅调试阶段启用)
    if config.debug_consistency:
        with torch.no_grad():
            logp_train = model_train(input_ids).logps      # train mode
            logp_infer = model_infer(input_ids).logps      # eval mode
        rel_diff = (logp_train - logp_infer).abs() / (logp_train.abs() + 1e-8)
        metrics["debug/consistency_rel_diff"] = verl_F.masked_mean(
            rel_diff, batch["response_mask"]
        ).item()

    # 🔹 KL 散度(训练必选)
    log_ratio = batch["ref_logps"] - batch["actor_logps"]
    kl_token = 0.5 * (log_ratio ** 2)  # k2
    metrics["actor/kl_loss"] = agg_loss(
        kl_token, batch["response_mask"], config.loss_agg_mode
    ).item()

    return metrics

评论