mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-05 20:40:12 -08:00
203 lines
6.2 KiB
Python
203 lines
6.2 KiB
Python
import os
|
|
import time
|
|
import random
|
|
from tqdm import tqdm
|
|
import argparse
|
|
|
|
import torch
|
|
import torchaudio
|
|
from accelerate import Accelerator
|
|
from einops import rearrange
|
|
from ema_pytorch import EMA
|
|
from vocos import Vocos
|
|
|
|
from model import CFM, UNetT, DiT
|
|
from model.utils import (
|
|
get_tokenizer,
|
|
get_seedtts_testset_metainfo,
|
|
get_librispeech_test_clean_metainfo,
|
|
get_inference_prompt,
|
|
)
|
|
|
|
accelerator = Accelerator()
|
|
device = f"cuda:{accelerator.process_index}"
|
|
|
|
|
|
# --------------------- Dataset Settings -------------------- #
|
|
|
|
target_sample_rate = 24000
|
|
n_mel_channels = 100
|
|
hop_length = 256
|
|
target_rms = 0.1
|
|
|
|
tokenizer = "pinyin"
|
|
|
|
|
|
# ---------------------- infer setting ---------------------- #
|
|
|
|
parser = argparse.ArgumentParser(description="batch inference")
|
|
|
|
parser.add_argument('-s', '--seed', default=None, type=int)
|
|
parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
|
|
parser.add_argument('-n', '--expname', required=True)
|
|
parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
|
|
|
|
parser.add_argument('-nfe', '--nfestep', default=32, type=int)
|
|
parser.add_argument('-o', '--odemethod', default="euler")
|
|
parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
|
|
|
|
parser.add_argument('-t', '--testset', required=True)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
seed = args.seed
|
|
dataset_name = args.dataset
|
|
exp_name = args.expname
|
|
ckpt_step = args.ckptstep
|
|
checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
|
|
|
|
nfe_step = args.nfestep
|
|
ode_method = args.odemethod
|
|
sway_sampling_coef = args.swaysampling
|
|
|
|
testset = args.testset
|
|
|
|
|
|
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
|
cfg_strength = 2.
|
|
speed = 1.
|
|
use_truth_duration = False
|
|
no_ref_audio = False
|
|
|
|
|
|
if exp_name == "F5TTS_Base":
|
|
model_cls = DiT
|
|
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
|
|
|
elif exp_name == "E2TTS_Base":
|
|
model_cls = UNetT
|
|
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
|
|
|
|
|
if testset == "ls_pc_test_clean":
|
|
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
|
|
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
|
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
|
|
|
elif testset == "seedtts_test_zh":
|
|
metalst = "data/seedtts_testset/zh/meta.lst"
|
|
metainfo = get_seedtts_testset_metainfo(metalst)
|
|
|
|
elif testset == "seedtts_test_en":
|
|
metalst = "data/seedtts_testset/en/meta.lst"
|
|
metainfo = get_seedtts_testset_metainfo(metalst)
|
|
|
|
|
|
# path to save genereted wavs
|
|
if seed is None: seed = random.randint(-10000, 10000)
|
|
output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
|
|
f"seed{seed}_{ode_method}_nfe{nfe_step}" \
|
|
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
|
|
f"_cfg{cfg_strength}_speed{speed}" \
|
|
f"{'_gt-dur' if use_truth_duration else ''}" \
|
|
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
|
|
|
|
|
# -------------------------------------------------#
|
|
|
|
use_ema = True
|
|
|
|
prompts_all = get_inference_prompt(
|
|
metainfo,
|
|
speed = speed,
|
|
tokenizer = tokenizer,
|
|
target_sample_rate = target_sample_rate,
|
|
n_mel_channels = n_mel_channels,
|
|
hop_length = hop_length,
|
|
target_rms = target_rms,
|
|
use_truth_duration = use_truth_duration,
|
|
infer_batch_size = infer_batch_size,
|
|
)
|
|
|
|
# Vocoder model
|
|
local = False
|
|
if local:
|
|
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
|
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
|
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
|
|
vocos.load_state_dict(state_dict)
|
|
vocos.eval()
|
|
else:
|
|
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
|
|
|
# Tokenizer
|
|
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
|
|
|
# Model
|
|
model = CFM(
|
|
transformer = model_cls(
|
|
**model_cfg,
|
|
text_num_embeds = vocab_size,
|
|
mel_dim = n_mel_channels
|
|
),
|
|
mel_spec_kwargs = dict(
|
|
target_sample_rate = target_sample_rate,
|
|
n_mel_channels = n_mel_channels,
|
|
hop_length = hop_length,
|
|
),
|
|
odeint_kwargs = dict(
|
|
method = ode_method,
|
|
),
|
|
vocab_char_map = vocab_char_map,
|
|
).to(device)
|
|
|
|
if use_ema == True:
|
|
ema_model = EMA(model, include_online_model = False).to(device)
|
|
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
|
ema_model.copy_params_from_ema_to_model()
|
|
else:
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
|
os.makedirs(output_dir)
|
|
|
|
# start batch inference
|
|
accelerator.wait_for_everyone()
|
|
start = time.time()
|
|
|
|
with accelerator.split_between_processes(prompts_all) as prompts:
|
|
|
|
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
|
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
|
|
ref_mels = ref_mels.to(device)
|
|
ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
|
|
total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
|
|
|
|
# Inference
|
|
with torch.inference_mode():
|
|
generated, _ = model.sample(
|
|
cond = ref_mels,
|
|
text = final_text_list,
|
|
duration = total_mel_lens,
|
|
lens = ref_mel_lens,
|
|
steps = nfe_step,
|
|
cfg_strength = cfg_strength,
|
|
sway_sampling_coef = sway_sampling_coef,
|
|
no_ref_audio = no_ref_audio,
|
|
seed = seed,
|
|
)
|
|
# Final result
|
|
for i, gen in enumerate(generated):
|
|
gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
|
|
gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
|
|
generated_wave = vocos.decode(gen_mel_spec.cpu())
|
|
if ref_rms_list[i] < target_rms:
|
|
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
|
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
|
|
|
|
accelerator.wait_for_everyone()
|
|
if accelerator.is_main_process:
|
|
timediff = time.time() - start
|
|
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
|