mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-12 21:24:19 -08:00
update
This commit is contained in:
@@ -326,29 +326,32 @@ class Trainer:
|
||||
and self.export_samples
|
||||
and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0
|
||||
):
|
||||
wave_org, wave_gen, mel_org, mel_gen = get_sample(
|
||||
vocos,
|
||||
self.model,
|
||||
self.file_path_samples,
|
||||
global_step,
|
||||
batch["mel"][0],
|
||||
text_inputs,
|
||||
target_sample_rate,
|
||||
hop_length,
|
||||
nfe_step,
|
||||
cfg_strength,
|
||||
sway_sampling_coef,
|
||||
)
|
||||
try:
|
||||
wave_org, wave_gen, mel_org, mel_gen = get_sample(
|
||||
vocos,
|
||||
self.model,
|
||||
self.file_path_samples,
|
||||
global_step,
|
||||
batch["mel"][0],
|
||||
text_inputs,
|
||||
target_sample_rate,
|
||||
hop_length,
|
||||
nfe_step,
|
||||
cfg_strength,
|
||||
sway_sampling_coef,
|
||||
)
|
||||
|
||||
if self.logger == "tensorboard":
|
||||
self.writer.add_audio(
|
||||
"Audio/original", wave_org, global_step, sample_rate=target_sample_rate
|
||||
)
|
||||
self.writer.add_audio(
|
||||
"Audio/generate", wave_gen, global_step, sample_rate=target_sample_rate
|
||||
)
|
||||
self.writer.add_image("Mel/original", mel_org, global_step, dataformats="CHW")
|
||||
self.writer.add_image("Mel/generate", mel_gen, global_step, dataformats="CHW")
|
||||
if self.logger == "tensorboard":
|
||||
self.writer.add_audio(
|
||||
"Audio/original", wave_org, global_step, sample_rate=target_sample_rate
|
||||
)
|
||||
self.writer.add_audio(
|
||||
"Audio/generate", wave_gen, global_step, sample_rate=target_sample_rate
|
||||
)
|
||||
self.writer.add_image("Mel/original", mel_org, global_step, dataformats="CHW")
|
||||
self.writer.add_image("Mel/generate", mel_gen, global_step, dataformats="CHW")
|
||||
except Exception as e:
|
||||
print("An error occurred:", e)
|
||||
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user