diff --git a/README.md b/README.md index 3023f26..dcbfb75 100644 --- a/README.md +++ b/README.md @@ -147,11 +147,11 @@ Note: Some model components have linting exceptions for E722 to accommodate tens ## Acknowledgements - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective -- [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets +- [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder -- [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools +- [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~ - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman) diff --git a/src/f5_tts/eval/README.md b/src/f5_tts/eval/README.md index ff324a1..c33ef92 100644 --- a/src/f5_tts/eval/README.md +++ b/src/f5_tts/eval/README.md @@ -39,11 +39,14 @@ Then update in the following scripts with the paths you put evaluation model ckp ### Objective Evaluation -Update the path with your batch-inferenced results, and carry out WER / SIM evaluations: +Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations: ```bash -# Evaluation for Seed-TTS test set -python src/f5_tts/eval/eval_seedtts_testset.py --gen_wav_dir +# Evaluation [WER] for Seed-TTS test [ZH] set +python src/f5_tts/eval/eval_seedtts_testset.py --eval_task wer --lang zh --gen_wav_dir --gpu_nums 8 -# Evaluation for LibriSpeech-PC test-clean (cross-sentence) -python src/f5_tts/eval/eval_librispeech_test_clean.py --gen_wav_dir --librispeech_test_clean_path -``` \ No newline at end of file +# Evaluation [SIM] for LibriSpeech-PC test-clean (cross-sentence) +python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_dir --librispeech_test_clean_path + +# Evaluation [UTMOS]. --ext: Audio extension +python src/f5_tts/eval/eval_utmos.py --audio_dir --ext wav +``` diff --git a/src/f5_tts/eval/eval_librispeech_test_clean.py b/src/f5_tts/eval/eval_librispeech_test_clean.py index 631f284..f172286 100644 --- a/src/f5_tts/eval/eval_librispeech_test_clean.py +++ b/src/f5_tts/eval/eval_librispeech_test_clean.py @@ -1,8 +1,9 @@ # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation) -import sys -import os import argparse +import json +import os +import sys sys.path.append(os.getcwd()) @@ -10,7 +11,6 @@ import multiprocessing as mp from importlib.resources import files import numpy as np -import json from f5_tts.eval.utils_eval import ( get_librispeech_test, run_asr_wer, @@ -54,36 +54,41 @@ def main(): wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth" # --------------------------- WER --------------------------- + if eval_task == "wer": - wers = [] wer_results = [] + wers = [] + with mp.Pool(processes=len(gpus)) as pool: args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_asr_wer, args) - for wers_ in results: - wers.extend(wers_) + for r in results: + wer_results.extend(r) - with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f: - for line in wers: - wer_results.append(line["wer"]) + wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl" + with open(wer_result_path, "w") as f: + for line in wer_results: + wers.append(line["wer"]) json_line = json.dumps(line, ensure_ascii=False) f.write(json_line + "\n") wer = round(np.mean(wers) * 100, 3) print(f"\nTotal {len(wers)} samples") print(f"WER : {wer}%") + print(f"Results have been saved to {wer_result_path}") # --------------------------- SIM --------------------------- + if eval_task == "sim": - sim_list = [] + sims = [] with mp.Pool(processes=len(gpus)) as pool: args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_sim, args) - for sim_ in results: - sim_list.extend(sim_) + for r in results: + sims.extend(r) - sim = round(sum(sim_list) / len(sim_list), 3) - print(f"\nTotal {len(sim_list)} samples") + sim = round(sum(sims) / len(sims), 3) + print(f"\nTotal {len(sims)} samples") print(f"SIM : {sim}") diff --git a/src/f5_tts/eval/eval_seedtts_testset.py b/src/f5_tts/eval/eval_seedtts_testset.py index 967a919..95a5f44 100644 --- a/src/f5_tts/eval/eval_seedtts_testset.py +++ b/src/f5_tts/eval/eval_seedtts_testset.py @@ -1,8 +1,9 @@ # Evaluate with Seed-TTS testset -import sys -import os import argparse +import json +import os +import sys sys.path.append(os.getcwd()) @@ -10,7 +11,6 @@ import multiprocessing as mp from importlib.resources import files import numpy as np -import json from f5_tts.eval.utils_eval import ( get_seed_tts_test, run_asr_wer, @@ -55,35 +55,39 @@ def main(): # --------------------------- WER --------------------------- if eval_task == "wer": - wers = [] wer_results = [] + wers = [] + with mp.Pool(processes=len(gpus)) as pool: args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_asr_wer, args) - for wers_ in results: - wers.extend(wers_) + for r in results: + wer_results.extend(r) - with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f: - for line in wers: - wer_results.append(line["wer"]) + wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl" + with open(wer_result_path, "w") as f: + for line in wer_results: + wers.append(line["wer"]) json_line = json.dumps(line, ensure_ascii=False) f.write(json_line + "\n") wer = round(np.mean(wers) * 100, 3) print(f"\nTotal {len(wers)} samples") print(f"WER : {wer}%") + print(f"Results have been saved to {wer_result_path}") # --------------------------- SIM --------------------------- + if eval_task == "sim": - sim_list = [] + sims = [] with mp.Pool(processes=len(gpus)) as pool: args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_sim, args) - for sim_ in results: - sim_list.extend(sim_) + for r in results: + sims.extend(r) - sim = round(sum(sim_list) / len(sim_list), 3) - print(f"\nTotal {len(sim_list)} samples") + sim = round(sum(sims) / len(sims), 3) + print(f"\nTotal {len(sims)} samples") print(f"SIM : {sim}") diff --git a/src/f5_tts/eval/eval_utmos.py b/src/f5_tts/eval/eval_utmos.py index 65196ce..9b069cd 100644 --- a/src/f5_tts/eval/eval_utmos.py +++ b/src/f5_tts/eval/eval_utmos.py @@ -1,46 +1,43 @@ -import torch -import librosa -from pathlib import Path -import json -from tqdm import tqdm import argparse +import json +from pathlib import Path + +import librosa +import torch +from tqdm import tqdm def main(): - parser = argparse.ArgumentParser(description="Evaluate UTMOS scores for audio files.") - parser.add_argument( - "--audio_dir", type=str, required=True, help="Path to the directory containing WAV audio files." - ) - parser.add_argument("--ext", type=str, default="wav", help="audio extension.") - parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (e.g. 'cuda' or 'cpu').") - + parser = argparse.ArgumentParser(description="UTMOS Evaluation") + parser.add_argument("--audio_dir", type=str, required=True, help="Audio file path.") + parser.add_argument("--ext", type=str, default="wav", help="Audio extension.") args = parser.parse_args() - device = "cuda" if args.device and torch.cuda.is_available() else "cpu" + device = "cuda" if torch.cuda.is_available() else "cpu" predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True) predictor = predictor.to(device) - lines = list(Path(args.audio_dir).rglob(f"*.{args.ext}")) - results = {} - utmos_result = 0 + audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}")) + utmos_results = {} + utmos_score = 0 - for line in tqdm(lines, desc="Processing"): - wave_name = line.stem - wave, sr = librosa.load(line, sr=None, mono=True) - wave_tensor = torch.from_numpy(wave).to(device).unsqueeze(0) - score = predictor(wave_tensor, sr) - results[str(wave_name)] = score.item() - utmos_result += score.item() + for audio_path in tqdm(audio_paths, desc="Processing"): + wav_name = audio_path.stem + wav, sr = librosa.load(audio_path, sr=None, mono=True) + wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0) + score = predictor(wav_tensor, sr) + utmos_results[str(wav_name)] = score.item() + utmos_score += score.item() - avg_score = utmos_result / len(lines) if len(lines) > 0 else 0 + avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0 print(f"UTMOS: {avg_score}") - output_path = Path(args.audio_dir) / "utmos_results.json" - with open(output_path, "w", encoding="utf-8") as f: - json.dump(results, f, ensure_ascii=False, indent=4) + utmos_result_path = Path(args.audio_dir) / "utmos_results.json" + with open(utmos_result_path, "w", encoding="utf-8") as f: + json.dump(utmos_results, f, ensure_ascii=False, indent=4) - print(f"Results have been saved to {output_path}") + print(f"Results have been saved to {utmos_result_path}") if __name__ == "__main__": diff --git a/src/f5_tts/eval/utils_eval.py b/src/f5_tts/eval/utils_eval.py index 72f1759..7c0a8a8 100644 --- a/src/f5_tts/eval/utils_eval.py +++ b/src/f5_tts/eval/utils_eval.py @@ -2,12 +2,13 @@ import math import os import random import string +from pathlib import Path import torch import torch.nn.functional as F import torchaudio from tqdm import tqdm -from pathlib import Path + from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL from f5_tts.model.modules import MelSpec from f5_tts.model.utils import convert_char_to_pinyin @@ -320,7 +321,7 @@ def run_asr_wer(args): from zhon.hanzi import punctuation punctuation_all = punctuation + string.punctuation - wers = [] + wer_results = [] from jiwer import compute_measures @@ -335,8 +336,8 @@ def run_asr_wer(args): for segment in segments: hypo = hypo + " " + segment.text - # raw_truth = truth - # raw_hypo = hypo + raw_truth = truth + raw_hypo = hypo for x in punctuation_all: truth = truth.replace(x, "") @@ -360,16 +361,16 @@ def run_asr_wer(args): # dele = measures["deletions"] / len(ref_list) # inse = measures["insertions"] / len(ref_list) - wers.append( + wer_results.append( { - "wav": Path(gen_wav).stem, # wav name - "truth": truth, # raw_truth - "hypo": hypo, # raw_hypo - "wer": wer, # wer score + "wav": Path(gen_wav).stem, + "truth": raw_truth, + "hypo": raw_hypo, + "wer": wer, } ) - return wers + return wer_results # SIM Evaluation @@ -388,7 +389,7 @@ def run_sim(args): model = model.cuda(device) model.eval() - sim_list = [] + sims = [] for wav1, wav2, truth in tqdm(test_set): wav1, sr1 = torchaudio.load(wav1) wav2, sr2 = torchaudio.load(wav2) @@ -407,6 +408,6 @@ def run_sim(args): sim = F.cosine_similarity(emb1, emb2)[0].item() # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).") - sim_list.append(sim) + sims.append(sim) - return sim_list + return sims diff --git a/src/f5_tts/infer/README.md b/src/f5_tts/infer/README.md index fe48a78..0c706b5 100644 --- a/src/f5_tts/infer/README.md +++ b/src/f5_tts/infer/README.md @@ -64,6 +64,9 @@ f5-tts_infer-cli \ # Choose Vocoder f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file + +# More instructions +f5-tts_infer-cli --help ``` And a `.toml` file would help with more flexible usage. diff --git a/src/f5_tts/infer/SHARED.md b/src/f5_tts/infer/SHARED.md index c67f60f..a09bef9 100644 --- a/src/f5_tts/infer/SHARED.md +++ b/src/f5_tts/infer/SHARED.md @@ -22,12 +22,12 @@ - [Finnish Common\_Voice Vox\_Populi @ finetune @ fi](#finnish-common_voice-vox_populi--finetune--fi) - [French](#french) - [French LibriVox @ finetune @ fr](#french-librivox--finetune--fr) +- [Hindi](#hindi) + - [F5-TTS Small @ pretrain @ hi](#f5-tts-small--pretrain--hi) - [Italian](#italian) - [F5-TTS Italian @ finetune @ it](#f5-tts-italian--finetune--it) - [Japanese](#japanese) - [F5-TTS Japanese @ pretrain/finetune @ ja](#f5-tts-japanese--pretrainfinetune--ja) -- [Hindi](#hindi) - - [F5-TTS Small @ pretrain @ hi](#f5-tts-small--pretrain--hi) - [Mandarin](#mandarin) - [Spanish](#spanish) - [F5-TTS Spanish @ pretrain/finetune @ es](#f5-tts-spanish--pretrainfinetune--es) @@ -81,6 +81,23 @@ VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt - [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434). +## Hindi + +#### F5-TTS Small @ pretrain @ hi +|Model|🤗Hugging Face|Data (Hours)|Model License| +|:---:|:------------:|:-----------:|:-------------:| +|F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0| + +```bash +MODEL_CKPT: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors +VOCAB_FILE: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt +``` + +Authors: SPRING Lab, Indian Institute of Technology, Madras +
+Website: https://asr.iitm.ac.in/ + + ## Italian #### F5-TTS Italian @ finetune @ it @@ -110,21 +127,6 @@ MODEL_CKPT: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt ``` -## Hindi - -#### F5-TTS Small @ pretrain @ hi -|Model|🤗Hugging Face|Data (Hours)|Model License| -|:---:|:------------:|:-----------:|:-------------:| -|F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0| - -```bash -MODEL_CKPT: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors -VOCAB_FILE: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt -``` - -Authors: SPRING Lab, Indian Institute of Technology, Madras -
-Website: https://asr.iitm.ac.in/ ## Mandarin diff --git a/src/f5_tts/infer/examples/basic/basic.toml b/src/f5_tts/infer/examples/basic/basic.toml index 4c594c7..c43af38 100644 --- a/src/f5_tts/infer/examples/basic/basic.toml +++ b/src/f5_tts/infer/examples/basic/basic.toml @@ -8,4 +8,4 @@ gen_text = "I don't really care what you call me. I've been a silent spectator, gen_file = "" remove_silence = false output_dir = "tests" -output_file = "infer_cli_out.wav" +output_file = "infer_cli_basic.wav" diff --git a/src/f5_tts/infer/examples/multi/story.toml b/src/f5_tts/infer/examples/multi/story.toml index c637062..10ba3fc 100644 --- a/src/f5_tts/infer/examples/multi/story.toml +++ b/src/f5_tts/infer/examples/multi/story.toml @@ -8,6 +8,7 @@ gen_text = "" gen_file = "infer/examples/multi/story.txt" remove_silence = true output_dir = "tests" +output_file = "infer_cli_story.wav" [voices.town] ref_audio = "infer/examples/multi/town.flac" diff --git a/src/f5_tts/infer/infer_cli.py b/src/f5_tts/infer/infer_cli.py index cc28182..47a71a5 100644 --- a/src/f5_tts/infer/infer_cli.py +++ b/src/f5_tts/infer/infer_cli.py @@ -2,6 +2,7 @@ import argparse import codecs import os import re +from datetime import datetime from importlib.resources import files from pathlib import Path @@ -11,6 +12,14 @@ import tomli from cached_path import cached_path from f5_tts.infer.utils_infer import ( + mel_spec_type, + target_rms, + cross_fade_duration, + nfe_step, + cfg_strength, + sway_sampling_coef, + speed, + fix_duration, infer_process, load_model, load_vocoder, @@ -19,6 +28,7 @@ from f5_tts.infer.utils_infer import ( ) from f5_tts.model import DiT, UNetT + parser = argparse.ArgumentParser( prog="python3 infer-cli.py", description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.", @@ -27,86 +37,161 @@ parser = argparse.ArgumentParser( parser.add_argument( "-c", "--config", - help="Configuration file. Default=infer/examples/basic/basic.toml", + type=str, default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"), + help="The configuration file, default see infer/examples/basic/basic.toml", ) + + +# Note. Not to provide default value here in order to read default from config file + parser.add_argument( "-m", "--model", - help="F5-TTS | E2-TTS", + type=str, + help="The model name: F5-TTS | E2-TTS", ) parser.add_argument( "-p", "--ckpt_file", - help="The Checkpoint .pt", + type=str, + help="The path to model checkpoint .pt, leave blank to use default", ) parser.add_argument( "-v", "--vocab_file", - help="The vocab .txt", + type=str, + help="The path to vocab file .txt, leave blank to use default", +) +parser.add_argument( + "-r", + "--ref_audio", + type=str, + help="The reference audio file.", +) +parser.add_argument( + "-s", + "--ref_text", + type=str, + help="The transcript/subtitle for the reference audio", ) -parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.") -parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.") parser.add_argument( "-t", "--gen_text", type=str, - help="Text to generate.", + help="The text to make model synthesize a speech", ) parser.add_argument( "-f", "--gen_file", type=str, - help="File with text to generate. Ignores --gen_text", + help="The file with text to generate, will ignore --gen_text", ) parser.add_argument( "-o", "--output_dir", type=str, - help="Path to output folder..", + help="The path to output folder", ) parser.add_argument( "-w", "--output_file", type=str, - help="Filename of output file..", + help="The name of output file", ) parser.add_argument( "--save_chunk", action="store_true", - help="Save chunk audio if your text is too long.", + help="To save each audio chunks during inference", ) parser.add_argument( "--remove_silence", - help="Remove silence.", + action="store_true", + help="To remove long silence found in ouput", ) -parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name") parser.add_argument( "--load_vocoder_from_local", action="store_true", - help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz", + help="To load vocoder from local dir, default to ../checkpoints/charactr/vocos-mel-24khz", ) parser.add_argument( - "--speed", + "--vocoder_name", + type=str, + choices=["vocos", "bigvgan"], + help=f"Used vocoder name: vocos | bigvgan, default {mel_spec_type}", +) +parser.add_argument( + "--target_rms", type=float, - default=1.0, - help="Adjust the speed of the audio generation (default: 1.0)", + help=f"Target output speech loudness normalization value, default {target_rms}", +) +parser.add_argument( + "--cross_fade_duration", + type=float, + help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}", ) parser.add_argument( "--nfe_step", type=int, - default=32, - help="Set the number of denoising steps (default: 32)", + help=f"The number of function evaluation (denoising steps), default {nfe_step}", +) +parser.add_argument( + "--cfg_strength", + type=float, + help=f"Classifier-free guidance strength, default {cfg_strength}", +) +parser.add_argument( + "--sway_sampling_coef", + type=float, + help=f"Sway Sampling coefficient, default {sway_sampling_coef}", +) +parser.add_argument( + "--speed", + type=float, + help=f"The speed of the generated audio, default {speed}", +) +parser.add_argument( + "--fix_duration", + type=float, + help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}", ) args = parser.parse_args() + +# config file + config = tomli.load(open(args.config, "rb")) -ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"] -ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"] -gen_text = args.gen_text if args.gen_text else config["gen_text"] -gen_file = args.gen_file if args.gen_file else config["gen_file"] -save_chunk = args.save_chunk if args.save_chunk else False + +# command-line interface parameters + +model = args.model or config.get("model", "F5-TTS") +ckpt_file = args.ckpt_file or config.get("ckpt_file", "") +vocab_file = args.vocab_file or config.get("vocab_file", "") + +ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav") +ref_text = args.ref_text or config.get("ref_text", "Some call me nature, others call me mother nature.") +gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.") +gen_file = args.gen_file or config.get("gen_file", "") + +output_dir = args.output_dir or config.get("output_dir", "tests") +output_file = args.output_file or config.get( + "output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav" +) + +save_chunk = args.save_chunk +remove_silence = args.remove_silence +load_vocoder_from_local = args.load_vocoder_from_local + +vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type) +target_rms = args.target_rms or config.get("target_rms", target_rms) +cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration) +nfe_step = args.nfe_step or config.get("nfe_step", nfe_step) +cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength) +sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef) +speed = args.speed or config.get("speed", speed) +fix_duration = args.fix_duration or config.get("fix_duration", fix_duration) + # patches for pip pkg user if "infer/examples/" in ref_audio: @@ -119,35 +204,39 @@ if "voices" in config: if "infer/examples/" in voice_ref_audio: config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}")) + +# ignore gen_text if gen_file provided + if gen_file: gen_text = codecs.open(gen_file, "r", "utf-8").read() -output_dir = args.output_dir if args.output_dir else config["output_dir"] -output_file = args.output_file if args.output_file else config["output_file"] -model = args.model if args.model else config["model"] -ckpt_file = args.ckpt_file if args.ckpt_file else "" -vocab_file = args.vocab_file if args.vocab_file else "" -remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"] -speed = args.speed -nfe_step = args.nfe_step + + +# output path wave_path = Path(output_dir) / output_file # spectrogram_path = Path(output_dir) / "infer_cli_out.png" +if save_chunk: + output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks") + if not os.path.exists(output_chunk_dir): + os.makedirs(output_chunk_dir) + + +# load vocoder -vocoder_name = args.vocoder_name -mel_spec_type = args.vocoder_name if vocoder_name == "vocos": vocoder_local_path = "../checkpoints/vocos-mel-24khz" elif vocoder_name == "bigvgan": vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" -vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path) +vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path) -# load models +# load TTS model + if model == "F5-TTS": model_cls = DiT model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) - if ckpt_file == "": + if not ckpt_file: # path not specified, download from repo if vocoder_name == "vocos": repo_name = "F5-TTS" exp_name = "F5TTS_Base" @@ -164,19 +253,21 @@ elif model == "E2-TTS": assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos" model_cls = UNetT model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) - if ckpt_file == "": + if not ckpt_file: # path not specified, download from repo repo_name = "E2-TTS" exp_name = "E2TTS_Base" ckpt_step = 1200000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path - print(f"Using {model}...") -ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file) +ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file) -def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed): +# inference process + + +def main(): main_voice = {"ref_audio": ref_audio, "ref_text": ref_text} if "voices" not in config: voices = {"main": main_voice} @@ -184,16 +275,16 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove voices = config["voices"] voices["main"] = main_voice for voice in voices: + print("Voice:", voice) + print("ref_audio ", voices[voice]["ref_audio"]) voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text( voices[voice]["ref_audio"], voices[voice]["ref_text"] ) - print("Voice:", voice) - print("Ref_audio:", voices[voice]["ref_audio"]) - print("Ref_text:", voices[voice]["ref_text"]) + print("ref_audio_", voices[voice]["ref_audio"], "\n\n") generated_audio_segments = [] reg1 = r"(?=\[\w+\])" - chunks = re.split(reg1, text_gen) + chunks = re.split(reg1, gen_text) reg2 = r"\[(\w+)\]" for text in chunks: if not text.strip(): @@ -208,21 +299,35 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove print(f"Voice {voice} not found, using main.") voice = "main" text = re.sub(reg2, "", text) - gen_text = text.strip() - ref_audio = voices[voice]["ref_audio"] - ref_text = voices[voice]["ref_text"] + ref_audio_ = voices[voice]["ref_audio"] + ref_text_ = voices[voice]["ref_text"] + gen_text_ = text.strip() print(f"Voice: {voice}") - audio, final_sample_rate, spectragram = infer_process( - ref_audio, - ref_text, - gen_text, - model_obj, + audio_segment, final_sample_rate, spectragram = infer_process( + ref_audio_, + ref_text_, + gen_text_, + ema_model, vocoder, - mel_spec_type=mel_spec_type, - speed=speed, + mel_spec_type=vocoder_name, + target_rms=target_rms, + cross_fade_duration=cross_fade_duration, nfe_step=nfe_step, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + speed=speed, + fix_duration=fix_duration, ) - generated_audio_segments.append(audio) + generated_audio_segments.append(audio_segment) + + if save_chunk: + if len(gen_text_) > 200: + gen_text_ = gen_text_[:200] + " ... " + sf.write( + os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"), + audio_segment, + final_sample_rate, + ) if generated_audio_segments: final_wave = np.concatenate(generated_audio_segments) @@ -236,22 +341,6 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove if remove_silence: remove_silence_for_generated_wav(f.name) print(f.name) - # Ensure the gen_text chunk directory exists - - if save_chunk: - gen_text_chunk_dir = os.path.join(output_dir, "chunks") - if not os.path.exists(gen_text_chunk_dir): # if Not create directory - os.makedirs(gen_text_chunk_dir) - - # Save individual chunks as separate files - for idx, segment in enumerate(generated_audio_segments): - gen_text_chunk_path = os.path.join(output_dir, gen_text_chunk_dir, f"chunk_{idx}.wav") - sf.write(gen_text_chunk_path, segment, final_sample_rate) - print(f"Saved gen_text chunk {idx} at {gen_text_chunk_path}") - - -def main(): - main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed) if __name__ == "__main__":