mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-18 07:46:06 -08:00
address #75
This commit is contained in:
@@ -46,6 +46,7 @@ pipe = pipeline(
|
||||
torch_dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
|
||||
# --------------------- Settings -------------------- #
|
||||
|
||||
@@ -243,8 +244,6 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
|
||||
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
||||
|
||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * rms / target_rms
|
||||
|
||||
@@ -95,6 +95,7 @@ device = (
|
||||
if torch.cuda.is_available()
|
||||
else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
)
|
||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
|
||||
print(f"Using {device} device")
|
||||
|
||||
@@ -286,8 +287,6 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
|
||||
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
||||
|
||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * rms / target_rms
|
||||
|
||||
Reference in New Issue
Block a user