mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-02-04 19:01:32 -08:00
redirect to split hf ckpt repos
This commit is contained in:
@@ -62,8 +62,8 @@ speed = 1.0
|
||||
fix_duration = None
|
||||
|
||||
|
||||
def load_model(exp_name, model_cls, model_cfg, ckpt_step):
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
|
||||
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
||||
model = CFM(
|
||||
@@ -93,10 +93,10 @@ F5TTS_model_cfg = dict(
|
||||
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
|
||||
F5TTS_ema_model = load_model(
|
||||
"F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
|
||||
"F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
|
||||
)
|
||||
E2TTS_ema_model = load_model(
|
||||
"E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
|
||||
"E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
|
||||
)
|
||||
|
||||
def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
|
||||
|
||||
@@ -74,7 +74,7 @@ ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
|
||||
ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
|
||||
gen_text = args.gen_text if args.gen_text else config["gen_text"]
|
||||
output_dir = args.output_dir if args.output_dir else config["output_dir"]
|
||||
exp_name = args.model if args.model else config["model"]
|
||||
model = args.model if args.model else config["model"]
|
||||
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
||||
wave_path = Path(output_dir)/"out.wav"
|
||||
spectrogram_path = Path(output_dir)/"out.png"
|
||||
@@ -112,8 +112,8 @@ speed = 1.0
|
||||
# fix_duration = 27 # None or float (duration in seconds)
|
||||
fix_duration = None
|
||||
|
||||
def load_model(exp_name, model_cls, model_cfg, ckpt_step):
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
|
||||
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
||||
model = CFM(
|
||||
@@ -238,11 +238,11 @@ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
|
||||
|
||||
return batches
|
||||
|
||||
def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence):
|
||||
if exp_name == "F5-TTS":
|
||||
ema_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
|
||||
elif exp_name == "E2-TTS":
|
||||
ema_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
|
||||
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
|
||||
if model == "F5-TTS":
|
||||
ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
|
||||
elif model == "E2-TTS":
|
||||
ema_model = load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
|
||||
|
||||
audio, sr = ref_audio
|
||||
if audio.shape[0] > 1:
|
||||
@@ -320,7 +320,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence)
|
||||
print(spectrogram_path)
|
||||
|
||||
|
||||
def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words):
|
||||
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_split_words):
|
||||
if not custom_split_words.strip():
|
||||
custom_words = [word.strip() for word in custom_split_words.split(',')]
|
||||
global SPLIT_WORDS
|
||||
@@ -372,8 +372,8 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
|
||||
for i, gen_text in enumerate(gen_text_batches):
|
||||
print(f'gen_text {i}', gen_text)
|
||||
|
||||
print(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches, loading models...")
|
||||
return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
|
||||
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
||||
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
|
||||
|
||||
|
||||
infer(ref_audio, ref_text, gen_text, exp_name, remove_silence, ",".join(SPLIT_WORDS))
|
||||
infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))
|
||||
Reference in New Issue
Block a user