add and run pre-commit with ruff

This commit is contained in:
Tom Hunn
2024-10-21 14:46:45 +10:00
parent 77e00db01b
commit a4ca14b5f6
29 changed files with 1827 additions and 1328 deletions

View File

@@ -1,6 +1,7 @@
'''ADAPTIVE BATCH SIZE'''
print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
print(' -> least padding, gather wavs with accumulated frames in a batch\n')
"""ADAPTIVE BATCH SIZE"""
print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
print(" -> least padding, gather wavs with accumulated frames in a batch\n")
# data
total_hours = 95282

View File

@@ -1,13 +1,15 @@
import sys, os
import sys
import os
sys.path.append(os.getcwd())
from model import M2_TTS, UNetT, DiT, MMDiT
from model import M2_TTS, DiT
import torch
import thop
''' ~155M '''
""" ~155M """
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
@@ -15,11 +17,11 @@ import thop
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
# transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
''' ~335M '''
""" ~335M """
# FLOPs: 622.1 G, Params: 333.2 M
# transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
# FLOPs: 363.4 G, Params: 335.8 M
transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model = M2_TTS(transformer=transformer)
@@ -30,6 +32,8 @@ duration = 20
frame_length = int(duration * target_sample_rate / hop_length)
text_length = 150
flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
flops, params = thop.profile(
model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
)
print(f"FLOPs: {flops / 1e9} G")
print(f"Params: {params / 1e6} M")

View File

@@ -1,4 +1,6 @@
import sys, os
import sys
import os
sys.path.append(os.getcwd())
import time
@@ -14,9 +16,9 @@ from vocos import Vocos
from model import CFM, UNetT, DiT
from model.utils import (
load_checkpoint,
get_tokenizer,
get_seedtts_testset_metainfo,
get_librispeech_test_clean_metainfo,
get_tokenizer,
get_seedtts_testset_metainfo,
get_librispeech_test_clean_metainfo,
get_inference_prompt,
)
@@ -38,16 +40,16 @@ tokenizer = "pinyin"
parser = argparse.ArgumentParser(description="batch inference")
parser.add_argument('-s', '--seed', default=None, type=int)
parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
parser.add_argument('-n', '--expname', required=True)
parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
parser.add_argument("-s", "--seed", default=None, type=int)
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
parser.add_argument("-n", "--expname", required=True)
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
parser.add_argument('-nfe', '--nfestep', default=32, type=int)
parser.add_argument('-o', '--odemethod', default="euler")
parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
parser.add_argument("-o", "--odemethod", default="euler")
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
parser.add_argument('-t', '--testset', required=True)
parser.add_argument("-t", "--testset", required=True)
args = parser.parse_args()
@@ -66,26 +68,26 @@ testset = args.testset
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
cfg_strength = 2.
speed = 1.
cfg_strength = 2.0
speed = 1.0
use_truth_duration = False
no_ref_audio = False
if exp_name == "F5TTS_Base":
model_cls = DiT
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
elif exp_name == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if testset == "ls_pc_test_clean":
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
elif testset == "seedtts_test_zh":
metalst = "data/seedtts_testset/zh/meta.lst"
metainfo = get_seedtts_testset_metainfo(metalst)
@@ -96,13 +98,16 @@ elif testset == "seedtts_test_en":
# path to save genereted wavs
if seed is None: seed = random.randint(-10000, 10000)
output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
f"seed{seed}_{ode_method}_nfe{nfe_step}" \
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
f"_cfg{cfg_strength}_speed{speed}" \
f"{'_gt-dur' if use_truth_duration else ''}" \
if seed is None:
seed = random.randint(-10000, 10000)
output_dir = (
f"results/{exp_name}_{ckpt_step}/{testset}/"
f"seed{seed}_{ode_method}_nfe{nfe_step}"
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
f"_cfg{cfg_strength}_speed{speed}"
f"{'_gt-dur' if use_truth_duration else ''}"
f"{'_no-ref-audio' if no_ref_audio else ''}"
)
# -------------------------------------------------#
@@ -110,15 +115,15 @@ output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
use_ema = True
prompts_all = get_inference_prompt(
metainfo,
speed = speed,
tokenizer = tokenizer,
target_sample_rate = target_sample_rate,
n_mel_channels = n_mel_channels,
hop_length = hop_length,
target_rms = target_rms,
use_truth_duration = use_truth_duration,
infer_batch_size = infer_batch_size,
metainfo,
speed=speed,
tokenizer=tokenizer,
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
target_rms=target_rms,
use_truth_duration=use_truth_duration,
infer_batch_size=infer_batch_size,
)
# Vocoder model
@@ -137,23 +142,19 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
# Model
model = CFM(
transformer = model_cls(
**model_cfg,
text_num_embeds = vocab_size,
mel_dim = n_mel_channels
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
),
mel_spec_kwargs = dict(
target_sample_rate = target_sample_rate,
n_mel_channels = n_mel_channels,
hop_length = hop_length,
odeint_kwargs=dict(
method=ode_method,
),
odeint_kwargs = dict(
method = ode_method,
),
vocab_char_map = vocab_char_map,
vocab_char_map=vocab_char_map,
).to(device)
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
if not os.path.exists(output_dir) and accelerator.is_main_process:
os.makedirs(output_dir)
@@ -163,29 +164,28 @@ accelerator.wait_for_everyone()
start = time.time()
with accelerator.split_between_processes(prompts_all) as prompts:
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
ref_mels = ref_mels.to(device)
ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
# Inference
with torch.inference_mode():
generated, _ = model.sample(
cond = ref_mels,
text = final_text_list,
duration = total_mel_lens,
lens = ref_mel_lens,
steps = nfe_step,
cfg_strength = cfg_strength,
sway_sampling_coef = sway_sampling_coef,
no_ref_audio = no_ref_audio,
seed = seed,
cond=ref_mels,
text=final_text_list,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
no_ref_audio=no_ref_audio,
seed=seed,
)
# Final result
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1)
generated_wave = vocos.decode(gen_mel_spec.cpu())
if ref_rms_list[i] < target_rms:

