diff --git a/gradio_app.py b/gradio_app.py index c696443..c70f8de 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -46,6 +46,7 @@ pipe = pipeline( torch_dtype=torch.float16, device=device, ) +vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") # --------------------- Settings -------------------- # @@ -243,8 +244,6 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, generated = generated[:, ref_audio_len:, :] generated_mel_spec = rearrange(generated, "1 n d -> 1 d n") - - vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") generated_wave = vocos.decode(generated_mel_spec.cpu()) if rms < target_rms: generated_wave = generated_wave * rms / target_rms diff --git a/inference-cli.py b/inference-cli.py index 2e6020a..770f128 100644 --- a/inference-cli.py +++ b/inference-cli.py @@ -95,6 +95,7 @@ device = ( if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) +vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") print(f"Using {device} device") @@ -286,8 +287,6 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence): generated = generated[:, ref_audio_len:, :] generated_mel_spec = rearrange(generated, "1 n d -> 1 d n") - - vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") generated_wave = vocos.decode(generated_mel_spec.cpu()) if rms < target_rms: generated_wave = generated_wave * rms / target_rms