mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-05 20:40:12 -08:00
clean-up eval scripts
This commit is contained in:
@@ -14,16 +14,20 @@ pip install -e .[eval]
|
||||
1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
|
||||
2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
|
||||
3. Unzip the downloaded datasets and place them in the `data/` directory.
|
||||
4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py`
|
||||
5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
|
||||
4. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
|
||||
|
||||
### Batch Inference for Test Set
|
||||
|
||||
To run batch inference for evaluations, execute the following commands:
|
||||
|
||||
```bash
|
||||
# batch inference for evaluations
|
||||
accelerate config # if not set before
|
||||
# if not setup accelerate config yet
|
||||
accelerate config
|
||||
|
||||
# if only perform inference
|
||||
bash src/f5_tts/eval/eval_infer_batch.sh --infer-only
|
||||
|
||||
# if inference and with corresponding evaluation, setup the following tools first
|
||||
bash src/f5_tts/eval/eval_infer_batch.sh
|
||||
```
|
||||
|
||||
@@ -35,9 +39,13 @@ bash src/f5_tts/eval/eval_infer_batch.sh
|
||||
2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
|
||||
3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
|
||||
|
||||
Then update in the following scripts with the paths you put evaluation model ckpts to.
|
||||
> [!NOTE]
|
||||
> ASR model will be automatically downloaded if `--local` not set for evaluation scripts.
|
||||
> Otherwise, you should update the `asr_ckpt_dir` path values in `eval_librispeech_test_clean.py` or `eval_seedtts_testset.py`.
|
||||
>
|
||||
> WavLM model must be downloaded and your `wavlm_ckpt_dir` path updated in `eval_librispeech_test_clean.py` and `eval_seedtts_testset.py`.
|
||||
|
||||
### Objective Evaluation
|
||||
### Objective Evaluation Examples
|
||||
|
||||
Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
|
||||
```bash
|
||||
@@ -50,3 +58,6 @@ python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_
|
||||
# Evaluation [UTMOS]. --ext: Audio extension
|
||||
python src/f5_tts/eval/eval_utmos.py --audio_dir <WAV_DIR> --ext wav
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Evaluation results can also be found in `_*_results.jsonl` files saved in `<GEN_WAV_DIR>`/`<WAV_DIR>`.
|
||||
|
||||
@@ -48,6 +48,11 @@ def main():
|
||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||
|
||||
parser.add_argument("-t", "--testset", required=True)
|
||||
parser.add_argument(
|
||||
"-p", "--librispeech_test_clean_path", default=f"{rel_path}/data/LibriSpeech/test-clean", type=str
|
||||
)
|
||||
|
||||
parser.add_argument("--local", action="store_true", help="Use local vocoder checkpoint directory")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -83,7 +88,7 @@ def main():
|
||||
|
||||
if testset == "ls_pc_test_clean":
|
||||
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
||||
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
||||
librispeech_test_clean_path = args.librispeech_test_clean_path
|
||||
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
||||
|
||||
elif testset == "seedtts_test_zh":
|
||||
@@ -121,7 +126,7 @@ def main():
|
||||
)
|
||||
|
||||
# Vocoder model
|
||||
local = False
|
||||
local = args.local
|
||||
if mel_spec_type == "vocos":
|
||||
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
||||
elif mel_spec_type == "bigvgan":
|
||||
|
||||
@@ -1,18 +1,116 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
export PYTHONWARNINGS="ignore::UserWarning,ignore::FutureWarning"
|
||||
|
||||
# e.g. F5-TTS, 16 NFE
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16
|
||||
# Configuration parameters
|
||||
MODEL_NAME="F5TTS_v1_Base"
|
||||
SEEDS=(0 1 2)
|
||||
CKPTSTEPS=(1250000)
|
||||
TASKS=("seedtts_test_zh" "seedtts_test_en" "ls_pc_test_clean")
|
||||
LS_TEST_CLEAN_PATH="data/LibriSpeech/test-clean"
|
||||
GPUS="[0,1,2,3,4,5,6,7]"
|
||||
OFFLINE_MODE=false
|
||||
|
||||
# e.g. Vanilla E2 TTS, 32 NFE
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
||||
# Parse arguments
|
||||
if [ $OFFLINE_MODE = true ]; then
|
||||
LOCAL="--local"
|
||||
else
|
||||
LOCAL=""
|
||||
fi
|
||||
INFER_ONLY=false
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--infer-only)
|
||||
INFER_ONLY=true
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
echo "======== Unknown parameter: $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
||||
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0
|
||||
echo "======== Starting F5-TTS batch evaluation task..."
|
||||
if [ "$INFER_ONLY" = true ]; then
|
||||
echo "======== Mode: Execute infer tasks only"
|
||||
else
|
||||
echo "======== Mode: Execute full pipeline (infer + eval)"
|
||||
fi
|
||||
|
||||
# etc.
|
||||
# Function: Execute eval tasks
|
||||
execute_eval_tasks() {
|
||||
local ckptstep=$1
|
||||
local seed=$2
|
||||
local task_name=$3
|
||||
|
||||
local gen_wav_dir="results/${MODEL_NAME}_${ckptstep}/${task_name}/seed${seed}_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0"
|
||||
|
||||
echo ">>>>>>>> Starting eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}"
|
||||
|
||||
case $task_name in
|
||||
"seedtts_test_zh")
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh -g "$gen_wav_dir" -n "$GPUS" $LOCAL
|
||||
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
|
||||
;;
|
||||
"seedtts_test_en")
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l en -g "$gen_wav_dir" -n "$GPUS" $LOCAL
|
||||
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
|
||||
;;
|
||||
"ls_pc_test_clean")
|
||||
python src/f5_tts/eval/eval_librispeech_test_clean.py -e wer -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL
|
||||
python src/f5_tts/eval/eval_librispeech_test_clean.py -e sim -g "$gen_wav_dir" -n "$GPUS" -p "$LS_TEST_CLEAN_PATH" $LOCAL
|
||||
python src/f5_tts/eval/eval_utmos.py --audio_dir "$gen_wav_dir"
|
||||
;;
|
||||
esac
|
||||
|
||||
echo ">>>>>>>> Completed eval task: ckptstep=${ckptstep}, seed=${seed}, task=${task_name}"
|
||||
}
|
||||
|
||||
# Main execution loop
|
||||
for ckptstep in "${CKPTSTEPS[@]}"; do
|
||||
echo "======== Processing ckptstep: ${ckptstep}"
|
||||
|
||||
for seed in "${SEEDS[@]}"; do
|
||||
echo "-------- Processing seed: ${seed}"
|
||||
|
||||
# Store eval task PIDs for current seed (if not infer-only mode)
|
||||
if [ "$INFER_ONLY" = false ]; then
|
||||
declare -a eval_pids
|
||||
fi
|
||||
|
||||
# Execute each infer task sequentially
|
||||
for task in "${TASKS[@]}"; do
|
||||
echo ">>>>>>>> Executing infer task: accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n \"${MODEL_NAME}\" -t \"${task}\" -c ${ckptstep} $LOCAL"
|
||||
|
||||
# Execute infer task (foreground execution, wait for completion)
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s ${seed} -n "${MODEL_NAME}" -t "${task}" -c ${ckptstep} -p "${LS_TEST_CLEAN_PATH}" $LOCAL
|
||||
|
||||
# If not infer-only mode, launch corresponding eval task
|
||||
if [ "$INFER_ONLY" = false ]; then
|
||||
# Launch corresponding eval task (background execution, non-blocking for next infer)
|
||||
execute_eval_tasks $ckptstep $seed $task &
|
||||
eval_pids+=($!)
|
||||
fi
|
||||
done
|
||||
|
||||
# If not infer-only mode, wait for all eval tasks of current seed to complete
|
||||
if [ "$INFER_ONLY" = false ]; then
|
||||
echo ">>>>>>>> All infer tasks for seed ${seed} completed, waiting for corresponding eval tasks to finish..."
|
||||
|
||||
for pid in "${eval_pids[@]}"; do
|
||||
wait $pid
|
||||
done
|
||||
|
||||
unset eval_pids # Clean up array
|
||||
fi
|
||||
echo "-------- All eval tasks for seed ${seed} completed"
|
||||
done
|
||||
|
||||
echo "======== Completed ckptstep: ${ckptstep}"
|
||||
echo
|
||||
done
|
||||
|
||||
echo "======== All tasks completed!"
|
||||
18
src/f5_tts/eval/eval_infer_batch_example.sh
Normal file
18
src/f5_tts/eval/eval_infer_batch_example.sh
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
# e.g. F5-TTS, 16 NFE
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16 -p data/LibriSpeech/test-clean
|
||||
|
||||
# e.g. Vanilla E2 TTS, 32 NFE
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0 -p data/LibriSpeech/test-clean
|
||||
|
||||
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
||||
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe16_vocos_ss-1_cfg2.0_speed1.0
|
||||
|
||||
# etc.
|
||||
@@ -1,6 +1,7 @@
|
||||
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@@ -25,11 +26,26 @@ def get_args():
|
||||
parser.add_argument("-l", "--lang", type=str, default="en")
|
||||
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
||||
parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
|
||||
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
|
||||
parser.add_argument(
|
||||
"-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])"
|
||||
)
|
||||
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def parse_gpu_nums(gpu_nums_str):
|
||||
try:
|
||||
if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"):
|
||||
gpu_list = ast.literal_eval(gpu_nums_str)
|
||||
if isinstance(gpu_list, list):
|
||||
return gpu_list
|
||||
return list(range(int(gpu_nums_str)))
|
||||
except (ValueError, SyntaxError):
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
eval_task = args.eval_task
|
||||
@@ -38,7 +54,7 @@ def main():
|
||||
gen_wav_dir = args.gen_wav_dir
|
||||
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
||||
|
||||
gpus = list(range(args.gpu_nums))
|
||||
gpus = parse_gpu_nums(args.gpu_nums)
|
||||
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
|
||||
|
||||
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Evaluate with Seed-TTS testset
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@@ -24,11 +25,26 @@ def get_args():
|
||||
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
|
||||
parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
|
||||
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
||||
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
|
||||
parser.add_argument(
|
||||
"-n", "--gpu_nums", type=str, default="8", help="Number of GPUs to use (e.g., 8) or GPU list (e.g., [0,1,2,3])"
|
||||
)
|
||||
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def parse_gpu_nums(gpu_nums_str):
|
||||
try:
|
||||
if gpu_nums_str.startswith("[") and gpu_nums_str.endswith("]"):
|
||||
gpu_list = ast.literal_eval(gpu_nums_str)
|
||||
if isinstance(gpu_list, list):
|
||||
return gpu_list
|
||||
return list(range(int(gpu_nums_str)))
|
||||
except (ValueError, SyntaxError):
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Invalid GPU specification: {gpu_nums_str}. Use a number (e.g., 8) or a list (e.g., [0,1,2,3])"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
eval_task = args.eval_task
|
||||
@@ -38,7 +54,7 @@ def main():
|
||||
|
||||
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
|
||||
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
||||
gpus = list(range(args.gpu_nums))
|
||||
gpus = parse_gpu_nums(args.gpu_nums)
|
||||
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
|
||||
|
||||
local = args.local
|
||||
|
||||
@@ -395,14 +395,21 @@ def run_sim(args):
|
||||
wav1, sr1 = torchaudio.load(gen_wav)
|
||||
wav2, sr2 = torchaudio.load(prompt_wav)
|
||||
|
||||
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
||||
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
||||
wav1 = resample1(wav1)
|
||||
wav2 = resample2(wav2)
|
||||
|
||||
if use_gpu:
|
||||
wav1 = wav1.cuda(device)
|
||||
wav2 = wav2.cuda(device)
|
||||
|
||||
if sr1 != 16000:
|
||||
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
||||
if use_gpu:
|
||||
resample1 = resample1.cuda(device)
|
||||
wav1 = resample1(wav1)
|
||||
if sr2 != 16000:
|
||||
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
||||
if use_gpu:
|
||||
resample2 = resample2.cuda(device)
|
||||
wav2 = resample2(wav2)
|
||||
|
||||
with torch.no_grad():
|
||||
emb1 = model(wav1)
|
||||
emb2 = model(wav2)
|
||||
|
||||
32
src/f5_tts/scripts/count_max_epoch_precise.py
Normal file
32
src/f5_tts/scripts/count_max_epoch_precise.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import math
|
||||
|
||||
from torch.utils.data import SequentialSampler
|
||||
|
||||
from f5_tts.model.dataset import DynamicBatchSampler, load_dataset
|
||||
|
||||
|
||||
train_dataset = load_dataset("Emilia_ZH_EN", "pinyin")
|
||||
sampler = SequentialSampler(train_dataset)
|
||||
|
||||
gpus = 8
|
||||
batch_size_per_gpu = 38400
|
||||
max_samples_per_gpu = 64
|
||||
max_updates = 1250000
|
||||
|
||||
batch_sampler = DynamicBatchSampler(
|
||||
sampler,
|
||||
batch_size_per_gpu,
|
||||
max_samples=max_samples_per_gpu,
|
||||
random_seed=666,
|
||||
drop_residual=False,
|
||||
)
|
||||
|
||||
print(
|
||||
f"One epoch has {len(batch_sampler) / gpus} updates if gpus={gpus}, with "
|
||||
f"batch_size_per_gpu={batch_size_per_gpu} (frames) & "
|
||||
f"max_samples_per_gpu={max_samples_per_gpu}."
|
||||
)
|
||||
print(
|
||||
f"If gpus={gpus}, for max_updates={max_updates} "
|
||||
f"should set epoch={math.ceil(max_updates / len(batch_sampler) * gpus)}."
|
||||
)
|
||||
Reference in New Issue
Block a user