Add speed control to F5-TTS inference CLI

- Added support for the --speed argument to control the speed of generated audio.
- Updated the CLI to accept a speed parameter with a default value of 1.0.
- Adjusted the infer_process to apply the speed factor during TTS generation.
This commit is contained in:
Justin John
2024-10-26 10:45:11 +05:30
committed by GitHub
parent e963929b8e
commit ed179067df

View File

@@ -75,6 +75,12 @@ parser.add_argument(
action="store_true",
help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
)
parser.add_argument(
"--speed",
type=float,
default=1.0,
help="Adjust the speed of the audio generation (default: 1.0)",
)
args = parser.parse_args()
config = tomli.load(open(args.config, "rb"))
@@ -102,6 +108,7 @@ model = args.model if args.model else config["model"]
ckpt_file = args.ckpt_file if args.ckpt_file else ""
vocab_file = args.vocab_file if args.vocab_file else ""
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
speed = args.speed # The new speed argument
wave_path = Path(output_dir) / "infer_cli_out.wav"
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
@@ -134,7 +141,7 @@ print(f"Using {model}...")
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence, speed):
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
if "voices" not in config:
voices = {"main": main_voice}
@@ -168,7 +175,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
ref_audio = voices[voice]["ref_audio"]
ref_text = voices[voice]["ref_text"]
print(f"Voice: {voice}")
audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj, speed=speed)
generated_audio_segments.append(audio)
if generated_audio_segments:
@@ -186,7 +193,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
def main():
main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence)
main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence, speed)
if __name__ == "__main__":