
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you
快速结论:该报错在使用 Diffusers Flux2 模型并启用梯度检查点( gradient checkpointing )时发生,通常是因为 torch 的 reentrant 检查点无法识别作为元组(tuple)传递的张量。优先排查是否在训练/微调 Flux2 时使用了梯度检查点,并且检查点的use_reentrant参数是否为True。
问题场景
在 Huggingface Diffusers 库中使用 Flux2 模型(transformer_flux2.py)进行训练或微调,启用了梯度检查点(gradient checkpointing)功能。具体发生在调制(modulation)计算的输出以元组(tuple of Tensors)形式传递到检查点化的 transformer 模块中,而这些张量又需要梯度时。
报错原文
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
原因分析
Flux2 模型中,从transformer_flux2.py的 L716 计算的调制量(modulations)在 L628 处返回了张量元组(tuples of Tensors)。这些元组从 transformer 块外部传递到检查点化的 transformer 块内部。PyTorch 的 reentrant 检查点实现(torch.utils.checkpoint)不会将元组识别为张量——它只识别直接的张量。具体来说,use_reentrant=True 的检查点机制不会将嵌套结构(如自定义对象、列表、字典、元组)中的张量视为参与自动求导。这导致在反向传播时,保存的中间值被释放,从而触发 “Trying to backward through the graph a second time” 报错。值得注意的是,Diffusers 中默认的标准检查点使用 use_reentrant=False(非重入版本),该版本可以正确处理嵌套结构;只有当显式或隐式启用 use_reentrant=True 时才触发此问题。
环境排查
- 确认使用的 PyTorch 版本(Issue 中提及 torch 2.8)。
- 确认 Diffusers 版本(需要检查是否包含修复 PR #12777)。
- 确认 Flux2 模型是否启用了梯度检查点(gradient checkpointing)。
- 确认梯度检查点的
use_reentrant参数是否为True(Diffusers 标准检查点默认使用use_reentrant=False,不受此影响)。 - 确认是否有足够的显存运行全量反向传播(Flux2 模型较大,通常需要 128GB 显存或使用 LoRA、混合精度等优化方案)。
解决步骤
- 首选方案:升级 Diffusers 到包含 PR #12777 的版本。该 PR 通过将元组展平后传入检查点块,在块内部再重建元组,从而避免了检查点机制无法识别元组的问题。具体修改是:在
transformer_flux2.py中,将调制量的拆分操作移入检查点化的 transformer 块内部执行。 - 如果无法升级:在训练脚本中显式设置梯度检查点的
use_reentrant=False。例如:model.enable_gradient_checkpointing(use_reentrant=False)。这是 Diffusers 中许多模型的标准配置,可以正确处理元组中的张量。 - 备选方案:修改调制量输出(不推荐,除非熟悉 Flux2 架构)。将返回元组改为返回拼接后的单一张量(concatenated tensor),或者为元组创建一个自定义的类包装,使检查点机制能正确识别其包含的张量。
- 如果上述步骤仍无法复现:需要提供更具体的复现脚本。由于 Flux2 模型规模较大,可能需要在启用 offloading 和 fused backward pass 的环境下才能复现。也可以尝试使用 Diffusers 测试中使用的较小模型变体,其初始化字典可在
tests/models/transformers/test_models_transformer_flux2.py的L85附近找到。
验证方法
运行训练或微调脚本,确认梯度检查点功能正常工作,不再出现 RuntimeError 报错。如果使用测试命令,可以尝试:pytest tests/models/transformers/test_models_transformer_flux2.py -k "gradient_checkpointing"(该测试默认会执行反向传播,且使用较小的模型变体)。需要注意的是,即使测试通过,在某些特定配置(如 use_reentrant=True)下仍可能出现问题,因此建议在实际训练场景中进一步验证。



