mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-06 18:14:38 -08:00
Merge pull request #102 from jpgallegoar/main
Improved batching and added reference text ending
This commit is contained in:
140
gradio_app.py
140
gradio_app.py
@@ -112,101 +112,34 @@ E2TTS_ema_model = load_model(
|
||||
"E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
|
||||
)
|
||||
|
||||
def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
|
||||
if len(text.encode('utf-8')) <= max_chars:
|
||||
return [text]
|
||||
if text[-1] not in ['。', '.', '!', '!', '?', '?']:
|
||||
text += '.'
|
||||
|
||||
sentences = re.split('([。.!?!?])', text)
|
||||
sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
|
||||
|
||||
batches = []
|
||||
current_batch = ""
|
||||
|
||||
def split_by_words(text):
|
||||
words = text.split()
|
||||
current_word_part = ""
|
||||
word_batches = []
|
||||
for word in words:
|
||||
if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
|
||||
current_word_part += word + ' '
|
||||
else:
|
||||
if current_word_part:
|
||||
# Try to find a suitable split word
|
||||
for split_word in split_words:
|
||||
split_index = current_word_part.rfind(' ' + split_word + ' ')
|
||||
if split_index != -1:
|
||||
word_batches.append(current_word_part[:split_index].strip())
|
||||
current_word_part = current_word_part[split_index:].strip() + ' '
|
||||
break
|
||||
else:
|
||||
# If no suitable split word found, just append the current part
|
||||
word_batches.append(current_word_part.strip())
|
||||
current_word_part = ""
|
||||
current_word_part += word + ' '
|
||||
if current_word_part:
|
||||
word_batches.append(current_word_part.strip())
|
||||
return word_batches
|
||||
def chunk_text(text, max_chars=135):
|
||||
"""
|
||||
Splits the input text into chunks, each with a maximum number of characters.
|
||||
|
||||
Args:
|
||||
text (str): The text to be split.
|
||||
max_chars (int): The maximum number of characters per chunk.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of text chunks.
|
||||
"""
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
# Split the text into sentences based on punctuation followed by whitespace
|
||||
sentences = re.split(r'(?<=[;:,.!?])\s+', text)
|
||||
|
||||
for sentence in sentences:
|
||||
if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
|
||||
current_batch += sentence
|
||||
if len(current_chunk) + len(sentence) <= max_chars:
|
||||
current_chunk += sentence + " "
|
||||
else:
|
||||
# If adding this sentence would exceed the limit
|
||||
if current_batch:
|
||||
batches.append(current_batch)
|
||||
current_batch = ""
|
||||
|
||||
# If the sentence itself is longer than max_chars, split it
|
||||
if len(sentence.encode('utf-8')) > max_chars:
|
||||
# First, try to split by colon
|
||||
colon_parts = sentence.split(':')
|
||||
if len(colon_parts) > 1:
|
||||
for part in colon_parts:
|
||||
if len(part.encode('utf-8')) <= max_chars:
|
||||
batches.append(part)
|
||||
else:
|
||||
# If colon part is still too long, split by comma
|
||||
comma_parts = re.split('[,,]', part)
|
||||
if len(comma_parts) > 1:
|
||||
current_comma_part = ""
|
||||
for comma_part in comma_parts:
|
||||
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
|
||||
current_comma_part += comma_part + ','
|
||||
else:
|
||||
if current_comma_part:
|
||||
batches.append(current_comma_part.rstrip(','))
|
||||
current_comma_part = comma_part + ','
|
||||
if current_comma_part:
|
||||
batches.append(current_comma_part.rstrip(','))
|
||||
else:
|
||||
# If no comma, split by words
|
||||
batches.extend(split_by_words(part))
|
||||
else:
|
||||
# If no colon, split by comma
|
||||
comma_parts = re.split('[,,]', sentence)
|
||||
if len(comma_parts) > 1:
|
||||
current_comma_part = ""
|
||||
for comma_part in comma_parts:
|
||||
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
|
||||
current_comma_part += comma_part + ','
|
||||
else:
|
||||
if current_comma_part:
|
||||
batches.append(current_comma_part.rstrip(','))
|
||||
current_comma_part = comma_part + ','
|
||||
if current_comma_part:
|
||||
batches.append(current_comma_part.rstrip(','))
|
||||
else:
|
||||
# If no comma, split by words
|
||||
batches.extend(split_by_words(sentence))
|
||||
else:
|
||||
current_batch = sentence
|
||||
|
||||
if current_batch:
|
||||
batches.append(current_batch)
|
||||
|
||||
return batches
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = sentence + " "
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
return chunks
|
||||
|
||||
@gpu_decorator
|
||||
def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, progress=gr.Progress()):
|
||||
@@ -306,7 +239,9 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
aseg = AudioSegment.from_file(ref_audio_orig)
|
||||
|
||||
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
non_silent_wave += non_silent_seg
|
||||
@@ -332,13 +267,20 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
|
||||
else:
|
||||
gr.Info("Using custom reference text...")
|
||||
|
||||
# Split the input text into batches
|
||||
# Add the functionality to ensure it ends with ". "
|
||||
if not ref_text.endswith(". "):
|
||||
if ref_text.endswith("."):
|
||||
ref_text += " "
|
||||
else:
|
||||
ref_text += ". "
|
||||
|
||||
audio, sr = torchaudio.load(ref_audio)
|
||||
max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
|
||||
gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
|
||||
|
||||
# Use the new chunk_text function to split gen_text
|
||||
gen_text_batches = chunk_text(gen_text, max_chars=135)
|
||||
print('ref_text', ref_text)
|
||||
for i, gen_text in enumerate(gen_text_batches):
|
||||
print(f'gen_text {i}', gen_text)
|
||||
for i, batch_text in enumerate(gen_text_batches):
|
||||
print(f'gen_text {i}', batch_text)
|
||||
|
||||
gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
|
||||
return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
|
||||
@@ -823,4 +765,4 @@ def main(port, host, share, api):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
Reference in New Issue
Block a user