From 0f80f25c5fc95aed21a560bec22fed9d237948bf Mon Sep 17 00:00:00 2001 From: SWivid Date: Sun, 17 Nov 2024 06:25:07 +0800 Subject: [PATCH] minor fix. speech_edit & eval_infer_batch --- src/f5_tts/eval/eval_infer_batch.py | 4 ++-- src/f5_tts/infer/infer_cli.py | 23 +++++++++++------------ src/f5_tts/infer/speech_edit.py | 2 +- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/src/f5_tts/eval/eval_infer_batch.py b/src/f5_tts/eval/eval_infer_batch.py index bbccd4f..c425726 100644 --- a/src/f5_tts/eval/eval_infer_batch.py +++ b/src/f5_tts/eval/eval_infer_batch.py @@ -187,7 +187,7 @@ def main(): # Final result for i, gen in enumerate(generated): gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) - gen_mel_spec = gen.permute(0, 2, 1) + gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) if mel_spec_type == "vocos": generated_wave = vocoder.decode(gen_mel_spec) elif mel_spec_type == "bigvgan": @@ -195,7 +195,7 @@ def main(): if ref_rms_list[i] < target_rms: generated_wave = generated_wave * ref_rms_list[i] / target_rms - torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.squeeze(0).cpu(), target_sample_rate) + torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.cpu(), target_sample_rate) accelerator.wait_for_everyone() if accelerator.is_main_process: diff --git a/src/f5_tts/infer/infer_cli.py b/src/f5_tts/infer/infer_cli.py index 487ce4a..10d6928 100644 --- a/src/f5_tts/infer/infer_cli.py +++ b/src/f5_tts/infer/infer_cli.py @@ -109,13 +109,16 @@ ckpt_file = args.ckpt_file if args.ckpt_file else "" vocab_file = args.vocab_file if args.vocab_file else "" remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"] speed = args.speed + wave_path = Path(output_dir) / "infer_cli_out.wav" # spectrogram_path = Path(output_dir) / "infer_cli_out.png" -if args.vocoder_name == "vocos": - vocoder_local_path = "../checkpoints/vocos-mel-24khz" -elif args.vocoder_name == "bigvgan": - vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" + +vocoder_name = args.vocoder_name mel_spec_type = args.vocoder_name +if vocoder_name == "vocos": + vocoder_local_path = "../checkpoints/vocos-mel-24khz" +elif vocoder_name == "bigvgan": + vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path) @@ -125,19 +128,20 @@ if model == "F5-TTS": model_cls = DiT model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) if ckpt_file == "": - if args.vocoder_name == "vocos": + if vocoder_name == "vocos": repo_name = "F5-TTS" exp_name = "F5TTS_Base" ckpt_step = 1200000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path - elif args.vocoder_name == "bigvgan": + elif vocoder_name == "bigvgan": repo_name = "F5-TTS" exp_name = "F5TTS_Base_bigvgan" ckpt_step = 1250000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")) elif model == "E2-TTS": + assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos" model_cls = UNetT model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) if ckpt_file == "": @@ -146,15 +150,10 @@ elif model == "E2-TTS": ckpt_step = 1200000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path - elif args.vocoder_name == "bigvgan": # TODO: need to test - repo_name = "F5-TTS" - exp_name = "F5TTS_Base_bigvgan" - ckpt_step = 1250000 - ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")) print(f"Using {model}...") -ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=args.vocoder_name, vocab_file=vocab_file) +ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file) def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed): diff --git a/src/f5_tts/infer/speech_edit.py b/src/f5_tts/infer/speech_edit.py index c33b21f..07bb6d6 100644 --- a/src/f5_tts/infer/speech_edit.py +++ b/src/f5_tts/infer/speech_edit.py @@ -187,5 +187,5 @@ with torch.inference_mode(): generated_wave = generated_wave * rms / target_rms save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png") - torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.squeeze(0).cpu(), target_sample_rate) + torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.cpu(), target_sample_rate) print(f"Generated wav: {generated_wave.shape}")