mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-12 13:15:21 -08:00
make happy all other language dont suport the symbols in vocab , now you can finetune by extend (#293)
* fix path * change name * change name * fix path * fix last per steps and add notes * change order tab add note in vocab check tab * add note in reduse checkpoint tab * note in reduse checkpoint tab update * extend vocab to train language miss symbols * change enten to , * hide the option create new vocab , change order tab , add some info
This commit is contained in:
@@ -23,7 +23,7 @@ from datasets.arrow_writer import ArrowWriter
|
||||
from safetensors.torch import save_file
|
||||
from scipy.io import wavfile
|
||||
from transformers import pipeline
|
||||
|
||||
from cached_path import cached_path
|
||||
from f5_tts.api import F5TTS
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
|
||||
@@ -731,6 +731,97 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, s
|
||||
return f"An error occurred: {e}"
|
||||
|
||||
|
||||
def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
|
||||
seed = 666
|
||||
random.seed(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
ema_sd = ckpt.get("ema_model_state_dict", {})
|
||||
embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
|
||||
old_embed_ema = ema_sd[embed_key_ema]
|
||||
|
||||
vocab_old = old_embed_ema.size(0)
|
||||
embed_dim = old_embed_ema.size(1)
|
||||
vocab_new = vocab_old + num_new_tokens
|
||||
|
||||
def expand_embeddings(old_embeddings):
|
||||
new_embeddings = torch.zeros((vocab_new, embed_dim))
|
||||
new_embeddings[:vocab_old] = old_embeddings
|
||||
new_embeddings[vocab_old:] = torch.randn((num_new_tokens, embed_dim))
|
||||
return new_embeddings
|
||||
|
||||
ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema])
|
||||
|
||||
torch.save(ckpt, new_ckpt_path)
|
||||
|
||||
return vocab_new
|
||||
|
||||
|
||||
def vocab_count(text):
|
||||
return str(len(text.split(",")))
|
||||
|
||||
|
||||
def vocab_extend(project_name, symbols, model_type):
|
||||
if symbols == "":
|
||||
return "Symbols empty!"
|
||||
|
||||
name_project = project_name
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
file_vocab_project = os.path.join(path_project, "vocab.txt")
|
||||
|
||||
file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
|
||||
if not os.path.isfile(file_vocab):
|
||||
return f"the file {file_vocab} not found !"
|
||||
|
||||
symbols = symbols.split(",")
|
||||
if symbols == []:
|
||||
return "Symbols to extend not found."
|
||||
|
||||
with open(file_vocab, "r", encoding="utf-8-sig") as f:
|
||||
data = f.read()
|
||||
vocab = data.split("\n")
|
||||
vocab_check = set(vocab)
|
||||
|
||||
miss_symbols = []
|
||||
for item in symbols:
|
||||
item = item.replace(" ", "")
|
||||
if item in vocab_check:
|
||||
continue
|
||||
miss_symbols.append(item)
|
||||
|
||||
if miss_symbols == []:
|
||||
return "Symbols are okay no need to extend."
|
||||
|
||||
size_vocab = len(vocab)
|
||||
|
||||
for item in miss_symbols:
|
||||
vocab.append(item)
|
||||
|
||||
with open(file_vocab_project, "w", encoding="utf-8-sig") as f:
|
||||
f.write("\n".join(vocab))
|
||||
|
||||
if model_type == "F5-TTS":
|
||||
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
||||
else:
|
||||
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
||||
|
||||
new_ckpt_path = os.path.join(path_project_ckpts, name_project)
|
||||
os.makedirs(new_ckpt_path, exist_ok=True)
|
||||
new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt")
|
||||
|
||||
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=len(miss_symbols))
|
||||
|
||||
vocab_new = "\n".join(miss_symbols)
|
||||
return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {len(miss_symbols)}\nnew symbols :\n{vocab_new}"
|
||||
|
||||
|
||||
def vocab_check(project_name):
|
||||
name_project = project_name
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
@@ -739,7 +830,7 @@ def vocab_check(project_name):
|
||||
|
||||
file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
|
||||
if not os.path.isfile(file_vocab):
|
||||
return f"the file {file_vocab} not found !"
|
||||
return f"the file {file_vocab} not found !", ""
|
||||
|
||||
with open(file_vocab, "r", encoding="utf-8-sig") as f:
|
||||
data = f.read()
|
||||
@@ -747,7 +838,7 @@ def vocab_check(project_name):
|
||||
vocab = set(vocab)
|
||||
|
||||
if not os.path.isfile(file_metadata):
|
||||
return f"the file {file_metadata} not found !"
|
||||
return f"the file {file_metadata} not found !", ""
|
||||
|
||||
with open(file_metadata, "r", encoding="utf-8-sig") as f:
|
||||
data = f.read()
|
||||
@@ -765,12 +856,15 @@ def vocab_check(project_name):
|
||||
if t not in vocab and t not in miss_symbols_keep:
|
||||
miss_symbols.append(t)
|
||||
miss_symbols_keep[t] = t
|
||||
|
||||
if miss_symbols == []:
|
||||
vocab_miss = ""
|
||||
info = "You can train using your language !"
|
||||
else:
|
||||
info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
|
||||
vocab_miss = ",".join(miss_symbols)
|
||||
info = f"The following symbols are missing in your language {len(miss_symbols)}\n\n"
|
||||
|
||||
return info
|
||||
return info, vocab_miss
|
||||
|
||||
|
||||
def get_random_sample_prepare(project_name):
|
||||
@@ -1009,6 +1103,38 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
|
||||
outputs=[random_text_transcribe, random_audio_transcribe],
|
||||
)
|
||||
|
||||
with gr.TabItem("vocab check"):
|
||||
gr.Markdown("""```plaintext
|
||||
check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are included. for finetune new language
|
||||
```""")
|
||||
|
||||
check_button = gr.Button("check vocab")
|
||||
txt_info_check = gr.Text(label="info", value="")
|
||||
|
||||
gr.Markdown("""```plaintext
|
||||
Using the extended model, you can fine-tune to a new language that is missing symbols in the vocab , this create a new model with a new vocabulary size and save it in your ckpts/project folder.
|
||||
```""")
|
||||
|
||||
exp_name_extend = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
|
||||
|
||||
with gr.Row():
|
||||
txt_extend = gr.Textbox(
|
||||
label="Symbols",
|
||||
value="",
|
||||
placeholder="To add new symbols, make sure to use ',' for each symbol",
|
||||
scale=6,
|
||||
)
|
||||
txt_count_symbol = gr.Textbox(label="new size vocab", value="", scale=1)
|
||||
|
||||
extend_button = gr.Button("Extended")
|
||||
txt_info_extend = gr.Text(label="info", value="")
|
||||
|
||||
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
|
||||
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
|
||||
extend_button.click(
|
||||
fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend]
|
||||
)
|
||||
|
||||
with gr.TabItem("prepare Data"):
|
||||
gr.Markdown(
|
||||
"""```plaintext
|
||||
@@ -1030,7 +1156,7 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
|
||||
|
||||
```"""
|
||||
)
|
||||
ch_tokenizern = gr.Checkbox(label="create vocabulary from dataset", value=False)
|
||||
ch_tokenizern = gr.Checkbox(label="create vocabulary", value=False, visible=False)
|
||||
bt_prepare = bt_create = gr.Button("prepare")
|
||||
txt_info_prepare = gr.Text(label="info", value="")
|
||||
txt_vocab_prepare = gr.Text(label="vocab", value="")
|
||||
@@ -1048,14 +1174,6 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
|
||||
fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
|
||||
)
|
||||
|
||||
with gr.TabItem("vocab check"):
|
||||
gr.Markdown("""```plaintext
|
||||
check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are included. for finetune new language
|
||||
```""")
|
||||
check_button = gr.Button("check vocab")
|
||||
txt_info_check = gr.Text(label="info", value="")
|
||||
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check])
|
||||
|
||||
with gr.TabItem("train Data"):
|
||||
gr.Markdown("""```plaintext
|
||||
The auto-setting is still experimental. Please make sure that the epochs , save per updates , and last per steps are set correctly, or change them manually as needed.
|
||||
|
||||
Reference in New Issue
Block a user