finish trainer modification

This commit is contained in:
SWivid
2024-10-30 03:57:09 +08:00
parent 87c4f9ff06
commit aaa92f6e6d
2 changed files with 6 additions and 5 deletions

View File

@@ -191,7 +191,7 @@ class Trainer:
from f5_tts.infer.utils_infer import load_vocoder, nfe_step, cfg_strength, sway_sampling_coef
vocoder = load_vocoder()
target_sample_rate = self.model.mel_spec.mel_stft.sample_rate
target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.mel_stft.sample_rate
log_samples_path = f"{self.checkpoint_path}/samples"
os.makedirs(log_samples_path, exist_ok=True)
@@ -314,12 +314,12 @@ class Trainer:
if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
self.save_checkpoint(global_step)
if self.log_samples:
ref_audio, ref_audio_len = vocoder.decode([batch["mel"][0]].cpu()), mel_lengths[0]
if self.log_samples and self.accelerator.is_local_main_process:
ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0).cpu()), mel_lengths[0]
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
with torch.inference_mode():
generated, _ = self.model.sample(
cond=[mel_spec[0][:ref_audio_len]],
generated, _ = self.accelerator.unwrap_model(self.model).sample(
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
text=[text_inputs[0] + [" "] + text_inputs[0]],
duration=ref_audio_len * 2,
steps=nfe_step,

View File

@@ -83,6 +83,7 @@ def main():
wandb_run_name=exp_name,
wandb_resume_id=wandb_resume_id,
last_per_steps=last_per_steps,
log_samples=True,
)
train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)