mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-05 01:27:55 -08:00
Update to make passing in custom paths easier for finetuning/training
This commit is contained in:
@@ -184,11 +184,19 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
||||
|
||||
def load_dataset(
|
||||
dataset_name: str,
|
||||
tokenizer: str,
|
||||
tokenizer: str = "pinyon",
|
||||
dataset_type: str = "CustomDataset",
|
||||
audio_type: str = "raw",
|
||||
mel_spec_kwargs: dict = dict()
|
||||
<<<<<<< HEAD
|
||||
) -> CustomDataset | HFDataset:
|
||||
'''
|
||||
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
|
||||
- "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
|
||||
'''
|
||||
=======
|
||||
) -> CustomDataset:
|
||||
>>>>>>> 0297be2541f9a062f9d54103926bdb88d63440ea
|
||||
|
||||
print("Loading dataset ...")
|
||||
|
||||
@@ -206,7 +214,18 @@ def load_dataset(
|
||||
data_dict = json.load(f)
|
||||
durations = data_dict["duration"]
|
||||
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
|
||||
|
||||
|
||||
elif dataset_type == "CustomDatasetPath":
|
||||
try:
|
||||
train_dataset = load_from_disk(f"{dataset_name}/raw")
|
||||
except:
|
||||
train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
|
||||
|
||||
with open(f"{dataset_name}/duration.json", 'r', encoding='utf-8') as f:
|
||||
data_dict = json.load(f)
|
||||
durations = data_dict["duration"]
|
||||
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
|
||||
|
||||
elif dataset_type == "HFDataset":
|
||||
print("Should manually modify the path of huggingface dataset to your need.\n" +
|
||||
"May also the corresponding script cuz different dataset may have different format.")
|
||||
|
||||
@@ -129,6 +129,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
||||
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
||||
- "char" for char-wise tokenizer, need .txt vocab_file
|
||||
- "byte" for utf-8 tokenizer
|
||||
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
||||
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
||||
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
||||
- if use "byte", set to 256 (unicode byte range)
|
||||
@@ -144,6 +145,12 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
||||
elif tokenizer == "byte":
|
||||
vocab_char_map = None
|
||||
vocab_size = 256
|
||||
elif tokenizer == "custom":
|
||||
with open (dataset_name, "r", encoding="utf-8") as f:
|
||||
vocab_char_map = {}
|
||||
for i, char in enumerate(f):
|
||||
vocab_char_map[char[:-1]] = i
|
||||
vocab_size = len(vocab_char_map)
|
||||
|
||||
return vocab_char_map, vocab_size
|
||||
|
||||
|
||||
11
train.py
11
train.py
@@ -9,10 +9,10 @@ target_sample_rate = 24000
|
||||
n_mel_channels = 100
|
||||
hop_length = 256
|
||||
|
||||
tokenizer = "pinyin"
|
||||
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
||||
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
||||
dataset_name = "Emilia_ZH_EN"
|
||||
|
||||
|
||||
# -------------------------- Training Settings -------------------------- #
|
||||
|
||||
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
|
||||
@@ -44,8 +44,11 @@ elif exp_name == "E2TTS_Base":
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
def main():
|
||||
|
||||
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
||||
if tokenizer == "custom":
|
||||
tokenizer_path = tokenizer_path
|
||||
else:
|
||||
tokenizer_path = dataset_name
|
||||
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
||||
|
||||
mel_spec_kwargs = dict(
|
||||
target_sample_rate = target_sample_rate,
|
||||
|
||||
Reference in New Issue
Block a user