17 Commits
1.1.5 ... 1.1.7

Author SHA1 Message Date
Yushen CHEN
c96c3aeed8 Update pyproject.toml 2025-07-14 14:36:26 +08:00
Yushen CHEN
9b60fe6a34 update pyproject.toml, set gradio<=5.35.0 until fix #1126 2025-07-14 14:29:19 +08:00
SWivid
a275798a2f last fix patch-1 2025-07-08 18:44:47 +08:00
SWivid
efc7a7498b fix #1111 #1037 remove redundant unwrap_model for AcceleratedOptimizer; which has no attribute '_modules' thus conflict with has_compiled_regions check introduced in accelerate v1.7.0 2025-07-08 18:39:43 +08:00
SWivid
9842314127 update slicer in finetune_gradio, legacy min_length 2s changed to 20s 2025-07-08 16:59:46 +08:00
SWivid
69b0e0110e v1.1.6 fla support, several changed for finetune and infer-cli 2025-07-03 00:08:42 +08:00
SWivid
52c84776e5 fine-grained speed control for infer-cli. #1112 2025-07-02 23:41:55 +08:00
Danh Tran
ebbd7bd91f Update WAV File Naming and Dependencies 📝🔊 (#1091)
* Update infer_cli.py

* Update pyproject.toml

* formalized

---------

Co-authored-by: SWivid <swivid@qq.com>
2025-06-24 23:23:00 +08:00
Yushen CHEN
ac42286d04 update finetune_gradio.py, not to force lower case
Not to force lower case, otherwise train infer mismatch with main infer code
2025-06-23 16:37:51 +08:00
Yushen CHEN
d937efa6f3 fix finetune_gradio.py, not to force lower case 2025-06-23 16:22:33 +08:00
Yushen CHEN
8975fca803 Merge pull request #1084 from starkwj/main
Speedup inference by batching CFG in DiT
2025-06-12 03:54:04 +08:00
SWivid
8b0053ad0c backward compatibility 2025-06-12 03:52:12 +08:00
SWivid
b3ef4ed1d7 correct imple., minor fixes 2025-06-12 03:32:19 +08:00
starkwj
b1a9438496 Batch cfg DiT forward 2025-06-11 09:03:30 +00:00
Zhikang Niu
0914170e98 Add flash_attn2 support attn_mask, minor fixes (#1066)
* add flash attn2 support
* update flash attn config in F5TTS
* fix minor bug of get the length of ref_mel

---------

Co-authored-by: SWivid <swivid@qq.com>
2025-06-11 12:14:32 +08:00
SWivid
c6ebad0220 switch sync-hf workflow logic on release, avoid hidden space error with pypi/local_editable mismatch 2025-06-06 07:23:54 +08:00
SWivid
cfaba6387f refresh hf-space first 2025-06-06 07:22:02 +08:00
19 changed files with 290 additions and 151 deletions

View File

@@ -1,9 +1,8 @@
name: Sync to HF Space
on:
push:
branches:
- main
release:
types: [published]
jobs:
trigger_curl:

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "1.1.5"
version = "1.1.7"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
@@ -14,13 +14,13 @@ classifiers = [
"Programming Language :: Python :: 3",
]
dependencies = [
"accelerate>=0.33.0,!=1.7.0",
"accelerate>=0.33.0",
"bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
"cached_path",
"click",
"datasets",
"ema_pytorch>=0.5.2",
"gradio>=3.45.2",
"gradio<=5.35.0",
"hydra-core>=1.3.0",
"jieba",
"librosa",
@@ -38,6 +38,7 @@ dependencies = [
"tqdm>=4.65.0",
"transformers",
"transformers_stream_generator",
"unidecode",
"vocos",
"wandb",
"x_transformers>=1.31.14",

View File

@@ -31,6 +31,8 @@ model:
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000

View File

@@ -31,6 +31,8 @@ model:
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000

View File

@@ -32,6 +32,8 @@ model:
qk_norm: null # null | rms_norm
conv_layers: 4
pe_attn_head: null
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000

View File

@@ -148,10 +148,15 @@ def main():
vocab_char_map=vocab_char_map,
).to(device)
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
if not os.path.exists(ckpt_path):
ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}"
if os.path.exists(ckpt_prefix + ".pt"):
ckpt_path = ckpt_prefix + ".pt"
elif os.path.exists(ckpt_prefix + ".safetensors"):
ckpt_path = ckpt_prefix + ".safetensors"
else:
print("Loading from self-organized training checkpoints rather than released pretrained.")
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)

View File

@@ -126,8 +126,13 @@ def get_inference_prompt(
else:
text_list = text
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# Duration, mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
ref_mel_len = ref_mel.shape[-1]
if use_truth_duration:
gt_audio, gt_sr = torchaudio.load(gt_wav)
if gt_sr != target_sample_rate:
@@ -142,10 +147,6 @@ def get_inference_prompt(
gen_text_len = len(gt_text.encode("utf-8"))
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert min_tokens <= total_mel_len <= max_tokens, (

View File

@@ -13,8 +13,8 @@ output_file = "infer_cli_story.wav"
[voices.town]
ref_audio = "infer/examples/multi/town.flac"
ref_text = ""
speed = 0.8 # will ignore global speed
[voices.country]
ref_audio = "infer/examples/multi/country.flac"
ref_text = ""

View File

@@ -1 +1 @@
A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land. [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] Goodbye, [main] said he, [country] “Im off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace.
A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] "My poor dear friend, you live here no better than the ants! Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land." [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] "Goodbye," [main] said he, [country] "I'm off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace."

View File

@@ -12,6 +12,7 @@ import tomli
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from unidecode import unidecode
from f5_tts.infer.utils_infer import (
cfg_strength,
@@ -112,6 +113,11 @@ parser.add_argument(
action="store_true",
help="To save each audio chunks during inference",
)
parser.add_argument(
"--no_legacy_text",
action="store_false",
help="Not to use lossy ASCII transliterations of unicode text in saved file names.",
)
parser.add_argument(
"--remove_silence",
action="store_true",
@@ -197,6 +203,12 @@ output_file = args.output_file or config.get(
)
save_chunk = args.save_chunk or config.get("save_chunk", False)
use_legacy_text = args.no_legacy_text or config.get("no_legacy_text", False) # no_legacy_text is a store_false arg
if save_chunk and use_legacy_text:
print(
"\nWarning to --save_chunk: lossy ASCII transliterations of unicode text for legacy (.wav) file names, --no_legacy_text to disable.\n"
)
remove_silence = args.remove_silence or config.get("remove_silence", False)
load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
@@ -321,6 +333,7 @@ def main():
text = re.sub(reg2, "", text)
ref_audio_ = voices[voice]["ref_audio"]
ref_text_ = voices[voice]["ref_text"]
local_speed = voices[voice].get("speed", speed)
gen_text_ = text.strip()
print(f"Voice: {voice}")
audio_segment, final_sample_rate, spectrogram = infer_process(
@@ -335,7 +348,7 @@ def main():
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
speed=local_speed,
fix_duration=fix_duration,
device=device,
)
@@ -344,6 +357,8 @@ def main():
if save_chunk:
if len(gen_text_) > 200:
gen_text_ = gen_text_[:200] + " ... "
if use_legacy_text:
gen_text_ = unidecode(gen_text_)
sf.write(
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
audio_segment,

View File

@@ -116,6 +116,8 @@ class DiT(nn.Module):
qk_norm=None,
conv_layers=0,
pe_attn_head=None,
attn_backend="torch", # "torch" | "flash_attn"
attn_mask_enabled=False,
long_skip_connection=False,
checkpoint_activations=False,
):
@@ -145,6 +147,8 @@ class DiT(nn.Module):
dropout=dropout,
qk_norm=qk_norm,
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
)
for _ in range(depth)
]
@@ -178,26 +182,16 @@ class DiT(nn.Module):
return ckpt_forward
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
def get_input_embed(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, text: text, x: noised audio + cond audio + text
t = self.time_embed(time)
seq_len = x.shape[1]
if cache:
if drop_text:
if self.text_uncond is None:
@@ -209,8 +203,41 @@ class DiT(nn.Module):
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
return x
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, text: text, x: noised audio + cond audio + text
t = self.time_embed(time)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
x = torch.cat((x_cond, x_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
rope = self.rotary_embed.forward_from_seq_len(seq_len)
if self.long_skip_connection is not None:

View File

@@ -141,26 +141,15 @@ class MMDiT(nn.Module):
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
def get_input_embed(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
):
batch = x.shape[0]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cache:
if drop_text:
if self.text_uncond is None:
@@ -174,6 +163,41 @@ class MMDiT(nn.Module):
c = self.text_embed(text, drop_text=drop_text)
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
return x, c
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch = x.shape[0]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
x = torch.cat((x_cond, x_uncond), dim=0)
c = torch.cat((c_cond, c_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
x, c = self.get_input_embed(
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache
)
seq_len = x.shape[1]
text_len = text.shape[1]
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)

View File

@@ -178,26 +178,16 @@ class UNetT(nn.Module):
self.norm_out = RMSNorm(dim)
self.proj_out = nn.Linear(dim, mel_dim)
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
def get_input_embed(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
seq_len = x.shape[1]
if cache:
if drop_text:
if self.text_uncond is None:
@@ -209,8 +199,41 @@ class UNetT(nn.Module):
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
return x
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
x = torch.cat((x_cond, x_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
# postfix time t to input x, [b n d] -> [b n+1 d]
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
if mask is not None:

View File

@@ -162,16 +162,31 @@ class CFM(nn.Module):
# at each step, conditioning is fixed
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
# predict flow
pred = self.transformer(
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
)
# predict flow (cond)
if cfg_strength < 1e-5:
pred = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
drop_audio_cond=False,
drop_text=False,
cache=True,
)
return pred
null_pred = self.transformer(
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
# predict flow (cond and uncond), for classifier-free guidance
pred_cfg = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
cfg_infer=True,
cache=True,
)
pred, null_pred = torch.chunk(pred_cfg, 2, dim=0)
return pred + (pred - null_pred) * cfg_strength
# noise input
@@ -275,10 +290,9 @@ class CFM(nn.Module):
else:
drop_text = False
# if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
# apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
pred = self.transformer(
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask
)
# flow matching loss

View File

@@ -312,7 +312,7 @@ def collate_fn(batch):
max_mel_length = mel_lengths.amax()
padded_mel_specs = []
for spec in mel_specs: # TODO. maybe records mask for attention here
for spec in mel_specs:
padding = (0, max_mel_length - spec.size(-1))
padded_spec = F.pad(spec, padding, value=0)
padded_mel_specs.append(padded_spec)
@@ -324,7 +324,7 @@ def collate_fn(batch):
return dict(
mel=mel_specs,
mel_lengths=mel_lengths,
mel_lengths=mel_lengths, # records for padding mask
text=text,
text_lengths=text_lengths,
)

View File

@@ -6,6 +6,7 @@ nt - text sequence
nw - raw wave length
d - dimension
"""
# flake8: noqa
from __future__ import annotations
@@ -19,6 +20,8 @@ from librosa.filters import mel as librosa_mel_fn
from torch import nn
from x_transformers.x_transformers import apply_rotary_pos_emb
from f5_tts.model.utils import is_package_available
# raw wav to mel spec
@@ -175,7 +178,7 @@ class ConvPositionEmbedding(nn.Module):
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
@@ -417,9 +420,9 @@ class Attention(nn.Module):
def forward(
self,
x: float["b n d"], # noised input x # noqa: F722
c: float["b n d"] = None, # context c # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # noised input x
c: float["b n d"] = None, # context c
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
@@ -431,19 +434,30 @@ class Attention(nn.Module):
# Attention processor
if is_package_available("flash_attn"):
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn import flash_attn_varlen_func, flash_attn_func
class AttnProcessor:
def __init__(
self,
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
attn_backend: str = "torch", # "torch" or "flash_attn"
attn_mask_enabled: bool = True,
):
if attn_backend == "flash_attn":
assert is_package_available("flash_attn"), "Please install flash-attn first."
self.pe_attn_head = pe_attn_head
self.attn_backend = attn_backend
self.attn_mask_enabled = attn_mask_enabled
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # noised input x
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x.shape[0]
@@ -479,16 +493,40 @@ class AttnProcessor:
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
if self.attn_backend == "torch":
# mask. e.g. inference got a batch with different target durations, mask out the padding
if self.attn_mask_enabled and mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif self.attn_backend == "flash_attn":
query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.attn_mask_enabled and mask is not None:
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
value, _, _, _, _ = unpad_input(value, mask)
x = flash_attn_varlen_func(
query,
key,
value,
q_cu_seqlens,
k_cu_seqlens,
q_max_seqlen_in_batch,
k_max_seqlen_in_batch,
)
x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
else:
x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# linear proj
@@ -514,9 +552,9 @@ class JointAttnProcessor:
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
c: float["b nt d"] = None, # context c, here text # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # noised input x
c: float["b nt d"] = None, # context c, here text
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.FloatTensor:
@@ -608,12 +646,27 @@ class JointAttnProcessor:
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
def __init__(
self,
dim,
heads,
dim_head,
ff_mult=4,
dropout=0.1,
qk_norm=None,
pe_attn_head=None,
attn_backend="torch", # "torch" or "flash_attn"
attn_mask_enabled=True,
):
super().__init__()
self.attn_norm = AdaLayerNorm(dim)
self.attn = Attention(
processor=AttnProcessor(pe_attn_head=pe_attn_head),
processor=AttnProcessor(
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
),
dim=dim,
heads=heads,
dim_head=dim_head,
@@ -724,7 +777,7 @@ class TimestepEmbedding(nn.Module):
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, timestep: float["b"]): # noqa: F821
def forward(self, timestep: float["b"]):
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d

View File

@@ -149,7 +149,7 @@ class Trainer:
if self.is_main:
checkpoint = dict(
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
optimizer_state_dict=self.optimizer.state_dict(),
ema_model_state_dict=self.ema_model.state_dict(),
scheduler_state_dict=self.scheduler.state_dict(),
update=update,
@@ -242,7 +242,7 @@ class Trainer:
del checkpoint["model_state_dict"][key]
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if self.scheduler:
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
update = checkpoint["update"]

View File

@@ -35,6 +35,16 @@ def default(v, d):
return v if exists(v) else d
def is_package_available(package_name: str) -> bool:
try:
import importlib
package_exists = importlib.util.find_spec(package_name) is not None
return package_exists
except Exception:
return False
# tensor helpers

View File

@@ -178,50 +178,12 @@ def get_audio_duration(audio_path):
return audio.shape[1] / sample_rate
def clear_text(text):
"""Clean and prepare text by lowering the case and stripping whitespace."""
return text.lower().strip()
def get_rms(
y,
frame_length=2048,
hop_length=512,
pad_mode="constant",
): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
padding = (int(frame_length // 2), int(frame_length // 2))
y = np.pad(y, padding, mode=pad_mode)
axis = -1
# put our new within-frame axis at the end for now
out_strides = y.strides + tuple([y.strides[axis]])
# Reduce the shape on the framing axis
x_shape_trimmed = list(y.shape)
x_shape_trimmed[axis] -= frame_length - 1
out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
if axis < 0:
target_axis = axis - 1
else:
target_axis = axis + 1
xw = np.moveaxis(xw, -1, target_axis)
# Downsample along the target axis
slices = [slice(None)] * xw.ndim
slices[axis] = slice(0, None, hop_length)
x = xw[tuple(slices)]
# Calculate power
power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
return np.sqrt(power)
class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
def __init__(
self,
sr: int,
threshold: float = -40.0,
min_length: int = 2000,
min_length: int = 20000, # 20 seconds
min_interval: int = 300,
hop_size: int = 20,
max_sil_kept: int = 2000,
@@ -252,7 +214,7 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
samples = waveform
if samples.shape[0] <= self.min_length:
return [waveform]
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
sil_tags = []
silence_start = None
clip_start = 0
@@ -306,8 +268,7 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
silence_end = min(total_frames, silence_start + self.max_sil_kept)
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
sil_tags.append((pos, total_frames + 1))
# Apply and return slices.
####音频+起始时间+终止时间
# Apply and return slices: [chunk, start, end]
if len(sil_tags) == 0:
return [[waveform, 0, int(total_frames * self.hop_size)]]
else:
@@ -707,7 +668,7 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.
try:
text = transcribe(file_segment, language)
text = text.lower().strip().replace('"', "")
text = text.strip()
data += f"{name_segment}|{text}\n"
@@ -816,7 +777,7 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
error_files.append([file_audio, "very short text length 3"])
continue
text = clear_text(text)
text = text.strip()
text = convert_char_to_pinyin([text], polyphone=True)[0]
audio_path_list.append(file_audio)
@@ -1127,7 +1088,7 @@ def vocab_check(project_name, tokenizer_type):
if len(sp) != 2:
continue
text = sp[1].lower().strip()
text = sp[1].strip()
if tokenizer_type == "pinyin":
text = convert_char_to_pinyin([text], polyphone=True)[0]
@@ -1234,8 +1195,8 @@ def infer(
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
tts_api.infer(
ref_file=ref_audio,
ref_text=ref_text.lower().strip(),
gen_text=gen_text.lower().strip(),
ref_text=ref_text.strip(),
gen_text=gen_text.strip(),
nfe_step=nfe_step,
speed=speed,
remove_silence=remove_silence,