From 028421eafd2523fc195800ff97116fdd7db41ede Mon Sep 17 00:00:00 2001 From: jpgallegoar Date: Tue, 15 Oct 2024 19:09:07 +0200 Subject: [PATCH] Improved batching and added reference text ending --- gradio_app.py | 140 +++++++++++++++----------------------------------- 1 file changed, 41 insertions(+), 99 deletions(-) diff --git a/gradio_app.py b/gradio_app.py index 54f2323..002c480 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -112,101 +112,34 @@ E2TTS_ema_model = load_model( "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000 ) -def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS): - if len(text.encode('utf-8')) <= max_chars: - return [text] - if text[-1] not in ['。', '.', '!', '!', '?', '?']: - text += '.' - - sentences = re.split('([。.!?!?])', text) - sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])] - - batches = [] - current_batch = "" - - def split_by_words(text): - words = text.split() - current_word_part = "" - word_batches = [] - for word in words: - if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars: - current_word_part += word + ' ' - else: - if current_word_part: - # Try to find a suitable split word - for split_word in split_words: - split_index = current_word_part.rfind(' ' + split_word + ' ') - if split_index != -1: - word_batches.append(current_word_part[:split_index].strip()) - current_word_part = current_word_part[split_index:].strip() + ' ' - break - else: - # If no suitable split word found, just append the current part - word_batches.append(current_word_part.strip()) - current_word_part = "" - current_word_part += word + ' ' - if current_word_part: - word_batches.append(current_word_part.strip()) - return word_batches +def chunk_text(text, max_chars=135): + """ + Splits the input text into chunks, each with a maximum number of characters. + + Args: + text (str): The text to be split. + max_chars (int): The maximum number of characters per chunk. + + Returns: + List[str]: A list of text chunks. + """ + chunks = [] + current_chunk = "" + # Split the text into sentences based on punctuation followed by whitespace + sentences = re.split(r'(?<=[;:,.!?])\s+', text) for sentence in sentences: - if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars: - current_batch += sentence + if len(current_chunk) + len(sentence) <= max_chars: + current_chunk += sentence + " " else: - # If adding this sentence would exceed the limit - if current_batch: - batches.append(current_batch) - current_batch = "" - - # If the sentence itself is longer than max_chars, split it - if len(sentence.encode('utf-8')) > max_chars: - # First, try to split by colon - colon_parts = sentence.split(':') - if len(colon_parts) > 1: - for part in colon_parts: - if len(part.encode('utf-8')) <= max_chars: - batches.append(part) - else: - # If colon part is still too long, split by comma - comma_parts = re.split('[,,]', part) - if len(comma_parts) > 1: - current_comma_part = "" - for comma_part in comma_parts: - if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars: - current_comma_part += comma_part + ',' - else: - if current_comma_part: - batches.append(current_comma_part.rstrip(',')) - current_comma_part = comma_part + ',' - if current_comma_part: - batches.append(current_comma_part.rstrip(',')) - else: - # If no comma, split by words - batches.extend(split_by_words(part)) - else: - # If no colon, split by comma - comma_parts = re.split('[,,]', sentence) - if len(comma_parts) > 1: - current_comma_part = "" - for comma_part in comma_parts: - if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars: - current_comma_part += comma_part + ',' - else: - if current_comma_part: - batches.append(current_comma_part.rstrip(',')) - current_comma_part = comma_part + ',' - if current_comma_part: - batches.append(current_comma_part.rstrip(',')) - else: - # If no comma, split by words - batches.extend(split_by_words(sentence)) - else: - current_batch = sentence - - if current_batch: - batches.append(current_batch) - - return batches + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = sentence + " " + + if current_chunk: + chunks.append(current_chunk.strip()) + + return chunks @gpu_decorator def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, progress=gr.Progress()): @@ -306,7 +239,9 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: aseg = AudioSegment.from_file(ref_audio_orig) - non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500) + non_silent_segs = silence.split_on_silence( + aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500 + ) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: non_silent_wave += non_silent_seg @@ -332,13 +267,20 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s else: gr.Info("Using custom reference text...") - # Split the input text into batches + # Add the functionality to ensure it ends with ". " + if not ref_text.endswith(". "): + if ref_text.endswith("."): + ref_text += " " + else: + ref_text += ". " + audio, sr = torchaudio.load(ref_audio) - max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr)) - gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars) + + # Use the new chunk_text function to split gen_text + gen_text_batches = chunk_text(gen_text, max_chars=135) print('ref_text', ref_text) - for i, gen_text in enumerate(gen_text_batches): - print(f'gen_text {i}', gen_text) + for i, batch_text in enumerate(gen_text_batches): + print(f'gen_text {i}', batch_text) gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches") return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence) @@ -823,4 +765,4 @@ def main(port, host, share, api): if __name__ == "__main__": - main() + main() \ No newline at end of file