From a051a68552d1fa857261a351883199bab070ef7b Mon Sep 17 00:00:00 2001 From: SWivid Date: Fri, 24 Oct 2025 05:50:25 +0000 Subject: [PATCH] pytorch imple. fix batch inference skipping last words in shorter sentences issue #1039 #1179 --- src/f5_tts/model/backbones/dit.py | 28 ++++++++++++------- src/f5_tts/model/modules.py | 13 ++++++--- src/f5_tts/runtime/triton_trtllm/benchmark.py | 2 ++ 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index cf64255..0e29b79 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -6,6 +6,7 @@ nt - text sequence nw - raw wave length d - dimension """ +# ruff: noqa: F722 F821 from __future__ import annotations @@ -86,7 +87,7 @@ class TextEmbedding(nn.Module): return upsampled_text - def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722 + def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling: @@ -127,12 +128,19 @@ class InputEmbedding(nn.Module): self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) - def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 + def forward( + self, + x: float["b n d"], + cond: float["b n d"], + text_embed: float["b n d"], + drop_audio_cond=False, + audio_mask: bool["b n"] | None = None, + ): if drop_audio_cond: # cfg for cond audio cond = torch.zeros_like(cond) x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) - x = self.conv_pos_embed(x) + x + x = self.conv_pos_embed(x, mask=audio_mask) + x return x @@ -235,7 +243,7 @@ class DiT(nn.Module): drop_audio_cond: bool = False, drop_text: bool = False, cache: bool = True, - audio_mask: bool["b n"] | None = None, # noqa: F722 + audio_mask: bool["b n"] | None = None, ): if self.text_uncond is None or self.text_cond is None or not cache: if audio_mask is None: @@ -265,7 +273,7 @@ class DiT(nn.Module): else: text_embed = self.text_cond - x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) + x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond, audio_mask=audio_mask) return x @@ -274,11 +282,11 @@ class DiT(nn.Module): def forward( self, - x: float["b n d"], # nosied input audio # noqa: F722 - cond: float["b n d"], # masked cond audio # noqa: F722 - text: int["b nt"], # text # noqa: F722 - time: float["b"] | float[""], # time step # noqa: F821 F722 - mask: bool["b n"] | None = None, # noqa: F722 + x: float["b n d"], # nosied input audio + cond: float["b n d"], # masked cond audio + text: int["b nt"], # text + time: float["b"] | float[""], # time step + mask: bool["b n"] | None = None, drop_audio_cond: bool = False, # cfg for cond audio drop_text: bool = False, # cfg for text cfg_infer: bool = False, # cfg inference, pack cond & uncond forward diff --git a/src/f5_tts/model/modules.py b/src/f5_tts/model/modules.py index 1e2ee4a..9689d64 100644 --- a/src/f5_tts/model/modules.py +++ b/src/f5_tts/model/modules.py @@ -6,7 +6,7 @@ nt - text sequence nw - raw wave length d - dimension """ -# flake8: noqa +# ruff: noqa: F722 F821 from __future__ import annotations @@ -177,14 +177,19 @@ class ConvPositionEmbedding(nn.Module): nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), nn.Mish(), ) + self.layer_need_mask_idx = [i for i, layer in enumerate(self.conv1d) if isinstance(layer, nn.Conv1d)] def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): if mask is not None: - mask = mask[..., None] + mask = mask.unsqueeze(-1) + mask_t = mask.permute(0, 2, 1) x = x.masked_fill(~mask, 0.0) x = x.permute(0, 2, 1) - x = self.conv1d(x) + for i, block in enumerate(self.conv1d): + x = block(x) + if mask is not None and i in self.layer_need_mask_idx: + x = x.masked_fill(~mask_t, 0.0) out = x.permute(0, 2, 1) if mask is not None: @@ -435,8 +440,8 @@ class Attention(nn.Module): # Attention processor if is_package_available("flash_attn"): + from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import pad_input, unpad_input - from flash_attn import flash_attn_varlen_func, flash_attn_func class AttnProcessor: diff --git a/src/f5_tts/runtime/triton_trtllm/benchmark.py b/src/f5_tts/runtime/triton_trtllm/benchmark.py index 2eed1ec..c7e3121 100644 --- a/src/f5_tts/runtime/triton_trtllm/benchmark.py +++ b/src/f5_tts/runtime/triton_trtllm/benchmark.py @@ -307,6 +307,8 @@ def main(): text_mask_padding=pretrained_config["text_mask_padding"], conv_layers=pretrained_config["conv_layers"], pe_attn_head=pretrained_config["pe_attn_head"], + # attn_backend="flash_attn", # torch | flash_attn + # attn_mask_enabled=True, ) model = load_model(DiT, pt_model_config, args.model_path)