redirect to split hf ckpt repos

This commit is contained in:
SWivid
2024-10-15 02:12:20 +08:00
parent 372f6ab44e
commit e54fee3b7f
2 changed files with 16 additions and 16 deletions

View File

@@ -62,8 +62,8 @@ speed = 1.0
fix_duration = None
def load_model(exp_name, model_cls, model_cfg, ckpt_step):
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
model = CFM(
@@ -93,10 +93,10 @@ F5TTS_model_cfg = dict(
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
F5TTS_ema_model = load_model(
"F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
"F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
)
E2TTS_ema_model = load_model(
"E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
"E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
)
def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):

View File

@@ -74,7 +74,7 @@ ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
gen_text = args.gen_text if args.gen_text else config["gen_text"]
output_dir = args.output_dir if args.output_dir else config["output_dir"]
exp_name = args.model if args.model else config["model"]
model = args.model if args.model else config["model"]
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
wave_path = Path(output_dir)/"out.wav"
spectrogram_path = Path(output_dir)/"out.png"
@@ -112,8 +112,8 @@ speed = 1.0
# fix_duration = 27 # None or float (duration in seconds)
fix_duration = None
def load_model(exp_name, model_cls, model_cfg, ckpt_step):
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
model = CFM(
@@ -238,11 +238,11 @@ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
return batches
def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence):
if exp_name == "F5-TTS":
ema_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
elif exp_name == "E2-TTS":
ema_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
if model == "F5-TTS":
ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
elif model == "E2-TTS":
ema_model = load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
audio, sr = ref_audio
if audio.shape[0] > 1:
@@ -320,7 +320,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence)
print(spectrogram_path)
def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words):
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_split_words):
if not custom_split_words.strip():
custom_words = [word.strip() for word in custom_split_words.split(',')]
global SPLIT_WORDS
@@ -372,8 +372,8 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
for i, gen_text in enumerate(gen_text_batches):
print(f'gen_text {i}', gen_text)
print(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches, loading models...")
return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
infer(ref_audio, ref_text, gen_text, exp_name, remove_silence, ",".join(SPLIT_WORDS))
infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))