mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-05 01:27:55 -08:00
update inference-cli with features in gradio
This commit is contained in:
193
inference-cli.py
193
inference-cli.py
@@ -93,17 +93,6 @@ wave_path = Path(output_dir)/"out.wav"
|
||||
spectrogram_path = Path(output_dir)/"out.png"
|
||||
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
||||
|
||||
SPLIT_WORDS = [
|
||||
"but", "however", "nevertheless", "yet", "still",
|
||||
"therefore", "thus", "hence", "consequently",
|
||||
"moreover", "furthermore", "additionally",
|
||||
"meanwhile", "alternatively", "otherwise",
|
||||
"namely", "specifically", "for example", "such as",
|
||||
"in fact", "indeed", "notably",
|
||||
"in contrast", "on the other hand", "conversely",
|
||||
"in conclusion", "to summarize", "finally"
|
||||
]
|
||||
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
@@ -167,103 +156,36 @@ F5TTS_model_cfg = dict(
|
||||
)
|
||||
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
|
||||
def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
|
||||
if len(text.encode('utf-8')) <= max_chars:
|
||||
return [text]
|
||||
if text[-1] not in ['。', '.', '!', '!', '?', '?']:
|
||||
text += '.'
|
||||
|
||||
sentences = re.split('([。.!?!?])', text)
|
||||
sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
|
||||
|
||||
batches = []
|
||||
current_batch = ""
|
||||
|
||||
def split_by_words(text):
|
||||
words = text.split()
|
||||
current_word_part = ""
|
||||
word_batches = []
|
||||
for word in words:
|
||||
if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
|
||||
current_word_part += word + ' '
|
||||
else:
|
||||
if current_word_part:
|
||||
# Try to find a suitable split word
|
||||
for split_word in split_words:
|
||||
split_index = current_word_part.rfind(' ' + split_word + ' ')
|
||||
if split_index != -1:
|
||||
word_batches.append(current_word_part[:split_index].strip())
|
||||
current_word_part = current_word_part[split_index:].strip() + ' '
|
||||
break
|
||||
else:
|
||||
# If no suitable split word found, just append the current part
|
||||
word_batches.append(current_word_part.strip())
|
||||
current_word_part = ""
|
||||
current_word_part += word + ' '
|
||||
if current_word_part:
|
||||
word_batches.append(current_word_part.strip())
|
||||
return word_batches
|
||||
|
||||
def chunk_text(text, max_chars=135):
|
||||
"""
|
||||
Splits the input text into chunks, each with a maximum number of characters.
|
||||
Args:
|
||||
text (str): The text to be split.
|
||||
max_chars (int): The maximum number of characters per chunk.
|
||||
Returns:
|
||||
List[str]: A list of text chunks.
|
||||
"""
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
# Split the text into sentences based on punctuation followed by whitespace
|
||||
sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
|
||||
|
||||
for sentence in sentences:
|
||||
if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
|
||||
current_batch += sentence
|
||||
if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
|
||||
current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
||||
else:
|
||||
# If adding this sentence would exceed the limit
|
||||
if current_batch:
|
||||
batches.append(current_batch)
|
||||
current_batch = ""
|
||||
|
||||
# If the sentence itself is longer than max_chars, split it
|
||||
if len(sentence.encode('utf-8')) > max_chars:
|
||||
# First, try to split by colon
|
||||
colon_parts = sentence.split(':')
|
||||
if len(colon_parts) > 1:
|
||||
for part in colon_parts:
|
||||
if len(part.encode('utf-8')) <= max_chars:
|
||||
batches.append(part)
|
||||
else:
|
||||
# If colon part is still too long, split by comma
|
||||
comma_parts = re.split('[,,]', part)
|
||||
if len(comma_parts) > 1:
|
||||
current_comma_part = ""
|
||||
for comma_part in comma_parts:
|
||||
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
|
||||
current_comma_part += comma_part + ','
|
||||
else:
|
||||
if current_comma_part:
|
||||
batches.append(current_comma_part.rstrip(','))
|
||||
current_comma_part = comma_part + ','
|
||||
if current_comma_part:
|
||||
batches.append(current_comma_part.rstrip(','))
|
||||
else:
|
||||
# If no comma, split by words
|
||||
batches.extend(split_by_words(part))
|
||||
else:
|
||||
# If no colon, split by comma
|
||||
comma_parts = re.split('[,,]', sentence)
|
||||
if len(comma_parts) > 1:
|
||||
current_comma_part = ""
|
||||
for comma_part in comma_parts:
|
||||
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
|
||||
current_comma_part += comma_part + ','
|
||||
else:
|
||||
if current_comma_part:
|
||||
batches.append(current_comma_part.rstrip(','))
|
||||
current_comma_part = comma_part + ','
|
||||
if current_comma_part:
|
||||
batches.append(current_comma_part.rstrip(','))
|
||||
else:
|
||||
# If no comma, split by words
|
||||
batches.extend(split_by_words(sentence))
|
||||
else:
|
||||
current_batch = sentence
|
||||
|
||||
if current_batch:
|
||||
batches.append(current_batch)
|
||||
|
||||
return batches
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
||||
|
||||
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
|
||||
if model == "F5-TTS":
|
||||
ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
|
||||
elif model == "E2-TTS":
|
||||
@@ -321,8 +243,44 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
|
||||
generated_waves.append(generated_wave)
|
||||
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
||||
|
||||
# Combine all generated waves
|
||||
final_wave = np.concatenate(generated_waves)
|
||||
# Combine all generated waves with cross-fading
|
||||
if cross_fade_duration <= 0:
|
||||
# Simply concatenate
|
||||
final_wave = np.concatenate(generated_waves)
|
||||
else:
|
||||
final_wave = generated_waves[0]
|
||||
for i in range(1, len(generated_waves)):
|
||||
prev_wave = final_wave
|
||||
next_wave = generated_waves[i]
|
||||
|
||||
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
||||
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
||||
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
||||
|
||||
if cross_fade_samples <= 0:
|
||||
# No overlap possible, concatenate
|
||||
final_wave = np.concatenate([prev_wave, next_wave])
|
||||
continue
|
||||
|
||||
# Overlapping parts
|
||||
prev_overlap = prev_wave[-cross_fade_samples:]
|
||||
next_overlap = next_wave[:cross_fade_samples]
|
||||
|
||||
# Fade out and fade in
|
||||
fade_out = np.linspace(1, 0, cross_fade_samples)
|
||||
fade_in = np.linspace(0, 1, cross_fade_samples)
|
||||
|
||||
# Cross-faded overlap
|
||||
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
||||
|
||||
# Combine
|
||||
new_wave = np.concatenate([
|
||||
prev_wave[:-cross_fade_samples],
|
||||
cross_faded_overlap,
|
||||
next_wave[cross_fade_samples:]
|
||||
])
|
||||
|
||||
final_wave = new_wave
|
||||
|
||||
with open(wave_path, "wb") as f:
|
||||
sf.write(f.name, final_wave, target_sample_rate)
|
||||
@@ -343,11 +301,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
|
||||
print(spectrogram_path)
|
||||
|
||||
|
||||
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_split_words):
|
||||
if not custom_split_words.strip():
|
||||
custom_words = [word.strip() for word in custom_split_words.split(',')]
|
||||
global SPLIT_WORDS
|
||||
SPLIT_WORDS = custom_words
|
||||
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
|
||||
|
||||
print(gen_text)
|
||||
|
||||
@@ -355,7 +309,7 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
aseg = AudioSegment.from_file(ref_audio_orig)
|
||||
|
||||
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
|
||||
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
non_silent_wave += non_silent_seg
|
||||
@@ -387,16 +341,23 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
|
||||
else:
|
||||
print("Using custom reference text...")
|
||||
|
||||
# Add the functionality to ensure it ends with ". "
|
||||
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
|
||||
if ref_text.endswith("."):
|
||||
ref_text += " "
|
||||
else:
|
||||
ref_text += ". "
|
||||
|
||||
# Split the input text into batches
|
||||
audio, sr = torchaudio.load(ref_audio)
|
||||
max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
|
||||
gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
|
||||
max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
|
||||
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
||||
print('ref_text', ref_text)
|
||||
for i, gen_text in enumerate(gen_text_batches):
|
||||
print(f'gen_text {i}', gen_text)
|
||||
|
||||
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
||||
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
|
||||
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
|
||||
|
||||
|
||||
infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))
|
||||
infer(ref_audio, ref_text, gen_text, model, remove_silence)
|
||||
|
||||
@@ -6,5 +6,5 @@ ref_text = "Some call me nature, others call me mother nature."
|
||||
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
|
||||
# File with text to generate. Ignores the text above.
|
||||
gen_file = ""
|
||||
remove_silence = true
|
||||
remove_silence = false
|
||||
output_dir = "tests"
|
||||
Reference in New Issue
Block a user