diff --git a/gradio_app.py b/gradio_app.py index 391ea2b..c696443 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -62,8 +62,8 @@ speed = 1.0 fix_duration = None -def load_model(exp_name, model_cls, model_cfg, ckpt_step): - ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors")) +def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step): + ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin") model = CFM( @@ -93,10 +93,10 @@ F5TTS_model_cfg = dict( E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) F5TTS_ema_model = load_model( - "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000 + "F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000 ) E2TTS_ema_model = load_model( - "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000 + "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000 ) def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS): diff --git a/inference-cli.py b/inference-cli.py index 7501fce..2e6020a 100644 --- a/inference-cli.py +++ b/inference-cli.py @@ -74,7 +74,7 @@ ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"] ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"] gen_text = args.gen_text if args.gen_text else config["gen_text"] output_dir = args.output_dir if args.output_dir else config["output_dir"] -exp_name = args.model if args.model else config["model"] +model = args.model if args.model else config["model"] remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"] wave_path = Path(output_dir)/"out.wav" spectrogram_path = Path(output_dir)/"out.png" @@ -112,8 +112,8 @@ speed = 1.0 # fix_duration = 27 # None or float (duration in seconds) fix_duration = None -def load_model(exp_name, model_cls, model_cfg, ckpt_step): - ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors")) +def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step): + ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin") model = CFM( @@ -238,11 +238,11 @@ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS): return batches -def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence): - if exp_name == "F5-TTS": - ema_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000) - elif exp_name == "E2-TTS": - ema_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000) +def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence): + if model == "F5-TTS": + ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000) + elif model == "E2-TTS": + ema_model = load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000) audio, sr = ref_audio if audio.shape[0] > 1: @@ -320,7 +320,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence) print(spectrogram_path) -def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words): +def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_split_words): if not custom_split_words.strip(): custom_words = [word.strip() for word in custom_split_words.split(',')] global SPLIT_WORDS @@ -372,8 +372,8 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s for i, gen_text in enumerate(gen_text_batches): print(f'gen_text {i}', gen_text) - print(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches, loading models...") - return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence) + print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...") + return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence) -infer(ref_audio, ref_text, gen_text, exp_name, remove_silence, ",".join(SPLIT_WORDS)) \ No newline at end of file +infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS)) \ No newline at end of file