change prepare_csv_wavs from relative path to absolute path and get duration info with soundfile and torchaudio

This commit is contained in:
ZhikangNiu
2026-01-22 12:27:23 +08:00
parent 1d2f7c5389
commit 97fdc7fbb4
2 changed files with 76 additions and 63 deletions

View File

@@ -27,11 +27,12 @@ python src/f5_tts/train/datasets/prepare_libritts.py
python src/f5_tts/train/datasets/prepare_ljspeech.py
```
### 2. Create custom dataset with metadata.csv
### 2. Create custom dataset with CSV
Prepare a CSV with two columns using a required header: `audio_file|text`. Audio paths must be absolute.
Use guidance see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029).
```bash
python src/f5_tts/train/datasets/prepare_csv_wavs.py
python src/f5_tts/train/datasets/prepare_csv_wavs.py /path/to/metadata.csv /path/to/output
```
## Training & Finetuning

View File

@@ -1,9 +1,22 @@
"""
Usage:
python prepare_csv_wavs.py /path/to/metadata.csv /output/dataset/path [--pretrain] [--workers N]
CSV format (header required, "|" delimiter):
audio_file|text
/path/to/wavs/audio_0001.wav|Yo! Hello? Hello?
/path/to/wavs/audio_0002.wav|Hi, how are you doing today? I want to go shopping and buy me some lemons.
Notes:
- audio_file must be an absolute path.
"""
import concurrent.futures
import multiprocessing
import os
import shutil
import signal
import subprocess # For invoking ffprobe
import subprocess
import sys
from contextlib import contextmanager
@@ -16,6 +29,7 @@ import json
from importlib.resources import files
from pathlib import Path
import soundfile as sf
import torchaudio
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
@@ -25,23 +39,19 @@ from f5_tts.model.utils import convert_char_to_pinyin
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
def is_csv_wavs_format(input_dataset_dir):
fpath = Path(input_dataset_dir)
metadata = fpath / "metadata.csv"
wavs = fpath / "wavs"
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
# Configuration constants
BATCH_SIZE = 100 # Batch size for text conversion
MAX_WORKERS = max(1, multiprocessing.cpu_count() - 1) # Leave one CPU free
THREAD_NAME_PREFIX = "AudioProcessor"
CHUNK_SIZE = 100 # Number of files to process per worker batch
executor = None # Global executor for cleanup
def is_csv_wavs_format(input_path):
fpath = Path(input_path).expanduser()
return fpath.is_file() and fpath.suffix.lower() == ".csv"
@contextmanager
def graceful_exit():
"""Context manager for graceful shutdown on signals"""
@@ -89,15 +99,16 @@ def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE):
return converted_texts
def prepare_csv_wavs_dir(input_dir, num_workers=None):
def prepare_csv_wavs_dir(input_path, num_workers=None):
global executor
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
input_dir = Path(input_dir)
metadata_path = input_dir / "metadata.csv"
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
if not is_csv_wavs_format(input_path):
raise ValueError(f"input must be a .csv file: {input_path}")
audio_path_text_pairs = read_audio_text_pairs(Path(input_path).expanduser().as_posix())
polyphone = True
total_files = len(audio_path_text_pairs)
if total_files == 0:
raise RuntimeError("No valid rows found in CSV.")
# Use provided worker count or calculate optimal number
worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files)
@@ -155,10 +166,12 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None):
def get_audio_duration(audio_path, timeout=5):
"""
Get the duration of an audio file in seconds using ffmpeg's ffprobe.
Falls back to torchaudio.load() if ffprobe fails.
"""
"""Get the duration of an audio file in seconds with fallbacks."""
try:
return sf.info(audio_path).duration
except Exception as e:
print(f"Warning: soundfile failed for {audio_path} with error: {e}. Falling back to ffprobe.")
try:
cmd = [
"ffprobe",
@@ -178,27 +191,39 @@ def get_audio_duration(audio_path, timeout=5):
return float(duration_str)
raise ValueError("Empty duration string from ffprobe.")
except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e:
print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.")
try:
audio, sample_rate = torchaudio.load(audio_path)
return audio.shape[1] / sample_rate
except Exception as e:
raise RuntimeError(f"Both ffprobe and torchaudio failed for {audio_path}: {e}")
print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.info.")
try:
info = torchaudio.info(audio_path)
if info.sample_rate > 0:
return info.num_frames / info.sample_rate
raise ValueError("Invalid sample_rate from torchaudio.info.")
except Exception as e:
raise RuntimeError(f"failed to get duration for {audio_path}: {e}")
def read_audio_text_pairs(csv_file_path):
audio_text_pairs = []
parent = Path(csv_file_path).parent
with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile:
csv_path = Path(csv_file_path).expanduser().absolute()
with open(csv_path.as_posix(), mode="r", newline="", encoding="utf-8-sig") as csvfile:
reader = csv.reader(csvfile, delimiter="|")
next(reader) # Skip the header row
for row in reader:
if len(row) >= 2:
audio_file = row[0].strip() # First column: audio file path
text = row[1].strip() # Second column: text
audio_file_path = parent / audio_file
audio_text_pairs.append((audio_file_path.as_posix(), text))
header = next(reader, None)
if header is None:
return audio_text_pairs
if len(header) < 2 or header[0].strip() != "audio_file" or header[1].strip() != "text":
raise ValueError("CSV header must be: audio_file|text")
for row_idx, row in enumerate(reader, start=2):
if len(row) < 2:
continue
audio_file = row[0].strip()
text = row[1].strip()
if not audio_file:
continue
audio_path = Path(audio_file).expanduser()
if not audio_path.is_absolute():
raise ValueError(f"audio_file must be an absolute path (row {row_idx}): {audio_file}")
audio_text_pairs.append((audio_path.as_posix(), text))
return audio_text_pairs
@@ -242,35 +267,22 @@ def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
def get_args():
parser = argparse.ArgumentParser(description="Prepare and save dataset.")
parser.add_argument(
"inp_dir",
type=str,
help="Input CSV with header 'audio_file|text' and absolute wav paths.",
)
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
return parser.parse_args()
def cli():
try:
# Before processing, check if ffprobe is available.
if shutil.which("ffprobe") is None:
print(
"Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower)."
)
# Usage examples in help text
parser = argparse.ArgumentParser(
description="Prepare and save dataset.",
epilog="""
Examples:
# For fine-tuning (default):
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path
# For pre-training:
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --pretrain
# With custom worker count:
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --workers 4
""",
)
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
args = parser.parse_args()
args = get_args()
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers)
except KeyboardInterrupt:
print("\nOperation cancelled by user. Cleaning up...")