From 8a5041ef9f4b2d217a7709e6d86a3f3fb37c30d2 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 29 Oct 2024 17:14:03 +0200 Subject: [PATCH] update --- src/f5_tts/model/trainer.py | 47 ++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 420089a..10aa275 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -326,29 +326,32 @@ class Trainer: and self.export_samples and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0 ): - wave_org, wave_gen, mel_org, mel_gen = get_sample( - vocos, - self.model, - self.file_path_samples, - global_step, - batch["mel"][0], - text_inputs, - target_sample_rate, - hop_length, - nfe_step, - cfg_strength, - sway_sampling_coef, - ) + try: + wave_org, wave_gen, mel_org, mel_gen = get_sample( + vocos, + self.model, + self.file_path_samples, + global_step, + batch["mel"][0], + text_inputs, + target_sample_rate, + hop_length, + nfe_step, + cfg_strength, + sway_sampling_coef, + ) - if self.logger == "tensorboard": - self.writer.add_audio( - "Audio/original", wave_org, global_step, sample_rate=target_sample_rate - ) - self.writer.add_audio( - "Audio/generate", wave_gen, global_step, sample_rate=target_sample_rate - ) - self.writer.add_image("Mel/original", mel_org, global_step, dataformats="CHW") - self.writer.add_image("Mel/generate", mel_gen, global_step, dataformats="CHW") + if self.logger == "tensorboard": + self.writer.add_audio( + "Audio/original", wave_org, global_step, sample_rate=target_sample_rate + ) + self.writer.add_audio( + "Audio/generate", wave_gen, global_step, sample_rate=target_sample_rate + ) + self.writer.add_image("Mel/original", mel_org, global_step, dataformats="CHW") + self.writer.add_image("Mel/generate", mel_gen, global_step, dataformats="CHW") + except Exception as e: + print("An error occurred:", e) self.accelerator.backward(loss)