mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-13 05:27:40 -08:00
update finetune-cli -gradio
This commit is contained in:
@@ -4,7 +4,6 @@ import os
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import time
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
from importlib.resources import files
|
||||
@@ -97,8 +96,6 @@ def main():
|
||||
metainfo = get_seedtts_testset_metainfo(metalst)
|
||||
|
||||
# path to save genereted wavs
|
||||
if seed is None:
|
||||
seed = random.randint(-10000, 10000)
|
||||
output_dir = (
|
||||
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
||||
f"seed{seed}_{ode_method}_nfe{nfe_step}"
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
# e.g. F5-TTS, 16 NFE
|
||||
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
|
||||
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
|
||||
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
|
||||
accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
|
||||
accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
|
||||
accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
|
||||
|
||||
# e.g. Vanilla E2 TTS, 32 NFE
|
||||
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
|
||||
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
|
||||
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
||||
accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
|
||||
accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
|
||||
accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
||||
|
||||
# etc.
|
||||
|
||||
@@ -11,6 +11,7 @@ import soundfile as sf
|
||||
import torchaudio
|
||||
from cached_path import cached_path
|
||||
from pydub import AudioSegment
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
try:
|
||||
import spaces
|
||||
@@ -51,6 +52,33 @@ E2TTS_ema_model = load_model(
|
||||
UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
|
||||
)
|
||||
|
||||
# Initialize Qwen model and tokenizer
|
||||
model_name = "Qwen/Qwen2.5-3B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
|
||||
def generate_response(messages):
|
||||
"""Generate response using Qwen"""
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
generated_ids = model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=512,
|
||||
temperature=0.7,
|
||||
top_p=0.95,
|
||||
)
|
||||
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
|
||||
@gpu_decorator
|
||||
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1):
|
||||
@@ -490,6 +518,146 @@ with gr.Blocks() as app_emotional:
|
||||
outputs=generate_emotional_btn,
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks() as app_chat:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# Voice Chat
|
||||
Have a conversation with an AI using your reference voice!
|
||||
1. Upload a reference audio clip and optionally its transcript.
|
||||
2. Record your message through your microphone.
|
||||
3. The AI will respond using the reference voice.
|
||||
"""
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
|
||||
|
||||
with gr.Column():
|
||||
with gr.Accordion("Advanced Settings", open=False):
|
||||
model_choice_chat = gr.Radio(
|
||||
choices=["F5-TTS", "E2-TTS"],
|
||||
label="TTS Model",
|
||||
value="F5-TTS",
|
||||
)
|
||||
remove_silence_chat = gr.Checkbox(
|
||||
label="Remove Silences",
|
||||
value=True,
|
||||
)
|
||||
ref_text_chat = gr.Textbox(
|
||||
label="Reference Text",
|
||||
info="Optional: Leave blank to auto-transcribe",
|
||||
lines=2,
|
||||
)
|
||||
system_prompt_chat = gr.Textbox(
|
||||
label="System Prompt",
|
||||
value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
|
||||
lines=2,
|
||||
)
|
||||
|
||||
chatbot_interface = gr.Chatbot(label="Conversation")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
audio_output_chat = gr.Audio(autoplay=True)
|
||||
with gr.Column():
|
||||
audio_input_chat = gr.Microphone(
|
||||
label="Or speak your message",
|
||||
type="filepath",
|
||||
)
|
||||
|
||||
clear_btn_chat = gr.Button("Clear Conversation")
|
||||
|
||||
conversation_state = gr.State(
|
||||
value=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def process_audio_input(audio_path, history, conv_state):
|
||||
"""Handle audio input from user"""
|
||||
if not audio_path:
|
||||
return history, conv_state, ""
|
||||
|
||||
text = ""
|
||||
text = preprocess_ref_audio_text(audio_path, text)[1]
|
||||
|
||||
if not text.strip():
|
||||
return history, conv_state, ""
|
||||
|
||||
conv_state.append({"role": "user", "content": text})
|
||||
history.append((text, None))
|
||||
|
||||
response = generate_response(conv_state)
|
||||
|
||||
conv_state.append({"role": "assistant", "content": response})
|
||||
history[-1] = (text, response)
|
||||
|
||||
return history, conv_state, ""
|
||||
|
||||
def generate_audio_response(history, ref_audio, ref_text, model, remove_silence):
|
||||
"""Generate TTS audio for AI response"""
|
||||
if not history or not ref_audio:
|
||||
return None
|
||||
|
||||
last_user_message, last_ai_response = history[-1]
|
||||
if not last_ai_response:
|
||||
return None
|
||||
|
||||
audio_result, _ = infer(
|
||||
ref_audio,
|
||||
ref_text,
|
||||
last_ai_response,
|
||||
model,
|
||||
remove_silence,
|
||||
cross_fade_duration=0.15,
|
||||
speed=1.0,
|
||||
)
|
||||
return audio_result
|
||||
|
||||
def clear_conversation():
|
||||
"""Reset the conversation"""
|
||||
return [], [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a friendly person, and may impersonate whoever they address you as. Stay in character. Keep your responses concise since they will be spoken out loud.",
|
||||
}
|
||||
]
|
||||
|
||||
def update_system_prompt(new_prompt):
|
||||
"""Update the system prompt and reset the conversation"""
|
||||
new_conv_state = [{"role": "system", "content": new_prompt}]
|
||||
return [], new_conv_state
|
||||
|
||||
# Handle audio input
|
||||
audio_input_chat.stop_recording(
|
||||
process_audio_input,
|
||||
inputs=[audio_input_chat, chatbot_interface, conversation_state],
|
||||
outputs=[chatbot_interface, conversation_state],
|
||||
).then(
|
||||
generate_audio_response,
|
||||
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, model_choice_chat, remove_silence_chat],
|
||||
outputs=audio_output_chat,
|
||||
)
|
||||
|
||||
# Handle clear button
|
||||
clear_btn_chat.click(
|
||||
clear_conversation,
|
||||
outputs=[chatbot_interface, conversation_state],
|
||||
)
|
||||
|
||||
# Handle system prompt change and reset conversation
|
||||
system_prompt_chat.change(
|
||||
update_system_prompt,
|
||||
inputs=system_prompt_chat,
|
||||
outputs=[chatbot_interface, conversation_state],
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown(
|
||||
"""
|
||||
@@ -507,7 +675,10 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
||||
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
|
||||
"""
|
||||
)
|
||||
gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
|
||||
gr.TabbedInterface(
|
||||
[app_tts, app_podcast, app_emotional, app_chat, app_credits],
|
||||
["TTS", "Podcast", "Multi-Style", "Voice-Chat", "Credits"],
|
||||
)
|
||||
|
||||
|
||||
@click.command()
|
||||
|
||||
@@ -60,7 +60,7 @@ def read_audio_text_pairs(csv_file_path):
|
||||
audio_text_pairs = []
|
||||
|
||||
parent = Path(csv_file_path).parent
|
||||
with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile:
|
||||
with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile:
|
||||
reader = csv.reader(csvfile, delimiter="|")
|
||||
next(reader) # Skip the header row
|
||||
for row in reader:
|
||||
|
||||
@@ -15,26 +15,35 @@ hop_length = 256
|
||||
|
||||
# -------------------------- Argument Parsing --------------------------- #
|
||||
def parse_args():
|
||||
# batch_size_per_gpu = 1000 settting for gpu 8GB
|
||||
# batch_size_per_gpu = 1600 settting for gpu 12GB
|
||||
# batch_size_per_gpu = 2000 settting for gpu 16GB
|
||||
# batch_size_per_gpu = 3200 settting for gpu 24GB
|
||||
|
||||
# num_warmup_updates 10000 sample = 500
|
||||
|
||||
# change save_per_updates , last_per_steps what you need ,
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train CFM Model")
|
||||
|
||||
parser.add_argument(
|
||||
"--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
|
||||
)
|
||||
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for training")
|
||||
parser.add_argument("--batch_size_per_gpu", type=int, default=256, help="Batch size per GPU")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
|
||||
parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU")
|
||||
parser.add_argument(
|
||||
"--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
|
||||
)
|
||||
parser.add_argument("--max_samples", type=int, default=16, help="Max sequences per batch")
|
||||
parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
|
||||
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
||||
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
||||
parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
|
||||
parser.add_argument("--num_warmup_updates", type=int, default=5, help="Warmup steps")
|
||||
parser.add_argument("--save_per_updates", type=int, default=10, help="Save checkpoint every X steps")
|
||||
parser.add_argument("--last_per_steps", type=int, default=10, help="Save last checkpoint every X steps")
|
||||
parser.add_argument("--num_warmup_updates", type=int, default=500, help="Warmup steps")
|
||||
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
|
||||
parser.add_argument("--last_per_steps", type=int, default=20000, help="Save last checkpoint every X steps")
|
||||
parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
|
||||
|
||||
parser.add_argument("--pretrain", type=str, default=None, help="Use pretrain model for finetune")
|
||||
parser.add_argument(
|
||||
"--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
|
||||
)
|
||||
@@ -60,13 +69,19 @@ def main():
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
if args.finetune:
|
||||
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
||||
if args.pretrain is None:
|
||||
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
||||
else:
|
||||
ckpt_path = args.pretrain
|
||||
elif args.exp_name == "E2TTS_Base":
|
||||
wandb_resume_id = None
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
if args.finetune:
|
||||
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
||||
if args.pretrain is None:
|
||||
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
||||
else:
|
||||
ckpt_path = args.pretrain
|
||||
|
||||
if args.finetune:
|
||||
path_ckpt = os.path.join("ckpts", args.dataset_name)
|
||||
@@ -118,6 +133,7 @@ def main():
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
||||
|
||||
trainer.train(
|
||||
train_dataset,
|
||||
resumable_with_seed=666, # seed for shuffling dataset
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch
|
||||
import torchaudio
|
||||
from datasets import Dataset as Dataset_
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from safetensors.torch import save_file
|
||||
from scipy.io import wavfile
|
||||
from transformers import pipeline
|
||||
|
||||
@@ -247,6 +248,9 @@ def start_training(
|
||||
save_per_updates=400,
|
||||
last_per_steps=800,
|
||||
finetune=True,
|
||||
file_checkpoint_train="",
|
||||
tokenizer_type="pinyin",
|
||||
tokenizer_file="",
|
||||
):
|
||||
global training_process, tts_api
|
||||
|
||||
@@ -256,7 +260,7 @@ def start_training(
|
||||
torch.cuda.empty_cache()
|
||||
tts_api = None
|
||||
|
||||
path_project = os.path.join(path_data, dataset_name + "_pinyin")
|
||||
path_project = os.path.join(path_data, dataset_name)
|
||||
|
||||
if not os.path.isdir(path_project):
|
||||
yield (
|
||||
@@ -278,6 +282,7 @@ def start_training(
|
||||
yield "start train", gr.update(interactive=False), gr.update(interactive=False)
|
||||
|
||||
# Command to run the training script with the specified arguments
|
||||
dataset_name = dataset_name.replace("_pinyin", "").replace("_char", "")
|
||||
cmd = (
|
||||
f"accelerate launch finetune-cli.py --exp_name {exp_name} "
|
||||
f"--learning_rate {learning_rate} "
|
||||
@@ -295,6 +300,13 @@ def start_training(
|
||||
if finetune:
|
||||
cmd += f" --finetune {finetune}"
|
||||
|
||||
if file_checkpoint_train != "":
|
||||
cmd += f" --file_checkpoint_train {file_checkpoint_train}"
|
||||
|
||||
if tokenizer_file != "":
|
||||
cmd += f" --tokenizer_path {tokenizer_file}"
|
||||
cmd += f" --tokenizer {tokenizer_type} "
|
||||
|
||||
print(cmd)
|
||||
|
||||
try:
|
||||
@@ -331,10 +343,28 @@ def stop_training():
|
||||
return "train stop", gr.update(interactive=True), gr.update(interactive=False)
|
||||
|
||||
|
||||
def create_data_project(name):
|
||||
name += "_pinyin"
|
||||
def get_list_projects():
|
||||
project_list = []
|
||||
for folder in os.listdir("data"):
|
||||
path_folder = os.path.join("data", folder)
|
||||
if not os.path.isdir(path_folder):
|
||||
continue
|
||||
folder = folder.lower()
|
||||
if folder == "emilia_zh_en_pinyin":
|
||||
continue
|
||||
project_list.append(folder)
|
||||
|
||||
projects_selelect = None if not project_list else project_list[-1]
|
||||
|
||||
return project_list, projects_selelect
|
||||
|
||||
|
||||
def create_data_project(name, tokenizer_type):
|
||||
name += "_" + tokenizer_type
|
||||
os.makedirs(os.path.join(path_data, name), exist_ok=True)
|
||||
os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
|
||||
project_list, projects_selelect = get_list_projects()
|
||||
return gr.update(choices=project_list, value=name)
|
||||
|
||||
|
||||
def transcribe(file_audio, language="english"):
|
||||
@@ -359,14 +389,14 @@ def transcribe(file_audio, language="english"):
|
||||
|
||||
|
||||
def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
|
||||
name_project += "_pinyin"
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
path_dataset = os.path.join(path_project, "dataset")
|
||||
path_project_wavs = os.path.join(path_project, "wavs")
|
||||
file_metadata = os.path.join(path_project, "metadata.csv")
|
||||
|
||||
if audio_files is None:
|
||||
return "You need to load an audio file."
|
||||
if not user:
|
||||
if audio_files is None:
|
||||
return "You need to load an audio file."
|
||||
|
||||
if os.path.isdir(path_project_wavs):
|
||||
shutil.rmtree(path_project_wavs)
|
||||
@@ -418,7 +448,7 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.
|
||||
except: # noqa: E722
|
||||
error_num += 1
|
||||
|
||||
with open(file_metadata, "w", encoding="utf-8") as f:
|
||||
with open(file_metadata, "w", encoding="utf-8-sig") as f:
|
||||
f.write(data)
|
||||
|
||||
if error_num != []:
|
||||
@@ -437,7 +467,6 @@ def format_seconds_to_hms(seconds):
|
||||
|
||||
|
||||
def create_metadata(name_project, progress=gr.Progress()):
|
||||
name_project += "_pinyin"
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
path_project_wavs = os.path.join(path_project, "wavs")
|
||||
file_metadata = os.path.join(path_project, "metadata.csv")
|
||||
@@ -448,7 +477,7 @@ def create_metadata(name_project, progress=gr.Progress()):
|
||||
if not os.path.isfile(file_metadata):
|
||||
return "The file was not found in " + file_metadata
|
||||
|
||||
with open(file_metadata, "r", encoding="utf-8") as f:
|
||||
with open(file_metadata, "r", encoding="utf-8-sig") as f:
|
||||
data = f.read()
|
||||
|
||||
audio_path_list = []
|
||||
@@ -499,7 +528,7 @@ def create_metadata(name_project, progress=gr.Progress()):
|
||||
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
|
||||
writer.write(line)
|
||||
|
||||
with open(file_duration, "w", encoding="utf-8") as f:
|
||||
with open(file_duration, "w") as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
|
||||
file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
||||
@@ -529,7 +558,6 @@ def calculate_train(
|
||||
last_per_steps,
|
||||
finetune,
|
||||
):
|
||||
name_project += "_pinyin"
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
file_duraction = os.path.join(path_project, "duration.json")
|
||||
|
||||
@@ -548,8 +576,8 @@ def calculate_train(
|
||||
data = json.load(file)
|
||||
|
||||
duration_list = data["duration"]
|
||||
|
||||
samples = len(duration_list)
|
||||
hours = sum(duration_list) / 3600
|
||||
|
||||
if torch.cuda.is_available():
|
||||
gpu_properties = torch.cuda.get_device_properties(0)
|
||||
@@ -583,34 +611,67 @@ def calculate_train(
|
||||
save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
|
||||
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
|
||||
|
||||
total_hours = hours
|
||||
mel_hop_length = 256
|
||||
mel_sampling_rate = 24000
|
||||
|
||||
# target
|
||||
wanted_max_updates = 1000000
|
||||
|
||||
# train params
|
||||
gpus = 1
|
||||
frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
|
||||
grad_accum = 1
|
||||
|
||||
# intermediate
|
||||
mini_batch_frames = frames_per_gpu * grad_accum * gpus
|
||||
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
|
||||
updates_per_epoch = total_hours / mini_batch_hours
|
||||
# steps_per_epoch = updates_per_epoch * grad_accum
|
||||
epochs = wanted_max_updates / updates_per_epoch
|
||||
|
||||
if finetune:
|
||||
learning_rate = 1e-5
|
||||
else:
|
||||
learning_rate = 7.5e-5
|
||||
|
||||
return batch_size_per_gpu, max_samples, num_warmup_updates, save_per_updates, last_per_steps, samples, learning_rate
|
||||
return (
|
||||
batch_size_per_gpu,
|
||||
max_samples,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_steps,
|
||||
samples,
|
||||
learning_rate,
|
||||
int(epochs),
|
||||
)
|
||||
|
||||
|
||||
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
|
||||
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str:
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
print("Original Checkpoint Keys:", checkpoint.keys())
|
||||
|
||||
ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
|
||||
if ema_model_state_dict is None:
|
||||
return "No 'ema_model_state_dict' found in the checkpoint."
|
||||
|
||||
if ema_model_state_dict is not None:
|
||||
if safetensors:
|
||||
new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors")
|
||||
save_file(ema_model_state_dict, new_checkpoint_path)
|
||||
else:
|
||||
new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt")
|
||||
new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
|
||||
torch.save(new_checkpoint, new_checkpoint_path)
|
||||
return f"New checkpoint saved at: {new_checkpoint_path}"
|
||||
else:
|
||||
return "No 'ema_model_state_dict' found in the checkpoint."
|
||||
|
||||
return f"New checkpoint saved at: {new_checkpoint_path}"
|
||||
|
||||
except Exception as e:
|
||||
return f"An error occurred: {e}"
|
||||
|
||||
|
||||
def vocab_check(project_name):
|
||||
name_project = project_name + "_pinyin"
|
||||
name_project = project_name
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
|
||||
file_metadata = os.path.join(path_project, "metadata.csv")
|
||||
@@ -619,15 +680,15 @@ def vocab_check(project_name):
|
||||
if not os.path.isfile(file_vocab):
|
||||
return f"the file {file_vocab} not found !"
|
||||
|
||||
with open(file_vocab, "r", encoding="utf-8") as f:
|
||||
with open(file_vocab, "r", encoding="utf-8-sig") as f:
|
||||
data = f.read()
|
||||
|
||||
vocab = data.split("\n")
|
||||
vocab = data.split("\n")
|
||||
vocab = set(vocab)
|
||||
|
||||
if not os.path.isfile(file_metadata):
|
||||
return f"the file {file_metadata} not found !"
|
||||
|
||||
with open(file_metadata, "r", encoding="utf-8") as f:
|
||||
with open(file_metadata, "r", encoding="utf-8-sig") as f:
|
||||
data = f.read()
|
||||
|
||||
miss_symbols = []
|
||||
@@ -652,7 +713,7 @@ def vocab_check(project_name):
|
||||
|
||||
|
||||
def get_random_sample_prepare(project_name):
|
||||
name_project = project_name + "_pinyin"
|
||||
name_project = project_name
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
file_arrow = os.path.join(path_project, "raw.arrow")
|
||||
if not os.path.isfile(file_arrow):
|
||||
@@ -665,14 +726,14 @@ def get_random_sample_prepare(project_name):
|
||||
|
||||
|
||||
def get_random_sample_transcribe(project_name):
|
||||
name_project = project_name + "_pinyin"
|
||||
name_project = project_name
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
file_metadata = os.path.join(path_project, "metadata.csv")
|
||||
if not os.path.isfile(file_metadata):
|
||||
return "", None
|
||||
|
||||
data = ""
|
||||
with open(file_metadata, "r", encoding="utf-8") as f:
|
||||
with open(file_metadata, "r", encoding="utf-8-sig") as f:
|
||||
data = f.read()
|
||||
|
||||
list_data = []
|
||||
@@ -703,13 +764,14 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
|
||||
global last_checkpoint, last_device, tts_api
|
||||
|
||||
if not os.path.isfile(file_checkpoint):
|
||||
return None
|
||||
return None, "checkpoint not found!"
|
||||
|
||||
if training_process is not None:
|
||||
device_test = "cpu"
|
||||
else:
|
||||
device_test = None
|
||||
|
||||
device_test = "cpu"
|
||||
if last_checkpoint != file_checkpoint or last_device != device_test:
|
||||
if last_checkpoint != file_checkpoint:
|
||||
last_checkpoint = file_checkpoint
|
||||
@@ -722,19 +784,67 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
|
||||
return f.name
|
||||
return f.name, tts_api.device
|
||||
|
||||
|
||||
def check_finetune(finetune):
|
||||
return gr.update(interactive=finetune), gr.update(interactive=finetune), gr.update(interactive=finetune)
|
||||
|
||||
|
||||
def get_checkpoints_project(project_name, is_gradio=True):
|
||||
if project_name is None:
|
||||
return [], ""
|
||||
project_name = project_name.replace("_pinyin", "").replace("_char", "")
|
||||
path_project_ckpts = os.path.join("ckpts", project_name)
|
||||
|
||||
if os.path.isdir(path_project_ckpts):
|
||||
files_checkpoints = glob(os.path.join(path_project_ckpts, "*.pt"))
|
||||
files_checkpoints = sorted(
|
||||
files_checkpoints,
|
||||
key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])
|
||||
if os.path.basename(x) != "model_last.pt"
|
||||
else float("inf"),
|
||||
)
|
||||
else:
|
||||
files_checkpoints = []
|
||||
|
||||
selelect_checkpoint = None if not files_checkpoints else files_checkpoints[0]
|
||||
|
||||
if is_gradio:
|
||||
return gr.update(choices=files_checkpoints, value=selelect_checkpoint)
|
||||
|
||||
return files_checkpoints, selelect_checkpoint
|
||||
|
||||
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# E2/F5 TTS AUTOMATIC FINETUNE
|
||||
|
||||
This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
|
||||
|
||||
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
|
||||
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
|
||||
|
||||
The checkpoints support English and Chinese.
|
||||
|
||||
for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
|
||||
"""
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
projects, projects_selelect = get_list_projects()
|
||||
tokenizer_type = gr.Radio(label="Tokenizer Type", choices=["pinyin", "char"], value="pinyin")
|
||||
project_name = gr.Textbox(label="project name", value="my_speak")
|
||||
bt_create = gr.Button("create new project")
|
||||
|
||||
bt_create.click(fn=create_data_project, inputs=[project_name])
|
||||
cm_project = gr.Dropdown(choices=projects, value=projects_selelect, label="Project", allow_custom_value=True)
|
||||
|
||||
bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project])
|
||||
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("transcribe Data"):
|
||||
ch_manual = gr.Checkbox(label="user", value=False)
|
||||
ch_manual = gr.Checkbox(label="audio from path", value=False)
|
||||
|
||||
mark_info_transcribe = gr.Markdown(
|
||||
"""```plaintext
|
||||
@@ -756,7 +866,7 @@ with gr.Blocks() as app:
|
||||
txt_info_transcribe = gr.Text(label="info", value="")
|
||||
bt_transcribe.click(
|
||||
fn=transcribe_all,
|
||||
inputs=[project_name, audio_speaker, txt_lang, ch_manual],
|
||||
inputs=[cm_project, audio_speaker, txt_lang, ch_manual],
|
||||
outputs=[txt_info_transcribe],
|
||||
)
|
||||
ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
|
||||
@@ -769,7 +879,7 @@ with gr.Blocks() as app:
|
||||
|
||||
random_sample_transcribe.click(
|
||||
fn=get_random_sample_transcribe,
|
||||
inputs=[project_name],
|
||||
inputs=[cm_project],
|
||||
outputs=[random_text_transcribe, random_audio_transcribe],
|
||||
)
|
||||
|
||||
@@ -797,7 +907,7 @@ with gr.Blocks() as app:
|
||||
|
||||
bt_prepare = bt_create = gr.Button("prepare")
|
||||
txt_info_prepare = gr.Text(label="info", value="")
|
||||
bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
|
||||
bt_prepare.click(fn=create_metadata, inputs=[cm_project], outputs=[txt_info_prepare])
|
||||
|
||||
random_sample_prepare = gr.Button("random sample")
|
||||
|
||||
@@ -806,16 +916,20 @@ with gr.Blocks() as app:
|
||||
random_audio_prepare = gr.Audio(label="Audio", type="filepath")
|
||||
|
||||
random_sample_prepare.click(
|
||||
fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare]
|
||||
fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
|
||||
)
|
||||
|
||||
with gr.TabItem("train Data"):
|
||||
with gr.Row():
|
||||
bt_calculate = bt_create = gr.Button("Auto Settings")
|
||||
ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
|
||||
lb_samples = gr.Label(label="samples")
|
||||
batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
|
||||
|
||||
with gr.Row():
|
||||
ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
|
||||
tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
|
||||
file_checkpoint_train = gr.Textbox(label="Checkpoint", value="")
|
||||
|
||||
with gr.Row():
|
||||
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
|
||||
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
|
||||
@@ -844,7 +958,7 @@ with gr.Blocks() as app:
|
||||
start_button.click(
|
||||
fn=start_training,
|
||||
inputs=[
|
||||
project_name,
|
||||
cm_project,
|
||||
exp_name,
|
||||
learning_rate,
|
||||
batch_size_per_gpu,
|
||||
@@ -857,14 +971,18 @@ with gr.Blocks() as app:
|
||||
save_per_updates,
|
||||
last_per_steps,
|
||||
ch_finetune,
|
||||
file_checkpoint_train,
|
||||
tokenizer_type,
|
||||
tokenizer_file,
|
||||
],
|
||||
outputs=[txt_info_train, start_button, stop_button],
|
||||
)
|
||||
stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
|
||||
|
||||
bt_calculate.click(
|
||||
fn=calculate_train,
|
||||
inputs=[
|
||||
project_name,
|
||||
cm_project,
|
||||
batch_size_type,
|
||||
max_samples,
|
||||
learning_rate,
|
||||
@@ -881,29 +999,42 @@ with gr.Blocks() as app:
|
||||
last_per_steps,
|
||||
lb_samples,
|
||||
learning_rate,
|
||||
epochs,
|
||||
],
|
||||
)
|
||||
|
||||
ch_finetune.change(
|
||||
check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type]
|
||||
)
|
||||
|
||||
with gr.TabItem("reduse checkpoint"):
|
||||
txt_path_checkpoint = gr.Text(label="path checkpoint :")
|
||||
txt_path_checkpoint_small = gr.Text(label="path output :")
|
||||
ch_safetensors = gr.Checkbox(label="safetensors", value="")
|
||||
txt_info_reduse = gr.Text(label="info", value="")
|
||||
reduse_button = gr.Button("reduse")
|
||||
reduse_button.click(
|
||||
fn=extract_and_save_ema_model,
|
||||
inputs=[txt_path_checkpoint, txt_path_checkpoint_small],
|
||||
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_safetensors],
|
||||
outputs=[txt_info_reduse],
|
||||
)
|
||||
|
||||
with gr.TabItem("vocab check experiment"):
|
||||
check_button = gr.Button("check vocab")
|
||||
txt_info_check = gr.Text(label="info", value="")
|
||||
check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
|
||||
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check])
|
||||
|
||||
with gr.TabItem("test model"):
|
||||
exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
|
||||
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
|
||||
|
||||
nfe_step = gr.Number(label="n_step", value=32)
|
||||
file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="")
|
||||
|
||||
with gr.Row():
|
||||
cm_checkpoint = gr.Dropdown(
|
||||
choices=list_checkpoints, value=checkpoint_select, label="checkpoints", allow_custom_value=True
|
||||
)
|
||||
bt_checkpoint_refresh = gr.Button("refresh")
|
||||
|
||||
random_sample_infer = gr.Button("random sample")
|
||||
|
||||
@@ -911,17 +1042,24 @@ with gr.Blocks() as app:
|
||||
ref_audio = gr.Audio(label="audio ref", type="filepath")
|
||||
gen_text = gr.Textbox(label="gen text")
|
||||
random_sample_infer.click(
|
||||
fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio]
|
||||
fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
|
||||
)
|
||||
check_button_infer = gr.Button("infer")
|
||||
|
||||
with gr.Row():
|
||||
txt_info_gpu = gr.Textbox("", label="device")
|
||||
check_button_infer = gr.Button("infer")
|
||||
|
||||
gen_audio = gr.Audio(label="audio gen", type="filepath")
|
||||
|
||||
check_button_infer.click(
|
||||
fn=infer,
|
||||
inputs=[file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step],
|
||||
outputs=[gen_audio],
|
||||
inputs=[cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step],
|
||||
outputs=[gen_audio, txt_info_gpu],
|
||||
)
|
||||
|
||||
bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
||||
cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|
||||
|
||||
Reference in New Issue
Block a user