8 Commits
1.0.0 ... 1.0.3

Author SHA1 Message Date
SWivid
7e4985ca56 v1.0.3 fix api.py 2025-03-17 02:39:20 +08:00
SWivid
f05ceda4cb v1.0.2 fix: torch.utils.checkpoint.checkpoint add use_reentrant=False 2025-03-15 16:34:32 +08:00
Yushen CHEN
2bd39dd813 Merge pull request #859 from ZhikangNiu/main
fix #858 and pass use_reentrant explicitly in checkpoint_activation mode
2025-03-15 16:23:50 +08:00
ZhikangNiu
f017815083 fix #858 and pass use_reentrant explicitly in checkpoint_activation mode 2025-03-15 15:48:47 +08:00
Yushen CHEN
297755fac3 v1.0.1 VRAM usage management #851 2025-03-14 17:31:44 +08:00
Yushen CHEN
d05075205f Merge pull request #851 from niknah/vram-usage
VRAM usage on long texts gradually uses up memory.
2025-03-14 17:25:56 +08:00
Yushen CHEN
8722cf0766 Update utils_infer.py 2025-03-14 17:23:20 +08:00
niknah
48d1a9312e VRAM usage on long texts gradually uses up memory. 2025-03-14 16:53:58 +11:00
4 changed files with 14 additions and 11 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "1.0.0"
version = "1.0.3"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}

View File

@@ -74,8 +74,6 @@ class F5TTS:
elif model == "E2TTS_Base":
repo_name = "E2-TTS"
ckpt_step = 1200000
else:
raise ValueError(f"Unknown model type: {model}")
if not ckpt_file:
ckpt_file = str(
@@ -117,8 +115,9 @@ class F5TTS:
seed=None,
):
if seed is None:
self.seed = random.randint(0, sys.maxsize)
seed_everything(self.seed)
seed = random.randint(0, sys.maxsize)
seed_everything(seed)
self.seed = seed
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)

View File

@@ -479,14 +479,15 @@ def infer_batch_process(
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
del _
generated = generated.to(torch.float32)
generated = generated.to(torch.float32) # generated mel spectrogram
generated = generated[:, ref_audio_len:, :]
generated_mel_spec = generated.permute(0, 2, 1)
generated = generated.permute(0, 2, 1)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(generated_mel_spec)
generated_wave = vocoder.decode(generated)
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(generated_mel_spec)
generated_wave = vocoder(generated)
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
@@ -497,7 +498,9 @@ def infer_batch_process(
for j in range(0, len(generated_wave), chunk_size):
yield generated_wave[j : j + chunk_size], target_sample_rate
else:
yield generated_wave, generated_mel_spec[0].cpu().numpy()
generated_cpu = generated[0].cpu().numpy()
del generated
yield generated_wave, generated_cpu
if streaming:
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:

View File

@@ -219,7 +219,8 @@ class DiT(nn.Module):
for block in self.transformer_blocks:
if self.checkpoint_activations:
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
# https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
else:
x = block(x, t, mask=mask, rope=rope)