v1.0.2 fix: torch.utils.checkpoint.checkpoint add use_reentrant=False

This commit is contained in:
SWivid
2025-03-15 16:34:32 +08:00
parent 2bd39dd813
commit f05ceda4cb
2 changed files with 2 additions and 3 deletions

View File

@@ -219,8 +219,7 @@ class DiT(nn.Module):
for block in self.transformer_blocks:
if self.checkpoint_activations:
# if you have question, please check: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
# After PyTorch 2.4, we must pass the use_reentrant explicitly
# https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
else:
x = block(x, t, mask=mask, rope=rope)