mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-05 09:39:52 -08:00
auto settings and reduse new tab
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import os,sys
|
||||
os.chdir(r"C:\PythonApps\ff5ttsmy\F5-TTS")
|
||||
|
||||
from transformers import pipeline
|
||||
import gradio as gr
|
||||
@@ -22,6 +23,10 @@ import platform
|
||||
import subprocess
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
|
||||
import json
|
||||
|
||||
|
||||
|
||||
training_process = None
|
||||
system = platform.system()
|
||||
python_executable = sys.executable or "python"
|
||||
@@ -79,10 +84,10 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.p
|
||||
self,
|
||||
sr: int,
|
||||
threshold: float = -40.0,
|
||||
min_length: int = 5000,
|
||||
min_length: int = 2000,
|
||||
min_interval: int = 300,
|
||||
hop_size: int = 20,
|
||||
max_sil_kept: int = 5000,
|
||||
max_sil_kept: int = 2000,
|
||||
):
|
||||
if not min_length >= min_interval >= hop_size:
|
||||
raise ValueError(
|
||||
@@ -245,19 +250,19 @@ def terminate_process(pid):
|
||||
else:
|
||||
terminate_process_tree(pid)
|
||||
|
||||
def start_training(
|
||||
dataset_name="",
|
||||
exp_name="F5TTS_Base", # Default experiment name
|
||||
learning_rate=1e-4, # Default learning rate
|
||||
batch_size_per_gpu=400, # Default batch size per GPU
|
||||
batch_size_type="frame", # Default batch size type
|
||||
max_samples=64, # Default max sequences per batch
|
||||
grad_accumulation_steps=1, # Default gradient accumulation steps
|
||||
max_grad_norm=1.0, # Default max gradient norm
|
||||
epochs=11, # Default number of training epochs
|
||||
num_warmup_updates=200, # Default number of warmup updates
|
||||
save_per_updates=400, # Default save interval for checkpoints
|
||||
last_per_steps=800, # Default save interval for last checkpoint
|
||||
def start_training(dataset_name="",
|
||||
exp_name="F5TTS_Base",
|
||||
learning_rate=1e-4,
|
||||
batch_size_per_gpu=400,
|
||||
batch_size_type="frame",
|
||||
max_samples=64,
|
||||
grad_accumulation_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
epochs=11,
|
||||
num_warmup_updates=200,
|
||||
save_per_updates=400,
|
||||
last_per_steps=800,
|
||||
finetune=True,
|
||||
):
|
||||
|
||||
global training_process
|
||||
@@ -280,8 +285,9 @@ def start_training(
|
||||
f"--num_warmup_updates {num_warmup_updates} " \
|
||||
f"--save_per_updates {save_per_updates} " \
|
||||
f"--last_per_steps {last_per_steps} " \
|
||||
f"--dataset_name {dataset_name}"
|
||||
print(cmd)
|
||||
f"--dataset_name {dataset_name}"
|
||||
if finetune:cmd += f" --finetune {finetune}"
|
||||
print(cmd)
|
||||
try:
|
||||
# Start the training process
|
||||
training_process = subprocess.Popen(cmd, shell=True)
|
||||
@@ -451,6 +457,65 @@ def create_metadata(name_project,progress=gr.Progress()):
|
||||
def check_user(value):
|
||||
return gr.update(visible=not value),gr.update(visible=value)
|
||||
|
||||
def calculate_train(name_project,batch_size_type,max_samples,learning_rate,num_warmup_updates,save_per_updates,last_per_steps,finetune):
|
||||
name_project+="_pinyin"
|
||||
path_project= os.path.join(path_data,name_project)
|
||||
file_duraction = os.path.join(path_project,"duration.json")
|
||||
|
||||
with open(file_duraction, 'r') as file:
|
||||
data = json.load(file)
|
||||
|
||||
duration_list = data['duration']
|
||||
samples = len(duration_list)
|
||||
|
||||
gpu_properties = torch.cuda.get_device_properties(0)
|
||||
total_memory = gpu_properties.total_memory / (1024 ** 3)
|
||||
|
||||
if batch_size_type=="frame":
|
||||
batch = int(total_memory * 0.5)
|
||||
batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
|
||||
batch_size_per_gpu = int(36800 / batch )
|
||||
else:
|
||||
batch_size_per_gpu = int(total_memory / 8)
|
||||
batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
|
||||
batch = batch_size_per_gpu
|
||||
|
||||
if batch_size_per_gpu<=0:batch_size_per_gpu=1
|
||||
|
||||
if samples<64:
|
||||
max_samples = int(samples * 0.25)
|
||||
|
||||
num_warmup_updates = int(samples * 0.10)
|
||||
save_per_updates = int(samples * 0.25)
|
||||
last_per_steps =int(save_per_updates * 5)
|
||||
|
||||
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
|
||||
num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
|
||||
save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
|
||||
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
|
||||
|
||||
if finetune:learning_rate=1e-4
|
||||
else:learning_rate=7.5e-5
|
||||
|
||||
return batch_size_per_gpu,max_samples,num_warmup_updates,save_per_updates,last_per_steps,samples,learning_rate
|
||||
|
||||
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
print("Original Checkpoint Keys:", checkpoint.keys())
|
||||
|
||||
ema_model_state_dict = checkpoint.get('ema_model_state_dict', None)
|
||||
|
||||
if ema_model_state_dict is not None:
|
||||
new_checkpoint = {'ema_model_state_dict': ema_model_state_dict}
|
||||
torch.save(new_checkpoint, new_checkpoint_path)
|
||||
print(f"New checkpoint saved at: {new_checkpoint_path}")
|
||||
else:
|
||||
print("No 'ema_model_state_dict' found in the checkpoint.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
with gr.Blocks() as app:
|
||||
|
||||
with gr.Row():
|
||||
@@ -513,14 +578,20 @@ with gr.Blocks() as app:
|
||||
|
||||
with gr.TabItem("train Data"):
|
||||
|
||||
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
|
||||
learning_rate = gr.Number(label="Learning Rate", value=1e-4, step=1e-4)
|
||||
|
||||
with gr.Row():
|
||||
bt_calculate=bt_create=gr.Button("Auto Settings")
|
||||
ch_finetune=bt_create=gr.Checkbox(label="finetune",value=True)
|
||||
lb_samples = gr.Label(label="samples")
|
||||
batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
|
||||
|
||||
with gr.Row():
|
||||
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
|
||||
learning_rate = gr.Number(label="Learning Rate", value=1e-4, step=1e-4)
|
||||
|
||||
with gr.Row():
|
||||
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=256)
|
||||
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
|
||||
max_samples = gr.Number(label="Max Samples", value=16)
|
||||
batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
|
||||
max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
|
||||
@@ -531,17 +602,23 @@ with gr.Blocks() as app:
|
||||
|
||||
with gr.Row():
|
||||
save_per_updates = gr.Number(label="Save per Updates", value=10)
|
||||
last_per_steps = gr.Number(label="Last per Steps", value=10)
|
||||
last_per_steps = gr.Number(label="Last per Steps", value=50)
|
||||
|
||||
with gr.Row():
|
||||
start_button = gr.Button("Start Training")
|
||||
stop_button = gr.Button("Stop Training",interactive=False)
|
||||
|
||||
txt_info_train=gr.Text(label="info",value="")
|
||||
start_button.click(fn=start_training,inputs=[project_name,exp_name,learning_rate,batch_size_per_gpu,batch_size_type,max_samples,grad_accumulation_steps,max_grad_norm,epochs,num_warmup_updates,save_per_updates,last_per_steps],outputs=[txt_info_train,start_button,stop_button])
|
||||
start_button.click(fn=start_training,inputs=[project_name,exp_name,learning_rate,batch_size_per_gpu,batch_size_type,max_samples,grad_accumulation_steps,max_grad_norm,epochs,num_warmup_updates,save_per_updates,last_per_steps,ch_finetune],outputs=[txt_info_train,start_button,stop_button])
|
||||
stop_button.click(fn=stop_training,outputs=[txt_info_train,start_button,stop_button])
|
||||
|
||||
|
||||
bt_calculate.click(fn=calculate_train,inputs=[project_name,batch_size_type,max_samples,learning_rate,num_warmup_updates,save_per_updates,last_per_steps,ch_finetune],outputs=[batch_size_per_gpu,max_samples,num_warmup_updates,save_per_updates,last_per_steps,lb_samples,learning_rate])
|
||||
|
||||
with gr.TabItem("reduse checkpoint"):
|
||||
txt_path_checkpoint = gr.Text(label="path checkpoint :")
|
||||
txt_path_checkpoint_small = gr.Text(label="path output :")
|
||||
reduse_button = gr.Button("reduse")
|
||||
reduse_button.click(fn=extract_and_save_ema_model,inputs=[txt_path_checkpoint,txt_path_checkpoint_small])
|
||||
|
||||
@click.command()
|
||||
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|
||||
@click.option("--host", "-H", default=None, help="Host to run the app on")
|
||||
|
||||
Reference in New Issue
Block a user