This commit is contained in:
unknown
2024-10-29 16:46:38 +02:00
parent 3409192662
commit 2ca1fb7c25
2 changed files with 5 additions and 5 deletions

View File

@@ -19,7 +19,7 @@ from f5_tts.model import CFM
from f5_tts.model.utils import exists, default
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
from f5_tts.infer.utils_infer import target_sample_rate, hop_length, nfe_step, cfg_strength, sway_sampling_coef, vocos
from f5_tts.model.utils import gen_sample
from f5_tts.model.utils import get_sample
# trainer
@@ -315,7 +315,7 @@ class Trainer:
and self.export_samples
and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0
):
wave_org, wave_gen, mel_org, mel_gen = gen_sample(
wave_org, wave_gen, mel_org, mel_gen = get_sample(
vocos,
self.model,
self.file_path_samples,

View File

@@ -205,7 +205,7 @@ def export_mel(mel_colored_hwc, file_out):
plt.imsave(file_out, mel_colored_hwc)
def get_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef):
def gen_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef):
audio, sr = torchaudio.load(file_wav_org)
audio = audio.to("cuda")
ref_audio_len = audio.shape[-1] // hop_length
@@ -228,7 +228,7 @@ def get_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cf
return generated_wave_gen, generated_mel_spec_gen
def gen_sample(
def get_sample(
vocos,
model,
file_path_samples,
@@ -245,7 +245,7 @@ def gen_sample(
generated_wave_org = generated_wave_org.squeeze().cpu().numpy()
file_wav_org = os.path.join(file_path_samples, f"step_{global_step}_org.wav")
export_audio(file_wav_org, generated_wave_org, target_sample_rate)
generated_wave_gen, generated_mel_spec_gen = get_sample(
generated_wave_gen, generated_mel_spec_gen = gen_sample(
model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef
)
file_wav_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.wav")