diff --git a/pyproject.toml b/pyproject.toml index 91a42dd..bf8c473 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "f5-tts" -version = "0.4.4" +version = "0.4.5" description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching" readme = "README.md" license = {text = "MIT License"} diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index 7a54071..b0cefc0 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -120,10 +120,6 @@ class CFM(nn.Module): text = list_str_to_tensor(text).to(device) assert text.shape[0] == batch - if exists(text): - text_lens = (text != -1).sum(dim=-1) - lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters - # duration cond_mask = lens_to_mask(lens) @@ -133,7 +129,9 @@ class CFM(nn.Module): if isinstance(duration, int): duration = torch.full((batch,), duration, device=device, dtype=torch.long) - duration = torch.maximum(lens + 1, duration) # just add one token so something is generated + duration = torch.maximum( + torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration + ) # duration at least text/audio prompt length plus one token, so something is generated duration = duration.clamp(max=max_duration) max_duration = duration.amax()