mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-16 06:53:17 -08:00
add and run pre-commit with ruff
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
'''ADAPTIVE BATCH SIZE'''
|
||||
print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
|
||||
print(' -> least padding, gather wavs with accumulated frames in a batch\n')
|
||||
"""ADAPTIVE BATCH SIZE"""
|
||||
|
||||
print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
|
||||
print(" -> least padding, gather wavs with accumulated frames in a batch\n")
|
||||
|
||||
# data
|
||||
total_hours = 95282
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from model import M2_TTS, UNetT, DiT, MMDiT
|
||||
from model import M2_TTS, DiT
|
||||
|
||||
import torch
|
||||
import thop
|
||||
|
||||
|
||||
''' ~155M '''
|
||||
""" ~155M """
|
||||
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
|
||||
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
|
||||
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
|
||||
@@ -15,11 +17,11 @@ import thop
|
||||
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
|
||||
# transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
|
||||
|
||||
''' ~335M '''
|
||||
""" ~335M """
|
||||
# FLOPs: 622.1 G, Params: 333.2 M
|
||||
# transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
||||
# FLOPs: 363.4 G, Params: 335.8 M
|
||||
transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
||||
transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
|
||||
|
||||
model = M2_TTS(transformer=transformer)
|
||||
@@ -30,6 +32,8 @@ duration = 20
|
||||
frame_length = int(duration * target_sample_rate / hop_length)
|
||||
text_length = 150
|
||||
|
||||
flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
|
||||
flops, params = thop.profile(
|
||||
model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
|
||||
)
|
||||
print(f"FLOPs: {flops / 1e9} G")
|
||||
print(f"Params: {params / 1e6} M")
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import time
|
||||
@@ -14,9 +16,9 @@ from vocos import Vocos
|
||||
from model import CFM, UNetT, DiT
|
||||
from model.utils import (
|
||||
load_checkpoint,
|
||||
get_tokenizer,
|
||||
get_seedtts_testset_metainfo,
|
||||
get_librispeech_test_clean_metainfo,
|
||||
get_tokenizer,
|
||||
get_seedtts_testset_metainfo,
|
||||
get_librispeech_test_clean_metainfo,
|
||||
get_inference_prompt,
|
||||
)
|
||||
|
||||
@@ -38,16 +40,16 @@ tokenizer = "pinyin"
|
||||
|
||||
parser = argparse.ArgumentParser(description="batch inference")
|
||||
|
||||
parser.add_argument('-s', '--seed', default=None, type=int)
|
||||
parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
|
||||
parser.add_argument('-n', '--expname', required=True)
|
||||
parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
|
||||
parser.add_argument("-s", "--seed", default=None, type=int)
|
||||
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
||||
parser.add_argument("-n", "--expname", required=True)
|
||||
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
|
||||
|
||||
parser.add_argument('-nfe', '--nfestep', default=32, type=int)
|
||||
parser.add_argument('-o', '--odemethod', default="euler")
|
||||
parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
|
||||
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
||||
parser.add_argument("-o", "--odemethod", default="euler")
|
||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||
|
||||
parser.add_argument('-t', '--testset', required=True)
|
||||
parser.add_argument("-t", "--testset", required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -66,26 +68,26 @@ testset = args.testset
|
||||
|
||||
|
||||
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
||||
cfg_strength = 2.
|
||||
speed = 1.
|
||||
cfg_strength = 2.0
|
||||
speed = 1.0
|
||||
use_truth_duration = False
|
||||
no_ref_audio = False
|
||||
|
||||
|
||||
if exp_name == "F5TTS_Base":
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
|
||||
elif exp_name == "E2TTS_Base":
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
|
||||
|
||||
if testset == "ls_pc_test_clean":
|
||||
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
|
||||
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
||||
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
||||
|
||||
|
||||
elif testset == "seedtts_test_zh":
|
||||
metalst = "data/seedtts_testset/zh/meta.lst"
|
||||
metainfo = get_seedtts_testset_metainfo(metalst)
|
||||
@@ -96,13 +98,16 @@ elif testset == "seedtts_test_en":
|
||||
|
||||
|
||||
# path to save genereted wavs
|
||||
if seed is None: seed = random.randint(-10000, 10000)
|
||||
output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
|
||||
f"seed{seed}_{ode_method}_nfe{nfe_step}" \
|
||||
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
|
||||
f"_cfg{cfg_strength}_speed{speed}" \
|
||||
f"{'_gt-dur' if use_truth_duration else ''}" \
|
||||
if seed is None:
|
||||
seed = random.randint(-10000, 10000)
|
||||
output_dir = (
|
||||
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
||||
f"seed{seed}_{ode_method}_nfe{nfe_step}"
|
||||
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
||||
f"_cfg{cfg_strength}_speed{speed}"
|
||||
f"{'_gt-dur' if use_truth_duration else ''}"
|
||||
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------#
|
||||
@@ -110,15 +115,15 @@ output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
|
||||
use_ema = True
|
||||
|
||||
prompts_all = get_inference_prompt(
|
||||
metainfo,
|
||||
speed = speed,
|
||||
tokenizer = tokenizer,
|
||||
target_sample_rate = target_sample_rate,
|
||||
n_mel_channels = n_mel_channels,
|
||||
hop_length = hop_length,
|
||||
target_rms = target_rms,
|
||||
use_truth_duration = use_truth_duration,
|
||||
infer_batch_size = infer_batch_size,
|
||||
metainfo,
|
||||
speed=speed,
|
||||
tokenizer=tokenizer,
|
||||
target_sample_rate=target_sample_rate,
|
||||
n_mel_channels=n_mel_channels,
|
||||
hop_length=hop_length,
|
||||
target_rms=target_rms,
|
||||
use_truth_duration=use_truth_duration,
|
||||
infer_batch_size=infer_batch_size,
|
||||
)
|
||||
|
||||
# Vocoder model
|
||||
@@ -137,23 +142,19 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
||||
|
||||
# Model
|
||||
model = CFM(
|
||||
transformer = model_cls(
|
||||
**model_cfg,
|
||||
text_num_embeds = vocab_size,
|
||||
mel_dim = n_mel_channels
|
||||
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
mel_spec_kwargs=dict(
|
||||
target_sample_rate=target_sample_rate,
|
||||
n_mel_channels=n_mel_channels,
|
||||
hop_length=hop_length,
|
||||
),
|
||||
mel_spec_kwargs = dict(
|
||||
target_sample_rate = target_sample_rate,
|
||||
n_mel_channels = n_mel_channels,
|
||||
hop_length = hop_length,
|
||||
odeint_kwargs=dict(
|
||||
method=ode_method,
|
||||
),
|
||||
odeint_kwargs = dict(
|
||||
method = ode_method,
|
||||
),
|
||||
vocab_char_map = vocab_char_map,
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
|
||||
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
|
||||
|
||||
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
||||
os.makedirs(output_dir)
|
||||
@@ -163,29 +164,28 @@ accelerator.wait_for_everyone()
|
||||
start = time.time()
|
||||
|
||||
with accelerator.split_between_processes(prompts_all) as prompts:
|
||||
|
||||
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
||||
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
|
||||
ref_mels = ref_mels.to(device)
|
||||
ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
|
||||
total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
|
||||
|
||||
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
||||
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
||||
|
||||
# Inference
|
||||
with torch.inference_mode():
|
||||
generated, _ = model.sample(
|
||||
cond = ref_mels,
|
||||
text = final_text_list,
|
||||
duration = total_mel_lens,
|
||||
lens = ref_mel_lens,
|
||||
steps = nfe_step,
|
||||
cfg_strength = cfg_strength,
|
||||
sway_sampling_coef = sway_sampling_coef,
|
||||
no_ref_audio = no_ref_audio,
|
||||
seed = seed,
|
||||
cond=ref_mels,
|
||||
text=final_text_list,
|
||||
duration=total_mel_lens,
|
||||
lens=ref_mel_lens,
|
||||
steps=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
no_ref_audio=no_ref_audio,
|
||||
seed=seed,
|
||||
)
|
||||
# Final result
|
||||
for i, gen in enumerate(generated):
|
||||
gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
|
||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||
gen_mel_spec = gen.permute(0, 2, 1)
|
||||
generated_wave = vocos.decode(gen_mel_spec.cpu())
|
||||
if ref_rms_list[i] < target_rms:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
||||
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import multiprocessing as mp
|
||||
@@ -19,7 +21,7 @@ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
|
||||
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
||||
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
|
||||
|
||||
gpus = [0,1,2,3,4,5,6,7]
|
||||
gpus = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
|
||||
|
||||
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
||||
@@ -46,7 +48,7 @@ if eval_task == "wer":
|
||||
for wers_ in results:
|
||||
wers.extend(wers_)
|
||||
|
||||
wer = round(np.mean(wers)*100, 3)
|
||||
wer = round(np.mean(wers) * 100, 3)
|
||||
print(f"\nTotal {len(wers)} samples")
|
||||
print(f"WER : {wer}%")
|
||||
|
||||
@@ -62,6 +64,6 @@ if eval_task == "sim":
|
||||
for sim_ in results:
|
||||
sim_list.extend(sim_)
|
||||
|
||||
sim = round(sum(sim_list)/len(sim_list), 3)
|
||||
sim = round(sum(sim_list) / len(sim_list), 3)
|
||||
print(f"\nTotal {len(sim_list)} samples")
|
||||
print(f"SIM : {sim}")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Evaluate with Seed-TTS testset
|
||||
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import multiprocessing as mp
|
||||
@@ -14,21 +16,21 @@ from model.utils import (
|
||||
|
||||
|
||||
eval_task = "wer" # sim | wer
|
||||
lang = "zh" # zh | en
|
||||
lang = "zh" # zh | en
|
||||
metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
|
||||
# gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
|
||||
gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs
|
||||
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
|
||||
|
||||
|
||||
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
|
||||
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
||||
gpus = [0,1,2,3,4,5,6,7]
|
||||
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
||||
gpus = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
|
||||
|
||||
local = False
|
||||
if local: # use local custom checkpoint dir
|
||||
if lang == "zh":
|
||||
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
|
||||
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
|
||||
elif lang == "en":
|
||||
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
||||
else:
|
||||
@@ -48,7 +50,7 @@ if eval_task == "wer":
|
||||
for wers_ in results:
|
||||
wers.extend(wers_)
|
||||
|
||||
wer = round(np.mean(wers)*100, 3)
|
||||
wer = round(np.mean(wers) * 100, 3)
|
||||
print(f"\nTotal {len(wers)} samples")
|
||||
print(f"WER : {wer}%")
|
||||
|
||||
@@ -64,6 +66,6 @@ if eval_task == "sim":
|
||||
for sim_ in results:
|
||||
sim_list.extend(sim_)
|
||||
|
||||
sim = round(sum(sim_list)/len(sim_list), 3)
|
||||
sim = round(sum(sim_list) / len(sim_list), 3)
|
||||
print(f"\nTotal {len(sim_list)} samples")
|
||||
print(f"SIM : {sim}")
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from pathlib import Path
|
||||
@@ -17,10 +19,11 @@ from model.utils import (
|
||||
|
||||
PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
||||
|
||||
|
||||
def is_csv_wavs_format(input_dataset_dir):
|
||||
fpath = Path(input_dataset_dir)
|
||||
metadata = fpath / "metadata.csv"
|
||||
wavs = fpath / 'wavs'
|
||||
wavs = fpath / "wavs"
|
||||
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
|
||||
|
||||
|
||||
@@ -46,22 +49,24 @@ def prepare_csv_wavs_dir(input_dir):
|
||||
|
||||
return sub_result, durations, vocab_set
|
||||
|
||||
|
||||
def get_audio_duration(audio_path):
|
||||
audio, sample_rate = torchaudio.load(audio_path)
|
||||
num_channels = audio.shape[0]
|
||||
return audio.shape[1] / (sample_rate * num_channels)
|
||||
|
||||
|
||||
def read_audio_text_pairs(csv_file_path):
|
||||
audio_text_pairs = []
|
||||
|
||||
parent = Path(csv_file_path).parent
|
||||
with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csvfile:
|
||||
reader = csv.reader(csvfile, delimiter='|')
|
||||
with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile:
|
||||
reader = csv.reader(csvfile, delimiter="|")
|
||||
next(reader) # Skip the header row
|
||||
for row in reader:
|
||||
if len(row) >= 2:
|
||||
audio_file = row[0].strip() # First column: audio file path
|
||||
text = row[1].strip() # Second column: text
|
||||
text = row[1].strip() # Second column: text
|
||||
audio_file_path = parent / audio_file
|
||||
audio_text_pairs.append((audio_file_path.as_posix(), text))
|
||||
|
||||
@@ -78,12 +83,12 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
|
||||
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
|
||||
raw_arrow_path = out_dir / "raw.arrow"
|
||||
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
|
||||
for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
|
||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
dur_json_path = out_dir / "duration.json"
|
||||
with open(dur_json_path.as_posix(), 'w', encoding='utf-8') as f:
|
||||
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
|
||||
# vocab map, i.e. tokenizer
|
||||
@@ -120,13 +125,14 @@ def cli():
|
||||
# finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
|
||||
# pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
|
||||
parser = argparse.ArgumentParser(description="Prepare and save dataset.")
|
||||
parser.add_argument('inp_dir', type=str, help="Input directory containing the data.")
|
||||
parser.add_argument('out_dir', type=str, help="Output directory to save the prepared data.")
|
||||
parser.add_argument('--pretrain', action='store_true', help="Enable for new pretrain, otherwise is a fine-tune")
|
||||
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
|
||||
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
|
||||
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
# generate audio text map for Emilia ZH & EN
|
||||
# evaluate for vocab size
|
||||
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from pathlib import Path
|
||||
@@ -12,7 +14,6 @@ import json
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
from datasets import Dataset
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
|
||||
from model.utils import (
|
||||
@@ -21,13 +22,89 @@ from model.utils import (
|
||||
)
|
||||
|
||||
|
||||
out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
|
||||
out_zh = {
|
||||
"ZH_B00041_S06226",
|
||||
"ZH_B00042_S09204",
|
||||
"ZH_B00065_S09430",
|
||||
"ZH_B00065_S09431",
|
||||
"ZH_B00066_S09327",
|
||||
"ZH_B00066_S09328",
|
||||
}
|
||||
zh_filters = ["い", "て"]
|
||||
# seems synthesized audios, or heavily code-switched
|
||||
out_en = {
|
||||
"EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
|
||||
|
||||
"EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
|
||||
"EN_B00013_S00913",
|
||||
"EN_B00042_S00120",
|
||||
"EN_B00055_S04111",
|
||||
"EN_B00061_S00693",
|
||||
"EN_B00061_S01494",
|
||||
"EN_B00061_S03375",
|
||||
"EN_B00059_S00092",
|
||||
"EN_B00111_S04300",
|
||||
"EN_B00100_S03759",
|
||||
"EN_B00087_S03811",
|
||||
"EN_B00059_S00950",
|
||||
"EN_B00089_S00946",
|
||||
"EN_B00078_S05127",
|
||||
"EN_B00070_S04089",
|
||||
"EN_B00074_S09659",
|
||||
"EN_B00061_S06983",
|
||||
"EN_B00061_S07060",
|
||||
"EN_B00059_S08397",
|
||||
"EN_B00082_S06192",
|
||||
"EN_B00091_S01238",
|
||||
"EN_B00089_S07349",
|
||||
"EN_B00070_S04343",
|
||||
"EN_B00061_S02400",
|
||||
"EN_B00076_S01262",
|
||||
"EN_B00068_S06467",
|
||||
"EN_B00076_S02943",
|
||||
"EN_B00064_S05954",
|
||||
"EN_B00061_S05386",
|
||||
"EN_B00066_S06544",
|
||||
"EN_B00076_S06944",
|
||||
"EN_B00072_S08620",
|
||||
"EN_B00076_S07135",
|
||||
"EN_B00076_S09127",
|
||||
"EN_B00065_S00497",
|
||||
"EN_B00059_S06227",
|
||||
"EN_B00063_S02859",
|
||||
"EN_B00075_S01547",
|
||||
"EN_B00061_S08286",
|
||||
"EN_B00079_S02901",
|
||||
"EN_B00092_S03643",
|
||||
"EN_B00096_S08653",
|
||||
"EN_B00063_S04297",
|
||||
"EN_B00063_S04614",
|
||||
"EN_B00079_S04698",
|
||||
"EN_B00104_S01666",
|
||||
"EN_B00061_S09504",
|
||||
"EN_B00061_S09694",
|
||||
"EN_B00065_S05444",
|
||||
"EN_B00063_S06860",
|
||||
"EN_B00065_S05725",
|
||||
"EN_B00069_S07628",
|
||||
"EN_B00083_S03875",
|
||||
"EN_B00071_S07665",
|
||||
"EN_B00071_S07665",
|
||||
"EN_B00062_S04187",
|
||||
"EN_B00065_S09873",
|
||||
"EN_B00065_S09922",
|
||||
"EN_B00084_S02463",
|
||||
"EN_B00067_S05066",
|
||||
"EN_B00106_S08060",
|
||||
"EN_B00073_S06399",
|
||||
"EN_B00073_S09236",
|
||||
"EN_B00087_S00432",
|
||||
"EN_B00085_S05618",
|
||||
"EN_B00064_S01262",
|
||||
"EN_B00072_S01739",
|
||||
"EN_B00059_S03913",
|
||||
"EN_B00069_S04036",
|
||||
"EN_B00067_S05623",
|
||||
"EN_B00060_S05389",
|
||||
"EN_B00060_S07290",
|
||||
"EN_B00062_S08995",
|
||||
}
|
||||
en_filters = ["ا", "い", "て"]
|
||||
|
||||
@@ -43,18 +120,24 @@ def deal_with_audio_dir(audio_dir):
|
||||
for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
|
||||
obj = json.loads(line)
|
||||
text = obj["text"]
|
||||
if obj['language'] == "zh":
|
||||
if obj["language"] == "zh":
|
||||
if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
|
||||
bad_case_zh += 1
|
||||
continue
|
||||
else:
|
||||
text = text.translate(str.maketrans({',': ',', '!': '!', '?': '?'})) # not "。" cuz much code-switched
|
||||
if obj['language'] == "en":
|
||||
if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
|
||||
text = text.translate(
|
||||
str.maketrans({",": ",", "!": "!", "?": "?"})
|
||||
) # not "。" cuz much code-switched
|
||||
if obj["language"] == "en":
|
||||
if (
|
||||
obj["wav"].split("/")[1] in out_en
|
||||
or any(f in text for f in en_filters)
|
||||
or repetition_found(text, length=4)
|
||||
):
|
||||
bad_case_en += 1
|
||||
continue
|
||||
if tokenizer == "pinyin":
|
||||
text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
|
||||
text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
|
||||
duration = obj["duration"]
|
||||
sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
|
||||
durations.append(duration)
|
||||
@@ -96,11 +179,11 @@ def main():
|
||||
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
|
||||
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
|
||||
with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
|
||||
for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
|
||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
|
||||
with open(f"data/{dataset_name}/duration.json", "w", encoding="utf-8") as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
|
||||
# vocab map, i.e. tokenizer
|
||||
@@ -114,12 +197,13 @@ def main():
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
||||
if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
|
||||
if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
|
||||
if "ZH" in langs:
|
||||
print(f"Bad zh transcription case: {total_bad_case_zh}")
|
||||
if "EN" in langs:
|
||||
print(f"Bad en transcription case: {total_bad_case_en}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
max_workers = 32
|
||||
|
||||
tokenizer = "pinyin" # "pinyin" | "char"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# generate audio text map for WenetSpeech4TTS
|
||||
# evaluate for vocab size
|
||||
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import json
|
||||
@@ -23,7 +25,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
|
||||
|
||||
audio_paths, texts, durations = [], [], []
|
||||
for text_file in tqdm(text_files):
|
||||
with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
|
||||
with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file:
|
||||
first_line = file.readline().split("\t")
|
||||
audio_nm = first_line[0]
|
||||
audio_path = os.path.join(audio_dir, audio_nm + ".wav")
|
||||
@@ -32,7 +34,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
|
||||
audio_paths.append(audio_path)
|
||||
|
||||
if tokenizer == "pinyin":
|
||||
texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
|
||||
texts.extend(convert_char_to_pinyin([text], polyphone=polyphone))
|
||||
elif tokenizer == "char":
|
||||
texts.append(text)
|
||||
|
||||
@@ -46,7 +48,7 @@ def main():
|
||||
assert tokenizer in ["pinyin", "char"]
|
||||
|
||||
audio_path_list, text_list, duration_list = [], [], []
|
||||
|
||||
|
||||
executor = ProcessPoolExecutor(max_workers=max_workers)
|
||||
futures = []
|
||||
for dataset_path in dataset_paths:
|
||||
@@ -68,8 +70,10 @@ def main():
|
||||
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
|
||||
dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
|
||||
|
||||
with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
with open(f"data/{dataset_name}_{tokenizer}/duration.json", "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{"duration": duration_list}, f, ensure_ascii=False
|
||||
) # dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
|
||||
print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
|
||||
text_vocab_set = set()
|
||||
@@ -85,22 +89,21 @@ def main():
|
||||
f.write(vocab + "\n")
|
||||
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
max_workers = 32
|
||||
|
||||
tokenizer = "pinyin" # "pinyin" | "char"
|
||||
polyphone = True
|
||||
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
|
||||
|
||||
dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
|
||||
dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
|
||||
dataset_paths = [
|
||||
"<SOME_PATH>/WenetSpeech4TTS/Basic",
|
||||
"<SOME_PATH>/WenetSpeech4TTS/Standard",
|
||||
"<SOME_PATH>/WenetSpeech4TTS/Premium",
|
||||
][-dataset_choice:]
|
||||
][-dataset_choice:]
|
||||
print(f"\nChoose Dataset: {dataset_name}\n")
|
||||
|
||||
main()
|
||||
@@ -109,8 +112,8 @@ if __name__ == "__main__":
|
||||
# WenetSpeech4TTS Basic Standard Premium
|
||||
# samples count 3932473 1941220 407494
|
||||
# pinyin vocab size 1349 1348 1344 (no polyphone)
|
||||
# - - 1459 (polyphone)
|
||||
# - - 1459 (polyphone)
|
||||
# char vocab size 5264 5219 5042
|
||||
|
||||
|
||||
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# please be careful if using pretrained model, make sure the vocab.txt is same
|
||||
|
||||
Reference in New Issue
Block a user