10 Commits

Author SHA1 Message Date
Danh Tran
ebbd7bd91f Update WAV File Naming and Dependencies 📝🔊 (#1091)
* Update infer_cli.py

* Update pyproject.toml

* formalized

---------

Co-authored-by: SWivid <swivid@qq.com>
2025-06-24 23:23:00 +08:00
Yushen CHEN
ac42286d04 update finetune_gradio.py, not to force lower case
Not to force lower case, otherwise train infer mismatch with main infer code
2025-06-23 16:37:51 +08:00
Yushen CHEN
d937efa6f3 fix finetune_gradio.py, not to force lower case 2025-06-23 16:22:33 +08:00
Yushen CHEN
8975fca803 Merge pull request #1084 from starkwj/main
Speedup inference by batching CFG in DiT
2025-06-12 03:54:04 +08:00
SWivid
8b0053ad0c backward compatibility 2025-06-12 03:52:12 +08:00
SWivid
b3ef4ed1d7 correct imple., minor fixes 2025-06-12 03:32:19 +08:00
starkwj
b1a9438496 Batch cfg DiT forward 2025-06-11 09:03:30 +00:00
Zhikang Niu
0914170e98 Add flash_attn2 support attn_mask, minor fixes (#1066)
* add flash attn2 support
* update flash attn config in F5TTS
* fix minor bug of get the length of ref_mel

---------

Co-authored-by: SWivid <swivid@qq.com>
2025-06-11 12:14:32 +08:00
SWivid
c6ebad0220 switch sync-hf workflow logic on release, avoid hidden space error with pypi/local_editable mismatch 2025-06-06 07:23:54 +08:00
SWivid
cfaba6387f refresh hf-space first 2025-06-06 07:22:02 +08:00
16 changed files with 278 additions and 106 deletions

View File

@@ -1,9 +1,8 @@
name: Sync to HF Space
on:
push:
branches:
- main
release:
types: [published]
jobs:
trigger_curl:

View File

@@ -38,6 +38,7 @@ dependencies = [
"tqdm>=4.65.0",
"transformers",
"transformers_stream_generator",
"unidecode",
"vocos",
"wandb",
"x_transformers>=1.31.14",

View File

@@ -31,6 +31,8 @@ model:
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000

View File

@@ -31,6 +31,8 @@ model:
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000

View File

@@ -32,6 +32,8 @@ model:
qk_norm: null # null | rms_norm
conv_layers: 4
pe_attn_head: null
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000

View File

@@ -148,10 +148,15 @@ def main():
vocab_char_map=vocab_char_map,
).to(device)
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
if not os.path.exists(ckpt_path):
ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}"
if os.path.exists(ckpt_prefix + ".pt"):
ckpt_path = ckpt_prefix + ".pt"
elif os.path.exists(ckpt_prefix + ".safetensors"):
ckpt_path = ckpt_prefix + ".safetensors"
else:
print("Loading from self-organized training checkpoints rather than released pretrained.")
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)

View File

