From d5c307b56aa85bbc9b210ebe459ea4a11317053f Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 29 Oct 2024 17:30:28 +0200 Subject: [PATCH] add logger --- src/f5_tts/train/finetune_gradio.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 3693f2f..b5a45c7 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -69,6 +69,7 @@ def save_settings( tokenizer_type, tokenizer_file, mixed_precision, + logger, ): path_project = os.path.join(path_project_ckpts, project_name) os.makedirs(path_project, exist_ok=True) @@ -91,6 +92,7 @@ def save_settings( "tokenizer_type": tokenizer_type, "tokenizer_file": tokenizer_file, "mixed_precision": mixed_precision, + "logger": logger, } with open(file_setting, "w") as f: json.dump(settings, f, indent=4) @@ -121,6 +123,7 @@ def load_settings(project_name): "tokenizer_type": "pinyin", "tokenizer_file": "", "mixed_precision": "none", + "logger": "wandb", } return ( settings["exp_name"], @@ -139,6 +142,7 @@ def load_settings(project_name): settings["tokenizer_type"], settings["tokenizer_file"], settings["mixed_precision"], + settings["logger"], ) with open(file_setting, "r") as f: @@ -160,6 +164,7 @@ def load_settings(project_name): settings["tokenizer_type"], settings["tokenizer_file"], settings["mixed_precision"], + settings["logger"], ) @@ -374,6 +379,7 @@ def start_training( tokenizer_file="", mixed_precision="fp16", stream=False, + logger="wandb", ): global training_process, tts_api, stop_signal @@ -447,7 +453,7 @@ def start_training( cmd += f" --tokenizer {tokenizer_type} " - cmd += " --export_samples True --logger wandb " + cmd += f" --export_samples True --logger {logger} " print(cmd) @@ -469,6 +475,7 @@ def start_training( tokenizer_type, tokenizer_file, mixed_precision, + logger, ) try: @@ -1508,6 +1515,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle with gr.Row(): mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none") + cd_logger = gr.Radio(label="logger", choices=["none", "wandb", "tensorboard"], value="wandb") start_button = gr.Button("Start Training") stop_button = gr.Button("Stop Training", interactive=False) @@ -1529,6 +1537,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle tokenizer_typev, tokenizer_filev, mixed_precisionv, + cd_loggerv, ) = load_settings(projects_selelect) exp_name.value = exp_namev learning_rate.value = learning_ratev @@ -1546,6 +1555,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle tokenizer_type.value = tokenizer_typev tokenizer_file.value = tokenizer_filev mixed_precision.value = mixed_precisionv + cd_logger.value = cd_loggerv ch_stream = gr.Checkbox(label="stream output experiment.", value=True) txt_info_train = gr.Text(label="info", value="") @@ -1611,6 +1621,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle tokenizer_file, mixed_precision, ch_stream, + cd_logger, ], outputs=[txt_info_train, start_button, stop_button], ) @@ -1662,6 +1673,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle tokenizer_type, tokenizer_file, mixed_precision, + cd_logger, ] return output_components