mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-18 15:57:35 -08:00
Merge branch 'main' of github.com:SWivid/F5-TTS into main
This commit is contained in:
@@ -174,6 +174,32 @@ F5TTS_model_cfg = dict(
|
||||
)
|
||||
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
|
||||
if model == "F5-TTS":
|
||||
|
||||
if ckpt_file == "":
|
||||
repo_name= "F5-TTS"
|
||||
exp_name = "F5TTS_Base"
|
||||
ckpt_step= 1200000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
|
||||
ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,vocab_file)
|
||||
|
||||
elif model == "E2-TTS":
|
||||
if ckpt_file == "":
|
||||
repo_name= "E2-TTS"
|
||||
exp_name = "E2TTS_Base"
|
||||
ckpt_step= 1200000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
|
||||
ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,vocab_file)
|
||||
|
||||
asr_pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model="openai/whisper-large-v3-turbo",
|
||||
torch_dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def chunk_text(text, max_chars=135):
|
||||
"""
|
||||
Splits the input text into chunks, each with a maximum number of characters.
|
||||
@@ -205,26 +231,7 @@ def chunk_text(text, max_chars=135):
|
||||
#if not Path(ckpt_path).exists():
|
||||
#ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
|
||||
def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
|
||||
if model == "F5-TTS":
|
||||
|
||||
if ckpt_file == "":
|
||||
repo_name= "F5-TTS"
|
||||
exp_name = "F5TTS_Base"
|
||||
ckpt_step= 1200000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
|
||||
ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,file_vocab)
|
||||
|
||||
elif model == "E2-TTS":
|
||||
if ckpt_file == "":
|
||||
repo_name= "E2-TTS"
|
||||
exp_name = "E2TTS_Base"
|
||||
ckpt_step= 1200000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
|
||||
ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,file_vocab)
|
||||
|
||||
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
|
||||
audio, sr = ref_audio
|
||||
if audio.shape[0] > 1:
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
@@ -341,13 +348,7 @@ def process_voice(ref_audio_orig, ref_text):
|
||||
|
||||
if not ref_text.strip():
|
||||
print("No reference text provided, transcribing reference audio...")
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model="openai/whisper-large-v3-turbo",
|
||||
torch_dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
ref_text = pipe(
|
||||
ref_text = asr_pipe(
|
||||
ref_audio,
|
||||
chunk_length_s=30,
|
||||
batch_size=128,
|
||||
@@ -359,7 +360,7 @@ def process_voice(ref_audio_orig, ref_text):
|
||||
print("Using custom reference text...")
|
||||
return ref_audio, ref_text
|
||||
|
||||
def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
|
||||
def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
|
||||
# Add the functionality to ensure it ends with ". "
|
||||
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
|
||||
if ref_text.endswith("."):
|
||||
@@ -375,10 +376,10 @@ def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_sile
|
||||
print(f'gen_text {i}', gen_text)
|
||||
|
||||
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
||||
return infer_batch((audio, sr), ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration)
|
||||
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
|
||||
|
||||
|
||||
def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_silence):
|
||||
def process(ref_audio, ref_text, text_gen, model, remove_silence):
|
||||
main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
|
||||
if "voices" not in config:
|
||||
voices = {"main": main_voice}
|
||||
@@ -406,7 +407,7 @@ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_si
|
||||
ref_audio = voices[voice]['ref_audio']
|
||||
ref_text = voices[voice]['ref_text']
|
||||
print(f"Voice: {voice}")
|
||||
audio, spectragram = infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence)
|
||||
audio, spectragram = infer(ref_audio, ref_text, gen_text, model,remove_silence)
|
||||
generated_audio_segments.append(audio)
|
||||
|
||||
if generated_audio_segments:
|
||||
@@ -425,4 +426,4 @@ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_si
|
||||
print(f.name)
|
||||
|
||||
|
||||
process(ref_audio, ref_text, gen_text, model,ckpt_file,vocab_file, remove_silence)
|
||||
process(ref_audio, ref_text, gen_text, model, remove_silence)
|
||||
|
||||
Reference in New Issue
Block a user