mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-09 19:57:49 -08:00
finish trainer modification
This commit is contained in:
@@ -191,7 +191,7 @@ class Trainer:
|
||||
from f5_tts.infer.utils_infer import load_vocoder, nfe_step, cfg_strength, sway_sampling_coef
|
||||
|
||||
vocoder = load_vocoder()
|
||||
target_sample_rate = self.model.mel_spec.mel_stft.sample_rate
|
||||
target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.mel_stft.sample_rate
|
||||
log_samples_path = f"{self.checkpoint_path}/samples"
|
||||
os.makedirs(log_samples_path, exist_ok=True)
|
||||
|
||||
@@ -314,12 +314,12 @@ class Trainer:
|
||||
if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
|
||||
self.save_checkpoint(global_step)
|
||||
|
||||
if self.log_samples:
|
||||
ref_audio, ref_audio_len = vocoder.decode([batch["mel"][0]].cpu()), mel_lengths[0]
|
||||
if self.log_samples and self.accelerator.is_local_main_process:
|
||||
ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0).cpu()), mel_lengths[0]
|
||||
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
|
||||
with torch.inference_mode():
|
||||
generated, _ = self.model.sample(
|
||||
cond=[mel_spec[0][:ref_audio_len]],
|
||||
generated, _ = self.accelerator.unwrap_model(self.model).sample(
|
||||
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
|
||||
text=[text_inputs[0] + [" "] + text_inputs[0]],
|
||||
duration=ref_audio_len * 2,
|
||||
steps=nfe_step,
|
||||
|
||||
@@ -83,6 +83,7 @@ def main():
|
||||
wandb_run_name=exp_name,
|
||||
wandb_resume_id=wandb_resume_id,
|
||||
last_per_steps=last_per_steps,
|
||||
log_samples=True,
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user