From 6b07fb03b243dee8790e9de6749c86ee9501ef4b Mon Sep 17 00:00:00 2001 From: SWivid Date: Fri, 24 Oct 2025 08:30:55 +0000 Subject: [PATCH] clean-up ruff lint --- src/f5_tts/model/backbones/mmdit.py | 15 ++++++++------- src/f5_tts/model/backbones/unett.py | 15 ++++++++------- src/f5_tts/model/cfm.py | 17 +++++++++-------- src/f5_tts/model/modules.py | 16 +++++++--------- src/f5_tts/model/utils.py | 14 ++++++++------ src/f5_tts/runtime/triton_trtllm/benchmark.py | 2 +- 6 files changed, 41 insertions(+), 38 deletions(-) diff --git a/src/f5_tts/model/backbones/mmdit.py b/src/f5_tts/model/backbones/mmdit.py index cd56de9..3019bef 100644 --- a/src/f5_tts/model/backbones/mmdit.py +++ b/src/f5_tts/model/backbones/mmdit.py @@ -6,6 +6,7 @@ nt - text sequence nw - raw wave length d - dimension """ +# ruff: noqa: F722 F821 from __future__ import annotations @@ -36,7 +37,7 @@ class TextEmbedding(nn.Module): self.precompute_max_pos = 1024 self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False) - def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722 + def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() if self.mask_padding: text_mask = text == 0 @@ -69,7 +70,7 @@ class AudioEmbedding(nn.Module): self.linear = nn.Linear(2 * in_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(out_dim) - def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722 + def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): if drop_audio_cond: cond = torch.zeros_like(cond) x = torch.cat((x, cond), dim=-1) @@ -170,11 +171,11 @@ class MMDiT(nn.Module): def forward( self, - x: float["b n d"], # nosied input audio # noqa: F722 - cond: float["b n d"], # masked cond audio # noqa: F722 - text: int["b nt"], # text # noqa: F722 - time: float["b"] | float[""], # time step # noqa: F821 F722 - mask: bool["b n"] | None = None, # noqa: F722 + x: float["b n d"], # nosied input audio + cond: float["b n d"], # masked cond audio + text: int["b nt"], # text + time: float["b"] | float[""], # time step + mask: bool["b n"] | None = None, drop_audio_cond: bool = False, # cfg for cond audio drop_text: bool = False, # cfg for text cfg_infer: bool = False, # cfg inference, pack cond & uncond forward diff --git a/src/f5_tts/model/backbones/unett.py b/src/f5_tts/model/backbones/unett.py index f01081c..2cab47f 100644 --- a/src/f5_tts/model/backbones/unett.py +++ b/src/f5_tts/model/backbones/unett.py @@ -6,6 +6,7 @@ nt - text sequence nw - raw wave length d - dimension """ +# ruff: noqa: F722 F821 from __future__ import annotations @@ -49,7 +50,7 @@ class TextEmbedding(nn.Module): else: self.extra_modeling = False - def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 + def forward(self, text: int["b nt"], seq_len, drop_text=False): text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens batch, text_len = text.shape[0], text.shape[1] @@ -91,7 +92,7 @@ class InputEmbedding(nn.Module): self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) - def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722 + def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): if drop_audio_cond: # cfg for cond audio cond = torch.zeros_like(cond) @@ -215,11 +216,11 @@ class UNetT(nn.Module): def forward( self, - x: float["b n d"], # nosied input audio # noqa: F722 - cond: float["b n d"], # masked cond audio # noqa: F722 - text: int["b nt"], # text # noqa: F722 - time: float["b"] | float[""], # time step # noqa: F821 F722 - mask: bool["b n"] | None = None, # noqa: F722 + x: float["b n d"], # nosied input audio + cond: float["b n d"], # masked cond audio + text: int["b nt"], # text + time: float["b"] | float[""], # time step + mask: bool["b n"] | None = None, drop_audio_cond: bool = False, # cfg for cond audio drop_text: bool = False, # cfg for text cfg_infer: bool = False, # cfg inference, pack cond & uncond forward diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index 17fda1d..a114953 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -6,6 +6,7 @@ nt - text sequence nw - raw wave length d - dimension """ +# ruff: noqa: F722 F821 from __future__ import annotations @@ -82,17 +83,17 @@ class CFM(nn.Module): @torch.no_grad() def sample( self, - cond: float["b n d"] | float["b nw"], # noqa: F722 - text: int["b nt"] | list[str], # noqa: F722 - duration: int | int["b"], # noqa: F821 + cond: float["b n d"] | float["b nw"], + text: int["b nt"] | list[str], + duration: int | int["b"], *, - lens: int["b"] | None = None, # noqa: F821 + lens: int["b"] | None = None, steps=32, cfg_strength=1.0, sway_sampling_coef=None, seed: int | None = None, max_duration=4096, - vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 + vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, use_epss=True, no_ref_audio=False, duplicate_test=False, @@ -229,10 +230,10 @@ class CFM(nn.Module): def forward( self, - inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 - text: int["b nt"] | list[str], # noqa: F722 + inp: float["b n d"] | float["b nw"], # mel or raw wave + text: int["b nt"] | list[str], *, - lens: int["b"] | None = None, # noqa: F821 + lens: int["b"] | None = None, noise_scheduler: str | None = None, ): # handle raw wave diff --git a/src/f5_tts/model/modules.py b/src/f5_tts/model/modules.py index 9689d64..274c857 100644 --- a/src/f5_tts/model/modules.py +++ b/src/f5_tts/model/modules.py @@ -181,21 +181,19 @@ class ConvPositionEmbedding(nn.Module): def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): if mask is not None: - mask = mask.unsqueeze(-1) - mask_t = mask.permute(0, 2, 1) - x = x.masked_fill(~mask, 0.0) + mask = mask.unsqueeze(1) # [B 1 N] + x = x.permute(0, 2, 1) # [B D N] - x = x.permute(0, 2, 1) + if mask is not None: + x = x.masked_fill(~mask, 0.0) for i, block in enumerate(self.conv1d): x = block(x) if mask is not None and i in self.layer_need_mask_idx: - x = x.masked_fill(~mask_t, 0.0) - out = x.permute(0, 2, 1) + x = x.masked_fill(~mask, 0.0) - if mask is not None: - out = out.masked_fill(~mask, 0.0) + x = x.permute(0, 2, 1) # [B N D] - return out + return x # rotary positional embedding related diff --git a/src/f5_tts/model/utils.py b/src/f5_tts/model/utils.py index c5c3829..cd5b3a0 100644 --- a/src/f5_tts/model/utils.py +++ b/src/f5_tts/model/utils.py @@ -1,3 +1,5 @@ +# ruff: noqa: F722 F821 + from __future__ import annotations import os @@ -48,7 +50,7 @@ def is_package_available(package_name: str) -> bool: # tensor helpers -def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821 +def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: if not exists(length): length = t.amax() @@ -56,7 +58,7 @@ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa return seq[None, :] < t[:, None] -def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821 +def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): max_seq_len = seq_len.max().item() seq = torch.arange(max_seq_len, device=start.device).long() start_mask = seq[None, :] >= start[:, None] @@ -64,7 +66,7 @@ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b" return start_mask & end_mask -def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821 +def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): lengths = (frac_lengths * seq_len).long() max_start = seq_len - lengths @@ -75,7 +77,7 @@ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa return mask_from_start_end_indices(seq_len, start, end) -def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722 +def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: if not exists(mask): return t.mean(dim=1) @@ -87,7 +89,7 @@ def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d # simple utf-8 tokenizer, since paper went character based -def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722 +def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True) return text @@ -98,7 +100,7 @@ def list_str_to_idx( text: list[str] | list[list[str]], vocab_char_map: dict[str, int], # {char: idx} padding_value=-1, -) -> int["b nt"]: # noqa: F722 +) -> int["b nt"]: list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) return text diff --git a/src/f5_tts/runtime/triton_trtllm/benchmark.py b/src/f5_tts/runtime/triton_trtllm/benchmark.py index c7e3121..4f6b4b5 100644 --- a/src/f5_tts/runtime/triton_trtllm/benchmark.py +++ b/src/f5_tts/runtime/triton_trtllm/benchmark.py @@ -307,7 +307,7 @@ def main(): text_mask_padding=pretrained_config["text_mask_padding"], conv_layers=pretrained_config["conv_layers"], pe_attn_head=pretrained_config["pe_attn_head"], - # attn_backend="flash_attn", # torch | flash_attn + # attn_backend="flash_attn", # attn_mask_enabled=True, ) model = load_model(DiT, pt_model_config, args.model_path)