View File

@@ -1,6 +1,8 @@
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
import sys, os
import sys
import os
sys.path.append(os.getcwd())
import multiprocessing as mp
@@ -19,7 +21,7 @@ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
gpus = [0,1,2,3,4,5,6,7]
gpus = [0, 1, 2, 3, 4, 5, 6, 7]
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
@@ -46,7 +48,7 @@ if eval_task == "wer":
for wers_ in results:
wers.extend(wers_)
wer = round(np.mean(wers)*100, 3)
wer = round(np.mean(wers) * 100, 3)
print(f"\nTotal {len(wers)} samples")
print(f"WER : {wer}%")
@@ -62,6 +64,6 @@ if eval_task == "sim":
for sim_ in results:
sim_list.extend(sim_)
sim = round(sum(sim_list)/len(sim_list), 3)
sim = round(sum(sim_list) / len(sim_list), 3)
print(f"\nTotal {len(sim_list)} samples")
print(f"SIM : {sim}")

View File

@@ -1,6 +1,8 @@
# Evaluate with Seed-TTS testset
import sys, os
import sys
import os
sys.path.append(os.getcwd())
import multiprocessing as mp
@@ -14,21 +16,21 @@ from model.utils import (
eval_task = "wer" # sim | wer
lang = "zh" # zh | en
lang = "zh" # zh | en
metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
# gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
# zh 1.254 seems a result of 4 workers wer_seed_tts
gpus = [0,1,2,3,4,5,6,7]
# zh 1.254 seems a result of 4 workers wer_seed_tts
gpus = [0, 1, 2, 3, 4, 5, 6, 7]
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
local = False
if local: # use local custom checkpoint dir
if lang == "zh":
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
elif lang == "en":
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
else:
@@ -48,7 +50,7 @@ if eval_task == "wer":
for wers_ in results:
wers.extend(wers_)
wer = round(np.mean(wers)*100, 3)
wer = round(np.mean(wers) * 100, 3)
print(f"\nTotal {len(wers)} samples")
print(f"WER : {wer}%")
@@ -64,6 +66,6 @@ if eval_task == "sim":
for sim_ in results:
sim_list.extend(sim_)
sim = round(sum(sim_list)/len(sim_list), 3)
sim = round(sum(sim_list) / len(sim_list), 3)
print(f"\nTotal {len(sim_list)} samples")
print(f"SIM : {sim}")

View File

@@ -1,4 +1,6 @@
import sys, os
import sys
import os
sys.path.append(os.getcwd())
from pathlib import Path
@@ -17,10 +19,11 @@ from model.utils import (
PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
def is_csv_wavs_format(input_dataset_dir):
fpath = Path(input_dataset_dir)
metadata = fpath / "metadata.csv"
wavs = fpath / 'wavs'
wavs = fpath / "wavs"
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
@@ -46,22 +49,24 @@ def prepare_csv_wavs_dir(input_dir):
return sub_result, durations, vocab_set
def get_audio_duration(audio_path):
audio, sample_rate = torchaudio.load(audio_path)
num_channels = audio.shape[0]
return audio.shape[1] / (sample_rate * num_channels)
def read_audio_text_pairs(csv_file_path):
audio_text_pairs = []
parent = Path(csv_file_path).parent
with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csvfile:
reader = csv.reader(csvfile, delimiter='|')
with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile:
reader = csv.reader(csvfile, delimiter="|")
next(reader) # Skip the header row
for row in reader:
if len(row) >= 2:
audio_file = row[0].strip() # First column: audio file path
text = row[1].strip() # Second column: text
text = row[1].strip() # Second column: text
audio_file_path = parent / audio_file
audio_text_pairs.append((audio_file_path.as_posix(), text))
@@ -78,12 +83,12 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
raw_arrow_path = out_dir / "raw.arrow"
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
# dup a json separately saving duration in case for DynamicBatchSampler ease
dur_json_path = out_dir / "duration.json"
with open(dur_json_path.as_posix(), 'w', encoding='utf-8') as f:
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# vocab map, i.e. tokenizer
@@ -120,13 +125,14 @@ def cli():
# finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
# pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
parser = argparse.ArgumentParser(description="Prepare and save dataset.")
parser.add_argument('inp_dir', type=str, help="Input directory containing the data.")
parser.add_argument('out_dir', type=str, help="Output directory to save the prepared data.")
parser.add_argument('--pretrain', action='store_true', help="Enable for new pretrain, otherwise is a fine-tune")
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
args = parser.parse_args()
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
if __name__ == "__main__":
cli()

View File

@@ -4,7 +4,9 @@
# generate audio text map for Emilia ZH & EN
# evaluate for vocab size
import sys, os
import sys
import os
sys.path.append(os.getcwd())
from pathlib import Path
@@ -12,7 +14,6 @@ import json
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
from datasets import Dataset
from datasets.arrow_writer import ArrowWriter
from model.utils import (
@@ -21,13 +22,89 @@ from model.utils import (
)
out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
out_zh = {
"ZH_B00041_S06226",
"ZH_B00042_S09204",
"ZH_B00065_S09430",
"ZH_B00065_S09431",
"ZH_B00066_S09327",
"ZH_B00066_S09328",
}
zh_filters = ["", ""]
# seems synthesized audios, or heavily code-switched
out_en = {
"EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
"EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
"EN_B00013_S00913",
"EN_B00042_S00120",
"EN_B00055_S04111",
"EN_B00061_S00693",
"EN_B00061_S01494",
"EN_B00061_S03375",
"EN_B00059_S00092",
"EN_B00111_S04300",
"EN_B00100_S03759",
"EN_B00087_S03811",
"EN_B00059_S00950",
"EN_B00089_S00946",
"EN_B00078_S05127",
"EN_B00070_S04089",
"EN_B00074_S09659",
"EN_B00061_S06983",
"EN_B00061_S07060",
"EN_B00059_S08397",
"EN_B00082_S06192",
"EN_B00091_S01238",
"EN_B00089_S07349",
"EN_B00070_S04343",
"EN_B00061_S02400",
"EN_B00076_S01262",
"EN_B00068_S06467",
"EN_B00076_S02943",
"EN_B00064_S05954",
"EN_B00061_S05386",
"EN_B00066_S06544",
"EN_B00076_S06944",
"EN_B00072_S08620",
"EN_B00076_S07135",
"EN_B00076_S09127",
"EN_B00065_S00497",
"EN_B00059_S06227",
"EN_B00063_S02859",
"EN_B00075_S01547",
"EN_B00061_S08286",
"EN_B00079_S02901",
"EN_B00092_S03643",
"EN_B00096_S08653",
"EN_B00063_S04297",
"EN_B00063_S04614",
"EN_B00079_S04698",
"EN_B00104_S01666",
"EN_B00061_S09504",
"EN_B00061_S09694",
"EN_B00065_S05444",
"EN_B00063_S06860",
"EN_B00065_S05725",
"EN_B00069_S07628",
"EN_B00083_S03875",
"EN_B00071_S07665",
"EN_B00071_S07665",
"EN_B00062_S04187",
"EN_B00065_S09873",
"EN_B00065_S09922",
"EN_B00084_S02463",
"EN_B00067_S05066",
"EN_B00106_S08060",
"EN_B00073_S06399",
"EN_B00073_S09236",
"EN_B00087_S00432",
"EN_B00085_S05618",
"EN_B00064_S01262",
"EN_B00072_S01739",
"EN_B00059_S03913",
"EN_B00069_S04036",
"EN_B00067_S05623",
"EN_B00060_S05389",
"EN_B00060_S07290",
"EN_B00062_S08995",
}
en_filters = ["ا", "", ""]
@@ -43,18 +120,24 @@ def deal_with_audio_dir(audio_dir):
for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
obj = json.loads(line)
text = obj["text"]
if obj['language'] == "zh":
if obj["language"] == "zh":
if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
bad_case_zh += 1
continue
else:
text = text.translate(str.maketrans({',': '', '!': '', '?': ''})) # not "。" cuz much code-switched
if obj['language'] == "en":
if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
text = text.translate(
str.maketrans({",": "", "!": "", "?": ""})
) # not "。" cuz much code-switched
if obj["language"] == "en":
if (
obj["wav"].split("/")[1] in out_en
or any(f in text for f in en_filters)
or repetition_found(text, length=4)
):
bad_case_en += 1
continue
if tokenizer == "pinyin":
text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
duration = obj["duration"]
sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
durations.append(duration)
@@ -96,11 +179,11 @@ def main():
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
with open(f"data/{dataset_name}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# vocab map, i.e. tokenizer
@@ -114,12 +197,13 @@ def main():
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
if "ZH" in langs:
print(f"Bad zh transcription case: {total_bad_case_zh}")
if "EN" in langs:
print(f"Bad en transcription case: {total_bad_case_en}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "pinyin" # "pinyin" | "char"

View File

@@ -1,7 +1,9 @@
# generate audio text map for WenetSpeech4TTS
# evaluate for vocab size
import sys, os
import sys
import os
sys.path.append(os.getcwd())
import json
@@ -23,7 +25,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
audio_paths, texts, durations = [], [], []
for text_file in tqdm(text_files):
with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file:
first_line = file.readline().split("\t")
audio_nm = first_line[0]
audio_path = os.path.join(audio_dir, audio_nm + ".wav")
@@ -32,7 +34,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
audio_paths.append(audio_path)
if tokenizer == "pinyin":
texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
texts.extend(convert_char_to_pinyin([text], polyphone=polyphone))
elif tokenizer == "char":
texts.append(text)
@@ -46,7 +48,7 @@ def main():
assert tokenizer in ["pinyin", "char"]
audio_path_list, text_list, duration_list = [], [], []
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
for dataset_path in dataset_paths:
@@ -68,8 +70,10 @@ def main():
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"data/{dataset_name}_{tokenizer}/duration.json", "w", encoding="utf-8") as f:
json.dump(
{"duration": duration_list}, f, ensure_ascii=False
) # dup a json separately saving duration in case for DynamicBatchSampler ease
print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
text_vocab_set = set()
@@ -85,22 +89,21 @@ def main():
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "pinyin" # "pinyin" | "char"
polyphone = True
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
dataset_paths = [
"<SOME_PATH>/WenetSpeech4TTS/Basic",
"<SOME_PATH>/WenetSpeech4TTS/Standard",
"<SOME_PATH>/WenetSpeech4TTS/Premium",
][-dataset_choice:]
][-dataset_choice:]
print(f"\nChoose Dataset: {dataset_name}\n")
main()
@@ -109,8 +112,8 @@ if __name__ == "__main__":
# WenetSpeech4TTS Basic Standard Premium
# samples count 3932473 1941220 407494
# pinyin vocab size 1349 1348 1344 (no polyphone)
# - - 1459 (polyphone)
# - - 1459 (polyphone)
# char vocab size 5264 5219 5042
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
# please be careful if using pretrained model, make sure the vocab.txt is same