diff --git a/model/dataset.py b/model/dataset.py index bb5ec8e..8480f9d 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -184,11 +184,19 @@ class DynamicBatchSampler(Sampler[list[int]]): def load_dataset( dataset_name: str, - tokenizer: str, + tokenizer: str = "pinyon", dataset_type: str = "CustomDataset", audio_type: str = "raw", mel_spec_kwargs: dict = dict() +<<<<<<< HEAD + ) -> CustomDataset | HFDataset: + ''' + dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset + - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer + ''' +======= ) -> CustomDataset: +>>>>>>> 0297be2541f9a062f9d54103926bdb88d63440ea print("Loading dataset ...") @@ -206,7 +214,18 @@ def load_dataset( data_dict = json.load(f) durations = data_dict["duration"] train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs) - + + elif dataset_type == "CustomDatasetPath": + try: + train_dataset = load_from_disk(f"{dataset_name}/raw") + except: + train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow") + + with open(f"{dataset_name}/duration.json", 'r', encoding='utf-8') as f: + data_dict = json.load(f) + durations = data_dict["duration"] + train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs) + elif dataset_type == "HFDataset": print("Should manually modify the path of huggingface dataset to your need.\n" + "May also the corresponding script cuz different dataset may have different format.") diff --git a/model/utils.py b/model/utils.py index 874f022..818df12 100644 --- a/model/utils.py +++ b/model/utils.py @@ -129,6 +129,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file - "char" for char-wise tokenizer, need .txt vocab_file - "byte" for utf-8 tokenizer + - "custom" if you're directly passing in a path to the vocab.txt you want to use vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols - if use "char", derived from unfiltered character & symbol counts of custom dataset - if use "byte", set to 256 (unicode byte range) @@ -144,6 +145,12 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): elif tokenizer == "byte": vocab_char_map = None vocab_size = 256 + elif tokenizer == "custom": + with open (dataset_name, "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) return vocab_char_map, vocab_size diff --git a/train.py b/train.py index e056175..f14029d 100644 --- a/train.py +++ b/train.py @@ -9,10 +9,10 @@ target_sample_rate = 24000 n_mel_channels = 100 hop_length = 256 -tokenizer = "pinyin" +tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' +tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) dataset_name = "Emilia_ZH_EN" - # -------------------------- Training Settings -------------------------- # exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base @@ -44,8 +44,11 @@ elif exp_name == "E2TTS_Base": # ----------------------------------------------------------------------- # def main(): - - vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) + if tokenizer == "custom": + tokenizer_path = tokenizer_path + else: + tokenizer_path = dataset_name + vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) mel_spec_kwargs = dict( target_sample_rate = target_sample_rate,