mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-03-12 21:02:50 -07:00
change prepare_csv_wavs from relative path to absolute path and get duration info with soundfile and torchaudio
This commit is contained in:
@@ -27,11 +27,12 @@ python src/f5_tts/train/datasets/prepare_libritts.py
|
||||
python src/f5_tts/train/datasets/prepare_ljspeech.py
|
||||
```
|
||||
|
||||
### 2. Create custom dataset with metadata.csv
|
||||
### 2. Create custom dataset with CSV
|
||||
Prepare a CSV with two columns using a required header: `audio_file|text`. Audio paths must be absolute.
|
||||
Use guidance see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029).
|
||||
|
||||
```bash
|
||||
python src/f5_tts/train/datasets/prepare_csv_wavs.py
|
||||
python src/f5_tts/train/datasets/prepare_csv_wavs.py /path/to/metadata.csv /path/to/output
|
||||
```
|
||||
|
||||
## Training & Finetuning
|
||||
|
||||
@@ -1,9 +1,22 @@
|
||||
"""
|
||||
Usage:
|
||||
python prepare_csv_wavs.py /path/to/metadata.csv /output/dataset/path [--pretrain] [--workers N]
|
||||
|
||||
CSV format (header required, "|" delimiter):
|
||||
audio_file|text
|
||||
/path/to/wavs/audio_0001.wav|Yo! Hello? Hello?
|
||||
/path/to/wavs/audio_0002.wav|Hi, how are you doing today? I want to go shopping and buy me some lemons.
|
||||
|
||||
Notes:
|
||||
- audio_file must be an absolute path.
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess # For invoking ffprobe
|
||||
import subprocess
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
|
||||
@@ -16,6 +29,7 @@ import json
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
import torchaudio
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
@@ -25,23 +39,19 @@ from f5_tts.model.utils import convert_char_to_pinyin
|
||||
|
||||
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
|
||||
|
||||
|
||||
def is_csv_wavs_format(input_dataset_dir):
|
||||
fpath = Path(input_dataset_dir)
|
||||
metadata = fpath / "metadata.csv"
|
||||
wavs = fpath / "wavs"
|
||||
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
|
||||
|
||||
|
||||
# Configuration constants
|
||||
BATCH_SIZE = 100 # Batch size for text conversion
|
||||
MAX_WORKERS = max(1, multiprocessing.cpu_count() - 1) # Leave one CPU free
|
||||
THREAD_NAME_PREFIX = "AudioProcessor"
|
||||
CHUNK_SIZE = 100 # Number of files to process per worker batch
|
||||
|
||||
executor = None # Global executor for cleanup
|
||||
|
||||
|
||||
def is_csv_wavs_format(input_path):
|
||||
fpath = Path(input_path).expanduser()
|
||||
return fpath.is_file() and fpath.suffix.lower() == ".csv"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def graceful_exit():
|
||||
"""Context manager for graceful shutdown on signals"""
|
||||
@@ -89,15 +99,16 @@ def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE):
|
||||
return converted_texts
|
||||
|
||||
|
||||
def prepare_csv_wavs_dir(input_dir, num_workers=None):
|
||||
def prepare_csv_wavs_dir(input_path, num_workers=None):
|
||||
global executor
|
||||
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
|
||||
input_dir = Path(input_dir)
|
||||
metadata_path = input_dir / "metadata.csv"
|
||||
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
|
||||
if not is_csv_wavs_format(input_path):
|
||||
raise ValueError(f"input must be a .csv file: {input_path}")
|
||||
audio_path_text_pairs = read_audio_text_pairs(Path(input_path).expanduser().as_posix())
|
||||
|
||||
polyphone = True
|
||||
total_files = len(audio_path_text_pairs)
|
||||
if total_files == 0:
|
||||
raise RuntimeError("No valid rows found in CSV.")
|
||||
|
||||
# Use provided worker count or calculate optimal number
|
||||
worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files)
|
||||
@@ -155,10 +166,12 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None):
|
||||
|
||||
|
||||
def get_audio_duration(audio_path, timeout=5):
|
||||
"""
|
||||
Get the duration of an audio file in seconds using ffmpeg's ffprobe.
|
||||
Falls back to torchaudio.load() if ffprobe fails.
|
||||
"""
|
||||
"""Get the duration of an audio file in seconds with fallbacks."""
|
||||
try:
|
||||
return sf.info(audio_path).duration
|
||||
except Exception as e:
|
||||
print(f"Warning: soundfile failed for {audio_path} with error: {e}. Falling back to ffprobe.")
|
||||
|
||||
try:
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
@@ -178,27 +191,39 @@ def get_audio_duration(audio_path, timeout=5):
|
||||
return float(duration_str)
|
||||
raise ValueError("Empty duration string from ffprobe.")
|
||||
except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e:
|
||||
print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.")
|
||||
try:
|
||||
audio, sample_rate = torchaudio.load(audio_path)
|
||||
return audio.shape[1] / sample_rate
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Both ffprobe and torchaudio failed for {audio_path}: {e}")
|
||||
print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.info.")
|
||||
|
||||
try:
|
||||
info = torchaudio.info(audio_path)
|
||||
if info.sample_rate > 0:
|
||||
return info.num_frames / info.sample_rate
|
||||
raise ValueError("Invalid sample_rate from torchaudio.info.")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to get duration for {audio_path}: {e}")
|
||||
|
||||
|
||||
def read_audio_text_pairs(csv_file_path):
|
||||
audio_text_pairs = []
|
||||
|
||||
parent = Path(csv_file_path).parent
|
||||
with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile:
|
||||
csv_path = Path(csv_file_path).expanduser().absolute()
|
||||
with open(csv_path.as_posix(), mode="r", newline="", encoding="utf-8-sig") as csvfile:
|
||||
reader = csv.reader(csvfile, delimiter="|")
|
||||
next(reader) # Skip the header row
|
||||
for row in reader:
|
||||
if len(row) >= 2:
|
||||
audio_file = row[0].strip() # First column: audio file path
|
||||
text = row[1].strip() # Second column: text
|
||||
audio_file_path = parent / audio_file
|
||||
audio_text_pairs.append((audio_file_path.as_posix(), text))
|
||||
header = next(reader, None)
|
||||
if header is None:
|
||||
return audio_text_pairs
|
||||
if len(header) < 2 or header[0].strip() != "audio_file" or header[1].strip() != "text":
|
||||
raise ValueError("CSV header must be: audio_file|text")
|
||||
for row_idx, row in enumerate(reader, start=2):
|
||||
if len(row) < 2:
|
||||
continue
|
||||
audio_file = row[0].strip()
|
||||
text = row[1].strip()
|
||||
if not audio_file:
|
||||
continue
|
||||
audio_path = Path(audio_file).expanduser()
|
||||
if not audio_path.is_absolute():
|
||||
raise ValueError(f"audio_file must be an absolute path (row {row_idx}): {audio_file}")
|
||||
audio_text_pairs.append((audio_path.as_posix(), text))
|
||||
|
||||
return audio_text_pairs
|
||||
|
||||
@@ -242,35 +267,22 @@ def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers
|
||||
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="Prepare and save dataset.")
|
||||
parser.add_argument(
|
||||
"inp_dir",
|
||||
type=str,
|
||||
help="Input CSV with header 'audio_file|text' and absolute wav paths.",
|
||||
)
|
||||
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
|
||||
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
|
||||
parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def cli():
|
||||
try:
|
||||
# Before processing, check if ffprobe is available.
|
||||
if shutil.which("ffprobe") is None:
|
||||
print(
|
||||
"Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower)."
|
||||
)
|
||||
|
||||
# Usage examples in help text
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Prepare and save dataset.",
|
||||
epilog="""
|
||||
Examples:
|
||||
# For fine-tuning (default):
|
||||
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path
|
||||
|
||||
# For pre-training:
|
||||
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --pretrain
|
||||
|
||||
# With custom worker count:
|
||||
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --workers 4
|
||||
""",
|
||||
)
|
||||
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
|
||||
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
|
||||
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
|
||||
parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
|
||||
args = parser.parse_args()
|
||||
|
||||
args = get_args()
|
||||
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers)
|
||||
except KeyboardInterrupt:
|
||||
print("\nOperation cancelled by user. Cleaning up...")
|
||||
|
||||
Reference in New Issue
Block a user