diff --git a/src/f5_tts/train/datasets/prepare_csv_wavs.py b/src/f5_tts/train/datasets/prepare_csv_wavs.py index ba964ea..57bfe25 100644 --- a/src/f5_tts/train/datasets/prepare_csv_wavs.py +++ b/src/f5_tts/train/datasets/prepare_csv_wavs.py @@ -42,9 +42,11 @@ CHUNK_SIZE = 100 # Number of files to process per worker batch executor = None # Global executor for cleanup + @contextmanager def graceful_exit(): """Context manager for graceful shutdown on signals""" + def signal_handler(signum, frame): print("\nReceived signal to terminate. Cleaning up...") if executor is not None: @@ -55,13 +57,14 @@ def graceful_exit(): # Set up signal handlers signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - + try: yield finally: if executor is not None: executor.shutdown(wait=False) + def process_audio_file(audio_path, text, polyphone): """Process a single audio file by checking its existence and extracting duration.""" if not Path(audio_path).exists(): @@ -76,15 +79,17 @@ def process_audio_file(audio_path, text, polyphone): print(f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file.") return None + def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE): """Convert a list of texts to pinyin in batches.""" converted_texts = [] for i in range(0, len(texts), batch_size): - batch = texts[i:i + batch_size] + batch = texts[i : i + batch_size] converted_batch = convert_char_to_pinyin(batch, polyphone=polyphone) converted_texts.extend(converted_batch) return converted_texts + def prepare_csv_wavs_dir(input_dir, num_workers=None): global executor assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}" @@ -94,7 +99,7 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None): polyphone = True total_files = len(audio_path_text_pairs) - + # Use provided worker count or calculate optimal number worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files) print(f"\nProcessing {total_files} audio files using {worker_count} workers...") @@ -102,26 +107,22 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None): with graceful_exit(): # Initialize thread pool with optimized settings with concurrent.futures.ThreadPoolExecutor( - max_workers=worker_count, - thread_name_prefix=THREAD_NAME_PREFIX + max_workers=worker_count, thread_name_prefix=THREAD_NAME_PREFIX ) as exec: executor = exec results = [] - + # Process files in chunks for better efficiency for i in range(0, len(audio_path_text_pairs), CHUNK_SIZE): - chunk = audio_path_text_pairs[i:i + CHUNK_SIZE] + chunk = audio_path_text_pairs[i : i + CHUNK_SIZE] # Submit futures in order - chunk_futures = [ - executor.submit(process_audio_file, pair[0], pair[1], polyphone) - for pair in chunk - ] - + chunk_futures = [executor.submit(process_audio_file, pair[0], pair[1], polyphone) for pair in chunk] + # Iterate over futures in the original submission order to preserve ordering for future in tqdm( chunk_futures, total=len(chunk), - desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}" + desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}", ): try: result = future.result() @@ -129,28 +130,28 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None): results.append(result) except Exception as e: print(f"Error processing file: {e}") - + executor = None - + # Filter out failed results processed = [res for res in results if res is not None] if not processed: raise RuntimeError("No valid audio files were processed!") - + # Batch process text conversion raw_texts = [item[1] for item in processed] converted_texts = batch_convert_texts(raw_texts, polyphone, batch_size=BATCH_SIZE) - + # Prepare final results sub_result = [] durations = [] vocab_set = set() - + for (audio_path, _, duration), conv_text in zip(processed, converted_texts): sub_result.append({"audio_path": audio_path, "text": conv_text, "duration": duration}) durations.append(duration) vocab_set.update(list(conv_text)) - + return sub_result, durations, vocab_set @@ -161,13 +162,18 @@ def get_audio_duration(audio_path, timeout=5): """ try: cmd = [ - "ffprobe", "-v", "error", - "-show_entries", "format=duration", - "-of", "default=noprint_wrappers=1:nokey=1", - audio_path + "ffprobe", + "-v", + "error", + "-show_entries", + "format=duration", + "-of", + "default=noprint_wrappers=1:nokey=1", + audio_path, ] - result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - text=True, check=True, timeout=timeout) + result = subprocess.run( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout + ) duration_str = result.stdout.strip() if duration_str: return float(duration_str) @@ -241,8 +247,10 @@ def cli(): try: # Before processing, check if ffprobe is available. if shutil.which("ffprobe") is None: - print("Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower).") - + print( + "Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower)." + ) + # Usage examples in help text parser = argparse.ArgumentParser( description="Prepare and save dataset.", @@ -256,20 +264,15 @@ Examples: # With custom worker count: python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --workers 4 - """ + """, ) parser.add_argument("inp_dir", type=str, help="Input directory containing the data.") parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.") parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune") parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})") args = parser.parse_args() - - prepare_and_save_set( - args.inp_dir, - args.out_dir, - is_finetune=not args.pretrain, - num_workers=args.workers - ) + + prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers) except KeyboardInterrupt: print("\nOperation cancelled by user. Cleaning up...") if executor is not None: