diff --git a/inference-cli.py b/inference-cli.py index dbee49b..58ef5e4 100644 --- a/inference-cli.py +++ b/inference-cli.py @@ -174,6 +174,32 @@ F5TTS_model_cfg = dict( ) E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) +if model == "F5-TTS": + + if ckpt_file == "": + repo_name= "F5-TTS" + exp_name = "F5TTS_Base" + ckpt_step= 1200000 + ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) + + ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,vocab_file) + +elif model == "E2-TTS": + if ckpt_file == "": + repo_name= "E2-TTS" + exp_name = "E2TTS_Base" + ckpt_step= 1200000 + ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) + + ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,vocab_file) + +asr_pipe = pipeline( + "automatic-speech-recognition", + model="openai/whisper-large-v3-turbo", + torch_dtype=torch.float16, + device=device, +) + def chunk_text(text, max_chars=135): """ Splits the input text into chunks, each with a maximum number of characters. @@ -205,26 +231,7 @@ def chunk_text(text, max_chars=135): #if not Path(ckpt_path).exists(): #ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) -def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15): - if model == "F5-TTS": - - if ckpt_file == "": - repo_name= "F5-TTS" - exp_name = "F5TTS_Base" - ckpt_step= 1200000 - ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) - - ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,file_vocab) - - elif model == "E2-TTS": - if ckpt_file == "": - repo_name= "E2-TTS" - exp_name = "E2TTS_Base" - ckpt_step= 1200000 - ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) - - ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,file_vocab) - +def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15): audio, sr = ref_audio if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) @@ -341,13 +348,7 @@ def process_voice(ref_audio_orig, ref_text): if not ref_text.strip(): print("No reference text provided, transcribing reference audio...") - pipe = pipeline( - "automatic-speech-recognition", - model="openai/whisper-large-v3-turbo", - torch_dtype=torch.float16, - device=device, - ) - ref_text = pipe( + ref_text = asr_pipe( ref_audio, chunk_length_s=30, batch_size=128, @@ -359,7 +360,7 @@ def process_voice(ref_audio_orig, ref_text): print("Using custom reference text...") return ref_audio, ref_text -def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15): +def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15): # Add the functionality to ensure it ends with ". " if not ref_text.endswith(". ") and not ref_text.endswith("。"): if ref_text.endswith("."): @@ -375,10 +376,10 @@ def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_sile print(f'gen_text {i}', gen_text) print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...") - return infer_batch((audio, sr), ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration) + return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration) -def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_silence): +def process(ref_audio, ref_text, text_gen, model, remove_silence): main_voice = {"ref_audio":ref_audio, "ref_text":ref_text} if "voices" not in config: voices = {"main": main_voice} @@ -406,7 +407,7 @@ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_si ref_audio = voices[voice]['ref_audio'] ref_text = voices[voice]['ref_text'] print(f"Voice: {voice}") - audio, spectragram = infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence) + audio, spectragram = infer(ref_audio, ref_text, gen_text, model,remove_silence) generated_audio_segments.append(audio) if generated_audio_segments: @@ -425,4 +426,4 @@ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_si print(f.name) -process(ref_audio, ref_text, gen_text, model,ckpt_file,vocab_file, remove_silence) \ No newline at end of file +process(ref_audio, ref_text, gen_text, model, remove_silence)