跳转至

AI Post Traning: DanceGRPO

导言

DanceGRPO是25年5月发表的论文,把GRPO的方法引入到了生成领域。(类似的有flowGRPO)。字节客户基于此魔改,故学习。

背景知识:

Diffusion Model

去噪(Denoising)的本质不是“修改”,而是“根据模糊的线索,重新生成一个全新的清晰样本”。

第一步:前向加噪

(理解数据是怎么毁掉的)

\[ z_t = \alpha_t x + \sigma_t \epsilon \]
  • \(z_t\):时间 \(t\) 时刻的图片(它既不是纯图,也不是纯噪,是两者的混合体)。
  • \(x\):原始的清晰图片(比如一只猫)。
  • \(\epsilon\):纯随机噪声(电视屏幕的雪花点)。
  • \(\alpha_t, \sigma_t\):两个控制比例的阀门。

直观理解:

这就像是在调鸡尾酒:

  • 杯子里原本有 80% 的猫(\(x\)20% 的雪花噪点(\(\epsilon\)
  • 随着时间 \(t\) 增加,\(\alpha_t\) 变小(猫的味道变淡),\(\sigma_t\) 变大(油漆味变浓)。
  • 到最后(\(t=1\)),杯子里 0% 是猫,100% 是油漆

第二步:反向去噪

(AI 怎么把猫变回来)公式 (2) 看起来很奇怪,它其实分两步走:

目的:我们要算出 \(z_{s}\)(一个比 \(z_t\) 更清晰一点的图片)。

过程

  1. AI 的工作(预测噪声): 模型 \(\epsilon_\theta\) 看着满是油漆的图片 \(z_t\),猜这里面有多少是噪声。 它猜的结果叫 \(\hat{\epsilon}\)
  2. 人类的工作(重新配比): 我们拿到 AI 猜的噪声 \(\hat{\epsilon}\),然后利用公式 (2) 强行配出一杯新酒: $$ z_s = \alpha_s \mathbf{\hat{x}} + \sigma_s \mathbf{\hat{\epsilon}} $$

关键点解析: * \(\hat{x}\)(x-hat):这是根据公式 (1) 倒推出来的“假原图”。 * 既然 \(z_t = \alpha_t x + \sigma_t \epsilon\), * 那么 \(x\) 应该等于 \((z_t - \sigma_t \epsilon) / \alpha_t\)。 * 但是我们不知道真实的 \(\epsilon\),所以我们用 AI 猜的 \(\hat{\epsilon}\) 来代替。 * \(\hat{\epsilon}\)(epsilon-hat):这是 AI 预测的噪声。

所以,整个过程是:

  1. 拆解:AI 看着模糊图 \(z_t\),说:“我觉得这里面大概有 80% 是噪声,20% 是图。”(它输出了 \(\hat{\epsilon}\))。
  2. 重组:我们根据 AI 的判断,把噪声去掉,然后利用公式 (2) 重新混合一个清晰度高一点的图 \(z_s\)

3. 乘以 \(\sigma_s\)\(\alpha_s\)

为什么公式里还要乘以 \(\sigma_s\)\(\alpha_s\)? 这其实是在模拟“时间的流逝”

  • 假设 \(s\) 是第 99 步(非常接近原始图片了)。
    • 这时候 \(\alpha_s\) 接近 1(全是图),\(\sigma_s\) 接近 0(几乎没有噪)。
    • 公式变成:\(z_s \approx 1 \times \text{AI猜的图} + 0 \times \text{AI猜的噪}\)
    • 结果:输出就是 AI 猜的图,非常清晰。
  • 假设 \(s\) 是第 1 步(刚开始生成)。
    • 这时候 \(\alpha_s\) 接近 0,\(\sigma_s\) 接近 1。
    • 公式变成:\(z_s \approx 0 \times \text{AI猜的图} + 1 \times \text{AI猜的噪}\)
    • 结果:输出就是一堆随机噪声(从 \(z_1\) 开始)。

4. 去噪过程

如果你要画一张图来解释这个过程,请按以下步骤:

  1. 起点:手里拿着一个全是噪点的球(\(z_1\))。
  2. 循环 50 次
    • 看一眼:把球给 AI 看,问:“这里面大概是什么图案的噪声?”(模型输出 \(\hat{\epsilon}\))。
    • 捏造型:根据 AI 的回答,你手里的泥巴(参数)开始变形。
      • 你不是在“擦除”噪点,你是在根据 AI 的描述,重新捏一个新球
      • 捏的时候,你遵循公式 (2) 的比例:刚开始捏的时候,主要听噪声的(因为 \(\sigma\) 大);越往后捏,越听 AI 描述的图案(因为 \(\alpha\) 大)。
    • 更新:把手里的球换成新捏的球(\(z_s\))。
  3. 终点:最后一步,\(\sigma\) 变成了 0,你完全不听噪声了,只听 AI 描述的图案。于是你手里出现了一只完美的猫。

一句话总结: 去噪过程就是利用 AI 对噪声的预测 (\(\hat{\epsilon}\)),通过公式 (2) 不断重新混合样本,逐渐把“全是噪点的液体”变成“全是数据的固体”的过程。

timestep的作用

在扩散模型(例如:DDPM、Stable Diffusion)中,timestep t 主要标识着模型当前的“降噪”进度。实际上,timestep 并不直接代表物理时间,而是代表扩散过程中的步骤,通常从 t=0 到 t=T。

  • t=0 是最清晰的图像(无噪声)。
  • t=T 是完全噪声的图像。

扩散过程的目标是通过一系列时间步 逐步添加噪声,而反向过程则是 逐步去噪,直到恢复出清晰图像。

  • 简单的 timestep 是线性地表示时间步,但扩散过程需要更精细的噪声控制。
  • 通过调整 timestep的分布/频率(随机去除),你能 控制每个时间步的噪声强度,这对模型的降噪效果有重要影响。

timestep代码解释

SFT 以MM为例, sora_model.py

noised_latents, noise, timesteps = self.diffusion.q_sample(latents, model_kwargs=kwargs, mask=video_mask)

def _set_timesteps(self, num_steps=100, training=False):
    sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min)
    if self.extra_one_step:
        self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_steps + 1)[:-1]
    else:
        # 一般是 0~1,切num_steps个
        self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_steps)
    if self.inverse_timesteps:
        self.sigmas = torch.flip(self.sigmas, dims=[0])

    # 这里也有shift
    self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
    if self.reverse_sigmas:
        self.sigmas = 1 - self.sigmas

    # 推理不能用训练没见过的timestep范围,但是范围内的子区间是允许的。
    # 把0~1 放大到 0~self.num_train_timesteps
    self.timesteps = self.sigmas * self.num_train_timesteps
    if training:
        y = torch.exp(-2 * ((self.timesteps - num_steps / 2) / num_steps) ** 2)
        y_shifted = y - y.min()
        bsmntw_weighing = y_shifted * (num_steps / y_shifted.sum())
        self.linear_timesteps_weights = bsmntw_weighing

