mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-09 19:57:49 -08:00
v1.0.2 fix: torch.utils.checkpoint.checkpoint add use_reentrant=False
This commit is contained in:
@@ -219,8 +219,7 @@ class DiT(nn.Module):
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.checkpoint_activations:
|
||||
# if you have question, please check: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
|
||||
# After PyTorch 2.4, we must pass the use_reentrant explicitly
|
||||
# https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
|
||||
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
|
||||
else:
|
||||
x = block(x, t, mask=mask, rope=rope)
|
||||
|
||||
Reference in New Issue
Block a user