mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-05 09:39:52 -08:00
Fix code formatting
This commit is contained in:
@@ -42,9 +42,11 @@ CHUNK_SIZE = 100 # Number of files to process per worker batch
|
||||
|
||||
executor = None # Global executor for cleanup
|
||||
|
||||
|
||||
@contextmanager
|
||||
def graceful_exit():
|
||||
"""Context manager for graceful shutdown on signals"""
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
print("\nReceived signal to terminate. Cleaning up...")
|
||||
if executor is not None:
|
||||
@@ -55,13 +57,14 @@ def graceful_exit():
|
||||
# Set up signal handlers
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if executor is not None:
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
|
||||
def process_audio_file(audio_path, text, polyphone):
|
||||
"""Process a single audio file by checking its existence and extracting duration."""
|
||||
if not Path(audio_path).exists():
|
||||
@@ -76,15 +79,17 @@ def process_audio_file(audio_path, text, polyphone):
|
||||
print(f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file.")
|
||||
return None
|
||||
|
||||
|
||||
def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE):
|
||||
"""Convert a list of texts to pinyin in batches."""
|
||||
converted_texts = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
batch = texts[i : i + batch_size]
|
||||
converted_batch = convert_char_to_pinyin(batch, polyphone=polyphone)
|
||||
converted_texts.extend(converted_batch)
|
||||
return converted_texts
|
||||
|
||||
|
||||
def prepare_csv_wavs_dir(input_dir, num_workers=None):
|
||||
global executor
|
||||
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
|
||||
@@ -94,7 +99,7 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None):
|
||||
|
||||
polyphone = True
|
||||
total_files = len(audio_path_text_pairs)
|
||||
|
||||
|
||||
# Use provided worker count or calculate optimal number
|
||||
worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files)
|
||||
print(f"\nProcessing {total_files} audio files using {worker_count} workers...")
|
||||
@@ -102,26 +107,22 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None):
|
||||
with graceful_exit():
|
||||
# Initialize thread pool with optimized settings
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=worker_count,
|
||||
thread_name_prefix=THREAD_NAME_PREFIX
|
||||
max_workers=worker_count, thread_name_prefix=THREAD_NAME_PREFIX
|
||||
) as exec:
|
||||
executor = exec
|
||||
results = []
|
||||
|
||||
|
||||
# Process files in chunks for better efficiency
|
||||
for i in range(0, len(audio_path_text_pairs), CHUNK_SIZE):
|
||||
chunk = audio_path_text_pairs[i:i + CHUNK_SIZE]
|
||||
chunk = audio_path_text_pairs[i : i + CHUNK_SIZE]
|
||||
# Submit futures in order
|
||||
chunk_futures = [
|
||||
executor.submit(process_audio_file, pair[0], pair[1], polyphone)
|
||||
for pair in chunk
|
||||
]
|
||||
|
||||
chunk_futures = [executor.submit(process_audio_file, pair[0], pair[1], polyphone) for pair in chunk]
|
||||
|
||||
# Iterate over futures in the original submission order to preserve ordering
|
||||
for future in tqdm(
|
||||
chunk_futures,
|
||||
total=len(chunk),
|
||||
desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}"
|
||||
desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}",
|
||||
):
|
||||
try:
|
||||
result = future.result()
|
||||
@@ -129,28 +130,28 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None):
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
print(f"Error processing file: {e}")
|
||||
|
||||
|
||||
executor = None
|
||||
|
||||
|
||||
# Filter out failed results
|
||||
processed = [res for res in results if res is not None]
|
||||
if not processed:
|
||||
raise RuntimeError("No valid audio files were processed!")
|
||||
|
||||
|
||||
# Batch process text conversion
|
||||
raw_texts = [item[1] for item in processed]
|
||||
converted_texts = batch_convert_texts(raw_texts, polyphone, batch_size=BATCH_SIZE)
|
||||
|
||||
|
||||
# Prepare final results
|
||||
sub_result = []
|
||||
durations = []
|
||||
vocab_set = set()
|
||||
|
||||
|
||||
for (audio_path, _, duration), conv_text in zip(processed, converted_texts):
|
||||
sub_result.append({"audio_path": audio_path, "text": conv_text, "duration": duration})
|
||||
durations.append(duration)
|
||||
vocab_set.update(list(conv_text))
|
||||
|
||||
|
||||
return sub_result, durations, vocab_set
|
||||
|
||||
|
||||
@@ -161,13 +162,18 @@ def get_audio_duration(audio_path, timeout=5):
|
||||
"""
|
||||
try:
|
||||
cmd = [
|
||||
"ffprobe", "-v", "error",
|
||||
"-show_entries", "format=duration",
|
||||
"-of", "default=noprint_wrappers=1:nokey=1",
|
||||
audio_path
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
audio_path,
|
||||
]
|
||||
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||
text=True, check=True, timeout=timeout)
|
||||
result = subprocess.run(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout
|
||||
)
|
||||
duration_str = result.stdout.strip()
|
||||
if duration_str:
|
||||
return float(duration_str)
|
||||
@@ -241,8 +247,10 @@ def cli():
|
||||
try:
|
||||
# Before processing, check if ffprobe is available.
|
||||
if shutil.which("ffprobe") is None:
|
||||
print("Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower).")
|
||||
|
||||
print(
|
||||
"Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower)."
|
||||
)
|
||||
|
||||
# Usage examples in help text
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Prepare and save dataset.",
|
||||
@@ -256,20 +264,15 @@ Examples:
|
||||
|
||||
# With custom worker count:
|
||||
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --workers 4
|
||||
"""
|
||||
""",
|
||||
)
|
||||
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
|
||||
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
|
||||
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
|
||||
parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
|
||||
args = parser.parse_args()
|
||||
|
||||
prepare_and_save_set(
|
||||
args.inp_dir,
|
||||
args.out_dir,
|
||||
is_finetune=not args.pretrain,
|
||||
num_workers=args.workers
|
||||
)
|
||||
|
||||
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers)
|
||||
except KeyboardInterrupt:
|
||||
print("\nOperation cancelled by user. Cleaning up...")
|
||||
if executor is not None:
|
||||
|
||||
Reference in New Issue
Block a user