def q_sample(self, latents, noise=None, t=None, **kwargs):
    curr_rank = torch.distributed.get_rank()
    seed = int(time.time() * 1000) + curr_rank
    generator = torch.Generator().manual_seed(seed)
    noise = torch.randn(latents.shape, generator=generator).to(latents.device)

    # 随机值
    timestep_idx = torch.randint(0, self.num_train_timesteps, (1,), generator=generator).to(latents.device)
    cp_src_rank = list(mpu.get_context_parallel_global_ranks())[0]
    if mpu.get_context_parallel_world_size() > 1:
        torch.distributed.broadcast(noise, cp_src_rank, group=mpu.get_context_parallel_group())
        torch.distributed.broadcast(timestep_idx, cp_src_rank, group=mpu.get_context_parallel_group())
    timestep_idx = timestep_idx.to("cpu")

    # 抽取对应map
    timestep = self.timesteps[timestep_idx].to(latents.device)
    sigma = self.sigmas[timestep_idx].to(latents.device)

    # 计算噪声
    noised_latents = (1 - sigma) * latents + sigma * noise
    return noised_latents, noise, timestep

RL以danceGRPO为例子, 就用的WanFlowMatchSchedulerself.timesteps

self.scheduler = DiffusionModel(args.mm.model.diffusion).get_model()
timesteps = self.scheduler.timesteps

修正流(Rectified Flow)

1. 前向过程(Forward Process)

