update finetune-cli -gradio

This commit is contained in:
SWivid
2024-10-24 15:23:55 +08:00
parent b4abb3cbd6
commit 254e5e6d30
8 changed files with 388 additions and 64 deletions

View File

@@ -4,7 +4,6 @@ import os
sys.path.append(os.getcwd())
import time
import random
from tqdm import tqdm
import argparse
from importlib.resources import files
@@ -97,8 +96,6 @@ def main():
metainfo = get_seedtts_testset_metainfo(metalst)
# path to save genereted wavs
if seed is None:
seed = random.randint(-10000, 10000)
output_dir = (
f"results/{exp_name}_{ckpt_step}/{testset}/"
f"seed{seed}_{ode_method}_nfe{nfe_step}"

View File

@@ -1,13 +1,13 @@
#!/bin/bash
# e.g. F5-TTS, 16 NFE
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
# e.g. Vanilla E2 TTS, 32 NFE
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
# etc.

View File

@@ -11,6 +11,7 @@ import soundfile as sf
import torchaudio
from cached_path import cached_path
from pydub import AudioSegment
from transformers import AutoModelForCausalLM, AutoTokenizer
try:
import spaces
@@ -51,6 +52,33 @@ E2TTS_ema_model = load_model(
UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
)
# Initialize Qwen model and tokenizer
model_name = "Qwen/Qwen2.5-3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
def generate_response(messages):
"""Generate response using Qwen"""
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.95,
)
generated_ids = [
output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
@gpu_decorator
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1):
@@ -490,6 +518,146 @@ with gr.Blocks() as app_emotional:
outputs=generate_emotional_btn,
)
with gr.Blocks() as app_chat:
gr.Markdown(
"""
# Voice Chat
Have a conversation with an AI using your reference voice!
1. Upload a reference audio clip and optionally its transcript.
2. Record your message through your microphone.
3. The AI will respond using the reference voice.
"""
)
with gr.Row():
with gr.Column():
ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
with gr.Column():
with gr.Accordion("Advanced Settings", open=False):
model_choice_chat = gr.Radio(
choices=["F5-TTS", "E2-TTS"],
label="TTS Model",
value="F5-TTS",
)
remove_silence_chat = gr.Checkbox(
label="Remove Silences",
value=True,
)
ref_text_chat = gr.Textbox(
label="Reference Text",
info="Optional: Leave blank to auto-transcribe",
lines=2,
)
system_prompt_chat = gr.Textbox(
label="System Prompt",
value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
lines=2,
)
chatbot_interface = gr.Chatbot(label="Conversation")
with gr.Row():
with gr.Column():
audio_output_chat = gr.Audio(autoplay=True)
with gr.Column():
audio_input_chat = gr.Microphone(
label="Or speak your message",
type="filepath",
)
clear_btn_chat = gr.Button("Clear Conversation")
conversation_state = gr.State(
value=[
{
"role": "system",
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
}
]
)
def process_audio_input(audio_path, history, conv_state):
"""Handle audio input from user"""
if not audio_path:
return history, conv_state, ""
text = ""
text = preprocess_ref_audio_text(audio_path, text)[1]
if not text.strip():
return history, conv_state, ""
conv_state.append({"role": "user", "content": text})
history.append((text, None))
response = generate_response(conv_state)
conv_state.append({"role": "assistant", "content": response})
history[-1] = (text, response)
return history, conv_state, ""
def generate_audio_response(history, ref_audio, ref_text, model, remove_silence):
"""Generate TTS audio for AI response"""
if not history or not ref_audio:
return None
last_user_message, last_ai_response = history[-1]
if not last_ai_response:
return None
audio_result, _ = infer(
ref_audio,
ref_text,
last_ai_response,
model,
remove_silence,
cross_fade_duration=0.15,
speed=1.0,
)
return audio_result
def clear_conversation():
"""Reset the conversation"""
return [], [
{
"role": "system",
"content": "You are a friendly person, and may impersonate whoever they address you as. Stay in character. Keep your responses concise since they will be spoken out loud.",
}
]
def update_system_prompt(new_prompt):
"""Update the system prompt and reset the conversation"""
new_conv_state = [{"role": "system", "content": new_prompt}]
return [], new_conv_state
# Handle audio input
audio_input_chat.stop_recording(
process_audio_input,
inputs=[audio_input_chat, chatbot_interface, conversation_state],
outputs=[chatbot_interface, conversation_state],
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, model_choice_chat, remove_silence_chat],
outputs=audio_output_chat,
)
# Handle clear button
clear_btn_chat.click(
clear_conversation,
outputs=[chatbot_interface, conversation_state],
)
# Handle system prompt change and reset conversation
system_prompt_chat.change(
update_system_prompt,
inputs=system_prompt_chat,
outputs=[chatbot_interface, conversation_state],
)
with gr.Blocks() as app:
gr.Markdown(
"""
@@ -507,7 +675,10 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
"""
)
gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
gr.TabbedInterface(
[app_tts, app_podcast, app_emotional, app_chat, app_credits],
["TTS", "Podcast", "Multi-Style", "Voice-Chat", "Credits"],
)
@click.command()

