From 3941e8102fa8a72d3b1544abd8a092c42c3c217d Mon Sep 17 00:00:00 2001 From: lpscr <147736764+lpscr@users.noreply.github.com> Date: Sun, 27 Oct 2024 15:03:22 +0200 Subject: [PATCH] make happy all other language dont suport the symbols in vocab , now you can finetune by extend (#293) * fix path * change name * change name * fix path * fix last per steps and add notes * change order tab add note in vocab check tab * add note in reduse checkpoint tab * note in reduse checkpoint tab update * extend vocab to train language miss symbols * change enten to , * hide the option create new vocab , change order tab , add some info --- src/f5_tts/train/finetune_gradio.py | 146 +++++++++++++++++++++++++--- 1 file changed, 132 insertions(+), 14 deletions(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index cce6e30..bf540d7 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -23,7 +23,7 @@ from datasets.arrow_writer import ArrowWriter from safetensors.torch import save_file from scipy.io import wavfile from transformers import pipeline - +from cached_path import cached_path from f5_tts.api import F5TTS from f5_tts.model.utils import convert_char_to_pinyin @@ -731,6 +731,97 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, s return f"An error occurred: {e}" +def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42): + seed = 666 + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + ckpt = torch.load(ckpt_path, map_location="cpu") + + ema_sd = ckpt.get("ema_model_state_dict", {}) + embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight" + old_embed_ema = ema_sd[embed_key_ema] + + vocab_old = old_embed_ema.size(0) + embed_dim = old_embed_ema.size(1) + vocab_new = vocab_old + num_new_tokens + + def expand_embeddings(old_embeddings): + new_embeddings = torch.zeros((vocab_new, embed_dim)) + new_embeddings[:vocab_old] = old_embeddings + new_embeddings[vocab_old:] = torch.randn((num_new_tokens, embed_dim)) + return new_embeddings + + ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema]) + + torch.save(ckpt, new_ckpt_path) + + return vocab_new + + +def vocab_count(text): + return str(len(text.split(","))) + + +def vocab_extend(project_name, symbols, model_type): + if symbols == "": + return "Symbols empty!" + + name_project = project_name + path_project = os.path.join(path_data, name_project) + file_vocab_project = os.path.join(path_project, "vocab.txt") + + file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt") + if not os.path.isfile(file_vocab): + return f"the file {file_vocab} not found !" + + symbols = symbols.split(",") + if symbols == []: + return "Symbols to extend not found." + + with open(file_vocab, "r", encoding="utf-8-sig") as f: + data = f.read() + vocab = data.split("\n") + vocab_check = set(vocab) + + miss_symbols = [] + for item in symbols: + item = item.replace(" ", "") + if item in vocab_check: + continue + miss_symbols.append(item) + + if miss_symbols == []: + return "Symbols are okay no need to extend." + + size_vocab = len(vocab) + + for item in miss_symbols: + vocab.append(item) + + with open(file_vocab_project, "w", encoding="utf-8-sig") as f: + f.write("\n".join(vocab)) + + if model_type == "F5-TTS": + ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) + else: + ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) + + new_ckpt_path = os.path.join(path_project_ckpts, name_project) + os.makedirs(new_ckpt_path, exist_ok=True) + new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt") + + size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=len(miss_symbols)) + + vocab_new = "\n".join(miss_symbols) + return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {len(miss_symbols)}\nnew symbols :\n{vocab_new}" + + def vocab_check(project_name): name_project = project_name path_project = os.path.join(path_data, name_project) @@ -739,7 +830,7 @@ def vocab_check(project_name): file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt") if not os.path.isfile(file_vocab): - return f"the file {file_vocab} not found !" + return f"the file {file_vocab} not found !", "" with open(file_vocab, "r", encoding="utf-8-sig") as f: data = f.read() @@ -747,7 +838,7 @@ def vocab_check(project_name): vocab = set(vocab) if not os.path.isfile(file_metadata): - return f"the file {file_metadata} not found !" + return f"the file {file_metadata} not found !", "" with open(file_metadata, "r", encoding="utf-8-sig") as f: data = f.read() @@ -765,12 +856,15 @@ def vocab_check(project_name): if t not in vocab and t not in miss_symbols_keep: miss_symbols.append(t) miss_symbols_keep[t] = t + if miss_symbols == []: + vocab_miss = "" info = "You can train using your language !" else: - info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols) + vocab_miss = ",".join(miss_symbols) + info = f"The following symbols are missing in your language {len(miss_symbols)}\n\n" - return info + return info, vocab_miss def get_random_sample_prepare(project_name): @@ -1009,6 +1103,38 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion outputs=[random_text_transcribe, random_audio_transcribe], ) + with gr.TabItem("vocab check"): + gr.Markdown("""```plaintext +check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are included. for finetune new language +```""") + + check_button = gr.Button("check vocab") + txt_info_check = gr.Text(label="info", value="") + + gr.Markdown("""```plaintext +Using the extended model, you can fine-tune to a new language that is missing symbols in the vocab , this create a new model with a new vocabulary size and save it in your ckpts/project folder. +```""") + + exp_name_extend = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS") + + with gr.Row(): + txt_extend = gr.Textbox( + label="Symbols", + value="", + placeholder="To add new symbols, make sure to use ',' for each symbol", + scale=6, + ) + txt_count_symbol = gr.Textbox(label="new size vocab", value="", scale=1) + + extend_button = gr.Button("Extended") + txt_info_extend = gr.Text(label="info", value="") + + txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol]) + check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend]) + extend_button.click( + fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend] + ) + with gr.TabItem("prepare Data"): gr.Markdown( """```plaintext @@ -1030,7 +1156,7 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion ```""" ) - ch_tokenizern = gr.Checkbox(label="create vocabulary from dataset", value=False) + ch_tokenizern = gr.Checkbox(label="create vocabulary", value=False, visible=False) bt_prepare = bt_create = gr.Button("prepare") txt_info_prepare = gr.Text(label="info", value="") txt_vocab_prepare = gr.Text(label="vocab", value="") @@ -1048,14 +1174,6 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare] ) - with gr.TabItem("vocab check"): - gr.Markdown("""```plaintext -check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are included. for finetune new language -```""") - check_button = gr.Button("check vocab") - txt_info_check = gr.Text(label="info", value="") - check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check]) - with gr.TabItem("train Data"): gr.Markdown("""```plaintext The auto-setting is still experimental. Please make sure that the epochs , save per updates , and last per steps are set correctly, or change them manually as needed.