From 7931fcb7d09454b078bdb937cf437a17d2e7daeb Mon Sep 17 00:00:00 2001 From: ZhikangNiu Date: Sat, 14 Dec 2024 10:41:03 +0800 Subject: [PATCH] save every line wer results and update utmos evaluation --- .../eval/eval_librispeech_test_clean.py | 9 +++- src/f5_tts/eval/eval_seedtts_testset.py | 9 +++- src/f5_tts/eval/eval_utmos.py | 47 +++++++++++++++++++ src/f5_tts/eval/utils_eval.py | 11 ++++- 4 files changed, 72 insertions(+), 4 deletions(-) create mode 100644 src/f5_tts/eval/eval_utmos.py diff --git a/src/f5_tts/eval/eval_librispeech_test_clean.py b/src/f5_tts/eval/eval_librispeech_test_clean.py index a5f76e0..631f284 100644 --- a/src/f5_tts/eval/eval_librispeech_test_clean.py +++ b/src/f5_tts/eval/eval_librispeech_test_clean.py @@ -10,7 +10,7 @@ import multiprocessing as mp from importlib.resources import files import numpy as np - +import json from f5_tts.eval.utils_eval import ( get_librispeech_test, run_asr_wer, @@ -56,12 +56,19 @@ def main(): # --------------------------- WER --------------------------- if eval_task == "wer": wers = [] + wer_results = [] with mp.Pool(processes=len(gpus)) as pool: args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_asr_wer, args) for wers_ in results: wers.extend(wers_) + with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f: + for line in wers: + wer_results.append(line["wer"]) + json_line = json.dumps(line, ensure_ascii=False) + f.write(json_line + "\n") + wer = round(np.mean(wers) * 100, 3) print(f"\nTotal {len(wers)} samples") print(f"WER : {wer}%") diff --git a/src/f5_tts/eval/eval_seedtts_testset.py b/src/f5_tts/eval/eval_seedtts_testset.py index 5cc1987..967a919 100644 --- a/src/f5_tts/eval/eval_seedtts_testset.py +++ b/src/f5_tts/eval/eval_seedtts_testset.py @@ -10,7 +10,7 @@ import multiprocessing as mp from importlib.resources import files import numpy as np - +import json from f5_tts.eval.utils_eval import ( get_seed_tts_test, run_asr_wer, @@ -56,12 +56,19 @@ def main(): if eval_task == "wer": wers = [] + wer_results = [] with mp.Pool(processes=len(gpus)) as pool: args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set] results = pool.map(run_asr_wer, args) for wers_ in results: wers.extend(wers_) + with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f: + for line in wers: + wer_results.append(line["wer"]) + json_line = json.dumps(line, ensure_ascii=False) + f.write(json_line + "\n") + wer = round(np.mean(wers) * 100, 3) print(f"\nTotal {len(wers)} samples") print(f"WER : {wer}%") diff --git a/src/f5_tts/eval/eval_utmos.py b/src/f5_tts/eval/eval_utmos.py new file mode 100644 index 0000000..65196ce --- /dev/null +++ b/src/f5_tts/eval/eval_utmos.py @@ -0,0 +1,47 @@ +import torch +import librosa +from pathlib import Path +import json +from tqdm import tqdm +import argparse + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate UTMOS scores for audio files.") + parser.add_argument( + "--audio_dir", type=str, required=True, help="Path to the directory containing WAV audio files." + ) + parser.add_argument("--ext", type=str, default="wav", help="audio extension.") + parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (e.g. 'cuda' or 'cpu').") + + args = parser.parse_args() + + device = "cuda" if args.device and torch.cuda.is_available() else "cpu" + + predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True) + predictor = predictor.to(device) + + lines = list(Path(args.audio_dir).rglob(f"*.{args.ext}")) + results = {} + utmos_result = 0 + + for line in tqdm(lines, desc="Processing"): + wave_name = line.stem + wave, sr = librosa.load(line, sr=None, mono=True) + wave_tensor = torch.from_numpy(wave).to(device).unsqueeze(0) + score = predictor(wave_tensor, sr) + results[str(wave_name)] = score.item() + utmos_result += score.item() + + avg_score = utmos_result / len(lines) if len(lines) > 0 else 0 + print(f"UTMOS: {avg_score}") + + output_path = Path(args.audio_dir) / "utmos_results.json" + with open(output_path, "w", encoding="utf-8") as f: + json.dump(results, f, ensure_ascii=False, indent=4) + + print(f"Results have been saved to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/src/f5_tts/eval/utils_eval.py b/src/f5_tts/eval/utils_eval.py index 00cd97a..72f1759 100644 --- a/src/f5_tts/eval/utils_eval.py +++ b/src/f5_tts/eval/utils_eval.py @@ -7,7 +7,7 @@ import torch import torch.nn.functional as F import torchaudio from tqdm import tqdm - +from pathlib import Path from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL from f5_tts.model.modules import MelSpec from f5_tts.model.utils import convert_char_to_pinyin @@ -360,7 +360,14 @@ def run_asr_wer(args): # dele = measures["deletions"] / len(ref_list) # inse = measures["insertions"] / len(ref_list) - wers.append(wer) + wers.append( + { + "wav": Path(gen_wav).stem, # wav name + "truth": truth, # raw_truth + "hypo": hypo, # raw_hypo + "wer": wer, # wer score + } + ) return wers