diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index c45b6b1..e9ed5d4 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -19,7 +19,7 @@ from f5_tts.model import CFM from f5_tts.model.utils import exists, default from f5_tts.model.dataset import DynamicBatchSampler, collate_fn from f5_tts.infer.utils_infer import target_sample_rate, hop_length, nfe_step, cfg_strength, sway_sampling_coef, vocos -from f5_tts.model.utils import gen_sample +from f5_tts.model.utils import get_sample # trainer @@ -315,7 +315,7 @@ class Trainer: and self.export_samples and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0 ): - wave_org, wave_gen, mel_org, mel_gen = gen_sample( + wave_org, wave_gen, mel_org, mel_gen = get_sample( vocos, self.model, self.file_path_samples, diff --git a/src/f5_tts/model/utils.py b/src/f5_tts/model/utils.py index 8877a7e..0314d1f 100644 --- a/src/f5_tts/model/utils.py +++ b/src/f5_tts/model/utils.py @@ -205,7 +205,7 @@ def export_mel(mel_colored_hwc, file_out): plt.imsave(file_out, mel_colored_hwc) -def get_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef): +def gen_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef): audio, sr = torchaudio.load(file_wav_org) audio = audio.to("cuda") ref_audio_len = audio.shape[-1] // hop_length @@ -228,7 +228,7 @@ def get_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cf return generated_wave_gen, generated_mel_spec_gen -def gen_sample( +def get_sample( vocos, model, file_path_samples, @@ -245,7 +245,7 @@ def gen_sample( generated_wave_org = generated_wave_org.squeeze().cpu().numpy() file_wav_org = os.path.join(file_path_samples, f"step_{global_step}_org.wav") export_audio(file_wav_org, generated_wave_org, target_sample_rate) - generated_wave_gen, generated_mel_spec_gen = get_sample( + generated_wave_gen, generated_mel_spec_gen = gen_sample( model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef ) file_wav_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.wav")