diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 58b582d..08903d6 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -191,7 +191,7 @@ class Trainer: from f5_tts.infer.utils_infer import load_vocoder, nfe_step, cfg_strength, sway_sampling_coef vocoder = load_vocoder() - target_sample_rate = self.model.mel_spec.mel_stft.sample_rate + target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.mel_stft.sample_rate log_samples_path = f"{self.checkpoint_path}/samples" os.makedirs(log_samples_path, exist_ok=True) @@ -314,12 +314,12 @@ class Trainer: if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0: self.save_checkpoint(global_step) - if self.log_samples: - ref_audio, ref_audio_len = vocoder.decode([batch["mel"][0]].cpu()), mel_lengths[0] + if self.log_samples and self.accelerator.is_local_main_process: + ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0).cpu()), mel_lengths[0] torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate) with torch.inference_mode(): - generated, _ = self.model.sample( - cond=[mel_spec[0][:ref_audio_len]], + generated, _ = self.accelerator.unwrap_model(self.model).sample( + cond=mel_spec[0][:ref_audio_len].unsqueeze(0), text=[text_inputs[0] + [" "] + text_inputs[0]], duration=ref_audio_len * 2, steps=nfe_step, diff --git a/src/f5_tts/train/train.py b/src/f5_tts/train/train.py index a40efba..94fe9b5 100644 --- a/src/f5_tts/train/train.py +++ b/src/f5_tts/train/train.py @@ -83,6 +83,7 @@ def main(): wandb_run_name=exp_name, wandb_resume_id=wandb_resume_id, last_per_steps=last_per_steps, + log_samples=True, ) train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)