diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index 90679be..15be1bb 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -22,6 +22,7 @@ from f5_tts.model.modules import MelSpec from f5_tts.model.utils import ( default, exists, + get_epss_timesteps, lens_to_mask, list_str_to_idx, list_str_to_tensor, @@ -92,6 +93,7 @@ class CFM(nn.Module): seed: int | None = None, max_duration=4096, vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 + use_epss=True, no_ref_audio=False, duplicate_test=False, t_inter=0.1, @@ -190,7 +192,10 @@ class CFM(nn.Module): y0 = (1 - t_start) * y0 + t_start * test_cond steps = int(steps * (1 - t_start)) - t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype) + if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE + t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype) + else: + t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype) if sway_sampling_coef is not None: t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) diff --git a/src/f5_tts/model/utils.py b/src/f5_tts/model/utils.py index 439184b..37d5178 100644 --- a/src/f5_tts/model/utils.py +++ b/src/f5_tts/model/utils.py @@ -189,3 +189,22 @@ def repetition_found(text, length=2, tolerance=10): if count > tolerance: return True return False + + +# get the empirically pruned step for sampling + + +def get_epss_timesteps(n, device, dtype): + dt = 1 / 32 + predefined_timesteps = { + 5: [0, 2, 4, 8, 16, 32], + 6: [0, 2, 4, 6, 8, 16, 32], + 7: [0, 2, 4, 6, 8, 16, 24, 32], + 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], + 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32], + 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32], + } + t = predefined_timesteps.get(n, []) + if not t: + return torch.linspace(0, 1, n + 1, device=device, dtype=dtype) + return dt * torch.tensor(t, device=device, dtype=dtype)