This commit is contained in:
unknown
2024-10-29 17:14:03 +02:00
parent 886500ac97
commit 8a5041ef9f

View File

@@ -326,29 +326,32 @@ class Trainer:
and self.export_samples
and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0
):
wave_org, wave_gen, mel_org, mel_gen = get_sample(
vocos,
self.model,
self.file_path_samples,
global_step,
batch["mel"][0],
text_inputs,
target_sample_rate,
hop_length,
nfe_step,
cfg_strength,
sway_sampling_coef,
)
try:
wave_org, wave_gen, mel_org, mel_gen = get_sample(
vocos,
self.model,
self.file_path_samples,
global_step,
batch["mel"][0],
text_inputs,
target_sample_rate,
hop_length,
nfe_step,
cfg_strength,
sway_sampling_coef,
)
if self.logger == "tensorboard":
self.writer.add_audio(
"Audio/original", wave_org, global_step, sample_rate=target_sample_rate
)
self.writer.add_audio(
"Audio/generate", wave_gen, global_step, sample_rate=target_sample_rate
)
self.writer.add_image("Mel/original", mel_org, global_step, dataformats="CHW")
self.writer.add_image("Mel/generate", mel_gen, global_step, dataformats="CHW")
if self.logger == "tensorboard":
self.writer.add_audio(
"Audio/original", wave_org, global_step, sample_rate=target_sample_rate
)
self.writer.add_audio(
"Audio/generate", wave_gen, global_step, sample_rate=target_sample_rate
)
self.writer.add_image("Mel/original", mel_org, global_step, dataformats="CHW")
self.writer.add_image("Mel/generate", mel_gen, global_step, dataformats="CHW")
except Exception as e:
print("An error occurred:", e)
self.accelerator.backward(loss)