mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-16 23:12:24 -08:00
minor fix. speech_edit & eval_infer_batch
This commit is contained in:
@@ -187,7 +187,7 @@ def main():
|
||||
# Final result
|
||||
for i, gen in enumerate(generated):
|
||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||
gen_mel_spec = gen.permute(0, 2, 1)
|
||||
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||||
if mel_spec_type == "vocos":
|
||||
generated_wave = vocoder.decode(gen_mel_spec)
|
||||
elif mel_spec_type == "bigvgan":
|
||||
@@ -195,7 +195,7 @@ def main():
|
||||
|
||||
if ref_rms_list[i] < target_rms:
|
||||
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||||
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
|
||||
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.cpu(), target_sample_rate)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
|
||||
@@ -109,13 +109,16 @@ ckpt_file = args.ckpt_file if args.ckpt_file else ""
|
||||
vocab_file = args.vocab_file if args.vocab_file else ""
|
||||
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
||||
speed = args.speed
|
||||
|
||||
wave_path = Path(output_dir) / "infer_cli_out.wav"
|
||||
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
|
||||
if args.vocoder_name == "vocos":
|
||||
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
|
||||
elif args.vocoder_name == "bigvgan":
|
||||
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
||||
|
||||
vocoder_name = args.vocoder_name
|
||||
mel_spec_type = args.vocoder_name
|
||||
if vocoder_name == "vocos":
|
||||
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
|
||||
elif vocoder_name == "bigvgan":
|
||||
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
||||
|
||||
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
|
||||
|
||||
@@ -125,19 +128,20 @@ if model == "F5-TTS":
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
if ckpt_file == "":
|
||||
if args.vocoder_name == "vocos":
|
||||
if vocoder_name == "vocos":
|
||||
repo_name = "F5-TTS"
|
||||
exp_name = "F5TTS_Base"
|
||||
ckpt_step = 1200000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
||||
elif args.vocoder_name == "bigvgan":
|
||||
elif vocoder_name == "bigvgan":
|
||||
repo_name = "F5-TTS"
|
||||
exp_name = "F5TTS_Base_bigvgan"
|
||||
ckpt_step = 1250000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
|
||||
|
||||
elif model == "E2-TTS":
|
||||
assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos"
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
if ckpt_file == "":
|
||||
@@ -146,15 +150,10 @@ elif model == "E2-TTS":
|
||||
ckpt_step = 1200000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
||||
elif args.vocoder_name == "bigvgan": # TODO: need to test
|
||||
repo_name = "F5-TTS"
|
||||
exp_name = "F5TTS_Base_bigvgan"
|
||||
ckpt_step = 1250000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
|
||||
|
||||
|
||||
print(f"Using {model}...")
|
||||
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=args.vocoder_name, vocab_file=vocab_file)
|
||||
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file)
|
||||
|
||||
|
||||
def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
|
||||
|
||||
@@ -187,5 +187,5 @@ with torch.inference_mode():
|
||||
generated_wave = generated_wave * rms / target_rms
|
||||
|
||||
save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
|
||||
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
|
||||
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.cpu(), target_sample_rate)
|
||||
print(f"Generated wav: {generated_wave.shape}")
|
||||
|
||||
Reference in New Issue
Block a user