mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-17 15:32:49 -08:00
0.4.5 fix extremely short case that lengths of text_seq > audio_seq, causing wrong cond_mask
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "0.4.4"
|
||||
version = "0.4.5"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
|
||||
@@ -120,10 +120,6 @@ class CFM(nn.Module):
|
||||
text = list_str_to_tensor(text).to(device)
|
||||
assert text.shape[0] == batch
|
||||
|
||||
if exists(text):
|
||||
text_lens = (text != -1).sum(dim=-1)
|
||||
lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
|
||||
|
||||
# duration
|
||||
|
||||
cond_mask = lens_to_mask(lens)
|
||||
@@ -133,7 +129,9 @@ class CFM(nn.Module):
|
||||
if isinstance(duration, int):
|
||||
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
|
||||
|
||||
duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
|
||||
duration = torch.maximum(
|
||||
torch.maximum((text != -1).sum(dim=-1), lens) + 1, duration
|
||||
) # duration at least text/audio prompt length plus one token, so something is generated
|
||||
duration = duration.clamp(max=max_duration)
|
||||
max_duration = duration.amax()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user