mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-04 17:18:19 -08:00
save every line wer results and update utmos evaluation
This commit is contained in:
@@ -10,7 +10,7 @@ import multiprocessing as mp
|
||||
from importlib.resources import files
|
||||
|
||||
import numpy as np
|
||||
|
||||
import json
|
||||
from f5_tts.eval.utils_eval import (
|
||||
get_librispeech_test,
|
||||
run_asr_wer,
|
||||
@@ -56,12 +56,19 @@ def main():
|
||||
# --------------------------- WER ---------------------------
|
||||
if eval_task == "wer":
|
||||
wers = []
|
||||
wer_results = []
|
||||
with mp.Pool(processes=len(gpus)) as pool:
|
||||
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
||||
results = pool.map(run_asr_wer, args)
|
||||
for wers_ in results:
|
||||
wers.extend(wers_)
|
||||
|
||||
with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f:
|
||||
for line in wers:
|
||||
wer_results.append(line["wer"])
|
||||
json_line = json.dumps(line, ensure_ascii=False)
|
||||
f.write(json_line + "\n")
|
||||
|
||||
wer = round(np.mean(wers) * 100, 3)
|
||||
print(f"\nTotal {len(wers)} samples")
|
||||
print(f"WER : {wer}%")
|
||||
|
||||
@@ -10,7 +10,7 @@ import multiprocessing as mp
|
||||
from importlib.resources import files
|
||||
|
||||
import numpy as np
|
||||
|
||||
import json
|
||||
from f5_tts.eval.utils_eval import (
|
||||
get_seed_tts_test,
|
||||
run_asr_wer,
|
||||
@@ -56,12 +56,19 @@ def main():
|
||||
|
||||
if eval_task == "wer":
|
||||
wers = []
|
||||
wer_results = []
|
||||
with mp.Pool(processes=len(gpus)) as pool:
|
||||
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
||||
results = pool.map(run_asr_wer, args)
|
||||
for wers_ in results:
|
||||
wers.extend(wers_)
|
||||
|
||||
with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f:
|
||||
for line in wers:
|
||||
wer_results.append(line["wer"])
|
||||
json_line = json.dumps(line, ensure_ascii=False)
|
||||
f.write(json_line + "\n")
|
||||
|
||||
wer = round(np.mean(wers) * 100, 3)
|
||||
print(f"\nTotal {len(wers)} samples")
|
||||
print(f"WER : {wer}%")
|
||||
|
||||
47
src/f5_tts/eval/eval_utmos.py
Normal file
47
src/f5_tts/eval/eval_utmos.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import librosa
|
||||
from pathlib import Path
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Evaluate UTMOS scores for audio files.")
|
||||
parser.add_argument(
|
||||
"--audio_dir", type=str, required=True, help="Path to the directory containing WAV audio files."
|
||||
)
|
||||
parser.add_argument("--ext", type=str, default="wav", help="audio extension.")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (e.g. 'cuda' or 'cpu').")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
device = "cuda" if args.device and torch.cuda.is_available() else "cpu"
|
||||
|
||||
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
|
||||
predictor = predictor.to(device)
|
||||
|
||||
lines = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
|
||||
results = {}
|
||||
utmos_result = 0
|
||||
|
||||
for line in tqdm(lines, desc="Processing"):
|
||||
wave_name = line.stem
|
||||
wave, sr = librosa.load(line, sr=None, mono=True)
|
||||
wave_tensor = torch.from_numpy(wave).to(device).unsqueeze(0)
|
||||
score = predictor(wave_tensor, sr)
|
||||
results[str(wave_name)] = score.item()
|
||||
utmos_result += score.item()
|
||||
|
||||
avg_score = utmos_result / len(lines) if len(lines) > 0 else 0
|
||||
print(f"UTMOS: {avg_score}")
|
||||
|
||||
output_path = Path(args.audio_dir) / "utmos_results.json"
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, ensure_ascii=False, indent=4)
|
||||
|
||||
print(f"Results have been saved to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from tqdm import tqdm
|
||||
|
||||
from pathlib import Path
|
||||
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
|
||||
from f5_tts.model.modules import MelSpec
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
@@ -360,7 +360,14 @@ def run_asr_wer(args):
|
||||
# dele = measures["deletions"] / len(ref_list)
|
||||
# inse = measures["insertions"] / len(ref_list)
|
||||
|
||||
wers.append(wer)
|
||||
wers.append(
|
||||
{
|
||||
"wav": Path(gen_wav).stem, # wav name
|
||||
"truth": truth, # raw_truth
|
||||
"hypo": hypo, # raw_hypo
|
||||
"wer": wer, # wer score
|
||||
}
|
||||
)
|
||||
|
||||
return wers
|
||||
|
||||
|
||||
Reference in New Issue
Block a user