在修正流中,前向过程被视为数据(\(x\)噪声(\(\epsilon\)之间的线性插值。

\[ z_t = (1 - t)x + t\epsilon \tag{3} \]
  • \(z_t\):时刻 \(t\) 的状态(混合了数据和噪声的中间态)。
  • \(x\):原始清晰数据。
  • \(\epsilon\):高斯噪声(Gaussian noise),作为生成过程的起点。
  • \(t\):时间步,取值范围通常为 。
    • \(t=0\) 时,\(z_0 = x\)(纯数据)。
    • \(t=1\) 时,\(z_1 = \epsilon\)(纯噪声)。

2. 速度场定义(Velocity)

定义 \(u = \epsilon - x\) 为“速度”或“向量场”。

  • 通俗理解:这代表了数据点需要以多大的“速度”和方向移动,才能从当前位置变形为噪声(或反之)。

3. 采样/去噪过程(Reverse/Sampling)

类似于扩散模型,给定去噪模型在时间步 \(t\) 的输出预测值 \(\hat{u}\),我们可以计算更低噪声水平 \(s\)(即 \(s < t\))的状态:

\[ z_s = z_t + \hat{u} \cdot (s - t) \tag{4} \]
  • \(\hat{u}\):模型预测的“速度”。
  • \((s - t)\):时间差(负值,因为 \(s < t\))。
  • 逻辑:根据预测的速度,回溯到前一个时间步的状态。这相当于在向量场中进行了一次“反向跳跃”。

这段文字揭示了扩散模型(Diffusion Models)和修正流模型(Rectified Flow Models)在数学形式上的统一性。虽然原文中提到的公式(5)没有直接显示,但根据文字描述,核心在于两者都可以归纳为一种线性的预测形式。

以下是针对该段落的易懂且干练的翻译:

扩散模型与修正流的统一性

虽然扩散模型和修正流拥有不同的理论基础,但在实践中,它们本质上是同一事物的两面(一体两面),其通用公式如下:

\[ \tilde{z}_s = \tilde{z}_t + \text{Network output} \cdot (\eta_s - \eta_t) \]

具体的映射关系如下表所示:

模型类型 隐变量变换 (\(\tilde{z}\)) 时间尺度 (\(\eta\)) 备注
扩散模型 (\(\epsilon\)-prediction) \(\tilde{z} = z / \alpha\) \(\eta = \sigma / \alpha\) 基于公式 (2) 推导
修正流 (Rectified Flows) \(\tilde{z} = z\) \(\eta = t\) 基于公式 (4) 推导

💡 核心解读::无论你是用扩散模型还是修正流,其实都是在计算“当前状态 + 网络预测值 × 时间变化量”。两者的区别仅在于是否对数据进行了缩放处理(即是否除以 \(\alpha\) 系数)。

设计

将去噪视为马尔可夫决策过程 (MDP)

本文参考 DDPO ,将扩散模型和修正流的去噪过程形式化为一个强化学习的 MDP 问题:

1. 核心要素定义

  • 状态 (\(s_t\)):当前的环境状态,包含三部分信息:
    • c:输入的提示词(Prompt)。
    • t:当前的时间步(代表噪声水平)。
    • \(z_t\):当前带噪声的隐变量(Latent)。
  • 动作 (\(a_t\)):智能体采取的动作,即生成下一个时间步的去噪结果 (\(z_{t-1}\))。
  • 奖励 (\(R\)):仅在去噪结束时\(t=0\),即得到最终清晰图像 \(z_0\) 时)根据提示词 \(c\) 计算奖励;过程中无奖励。

2. 数学公式详解

$$ s_t \triangleq (\mathbf{c}, t, z_t) \tag{State} $$ * 状态由提示词 c、当前步数 t 和当前噪声数据 \(z_t\) 组成。

$$ \pi(a_t | s_t) \triangleq p(z_{t-1} | z_t, c) \tag{Policy} $$ * 策略:在给定当前状态(当前噪声 \(z_t\) + 提示词 c)下,模型预测下一个状态(去噪一步 \(z_{t-1}\))的概率。

$$ a_t \triangleq z_{t-1} \tag{Action} $$ * 动作:直接等于去噪后的结果 \(z_{t-1}\)

$$ R(s_t, a_t) \triangleq \begin{cases} r(z_0, c), & \text{if } t = 0 \ 0, & \text{otherwise} \end{cases} \tag{Reward} $$ * 奖励函数: * 如果是最后一步\(t=0\)),奖励等于奖励模型 \(r\) 的打分(基于生成的图像 \(z_0\) 和提示词 c)。 * 如果是中间步骤,奖励为 0。

$$ \rho_0(s_0) \triangleq (p(c), \delta_T, \mathcal{N}(0, I)) \tag{Initial} $$ * 初始状态分布: * 提示词来自分布 \(p(c)\)。 * 初始步数为 \(T\)(最大噪声步数)。 * 初始隐变量 \(z_T\) 是纯高斯噪声 \(\mathcal{N}(0, I)\)


