From ed179067df5228615e67f96ec3feb63df3c6f686 Mon Sep 17 00:00:00 2001 From: Justin John <34035011+justinjohn0306@users.noreply.github.com> Date: Sat, 26 Oct 2024 10:45:11 +0530 Subject: [PATCH] Add speed control to F5-TTS inference CLI - Added support for the --speed argument to control the speed of generated audio. - Updated the CLI to accept a speed parameter with a default value of 1.0. - Adjusted the infer_process to apply the speed factor during TTS generation. --- src/f5_tts/infer/infer_cli.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/f5_tts/infer/infer_cli.py b/src/f5_tts/infer/infer_cli.py index 5109667..cea4cb3 100644 --- a/src/f5_tts/infer/infer_cli.py +++ b/src/f5_tts/infer/infer_cli.py @@ -75,6 +75,12 @@ parser.add_argument( action="store_true", help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz", ) +parser.add_argument( + "--speed", + type=float, + default=1.0, + help="Adjust the speed of the audio generation (default: 1.0)", +) args = parser.parse_args() config = tomli.load(open(args.config, "rb")) @@ -102,6 +108,7 @@ model = args.model if args.model else config["model"] ckpt_file = args.ckpt_file if args.ckpt_file else "" vocab_file = args.vocab_file if args.vocab_file else "" remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"] +speed = args.speed # The new speed argument wave_path = Path(output_dir) / "infer_cli_out.wav" # spectrogram_path = Path(output_dir) / "infer_cli_out.png" vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" @@ -134,7 +141,7 @@ print(f"Using {model}...") ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file) -def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence): +def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence, speed): main_voice = {"ref_audio": ref_audio, "ref_text": ref_text} if "voices" not in config: voices = {"main": main_voice} @@ -168,7 +175,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence): ref_audio = voices[voice]["ref_audio"] ref_text = voices[voice]["ref_text"] print(f"Voice: {voice}") - audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj) + audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj, speed=speed) generated_audio_segments.append(audio) if generated_audio_segments: @@ -186,7 +193,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence): def main(): - main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence) + main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence, speed) if __name__ == "__main__":