Ensure tensors are moved to CPU before saving with torchaudio

This commit is contained in:
Justin John
2024-11-01 23:47:00 +05:30
committed by GitHub
parent b0f482421b
commit 183ad09084

View File

@@ -61,7 +61,7 @@ class Trainer:
gradient_accumulation_steps=grad_accumulation_steps,
**accelerate_kwargs,
)
self.device = self.accelerator.device
self.logger = logger
if self.logger == "wandb":
if exists(wandb_resume_id):
@@ -325,7 +325,7 @@ class Trainer:
if self.log_samples and self.accelerator.is_local_main_process:
ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate)
with torch.inference_mode():
generated, _ = self.accelerator.unwrap_model(self.model).sample(
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
@@ -336,8 +336,8 @@ class Trainer:
sway_sampling_coef=sway_sampling_coef,
)
generated = generated.to(torch.float32)
gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).cpu())
torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
gen_audio = vocoder.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.device))
torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate)
if global_step % self.last_per_steps == 0:
self.save_checkpoint(global_step, last=True)