Update: Empirically Pruned Step Sampling (#1077)

* update Empirically Pruned Step Sampling

---------

Co-authored-by: Fast-F5-TTS <2942755472@qq.com>
Co-authored-by: SWivid <swivid@qq.com>
This commit is contained in:
Jerrister Zheng
2025-06-04 22:59:30 +08:00
committed by GitHub
parent 6fbe7592f5
commit 2e2acc6ea2
2 changed files with 25 additions and 1 deletions

View File

@@ -22,6 +22,7 @@ from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import (
default,
exists,
get_epss_timesteps,
lens_to_mask,
list_str_to_idx,
list_str_to_tensor,
@@ -92,6 +93,7 @@ class CFM(nn.Module):
seed: int | None = None,
max_duration=4096,
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
use_epss=True,
no_ref_audio=False,
duplicate_test=False,
t_inter=0.1,
@@ -190,7 +192,10 @@ class CFM(nn.Module):
y0 = (1 - t_start) * y0 + t_start * test_cond
steps = int(steps * (1 - t_start))
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE
t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
else:
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
if sway_sampling_coef is not None:
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)

View File

@@ -189,3 +189,22 @@ def repetition_found(text, length=2, tolerance=10):
if count > tolerance:
return True
return False
# get the empirically pruned step for sampling
def get_epss_timesteps(n, device, dtype):
dt = 1 / 32
predefined_timesteps = {
5: [0, 2, 4, 8, 16, 32],
6: [0, 2, 4, 6, 8, 16, 32],
7: [0, 2, 4, 6, 8, 16, 24, 32],
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
}
t = predefined_timesteps.get(n, [])
if not t:
return torch.linspace(0, 1, n + 1, device=device, dtype=dtype)
return dt * torch.tensor(t, device=device, dtype=dtype)