mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-13 21:47:14 -08:00
update
This commit is contained in:
@@ -19,7 +19,7 @@ from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import exists, default
|
||||
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
||||
from f5_tts.infer.utils_infer import target_sample_rate, hop_length, nfe_step, cfg_strength, sway_sampling_coef, vocos
|
||||
from f5_tts.model.utils import gen_sample
|
||||
from f5_tts.model.utils import get_sample
|
||||
|
||||
# trainer
|
||||
|
||||
@@ -315,7 +315,7 @@ class Trainer:
|
||||
and self.export_samples
|
||||
and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0
|
||||
):
|
||||
wave_org, wave_gen, mel_org, mel_gen = gen_sample(
|
||||
wave_org, wave_gen, mel_org, mel_gen = get_sample(
|
||||
vocos,
|
||||
self.model,
|
||||
self.file_path_samples,
|
||||
|
||||
@@ -205,7 +205,7 @@ def export_mel(mel_colored_hwc, file_out):
|
||||
plt.imsave(file_out, mel_colored_hwc)
|
||||
|
||||
|
||||
def get_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef):
|
||||
def gen_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef):
|
||||
audio, sr = torchaudio.load(file_wav_org)
|
||||
audio = audio.to("cuda")
|
||||
ref_audio_len = audio.shape[-1] // hop_length
|
||||
@@ -228,7 +228,7 @@ def get_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cf
|
||||
return generated_wave_gen, generated_mel_spec_gen
|
||||
|
||||
|
||||
def gen_sample(
|
||||
def get_sample(
|
||||
vocos,
|
||||
model,
|
||||
file_path_samples,
|
||||
@@ -245,7 +245,7 @@ def gen_sample(
|
||||
generated_wave_org = generated_wave_org.squeeze().cpu().numpy()
|
||||
file_wav_org = os.path.join(file_path_samples, f"step_{global_step}_org.wav")
|
||||
export_audio(file_wav_org, generated_wave_org, target_sample_rate)
|
||||
generated_wave_gen, generated_mel_spec_gen = get_sample(
|
||||
generated_wave_gen, generated_mel_spec_gen = gen_sample(
|
||||
model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef
|
||||
)
|
||||
file_wav_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.wav")
|
||||
|
||||
Reference in New Issue
Block a user