diff --git a/src/f5_tts/eval/eval_infer_batch.py b/src/f5_tts/eval/eval_infer_batch.py index c425726..8598f48 100644 --- a/src/f5_tts/eval/eval_infer_batch.py +++ b/src/f5_tts/eval/eval_infer_batch.py @@ -189,13 +189,13 @@ def main(): gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) if mel_spec_type == "vocos": - generated_wave = vocoder.decode(gen_mel_spec) + generated_wave = vocoder.decode(gen_mel_spec).cpu() elif mel_spec_type == "bigvgan": - generated_wave = vocoder(gen_mel_spec) + generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() if ref_rms_list[i] < target_rms: generated_wave = generated_wave * ref_rms_list[i] / target_rms - torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.cpu(), target_sample_rate) + torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate) accelerator.wait_for_everyone() if accelerator.is_main_process: diff --git a/src/f5_tts/infer/speech_edit.py b/src/f5_tts/infer/speech_edit.py index 4eee068..fc6505c 100644 --- a/src/f5_tts/infer/speech_edit.py +++ b/src/f5_tts/infer/speech_edit.py @@ -181,13 +181,13 @@ with torch.inference_mode(): generated = generated[:, ref_audio_len:, :] gen_mel_spec = generated.permute(0, 2, 1) if mel_spec_type == "vocos": - generated_wave = vocoder.decode(gen_mel_spec) + generated_wave = vocoder.decode(gen_mel_spec).cpu() elif mel_spec_type == "bigvgan": - generated_wave = vocoder(gen_mel_spec) + generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() if rms < target_rms: generated_wave = generated_wave * rms / target_rms save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png") - torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.cpu(), target_sample_rate) + torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate) print(f"Generated wav: {generated_wave.shape}") diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 19fb309..51ce33f 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -324,26 +324,37 @@ class Trainer: self.save_checkpoint(global_step) if self.log_samples and self.accelerator.is_local_main_process: - ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0] - torchaudio.save( - f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate - ) + ref_audio_len = mel_lengths[0] + infer_text = [ + text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0] + ] with torch.inference_mode(): generated, _ = self.accelerator.unwrap_model(self.model).sample( cond=mel_spec[0][:ref_audio_len].unsqueeze(0), - text=[text_inputs[0] + [" "] + text_inputs[0]], + text=infer_text, duration=ref_audio_len * 2, steps=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, ) - generated = generated.to(torch.float32) - gen_audio = vocoder.decode( - generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device) - ) - torchaudio.save( - f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate - ) + generated = generated.to(torch.float32) + if self.vocoder_name == "vocos": + gen_audio = vocoder.decode( + generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device) + ).cpu() + ref_audio = vocoder.decode(batch["mel"][0].unsqueeze(0)).cpu() + elif self.vocoder_name == "bigvgan": + gen_audio = ( + vocoder( + generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device) + ) + .squeeze(0) + .cpu() + ) + ref_audio = vocoder(batch["mel"][0].unsqueeze(0)).squeeze(0).cpu() + + torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate) + torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate) if global_step % self.last_per_steps == 0: self.save_checkpoint(global_step, last=True)