通俗理解

这段话的意思是:把生成图片的过程当成玩一个“猜图”游戏。

  1. 状态 (State):你现在的手里的草稿(\(z_t\))、题目要求(\(c\))和当前是第几步(\(t\))。
  2. 动作 (Action):你擦掉一部分噪点,画出一个新草稿(\(z_{t-1}\))。
  3. 奖励 (Reward)在整个游戏结束前,你都看不到分数。只有当你画出最终结果(\(t=0\))并交给裁判(Reward Model,如 CLIP)看时,才会根据“画得像不像题目”给出最终分数。

公式

将生成模型的采样过程改写为随机微分方程(SDE),以便适配强化学习算法(GRPO)的需求。

🎯 采样 SDE 的公式化 (Formulation of Sampling SDEs)

为了适配 GRPO 算法对随机探索(Stochastic Exploration)的需求(即策略更新依赖于轨迹的概率分布),我们将扩散模型和修正流的采样过程统一改写为 SDE 形式。

1. 扩散模型 (Diffusion Models)

前向

  • 形式:这是一个 SDE(随机微分方程)
  • 公式\(dz_t = f_t z_t dt + g_t dw\)
  • 含义:它天生带有随机性(\(dw\) 代表布朗运动/噪声)。这是为了模拟数据逐渐变成噪声的过程。

反向去噪(Reverse Process - ODE)

  • 逻辑:通常情况下,生成模型(如扩散模型或修正流)在生成图片时,往往追求确定性(Deterministic)。
  • 形式:通常被写成 ODE(常微分方程)
  • 特点:给定一个输入,永远得到同一个输出(像按固定轨道运行的卫星)。

反向 SDE(Reverse SDE)

改写表达式如下:

\[ \mathrm{d} \mathbf{z}_{t} = \left( f_{t} \mathbf{z}_{t} - \frac{1 + \varepsilon_{t}^{2}}{2} g_{t}^{2} \nabla \log p_{t} ( \mathbf{z_{t}} ) \right) \mathrm{d} t + \varepsilon_{t} g_{t} \mathrm{d} \mathbf{w} \tag{7} \]

公式中各符号的直观含义

  • \(dz_t\):下一步状态的变化量。
  • \((f_t z_t ... ) dt\)“确定性部分”。这部分是基于模型预测的方向(比如去噪的方向),是主要的移动趋势。
  • \(- \frac{1 + \varepsilon_{t}^{2}}{2} g_{t}^{2} \nabla \log p_{t} ( z_{t} )\)“修正项”。这是为了让生成的分布符合真实数据分布而做的数学修正(Score-based 项)。
  • \(\varepsilon_t g_t dw\)“随机性部分”(关键点)。
    • \(\varepsilon_t\):控制随机程度的系数。
    • \(dw\):随机噪声(来自布朗运动)。
    • 作用:这就是为了让 AI 在画图时能“随机发挥”,从而满足强化学习需要试错的需求。

2. 修正流 (Rectified Flows)

修正流的前向过程通常是确定性的 ODE(\(dz_t = u_t dt\)),但这无法满足 GRPO 所需的随机性。受启发,我们在其反向过程中引入 SDE:

\[ \mathrm{d} \mathbf{z}_{t} = \left( \mathbf{u}_{t} - \frac{1}{2} \varepsilon_{t}^{2} \nabla \log p_{t} ( \mathbf{z}_{t} ) \right) \mathrm{d} t + \varepsilon_{t} \mathrm{d} \mathbf{w} \tag{8} \]

3. 核心变量说明

  • \(\varepsilon_t\) (Epsilon):引入采样过程中的随机性(Stochasticity),这是连接确定性模型与强化学习探索需求的关键桥梁。
  • \(\nabla \log p_t(z_t)\):得分函数(Score Function)。若给定正态分布 \(p_t(z_t) = N(z_t | \alpha_t x, \sigma_t^2 I)\),则有 \(\nabla \log p_{t} ( \mathbf{z}_{t} ) = -(\mathbf{z}_{t} - \alpha_{t} \mathbf{x}) / \sigma_{t}^{2}\)
  • \(dw\):布朗运动(Brownian motion),代表随机噪声。

通过将上述公式代入,我们可以推导出策略概率 \(\pi(a_t | s_t)\)

DanceGRPO 目标函数

