From 183ad0908403785ed7eb9ee07f5db613def13cae Mon Sep 17 00:00:00 2001 From: Justin John <34035011+justinjohn0306@users.noreply.github.com> Date: Fri, 1 Nov 2024 23:47:00 +0530 Subject: [PATCH 1/3] Ensure tensors are moved to CPU before saving with torchaudio --- src/f5_tts/model/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index feb1735..96f5d74 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -61,7 +61,7 @@ class Trainer: gradient_accumulation_steps=grad_accumulation_steps, **accelerate_kwargs, ) - + self.device = self.accelerator.device self.logger = logger if self.logger == "wandb": if exists(wandb_resume_id): @@ -325,7 +325,7 @@ class Trainer: if self.log_samples and self.accelerator.is_local_main_process: ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0] - torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate) + torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate) with torch.inference_mode(): generated, _ = self.accelerator.unwrap_model(self.model).sample( cond=mel_spec[0][:ref_audio_len].unsqueeze(0), @@ -336,8 +336,8 @@ class Trainer: sway_sampling_coef=sway_sampling_coef, ) generated = generated.to(torch.float32) - gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).cpu()) - torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate) + gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.device)) + torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate) if global_step % self.last_per_steps == 0: self.save_checkpoint(global_step, last=True) From f7e248e2ced0f1bc6885093d29893a1e4463bc71 Mon Sep 17 00:00:00 2001 From: SWivid Date: Sat, 2 Nov 2024 12:58:28 +0800 Subject: [PATCH 2/3] formatting --- src/f5_tts/model/trainer.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 96f5d74..bd96da7 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -61,7 +61,7 @@ class Trainer: gradient_accumulation_steps=grad_accumulation_steps, **accelerate_kwargs, ) - self.device = self.accelerator.device + self.logger = logger if self.logger == "wandb": if exists(wandb_resume_id): @@ -325,7 +325,9 @@ class Trainer: if self.log_samples and self.accelerator.is_local_main_process: ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0] - torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate) + torchaudio.save( + f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate + ) with torch.inference_mode(): generated, _ = self.accelerator.unwrap_model(self.model).sample( cond=mel_spec[0][:ref_audio_len].unsqueeze(0), @@ -336,8 +338,12 @@ class Trainer: sway_sampling_coef=sway_sampling_coef, ) generated = generated.to(torch.float32) - gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.device)) - torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate) + gen_audio = vocoder.decode( + generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device) + ) + torchaudio.save( + f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate + ) if global_step % self.last_per_steps == 0: self.save_checkpoint(global_step, last=True) From ea90244d6274c649252834d645765ad6811cd053 Mon Sep 17 00:00:00 2001 From: SWivid Date: Sat, 2 Nov 2024 13:48:37 +0800 Subject: [PATCH 3/3] fix. add dtype check for asr pipeline addressing #356 --- src/f5_tts/infer/utils_infer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index f451efc..33f9a3f 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -119,12 +119,16 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev asr_pipe = None -def initialize_asr_pipeline(device=device): +def initialize_asr_pipeline(device=device, dtype=None): + if dtype is None: + dtype = ( + torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32 + ) global asr_pipe asr_pipe = pipeline( "automatic-speech-recognition", model="openai/whisper-large-v3-turbo", - torch_dtype=torch.float16, + torch_dtype=dtype, device=device, )