View File

@@ -60,7 +60,7 @@ def read_audio_text_pairs(csv_file_path):
audio_text_pairs = []
parent = Path(csv_file_path).parent
with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile:
with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile:
reader = csv.reader(csvfile, delimiter="|")
next(reader) # Skip the header row
for row in reader:

View File

@@ -15,26 +15,35 @@ hop_length = 256
# -------------------------- Argument Parsing --------------------------- #
def parse_args():
# batch_size_per_gpu = 1000 settting for gpu 8GB
# batch_size_per_gpu = 1600 settting for gpu 12GB
# batch_size_per_gpu = 2000 settting for gpu 16GB
# batch_size_per_gpu = 3200 settting for gpu 24GB
# num_warmup_updates 10000 sample = 500
# change save_per_updates , last_per_steps what you need ,
parser = argparse.ArgumentParser(description="Train CFM Model")
parser.add_argument(
"--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
)
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for training")
parser.add_argument("--batch_size_per_gpu", type=int, default=256, help="Batch size per GPU")
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU")
parser.add_argument(
"--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
)
parser.add_argument("--max_samples", type=int, default=16, help="Max sequences per batch")
parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
parser.add_argument("--num_warmup_updates", type=int, default=5, help="Warmup steps")
parser.add_argument("--save_per_updates", type=int, default=10, help="Save checkpoint every X steps")
parser.add_argument("--last_per_steps", type=int, default=10, help="Save last checkpoint every X steps")
parser.add_argument("--num_warmup_updates", type=int, default=500, help="Warmup steps")
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
parser.add_argument("--last_per_steps", type=int, default=20000, help="Save last checkpoint every X steps")
parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
parser.add_argument("--pretrain", type=str, default=None, help="Use pretrain model for finetune")
parser.add_argument(
"--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
)
@@ -60,13 +69,19 @@ def main():
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
if args.finetune:
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
else:
ckpt_path = args.pretrain
elif args.exp_name == "E2TTS_Base":
wandb_resume_id = None
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if args.finetune:
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
else:
ckpt_path = args.pretrain
if args.finetune:
path_ckpt = os.path.join("ckpts", args.dataset_name)
@@ -118,6 +133,7 @@ def main():
)
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
trainer.train(
train_dataset,
resumable_with_seed=666, # seed for shuffling dataset

View File

@@ -20,6 +20,7 @@ import torch
import torchaudio
from datasets import Dataset as Dataset_
from datasets.arrow_writer import ArrowWriter
from safetensors.torch import save_file
from scipy.io import wavfile
from transformers import pipeline
@@ -247,6 +248,9 @@ def start_training(
save_per_updates=400,
last_per_steps=800,
finetune=True,
file_checkpoint_train="",
tokenizer_type="pinyin",
tokenizer_file="",
):
global training_process, tts_api
@@ -256,7 +260,7 @@ def start_training(
torch.cuda.empty_cache()
tts_api = None
path_project = os.path.join(path_data, dataset_name + "_pinyin")
path_project = os.path.join(path_data, dataset_name)
if not os.path.isdir(path_project):
yield (
@@ -278,6 +282,7 @@ def start_training(
yield "start train", gr.update(interactive=False), gr.update(interactive=False)
# Command to run the training script with the specified arguments
dataset_name = dataset_name.replace("_pinyin", "").replace("_char", "")
cmd = (
f"accelerate launch finetune-cli.py --exp_name {exp_name} "
f"--learning_rate {learning_rate} "
@@ -295,6 +300,13 @@ def start_training(
if finetune:
cmd += f" --finetune {finetune}"
if file_checkpoint_train != "":
cmd += f" --file_checkpoint_train {file_checkpoint_train}"
if tokenizer_file != "":
cmd += f" --tokenizer_path {tokenizer_file}"
cmd += f" --tokenizer {tokenizer_type} "
print(cmd)
try:
@@ -331,10 +343,28 @@ def stop_training():
return "train stop", gr.update(interactive=True), gr.update(interactive=False)
def create_data_project(name):
name += "_pinyin"
def get_list_projects():
project_list = []
for folder in os.listdir("data"):
path_folder = os.path.join("data", folder)
if not os.path.isdir(path_folder):
continue
folder = folder.lower()
if folder == "emilia_zh_en_pinyin":
continue
project_list.append(folder)
projects_selelect = None if not project_list else project_list[-1]
return project_list, projects_selelect
def create_data_project(name, tokenizer_type):
name += "_" + tokenizer_type
os.makedirs(os.path.join(path_data, name), exist_ok=True)
os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
project_list, projects_selelect = get_list_projects()
return gr.update(choices=project_list, value=name)
def transcribe(file_audio, language="english"):
@@ -359,14 +389,14 @@ def transcribe(file_audio, language="english"):
def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
name_project += "_pinyin"
path_project = os.path.join(path_data, name_project)
path_dataset = os.path.join(path_project, "dataset")
path_project_wavs = os.path.join(path_project, "wavs")
file_metadata = os.path.join(path_project, "metadata.csv")
if audio_files is None:
return "You need to load an audio file."
if not user:
if audio_files is None:
return "You need to load an audio file."
if os.path.isdir(path_project_wavs):
shutil.rmtree(path_project_wavs)
@@ -418,7 +448,7 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.
except: # noqa: E722
error_num += 1
with open(file_metadata, "w", encoding="utf-8") as f:
with open(file_metadata, "w", encoding="utf-8-sig") as f:
f.write(data)
if error_num != []:
@@ -437,7 +467,6 @@ def format_seconds_to_hms(seconds):
def create_metadata(name_project, progress=gr.Progress()):
name_project += "_pinyin"
path_project = os.path.join(path_data, name_project)
path_project_wavs = os.path.join(path_project, "wavs")
file_metadata = os.path.join(path_project, "metadata.csv")
@@ -448,7 +477,7 @@ def create_metadata(name_project, progress=gr.Progress()):
if not os.path.isfile(file_metadata):
return "The file was not found in " + file_metadata
with open(file_metadata, "r", encoding="utf-8") as f:
with open(file_metadata, "r", encoding="utf-8-sig") as f:
data = f.read()
audio_path_list = []
@@ -499,7 +528,7 @@ def create_metadata(name_project, progress=gr.Progress()):
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
writer.write(line)
with open(file_duration, "w", encoding="utf-8") as f:
with open(file_duration, "w") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
@@ -529,7 +558,6 @@ def calculate_train(
last_per_steps,
finetune,
):
name_project += "_pinyin"
path_project = os.path.join(path_data, name_project)
file_duraction = os.path.join(path_project, "duration.json")
@@ -548,8 +576,8 @@ def calculate_train(
data = json.load(file)
duration_list = data["duration"]
samples = len(duration_list)
hours = sum(duration_list) / 3600
if torch.cuda.is_available():
gpu_properties = torch.cuda.get_device_properties(0)
@@ -583,34 +611,67 @@ def calculate_train(
save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
total_hours = hours
mel_hop_length = 256
mel_sampling_rate = 24000
# target
wanted_max_updates = 1000000
# train params
gpus = 1
frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
grad_accum = 1
# intermediate
mini_batch_frames = frames_per_gpu * grad_accum * gpus
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
updates_per_epoch = total_hours / mini_batch_hours
# steps_per_epoch = updates_per_epoch * grad_accum
epochs = wanted_max_updates / updates_per_epoch
if finetune:
learning_rate = 1e-5
else:
learning_rate = 7.5e-5
return batch_size_per_gpu, max_samples, num_warmup_updates, save_per_updates, last_per_steps, samples, learning_rate
return (
batch_size_per_gpu,
max_samples,
num_warmup_updates,
save_per_updates,
last_per_steps,
samples,
learning_rate,
int(epochs),
)
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str:
try:
checkpoint = torch.load(checkpoint_path)
print("Original Checkpoint Keys:", checkpoint.keys())
ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
if ema_model_state_dict is None:
return "No 'ema_model_state_dict' found in the checkpoint."
if ema_model_state_dict is not None:
if safetensors:
new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors")
save_file(ema_model_state_dict, new_checkpoint_path)
else:
new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt")
new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
torch.save(new_checkpoint, new_checkpoint_path)
return f"New checkpoint saved at: {new_checkpoint_path}"
else:
return "No 'ema_model_state_dict' found in the checkpoint."
return f"New checkpoint saved at: {new_checkpoint_path}"
except Exception as e:
return f"An error occurred: {e}"
def vocab_check(project_name):
name_project = project_name + "_pinyin"
name_project = project_name
path_project = os.path.join(path_data, name_project)
file_metadata = os.path.join(path_project, "metadata.csv")
@@ -619,15 +680,15 @@ def vocab_check(project_name):
if not os.path.isfile(file_vocab):
return f"the file {file_vocab} not found !"
with open(file_vocab, "r", encoding="utf-8") as f:
with open(file_vocab, "r", encoding="utf-8-sig") as f:
data = f.read()
vocab = data.split("\n")
vocab = data.split("\n")
vocab = set(vocab)
if not os.path.isfile(file_metadata):
return f"the file {file_metadata} not found !"
with open(file_metadata, "r", encoding="utf-8") as f:
with open(file_metadata, "r", encoding="utf-8-sig") as f:
data = f.read()
miss_symbols = []
@@ -652,7 +713,7 @@ def vocab_check(project_name):
def get_random_sample_prepare(project_name):
name_project = project_name + "_pinyin"
name_project = project_name
path_project = os.path.join(path_data, name_project)
file_arrow = os.path.join(path_project, "raw.arrow")
if not os.path.isfile(file_arrow):
@@ -665,14 +726,14 @@ def get_random_sample_prepare(project_name):
def get_random_sample_transcribe(project_name):
name_project = project_name + "_pinyin"
name_project = project_name
path_project = os.path.join(path_data, name_project)
file_metadata = os.path.join(path_project, "metadata.csv")
if not os.path.isfile(file_metadata):
return "", None
data = ""
with open(file_metadata, "r", encoding="utf-8") as f:
with open(file_metadata, "r", encoding="utf-8-sig") as f:
data = f.read()
list_data = []
@@ -703,13 +764,14 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
global last_checkpoint, last_device, tts_api
if not os.path.isfile(file_checkpoint):
return None
return None, "checkpoint not found!"
if training_process is not None:
device_test = "cpu"
else:
device_test = None
device_test = "cpu"
if last_checkpoint != file_checkpoint or last_device != device_test:
if last_checkpoint != file_checkpoint:
last_checkpoint = file_checkpoint
@@ -722,19 +784,67 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
return f.name
return f.name, tts_api.device
def check_finetune(finetune):
return gr.update(interactive=finetune), gr.update(interactive=finetune), gr.update(interactive=finetune)
def get_checkpoints_project(project_name, is_gradio=True):
if project_name is None:
return [], ""
project_name = project_name.replace("_pinyin", "").replace("_char", "")
path_project_ckpts = os.path.join("ckpts", project_name)
if os.path.isdir(path_project_ckpts):
files_checkpoints = glob(os.path.join(path_project_ckpts, "*.pt"))
files_checkpoints = sorted(
files_checkpoints,
key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])
if os.path.basename(x) != "model_last.pt"
else float("inf"),
)
else:
files_checkpoints = []
selelect_checkpoint = None if not files_checkpoints else files_checkpoints[0]
if is_gradio:
return gr.update(choices=files_checkpoints, value=selelect_checkpoint)
return files_checkpoints, selelect_checkpoint
with gr.Blocks() as app:
gr.Markdown(
"""
# E2/F5 TTS AUTOMATIC FINETUNE
This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
The checkpoints support English and Chinese.
for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
"""
)
with gr.Row():
projects, projects_selelect = get_list_projects()
tokenizer_type = gr.Radio(label="Tokenizer Type", choices=["pinyin", "char"], value="pinyin")
project_name = gr.Textbox(label="project name", value="my_speak")
bt_create = gr.Button("create new project")
bt_create.click(fn=create_data_project, inputs=[project_name])
cm_project = gr.Dropdown(choices=projects, value=projects_selelect, label="Project", allow_custom_value=True)
bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project])
with gr.Tabs():
with gr.TabItem("transcribe Data"):
ch_manual = gr.Checkbox(label="user", value=False)
ch_manual = gr.Checkbox(label="audio from path", value=False)
mark_info_transcribe = gr.Markdown(
"""```plaintext
@@ -756,7 +866,7 @@ with gr.Blocks() as app:
txt_info_transcribe = gr.Text(label="info", value="")
bt_transcribe.click(
fn=transcribe_all,
inputs=[project_name, audio_speaker, txt_lang, ch_manual],
inputs=[cm_project, audio_speaker, txt_lang, ch_manual],
outputs=[txt_info_transcribe],
)
ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
@@ -769,7 +879,7 @@ with gr.Blocks() as app:
random_sample_transcribe.click(
fn=get_random_sample_transcribe,
inputs=[project_name],
inputs=[cm_project],
outputs=[random_text_transcribe, random_audio_transcribe],
)
@@ -797,7 +907,7 @@ with gr.Blocks() as app:
bt_prepare = bt_create = gr.Button("prepare")
txt_info_prepare = gr.Text(label="info", value="")
bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
bt_prepare.click(fn=create_metadata, inputs=[cm_project], outputs=[txt_info_prepare])
random_sample_prepare = gr.Button("random sample")
@@ -806,16 +916,20 @@ with gr.Blocks() as app:
random_audio_prepare = gr.Audio(label="Audio", type="filepath")
random_sample_prepare.click(
fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare]
fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
)
with gr.TabItem("train Data"):
with gr.Row():
bt_calculate = bt_create = gr.Button("Auto Settings")
ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
lb_samples = gr.Label(label="samples")
batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
with gr.Row():
ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
file_checkpoint_train = gr.Textbox(label="Checkpoint", value="")
with gr.Row():
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
@@ -844,7 +958,7 @@ with gr.Blocks() as app:
start_button.click(
fn=start_training,
inputs=[
project_name,
cm_project,
exp_name,
learning_rate,
batch_size_per_gpu,
@@ -857,14 +971,18 @@ with gr.Blocks() as app:
save_per_updates,
last_per_steps,
ch_finetune,
file_checkpoint_train,
tokenizer_type,
tokenizer_file,
],
outputs=[txt_info_train, start_button, stop_button],
)
stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
bt_calculate.click(
fn=calculate_train,
inputs=[
project_name,
cm_project,
batch_size_type,
max_samples,
learning_rate,
@@ -881,29 +999,42 @@ with gr.Blocks() as app:
last_per_steps,
lb_samples,
learning_rate,
epochs,
],
)
ch_finetune.change(
check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type]
)
with gr.TabItem("reduse checkpoint"):
txt_path_checkpoint = gr.Text(label="path checkpoint :")
txt_path_checkpoint_small = gr.Text(label="path output :")
ch_safetensors = gr.Checkbox(label="safetensors", value="")
txt_info_reduse = gr.Text(label="info", value="")
reduse_button = gr.Button("reduse")
reduse_button.click(
fn=extract_and_save_ema_model,
inputs=[txt_path_checkpoint, txt_path_checkpoint_small],
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_safetensors],
outputs=[txt_info_reduse],
)
with gr.TabItem("vocab check experiment"):
check_button = gr.Button("check vocab")
txt_info_check = gr.Text(label="info", value="")
check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check])
with gr.TabItem("test model"):
exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
nfe_step = gr.Number(label="n_step", value=32)
file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="")
with gr.Row():
cm_checkpoint = gr.Dropdown(
choices=list_checkpoints, value=checkpoint_select, label="checkpoints", allow_custom_value=True
)
bt_checkpoint_refresh = gr.Button("refresh")
random_sample_infer = gr.Button("random sample")
@@ -911,17 +1042,24 @@ with gr.Blocks() as app:
ref_audio = gr.Audio(label="audio ref", type="filepath")
gen_text = gr.Textbox(label="gen text")
random_sample_infer.click(
fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio]
fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
)
check_button_infer = gr.Button("infer")
with gr.Row():
txt_info_gpu = gr.Textbox("", label="device")
check_button_infer = gr.Button("infer")
gen_audio = gr.Audio(label="audio gen", type="filepath")
check_button_infer.click(
fn=infer,
inputs=[file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step],
outputs=[gen_audio],
inputs=[cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step],
outputs=[gen_audio, txt_info_gpu],
)
bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
@click.command()
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")