diff --git a/inference-cli.py b/inference-cli.py index 480fc5c..3d6e3d2 100644 --- a/inference-cli.py +++ b/inference-cli.py @@ -93,17 +93,6 @@ wave_path = Path(output_dir)/"out.wav" spectrogram_path = Path(output_dir)/"out.png" vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" -SPLIT_WORDS = [ - "but", "however", "nevertheless", "yet", "still", - "therefore", "thus", "hence", "consequently", - "moreover", "furthermore", "additionally", - "meanwhile", "alternatively", "otherwise", - "namely", "specifically", "for example", "such as", - "in fact", "indeed", "notably", - "in contrast", "on the other hand", "conversely", - "in conclusion", "to summarize", "finally" -] - device = ( "cuda" if torch.cuda.is_available() @@ -167,103 +156,36 @@ F5TTS_model_cfg = dict( ) E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) -def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS): - if len(text.encode('utf-8')) <= max_chars: - return [text] - if text[-1] not in ['。', '.', '!', '!', '?', '?']: - text += '.' - - sentences = re.split('([。.!?!?])', text) - sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])] - - batches = [] - current_batch = "" - - def split_by_words(text): - words = text.split() - current_word_part = "" - word_batches = [] - for word in words: - if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars: - current_word_part += word + ' ' - else: - if current_word_part: - # Try to find a suitable split word - for split_word in split_words: - split_index = current_word_part.rfind(' ' + split_word + ' ') - if split_index != -1: - word_batches.append(current_word_part[:split_index].strip()) - current_word_part = current_word_part[split_index:].strip() + ' ' - break - else: - # If no suitable split word found, just append the current part - word_batches.append(current_word_part.strip()) - current_word_part = "" - current_word_part += word + ' ' - if current_word_part: - word_batches.append(current_word_part.strip()) - return word_batches + +def chunk_text(text, max_chars=135): + """ + Splits the input text into chunks, each with a maximum number of characters. + Args: + text (str): The text to be split. + max_chars (int): The maximum number of characters per chunk. + Returns: + List[str]: A list of text chunks. + """ + chunks = [] + current_chunk = "" + # Split the text into sentences based on punctuation followed by whitespace + sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text) for sentence in sentences: - if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars: - current_batch += sentence + if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars: + current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence else: - # If adding this sentence would exceed the limit - if current_batch: - batches.append(current_batch) - current_batch = "" - - # If the sentence itself is longer than max_chars, split it - if len(sentence.encode('utf-8')) > max_chars: - # First, try to split by colon - colon_parts = sentence.split(':') - if len(colon_parts) > 1: - for part in colon_parts: - if len(part.encode('utf-8')) <= max_chars: - batches.append(part) - else: - # If colon part is still too long, split by comma - comma_parts = re.split('[,,]', part) - if len(comma_parts) > 1: - current_comma_part = "" - for comma_part in comma_parts: - if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars: - current_comma_part += comma_part + ',' - else: - if current_comma_part: - batches.append(current_comma_part.rstrip(',')) - current_comma_part = comma_part + ',' - if current_comma_part: - batches.append(current_comma_part.rstrip(',')) - else: - # If no comma, split by words - batches.extend(split_by_words(part)) - else: - # If no colon, split by comma - comma_parts = re.split('[,,]', sentence) - if len(comma_parts) > 1: - current_comma_part = "" - for comma_part in comma_parts: - if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars: - current_comma_part += comma_part + ',' - else: - if current_comma_part: - batches.append(current_comma_part.rstrip(',')) - current_comma_part = comma_part + ',' - if current_comma_part: - batches.append(current_comma_part.rstrip(',')) - else: - # If no comma, split by words - batches.extend(split_by_words(sentence)) - else: - current_batch = sentence - - if current_batch: - batches.append(current_batch) - - return batches + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence -def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence): + if current_chunk: + chunks.append(current_chunk.strip()) + + return chunks + + +def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15): if model == "F5-TTS": ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000) elif model == "E2-TTS": @@ -321,8 +243,44 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence): generated_waves.append(generated_wave) spectrograms.append(generated_mel_spec[0].cpu().numpy()) - # Combine all generated waves - final_wave = np.concatenate(generated_waves) + # Combine all generated waves with cross-fading + if cross_fade_duration <= 0: + # Simply concatenate + final_wave = np.concatenate(generated_waves) + else: + final_wave = generated_waves[0] + for i in range(1, len(generated_waves)): + prev_wave = final_wave + next_wave = generated_waves[i] + + # Calculate cross-fade samples, ensuring it does not exceed wave lengths + cross_fade_samples = int(cross_fade_duration * target_sample_rate) + cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) + + if cross_fade_samples <= 0: + # No overlap possible, concatenate + final_wave = np.concatenate([prev_wave, next_wave]) + continue + + # Overlapping parts + prev_overlap = prev_wave[-cross_fade_samples:] + next_overlap = next_wave[:cross_fade_samples] + + # Fade out and fade in + fade_out = np.linspace(1, 0, cross_fade_samples) + fade_in = np.linspace(0, 1, cross_fade_samples) + + # Cross-faded overlap + cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in + + # Combine + new_wave = np.concatenate([ + prev_wave[:-cross_fade_samples], + cross_faded_overlap, + next_wave[cross_fade_samples:] + ]) + + final_wave = new_wave with open(wave_path, "wb") as f: sf.write(f.name, final_wave, target_sample_rate) @@ -343,11 +301,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence): print(spectrogram_path) -def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_split_words): - if not custom_split_words.strip(): - custom_words = [word.strip() for word in custom_split_words.split(',')] - global SPLIT_WORDS - SPLIT_WORDS = custom_words +def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15): print(gen_text) @@ -355,7 +309,7 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: aseg = AudioSegment.from_file(ref_audio_orig) - non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500) + non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: non_silent_wave += non_silent_seg @@ -387,16 +341,23 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli else: print("Using custom reference text...") + # Add the functionality to ensure it ends with ". " + if not ref_text.endswith(". ") and not ref_text.endswith("。"): + if ref_text.endswith("."): + ref_text += " " + else: + ref_text += ". " + # Split the input text into batches audio, sr = torchaudio.load(ref_audio) - max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr)) - gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars) + max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr)) + gen_text_batches = chunk_text(gen_text, max_chars=max_chars) print('ref_text', ref_text) for i, gen_text in enumerate(gen_text_batches): print(f'gen_text {i}', gen_text) print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...") - return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence) + return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration) -infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS)) +infer(ref_audio, ref_text, gen_text, model, remove_silence) diff --git a/inference-cli.toml b/inference-cli.toml index d2f3bbb..b6bea1c 100644 --- a/inference-cli.toml +++ b/inference-cli.toml @@ -6,5 +6,5 @@ ref_text = "Some call me nature, others call me mother nature." gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences." # File with text to generate. Ignores the text above. gen_file = "" -remove_silence = true +remove_silence = false output_dir = "tests" \ No newline at end of file