From e1e3b269870a81a3ccb395c0bde99439b0322eb2 Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 27 Oct 2024 18:01:46 +0200 Subject: [PATCH 01/12] fix space curse problem with utf-8-sig --- src/f5_tts/train/finetune_gradio.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 22fd209..3506a7c 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -801,11 +801,11 @@ def vocab_extend(project_name, symbols, model_type): return "Symbols are okay no need to extend." size_vocab = len(vocab) - vocab.pop() # fix empty space leave + for item in miss_symbols: vocab.append(item) - with open(file_vocab_project, "w", encoding="utf-8-sig") as f: + with open(file_vocab_project, "w", encoding="utf-8") as f: f.write("\n".join(vocab)) if model_type == "F5-TTS": @@ -813,14 +813,17 @@ def vocab_extend(project_name, symbols, model_type): else: ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) - new_ckpt_path = os.path.join(path_project_ckpts, name_project) + vocab_size_new = len(miss_symbols) + + dataset_name = name_project.replace("_pinyin", "").replace("_char", "") + new_ckpt_path = os.path.join(path_project_ckpts, dataset_name) os.makedirs(new_ckpt_path, exist_ok=True) new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt") - size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=len(miss_symbols)) + size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new) vocab_new = "\n".join(miss_symbols) - return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {len(miss_symbols)}\nnew symbols :\n{vocab_new}" + return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new}\nnew symbols :\n{vocab_new}" def vocab_check(project_name): From 0de2e531d4f4e7c7e2fa081d8fb6b70cc7796b1b Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 27 Oct 2024 19:14:04 +0200 Subject: [PATCH 02/12] fix extend --- src/f5_tts/train/finetune_gradio.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 3506a7c..a33e4d1 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -801,10 +801,12 @@ def vocab_extend(project_name, symbols, model_type): return "Symbols are okay no need to extend." size_vocab = len(vocab) - + vocab.pop() for item in miss_symbols: vocab.append(item) + vocab.append("") + with open(file_vocab_project, "w", encoding="utf-8") as f: f.write("\n".join(vocab)) From 2eae16b4a3bc980a86636c81420f6011e368e778 Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 27 Oct 2024 19:30:04 +0200 Subject: [PATCH 03/12] Do not overwrite the vocab if it already exists ! --- src/f5_tts/train/finetune_gradio.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index a33e4d1..fec3e6b 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -564,10 +564,11 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()): new_vocal = "" if not ch_tokenizer: - file_vocab_finetune = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt") - if not os.path.isfile(file_vocab_finetune): - return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!", "" - shutil.copy2(file_vocab_finetune, file_vocab) + if not os.path.isfile(file_vocab): + file_vocab_finetune = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt") + if not os.path.isfile(file_vocab_finetune): + return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!", "" + shutil.copy2(file_vocab_finetune, file_vocab) with open(file_vocab, "r", encoding="utf-8-sig") as f: vocab_char_map = {} From 0f2a9230ec645384577e84d433e7ba4cc01a7f5b Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 27 Oct 2024 20:16:21 +0200 Subject: [PATCH 04/12] add settings --- src/f5_tts/train/finetune_cli.py | 6 +- src/f5_tts/train/finetune_gradio.py | 192 ++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+), 1 deletion(-) diff --git a/src/f5_tts/train/finetune_cli.py b/src/f5_tts/train/finetune_cli.py index bfd0499..1dcd9a8 100644 --- a/src/f5_tts/train/finetune_cli.py +++ b/src/f5_tts/train/finetune_cli.py @@ -89,7 +89,11 @@ def main(): if args.finetune: if not os.path.isdir(checkpoint_path): os.makedirs(checkpoint_path, exist_ok=True) - shutil.copy2(ckpt_path, os.path.join(checkpoint_path, os.path.basename(ckpt_path))) + + file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path)) + if os.path.isfile(file_checkpoint) == False: + shutil.copy2(ckpt_path, file_checkpoint) + print("copy checkpoint for finetune") # Use the tokenizer and tokenizer_path provided in the command line arguments tokenizer = args.tokenizer diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index fec3e6b..faaea23 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -46,6 +46,119 @@ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is pipe = None +# Save settings from a JSON file +def save_settings( + project_name, + exp_name, + learning_rate, + batch_size_per_gpu, + batch_size_type, + max_samples, + grad_accumulation_steps, + max_grad_norm, + epochs, + num_warmup_updates, + save_per_updates, + last_per_steps, + finetune, + file_checkpoint_train, + tokenizer_type, + tokenizer_file, + mixed_precision, +): + path_project = os.path.join(path_project_ckpts, project_name) + os.makedirs(path_project, exist_ok=True) + file_setting = os.path.join(path_project, "setting.json") + + settings = { + "exp_name": exp_name, + "learning_rate": learning_rate, + "batch_size_per_gpu": batch_size_per_gpu, + "batch_size_type": batch_size_type, + "max_samples": max_samples, + "grad_accumulation_steps": grad_accumulation_steps, + "max_grad_norm": max_grad_norm, + "epochs": epochs, + "num_warmup_updates": num_warmup_updates, + "save_per_updates": save_per_updates, + "last_per_steps": last_per_steps, + "finetune": finetune, + "file_checkpoint_train": file_checkpoint_train, + "tokenizer_type": tokenizer_type, + "tokenizer_file": tokenizer_file, + "mixed_precision": mixed_precision, + } + with open(file_setting, "w") as f: + json.dump(settings, f, indent=4) + return "Settings saved!" + + +# Load settings from a JSON file +def load_settings(project_name): + project_name = project_name.replace("_pinyin", "").replace("_char", "") + path_project = os.path.join(path_project_ckpts, project_name) + file_setting = os.path.join(path_project, "setting.json") + + if os.path.isfile(file_setting) == False: + settings = { + "exp_name": "F5TTS_Base", + "learning_rate": 1e-05, + "batch_size_per_gpu": 1000, + "batch_size_type": "frame", + "max_samples": 64, + "grad_accumulation_steps": 1, + "max_grad_norm": 1, + "epochs": 100, + "num_warmup_updates": 2, + "save_per_updates": 300, + "last_per_steps": 200, + "finetune": True, + "file_checkpoint_train": "", + "tokenizer_type": "pinyin", + "tokenizer_file": "", + "mixed_precision": "none", + } + return ( + settings["exp_name"], + settings["learning_rate"], + settings["batch_size_per_gpu"], + settings["batch_size_type"], + settings["max_samples"], + settings["grad_accumulation_steps"], + settings["max_grad_norm"], + settings["epochs"], + settings["num_warmup_updates"], + settings["save_per_updates"], + settings["last_per_steps"], + settings["finetune"], + settings["file_checkpoint_train"], + settings["tokenizer_type"], + settings["tokenizer_file"], + settings["mixed_precision"], + ) + + with open(file_setting, "r") as f: + settings = json.load(f) + return ( + settings["exp_name"], + settings["learning_rate"], + settings["batch_size_per_gpu"], + settings["batch_size_type"], + settings["max_samples"], + settings["grad_accumulation_steps"], + settings["max_grad_norm"], + settings["epochs"], + settings["num_warmup_updates"], + settings["save_per_updates"], + settings["last_per_steps"], + settings["finetune"], + settings["file_checkpoint_train"], + settings["tokenizer_type"], + settings["tokenizer_file"], + settings["mixed_precision"], + ) + + # Load metadata def get_audio_duration(audio_path): """Calculate the duration of an audio file.""" @@ -330,6 +443,26 @@ def start_training( print(cmd) + save_settings( + dataset_name, + exp_name, + learning_rate, + batch_size_per_gpu, + batch_size_type, + max_samples, + grad_accumulation_steps, + max_grad_norm, + epochs, + num_warmup_updates, + save_per_updates, + last_per_steps, + finetune, + file_checkpoint_train, + tokenizer_type, + tokenizer_file, + mixed_precision, + ) + try: # Start the training process training_process = subprocess.Popen(cmd, shell=True) @@ -1225,6 +1358,42 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle start_button = gr.Button("Start Training") stop_button = gr.Button("Stop Training", interactive=False) + if projects_selelect is not None: + ( + exp_namev, + learning_ratev, + batch_size_per_gpuv, + batch_size_typev, + max_samplesv, + grad_accumulation_stepsv, + max_grad_normv, + epochsv, + num_warmupv_updatesv, + save_per_updatesv, + last_per_stepsv, + finetunev, + file_checkpoint_trainv, + tokenizer_typev, + tokenizer_filev, + mixed_precisionv, + ) = load_settings(projects_selelect) + exp_name.value = exp_namev + learning_rate.value = learning_ratev + batch_size_per_gpu.value = batch_size_per_gpuv + batch_size_type.value = batch_size_typev + max_samples.value = max_samplesv + grad_accumulation_steps.value = grad_accumulation_stepsv + max_grad_norm.value = max_grad_normv + epochs.value = epochsv + num_warmup_updates.value = num_warmupv_updatesv + save_per_updates.value = save_per_updatesv + last_per_steps.value = last_per_stepsv + ch_finetune.value = finetunev + file_checkpoint_train.value = file_checkpoint_train + tokenizer_type.value = tokenizer_typev + tokenizer_file.value = tokenizer_filev + mixed_precision.value = mixed_precisionv + txt_info_train = gr.Text(label="info", value="") start_button.click( fn=start_training, @@ -1279,6 +1448,29 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type] ) + cm_project.change( + fn=load_settings, + inputs=[cm_project], + outputs=[ + exp_name, + learning_rate, + batch_size_per_gpu, + batch_size_type, + max_samples, + grad_accumulation_steps, + max_grad_norm, + epochs, + num_warmup_updates, + save_per_updates, + last_per_steps, + ch_finetune, + file_checkpoint_train, + tokenizer_type, + tokenizer_file, + mixed_precision, + ], + ) + with gr.TabItem("test model"): exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS") list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False) From 48e3eb1c5768351bd621555b818f7523f0554c8d Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 27 Oct 2024 20:19:59 +0200 Subject: [PATCH 05/12] add settings --- src/f5_tts/train/finetune_cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/f5_tts/train/finetune_cli.py b/src/f5_tts/train/finetune_cli.py index 1dcd9a8..c4da990 100644 --- a/src/f5_tts/train/finetune_cli.py +++ b/src/f5_tts/train/finetune_cli.py @@ -91,7 +91,7 @@ def main(): os.makedirs(checkpoint_path, exist_ok=True) file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path)) - if os.path.isfile(file_checkpoint) == False: + if not os.path.isfile(file_checkpoint): shutil.copy2(ckpt_path, file_checkpoint) print("copy checkpoint for finetune") From 3af98f2a52b0ffbfc39968dc20ad79befaf785fb Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 27 Oct 2024 20:21:12 +0200 Subject: [PATCH 06/12] add settings --- src/f5_tts/train/finetune_gradio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index faaea23..3f2f342 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -99,7 +99,7 @@ def load_settings(project_name): path_project = os.path.join(path_project_ckpts, project_name) file_setting = os.path.join(path_project, "setting.json") - if os.path.isfile(file_setting) == False: + if not os.path.isfile(file_setting): settings = { "exp_name": "F5TTS_Base", "learning_rate": 1e-05, From eb19d9d928e9d4cb7b14e3628ac7f04caaa26255 Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 27 Oct 2024 20:39:04 +0200 Subject: [PATCH 07/12] fix path --- src/f5_tts/train/finetune_gradio.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 3f2f342..bfe1fae 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -26,7 +26,7 @@ from transformers import pipeline from cached_path import cached_path from f5_tts.api import F5TTS from f5_tts.model.utils import convert_char_to_pinyin - +from importlib.resources import files training_process = None system = platform.system() @@ -36,9 +36,9 @@ last_checkpoint = "" last_device = "" last_ema = None -path_basic = os.path.abspath(os.path.join(__file__, "../../../..")) -path_data = os.path.join(path_basic, "data") -path_project_ckpts = os.path.join(path_basic, "ckpts") + +path_data = str(files("f5_tts").joinpath("../../data")) +path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts")) file_train = "src/f5_tts/train/finetune_cli.py" device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" From 8c3810a66cd27b288205a6040202fb0697ca8742 Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 27 Oct 2024 20:48:45 +0200 Subject: [PATCH 08/12] change name make more clear the preetain need path --- src/f5_tts/train/finetune_cli.py | 2 +- src/f5_tts/train/finetune_gradio.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/f5_tts/train/finetune_cli.py b/src/f5_tts/train/finetune_cli.py index c4da990..3a867cf 100644 --- a/src/f5_tts/train/finetune_cli.py +++ b/src/f5_tts/train/finetune_cli.py @@ -45,7 +45,7 @@ def parse_args(): parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps") parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps") parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune") - parser.add_argument("--pretrain", type=str, default=None, help="Use pretrain model for finetune") + parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint") parser.add_argument( "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type" ) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index bfe1fae..7136253 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -1331,7 +1331,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle with gr.Row(): ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True) tokenizer_file = gr.Textbox(label="Tokenizer File", value="") - file_checkpoint_train = gr.Textbox(label="Pretrain Model", value="") + file_checkpoint_train = gr.Textbox(label="Path to the preetrain checkpoint ", value="") with gr.Row(): exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base") From 9db5de651bad9323ad46280a6542d91ff5a7455d Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 27 Oct 2024 21:10:26 +0200 Subject: [PATCH 09/12] add note about ema --- src/f5_tts/train/finetune_gradio.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 7136253..1b42c6a 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -1472,6 +1472,9 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle ) with gr.TabItem("test model"): + gr.Markdown("""```plaintext +SOS : check the use_ema setting (True or False) for your model to see what works best for you. +```""") exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS") list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False) From 5427f28a6d0614cb858ddf7a5e9fc044bdfdfb54 Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 28 Oct 2024 10:26:19 +0200 Subject: [PATCH 10/12] fix wrong value print vocab --- src/f5_tts/train/finetune_cli.py | 2 ++ src/f5_tts/train/finetune_gradio.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/f5_tts/train/finetune_cli.py b/src/f5_tts/train/finetune_cli.py index 3a867cf..9a95647 100644 --- a/src/f5_tts/train/finetune_cli.py +++ b/src/f5_tts/train/finetune_cli.py @@ -106,6 +106,8 @@ def main(): vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) + print("\nvocab : ", vocab_size) + mel_spec_kwargs = dict( target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 1b42c6a..1c73931 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -1389,7 +1389,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle save_per_updates.value = save_per_updatesv last_per_steps.value = last_per_stepsv ch_finetune.value = finetunev - file_checkpoint_train.value = file_checkpoint_train + file_checkpoint_train.value = file_checkpoint_trainv tokenizer_type.value = tokenizer_typev tokenizer_file.value = tokenizer_filev mixed_precision.value = mixed_precisionv From 41eb33c5c6d78ed88a3297f96b70e5b2122e43d8 Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 28 Oct 2024 12:51:51 +0200 Subject: [PATCH 11/12] add stream output --- src/f5_tts/train/finetune_gradio.py | 164 ++++++++++++++++++++++++---- 1 file changed, 145 insertions(+), 19 deletions(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 1c73931..016640d 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -1,3 +1,7 @@ +import threading +import queue +import re + import gc import json import os @@ -111,7 +115,7 @@ def load_settings(project_name): "epochs": 100, "num_warmup_updates": 2, "save_per_updates": 300, - "last_per_steps": 200, + "last_per_steps": 100, "finetune": True, "file_checkpoint_train": "", "tokenizer_type": "pinyin", @@ -369,8 +373,9 @@ def start_training( tokenizer_type="pinyin", tokenizer_file="", mixed_precision="fp16", + stream=False, ): - global training_process, tts_api + global training_process, tts_api, stop_signal if tts_api is not None: del tts_api @@ -430,6 +435,7 @@ def start_training( f"--last_per_steps {last_per_steps} " f"--dataset_name {dataset_name}" ) + if finetune: cmd += f" --finetune {finetune}" @@ -464,14 +470,112 @@ def start_training( ) try: - # Start the training process - training_process = subprocess.Popen(cmd, shell=True) + if not stream: + # Start the training process + training_process = subprocess.Popen(cmd, shell=True) - time.sleep(5) - yield "train start", gr.update(interactive=False), gr.update(interactive=True) + time.sleep(5) + yield "train start", gr.update(interactive=False), gr.update(interactive=True) + + # Wait for the training process to finish + training_process.wait() + else: + + def stream_output(pipe, output_queue): + try: + for line in iter(pipe.readline, ""): + output_queue.put(line) + except Exception as e: + output_queue.put(f"Error reading pipe: {str(e)}") + finally: + pipe.close() + + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + + training_process = subprocess.Popen( + cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env + ) + yield "Training started...", gr.update(interactive=False), gr.update(interactive=True) + + stdout_queue = queue.Queue() + stderr_queue = queue.Queue() + + stdout_thread = threading.Thread(target=stream_output, args=(training_process.stdout, stdout_queue)) + stderr_thread = threading.Thread(target=stream_output, args=(training_process.stderr, stderr_queue)) + stdout_thread.daemon = True + stderr_thread.daemon = True + stdout_thread.start() + stderr_thread.start() + stop_signal = False + while True: + if stop_signal: + training_process.terminate() + time.sleep(0.5) + if training_process.poll() is None: + training_process.kill() + yield "Training stopped by user.", gr.update(interactive=True), gr.update(interactive=False) + break + + process_status = training_process.poll() + + # Handle stdout + try: + while True: + output = stdout_queue.get_nowait() + print(output, end="") + match = re.search( + r"Epoch (\d+)/(\d+):\s+(\d+)%\|.*\[(\d+:\d+)<.*?loss=(\d+\.\d+), step=(\d+)", output + ) + if match: + current_epoch = match.group(1) + total_epochs = match.group(2) + percent_complete = match.group(3) + elapsed_time = match.group(4) + loss = match.group(5) + current_step = match.group(6) + message = ( + f"Epoch: {current_epoch}/{total_epochs}, " + f"Progress: {percent_complete}%, " + f"Elapsed Time: {elapsed_time}, " + f"Loss: {loss}, " + f"Step: {current_step}" + ) + yield message, gr.update(interactive=False), gr.update(interactive=True) + elif output.strip(): + yield output, gr.update(interactive=False), gr.update(interactive=True) + except queue.Empty: + pass + + # Handle stderr + try: + while True: + error_output = stderr_queue.get_nowait() + print(error_output, end="") + if error_output.strip(): + yield f"{error_output.strip()}", gr.update(interactive=False), gr.update(interactive=True) + except queue.Empty: + pass + + if process_status is not None and stdout_queue.empty() and stderr_queue.empty(): + if process_status != 0: + yield ( + f"Process crashed with exit code {process_status}!", + gr.update(interactive=False), + gr.update(interactive=True), + ) + else: + yield "Training complete!", gr.update(interactive=False), gr.update(interactive=True) + break + + # Small sleep to prevent CPU thrashing + time.sleep(0.1) + + # Clean up + training_process.stdout.close() + training_process.stderr.close() + training_process.wait() - # Wait for the training process to finish - training_process.wait() time.sleep(1) if training_process is None: @@ -489,11 +593,13 @@ def start_training( def stop_training(): - global training_process + global training_process, stop_signal + if training_process is None: return "Train not run !", gr.update(interactive=True), gr.update(interactive=False) terminate_process_tree(training_process.pid) - training_process = None + # training_process = None + stop_signal = True return "train stop", gr.update(interactive=True), gr.update(interactive=False) @@ -1202,7 +1308,11 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion project_name = gr.Textbox(label="project name", value="my_speak") bt_create = gr.Button("create new project") - cm_project = gr.Dropdown(choices=projects, value=projects_selelect, label="Project", allow_custom_value=True) + with gr.Row(): + cm_project = gr.Dropdown( + choices=projects, value=projects_selelect, label="Project", allow_custom_value=True, scale=6 + ) + ch_refresh_project = gr.Button("refresh", scale=1) bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project]) @@ -1304,6 +1414,7 @@ Using the extended model, you can fine-tune to a new language that is missing sy bt_prepare = bt_create = gr.Button("prepare") txt_info_prepare = gr.Text(label="info", value="") txt_vocab_prepare = gr.Text(label="vocab", value="") + bt_prepare.click( fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare] ) @@ -1347,11 +1458,11 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle with gr.Row(): epochs = gr.Number(label="Epochs", value=10) - num_warmup_updates = gr.Number(label="Warmup Updates", value=5) + num_warmup_updates = gr.Number(label="Warmup Updates", value=2) with gr.Row(): - save_per_updates = gr.Number(label="Save per Updates", value=10) - last_per_steps = gr.Number(label="Last per Steps", value=50) + save_per_updates = gr.Number(label="Save per Updates", value=300) + last_per_steps = gr.Number(label="Last per Steps", value=100) with gr.Row(): mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none") @@ -1394,6 +1505,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle tokenizer_file.value = tokenizer_filev mixed_precision.value = mixed_precisionv + ch_stream = gr.Checkbox(label="stream output experiment.", value=True) txt_info_train = gr.Text(label="info", value="") start_button.click( fn=start_training, @@ -1415,6 +1527,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle tokenizer_type, tokenizer_file, mixed_precision, + ch_stream, ], outputs=[txt_info_train, start_button, stop_button], ) @@ -1448,10 +1561,8 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type] ) - cm_project.change( - fn=load_settings, - inputs=[cm_project], - outputs=[ + def setup_load_settings(): + output_components = [ exp_name, learning_rate, batch_size_per_gpu, @@ -1468,7 +1579,22 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle tokenizer_type, tokenizer_file, mixed_precision, - ], + ] + + return output_components + + outputs = setup_load_settings() + + cm_project.change( + fn=load_settings, + inputs=[cm_project], + outputs=outputs, + ) + + ch_refresh_project.click( + fn=load_settings, + inputs=[cm_project], + outputs=outputs, ) with gr.TabItem("test model"): From 2dddb10c369c40f35787550b3dd3d18bd54bbcfd Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 28 Oct 2024 21:09:45 +0200 Subject: [PATCH 12/12] fix vocab file take from the project --- src/f5_tts/train/finetune_gradio.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 016640d..0637b38 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -1161,7 +1161,7 @@ def get_random_sample_infer(project_name): ) -def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema): +def infer(project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema): global last_checkpoint, last_device, tts_api, last_ema if not os.path.isfile(file_checkpoint): @@ -1182,7 +1182,11 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, us if last_ema != use_ema: last_ema = use_ema - tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test, use_ema=use_ema) + vocab_file = os.path.join(path_data, project, "vocab.txt") + + tts_api = F5TTS( + model_type=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema + ) print("update >> ", device_test, file_checkpoint, use_ema) @@ -1630,7 +1634,7 @@ SOS : check the use_ema setting (True or False) for your model to see what works check_button_infer.click( fn=infer, - inputs=[cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, ch_use_ema], + inputs=[cm_project, cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, ch_use_ema], outputs=[gen_audio, txt_info_gpu], )