save every line wer results and update utmos evaluation

This commit is contained in:
ZhikangNiu
2024-12-14 10:41:03 +08:00
parent c85252f59a
commit 7931fcb7d0
4 changed files with 72 additions and 4 deletions

View File

@@ -10,7 +10,7 @@ import multiprocessing as mp
from importlib.resources import files
import numpy as np
import json
from f5_tts.eval.utils_eval import (
get_librispeech_test,
run_asr_wer,
@@ -56,12 +56,19 @@ def main():
# --------------------------- WER ---------------------------
if eval_task == "wer":
wers = []
wer_results = []
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_asr_wer, args)
for wers_ in results:
wers.extend(wers_)
with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f:
for line in wers:
wer_results.append(line["wer"])
json_line = json.dumps(line, ensure_ascii=False)
f.write(json_line + "\n")
wer = round(np.mean(wers) * 100, 3)
print(f"\nTotal {len(wers)} samples")
print(f"WER : {wer}%")

View File

@@ -10,7 +10,7 @@ import multiprocessing as mp
from importlib.resources import files
import numpy as np
import json
from f5_tts.eval.utils_eval import (
get_seed_tts_test,
run_asr_wer,
@@ -56,12 +56,19 @@ def main():
if eval_task == "wer":
wers = []
wer_results = []
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_asr_wer, args)
for wers_ in results:
wers.extend(wers_)
with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f:
for line in wers:
wer_results.append(line["wer"])
json_line = json.dumps(line, ensure_ascii=False)
f.write(json_line + "\n")
wer = round(np.mean(wers) * 100, 3)
print(f"\nTotal {len(wers)} samples")
print(f"WER : {wer}%")

View File

@@ -0,0 +1,47 @@
import torch
import librosa
from pathlib import Path
import json
from tqdm import tqdm
import argparse
def main():
parser = argparse.ArgumentParser(description="Evaluate UTMOS scores for audio files.")
parser.add_argument(
"--audio_dir", type=str, required=True, help="Path to the directory containing WAV audio files."
)
parser.add_argument("--ext", type=str, default="wav", help="audio extension.")
parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (e.g. 'cuda' or 'cpu').")
args = parser.parse_args()
device = "cuda" if args.device and torch.cuda.is_available() else "cpu"
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
predictor = predictor.to(device)
lines = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
results = {}
utmos_result = 0
for line in tqdm(lines, desc="Processing"):
wave_name = line.stem
wave, sr = librosa.load(line, sr=None, mono=True)
wave_tensor = torch.from_numpy(wave).to(device).unsqueeze(0)
score = predictor(wave_tensor, sr)
results[str(wave_name)] = score.item()
utmos_result += score.item()
avg_score = utmos_result / len(lines) if len(lines) > 0 else 0
print(f"UTMOS: {avg_score}")
output_path = Path(args.audio_dir) / "utmos_results.json"
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=4)
print(f"Results have been saved to {output_path}")
if __name__ == "__main__":
main()

View File

@@ -7,7 +7,7 @@ import torch
import torch.nn.functional as F
import torchaudio
from tqdm import tqdm
from pathlib import Path
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import convert_char_to_pinyin
@@ -360,7 +360,14 @@ def run_asr_wer(args):
# dele = measures["deletions"] / len(ref_list)
# inse = measures["insertions"] / len(ref_list)
wers.append(wer)
wers.append(
{
"wav": Path(gen_wav).stem, # wav name
"truth": truth, # raw_truth
"hypo": hypo, # raw_hypo
"wer": wer, # wer score
}
)
return wers