diff --git a/.github/workflows/publish-pypi.yaml b/.github/workflows/publish-pypi.yaml new file mode 100644 index 0000000..6e446d9 --- /dev/null +++ b/.github/workflows/publish-pypi.yaml @@ -0,0 +1,66 @@ +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +# GitHub recommends pinning actions to a commit SHA. +# To get a newer version, you will need to update the SHA. +# You can also reference a tag or branch, but the action may change without warning. + +name: Upload Python Package + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + release-build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + + - name: Build release distributions + run: | + # NOTE: put your own distribution build steps here. + python -m pip install build + python -m build + + - name: Upload distributions + uses: actions/upload-artifact@v4 + with: + name: release-dists + path: dist/ + + pypi-publish: + runs-on: ubuntu-latest + + needs: + - release-build + + permissions: + # IMPORTANT: this permission is mandatory for trusted publishing + id-token: write + + # Dedicated environments with protections for publishing are strongly recommended. + environment: + name: pypi + # OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status: + # url: https://pypi.org/p/YOURPROJECT + + steps: + - name: Retrieve release distributions + uses: actions/download-artifact@v4 + with: + name: release-dists + path: dist/ + + - name: Publish release distributions to PyPI + uses: pypa/gh-action-pypi-publish@6f7e8d9c0b1a2c3d4e5f6a7b8c9d0e1f2a3b4c5d diff --git a/README.md b/README.md index 9b8855a..55c4777 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ ### Thanks to all the contributors ! ## News +- **2025/03/12**: F5-TTS v1 base model with better training and inference performance. - **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN). ## Installation @@ -37,7 +38,7 @@ conda activate f5-tts > ```bash > # Install pytorch with your CUDA version, e.g. -> pip install torch==2.3.0+cu118 torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 +> pip install torch==2.4.0+cu124 torchaudio==2.4.0+cu124 --extra-index-url https://download.pytorch.org/whl/cu124 > ``` @@ -159,7 +160,7 @@ volumes: # Run with flags # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage) f5-tts_infer-cli \ ---model "F5-TTS" \ +--model "F5-TTS_v1" \ --ref_audio "ref_audio.wav" \ --ref_text "The content, subtitle or transcription of reference audio." \ --gen_text "Some text you want TTS model generate for you." diff --git a/ckpts/README.md b/ckpts/README.md index 45ac3b0..0d6b048 100644 --- a/ckpts/README.md +++ b/ckpts/README.md @@ -3,8 +3,10 @@ Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS ``` ckpts/ - E2TTS_Base/ - model_1200000.pt + F5TTS_v1_Base/ + model_1250000.safetensors F5TTS_Base/ - model_1200000.pt + model_1200000.safetensors + E2TTS_Base/ + model_1200000.safetensors ``` \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0019b26..26e42c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "f5-tts" -version = "0.6.2" +version = "1.0.0" description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching" readme = "README.md" license = {text = "MIT License"} @@ -25,7 +25,6 @@ dependencies = [ "jieba", "librosa", "matplotlib", - "nltk", "numpy<=1.26.4", "pydub", "pypinyin", diff --git a/src/f5_tts/api.py b/src/f5_tts/api.py index d9ca38e..7c73c87 100644 --- a/src/f5_tts/api.py +++ b/src/f5_tts/api.py @@ -5,43 +5,43 @@ from importlib.resources import files import soundfile as sf import tqdm from cached_path import cached_path +from omegaconf import OmegaConf from f5_tts.infer.utils_infer import ( - hop_length, - infer_process, load_model, load_vocoder, + transcribe, preprocess_ref_audio_text, + infer_process, remove_silence_for_generated_wav, save_spectrogram, - transcribe, - target_sample_rate, ) -from f5_tts.model import DiT, UNetT +from f5_tts.model import DiT, UNetT # noqa: F401. used for config from f5_tts.model.utils import seed_everything class F5TTS: def __init__( self, - model_type="F5-TTS", + model="F5TTS_v1_Base", ckpt_file="", vocab_file="", ode_method="euler", use_ema=True, - vocoder_name="vocos", - local_path=None, + vocoder_local_path=None, device=None, hf_cache_dir=None, ): - # Initialize parameters - self.final_wave = None - self.target_sample_rate = target_sample_rate - self.hop_length = hop_length - self.seed = -1 - self.mel_spec_type = vocoder_name + model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) + model_cls = globals()[model_cfg.model.backbone] + model_arc = model_cfg.model.arch + + self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type + self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate + + self.ode_method = ode_method + self.use_ema = use_ema - # Set device if device is not None: self.device = device else: @@ -58,39 +58,31 @@ class F5TTS: ) # Load models - self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir) - self.load_ema_model( - model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir + self.vocoder = load_vocoder( + self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir ) - def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None): - self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir) + repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" - def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None): - if model_type == "F5-TTS": - if not ckpt_file: - if mel_spec_type == "vocos": - ckpt_file = str( - cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir) - ) - elif mel_spec_type == "bigvgan": - ckpt_file = str( - cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir) - ) - model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) - model_cls = DiT - elif model_type == "E2-TTS": - if not ckpt_file: - ckpt_file = str( - cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir) - ) - model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) - model_cls = UNetT + # override for previous models + if model == "F5TTS_Base": + if self.mel_spec_type == "vocos": + ckpt_step = 1200000 + elif self.mel_spec_type == "bigvgan": + model = "F5TTS_Base_bigvgan" + ckpt_type = "pt" + elif model == "E2TTS_Base": + repo_name = "E2-TTS" + ckpt_step = 1200000 else: - raise ValueError(f"Unknown model type: {model_type}") + raise ValueError(f"Unknown model type: {model}") + if not ckpt_file: + ckpt_file = str( + cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir) + ) self.ema_model = load_model( - model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device + model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device ) def transcribe(self, ref_audio, language=None): @@ -102,8 +94,8 @@ class F5TTS: if remove_silence: remove_silence_for_generated_wav(file_wave) - def export_spectrogram(self, spect, file_spect): - save_spectrogram(spect, file_spect) + def export_spectrogram(self, spec, file_spec): + save_spectrogram(spec, file_spec) def infer( self, @@ -121,17 +113,16 @@ class F5TTS: fix_duration=None, remove_silence=False, file_wave=None, - file_spect=None, - seed=-1, + file_spec=None, + seed=None, ): - if seed == -1: - seed = random.randint(0, sys.maxsize) - seed_everything(seed) - self.seed = seed + if seed is None: + self.seed = random.randint(0, sys.maxsize) + seed_everything(self.seed) ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device) - wav, sr, spect = infer_process( + wav, sr, spec = infer_process( ref_file, ref_text, gen_text, @@ -153,22 +144,22 @@ class F5TTS: if file_wave is not None: self.export_wav(wav, file_wave, remove_silence) - if file_spect is not None: - self.export_spectrogram(spect, file_spect) + if file_spec is not None: + self.export_spectrogram(spec, file_spec) - return wav, sr, spect + return wav, sr, spec if __name__ == "__main__": f5tts = F5TTS() - wav, sr, spect = f5tts.infer( + wav, sr, spec = f5tts.infer( ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")), ref_text="some call me nature, others call me mother nature.", gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""", file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")), - file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")), - seed=-1, # random seed = -1 + file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")), + seed=None, ) print("seed :", f5tts.seed) diff --git a/src/f5_tts/configs/E2TTS_Base_train.yaml b/src/f5_tts/configs/E2TTS_Base.yaml similarity index 71% rename from src/f5_tts/configs/E2TTS_Base_train.yaml rename to src/f5_tts/configs/E2TTS_Base.yaml index da23b05..ee70182 100644 --- a/src/f5_tts/configs/E2TTS_Base_train.yaml +++ b/src/f5_tts/configs/E2TTS_Base.yaml @@ -1,16 +1,16 @@ hydra: run: dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} - + datasets: name: Emilia_ZH_EN # dataset name batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 - batch_size_type: frame # "frame" or "sample" + batch_size_type: frame # frame | sample max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models num_workers: 16 optim: - epochs: 15 + epochs: 11 learning_rate: 7.5e-5 num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps @@ -20,25 +20,29 @@ optim: model: name: E2TTS_Base tokenizer: pinyin - tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) + tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt) + backbone: UNetT arch: dim: 1024 depth: 24 heads: 16 ff_mult: 4 + text_mask_padding: False + pe_attn_head: 1 mel_spec: target_sample_rate: 24000 n_mel_channels: 100 hop_length: 256 win_length: 1024 n_fft: 1024 - mel_spec_type: vocos # 'vocos' or 'bigvgan' + mel_spec_type: vocos # vocos | bigvgan vocoder: is_local: False # use local offline ckpt or not - local_path: None # local vocoder path + local_path: null # local vocoder path ckpts: - logger: wandb # wandb | tensorboard | None + logger: wandb # wandb | tensorboard | null + log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates diff --git a/src/f5_tts/configs/E2TTS_Small_train.yaml b/src/f5_tts/configs/E2TTS_Small.yaml similarity index 70% rename from src/f5_tts/configs/E2TTS_Small_train.yaml rename to src/f5_tts/configs/E2TTS_Small.yaml index b2d1a6c..cbb1f44 100644 --- a/src/f5_tts/configs/E2TTS_Small_train.yaml +++ b/src/f5_tts/configs/E2TTS_Small.yaml @@ -1,16 +1,16 @@ hydra: run: dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} - + datasets: name: Emilia_ZH_EN batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 - batch_size_type: frame # "frame" or "sample" + batch_size_type: frame # frame | sample max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models num_workers: 16 optim: - epochs: 15 + epochs: 11 learning_rate: 7.5e-5 num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps @@ -20,25 +20,29 @@ optim: model: name: E2TTS_Small tokenizer: pinyin - tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) + tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt) + backbone: UNetT arch: dim: 768 depth: 20 heads: 12 ff_mult: 4 + text_mask_padding: False + pe_attn_head: 1 mel_spec: target_sample_rate: 24000 n_mel_channels: 100 hop_length: 256 win_length: 1024 n_fft: 1024 - mel_spec_type: vocos # 'vocos' or 'bigvgan' + mel_spec_type: vocos # vocos | bigvgan vocoder: is_local: False # use local offline ckpt or not - local_path: None # local vocoder path + local_path: null # local vocoder path ckpts: - logger: wandb # wandb | tensorboard | None + logger: wandb # wandb | tensorboard | null + log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates diff --git a/src/f5_tts/configs/F5TTS_Base_train.yaml b/src/f5_tts/configs/F5TTS_Base.yaml similarity index 73% rename from src/f5_tts/configs/F5TTS_Base_train.yaml rename to src/f5_tts/configs/F5TTS_Base.yaml index ff8639f..9a2eeb9 100644 --- a/src/f5_tts/configs/F5TTS_Base_train.yaml +++ b/src/f5_tts/configs/F5TTS_Base.yaml @@ -1,16 +1,16 @@ hydra: run: dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} - + datasets: name: Emilia_ZH_EN # dataset name batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 - batch_size_type: frame # "frame" or "sample" + batch_size_type: frame # frame | sample max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models num_workers: 16 optim: - epochs: 15 + epochs: 11 learning_rate: 7.5e-5 num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps @@ -20,14 +20,17 @@ optim: model: name: F5TTS_Base # model name tokenizer: pinyin # tokenizer type - tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) + tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt) + backbone: DiT arch: dim: 1024 depth: 22 heads: 16 ff_mult: 2 text_dim: 512 + text_mask_padding: False conv_layers: 4 + pe_attn_head: 1 checkpoint_activations: False # recompute activations and save memory for extra compute mel_spec: target_sample_rate: 24000 @@ -35,13 +38,14 @@ model: hop_length: 256 win_length: 1024 n_fft: 1024 - mel_spec_type: vocos # 'vocos' or 'bigvgan' + mel_spec_type: vocos # vocos | bigvgan vocoder: is_local: False # use local offline ckpt or not - local_path: None # local vocoder path + local_path: null # local vocoder path ckpts: - logger: wandb # wandb | tensorboard | None + logger: wandb # wandb | tensorboard | null + log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates diff --git a/src/f5_tts/configs/F5TTS_Small_train.yaml b/src/f5_tts/configs/F5TTS_Small.yaml similarity index 73% rename from src/f5_tts/configs/F5TTS_Small_train.yaml rename to src/f5_tts/configs/F5TTS_Small.yaml index 790be06..faae390 100644 --- a/src/f5_tts/configs/F5TTS_Small_train.yaml +++ b/src/f5_tts/configs/F5TTS_Small.yaml @@ -1,16 +1,16 @@ hydra: run: dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} - + datasets: name: Emilia_ZH_EN batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 - batch_size_type: frame # "frame" or "sample" + batch_size_type: frame # frame | sample max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models num_workers: 16 optim: - epochs: 15 + epochs: 11 learning_rate: 7.5e-5 num_warmup_updates: 20000 # warmup updates grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps @@ -20,14 +20,17 @@ optim: model: name: F5TTS_Small tokenizer: pinyin - tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) + tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt) + backbone: DiT arch: dim: 768 depth: 18 heads: 12 ff_mult: 2 text_dim: 512 + text_mask_padding: False conv_layers: 4 + pe_attn_head: 1 checkpoint_activations: False # recompute activations and save memory for extra compute mel_spec: target_sample_rate: 24000 @@ -35,13 +38,14 @@ model: hop_length: 256 win_length: 1024 n_fft: 1024 - mel_spec_type: vocos # 'vocos' or 'bigvgan' + mel_spec_type: vocos # vocos | bigvgan vocoder: is_local: False # use local offline ckpt or not - local_path: None # local vocoder path + local_path: null # local vocoder path ckpts: - logger: wandb # wandb | tensorboard | None + logger: wandb # wandb | tensorboard | null + log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates diff --git a/src/f5_tts/configs/F5TTS_v1_Base.yaml b/src/f5_tts/configs/F5TTS_v1_Base.yaml new file mode 100644 index 0000000..c7717fa --- /dev/null +++ b/src/f5_tts/configs/F5TTS_v1_Base.yaml @@ -0,0 +1,53 @@ +hydra: + run: + dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +datasets: + name: Emilia_ZH_EN # dataset name + batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200 + batch_size_type: frame # frame | sample + max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models + num_workers: 16 + +optim: + epochs: 11 + learning_rate: 7.5e-5 + num_warmup_updates: 20000 # warmup updates + grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps + max_grad_norm: 1.0 # gradient clipping + bnb_optimizer: False # use bnb 8bit AdamW optimizer or not + +model: + name: F5TTS_v1_Base # model name + tokenizer: pinyin # tokenizer type + tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt) + backbone: DiT + arch: + dim: 1024 + depth: 22 + heads: 16 + ff_mult: 2 + text_dim: 512 + text_mask_padding: True + qk_norm: null # null | rms_norm + conv_layers: 4 + pe_attn_head: null + checkpoint_activations: False # recompute activations and save memory for extra compute + mel_spec: + target_sample_rate: 24000 + n_mel_channels: 100 + hop_length: 256 + win_length: 1024 + n_fft: 1024 + mel_spec_type: vocos # vocos | bigvgan + vocoder: + is_local: False # use local offline ckpt or not + local_path: null # local vocoder path + +ckpts: + logger: wandb # wandb | tensorboard | null + log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples + save_per_updates: 50000 # save checkpoint per updates + keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints + last_per_updates: 5000 # save last checkpoint per updates + save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} \ No newline at end of file diff --git a/src/f5_tts/eval/eval_infer_batch.py b/src/f5_tts/eval/eval_infer_batch.py index 785880c..e779ff0 100644 --- a/src/f5_tts/eval/eval_infer_batch.py +++ b/src/f5_tts/eval/eval_infer_batch.py @@ -10,6 +10,7 @@ from importlib.resources import files import torch import torchaudio from accelerate import Accelerator +from omegaconf import OmegaConf from tqdm import tqdm from f5_tts.eval.utils_eval import ( @@ -18,36 +19,26 @@ from f5_tts.eval.utils_eval import ( get_seedtts_testset_metainfo, ) from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder -from f5_tts.model import CFM, DiT, UNetT +from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config from f5_tts.model.utils import get_tokenizer accelerator = Accelerator() device = f"cuda:{accelerator.process_index}" -# --------------------- Dataset Settings -------------------- # - -target_sample_rate = 24000 -n_mel_channels = 100 -hop_length = 256 -win_length = 1024 -n_fft = 1024 +use_ema = True target_rms = 0.1 + rel_path = str(files("f5_tts").joinpath("../../")) def main(): - # ---------------------- infer setting ---------------------- # - parser = argparse.ArgumentParser(description="batch inference") parser.add_argument("-s", "--seed", default=None, type=int) - parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN") parser.add_argument("-n", "--expname", required=True) - parser.add_argument("-c", "--ckptstep", default=1200000, type=int) - parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"]) - parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"]) + parser.add_argument("-c", "--ckptstep", default=1250000, type=int) parser.add_argument("-nfe", "--nfestep", default=32, type=int) parser.add_argument("-o", "--odemethod", default="euler") @@ -58,12 +49,8 @@ def main(): args = parser.parse_args() seed = args.seed - dataset_name = args.dataset exp_name = args.expname ckpt_step = args.ckptstep - ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt" - mel_spec_type = args.mel_spec_type - tokenizer = args.tokenizer nfe_step = args.nfestep ode_method = args.odemethod @@ -77,13 +64,19 @@ def main(): use_truth_duration = False no_ref_audio = False - if exp_name == "F5TTS_Base": - model_cls = DiT - model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) + model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml"))) + model_cls = globals()[model_cfg.model.backbone] + model_arc = model_cfg.model.arch - elif exp_name == "E2TTS_Base": - model_cls = UNetT - model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) + dataset_name = model_cfg.datasets.name + tokenizer = model_cfg.model.tokenizer + + mel_spec_type = model_cfg.model.mel_spec.mel_spec_type + target_sample_rate = model_cfg.model.mel_spec.target_sample_rate + n_mel_channels = model_cfg.model.mel_spec.n_mel_channels + hop_length = model_cfg.model.mel_spec.hop_length + win_length = model_cfg.model.mel_spec.win_length + n_fft = model_cfg.model.mel_spec.n_fft if testset == "ls_pc_test_clean": metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst" @@ -111,8 +104,6 @@ def main(): # -------------------------------------------------# - use_ema = True - prompts_all = get_inference_prompt( metainfo, speed=speed, @@ -139,7 +130,7 @@ def main(): # Model model = CFM( - transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), + transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=dict( n_fft=n_fft, hop_length=hop_length, @@ -154,6 +145,10 @@ def main(): vocab_char_map=vocab_char_map, ).to(device) + ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt" + if not os.path.exists(ckpt_path): + print("Loading from self-organized training checkpoints rather than released pretrained.") + ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt" dtype = torch.float32 if mel_spec_type == "bigvgan" else None model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) diff --git a/src/f5_tts/eval/eval_infer_batch.sh b/src/f5_tts/eval/eval_infer_batch.sh index 47361e3..a5b4f63 100644 --- a/src/f5_tts/eval/eval_infer_batch.sh +++ b/src/f5_tts/eval/eval_infer_batch.sh @@ -1,13 +1,18 @@ #!/bin/bash # e.g. F5-TTS, 16 NFE -accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16 -accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16 -accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16 +accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16 +accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16 +accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16 # e.g. Vanilla E2 TTS, 32 NFE -accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0 -accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0 -accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0 +accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0 +accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0 +accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0 + +# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh +python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8 +python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8 +python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 # etc. diff --git a/src/f5_tts/eval/eval_librispeech_test_clean.py b/src/f5_tts/eval/eval_librispeech_test_clean.py index f172286..0b40368 100644 --- a/src/f5_tts/eval/eval_librispeech_test_clean.py +++ b/src/f5_tts/eval/eval_librispeech_test_clean.py @@ -53,43 +53,37 @@ def main(): asr_ckpt_dir = "" # auto download to cache dir wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" - # --------------------------- WER --------------------------- + # -------------------------------------------------------------------------- + + full_results = [] + metrics = [] if eval_task == "wer": - wer_results = [] - wers = [] - with mp.Pool(processes=len(gpus)) as pool: args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_asr_wer, args) for r in results: - wer_results.extend(r) - - wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl" - with open(wer_result_path, "w") as f: - for line in wer_results: - wers.append(line["wer"]) - json_line = json.dumps(line, ensure_ascii=False) - f.write(json_line + "\n") - - wer = round(np.mean(wers) * 100, 3) - print(f"\nTotal {len(wers)} samples") - print(f"WER : {wer}%") - print(f"Results have been saved to {wer_result_path}") - - # --------------------------- SIM --------------------------- - - if eval_task == "sim": - sims = [] + full_results.extend(r) + elif eval_task == "sim": with mp.Pool(processes=len(gpus)) as pool: args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_sim, args) for r in results: - sims.extend(r) + full_results.extend(r) + else: + raise ValueError(f"Unknown metric type: {eval_task}") - sim = round(sum(sims) / len(sims), 3) - print(f"\nTotal {len(sims)} samples") - print(f"SIM : {sim}") + result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl" + with open(result_path, "w") as f: + for line in full_results: + metrics.append(line[eval_task]) + f.write(json.dumps(line, ensure_ascii=False) + "\n") + metric = round(np.mean(metrics), 5) + f.write(f"\n{eval_task.upper()}: {metric}\n") + + print(f"\nTotal {len(metrics)} samples") + print(f"{eval_task.upper()}: {metric}") + print(f"{eval_task.upper()} results saved to {result_path}") if __name__ == "__main__": diff --git a/src/f5_tts/eval/eval_seedtts_testset.py b/src/f5_tts/eval/eval_seedtts_testset.py index 95a5f44..0bb68ee 100644 --- a/src/f5_tts/eval/eval_seedtts_testset.py +++ b/src/f5_tts/eval/eval_seedtts_testset.py @@ -52,43 +52,37 @@ def main(): asr_ckpt_dir = "" # auto download to cache dir wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" - # --------------------------- WER --------------------------- + # -------------------------------------------------------------------------- + + full_results = [] + metrics = [] if eval_task == "wer": - wer_results = [] - wers = [] - with mp.Pool(processes=len(gpus)) as pool: args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_asr_wer, args) for r in results: - wer_results.extend(r) - - wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl" - with open(wer_result_path, "w") as f: - for line in wer_results: - wers.append(line["wer"]) - json_line = json.dumps(line, ensure_ascii=False) - f.write(json_line + "\n") - - wer = round(np.mean(wers) * 100, 3) - print(f"\nTotal {len(wers)} samples") - print(f"WER : {wer}%") - print(f"Results have been saved to {wer_result_path}") - - # --------------------------- SIM --------------------------- - - if eval_task == "sim": - sims = [] + full_results.extend(r) + elif eval_task == "sim": with mp.Pool(processes=len(gpus)) as pool: args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_sim, args) for r in results: - sims.extend(r) + full_results.extend(r) + else: + raise ValueError(f"Unknown metric type: {eval_task}") - sim = round(sum(sims) / len(sims), 3) - print(f"\nTotal {len(sims)} samples") - print(f"SIM : {sim}") + result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl" + with open(result_path, "w") as f: + for line in full_results: + metrics.append(line[eval_task]) + f.write(json.dumps(line, ensure_ascii=False) + "\n") + metric = round(np.mean(metrics), 5) + f.write(f"\n{eval_task.upper()}: {metric}\n") + + print(f"\nTotal {len(metrics)} samples") + print(f"{eval_task.upper()}: {metric}") + print(f"{eval_task.upper()} results saved to {result_path}") if __name__ == "__main__": diff --git a/src/f5_tts/eval/eval_utmos.py b/src/f5_tts/eval/eval_utmos.py index c4e9449..b6166e8 100644 --- a/src/f5_tts/eval/eval_utmos.py +++ b/src/f5_tts/eval/eval_utmos.py @@ -19,25 +19,23 @@ def main(): predictor = predictor.to(device) audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}")) - utmos_results = {} utmos_score = 0 - for audio_path in tqdm(audio_paths, desc="Processing"): - wav_name = audio_path.stem - wav, sr = librosa.load(audio_path, sr=None, mono=True) - wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0) - score = predictor(wav_tensor, sr) - utmos_results[str(wav_name)] = score.item() - utmos_score += score.item() - - avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0 - print(f"UTMOS: {avg_score}") - - utmos_result_path = Path(args.audio_dir) / "utmos_results.json" + utmos_result_path = Path(args.audio_dir) / "_utmos_results.jsonl" with open(utmos_result_path, "w", encoding="utf-8") as f: - json.dump(utmos_results, f, ensure_ascii=False, indent=4) + for audio_path in tqdm(audio_paths, desc="Processing"): + wav, sr = librosa.load(audio_path, sr=None, mono=True) + wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0) + score = predictor(wav_tensor, sr) + line = {} + line["wav"], line["utmos"] = str(audio_path.stem), score.item() + utmos_score += score.item() + f.write(json.dumps(line, ensure_ascii=False) + "\n") + avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0 + f.write(f"\nUTMOS: {avg_score:.4f}\n") - print(f"Results have been saved to {utmos_result_path}") + print(f"UTMOS: {avg_score:.4f}") + print(f"UTMOS results saved to {utmos_result_path}") if __name__ == "__main__": diff --git a/src/f5_tts/eval/utils_eval.py b/src/f5_tts/eval/utils_eval.py index 7c0a8a8..d8407ad 100644 --- a/src/f5_tts/eval/utils_eval.py +++ b/src/f5_tts/eval/utils_eval.py @@ -389,10 +389,10 @@ def run_sim(args): model = model.cuda(device) model.eval() - sims = [] - for wav1, wav2, truth in tqdm(test_set): - wav1, sr1 = torchaudio.load(wav1) - wav2, sr2 = torchaudio.load(wav2) + sim_results = [] + for gen_wav, prompt_wav, truth in tqdm(test_set): + wav1, sr1 = torchaudio.load(gen_wav) + wav2, sr2 = torchaudio.load(prompt_wav) resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000) resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000) @@ -408,6 +408,11 @@ def run_sim(args): sim = F.cosine_similarity(emb1, emb2)[0].item() # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).") - sims.append(sim) + sim_results.append( + { + "wav": Path(gen_wav).stem, + "sim": sim, + } + ) - return sims + return sim_results diff --git a/src/f5_tts/infer/README.md b/src/f5_tts/infer/README.md index d3bc877..9435fb0 100644 --- a/src/f5_tts/infer/README.md +++ b/src/f5_tts/infer/README.md @@ -68,14 +68,16 @@ Basically you can inference with flags: ```bash # Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage) f5-tts_infer-cli \ ---model "F5-TTS" \ +--model F5TTS_v1_Base \ --ref_audio "ref_audio.wav" \ --ref_text "The content, subtitle or transcription of reference audio." \ --gen_text "Some text you want TTS model generate for you." -# Choose Vocoder -f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file -f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file +# Use BigVGAN as vocoder. Currently only support F5TTS_Base. +f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local + +# Use custom path checkpoint, e.g. +f5-tts_infer-cli --ckpt_file ckpts/F5TTS_Base/model_1200000.safetensors # More instructions f5-tts_infer-cli --help @@ -90,8 +92,8 @@ f5-tts_infer-cli -c custom.toml For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`: ```toml -# F5-TTS | E2-TTS -model = "F5-TTS" +# F5TTS_v1_Base | E2TTS_Base +model = "F5TTS_v1_Base" ref_audio = "infer/examples/basic/basic_ref_en.wav" # If an empty "", transcribes the reference audio automatically. ref_text = "Some call me nature, others call me mother nature." @@ -105,8 +107,8 @@ output_dir = "tests" You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`. ```toml -# F5-TTS | E2-TTS -model = "F5-TTS" +# F5TTS_v1_Base | E2TTS_Base +model = "F5TTS_v1_Base" ref_audio = "infer/examples/multi/main.flac" # If an empty "", transcribes the reference audio automatically. ref_text = "" @@ -126,6 +128,22 @@ ref_text = "" ``` You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`. +## Socket Real-time Service + +Real-time voice output with chunk stream: + +```bash +# Start socket server +python src/f5_tts/socket_server.py + +# If PyAudio not installed +sudo apt-get install portaudio19-dev +pip install pyaudio + +# Communicate with socket client +python src/f5_tts/socket_client.py +``` + ## Speech Editing To test speech editing capabilities, use the following command: @@ -134,86 +152,3 @@ To test speech editing capabilities, use the following command: python src/f5_tts/infer/speech_edit.py ``` -## Socket Realtime Client - -To communicate with socket server you need to run -```bash -python src/f5_tts/socket_server.py -``` - -
-Then create client to communicate - -```bash -# If PyAudio not installed -sudo apt-get install portaudio19-dev -pip install pyaudio -``` - -``` python -# Create the socket_client.py -import socket -import asyncio -import pyaudio -import numpy as np -import logging -import time - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998): - client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port))) - - start_time = time.time() - first_chunk_time = None - - async def play_audio_stream(): - nonlocal first_chunk_time - p = pyaudio.PyAudio() - stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048) - - try: - while True: - data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192) - if not data: - break - if data == b"END": - logger.info("End of audio received.") - break - - audio_array = np.frombuffer(data, dtype=np.float32) - stream.write(audio_array.tobytes()) - - if first_chunk_time is None: - first_chunk_time = time.time() - - finally: - stream.stop_stream() - stream.close() - p.terminate() - - logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds") - - try: - data_to_send = f"{text}".encode("utf-8") - await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send) - await play_audio_stream() - - except Exception as e: - logger.error(f"Error in listen_to_F5TTS: {e}") - - finally: - client_socket.close() - - -if __name__ == "__main__": - text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components" - - asyncio.run(listen_to_F5TTS(text_to_send)) -``` - -
- diff --git a/src/f5_tts/infer/SHARED.md b/src/f5_tts/infer/SHARED.md index 400548f..79d7f56 100644 --- a/src/f5_tts/infer/SHARED.md +++ b/src/f5_tts/infer/SHARED.md @@ -16,7 +16,7 @@ ### Supported Languages - [Multilingual](#multilingual) - - [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts) + - [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts) - [English](#english) - [Finnish](#finnish) - [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen) @@ -37,7 +37,17 @@ ## Multilingual -#### F5-TTS Base @ zh & en @ F5-TTS +#### F5-TTS v1 v0 Base @ zh & en @ F5-TTS +|Model|🤗Hugging Face|Data (Hours)|Model License| +|:---:|:------------:|:-----------:|:-------------:| +|F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0| + +```bash +Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors +Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt +Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} +``` + |Model|🤗Hugging Face|Data (Hours)|Model License| |:---:|:------------:|:-----------:|:-------------:| |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0| @@ -45,7 +55,7 @@ ```bash Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt -Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} +Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...* @@ -64,7 +74,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, " ```bash Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt -Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} +Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` @@ -78,7 +88,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, " ```bash Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt -Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} +Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french). @@ -96,7 +106,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, " ```bash Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt -Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} +Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` - Authors: SPRING Lab, Indian Institute of Technology, Madras @@ -113,7 +123,7 @@ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "c ```bash Model: hf://alien79/F5-TTS-italian/model_159600.safetensors Vocab: hf://alien79/F5-TTS-italian/vocab.txt -Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} +Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` - Trained by [Mithril Man](https://github.com/MithrilMan) @@ -131,7 +141,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, " ```bash Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt -Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} +Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` @@ -148,7 +158,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, " ```bash Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt -Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4} +Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1} ``` - Finetuned by [HotDro4illa](https://github.com/HotDro4illa) - Any improvements are welcome diff --git a/src/f5_tts/infer/examples/basic/basic.toml b/src/f5_tts/infer/examples/basic/basic.toml index c43af38..bc3ebb4 100644 --- a/src/f5_tts/infer/examples/basic/basic.toml +++ b/src/f5_tts/infer/examples/basic/basic.toml @@ -1,5 +1,5 @@ -# F5-TTS | E2-TTS -model = "F5-TTS" +# F5TTS_v1_Base | E2TTS_Base +model = "F5TTS_v1_Base" ref_audio = "infer/examples/basic/basic_ref_en.wav" # If an empty "", transcribes the reference audio automatically. ref_text = "Some call me nature, others call me mother nature." diff --git a/src/f5_tts/infer/examples/multi/story.toml b/src/f5_tts/infer/examples/multi/story.toml index 10ba3fc..f073c26 100644 --- a/src/f5_tts/infer/examples/multi/story.toml +++ b/src/f5_tts/infer/examples/multi/story.toml @@ -1,5 +1,5 @@ -# F5-TTS | E2-TTS -model = "F5-TTS" +# F5TTS_v1_Base | E2TTS_Base +model = "F5TTS_v1_Base" ref_audio = "infer/examples/multi/main.flac" # If an empty "", transcribes the reference audio automatically. ref_text = "" diff --git a/src/f5_tts/infer/infer_cli.py b/src/f5_tts/infer/infer_cli.py index c4e710a..5c7a1bb 100644 --- a/src/f5_tts/infer/infer_cli.py +++ b/src/f5_tts/infer/infer_cli.py @@ -27,7 +27,7 @@ from f5_tts.infer.utils_infer import ( preprocess_ref_audio_text, remove_silence_for_generated_wav, ) -from f5_tts.model import DiT, UNetT +from f5_tts.model import DiT, UNetT # noqa: F401. used for config parser = argparse.ArgumentParser( @@ -50,7 +50,7 @@ parser.add_argument( "-m", "--model", type=str, - help="The model name: F5-TTS | E2-TTS", + help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.", ) parser.add_argument( "-mc", @@ -172,8 +172,7 @@ config = tomli.load(open(args.config, "rb")) # command-line interface parameters -model = args.model or config.get("model", "F5-TTS") -model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml"))) +model = args.model or config.get("model", "F5TTS_v1_Base") ckpt_file = args.ckpt_file or config.get("ckpt_file", "") vocab_file = args.vocab_file or config.get("vocab_file", "") @@ -245,36 +244,32 @@ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_loc # load TTS model -if model == "F5-TTS": - model_cls = DiT - model_cfg = OmegaConf.load(model_cfg).model.arch - if not ckpt_file: # path not specified, download from repo - if vocoder_name == "vocos": - repo_name = "F5-TTS" - exp_name = "F5TTS_Base" - ckpt_step = 1200000 - ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) - # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path - elif vocoder_name == "bigvgan": - repo_name = "F5-TTS" - exp_name = "F5TTS_Base_bigvgan" - ckpt_step = 1250000 - ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")) +model_cfg = OmegaConf.load( + args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) +).model +model_cls = globals()[model_cfg.backbone] -elif model == "E2-TTS": - assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet" - assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet" - model_cls = UNetT - model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) - if not ckpt_file: # path not specified, download from repo - repo_name = "E2-TTS" - exp_name = "E2TTS_Base" +repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" + +if model != "F5TTS_Base": + assert vocoder_name == model_cfg.mel_spec.mel_spec_type + +# override for previous models +if model == "F5TTS_Base": + if vocoder_name == "vocos": ckpt_step = 1200000 - ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) - # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path + elif vocoder_name == "bigvgan": + model = "F5TTS_Base_bigvgan" + ckpt_type = "pt" +elif model == "E2TTS_Base": + repo_name = "E2-TTS" + ckpt_step = 1200000 + +if not ckpt_file: + ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")) print(f"Using {model}...") -ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file) +ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file) # inference process diff --git a/src/f5_tts/infer/infer_gradio.py b/src/f5_tts/infer/infer_gradio.py index 9695bea..72202f6 100644 --- a/src/f5_tts/infer/infer_gradio.py +++ b/src/f5_tts/infer/infer_gradio.py @@ -41,12 +41,12 @@ from f5_tts.infer.utils_infer import ( ) -DEFAULT_TTS_MODEL = "F5-TTS" +DEFAULT_TTS_MODEL = "F5-TTS_v1" tts_model_choice = DEFAULT_TTS_MODEL DEFAULT_TTS_MODEL_CFG = [ - "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", - "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt", + "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors", + "hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt", json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)), ] @@ -56,13 +56,15 @@ DEFAULT_TTS_MODEL_CFG = [ vocoder = load_vocoder() -def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))): - F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) +def load_f5tts(): + ckpt_path = str(cached_path(DEFAULT_TTS_MODEL_CFG[0])) + F5TTS_model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2]) return load_model(DiT, F5TTS_model_cfg, ckpt_path) -def load_e2tts(ckpt_path=str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))): - E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) +def load_e2tts(): + ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) + E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1) return load_model(UNetT, E2TTS_model_cfg, ckpt_path) @@ -73,7 +75,7 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None): if vocab_path.startswith("hf://"): vocab_path = str(cached_path(vocab_path)) if model_cfg is None: - model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) + model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2]) return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path) @@ -130,7 +132,7 @@ def infer( ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info) - if model == "F5-TTS": + if model == DEFAULT_TTS_MODEL: ema_model = F5TTS_ema_model elif model == "E2-TTS": global E2TTS_ema_model @@ -762,7 +764,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip """ ) - last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info.txt") + last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt") def load_last_used_custom(): try: @@ -821,7 +823,30 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip custom_model_cfg = gr.Dropdown( choices=[ DEFAULT_TTS_MODEL_CFG[2], - json.dumps(dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4)), + json.dumps( + dict( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + text_mask_padding=False, + conv_layers=4, + pe_attn_head=1, + ) + ), + json.dumps( + dict( + dim=768, + depth=18, + heads=12, + ff_mult=2, + text_dim=512, + text_mask_padding=False, + conv_layers=4, + pe_attn_head=1, + ) + ), ], value=load_last_used_custom()[2], allow_custom_value=True, diff --git a/src/f5_tts/infer/speech_edit.py b/src/f5_tts/infer/speech_edit.py index 593bc47..d8d073e 100644 --- a/src/f5_tts/infer/speech_edit.py +++ b/src/f5_tts/infer/speech_edit.py @@ -2,12 +2,15 @@ import os os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility +from importlib.resources import files + import torch import torch.nn.functional as F import torchaudio +from omegaconf import OmegaConf from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram -from f5_tts.model import CFM, DiT, UNetT +from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer device = ( @@ -21,44 +24,40 @@ device = ( ) -# --------------------- Dataset Settings -------------------- # - -target_sample_rate = 24000 -n_mel_channels = 100 -hop_length = 256 -win_length = 1024 -n_fft = 1024 -mel_spec_type = "vocos" # 'vocos' or 'bigvgan' -target_rms = 0.1 - -tokenizer = "pinyin" -dataset_name = "Emilia_ZH_EN" - - # ---------------------- infer setting ---------------------- # seed = None # int | None -exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base -ckpt_step = 1200000 +exp_name = "F5TTS_v1_Base" # F5TTS_v1_Base | E2TTS_Base +ckpt_step = 1250000 nfe_step = 32 # 16, 32 cfg_strength = 2.0 ode_method = "euler" # euler | midpoint sway_sampling_coef = -1.0 speed = 1.0 +target_rms = 0.1 -if exp_name == "F5TTS_Base": - model_cls = DiT - model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) -elif exp_name == "E2TTS_Base": - model_cls = UNetT - model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) +model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml"))) +model_cls = globals()[model_cfg.model.backbone] +model_arc = model_cfg.model.arch -ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors" +dataset_name = model_cfg.datasets.name +tokenizer = model_cfg.model.tokenizer + +mel_spec_type = model_cfg.model.mel_spec.mel_spec_type +target_sample_rate = model_cfg.model.mel_spec.target_sample_rate +n_mel_channels = model_cfg.model.mel_spec.n_mel_channels +hop_length = model_cfg.model.mel_spec.hop_length +win_length = model_cfg.model.mel_spec.win_length +n_fft = model_cfg.model.mel_spec.n_fft + + +ckpt_path = str(files("f5_tts").joinpath("../../")) + f"ckpts/{exp_name}/model_{ckpt_step}.safetensors" output_dir = "tests" + # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment] # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git # [write the origin_text into a file, e.g. tests/test_edit.txt] @@ -67,7 +66,7 @@ output_dir = "tests" # [--language "zho" for Chinese, "eng" for English] # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"] -audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_en.wav" +audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")) origin_text = "Some call me nature, others call me mother nature." target_text = "Some call me optimist, others call me realist." parts_to_edit = [ @@ -106,7 +105,7 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) # Model model = CFM( - transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), + transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=dict( n_fft=n_fft, hop_length=hop_length, diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index 6a65654..293cd69 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -301,19 +301,19 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in ) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: - if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000: + if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000: show_info("Audio is over 15s, clipping short. (1)") break non_silent_wave += non_silent_seg # 2. try to find short silence for clipping if 1. failed - if len(non_silent_wave) > 15000: + if len(non_silent_wave) > 12000: non_silent_segs = silence.split_on_silence( aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10 ) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: - if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000: + if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000: show_info("Audio is over 15s, clipping short. (2)") break non_silent_wave += non_silent_seg @@ -321,8 +321,8 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in aseg = non_silent_wave # 3. if no proper silence found for clipping - if len(aseg) > 15000: - aseg = aseg[:15000] + if len(aseg) > 12000: + aseg = aseg[:12000] show_info("Audio is over 15s, clipping short. (3)") aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50) @@ -383,7 +383,7 @@ def infer_process( ): # Split the input text into batches audio, sr = torchaudio.load(ref_audio) - max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr)) + max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr)) gen_text_batches = chunk_text(gen_text, max_chars=max_chars) for i, gen_text in enumerate(gen_text_batches): print(f"gen_text {i}", gen_text) diff --git a/src/f5_tts/model/backbones/README.md b/src/f5_tts/model/backbones/README.md index 155671e..09bd4da 100644 --- a/src/f5_tts/model/backbones/README.md +++ b/src/f5_tts/model/backbones/README.md @@ -4,7 +4,7 @@ ### unett.py - flat unet transformer - structure same as in e2-tts & voicebox paper except using rotary pos emb -- update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat +- possible abs pos emb & convnextv2 blocks for embedded text before concat ### dit.py - adaln-zero dit @@ -14,7 +14,7 @@ - possible long skip connection (first layer to last layer) ### mmdit.py -- sd3 structure +- stable diffusion 3 block structure - timestep as condition - left stream: text embedded and applied a abs pos emb - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 1ecd10e..c462528 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -20,7 +20,7 @@ from f5_tts.model.modules import ( ConvNeXtV2Block, ConvPositionEmbedding, DiTBlock, - AdaLayerNormZero_Final, + AdaLayerNorm_Final, precompute_freqs_cis, get_pos_embed_indices, ) @@ -30,10 +30,12 @@ from f5_tts.model.modules import ( class TextEmbedding(nn.Module): - def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): + def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2): super().__init__() self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token + self.mask_padding = mask_padding # mask filler and batch padding tokens or not + if conv_layers > 0: self.extra_modeling = True self.precompute_max_pos = 4096 # ~44s of 24khz audio @@ -49,6 +51,8 @@ class TextEmbedding(nn.Module): text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text_len), value=0) + if self.mask_padding: + text_mask = text == 0 if drop_text: # cfg for text text = torch.zeros_like(text) @@ -64,7 +68,13 @@ class TextEmbedding(nn.Module): text = text + text_pos_embed # convnextv2 blocks - text = self.text_blocks(text) + if self.mask_padding: + text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) + for block in self.text_blocks: + text = block(text) + text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) + else: + text = self.text_blocks(text) return text @@ -103,7 +113,10 @@ class DiT(nn.Module): mel_dim=100, text_num_embeds=256, text_dim=None, + text_mask_padding=True, + qk_norm=None, conv_layers=0, + pe_attn_head=None, long_skip_connection=False, checkpoint_activations=False, ): @@ -112,7 +125,10 @@ class DiT(nn.Module): self.time_embed = TimestepEmbedding(dim) if text_dim is None: text_dim = mel_dim - self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers) + self.text_embed = TextEmbedding( + text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers + ) + self.text_cond, self.text_uncond = None, None # text cache self.input_embed = InputEmbedding(mel_dim, text_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) @@ -121,15 +137,40 @@ class DiT(nn.Module): self.depth = depth self.transformer_blocks = nn.ModuleList( - [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)] + [ + DiTBlock( + dim=dim, + heads=heads, + dim_head=dim_head, + ff_mult=ff_mult, + dropout=dropout, + qk_norm=qk_norm, + pe_attn_head=pe_attn_head, + ) + for _ in range(depth) + ] ) self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None - self.norm_out = AdaLayerNormZero_Final(dim) # final modulation + self.norm_out = AdaLayerNorm_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) self.checkpoint_activations = checkpoint_activations + self.initialize_weights() + + def initialize_weights(self): + # Zero-out AdaLN layers in DiT blocks: + for block in self.transformer_blocks: + nn.init.constant_(block.attn_norm.linear.weight, 0) + nn.init.constant_(block.attn_norm.linear.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.norm_out.linear.weight, 0) + nn.init.constant_(self.norm_out.linear.bias, 0) + nn.init.constant_(self.proj_out.weight, 0) + nn.init.constant_(self.proj_out.bias, 0) + def ckpt_wrapper(self, module): # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py def ckpt_forward(*inputs): @@ -138,6 +179,9 @@ class DiT(nn.Module): return ckpt_forward + def clear_cache(self): + self.text_cond, self.text_uncond = None, None + def forward( self, x: float["b n d"], # nosied input audio # noqa: F722 @@ -147,14 +191,25 @@ class DiT(nn.Module): drop_audio_cond, # cfg for cond audio drop_text, # cfg for text mask: bool["b n"] | None = None, # noqa: F722 + cache=False, ): batch, seq_len = x.shape[0], x.shape[1] if time.ndim == 0: time = time.repeat(batch) - # t: conditioning time, c: context (text + masked cond audio), x: noised input audio + # t: conditioning time, text: text, x: noised audio + cond audio + text t = self.time_embed(time) - text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + if cache: + if drop_text: + if self.text_uncond is None: + self.text_uncond = self.text_embed(text, seq_len, drop_text=True) + text_embed = self.text_uncond + else: + if self.text_cond is None: + self.text_cond = self.text_embed(text, seq_len, drop_text=False) + text_embed = self.text_cond + else: + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) rope = self.rotary_embed.forward_from_seq_len(seq_len) diff --git a/src/f5_tts/model/backbones/mmdit.py b/src/f5_tts/model/backbones/mmdit.py index 64c7ef1..d150555 100644 --- a/src/f5_tts/model/backbones/mmdit.py +++ b/src/f5_tts/model/backbones/mmdit.py @@ -18,7 +18,7 @@ from f5_tts.model.modules import ( TimestepEmbedding, ConvPositionEmbedding, MMDiTBlock, - AdaLayerNormZero_Final, + AdaLayerNorm_Final, precompute_freqs_cis, get_pos_embed_indices, ) @@ -28,18 +28,24 @@ from f5_tts.model.modules import ( class TextEmbedding(nn.Module): - def __init__(self, out_dim, text_num_embeds): + def __init__(self, out_dim, text_num_embeds, mask_padding=True): super().__init__() self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token + self.mask_padding = mask_padding # mask filler and batch padding tokens or not + self.precompute_max_pos = 1024 self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False) def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722 - text = text + 1 - if drop_text: + text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() + if self.mask_padding: + text_mask = text == 0 + + if drop_text: # cfg for text text = torch.zeros_like(text) - text = self.text_embed(text) + + text = self.text_embed(text) # b nt -> b nt d # sinus pos emb batch_start = torch.zeros((text.shape[0],), dtype=torch.long) @@ -49,6 +55,9 @@ class TextEmbedding(nn.Module): text = text + text_pos_embed + if self.mask_padding: + text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) + return text @@ -83,13 +92,16 @@ class MMDiT(nn.Module): dim_head=64, dropout=0.1, ff_mult=4, - text_num_embeds=256, mel_dim=100, + text_num_embeds=256, + text_mask_padding=True, + qk_norm=None, ): super().__init__() self.time_embed = TimestepEmbedding(dim) - self.text_embed = TextEmbedding(dim, text_num_embeds) + self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding) + self.text_cond, self.text_uncond = None, None # text cache self.audio_embed = AudioEmbedding(mel_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) @@ -106,13 +118,33 @@ class MMDiT(nn.Module): dropout=dropout, ff_mult=ff_mult, context_pre_only=i == depth - 1, + qk_norm=qk_norm, ) for i in range(depth) ] ) - self.norm_out = AdaLayerNormZero_Final(dim) # final modulation + self.norm_out = AdaLayerNorm_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) + self.initialize_weights() + + def initialize_weights(self): + # Zero-out AdaLN layers in MMDiT blocks: + for block in self.transformer_blocks: + nn.init.constant_(block.attn_norm_x.linear.weight, 0) + nn.init.constant_(block.attn_norm_x.linear.bias, 0) + nn.init.constant_(block.attn_norm_c.linear.weight, 0) + nn.init.constant_(block.attn_norm_c.linear.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.norm_out.linear.weight, 0) + nn.init.constant_(self.norm_out.linear.bias, 0) + nn.init.constant_(self.proj_out.weight, 0) + nn.init.constant_(self.proj_out.bias, 0) + + def clear_cache(self): + self.text_cond, self.text_uncond = None, None + def forward( self, x: float["b n d"], # nosied input audio # noqa: F722 @@ -122,6 +154,7 @@ class MMDiT(nn.Module): drop_audio_cond, # cfg for cond audio drop_text, # cfg for text mask: bool["b n"] | None = None, # noqa: F722 + cache=False, ): batch = x.shape[0] if time.ndim == 0: @@ -129,7 +162,17 @@ class MMDiT(nn.Module): # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio t = self.time_embed(time) - c = self.text_embed(text, drop_text=drop_text) + if cache: + if drop_text: + if self.text_uncond is None: + self.text_uncond = self.text_embed(text, drop_text=True) + c = self.text_uncond + else: + if self.text_cond is None: + self.text_cond = self.text_embed(text, drop_text=False) + c = self.text_cond + else: + c = self.text_embed(text, drop_text=drop_text) x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond) seq_len = x.shape[1] diff --git a/src/f5_tts/model/backbones/unett.py b/src/f5_tts/model/backbones/unett.py index acf649a..11e4d02 100644 --- a/src/f5_tts/model/backbones/unett.py +++ b/src/f5_tts/model/backbones/unett.py @@ -33,10 +33,12 @@ from f5_tts.model.modules import ( class TextEmbedding(nn.Module): - def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): + def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2): super().__init__() self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token + self.mask_padding = mask_padding # mask filler and batch padding tokens or not + if conv_layers > 0: self.extra_modeling = True self.precompute_max_pos = 4096 # ~44s of 24khz audio @@ -52,6 +54,8 @@ class TextEmbedding(nn.Module): text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text_len), value=0) + if self.mask_padding: + text_mask = text == 0 if drop_text: # cfg for text text = torch.zeros_like(text) @@ -67,7 +71,13 @@ class TextEmbedding(nn.Module): text = text + text_pos_embed # convnextv2 blocks - text = self.text_blocks(text) + if self.mask_padding: + text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) + for block in self.text_blocks: + text = block(text) + text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0) + else: + text = self.text_blocks(text) return text @@ -106,7 +116,10 @@ class UNetT(nn.Module): mel_dim=100, text_num_embeds=256, text_dim=None, + text_mask_padding=True, + qk_norm=None, conv_layers=0, + pe_attn_head=None, skip_connect_type: Literal["add", "concat", "none"] = "concat", ): super().__init__() @@ -115,7 +128,10 @@ class UNetT(nn.Module): self.time_embed = TimestepEmbedding(dim) if text_dim is None: text_dim = mel_dim - self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers) + self.text_embed = TextEmbedding( + text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers + ) + self.text_cond, self.text_uncond = None, None # text cache self.input_embed = InputEmbedding(mel_dim, text_dim, dim) self.rotary_embed = RotaryEmbedding(dim_head) @@ -134,11 +150,12 @@ class UNetT(nn.Module): attn_norm = RMSNorm(dim) attn = Attention( - processor=AttnProcessor(), + processor=AttnProcessor(pe_attn_head=pe_attn_head), dim=dim, heads=heads, dim_head=dim_head, dropout=dropout, + qk_norm=qk_norm, ) ff_norm = RMSNorm(dim) @@ -161,6 +178,9 @@ class UNetT(nn.Module): self.norm_out = RMSNorm(dim) self.proj_out = nn.Linear(dim, mel_dim) + def clear_cache(self): + self.text_cond, self.text_uncond = None, None + def forward( self, x: float["b n d"], # nosied input audio # noqa: F722 @@ -170,6 +190,7 @@ class UNetT(nn.Module): drop_audio_cond, # cfg for cond audio drop_text, # cfg for text mask: bool["b n"] | None = None, # noqa: F722 + cache=False, ): batch, seq_len = x.shape[0], x.shape[1] if time.ndim == 0: @@ -177,7 +198,17 @@ class UNetT(nn.Module): # t: conditioning time, c: context (text + masked cond audio), x: noised input audio t = self.time_embed(time) - text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + if cache: + if drop_text: + if self.text_uncond is None: + self.text_uncond = self.text_embed(text, seq_len, drop_text=True) + text_embed = self.text_uncond + else: + if self.text_cond is None: + self.text_cond = self.text_embed(text, seq_len, drop_text=False) + text_embed = self.text_cond + else: + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) # postfix time t to input x, [b n d] -> [b n+1 d] diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index b0cefc0..ea4b67f 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -162,13 +162,13 @@ class CFM(nn.Module): # predict flow pred = self.transformer( - x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False + x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True ) if cfg_strength < 1e-5: return pred null_pred = self.transformer( - x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True + x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True ) return pred + (pred - null_pred) * cfg_strength @@ -195,6 +195,7 @@ class CFM(nn.Module): t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) trajectory = odeint(fn, y0, t, **self.odeint_kwargs) + self.transformer.clear_cache() sampled = trajectory[-1] out = sampled diff --git a/src/f5_tts/model/dataset.py b/src/f5_tts/model/dataset.py index 75eeddd..fd6fb11 100644 --- a/src/f5_tts/model/dataset.py +++ b/src/f5_tts/model/dataset.py @@ -173,7 +173,7 @@ class DynamicBatchSampler(Sampler[list[int]]): """ def __init__( - self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False + self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_residual: bool = False ): self.sampler = sampler self.frames_threshold = frames_threshold @@ -208,12 +208,15 @@ class DynamicBatchSampler(Sampler[list[int]]): batch = [] batch_frames = 0 - if not drop_last and len(batch) > 0: + if not drop_residual and len(batch) > 0: batches.append(batch) del indices self.batches = batches + # Ensure even batches with accelerate BatchSamplerShard cls under frame_per_batch setting + self.drop_last = True + def set_epoch(self, epoch: int) -> None: """Sets the epoch for this sampler.""" self.epoch = epoch diff --git a/src/f5_tts/model/modules.py b/src/f5_tts/model/modules.py index bf67fff..8e5c3c2 100644 --- a/src/f5_tts/model/modules.py +++ b/src/f5_tts/model/modules.py @@ -269,11 +269,36 @@ class ConvNeXtV2Block(nn.Module): return residual + x -# AdaLayerNormZero +# RMSNorm + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + self.native_rms_norm = float(torch.__version__[:3]) >= 2.4 + + def forward(self, x): + if self.native_rms_norm: + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.to(self.weight.dtype) + x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps) + else: + variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.to(self.weight.dtype) + x = x * self.weight + + return x + + +# AdaLayerNorm # return with modulated x for attn input, and params for later mlp modulation -class AdaLayerNormZero(nn.Module): +class AdaLayerNorm(nn.Module): def __init__(self, dim): super().__init__() @@ -290,11 +315,11 @@ class AdaLayerNormZero(nn.Module): return x, gate_msa, shift_mlp, scale_mlp, gate_mlp -# AdaLayerNormZero for final layer +# AdaLayerNorm for final layer # return only with modulated x for attn input, cuz no more mlp modulation -class AdaLayerNormZero_Final(nn.Module): +class AdaLayerNorm_Final(nn.Module): def __init__(self, dim): super().__init__() @@ -341,7 +366,8 @@ class Attention(nn.Module): dim_head: int = 64, dropout: float = 0.0, context_dim: Optional[int] = None, # if not None -> joint attention - context_pre_only=None, + context_pre_only: bool = False, + qk_norm: Optional[str] = None, ): super().__init__() @@ -362,18 +388,32 @@ class Attention(nn.Module): self.to_k = nn.Linear(dim, self.inner_dim) self.to_v = nn.Linear(dim, self.inner_dim) + if qk_norm is None: + self.q_norm = None + self.k_norm = None + elif qk_norm == "rms_norm": + self.q_norm = RMSNorm(dim_head, eps=1e-6) + self.k_norm = RMSNorm(dim_head, eps=1e-6) + else: + raise ValueError(f"Unimplemented qk_norm: {qk_norm}") + if self.context_dim is not None: + self.to_q_c = nn.Linear(context_dim, self.inner_dim) self.to_k_c = nn.Linear(context_dim, self.inner_dim) self.to_v_c = nn.Linear(context_dim, self.inner_dim) - if self.context_pre_only is not None: - self.to_q_c = nn.Linear(context_dim, self.inner_dim) + if qk_norm is None: + self.c_q_norm = None + self.c_k_norm = None + elif qk_norm == "rms_norm": + self.c_q_norm = RMSNorm(dim_head, eps=1e-6) + self.c_k_norm = RMSNorm(dim_head, eps=1e-6) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, dim)) self.to_out.append(nn.Dropout(dropout)) - if self.context_pre_only is not None and not self.context_pre_only: - self.to_out_c = nn.Linear(self.inner_dim, dim) + if self.context_dim is not None and not self.context_pre_only: + self.to_out_c = nn.Linear(self.inner_dim, context_dim) def forward( self, @@ -393,8 +433,11 @@ class Attention(nn.Module): class AttnProcessor: - def __init__(self): - pass + def __init__( + self, + pe_attn_head: int | None = None, # number of attention head to apply rope, None for all + ): + self.pe_attn_head = pe_attn_head def __call__( self, @@ -405,19 +448,11 @@ class AttnProcessor: ) -> torch.FloatTensor: batch_size = x.shape[0] - # `sample` projections. + # `sample` projections query = attn.to_q(x) key = attn.to_k(x) value = attn.to_v(x) - # apply rotary position embedding - if rope is not None: - freqs, xpos_scale = rope - q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) - - query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) - key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) - # attention inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -425,6 +460,25 @@ class AttnProcessor: key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # qk norm + if attn.q_norm is not None: + query = attn.q_norm(query) + if attn.k_norm is not None: + key = attn.k_norm(key) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + + if self.pe_attn_head is not None: + pn = self.pe_attn_head + query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale) + key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale) + else: + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + # mask. e.g. inference got a batch with different target durations, mask out the padding if mask is not None: attn_mask = mask @@ -470,16 +524,36 @@ class JointAttnProcessor: batch_size = c.shape[0] - # `sample` projections. + # `sample` projections query = attn.to_q(x) key = attn.to_k(x) value = attn.to_v(x) - # `context` projections. + # `context` projections c_query = attn.to_q_c(c) c_key = attn.to_k_c(c) c_value = attn.to_v_c(c) + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + c_query = c_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # qk norm + if attn.q_norm is not None: + query = attn.q_norm(query) + if attn.k_norm is not None: + key = attn.k_norm(key) + if attn.c_q_norm is not None: + c_query = attn.c_q_norm(c_query) + if attn.c_k_norm is not None: + c_key = attn.c_k_norm(c_key) + # apply rope for context and noised input independently if rope is not None: freqs, xpos_scale = rope @@ -492,16 +566,10 @@ class JointAttnProcessor: c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale) c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale) - # attention - query = torch.cat([query, c_query], dim=1) - key = torch.cat([key, c_key], dim=1) - value = torch.cat([value, c_value], dim=1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # joint attention + query = torch.cat([query, c_query], dim=2) + key = torch.cat([key, c_key], dim=2) + value = torch.cat([value, c_value], dim=2) # mask. e.g. inference got a batch with different target durations, mask out the padding if mask is not None: @@ -540,16 +608,17 @@ class JointAttnProcessor: class DiTBlock(nn.Module): - def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1): + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None): super().__init__() - self.attn_norm = AdaLayerNormZero(dim) + self.attn_norm = AdaLayerNorm(dim) self.attn = Attention( - processor=AttnProcessor(), + processor=AttnProcessor(pe_attn_head=pe_attn_head), dim=dim, heads=heads, dim_head=dim_head, dropout=dropout, + qk_norm=qk_norm, ) self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) @@ -585,26 +654,30 @@ class MMDiTBlock(nn.Module): context_pre_only: last layer only do prenorm + modulation cuz no more ffn """ - def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False): + def __init__( + self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None + ): super().__init__() - + if context_dim is None: + context_dim = dim self.context_pre_only = context_pre_only - self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim) - self.attn_norm_x = AdaLayerNormZero(dim) + self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim) + self.attn_norm_x = AdaLayerNorm(dim) self.attn = Attention( processor=JointAttnProcessor(), dim=dim, heads=heads, dim_head=dim_head, dropout=dropout, - context_dim=dim, + context_dim=context_dim, context_pre_only=context_pre_only, + qk_norm=qk_norm, ) if not context_pre_only: - self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") + self.ff_norm_c = nn.LayerNorm(context_dim, elementwise_affine=False, eps=1e-6) + self.ff_c = FeedForward(dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh") else: self.ff_norm_c = None self.ff_c = None diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 26970a3..d9ab4a8 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -32,7 +32,7 @@ class Trainer: save_per_updates=1000, keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints checkpoint_path=None, - batch_size=32, + batch_size_per_gpu=32, batch_size_type: str = "sample", max_samples=32, grad_accumulation_steps=1, @@ -40,7 +40,7 @@ class Trainer: noise_scheduler: str | None = None, duration_predictor: torch.nn.Module | None = None, logger: str | None = "wandb", # "wandb" | "tensorboard" | None - wandb_project="test_e2-tts", + wandb_project="test_f5-tts", wandb_run_name="test_run", wandb_resume_id: str = None, log_samples: bool = False, @@ -51,6 +51,7 @@ class Trainer: mel_spec_type: str = "vocos", # "vocos" | "bigvgan" is_local_vocoder: bool = False, # use local path vocoder local_vocoder_path: str = "", # local vocoder path + cfg_dict: dict = dict(), # training config ): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) @@ -72,21 +73,23 @@ class Trainer: else: init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} - self.accelerator.init_trackers( - project_name=wandb_project, - init_kwargs=init_kwargs, - config={ + if not cfg_dict: + cfg_dict = { "epochs": epochs, "learning_rate": learning_rate, "num_warmup_updates": num_warmup_updates, - "batch_size": batch_size, + "batch_size_per_gpu": batch_size_per_gpu, "batch_size_type": batch_size_type, "max_samples": max_samples, "grad_accumulation_steps": grad_accumulation_steps, "max_grad_norm": max_grad_norm, - "gpus": self.accelerator.num_processes, "noise_scheduler": noise_scheduler, - }, + } + cfg_dict["gpus"] = self.accelerator.num_processes + self.accelerator.init_trackers( + project_name=wandb_project, + init_kwargs=init_kwargs, + config=cfg_dict, ) elif self.logger == "tensorboard": @@ -111,9 +114,9 @@ class Trainer: self.save_per_updates = save_per_updates self.keep_last_n_checkpoints = keep_last_n_checkpoints self.last_per_updates = default(last_per_updates, save_per_updates) - self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts") + self.checkpoint_path = default(checkpoint_path, "ckpts/test_f5-tts") - self.batch_size = batch_size + self.batch_size_per_gpu = batch_size_per_gpu self.batch_size_type = batch_size_type self.max_samples = max_samples self.grad_accumulation_steps = grad_accumulation_steps @@ -179,7 +182,7 @@ class Trainer: if ( not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) - or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path)) + or not any(filename.endswith((".pt", ".safetensors")) for filename in os.listdir(self.checkpoint_path)) ): return 0 @@ -191,7 +194,7 @@ class Trainer: all_checkpoints = [ f for f in os.listdir(self.checkpoint_path) - if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith(".pt") + if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith((".pt", ".safetensors")) ] # First try to find regular training checkpoints @@ -205,8 +208,16 @@ class Trainer: # If no training checkpoints, use pretrained model latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_")) - # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ - checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu") + if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint + from safetensors.torch import load_file + + checkpoint = load_file(f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu") + checkpoint = {"ema_model_state_dict": checkpoint} + elif latest_checkpoint.endswith(".pt"): + # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ + checkpoint = torch.load( + f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu" + ) # patch for backward compatibility, 305e3ea for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]: @@ -271,7 +282,7 @@ class Trainer: num_workers=num_workers, pin_memory=True, persistent_workers=True, - batch_size=self.batch_size, + batch_size=self.batch_size_per_gpu, shuffle=True, generator=generator, ) @@ -280,10 +291,10 @@ class Trainer: sampler = SequentialSampler(train_dataset) batch_sampler = DynamicBatchSampler( sampler, - self.batch_size, + self.batch_size_per_gpu, max_samples=self.max_samples, random_seed=resumable_with_seed, # This enables reproducible shuffling - drop_last=False, + drop_residual=False, ) train_dataloader = DataLoader( train_dataset, diff --git a/src/f5_tts/model/utils.py b/src/f5_tts/model/utils.py index d9b17b5..24502d2 100644 --- a/src/f5_tts/model/utils.py +++ b/src/f5_tts/model/utils.py @@ -133,11 +133,12 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): # convert char to pinyin -jieba.initialize() -print("Word segmentation module jieba initialized.\n") - def convert_char_to_pinyin(text_list, polyphone=True): + if jieba.dt.initialized is False: + jieba.default_logger.setLevel(50) # CRITICAL + jieba.initialize() + final_text_list = [] custom_trans = str.maketrans( {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} diff --git a/src/f5_tts/scripts/count_max_epoch.py b/src/f5_tts/scripts/count_max_epoch.py index 18d36df..fe291e5 100644 --- a/src/f5_tts/scripts/count_max_epoch.py +++ b/src/f5_tts/scripts/count_max_epoch.py @@ -9,7 +9,7 @@ mel_hop_length = 256 mel_sampling_rate = 24000 # target -wanted_max_updates = 1000000 +wanted_max_updates = 1200000 # train params gpus = 8 diff --git a/src/f5_tts/socket_client.py b/src/f5_tts/socket_client.py new file mode 100644 index 0000000..4cad5e7 --- /dev/null +++ b/src/f5_tts/socket_client.py @@ -0,0 +1,61 @@ +import socket +import asyncio +import pyaudio +import numpy as np +import logging +import time + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998): + client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port))) + + start_time = time.time() + first_chunk_time = None + + async def play_audio_stream(): + nonlocal first_chunk_time + p = pyaudio.PyAudio() + stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048) + + try: + while True: + data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192) + if not data: + break + if data == b"END": + logger.info("End of audio received.") + break + + audio_array = np.frombuffer(data, dtype=np.float32) + stream.write(audio_array.tobytes()) + + if first_chunk_time is None: + first_chunk_time = time.time() + + finally: + stream.stop_stream() + stream.close() + p.terminate() + + logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds") + + try: + data_to_send = f"{text}".encode("utf-8") + await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send) + await play_audio_stream() + + except Exception as e: + logger.error(f"Error in listen_to_F5TTS: {e}") + + finally: + client_socket.close() + + +if __name__ == "__main__": + text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components" + + asyncio.run(listen_to_F5TTS(text_to_send)) diff --git a/src/f5_tts/socket_server.py b/src/f5_tts/socket_server.py index a053da8..344b1d7 100644 --- a/src/f5_tts/socket_server.py +++ b/src/f5_tts/socket_server.py @@ -13,8 +13,9 @@ from importlib.resources import files import torch import torchaudio from huggingface_hub import hf_hub_download +from omegaconf import OmegaConf -from f5_tts.model.backbones.dit import DiT +from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config from f5_tts.infer.utils_infer import ( chunk_text, preprocess_ref_audio_text, @@ -68,7 +69,7 @@ class AudioFileWriterThread(threading.Thread): class TTSStreamingProcessor: - def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32): + def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32): self.device = device or ( "cuda" if torch.cuda.is_available() @@ -78,21 +79,24 @@ class TTSStreamingProcessor: if torch.backends.mps.is_available() else "cpu" ) - self.mel_spec_type = "vocos" + model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) + self.model_cls = globals()[model_cfg.model.backbone] + self.model_arc = model_cfg.model.arch + self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type + self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate + self.model = self.load_ema_model(ckpt_file, vocab_file, dtype) self.vocoder = self.load_vocoder_model() - self.sampling_rate = 24000 + self.update_reference(ref_audio, ref_text) self._warm_up() self.file_writer_thread = None self.first_package = True def load_ema_model(self, ckpt_file, vocab_file, dtype): - model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) - model_cls = DiT return load_model( - model_cls=model_cls, - model_cfg=model_cfg, + self.model_cls, + self.model_arc, ckpt_path=ckpt_file, mel_spec_type=self.mel_spec_type, vocab_file=vocab_file, @@ -212,9 +216,14 @@ if __name__ == "__main__": parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", default=9998) + parser.add_argument( + "--model", + default="F5TTS_v1_Base", + help="The model name, e.g. F5TTS_v1_Base", + ) parser.add_argument( "--ckpt_file", - default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_Base/model_1200000.safetensors")), + default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_v1_Base/model_1250000.safetensors")), help="Path to the model checkpoint file", ) parser.add_argument( @@ -242,6 +251,7 @@ if __name__ == "__main__": try: # Initialize the processor with the model and vocoder processor = TTSStreamingProcessor( + model=args.model, ckpt_file=args.ckpt_file, vocab_file=args.vocab_file, ref_audio=args.ref_audio, diff --git a/src/f5_tts/train/README.md b/src/f5_tts/train/README.md index a57577f..05eb4d6 100644 --- a/src/f5_tts/train/README.md +++ b/src/f5_tts/train/README.md @@ -40,10 +40,10 @@ Once your datasets are prepared, you can start the training process. accelerate config # .yaml files are under src/f5_tts/configs directory -accelerate launch src/f5_tts/train/train.py --config-name F5TTS_Base_train.yaml +accelerate launch src/f5_tts/train/train.py --config-name F5TTS_v1_Base_train.yaml # possible to overwrite accelerate and hydra config -accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_Small_train.yaml ++datasets.batch_size_per_gpu=19200 +accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base_train.yaml ++datasets.batch_size_per_gpu=19200 ``` ### 2. Finetuning practice @@ -53,7 +53,7 @@ Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#1 The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results. -### 3. Wandb Logging +### 3. W&B Logging The `wandb/` dir will be created under path you run training/finetuning scripts. @@ -62,7 +62,7 @@ By default, the training script does NOT use logging (assuming you didn't manual To turn on wandb logging, you can either: 1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login) -2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows: +2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/authorize and set the environment variable as follows: On Mac & Linux: @@ -75,7 +75,7 @@ On Windows: ``` set WANDB_API_KEY= ``` -Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows: +Moreover, if you couldn't access W&B and want to log metrics offline, you can set the environment variable as follows: ``` export WANDB_MODE=offline diff --git a/src/f5_tts/train/finetune_cli.py b/src/f5_tts/train/finetune_cli.py index 6aeb733..28d890c 100644 --- a/src/f5_tts/train/finetune_cli.py +++ b/src/f5_tts/train/finetune_cli.py @@ -1,12 +1,13 @@ import argparse import os import shutil +from importlib.resources import files from cached_path import cached_path + from f5_tts.model import CFM, UNetT, DiT, Trainer from f5_tts.model.utils import get_tokenizer from f5_tts.model.dataset import load_dataset -from importlib.resources import files # -------------------------- Dataset Settings --------------------------- # @@ -20,19 +21,14 @@ mel_spec_type = "vocos" # 'vocos' or 'bigvgan' # -------------------------- Argument Parsing --------------------------- # def parse_args(): - # batch_size_per_gpu = 1000 settting for gpu 8GB - # batch_size_per_gpu = 1600 settting for gpu 12GB - # batch_size_per_gpu = 2000 settting for gpu 16GB - # batch_size_per_gpu = 3200 settting for gpu 24GB - - # num_warmup_updates = 300 for 5000 sample about 10 hours - - # change save_per_updates , last_per_updates change this value what you need , - parser = argparse.ArgumentParser(description="Train CFM Model") parser.add_argument( - "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name" + "--exp_name", + type=str, + default="F5TTS_v1_Base", + choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], + help="Experiment name", ) parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use") parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training") @@ -88,19 +84,54 @@ def main(): checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}")) # Model parameters based on experiment name - if args.exp_name == "F5TTS_Base": + + if args.exp_name == "F5TTS_v1_Base": wandb_resume_id = None model_cls = DiT - model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) + model_cfg = dict( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + conv_layers=4, + ) + if args.finetune: + if args.pretrain is None: + ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors")) + else: + ckpt_path = args.pretrain + + elif args.exp_name == "F5TTS_Base": + wandb_resume_id = None + model_cls = DiT + model_cfg = dict( + dim=1024, + depth=22, + heads=16, + ff_mult=2, + text_dim=512, + text_mask_padding=False, + conv_layers=4, + pe_attn_head=1, + ) if args.finetune: if args.pretrain is None: ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) else: ckpt_path = args.pretrain + elif args.exp_name == "E2TTS_Base": wandb_resume_id = None model_cls = UNetT - model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) + model_cfg = dict( + dim=1024, + depth=24, + heads=16, + ff_mult=4, + text_mask_padding=False, + pe_attn_head=1, + ) if args.finetune: if args.pretrain is None: ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) @@ -120,6 +151,7 @@ def main(): print("copy checkpoint for finetune") # Use the tokenizer and tokenizer_path provided in the command line arguments + tokenizer = args.tokenizer if tokenizer == "custom": if not args.tokenizer_path: @@ -156,7 +188,7 @@ def main(): save_per_updates=args.save_per_updates, keep_last_n_checkpoints=args.keep_last_n_checkpoints, checkpoint_path=checkpoint_path, - batch_size=args.batch_size_per_gpu, + batch_size_per_gpu=args.batch_size_per_gpu, batch_size_type=args.batch_size_type, max_samples=args.max_samples, grad_accumulation_steps=args.grad_accumulation_steps, diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 3d92009..578c931 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -1,36 +1,36 @@ -import threading -import queue -import re - import gc import json +import numpy as np import os import platform import psutil +import queue import random +import re import signal import shutil import subprocess import sys import tempfile +import threading import time from glob import glob +from importlib.resources import files +from scipy.io import wavfile import click import gradio as gr import librosa -import numpy as np import torch import torchaudio +from cached_path import cached_path from datasets import Dataset as Dataset_ from datasets.arrow_writer import ArrowWriter -from safetensors.torch import save_file -from scipy.io import wavfile -from cached_path import cached_path +from safetensors.torch import load_file, save_file + from f5_tts.api import F5TTS from f5_tts.model.utils import convert_char_to_pinyin from f5_tts.infer.utils_infer import transcribe -from importlib.resources import files training_process = None @@ -118,16 +118,16 @@ def load_settings(project_name): # Default settings default_settings = { - "exp_name": "F5TTS_Base", - "learning_rate": 1e-05, - "batch_size_per_gpu": 1000, - "batch_size_type": "frame", + "exp_name": "F5TTS_v1_Base", + "learning_rate": 1e-5, + "batch_size_per_gpu": 1, + "batch_size_type": "sample", "max_samples": 64, - "grad_accumulation_steps": 1, + "grad_accumulation_steps": 4, "max_grad_norm": 1, "epochs": 100, - "num_warmup_updates": 2, - "save_per_updates": 300, + "num_warmup_updates": 100, + "save_per_updates": 500, "keep_last_n_checkpoints": -1, "last_per_updates": 100, "finetune": True, @@ -362,18 +362,18 @@ def terminate_process(pid): def start_training( dataset_name="", - exp_name="F5TTS_Base", - learning_rate=1e-4, - batch_size_per_gpu=400, - batch_size_type="frame", + exp_name="F5TTS_v1_Base", + learning_rate=1e-5, + batch_size_per_gpu=1, + batch_size_type="sample", max_samples=64, - grad_accumulation_steps=1, + grad_accumulation_steps=4, max_grad_norm=1.0, - epochs=11, - num_warmup_updates=200, - save_per_updates=400, + epochs=100, + num_warmup_updates=100, + save_per_updates=500, keep_last_n_checkpoints=-1, - last_per_updates=800, + last_per_updates=100, finetune=True, file_checkpoint_train="", tokenizer_type="pinyin", @@ -797,14 +797,14 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()): print(f"Error processing {file_audio}: {e}") continue - if duration < 1 or duration > 25: - if duration > 25: - error_files.append([file_audio, "duration > 25 sec"]) + if duration < 1 or duration > 30: + if duration > 30: + error_files.append([file_audio, "duration > 30 sec"]) if duration < 1: error_files.append([file_audio, "duration < 1 sec "]) continue if len(text) < 3: - error_files.append([file_audio, "very small text len 3"]) + error_files.append([file_audio, "very short text length 3"]) continue text = clear_text(text) @@ -871,40 +871,37 @@ def check_user(value): def calculate_train( name_project, + epochs, + learning_rate, + batch_size_per_gpu, batch_size_type, max_samples, - learning_rate, num_warmup_updates, - save_per_updates, - last_per_updates, finetune, ): path_project = os.path.join(path_data, name_project) - file_duraction = os.path.join(path_project, "duration.json") + file_duration = os.path.join(path_project, "duration.json") - if not os.path.isfile(file_duraction): + hop_length = 256 + sampling_rate = 24000 + + if not os.path.isfile(file_duration): return ( - 1000, + epochs, + learning_rate, + batch_size_per_gpu, max_samples, num_warmup_updates, - save_per_updates, - last_per_updates, "project not found !", - learning_rate, ) - with open(file_duraction, "r") as file: + with open(file_duration, "r") as file: data = json.load(file) duration_list = data["duration"] - samples = len(duration_list) - hours = sum(duration_list) / 3600 - - # if torch.cuda.is_available(): - # gpu_properties = torch.cuda.get_device_properties(0) - # total_memory = gpu_properties.total_memory / (1024**3) - # elif torch.backends.mps.is_available(): - # total_memory = psutil.virtual_memory().available / (1024**3) + max_sample_length = max(duration_list) * sampling_rate / hop_length + total_samples = len(duration_list) + total_duration = sum(duration_list) if torch.cuda.is_available(): gpu_count = torch.cuda.device_count() @@ -912,64 +909,39 @@ def calculate_train( for i in range(gpu_count): gpu_properties = torch.cuda.get_device_properties(i) total_memory += gpu_properties.total_memory / (1024**3) # in GB - elif torch.xpu.is_available(): gpu_count = torch.xpu.device_count() total_memory = 0 for i in range(gpu_count): gpu_properties = torch.xpu.get_device_properties(i) total_memory += gpu_properties.total_memory / (1024**3) - elif torch.backends.mps.is_available(): gpu_count = 1 total_memory = psutil.virtual_memory().available / (1024**3) + avg_gpu_memory = total_memory / gpu_count + + # rough estimate of batch size if batch_size_type == "frame": - batch = int(total_memory * 0.5) - batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch) - batch_size_per_gpu = int(38400 / batch) - else: - batch_size_per_gpu = int(total_memory / 8) - batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu) - batch = batch_size_per_gpu + batch_size_per_gpu = max(int(38400 * (avg_gpu_memory - 5) / 75), int(max_sample_length)) + elif batch_size_type == "sample": + batch_size_per_gpu = int(200 / (total_duration / total_samples)) - if batch_size_per_gpu <= 0: - batch_size_per_gpu = 1 + if total_samples < 64: + max_samples = int(total_samples * 0.25) - if samples < 64: - max_samples = int(samples * 0.25) - else: - max_samples = 64 + num_warmup_updates = max(num_warmup_updates, int(total_samples * 0.05)) - num_warmup_updates = int(samples * 0.05) - save_per_updates = int(samples * 0.10) - last_per_updates = int(save_per_updates * 0.25) + # take 1.2M updates as the maximum + max_updates = 1200000 - max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples) - num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates) - save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates) - last_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_updates) - if last_per_updates <= 0: - last_per_updates = 2 + if batch_size_type == "frame": + mini_batch_duration = batch_size_per_gpu * gpu_count * hop_length / sampling_rate + updates_per_epoch = total_duration / mini_batch_duration + elif batch_size_type == "sample": + updates_per_epoch = total_samples / batch_size_per_gpu / gpu_count - total_hours = hours - mel_hop_length = 256 - mel_sampling_rate = 24000 - - # target - wanted_max_updates = 1000000 - - # train params - gpus = gpu_count - frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200 - grad_accum = 1 - - # intermediate - mini_batch_frames = frames_per_gpu * grad_accum * gpus - mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600 - updates_per_epoch = total_hours / mini_batch_hours - # steps_per_epoch = updates_per_epoch * grad_accum - epochs = wanted_max_updates / updates_per_epoch + epochs = int(max_updates / updates_per_epoch) if finetune: learning_rate = 1e-5 @@ -977,14 +949,12 @@ def calculate_train( learning_rate = 7.5e-5 return ( + epochs, + learning_rate, batch_size_per_gpu, max_samples, num_warmup_updates, - save_per_updates, - last_per_updates, - samples, - learning_rate, - int(epochs), + total_samples, ) @@ -1021,7 +991,11 @@ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False - ckpt = torch.load(ckpt_path, map_location="cpu") + if ckpt_path.endswith(".safetensors"): + ckpt = load_file(ckpt_path, device="cpu") + ckpt = {"ema_model_state_dict": ckpt} + elif ckpt_path.endswith(".pt"): + ckpt = torch.load(ckpt_path, map_location="cpu") ema_sd = ckpt.get("ema_model_state_dict", {}) embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight" @@ -1089,9 +1063,11 @@ def vocab_extend(project_name, symbols, model_type): with open(file_vocab_project, "w", encoding="utf-8") as f: f.write("\n".join(vocab)) - if model_type == "F5-TTS": + if model_type == "F5TTS_v1_Base": + ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors")) + elif model_type == "F5TTS_Base": ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) - else: + elif model_type == "E2TTS_Base": ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) vocab_size_new = len(miss_symbols) @@ -1101,7 +1077,7 @@ def vocab_extend(project_name, symbols, model_type): os.makedirs(new_ckpt_path, exist_ok=True) # Add pretrained_ prefix to model when copying for consistency with finetune_cli.py - new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_model_1200000.pt") + new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_" + os.path.basename(ckpt_path)) size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new) @@ -1231,21 +1207,21 @@ def infer( vocab_file = os.path.join(path_data, project, "vocab.txt") tts_api = F5TTS( - model_type=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema + model=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema ) print("update >> ", device_test, file_checkpoint, use_ema) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: tts_api.infer( - gen_text=gen_text.lower().strip(), - ref_text=ref_text.lower().strip(), ref_file=ref_audio, + ref_text=ref_text.lower().strip(), + gen_text=gen_text.lower().strip(), nfe_step=nfe_step, - file_wave=f.name, speed=speed, - seed=seed, remove_silence=remove_silence, + file_wave=f.name, + seed=seed, ) return f.name, tts_api.device, str(tts_api.seed) @@ -1404,14 +1380,14 @@ def get_audio_select(file_sample): with gr.Blocks() as app: gr.Markdown( """ -# E2/F5 TTS Automatic Finetune +# F5 TTS Automatic Finetune -This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models: +This is a local web UI for F5 TTS finetuning support. This app supports the following TTS models: * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching) * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS) -The checkpoints support English and Chinese. +The pretrained checkpoints support English and Chinese. For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143) """ @@ -1488,7 +1464,9 @@ Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are incl Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder. ```""") - exp_name_extend = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS") + exp_name_extend = gr.Radio( + label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base" + ) with gr.Row(): txt_extend = gr.Textbox( @@ -1557,9 +1535,9 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare] ) - with gr.TabItem("Train Data"): + with gr.TabItem("Train Model"): gr.Markdown("""```plaintext -The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per updates are set correctly, or change them manually as needed. +The auto-setting is still experimental. Set a large value of epoch if not sure; and keep last N checkpoints if limited disk space. If you encounter a memory error, try reducing the batch size per GPU to a smaller number. ```""") with gr.Row(): @@ -1573,11 +1551,13 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="") with gr.Row(): - exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base") + exp_name = gr.Radio( + label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base" + ) learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5) with gr.Row(): - batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000) + batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=3200) max_samples = gr.Number(label="Max Samples", value=64) with gr.Row(): @@ -1585,23 +1565,23 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0) with gr.Row(): - epochs = gr.Number(label="Epochs", value=10) - num_warmup_updates = gr.Number(label="Warmup Updates", value=2) + epochs = gr.Number(label="Epochs", value=100) + num_warmup_updates = gr.Number(label="Warmup Updates", value=100) with gr.Row(): - save_per_updates = gr.Number(label="Save per Updates", value=300) + save_per_updates = gr.Number(label="Save per Updates", value=500) keep_last_n_checkpoints = gr.Number( label="Keep Last N Checkpoints", value=-1, step=1, precision=0, - info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints", + info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints", ) last_per_updates = gr.Number(label="Last per Updates", value=100) with gr.Row(): ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer") - mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none") + mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="fp16") cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb") start_button = gr.Button("Start Training") stop_button = gr.Button("Stop Training", interactive=False) @@ -1718,23 +1698,21 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle fn=calculate_train, inputs=[ cm_project, + epochs, + learning_rate, + batch_size_per_gpu, batch_size_type, max_samples, - learning_rate, num_warmup_updates, - save_per_updates, - last_per_updates, ch_finetune, ], outputs=[ + epochs, + learning_rate, batch_size_per_gpu, max_samples, num_warmup_updates, - save_per_updates, - last_per_updates, lb_samples, - learning_rate, - epochs, ], ) @@ -1744,25 +1722,25 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle def setup_load_settings(): output_components = [ - exp_name, # 1 - learning_rate, # 2 - batch_size_per_gpu, # 3 - batch_size_type, # 4 - max_samples, # 5 - grad_accumulation_steps, # 6 - max_grad_norm, # 7 - epochs, # 8 - num_warmup_updates, # 9 - save_per_updates, # 10 - keep_last_n_checkpoints, # 11 - last_per_updates, # 12 - ch_finetune, # 13 - file_checkpoint_train, # 14 - tokenizer_type, # 15 - tokenizer_file, # 16 - mixed_precision, # 17 - cd_logger, # 18 - ch_8bit_adam, # 19 + exp_name, + learning_rate, + batch_size_per_gpu, + batch_size_type, + max_samples, + grad_accumulation_steps, + max_grad_norm, + epochs, + num_warmup_updates, + save_per_updates, + keep_last_n_checkpoints, + last_per_updates, + ch_finetune, + file_checkpoint_train, + tokenizer_type, + tokenizer_file, + mixed_precision, + cd_logger, + ch_8bit_adam, ] return output_components @@ -1784,7 +1762,9 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle gr.Markdown("""```plaintext SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random ```""") - exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS") + exp_name = gr.Radio( + label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base" + ) list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False) with gr.Row(): @@ -1838,9 +1818,9 @@ SOS: Check the use_ema setting (True or False) for your model to see what works bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint]) cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint]) - with gr.TabItem("Reduce Checkpoint"): + with gr.TabItem("Prune Checkpoint"): gr.Markdown("""```plaintext -Reduce the model size from 5GB to 1.3GB. The new checkpoint can be used for inference or fine-tuning afterward, but it cannot be used to continue training. +Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining. ```""") txt_path_checkpoint = gr.Text(label="Path to Checkpoint:") txt_path_checkpoint_small = gr.Text(label="Path to Output:") diff --git a/src/f5_tts/train/train.py b/src/f5_tts/train/train.py index ade54be..2e191a3 100644 --- a/src/f5_tts/train/train.py +++ b/src/f5_tts/train/train.py @@ -4,8 +4,9 @@ import os from importlib.resources import files import hydra +from omegaconf import OmegaConf -from f5_tts.model import CFM, DiT, Trainer, UNetT +from f5_tts.model import CFM, DiT, UNetT, Trainer # noqa: F401. used for config from f5_tts.model.dataset import load_dataset from f5_tts.model.utils import get_tokenizer @@ -14,9 +15,13 @@ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to @hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None) def main(cfg): + model_cls = globals()[cfg.model.backbone] + model_arc = cfg.model.arch tokenizer = cfg.model.tokenizer mel_spec_type = cfg.model.mel_spec.mel_spec_type + exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}" + wandb_resume_id = None # set text tokenizer if tokenizer != "custom": @@ -26,14 +31,8 @@ def main(cfg): vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) # set model - if "F5TTS" in cfg.model.name: - model_cls = DiT - elif "E2TTS" in cfg.model.name: - model_cls = UNetT - wandb_resume_id = None - model = CFM( - transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels), + transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels), mel_spec_kwargs=cfg.model.mel_spec, vocab_char_map=vocab_char_map, ) @@ -45,9 +44,9 @@ def main(cfg): learning_rate=cfg.optim.learning_rate, num_warmup_updates=cfg.optim.num_warmup_updates, save_per_updates=cfg.ckpts.save_per_updates, - keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", -1), + keep_last_n_checkpoints=cfg.ckpts.keep_last_n_checkpoints, checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")), - batch_size=cfg.datasets.batch_size_per_gpu, + batch_size_per_gpu=cfg.datasets.batch_size_per_gpu, batch_size_type=cfg.datasets.batch_size_type, max_samples=cfg.datasets.max_samples, grad_accumulation_steps=cfg.optim.grad_accumulation_steps, @@ -57,11 +56,12 @@ def main(cfg): wandb_run_name=exp_name, wandb_resume_id=wandb_resume_id, last_per_updates=cfg.ckpts.last_per_updates, - log_samples=True, + log_samples=cfg.ckpts.log_samples, bnb_optimizer=cfg.optim.bnb_optimizer, mel_spec_type=mel_spec_type, is_local_vocoder=cfg.model.vocoder.is_local, local_vocoder_path=cfg.model.vocoder.local_path, + cfg_dict=OmegaConf.to_container(cfg, resolve=True), ) train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)