@@ -126,8 +126,13 @@ def get_inference_prompt(
else:
text_list = text
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# Duration, mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
ref_mel_len = ref_mel.shape[-1]
if use_truth_duration:
gt_audio, gt_sr = torchaudio.load(gt_wav)
if gt_sr != target_sample_rate:
@@ -142,10 +147,6 @@ def get_inference_prompt(
gen_text_len = len(gt_text.encode("utf-8"))
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert min_tokens <= total_mel_len <= max_tokens, (

View File

@@ -12,6 +12,7 @@ import tomli
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from unidecode import unidecode
from f5_tts.infer.utils_infer import (
cfg_strength,
@@ -112,6 +113,11 @@ parser.add_argument(
action="store_true",
help="To save each audio chunks during inference",
)
parser.add_argument(
"--no_legacy_text",
action="store_false",
help="Not to use lossy ASCII transliterations of unicode text in saved file names.",
)
parser.add_argument(
"--remove_silence",
action="store_true",
@@ -197,6 +203,12 @@ output_file = args.output_file or config.get(
)
save_chunk = args.save_chunk or config.get("save_chunk", False)
use_legacy_text = args.no_legacy_text or config.get("no_legacy_text", False) # no_legacy_text is a store_false arg
if save_chunk and use_legacy_text:
print(
"\nWarning to --save_chunk: lossy ASCII transliterations of unicode text for legacy (.wav) file names, --no_legacy_text to disable.\n"
)
remove_silence = args.remove_silence or config.get("remove_silence", False)
load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
@@ -344,6 +356,8 @@ def main():
if save_chunk:
if len(gen_text_) > 200:
gen_text_ = gen_text_[:200] + " ... "
if use_legacy_text:
gen_text_ = unidecode(gen_text_)
sf.write(
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
audio_segment,

View File

@@ -116,6 +116,8 @@ class DiT(nn.Module):
qk_norm=None,
conv_layers=0,
pe_attn_head=None,
attn_backend="torch", # "torch" | "flash_attn"
attn_mask_enabled=False,
long_skip_connection=False,
checkpoint_activations=False,
):
@@ -145,6 +147,8 @@ class DiT(nn.Module):
dropout=dropout,
qk_norm=qk_norm,
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
)
for _ in range(depth)
]
@@ -178,26 +182,16 @@ class DiT(nn.Module):
return ckpt_forward
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
def get_input_embed(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, text: text, x: noised audio + cond audio + text
t = self.time_embed(time)
seq_len = x.shape[1]
if cache:
if drop_text:
if self.text_uncond is None:
@@ -209,8 +203,41 @@ class DiT(nn.Module):
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
return x
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, text: text, x: noised audio + cond audio + text
t = self.time_embed(time)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
x = torch.cat((x_cond, x_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
rope = self.rotary_embed.forward_from_seq_len(seq_len)
if self.long_skip_connection is not None:

View File

@@ -141,26 +141,15 @@ class MMDiT(nn.Module):
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
def get_input_embed(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
):
batch = x.shape[0]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cache:
if drop_text:
if self.text_uncond is None:
@@ -174,6 +163,41 @@ class MMDiT(nn.Module):
c = self.text_embed(text, drop_text=drop_text)
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
return x, c
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch = x.shape[0]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
x = torch.cat((x_cond, x_uncond), dim=0)
c = torch.cat((c_cond, c_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
x, c = self.get_input_embed(
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache
)
seq_len = x.shape[1]
text_len = text.shape[1]
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)

View File

@@ -178,26 +178,16 @@ class UNetT(nn.Module):
self.norm_out = RMSNorm(dim)
self.proj_out = nn.Linear(dim, mel_dim)
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
def get_input_embed(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
seq_len = x.shape[1]
if cache:
if drop_text:
if self.text_uncond is None:
@@ -209,8 +199,41 @@ class UNetT(nn.Module):
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
return x
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
x = torch.cat((x_cond, x_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
# postfix time t to input x, [b n d] -> [b n+1 d]
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
if mask is not None:

View File

@@ -162,16 +162,31 @@ class CFM(nn.Module):
# at each step, conditioning is fixed
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
# predict flow
pred = self.transformer(
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
)
# predict flow (cond)
if cfg_strength < 1e-5:
pred = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
drop_audio_cond=False,
drop_text=False,
cache=True,
)
return pred
null_pred = self.transformer(
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
# predict flow (cond and uncond), for classifier-free guidance
pred_cfg = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
cfg_infer=True,
cache=True,
)
pred, null_pred = torch.chunk(pred_cfg, 2, dim=0)
return pred + (pred - null_pred) * cfg_strength
# noise input
@@ -275,10 +290,9 @@ class CFM(nn.Module):
else:
drop_text = False
# if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
# apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
pred = self.transformer(
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask
)
# flow matching loss

View File

@@ -312,7 +312,7 @@ def collate_fn(batch):
max_mel_length = mel_lengths.amax()
padded_mel_specs = []
for spec in mel_specs: # TODO. maybe records mask for attention here
for spec in mel_specs:
padding = (0, max_mel_length - spec.size(-1))
padded_spec = F.pad(spec, padding, value=0)
padded_mel_specs.append(padded_spec)
@@ -324,7 +324,7 @@ def collate_fn(batch):
return dict(
mel=mel_specs,
mel_lengths=mel_lengths,
mel_lengths=mel_lengths, # records for padding mask
text=text,
text_lengths=text_lengths,
)

View File

@@ -6,6 +6,7 @@ nt - text sequence
nw - raw wave length
d - dimension
"""
# flake8: noqa
from __future__ import annotations
@@ -19,6 +20,8 @@ from librosa.filters import mel as librosa_mel_fn
from torch import nn
from x_transformers.x_transformers import apply_rotary_pos_emb
from f5_tts.model.utils import is_package_available
# raw wav to mel spec
@@ -175,7 +178,7 @@ class ConvPositionEmbedding(nn.Module):
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
@@ -417,9 +420,9 @@ class Attention(nn.Module):
def forward(
self,
x: float["b n d"], # noised input x # noqa: F722
c: float["b n d"] = None, # context c # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # noised input x
c: float["b n d"] = None, # context c
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
@@ -431,19 +434,30 @@ class Attention(nn.Module):
# Attention processor
if is_package_available("flash_attn"):
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn import flash_attn_varlen_func, flash_attn_func
class AttnProcessor:
def __init__(
self,
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
attn_backend: str = "torch", # "torch" or "flash_attn"
attn_mask_enabled: bool = True,
):
if attn_backend == "flash_attn":
assert is_package_available("flash_attn"), "Please install flash-attn first."
self.pe_attn_head = pe_attn_head
self.attn_backend = attn_backend
self.attn_mask_enabled = attn_mask_enabled
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # noised input x
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x.shape[0]
@@ -479,16 +493,40 @@ class AttnProcessor:
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
if self.attn_backend == "torch":
# mask. e.g. inference got a batch with different target durations, mask out the padding
if self.attn_mask_enabled and mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif self.attn_backend == "flash_attn":
query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.attn_mask_enabled and mask is not None:
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
value, _, _, _, _ = unpad_input(value, mask)
x = flash_attn_varlen_func(
query,
key,
value,
q_cu_seqlens,
k_cu_seqlens,
q_max_seqlen_in_batch,
k_max_seqlen_in_batch,
)
x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
else:
x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# linear proj
@@ -514,9 +552,9 @@ class JointAttnProcessor:
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
c: float["b nt d"] = None, # context c, here text # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # noised input x
c: float["b nt d"] = None, # context c, here text
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.FloatTensor:
@@ -608,12 +646,27 @@ class JointAttnProcessor:
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
def __init__(
self,
dim,
heads,
dim_head,
ff_mult=4,
dropout=0.1,
qk_norm=None,
pe_attn_head=None,
attn_backend="torch", # "torch" or "flash_attn"
attn_mask_enabled=True,
):
super().__init__()
self.attn_norm = AdaLayerNorm(dim)
self.attn = Attention(
processor=AttnProcessor(pe_attn_head=pe_attn_head),
processor=AttnProcessor(
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
),
dim=dim,
heads=heads,
dim_head=dim_head,
@@ -724,7 +777,7 @@ class TimestepEmbedding(nn.Module):
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, timestep: float["b"]): # noqa: F821
def forward(self, timestep: float["b"]):
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d

View File

@@ -35,6 +35,16 @@ def default(v, d):
return v if exists(v) else d
def is_package_available(package_name: str) -> bool:
try:
import importlib
package_exists = importlib.util.find_spec(package_name) is not None
return package_exists
except Exception:
return False
# tensor helpers

View File

@@ -178,11 +178,6 @@ def get_audio_duration(audio_path):
return audio.shape[1] / sample_rate
def clear_text(text):
"""Clean and prepare text by lowering the case and stripping whitespace."""
return text.lower().strip()
def get_rms(
y,
frame_length=2048,
@@ -707,7 +702,7 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.
try:
text = transcribe(file_segment, language)
text = text.lower().strip().replace('"', "")
text = text.strip()
data += f"{name_segment}|{text}\n"
@@ -816,7 +811,7 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
error_files.append([file_audio, "very short text length 3"])
continue
text = clear_text(text)
text = text.strip()
text = convert_char_to_pinyin([text], polyphone=True)[0]
audio_path_list.append(file_audio)
@@ -1127,7 +1122,7 @@ def vocab_check(project_name, tokenizer_type):
if len(sp) != 2:
continue
text = sp[1].lower().strip()
text = sp[1].strip()
if tokenizer_type == "pinyin":
text = convert_char_to_pinyin([text], polyphone=True)[0]
@@ -1234,8 +1229,8 @@ def infer(
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
tts_api.infer(
ref_file=ref_audio,
ref_text=ref_text.lower().strip(),
gen_text=gen_text.lower().strip(),
ref_text=ref_text.strip(),
gen_text=gen_text.strip(),
nfe_step=nfe_step,
speed=speed,
remove_silence=remove_silence,