diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index 9f556d4..b2fd727 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -481,18 +481,16 @@ def infer_batch_process( ) del _ - generated = generated.to(torch.float32) + generated = generated.to(torch.float32) # generated mel spectrogram generated = generated[:, ref_audio_len:, :] - generated_mel_spec = generated.permute(0, 2, 1) + generated = generated.permute(0, 2, 1) if mel_spec_type == "vocos": - generated_wave = vocoder.decode(generated_mel_spec) + generated_wave = vocoder.decode(generated) elif mel_spec_type == "bigvgan": - generated_wave = vocoder(generated_mel_spec) + generated_wave = vocoder(generated) if rms < target_rms: generated_wave = generated_wave * rms / target_rms - del generated - # wav -> numpy generated_wave = generated_wave.squeeze().cpu().numpy() @@ -500,8 +498,8 @@ def infer_batch_process( for j in range(0, len(generated_wave), chunk_size): yield generated_wave[j : j + chunk_size], target_sample_rate else: - generated_cpu = generated_mel_spec[0].cpu().numpy() - del generated_mel_spec + generated_cpu = generated[0].cpu().numpy() + del generated yield generated_wave, generated_cpu if streaming: