
NaN errors with fp16 training on Anima.
快速结论:使用 Kohya SD-Scripts 的 Anima LoRA 训练时,若设置 --mixed_precision="fp16" 会立即触发 NaN 错误。优先排查是否已合并 PR #2302 或手动应用 #2274 补丁;如果 GPU 不支持 bf16,fp16 可以通过模型精度补丁正常工作。
问题场景
用户在 Kohya SD-Scripts 的 SD3 分支上,使用 anima_train_network.py 训练 Anima LoRA,命令行参数中包含 --mixed_precision="fp16"。用户 GPU 不支持 bf16,因此只能依赖 fp16 训练。报错在训练启动后立即出现。
报错原文
NaN errors with fp16 training on Anima:
(立即产生 NaN,无额外报错代码;训练过程中 loss 变为 NaN 并中断)
原因分析
Anima 模型在 fp16 混合精度训练时,特定注意力层和 MLP 层的前向计算存在数值不稳定性,导致 loss 立即爆炸为 NaN。该问题在 PR #2274 中得到初步缓解,但仍有约 50% 概率在训练中后期复现。最终修复由 PR #2302 完全解决——该 PR 修改了 anima_util.py 中的 apply_fp16_patch 函数,对后续 block 的 self_attn、cross_attn、mlp 等子模块强制使用 torch.float16 autocast,同时保持整个 block 的 forward 为 torch.float32,并启用了 fp16 累加。
注意:如果用户使用 FP8 原始模型(如 anima-preview2-fp8.safetensors),fp16 训练可能会稳定,但本 Issue 中用户使用的是非 FP8 版本。
环境排查
- 确认 Kohya SD-Scripts 版本是否为 SD3 分支,以及是否已合并 PR #2302。
- 检查
anima_util.py文件中的apply_fp16_patch函数是否包含 Issue 评论中给出的修改代码(针对self_attn、cross_attn、mlp的 autocast 设置)。 - 确认 GPU 型号是否支持 bf16(本场景中用户 GPU 不支持,因此只能使用 fp16)。
- 确认使用的预训练模型文件是否为标准 safetensors(非 FP8 变体),以及
--mixed_precision设置是否为"fp16"。 - 确认 PyTorch 版本是否支持
torch.autocast和torch.set_float32_matmul_precision(建议 PyTorch >= 2.0)。
解决步骤
- 合并 PR #2302:将 GitHub 上的 pull request #2302 合并到本地 SD3 分支。这包含了最终的 fp16 修复,已在 Issue 中得到用户验证:“PR #2302 is working perfectly”。
- 若无法合并 PR,手动修改
anima_util.py:定位到apply_fp16_patch函数,按照 Issue 评论中的代码替换其实现。关键改动包括:- 对 index > 1 的 block 中的
self_attn、cross_attn、mlp的 forward 方法使用make_autocast(torch.float16, ...)包装。 - 对每个 block 的
forward方法使用make_autocast(torch.float32, ...)包装(保持 block 级精度为 float32)。 - 设置
torch.set_float32_matmul_precision("high")和torch.backends.cuda.matmul.allow_fp16_accumulation = True。
- 对 index > 1 的 block 中的
- 调整训练参数(可选):如果仍出现 NaN(原 Issue 提到即使打上 #2274 补丁,训练深入后仍有 50% 失败概率),可尝试:
- 降低学习率(如从 1e-4 降至 5e-5)。
- 修改
--timestep_sampling为其他策略(如"uniform")。 - 调整
--discrete_flow_shift的值(如从 1.0 改为 0.5 或 2.0)。
注意:这些调整在 Issue 中没有明确证据,属于可能原因下的探索性步骤。
- 考虑使用 FP8 模型:如果上述步骤无效,可尝试更换预训练模型为
anima-preview2-fp8.safetensors(来自 HuggingFace 社区版本)。Issue 中有用户反馈 FP8 模型配合 #2274 修改后训练稳定。但需确保--mixed_precision仍设为"fp16"。
验证方法
执行训练命令后,观察终端日志中的 loss 值是否稳定下降而不是变成 NaN。如果正常,应看到 loss 在 1.0-5.0 范围内正常浮动,训练不会中途崩溃。此外,请确认训练生成的 LoRA 文件可以正常推理并产生合理图像。



