mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-05 20:40:12 -08:00
This commit is contained in:
@@ -6,6 +6,7 @@ nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
# ruff: noqa: F722 F821
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -86,7 +87,7 @@ class TextEmbedding(nn.Module):
|
||||
|
||||
return upsampled_text
|
||||
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None):
|
||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
|
||||
@@ -127,12 +128,19 @@ class InputEmbedding(nn.Module):
|
||||
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
||||
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
||||
|
||||
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"],
|
||||
cond: float["b n d"],
|
||||
text_embed: float["b n d"],
|
||||
drop_audio_cond=False,
|
||||
audio_mask: bool["b n"] | None = None,
|
||||
):
|
||||
if drop_audio_cond: # cfg for cond audio
|
||||
cond = torch.zeros_like(cond)
|
||||
|
||||
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
||||
x = self.conv_pos_embed(x) + x
|
||||
x = self.conv_pos_embed(x, mask=audio_mask) + x
|
||||
return x
|
||||
|
||||
|
||||
@@ -235,7 +243,7 @@ class DiT(nn.Module):
|
||||
drop_audio_cond: bool = False,
|
||||
drop_text: bool = False,
|
||||
cache: bool = True,
|
||||
audio_mask: bool["b n"] | None = None, # noqa: F722
|
||||
audio_mask: bool["b n"] | None = None,
|
||||
):
|
||||
if self.text_uncond is None or self.text_cond is None or not cache:
|
||||
if audio_mask is None:
|
||||
@@ -265,7 +273,7 @@ class DiT(nn.Module):
|
||||
else:
|
||||
text_embed = self.text_cond
|
||||
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond, audio_mask=audio_mask)
|
||||
|
||||
return x
|
||||
|
||||
@@ -274,11 +282,11 @@ class DiT(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
x: float["b n d"], # nosied input audio
|
||||
cond: float["b n d"], # masked cond audio
|
||||
text: int["b nt"], # text
|
||||
time: float["b"] | float[""], # time step
|
||||
mask: bool["b n"] | None = None,
|
||||
drop_audio_cond: bool = False, # cfg for cond audio
|
||||
drop_text: bool = False, # cfg for text
|
||||
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
|
||||
|
||||
@@ -6,7 +6,7 @@ nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
# flake8: noqa
|
||||
# ruff: noqa: F722 F821
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -177,14 +177,19 @@ class ConvPositionEmbedding(nn.Module):
|
||||
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||
nn.Mish(),
|
||||
)
|
||||
self.layer_need_mask_idx = [i for i, layer in enumerate(self.conv1d) if isinstance(layer, nn.Conv1d)]
|
||||
|
||||
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
|
||||
if mask is not None:
|
||||
mask = mask[..., None]
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask_t = mask.permute(0, 2, 1)
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.conv1d(x)
|
||||
for i, block in enumerate(self.conv1d):
|
||||
x = block(x)
|
||||
if mask is not None and i in self.layer_need_mask_idx:
|
||||
x = x.masked_fill(~mask_t, 0.0)
|
||||
out = x.permute(0, 2, 1)
|
||||
|
||||
if mask is not None:
|
||||
@@ -435,8 +440,8 @@ class Attention(nn.Module):
|
||||
# Attention processor
|
||||
|
||||
if is_package_available("flash_attn"):
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn import flash_attn_varlen_func, flash_attn_func
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
|
||||
@@ -307,6 +307,8 @@ def main():
|
||||
text_mask_padding=pretrained_config["text_mask_padding"],
|
||||
conv_layers=pretrained_config["conv_layers"],
|
||||
pe_attn_head=pretrained_config["pe_attn_head"],
|
||||
# attn_backend="flash_attn", # torch | flash_attn
|
||||
# attn_mask_enabled=True,
|
||||
)
|
||||
model = load_model(DiT, pt_model_config, args.model_path)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user