From 765a2ae390810081606d3a0f4c1f267b8b3ee1c3 Mon Sep 17 00:00:00 2001 From: chigkim Date: Sun, 20 Oct 2024 08:01:22 -0400 Subject: [PATCH] Load model once in the beginning. --- inference-cli.py | 65 ++++++++++++++++++++++++------------------------ 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/inference-cli.py b/inference-cli.py index 7162790..c938dc7 100644 --- a/inference-cli.py +++ b/inference-cli.py @@ -175,6 +175,32 @@ F5TTS_model_cfg = dict( ) E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) +if model == "F5-TTS": + + if ckpt_file == "": + repo_name= "F5-TTS" + exp_name = "F5TTS_Base" + ckpt_step= 1200000 + ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) + + ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,vocab_file) + +elif model == "E2-TTS": + if ckpt_file == "": + repo_name= "E2-TTS" + exp_name = "E2TTS_Base" + ckpt_step= 1200000 + ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) + + ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,vocab_file) + +asr_pipe = pipeline( + "automatic-speech-recognition", + model="openai/whisper-large-v3-turbo", + torch_dtype=torch.float16, + device=device, +) + def chunk_text(text, max_chars=135): """ Splits the input text into chunks, each with a maximum number of characters. @@ -206,26 +232,7 @@ def chunk_text(text, max_chars=135): #if not Path(ckpt_path).exists(): #ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) -def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15): - if model == "F5-TTS": - - if ckpt_file == "": - repo_name= "F5-TTS" - exp_name = "F5TTS_Base" - ckpt_step= 1200000 - ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) - - ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,file_vocab) - - elif model == "E2-TTS": - if ckpt_file == "": - repo_name= "E2-TTS" - exp_name = "E2TTS_Base" - ckpt_step= 1200000 - ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) - - ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,file_vocab) - +def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15): audio, sr = ref_audio if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) @@ -342,13 +349,7 @@ def process_voice(ref_audio_orig, ref_text): if not ref_text.strip(): print("No reference text provided, transcribing reference audio...") - pipe = pipeline( - "automatic-speech-recognition", - model="openai/whisper-large-v3-turbo", - torch_dtype=torch.float16, - device=device, - ) - ref_text = pipe( + ref_text = asr_pipe( ref_audio, chunk_length_s=30, batch_size=128, @@ -360,7 +361,7 @@ def process_voice(ref_audio_orig, ref_text): print("Using custom reference text...") return ref_audio, ref_text -def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15): +def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15): # Add the functionality to ensure it ends with ". " if not ref_text.endswith(". ") and not ref_text.endswith("。"): if ref_text.endswith("."): @@ -376,10 +377,10 @@ def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_sile print(f'gen_text {i}', gen_text) print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...") - return infer_batch((audio, sr), ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration) + return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration) -def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_silence): +def process(ref_audio, ref_text, text_gen, model, remove_silence): main_voice = {"ref_audio":ref_audio, "ref_text":ref_text} if "voices" not in config: voices = {"main": main_voice} @@ -407,7 +408,7 @@ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_si ref_audio = voices[voice]['ref_audio'] ref_text = voices[voice]['ref_text'] print(f"Voice: {voice}") - audio, spectragram = infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence) + audio, spectragram = infer(ref_audio, ref_text, gen_text, model,remove_silence) generated_audio_segments.append(audio) if generated_audio_segments: @@ -426,4 +427,4 @@ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_si print(f.name) -process(ref_audio, ref_text, gen_text, model,ckpt_file,vocab_file, remove_silence) \ No newline at end of file +process(ref_audio, ref_text, gen_text, model, remove_silence) \ No newline at end of file