0.4.5 fix extremely short case that lengths of text_seq > audio_seq, causing wrong cond_mask

This commit is contained in:
unknown
2025-01-28 12:38:16 +08:00
parent ee2b77064e
commit 607b92b391
2 changed files with 4 additions and 6 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "0.4.4"
version = "0.4.5"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}

View File

@@ -120,10 +120,6 @@ class CFM(nn.Module):
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
if exists(text):
text_lens = (text != -1).sum(dim=-1)
lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
# duration
cond_mask = lens_to_mask(lens)
@@ -133,7 +129,9 @@ class CFM(nn.Module):
if isinstance(duration, int):
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
duration = torch.maximum(
torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration
) # duration at least text/audio prompt length plus one token, so something is generated
duration = duration.clamp(max=max_duration)
max_duration = duration.amax()