🎯 DanceGRPO 优化目标 (Objective Function)

受 DeepSeek-R1 启发,该算法通过最大化以下目标函数来更新策略模型(\(\pi_\theta\)):

\[ J(\theta) = \mathbb{E} \left[ \frac{1}{G} \sum_{i=1}^{G} \frac{1}{T} \sum_{t=1}^{T} \min \left( \rho_{t,i} A_i, \text{clip}(\rho_{t,i}, 1-\epsilon, 1+\epsilon) A_i \right) \right] \]

1. 核心逻辑:PPO 风格的剪切目标 (Clipped Surrogate Objective)

给定一个提示词 \(c\),模型生成一组样本(Group of outputs),通过比较新旧策略的差异来更新模型,防止更新幅度过大。

  • \(\rho_{t,i}\) (概率比率 Probability Ratio): $$ \rho_{t,i} = \frac{\pi_\theta(a_{t,i}|s_{t,i})}{\pi_{\theta_{old}}(a_{t,i}|s_{t,i})} $$
    • 衡量新策略 \(\pi_\theta\) 与旧策略 \(\pi_{\theta_{old}}\) 在状态 \(s\) 下采取动作 \(a\) 的概率差异。
  • \(\epsilon\) (超参数): 一个很小的数(如 0.2),用于限制更新的步长,确保训练稳定。
  • \(\text{clip}(\cdot)\): 剪切函数,将比率限制在 \([1-\epsilon, 1+\epsilon]\) 范围内。

2. 优势函数 (Advantage Function) \(A_i\)

用于衡量第 \(i\) 个样本在组内的表现优劣,通过组内标准化计算:

\[ A_i = \frac{r_i - \text{mean}(\{r_1, ..., r_G\})}{\text{std}(\{r_1, ..., r_G\})} \]
  • \(r_i\): 第 \(i\) 个样本获得的奖励分数。
  • 逻辑: 如果样本的奖励高于组内平均水平(\(r_i > \text{mean}\)),则 \(A_i\) 为正,模型会增加该样本生成的概率;反之则降低。

💡 通俗总结

这段公式的意思是: 1. 分组打分:给一个题目(Prompt),让模型做 \(G\) 道题(生成 \(G\) 个样本)。 2. 相对评价:算出每道题的分数后,看谁比平均分高,谁比平均分低。 3. 安全改进:表扬那些分数高于平均的样本(增加生成概率),批评低于平均的样本(减少概率),但表扬和批评的幅度不能太激进(由 \(\epsilon\) 控制),以免把模型改“坏”了。

没有kl loss

虽然传统的GRPO公式使用kl正则化来防止奖励过度优化,但我们从经验上观察到,当忽略这一组件时,性能差异最小。因此,我们在默认情况下省略了KLregularization项。

开CFG会不稳定

背景: CFG 是条件生成建模中生成高质量样本的常用技术。

存在问题: 在当前的训练设置中,将 CFG 集成到训练流水线(Training pipelines)中会导致优化过程不稳定

解决方案: 基于实证研究,建议对高保真度模型(如 HunyuanVideo 和 FLUX)在采样阶段禁用 CFG

  • 目的:减少梯度震荡(Gradient oscillation)。
  • 效果:在保持输出质量的同时,确保训练稳定性。

针对依赖 CFG 来保证生成质量的模型(如 SkyReels-I2V 和 Stable Diffusion),我们发现了一个关键问题及应对策略:

  1. 核心问题:优化轨迹发散

    • 现象:如果仅针对“条件输出”(Conditional Objective)进行训练,会导致优化过程失控(发散)。
    • 原因:必须同时对“条件输出”和“无条件输出”(Unconditional outputs)进行联合优化。
    • 代价:这需要同时运行两个网络进行计算,导致 显存(VRAM)消耗直接翻倍
  2. 解决方案:降低参数更新频率

    • 策略:建议降低每个训练迭代中的参数更新频率。
    • 实证效果:经验证,将每次迭代的更新次数限制为 1 次,能显著提升 SkyReels-I2V 模型的训练稳定性。
    • 副作用:对模型的收敛速度影响极小。

优化Trick

时间步选择 (Ablation on Timestep Selection)

实验目的: 探究不同时间步(Timestep)选择策略对 HunyuanVideo-T2I 模型训练效果的影响。

