Update to make passing in custom paths easier for finetuning/training

This commit is contained in:
Jarod Mica
2024-10-14 20:13:07 -07:00
parent 0297be2541
commit 6fda7e5f6f
3 changed files with 35 additions and 6 deletions

View File

@@ -184,11 +184,19 @@ class DynamicBatchSampler(Sampler[list[int]]):
def load_dataset(
dataset_name: str,
tokenizer: str,
tokenizer: str = "pinyon",
dataset_type: str = "CustomDataset",
audio_type: str = "raw",
mel_spec_kwargs: dict = dict()
<<<<<<< HEAD
) -> CustomDataset | HFDataset:
'''
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
- "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
'''
=======
) -> CustomDataset:
>>>>>>> 0297be2541f9a062f9d54103926bdb88d63440ea
print("Loading dataset ...")
@@ -206,7 +214,18 @@ def load_dataset(
data_dict = json.load(f)
durations = data_dict["duration"]
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
elif dataset_type == "CustomDatasetPath":
try:
train_dataset = load_from_disk(f"{dataset_name}/raw")
except:
train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
with open(f"{dataset_name}/duration.json", 'r', encoding='utf-8') as f:
data_dict = json.load(f)
durations = data_dict["duration"]
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
elif dataset_type == "HFDataset":
print("Should manually modify the path of huggingface dataset to your need.\n" +
"May also the corresponding script cuz different dataset may have different format.")

View File

@@ -129,6 +129,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
@@ -144,6 +145,12 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
elif tokenizer == "byte":
vocab_char_map = None
vocab_size = 256
elif tokenizer == "custom":
with open (dataset_name, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size

View File

@@ -9,10 +9,10 @@ target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
tokenizer = "pinyin"
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
dataset_name = "Emilia_ZH_EN"
# -------------------------- Training Settings -------------------------- #
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
@@ -44,8 +44,11 @@ elif exp_name == "E2TTS_Base":
# ----------------------------------------------------------------------- #
def main():
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
if tokenizer == "custom":
tokenizer_path = tokenizer_path
else:
tokenizer_path = dataset_name
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
mel_spec_kwargs = dict(
target_sample_rate = target_sample_rate,