From f05ceda4cbf4a990f4fa1638bc73ea9c8be0bbfa Mon Sep 17 00:00:00 2001 From: SWivid Date: Sat, 15 Mar 2025 16:34:32 +0800 Subject: [PATCH] v1.0.2 fix: torch.utils.checkpoint.checkpoint add use_reentrant=False --- pyproject.toml | 2 +- src/f5_tts/model/backbones/dit.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f76806a..aeb1e78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "f5-tts" -version = "1.0.1" +version = "1.0.2" description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching" readme = "README.md" license = {text = "MIT License"} diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 2ff9670..271e482 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -219,8 +219,7 @@ class DiT(nn.Module): for block in self.transformer_blocks: if self.checkpoint_activations: - # if you have question, please check: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint - # After PyTorch 2.4, we must pass the use_reentrant explicitly + # https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False) else: x = block(x, t, mask=mask, rope=rope)