五种实验条件

  1. 仅前 30%:仅使用从噪声开始的前 30% 时间步进行训练。
  2. 随机 30%:随机采样 30% 的时间步进行训练。
  3. 仅后 40%:仅使用输出前的最后 40% 时间步进行训练。
  4. 随机 60%:随机采样 60% 的时间步进行训练。
  5. 全量 100%:使用所有时间步进行训练(基准)。

核心发现

  • 初期步数至关重要:前 30% 的时间步对学习基础生成模式(Foundational generative patterns)至关重要,对模型性能贡献最大。
  • 单一区间有缺陷:如果训练仅局限于前 30% 的时间步,模型性能反而会下降。这是因为模型缺乏对“后期精细化动态”(Late-stage refinement)的学习。

最终策略:随机时间步丢弃 (Stochastic Timestep Dropout) 为了在计算效率模型保真度之间取得平衡,作者采用了 40% 的随机时间步丢弃策略: * 做法:在训练过程中,随机屏蔽(Mask)40% 的时间步。 * 优势: * 减少了计算量(提升了效率)。 * 保留了潜在扩散过程的时间连续性。 * 证明了通过策略性地减少时间步采样,可以在基于流的生成框架中优化资源利用。

脚本

wan2.1 脚本 bash scripts/finetune/finetune_wan_2_1_grpo.sh

超参

--cfg 0.0 \                 # 设置 cfg_rate 控制 dataset 里 无文本的比率。
--sampling_steps 20 \       # 推理的步数
--eta 0.3 \                 # flux step使用的 超参
--sampler_seed 1223627 \    # 固定采样
--shift 3 \                 # 调整timestep分布
--init_same_noise \         # 推理使用相同初始值
--clip_range 1e-4 \         # clip参数
--adv_clip_max 5.0 \        # 奖励 clip
--cfg_infer 5.0             # 是否开启cfg

代码

python代码 fastvideo/train_grpo_wan_2_1.py

timestep逻辑

def sd3_time_shift(shift, t):
    return (shift * t) / (1 + (shift - 1) * t) # shift>1, timesstep 增长会先快后慢;

sigma_schedule = torch.linspace(1, 0, args.sampling_steps + 1)
sigma_schedule = sd3_time_shift(args.shift, sigma_schedule)

cfg逻辑

if args.cfg_infer>1:
    with torch.autocast("cuda", torch.bfloat16):
        pred= transformer(
            hidden_states=torch.cat([z,z],dim=0),
            timestep=torch.cat([timesteps,timesteps],dim=0),
            encoder_hidden_states=torch.cat([encoder_hidden_states,negative_prompt_embeds],dim=0),
            attention_kwargs=None,
            return_dict=False,
        )[0]
        model_pred, uncond_pred = pred.chunk(2)
        pred  =  uncond_pred.to(torch.float32) + args.cfg_infer * (model_pred.to(torch.float32) - uncond_pred.to(torch.float32))
else:
    with torch.autocast("cuda", torch.bfloat16):
        pred= transformer(
            hidden_states=z,
            timestep=timesteps,
            encoder_hidden_states=encoder_hidden_states,
            attention_kwargs=None,
            return_dict=False,
        )[0]

奖励计算

gathered_reward = gather_tensor(samples["rewards"])
if dist.get_rank()==0:
    print("gathered_reward", gathered_reward)
    with open('./reward.txt', 'a') as f: 
        f.write(f"{gathered_reward.mean().item()}\n")

#计算advantage
if args.use_group:
    n = len(samples["rewards"]) // (args.num_generations)
    advantages = torch.zeros_like(samples["rewards"])

    for i in range(n):
        start_idx = i * args.num_generations
        end_idx = (i + 1) * args.num_generations
        group_rewards = samples["rewards"][start_idx:end_idx]
        group_mean = group_rewards.mean()
        group_std = group_rewards.std() + 1e-8
        advantages[start_idx:end_idx] = (group_rewards - group_mean) / group_std

    samples["advantages"] = advantages
else:
    advantages = (samples["rewards"] - gathered_reward.mean())/(gathered_reward.std()+1e-8)
    samples["advantages"] = advantages

loss

advantages = torch.clamp(
    sample["advantages"],
    -adv_clip_max,
    adv_clip_max,
)

ratio = torch.exp(new_log_probs - sample["log_probs"][:,_])

unclipped_loss = -advantages * ratio
clipped_loss = -advantages * torch.clamp(
    ratio,
    1.0 - clip_range,
    1.0 + clip_range,
)
loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss)) / (args.gradient_accumulation_steps * train_timesteps)

评论