diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index 625a562..88fb0cb 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -2,6 +2,7 @@ # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format import os import sys +from concurrent.futures import ThreadPoolExecutor os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility @@ -86,6 +87,8 @@ def chunk_text(text, max_chars=135): sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text) for sentence in sentences: + if not sentence: + continue if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars: current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence else: @@ -279,12 +282,12 @@ def remove_silence_edges(audio, silence_threshold=-42): audio = audio[non_silent_start_idx:] # Remove silence from the end - non_silent_end_duration = audio.duration_seconds - for ms in reversed(audio): - if ms.dBFS > silence_threshold: - break - non_silent_end_duration -= 0.001 - trimmed_audio = audio[: int(non_silent_end_duration * 1000)] + reversed_audio = audio.reverse() + non_silent_end_idx = silence.detect_leading_silence(reversed_audio, silence_threshold=silence_threshold) + if non_silent_end_idx > 0: + trimmed_audio = audio[: len(audio) - non_silent_end_idx] + else: + trimmed_audio = audio return trimmed_audio @@ -400,11 +403,16 @@ def infer_process( audio, sr = torchaudio.load(ref_audio) max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed) gen_text_batches = chunk_text(gen_text, max_chars=max_chars) - for i, gen_text in enumerate(gen_text_batches): - print(f"gen_text {i}", gen_text) + for i, gen_text_i in enumerate(gen_text_batches): + print(f"gen_text {i}", gen_text_i) print("\n") show_info(f"Generating audio in {len(gen_text_batches)} batches...") + + if not gen_text_batches: + show_info("No text batches to generate.") + return None, target_sample_rate, None + return next( infer_batch_process( (audio, sr), @@ -466,7 +474,7 @@ def infer_batch_process( if len(ref_text[-1].encode("utf-8")) == 1: ref_text = ref_text + " " - def process_batch(gen_text): + def _infer_basic(gen_text): local_speed = speed if len(gen_text.encode("utf-8")) < 10: local_speed = 0.3 @@ -509,23 +517,34 @@ def infer_batch_process( # wav -> numpy generated_wave = generated_wave.squeeze().cpu().numpy() - if streaming: - for j in range(0, len(generated_wave), chunk_size): - yield generated_wave[j : j + chunk_size], target_sample_rate - else: - generated_cpu = generated[0].cpu().numpy() - del generated - yield generated_wave, generated_cpu + return generated_wave, generated + + def infer_single_process(gen_text): + generated_wave, generated = _infer_basic(gen_text) + generated_cpu = generated[0].cpu().numpy() + del generated + return generated_wave, generated_cpu + + def infer_single_process_streaming(gen_text): + # for src/f5_tts/socket_server.py + generated_wave, generated = _infer_basic(gen_text) + del generated + for j in range(0, len(generated_wave), chunk_size): + yield generated_wave[j : j + chunk_size], target_sample_rate if streaming: for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches: - for chunk in process_batch(gen_text): + for chunk in infer_single_process_streaming(gen_text): yield chunk else: - for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches: - generated_wave, generated_mel_spec = next(process_batch(gen_text)) - generated_waves.append(generated_wave) - spectrograms.append(generated_mel_spec) + with ThreadPoolExecutor() as executor: + futures = [executor.submit(infer_single_process, gen_text) for gen_text in gen_text_batches] + for future in progress.tqdm(futures) if progress is not None else futures: + result = future.result() + if result: + generated_wave, generated_mel_spec = result + generated_waves.append(generated_wave) + spectrograms.append(generated_mel_spec) if generated_waves: if cross_fade_duration <= 0: