diff --git a/README.md b/README.md index d4308d1..c0ecefd 100644 --- a/README.md +++ b/README.md @@ -58,38 +58,28 @@ Once your datasets are prepared, you can start the training process. # setup accelerate config, e.g. use multi-gpu ddp, fp16 # will be to: ~/.cache/huggingface/accelerate/default_config.yaml accelerate config -accelerate launch test_train.py +accelerate launch train.py ``` An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57). ## Inference -To run inference with pretrained models, download the checkpoints from [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS). +To run inference with pretrained models, download the checkpoints from [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), or automatically downloaded with `inference-cli` and `gradio_app`. -Currently support up to 30s generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by Gradio APP now. +Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`. - To avoid possible inference failures, make sure you have seen through the following instructions. -- A longer prompt audio allows shorter generated output. The part longer than 30s cannot be generated properly. Consider split your text and do several separate inferences or leverage the local Gradio APP which enables a batch inference with chunks. +- A longer prompt audio allows shorter generated output. The part longer than 30s cannot be generated properly. Consider using a prompt audio <15s. - Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words. - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses. If first few words skipped in code-switched generation (cuz different speed with different languages), this might help. -### Single Inference +### CLI Inference -You can test single inference using the following command. Before running the command, modify the config up to your need. +Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py` ```bash -# modify the config up to your need, -# e.g. fix_duration (the total length of prompt + to_generate, currently support up to 30s) -# nfe_step (larger takes more time to do more precise inference ode) -# ode_method (switch to 'midpoint' for better compatibility with small nfe_step, ) -# ( though 'midpoint' is 2nd-order ode solver, slower compared to 1st-order 'Euler') -python test_infer_single.py -``` -### Speech Editing +python inference-cli.py --model "F5-TTS" --ref_audio "tests/ref_audio/test_en_1_ref_short.wav" --ref_text "Some call me nature, others call me mother nature." --gen_text "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences." -To test speech editing capabilities, use the following command. - -```bash -python test_infer_single_edit.py +python inference-cli.py --model "E2-TTS" --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" --ref_text "对,这就是我,万人敬仰的太乙真人。" --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\"" ``` ### Gradio App @@ -102,7 +92,7 @@ First, make sure you have the dependencies installed (`pip install -r requiremen pip install -r requirements_gradio.txt ``` -After installing the dependencies, launch the app (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`): +After installing the dependencies, launch the app (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`. ```bash python gradio_app.py @@ -120,6 +110,14 @@ Or launch a share link: python gradio_app.py --share ``` +### Speech Editing + +To test speech editing capabilities, use the following command. + +```bash +python speech_edit.py +``` + ## Evaluation ### Prepare Test Datasets @@ -127,7 +125,7 @@ python gradio_app.py --share 1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval). 2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/). 3. Unzip the downloaded datasets and place them in the data/ directory. -4. Update the path for the test-clean data in `test_infer_batch.py` +4. Update the path for the test-clean data in `scripts/eval_infer_batch.py` 5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo ### Batch Inference for Test Set @@ -137,7 +135,7 @@ To run batch inference for evaluations, execute the following commands: ```bash # batch inference for evaluations accelerate config # if not set before -bash test_infer_batch.sh +bash scripts/eval_infer_batch.sh ``` ### Download Evaluation Model Checkpoints diff --git a/inference-cli.py b/inference-cli.py index b630600..0bc5e58 100644 --- a/inference-cli.py +++ b/inference-cli.py @@ -1,4 +1,3 @@ -import os import re import torch import torchaudio @@ -16,10 +15,8 @@ from model.utils import ( save_spectrogram, ) from transformers import pipeline -import librosa -import click import soundfile as sf -import tomllib +import tomli import argparse import tqdm from pathlib import Path @@ -42,19 +39,19 @@ parser.add_argument( ) parser.add_argument( "-r", - "--reference", + "--ref_audio", type=str, help="Reference audio file < 15 seconds." ) parser.add_argument( "-s", - "--subtitle", + "--ref_text", type=str, help="Subtitle for the reference audio." ) parser.add_argument( "-t", - "--text", + "--gen_text", type=str, help="Text to generate.", ) @@ -70,11 +67,11 @@ parser.add_argument( ) args = parser.parse_args() -config = tomllib.load(open(args.config, "rb")) +config = tomli.load(open(args.config, "rb")) -ref_audio = args.reference if args.reference else config["reference"] -ref_text = args.subtitle if args.subtitle else config["subtitle"] -gen_text = args.text if args.text else config["text"] +ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"] +ref_text = args.ref_text if args.ref_text else config["ref_text"] +gen_text = args.gen_text if args.gen_text else config["gen_text"] output_dir = args.output_dir if args.output_dir else config["output_dir"] exp_name = args.model if args.model else config["model"] remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"] @@ -100,13 +97,6 @@ device = ( print(f"Using {device} device") -pipe = pipeline( - "automatic-speech-recognition", - model="openai/whisper-large-v3-turbo", - torch_dtype=torch.float16, - device=device, -) - # --------------------- Settings -------------------- # target_sample_rate = 24000 @@ -151,13 +141,6 @@ F5TTS_model_cfg = dict( ) E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) -F5TTS_ema_model = load_model( - "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000 -) -E2TTS_ema_model = load_model( - "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000 -) - def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS): if len(text.encode('utf-8')) <= max_chars: return [text] @@ -256,9 +239,9 @@ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS): def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence): if exp_name == "F5-TTS": - ema_model = F5TTS_ema_model + ema_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000) elif exp_name == "E2-TTS": - ema_model = E2TTS_ema_model + ema_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000) audio, sr = torchaudio.load(ref_audio) if audio.shape[0] > 1: @@ -363,6 +346,12 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s if not ref_text.strip(): print("No reference text provided, transcribing reference audio...") + pipe = pipeline( + "automatic-speech-recognition", + model="openai/whisper-large-v3-turbo", + torch_dtype=torch.float16, + device=device, + ) ref_text = pipe( ref_audio, chunk_length_s=30, diff --git a/inference-cli.toml b/inference-cli.toml index ca89ca1..20e3c38 100644 --- a/inference-cli.toml +++ b/inference-cli.toml @@ -1,8 +1,8 @@ # F5-TTS | E2-TTS model = "F5-TTS" -reference = "tests/ref_audio/test_en_1_ref_short.wav" +ref_audio = "tests/ref_audio/test_en_1_ref_short.wav" # If an empty "", transcribes the reference audio automatically. -subtitle = "Some call me nature, others call me mother nature." -text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences." +ref_text = "Some call me nature, others call me mother nature." +gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences." remove_silence = true output_dir = "tests" \ No newline at end of file diff --git a/model/dataset.py b/model/dataset.py index d1af96a..bb5ec8e 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -188,7 +188,7 @@ def load_dataset( dataset_type: str = "CustomDataset", audio_type: str = "raw", mel_spec_kwargs: dict = dict() - ) -> CustomDataset | HFDataset: + ) -> CustomDataset: print("Loading dataset ...") diff --git a/requirements.txt b/requirements.txt index 62ffa90..fa63937 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,21 @@ accelerate>=0.33.0 +cached_path +click datasets einops>=0.8.0 einx>=0.3.0 ema_pytorch>=0.5.2 faster_whisper funasr +gradio jieba jiwer librosa matplotlib +pydub pypinyin safetensors +soundfile # torch>=2.0 # torchaudio>=2.3.0 torchdiffeq @@ -20,6 +25,4 @@ vocos wandb x_transformers>=1.31.14 zhconv -zhon -pydub -cached_path \ No newline at end of file +zhon \ No newline at end of file diff --git a/requirements_gradio.txt b/requirements_gradio.txt deleted file mode 100644 index a28f7f8..0000000 --- a/requirements_gradio.txt +++ /dev/null @@ -1,5 +0,0 @@ -cached_path -click -gradio -pydub -soundfile \ No newline at end of file diff --git a/test_infer_batch.py b/scripts/eval_infer_batch.py similarity index 99% rename from test_infer_batch.py rename to scripts/eval_infer_batch.py index afd1b28..726eb93 100644 --- a/test_infer_batch.py +++ b/scripts/eval_infer_batch.py @@ -1,4 +1,6 @@ -import os +import sys, os +sys.path.append(os.getcwd()) + import time import random from tqdm import tqdm diff --git a/scripts/eval_infer_batch.sh b/scripts/eval_infer_batch.sh new file mode 100644 index 0000000..45b0717 --- /dev/null +++ b/scripts/eval_infer_batch.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# e.g. F5-TTS, 16 NFE +accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16 +accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16 +accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16 + +# e.g. Vanilla E2 TTS, 32 NFE +accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0 +accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0 +accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0 + +# etc. diff --git a/test_infer_single_edit.py b/speech_edit.py similarity index 100% rename from test_infer_single_edit.py rename to speech_edit.py diff --git a/test_infer_batch.sh b/test_infer_batch.sh deleted file mode 100644 index c9c7a19..0000000 --- a/test_infer_batch.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -# e.g. F5-TTS, 16 NFE -accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16 -accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16 -accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16 - -# e.g. Vanilla E2 TTS, 32 NFE -accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0 -accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0 -accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0 - -# etc. diff --git a/test_infer_single.py b/test_infer_single.py deleted file mode 100644 index d76c6b4..0000000 --- a/test_infer_single.py +++ /dev/null @@ -1,161 +0,0 @@ -import os -import re - -import torch -import torchaudio -from einops import rearrange -from vocos import Vocos - -from model import CFM, UNetT, DiT, MMDiT -from model.utils import ( - load_checkpoint, - get_tokenizer, - convert_char_to_pinyin, - save_spectrogram, -) - -device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - - -# --------------------- Dataset Settings -------------------- # - -target_sample_rate = 24000 -n_mel_channels = 100 -hop_length = 256 -target_rms = 0.1 - -tokenizer = "pinyin" -dataset_name = "Emilia_ZH_EN" - - -# ---------------------- infer setting ---------------------- # - -seed = None # int | None - -exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base -ckpt_step = 1200000 - -nfe_step = 32 # 16, 32 -cfg_strength = 2. -ode_method = 'euler' # euler | midpoint -sway_sampling_coef = -1. -speed = 1. -fix_duration = 27 # None (will linear estimate. if code-switched, consider fix) | float (total in seconds, include ref audio) - -if exp_name == "F5TTS_Base": - model_cls = DiT - model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4) - -elif exp_name == "E2TTS_Base": - model_cls = UNetT - model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4) - -ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" -output_dir = "tests" - -ref_audio = "tests/ref_audio/test_en_1_ref_short.wav" -ref_text = "Some call me nature, others call me mother nature." -gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences." - -# ref_audio = "tests/ref_audio/test_zh_1_ref_short.wav" -# ref_text = "对,这就是我,万人敬仰的太乙真人。" -# gen_text = "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\"" - - -# -------------------------------------------------# - -use_ema = True - -if not os.path.exists(output_dir): - os.makedirs(output_dir) - -# Vocoder model -local = False -if local: - vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" - vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") - state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device) - vocos.load_state_dict(state_dict) - vocos.eval() -else: - vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") - -# Tokenizer -vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) - -# Model -model = CFM( - transformer = model_cls( - **model_cfg, - text_num_embeds = vocab_size, - mel_dim = n_mel_channels - ), - mel_spec_kwargs = dict( - target_sample_rate = target_sample_rate, - n_mel_channels = n_mel_channels, - hop_length = hop_length, - ), - odeint_kwargs = dict( - method = ode_method, - ), - vocab_char_map = vocab_char_map, -).to(device) - -model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema) - -# Audio -audio, sr = torchaudio.load(ref_audio) -if audio.shape[0] > 1: - audio = torch.mean(audio, dim=0, keepdim=True) -rms = torch.sqrt(torch.mean(torch.square(audio))) -if rms < target_rms: - audio = audio * target_rms / rms -if sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(sr, target_sample_rate) - audio = resampler(audio) -audio = audio.to(device) - -# Text -if len(ref_text[-1].encode('utf-8')) == 1: - ref_text = ref_text + " " -text_list = [ref_text + gen_text] -if tokenizer == "pinyin": - final_text_list = convert_char_to_pinyin(text_list) -else: - final_text_list = [text_list] -print(f"text : {text_list}") -print(f"pinyin: {final_text_list}") - -# Duration -ref_audio_len = audio.shape[-1] // hop_length -if fix_duration is not None: - duration = int(fix_duration * target_sample_rate / hop_length) -else: # simple linear scale calcul - zh_pause_punc = r"。,、;:?!" - ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text)) - gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text)) - duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) - -# Inference -with torch.inference_mode(): - generated, trajectory = model.sample( - cond = audio, - text = final_text_list, - duration = duration, - steps = nfe_step, - cfg_strength = cfg_strength, - sway_sampling_coef = sway_sampling_coef, - seed = seed, - ) -print(f"Generated mel: {generated.shape}") - -# Final result -generated = generated[:, ref_audio_len:, :] -generated_mel_spec = rearrange(generated, '1 n d -> 1 d n') -generated_wave = vocos.decode(generated_mel_spec.cpu()) -if rms < target_rms: - generated_wave = generated_wave * rms / target_rms - -save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single.png") -torchaudio.save(f"{output_dir}/test_single.wav", generated_wave, target_sample_rate) -print(f"Generated wav: {generated_wave.shape}") diff --git a/test_train.py b/train.py similarity index 100% rename from test_train.py rename to train.py