minor fix. speech_edit & eval_infer_batch

This commit is contained in:
SWivid
2024-11-17 06:25:07 +08:00
parent 4b9441f8e0
commit 0f80f25c5f
3 changed files with 14 additions and 15 deletions

View File

@@ -187,7 +187,7 @@ def main():
# Final result
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(gen_mel_spec)
elif mel_spec_type == "bigvgan":
@@ -195,7 +195,7 @@ def main():
if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.cpu(), target_sample_rate)
accelerator.wait_for_everyone()
if accelerator.is_main_process:

View File

@@ -109,13 +109,16 @@ ckpt_file = args.ckpt_file if args.ckpt_file else ""
vocab_file = args.vocab_file if args.vocab_file else ""
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
speed = args.speed
wave_path = Path(output_dir) / "infer_cli_out.wav"
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
if args.vocoder_name == "vocos":
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
elif args.vocoder_name == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder_name = args.vocoder_name
mel_spec_type = args.vocoder_name
if vocoder_name == "vocos":
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
elif vocoder_name == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
@@ -125,19 +128,20 @@ if model == "F5-TTS":
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
if ckpt_file == "":
if args.vocoder_name == "vocos":
if vocoder_name == "vocos":
repo_name = "F5-TTS"
exp_name = "F5TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
elif args.vocoder_name == "bigvgan":
elif vocoder_name == "bigvgan":
repo_name = "F5-TTS"
exp_name = "F5TTS_Base_bigvgan"
ckpt_step = 1250000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
elif model == "E2-TTS":
assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos"
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if ckpt_file == "":
@@ -146,15 +150,10 @@ elif model == "E2-TTS":
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
elif args.vocoder_name == "bigvgan": # TODO: need to test
repo_name = "F5-TTS"
exp_name = "F5TTS_Base_bigvgan"
ckpt_step = 1250000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
print(f"Using {model}...")
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=args.vocoder_name, vocab_file=vocab_file)
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file)
def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):

View File

@@ -187,5 +187,5 @@ with torch.inference_mode():
generated_wave = generated_wave * rms / target_rms
save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.cpu(), target_sample_rate)
print(f"Generated wav: {generated_wave.shape}")