diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index f451efc..33f9a3f 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -119,12 +119,16 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev asr_pipe = None -def initialize_asr_pipeline(device=device): +def initialize_asr_pipeline(device=device, dtype=None): + if dtype is None: + dtype = ( + torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32 + ) global asr_pipe asr_pipe = pipeline( "automatic-speech-recognition", model="openai/whisper-large-v3-turbo", - torch_dtype=torch.float16, + torch_dtype=dtype, device=device, ) diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index feb1735..bd96da7 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -325,7 +325,9 @@ class Trainer: if self.log_samples and self.accelerator.is_local_main_process: ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0] - torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate) + torchaudio.save( + f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate + ) with torch.inference_mode(): generated, _ = self.accelerator.unwrap_model(self.model).sample( cond=mel_spec[0][:ref_audio_len].unsqueeze(0), @@ -336,8 +338,12 @@ class Trainer: sway_sampling_coef=sway_sampling_coef, ) generated = generated.to(torch.float32) - gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).cpu()) - torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate) + gen_audio = vocoder.decode( + generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device) + ) + torchaudio.save( + f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate + ) if global_step % self.last_per_steps == 0: self.save_checkpoint(global_step, last=True)