Merge branch 'main' of github.com:SWivid/F5-TTS into main

This commit is contained in:
SWivid
2024-10-20 20:43:25 +08:00

View File

@@ -174,6 +174,32 @@ F5TTS_model_cfg = dict(
)
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if model == "F5-TTS":
if ckpt_file == "":
repo_name= "F5-TTS"
exp_name = "F5TTS_Base"
ckpt_step= 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,vocab_file)
elif model == "E2-TTS":
if ckpt_file == "":
repo_name= "E2-TTS"
exp_name = "E2TTS_Base"
ckpt_step= 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,vocab_file)
asr_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device=device,
)
def chunk_text(text, max_chars=135):
"""
Splits the input text into chunks, each with a maximum number of characters.
@@ -205,26 +231,7 @@ def chunk_text(text, max_chars=135):
#if not Path(ckpt_path).exists():
#ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
if model == "F5-TTS":
if ckpt_file == "":
repo_name= "F5-TTS"
exp_name = "F5TTS_Base"
ckpt_step= 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,file_vocab)
elif model == "E2-TTS":
if ckpt_file == "":
repo_name= "E2-TTS"
exp_name = "E2TTS_Base"
ckpt_step= 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,file_vocab)
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
audio, sr = ref_audio
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
@@ -341,13 +348,7 @@ def process_voice(ref_audio_orig, ref_text):
if not ref_text.strip():
print("No reference text provided, transcribing reference audio...")
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device=device,
)
ref_text = pipe(
ref_text = asr_pipe(
ref_audio,
chunk_length_s=30,
batch_size=128,
@@ -359,7 +360,7 @@ def process_voice(ref_audio_orig, ref_text):
print("Using custom reference text...")
return ref_audio, ref_text
def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
# Add the functionality to ensure it ends with ". "
if not ref_text.endswith(". ") and not ref_text.endswith(""):
if ref_text.endswith("."):
@@ -375,10 +376,10 @@ def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_sile
print(f'gen_text {i}', gen_text)
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
return infer_batch((audio, sr), ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration)
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_silence):
def process(ref_audio, ref_text, text_gen, model, remove_silence):
main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
if "voices" not in config:
voices = {"main": main_voice}
@@ -406,7 +407,7 @@ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_si
ref_audio = voices[voice]['ref_audio']
ref_text = voices[voice]['ref_text']
print(f"Voice: {voice}")
audio, spectragram = infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence)
audio, spectragram = infer(ref_audio, ref_text, gen_text, model,remove_silence)
generated_audio_segments.append(audio)
if generated_audio_segments:
@@ -425,4 +426,4 @@ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_si
print(f.name)
process(ref_audio, ref_text, gen_text, model,ckpt_file,vocab_file, remove_silence)
process(ref_audio, ref_text, gen_text, model, remove_silence)