make happy all other language dont suport the symbols in vocab , now you can finetune by extend (#293)

* fix path

* change name

* change name

* fix path

* fix last per steps and add notes

* change order tab add note in vocab check tab

* add note in reduse checkpoint tab

* note in reduse checkpoint tab update

* extend vocab to train language miss symbols

* change enten to ,

* hide the option create new vocab , change order tab , add some info
This commit is contained in:
lpscr
2024-10-27 15:03:22 +02:00
committed by GitHub
parent 2056f5de41
commit 3941e8102f

View File

@@ -23,7 +23,7 @@ from datasets.arrow_writer import ArrowWriter
from safetensors.torch import save_file
from scipy.io import wavfile
from transformers import pipeline
from cached_path import cached_path
from f5_tts.api import F5TTS
from f5_tts.model.utils import convert_char_to_pinyin
@@ -731,6 +731,97 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, s
return f"An error occurred: {e}"
def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
seed = 666
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
ckpt = torch.load(ckpt_path, map_location="cpu")
ema_sd = ckpt.get("ema_model_state_dict", {})
embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
old_embed_ema = ema_sd[embed_key_ema]
vocab_old = old_embed_ema.size(0)
embed_dim = old_embed_ema.size(1)
vocab_new = vocab_old + num_new_tokens
def expand_embeddings(old_embeddings):
new_embeddings = torch.zeros((vocab_new, embed_dim))
new_embeddings[:vocab_old] = old_embeddings
new_embeddings[vocab_old:] = torch.randn((num_new_tokens, embed_dim))
return new_embeddings
ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema])
torch.save(ckpt, new_ckpt_path)
return vocab_new
def vocab_count(text):
return str(len(text.split(",")))
def vocab_extend(project_name, symbols, model_type):
if symbols == "":
return "Symbols empty!"
name_project = project_name
path_project = os.path.join(path_data, name_project)
file_vocab_project = os.path.join(path_project, "vocab.txt")
file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
if not os.path.isfile(file_vocab):
return f"the file {file_vocab} not found !"
symbols = symbols.split(",")
if symbols == []:
return "Symbols to extend not found."
with open(file_vocab, "r", encoding="utf-8-sig") as f:
data = f.read()
vocab = data.split("\n")
vocab_check = set(vocab)
miss_symbols = []
for item in symbols:
item = item.replace(" ", "")
if item in vocab_check:
continue
miss_symbols.append(item)
if miss_symbols == []:
return "Symbols are okay no need to extend."
size_vocab = len(vocab)
for item in miss_symbols:
vocab.append(item)
with open(file_vocab_project, "w", encoding="utf-8-sig") as f:
f.write("\n".join(vocab))
if model_type == "F5-TTS":
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
else:
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
new_ckpt_path = os.path.join(path_project_ckpts, name_project)
os.makedirs(new_ckpt_path, exist_ok=True)
new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt")
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=len(miss_symbols))
vocab_new = "\n".join(miss_symbols)
return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {len(miss_symbols)}\nnew symbols :\n{vocab_new}"
def vocab_check(project_name):
name_project = project_name
path_project = os.path.join(path_data, name_project)
@@ -739,7 +830,7 @@ def vocab_check(project_name):
file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
if not os.path.isfile(file_vocab):
return f"the file {file_vocab} not found !"
return f"the file {file_vocab} not found !", ""
with open(file_vocab, "r", encoding="utf-8-sig") as f:
data = f.read()
@@ -747,7 +838,7 @@ def vocab_check(project_name):
vocab = set(vocab)
if not os.path.isfile(file_metadata):
return f"the file {file_metadata} not found !"
return f"the file {file_metadata} not found !", ""
with open(file_metadata, "r", encoding="utf-8-sig") as f:
data = f.read()
@@ -765,12 +856,15 @@ def vocab_check(project_name):
if t not in vocab and t not in miss_symbols_keep:
miss_symbols.append(t)
miss_symbols_keep[t] = t
if miss_symbols == []:
vocab_miss = ""
info = "You can train using your language !"
else:
info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
vocab_miss = ",".join(miss_symbols)
info = f"The following symbols are missing in your language {len(miss_symbols)}\n\n"
return info
return info, vocab_miss
def get_random_sample_prepare(project_name):
@@ -1009,6 +1103,38 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
outputs=[random_text_transcribe, random_audio_transcribe],
)
with gr.TabItem("vocab check"):
gr.Markdown("""```plaintext
check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are included. for finetune new language
```""")
check_button = gr.Button("check vocab")
txt_info_check = gr.Text(label="info", value="")
gr.Markdown("""```plaintext
Using the extended model, you can fine-tune to a new language that is missing symbols in the vocab , this create a new model with a new vocabulary size and save it in your ckpts/project folder.
```""")
exp_name_extend = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
with gr.Row():
txt_extend = gr.Textbox(
label="Symbols",
value="",
placeholder="To add new symbols, make sure to use ',' for each symbol",
scale=6,
)
txt_count_symbol = gr.Textbox(label="new size vocab", value="", scale=1)
extend_button = gr.Button("Extended")
txt_info_extend = gr.Text(label="info", value="")
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
extend_button.click(
fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend]
)
with gr.TabItem("prepare Data"):
gr.Markdown(
"""```plaintext
@@ -1030,7 +1156,7 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
```"""
)
ch_tokenizern = gr.Checkbox(label="create vocabulary from dataset", value=False)
ch_tokenizern = gr.Checkbox(label="create vocabulary", value=False, visible=False)
bt_prepare = bt_create = gr.Button("prepare")
txt_info_prepare = gr.Text(label="info", value="")
txt_vocab_prepare = gr.Text(label="vocab", value="")
@@ -1048,14 +1174,6 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
)
with gr.TabItem("vocab check"):
gr.Markdown("""```plaintext
check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are included. for finetune new language
```""")
check_button = gr.Button("check vocab")
txt_info_check = gr.Text(label="info", value="")
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check])
with gr.TabItem("train Data"):
gr.Markdown("""```plaintext
The auto-setting is still experimental. Please make sure that the epochs , save per updates , and last per steps are set correctly, or change them manually as needed.