DiffusionNFT
导言
DiffusionNFT 直接在前向加噪过程(forward process)上进行优化,在彻底摆脱似然估计与特定采样器依赖的同时,显著提升了训练效率与生成质量。在GenEval任务上,DiffusionNFT仅用约1.7k步就达到0.94分,而对比方法FlowGRPO需要超过5k步且依赖CFG才达到0.95分。这表明DiffusionNFT的训练效率比FlowGRPO快约25倍。
动机¶
- 似然估计困难:自回归模型的似然可精确计算,而扩散模型的似然只能以高开销近似,导致 RL 优化过程存在系统性偏差。1
- 解释:指扩散模型的打分相对于LLM困难
- 前向–反向不一致:现有方法仅在反向去噪过程中施加优化,没有对扩散模型原生的前向加噪过程的一致性进行约束,模型在训练后可能退化为与前向不一致的级联高斯。
- 采样器受限:需要依赖特定的一阶 SDE 采样器,无法充分发挥 ODE 或高阶求解器在效率与质量上的优势。
- CFG 依赖与复杂性:现有 RL 方案在集成无分类器引导 (CFG) 时需要在训练中对双模型进行优化,效率低下。
思路¶
为什么不直接在“加噪”的前向过程中融入奖励信号呢?与其在去噪的每一步艰难地“纠正”方向,不如从一开始就引导整个扩散过程,使其“避开”通往低奖励样本的路径。DiffusionNFT 将强化学习的目标巧妙地转化为一个对前向过程的微调任务,从而完全绕开了棘手的似然估计问题。
创新点¶
1 负例感知微调 (Negative-aware FineTuning, NFT)¶
核心公式解释看1,
DiffusionNFT 是 DPO变种
从 DiffusionNFT 到 OmniNFT¶
DiffusionNFT 的核心转向是:不要在反向去噪轨迹上强行做 policy gradient,而是在前向 diffusion / flow matching 目标里引入 reward 方向。这解决的是“扩散模型如何做在线 RL”的基础问题:避免似然估计、减少轨迹保存、放开采样器约束。2
OmniNFT 则把问题推进到另一个层面:当生成目标从图像或视频扩展到 joint audio-video generation 时,reward 不再只是一个标量质量分数,而是至少包含视频质量、音频质量、音画同步和跨模态语义一致性。此时如果仍然把所有 reward 合成一个 global advantage,再把同一个 advantage 广播给所有分支,就会出现信用分配错误。3
核心差异
DiffusionNFT 解决“优化发生在哪里”:它把在线 RL 从 reverse denoising trajectory 转移到 forward process / flow matching objective。
OmniNFT 解决“奖励应该更新谁”:它把全局 advantage 拆成 modality-wise、layer-wise 和 region-wise 三层信用分配。
OmniNFT 的三类失配¶
OmniNFT 论文把 vanilla RLVR / GRPO 直接用于音视频联合生成时的失败原因归纳为三类优化失配:3
- Multi-objective advantage inconsistency 同一个生成样本可能视频 reward 高、音频 reward 低,或音频自然但画面质量差。论文图 2 指出,视频和音频 advantage 的相关性很弱,约一半样本在两个模态上收到相反 reward。此时用一个总分 advantage 更新所有分支,会把错误方向传给本不该惩罚的模态。
- Multi-modal gradient imbalance 音频浅层更偏 intra-modal generation,负责音频自身质量;中后层更偏 audio-video interaction,负责同步和跨模态对齐。如果视频分支梯度泄漏到音频浅层,音频质量会被视频目标污染。
- Uniform credit assignment 音画同步通常只发生在关键发声区域,例如手拍桌、嘴部发声、乐器动作。均匀更新整段视频 latent 会把优化预算浪费在无关区域。
三层信用分配¶
OmniNFT 的技术主线不是“把奖励加权求和”,而是做更细的 credit routing:
- Modality-wise advantage routing:为 video reward、audio reward、cross-modal synchronization reward 分别计算 advantage;单模态 advantage 只监督对应分支,同步 advantage 才同时影响音频和视频分支。
- Layer-wise gradient surgery:在音频分支浅层对来自视频流的部分梯度做 stop-gradient,避免视频目标污染音频自身生成;在更深的 cross-modal interaction 层保留有效梯度。
- Region-wise loss reweighting:利用 V2A cross-attention map 作为关键发声区域的内部代理,把视频侧 RL loss 权重集中到影响音画同步的区域,而不是均匀更新所有位置。
不要把 OmniNFT 理解成普通多奖励加权
多奖励加权只是在 reward 标量层面调比例;OmniNFT 的重点是 advantage 路由和梯度路径控制。它回答的是“哪个 reward 应该更新哪个分支、哪一层、哪一片区域”,不是简单调一个总奖励公式。
工程实现线索¶
OmniNFT 已公开代码、项目页和 LoRA 权重。仓库说明显示,它以 LTX-2 / LTX-2.3 为基础模型,训练时使用 HPSv3、VideoAlign、AudioBox、CLAP、ImageBind、Synchformer 等 reward model;其中 HPSv3 和 VideoAlign 以 remote HTTP server 形式运行,再通过 bash_train_omninft_ltx_fsdp.sh branch_aware_layer_surgery_avweight 启动训练。4
这条工程线有两个值得后续复现实验时优先检查的点:
- Reward server 的延迟和稳定性:音视频 RL 的瓶颈不只在模型前向,还在多个 reward model 的吞吐、队列和失败重试。
- LoRA 合并后的推理差异:Hugging Face 发布的是 LTX-2 / LTX-2.3 的 RL-LoRA,需要 merge 到 base checkpoint 后推理;不同推理端对 LoRA alpha、rank、dtype 的处理可能影响观感。5
和 UniGRPO / DanceGRPO 的关系¶
如果说 UniGRPO 关心的是 reasoning + image 的 统一 MDP,DanceGRPO / FlowGRPO 关心的是生成轨迹上的 policy gradient,那么 OmniNFT 更像是给 Omni 生成任务补上了一个前提:统一 rollout 不等于共享所有 advantage。
更稳的 Omni RL 结构应当是:
- 统一 rollout:让视频、音频、文本条件和同步关系在同一条生成轨迹中被评估。
- 分模态 reward:分别评估视频质量、音频质量、音画同步和语义对齐。
- 分层 advantage routing:单模态 reward 只更新相应分支,同步 reward 更新交互层和关键区域。
- 前向过程优化:尽量复用 DiffusionNFT 的 forward-process 思想,减少对完整反向采样轨迹和 log-prob ratio 的依赖。
实践判断
在音视频联合生成里,真正危险的不是 reward 不够多,而是 reward 太多后 归因不清。OmniNFT 的价值在于把多目标 RL 从“调权重”推进到“调路由”:先判断 reward 属于哪个模态,再判断它应该穿过哪几层、落到哪些区域。
代码跳读¶
入口scripts/train_nft_sd3.py
for epoch in range(first_epoch, config.num_epochs):
for i in tqdm(
range(config.sample.num_batches_per_epoch),
desc=f"Epoch {epoch}: sampling",
disable=not is_main_process(rank),
position=0,
):
# eval_fn 函数的主要目的是在训练过程中对模型进行评估。
images, latents, _ = pipeline_with_logprob()
rewards_future = executor.submit(reward_fn, images, prompts, prompt_metadata, only_strict=True)
# advantages 来自 reward 的平均值等计算,对应公式5的 r 项
normalized_advantages_clip = (advantages_clip / config.train.adv_clip_max) / 2.0 + 0.5
r = torch.clamp(normalized_advantages_clip, 0, 1)
# 切换模型到"旧"参数状态以获取参考预测值
transformer_ddp.module.set_adapter("old")
with torch.no_grad():
# prediction v
old_prediction = transformer_ddp(
hidden_states=xt,
timestep=train_sample_batch["timesteps"][:, j_idx],
encoder_hidden_states=embeds,
pooled_projections=pooled_embeds,
return_dict=False,
)[0].detach()
transformer_ddp.module.set_adapter("default")
# prediction v
forward_prediction = transformer_ddp(
hidden_states=xt,
timestep=train_sample_batch["timesteps"][:, j_idx],
encoder_hidden_states=embeds,
pooled_projections=pooled_embeds,
return_dict=False,
)[0]
# 对应公式5的 正负 v 项
positive_prediction = config.beta * forward_prediction + (1 - config.beta) * old_prediction.detach()
implicit_negative_prediction = (
1.0 + config.beta
) * old_prediction.detach() - config.beta * forward_prediction
timestep对效果影响很大,原仓没实现
Adaptive Weighting. We find stability improves when the flow-matching loss is given higher weight at larger t, whereas inverse strategies (e.g., w(t) = 1 − t) lead t
生成sampling¶
def run_sampling(
v_pred_fn,
z,
sigma_schedule,
solver="flow",
determistic=False,
eta=0.7,
):
assert solver in ["flow", "dance", "ddim", "dpm1", "dpm2"]
dtype = z.dtype
all_latents = [z]
all_log_probs = []
if "dpm" in solver:
order = int(solver[-1])
dpm_state = DPMState(order=order)
for i in tqdm(
range(len(sigma_schedule) - 1),
desc="Sampling Progress",
disable=not dist.is_initialized() or dist.get_rank() != 0,
):
sigma = sigma_schedule[i]
pred = v_pred_fn(z.to(dtype), sigma)
if solver == "flow":
z, pred_original, log_prob = flow_grpo_step(
)
elif solver == "dance":
z, pred_original, log_prob = dance_grpo_step(
pred.float(), z.float(), eta if not determistic else 0, sigmas=sigma_schedule, index=i, prev_sample=None
)
elif solver == "ddim":
z, pred_original, log_prob = ddim_step(
pred.float(), z.float(), eta if not determistic else 0, sigmas=sigma_schedule, index=i, prev_sample=None
)
elif "dpm" in solver:
assert determistic
z, pred_original, log_prob = dpm_step(
)
else:
assert False
z = z.to(dtype)
all_latents.append(z)
all_log_probs.append(log_prob)
latents = z.to(dtype)
return latents, all_latents, all_log_probs
这部分逻辑和danceGRPO是类似的,但是训练没有对应逻辑,并且all_latents只用于计算x0
x0 = train_sample_batch["latents_clean"]
t = train_sample_batch["timesteps"][:, j_idx] / 1000.0
t_expanded = t.view(-1, *([1] * (len(x0.shape) - 1)))
noise = torch.randn_like(x0.float())
xt = (1 - t_expanded) * x0 + t_expanded * noise



