From 254e5e6d30c5035caffc7d71e2b45729e14cfd64 Mon Sep 17 00:00:00 2001 From: SWivid Date: Thu, 24 Oct 2024 15:23:55 +0800 Subject: [PATCH] update finetune-cli -gradio --- README.md | 1 + pyproject.toml | 1 + src/f5_tts/eval/eval_infer_batch.py | 3 - src/f5_tts/eval/eval_infer_batch.sh | 12 +- src/f5_tts/infer/infer_gradio.py | 173 +++++++++++++- src/f5_tts/train/datasets/prepare_csv_wavs.py | 2 +- src/f5_tts/train/finetune_cli.py | 34 ++- src/f5_tts/train/finetune_gradio.py | 226 ++++++++++++++---- 8 files changed, 388 insertions(+), 64 deletions(-) diff --git a/README.md b/README.md index a08a2f1..6f7bf15 100644 --- a/README.md +++ b/README.md @@ -183,6 +183,7 @@ Currently supported features: - Chunk inference - Podcast Generation - Multiple Speech-Type Generation +- Voice Chat powered by Qwen2.5-3B-Instruct You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may also use local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`. diff --git a/pyproject.toml b/pyproject.toml index f6d3cc4..1f974d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "torchdiffeq", "tqdm>=4.65.0", "transformers", + "transformers_stream_generator", "vocos", "wandb", "x_transformers>=1.31.14", diff --git a/src/f5_tts/eval/eval_infer_batch.py b/src/f5_tts/eval/eval_infer_batch.py index 8d48661..b973c07 100644 --- a/src/f5_tts/eval/eval_infer_batch.py +++ b/src/f5_tts/eval/eval_infer_batch.py @@ -4,7 +4,6 @@ import os sys.path.append(os.getcwd()) import time -import random from tqdm import tqdm import argparse from importlib.resources import files @@ -97,8 +96,6 @@ def main(): metainfo = get_seedtts_testset_metainfo(metalst) # path to save genereted wavs - if seed is None: - seed = random.randint(-10000, 10000) output_dir = ( f"results/{exp_name}_{ckpt_step}/{testset}/" f"seed{seed}_{ode_method}_nfe{nfe_step}" diff --git a/src/f5_tts/eval/eval_infer_batch.sh b/src/f5_tts/eval/eval_infer_batch.sh index 45b0717..12dd2ff 100644 --- a/src/f5_tts/eval/eval_infer_batch.sh +++ b/src/f5_tts/eval/eval_infer_batch.sh @@ -1,13 +1,13 @@ #!/bin/bash # e.g. F5-TTS, 16 NFE -accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16 -accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16 -accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16 +accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16 +accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16 +accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16 # e.g. Vanilla E2 TTS, 32 NFE -accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0 -accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0 -accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0 +accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0 +accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0 +accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0 # etc. diff --git a/src/f5_tts/infer/infer_gradio.py b/src/f5_tts/infer/infer_gradio.py index 130e8fc..de48928 100644 --- a/src/f5_tts/infer/infer_gradio.py +++ b/src/f5_tts/infer/infer_gradio.py @@ -11,6 +11,7 @@ import soundfile as sf import torchaudio from cached_path import cached_path from pydub import AudioSegment +from transformers import AutoModelForCausalLM, AutoTokenizer try: import spaces @@ -51,6 +52,33 @@ E2TTS_ema_model = load_model( UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) ) +# Initialize Qwen model and tokenizer +model_name = "Qwen/Qwen2.5-3B-Instruct" +model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto") +tokenizer = AutoTokenizer.from_pretrained(model_name) + + +def generate_response(messages): + """Generate response using Qwen""" + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + generated_ids = model.generate( + **model_inputs, + max_new_tokens=512, + temperature=0.7, + top_p=0.95, + ) + + generated_ids = [ + output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + ] + return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + @gpu_decorator def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1): @@ -490,6 +518,146 @@ with gr.Blocks() as app_emotional: outputs=generate_emotional_btn, ) + +with gr.Blocks() as app_chat: + gr.Markdown( + """ +# Voice Chat +Have a conversation with an AI using your reference voice! +1. Upload a reference audio clip and optionally its transcript. +2. Record your message through your microphone. +3. The AI will respond using the reference voice. +""" + ) + + with gr.Row(): + with gr.Column(): + ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath") + + with gr.Column(): + with gr.Accordion("Advanced Settings", open=False): + model_choice_chat = gr.Radio( + choices=["F5-TTS", "E2-TTS"], + label="TTS Model", + value="F5-TTS", + ) + remove_silence_chat = gr.Checkbox( + label="Remove Silences", + value=True, + ) + ref_text_chat = gr.Textbox( + label="Reference Text", + info="Optional: Leave blank to auto-transcribe", + lines=2, + ) + system_prompt_chat = gr.Textbox( + label="System Prompt", + value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.", + lines=2, + ) + + chatbot_interface = gr.Chatbot(label="Conversation") + + with gr.Row(): + with gr.Column(): + audio_output_chat = gr.Audio(autoplay=True) + with gr.Column(): + audio_input_chat = gr.Microphone( + label="Or speak your message", + type="filepath", + ) + + clear_btn_chat = gr.Button("Clear Conversation") + + conversation_state = gr.State( + value=[ + { + "role": "system", + "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.", + } + ] + ) + + def process_audio_input(audio_path, history, conv_state): + """Handle audio input from user""" + if not audio_path: + return history, conv_state, "" + + text = "" + text = preprocess_ref_audio_text(audio_path, text)[1] + + if not text.strip(): + return history, conv_state, "" + + conv_state.append({"role": "user", "content": text}) + history.append((text, None)) + + response = generate_response(conv_state) + + conv_state.append({"role": "assistant", "content": response}) + history[-1] = (text, response) + + return history, conv_state, "" + + def generate_audio_response(history, ref_audio, ref_text, model, remove_silence): + """Generate TTS audio for AI response""" + if not history or not ref_audio: + return None + + last_user_message, last_ai_response = history[-1] + if not last_ai_response: + return None + + audio_result, _ = infer( + ref_audio, + ref_text, + last_ai_response, + model, + remove_silence, + cross_fade_duration=0.15, + speed=1.0, + ) + return audio_result + + def clear_conversation(): + """Reset the conversation""" + return [], [ + { + "role": "system", + "content": "You are a friendly person, and may impersonate whoever they address you as. Stay in character. Keep your responses concise since they will be spoken out loud.", + } + ] + + def update_system_prompt(new_prompt): + """Update the system prompt and reset the conversation""" + new_conv_state = [{"role": "system", "content": new_prompt}] + return [], new_conv_state + + # Handle audio input + audio_input_chat.stop_recording( + process_audio_input, + inputs=[audio_input_chat, chatbot_interface, conversation_state], + outputs=[chatbot_interface, conversation_state], + ).then( + generate_audio_response, + inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, model_choice_chat, remove_silence_chat], + outputs=audio_output_chat, + ) + + # Handle clear button + clear_btn_chat.click( + clear_conversation, + outputs=[chatbot_interface, conversation_state], + ) + + # Handle system prompt change and reset conversation + system_prompt_chat.change( + update_system_prompt, + inputs=system_prompt_chat, + outputs=[chatbot_interface, conversation_state], + ) + + with gr.Blocks() as app: gr.Markdown( """ @@ -507,7 +675,10 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.** """ ) - gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"]) + gr.TabbedInterface( + [app_tts, app_podcast, app_emotional, app_chat, app_credits], + ["TTS", "Podcast", "Multi-Style", "Voice-Chat", "Credits"], + ) @click.command() diff --git a/src/f5_tts/train/datasets/prepare_csv_wavs.py b/src/f5_tts/train/datasets/prepare_csv_wavs.py index e68b053..ece8fca 100644 --- a/src/f5_tts/train/datasets/prepare_csv_wavs.py +++ b/src/f5_tts/train/datasets/prepare_csv_wavs.py @@ -60,7 +60,7 @@ def read_audio_text_pairs(csv_file_path): audio_text_pairs = [] parent = Path(csv_file_path).parent - with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile: + with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile: reader = csv.reader(csvfile, delimiter="|") next(reader) # Skip the header row for row in reader: diff --git a/src/f5_tts/train/finetune_cli.py b/src/f5_tts/train/finetune_cli.py index 162cf5d..ec41d41 100644 --- a/src/f5_tts/train/finetune_cli.py +++ b/src/f5_tts/train/finetune_cli.py @@ -15,26 +15,35 @@ hop_length = 256 # -------------------------- Argument Parsing --------------------------- # def parse_args(): + # batch_size_per_gpu = 1000 settting for gpu 8GB + # batch_size_per_gpu = 1600 settting for gpu 12GB + # batch_size_per_gpu = 2000 settting for gpu 16GB + # batch_size_per_gpu = 3200 settting for gpu 24GB + + # num_warmup_updates 10000 sample = 500 + + # change save_per_updates , last_per_steps what you need , + parser = argparse.ArgumentParser(description="Train CFM Model") parser.add_argument( "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name" ) parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use") - parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for training") - parser.add_argument("--batch_size_per_gpu", type=int, default=256, help="Batch size per GPU") + parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training") + parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU") parser.add_argument( "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type" ) - parser.add_argument("--max_samples", type=int, default=16, help="Max sequences per batch") + parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch") parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps") parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping") parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs") - parser.add_argument("--num_warmup_updates", type=int, default=5, help="Warmup steps") - parser.add_argument("--save_per_updates", type=int, default=10, help="Save checkpoint every X steps") - parser.add_argument("--last_per_steps", type=int, default=10, help="Save last checkpoint every X steps") + parser.add_argument("--num_warmup_updates", type=int, default=500, help="Warmup steps") + parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps") + parser.add_argument("--last_per_steps", type=int, default=20000, help="Save last checkpoint every X steps") parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune") - + parser.add_argument("--pretrain", type=str, default=None, help="Use pretrain model for finetune") parser.add_argument( "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type" ) @@ -60,13 +69,19 @@ def main(): model_cls = DiT model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) if args.finetune: - ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) + if args.pretrain is None: + ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) + else: + ckpt_path = args.pretrain elif args.exp_name == "E2TTS_Base": wandb_resume_id = None model_cls = UNetT model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) if args.finetune: - ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) + if args.pretrain is None: + ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) + else: + ckpt_path = args.pretrain if args.finetune: path_ckpt = os.path.join("ckpts", args.dataset_name) @@ -118,6 +133,7 @@ def main(): ) train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) + trainer.train( train_dataset, resumable_with_seed=666, # seed for shuffling dataset diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 8f6db5f..182913a 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -20,6 +20,7 @@ import torch import torchaudio from datasets import Dataset as Dataset_ from datasets.arrow_writer import ArrowWriter +from safetensors.torch import save_file from scipy.io import wavfile from transformers import pipeline @@ -247,6 +248,9 @@ def start_training( save_per_updates=400, last_per_steps=800, finetune=True, + file_checkpoint_train="", + tokenizer_type="pinyin", + tokenizer_file="", ): global training_process, tts_api @@ -256,7 +260,7 @@ def start_training( torch.cuda.empty_cache() tts_api = None - path_project = os.path.join(path_data, dataset_name + "_pinyin") + path_project = os.path.join(path_data, dataset_name) if not os.path.isdir(path_project): yield ( @@ -278,6 +282,7 @@ def start_training( yield "start train", gr.update(interactive=False), gr.update(interactive=False) # Command to run the training script with the specified arguments + dataset_name = dataset_name.replace("_pinyin", "").replace("_char", "") cmd = ( f"accelerate launch finetune-cli.py --exp_name {exp_name} " f"--learning_rate {learning_rate} " @@ -295,6 +300,13 @@ def start_training( if finetune: cmd += f" --finetune {finetune}" + if file_checkpoint_train != "": + cmd += f" --file_checkpoint_train {file_checkpoint_train}" + + if tokenizer_file != "": + cmd += f" --tokenizer_path {tokenizer_file}" + cmd += f" --tokenizer {tokenizer_type} " + print(cmd) try: @@ -331,10 +343,28 @@ def stop_training(): return "train stop", gr.update(interactive=True), gr.update(interactive=False) -def create_data_project(name): - name += "_pinyin" +def get_list_projects(): + project_list = [] + for folder in os.listdir("data"): + path_folder = os.path.join("data", folder) + if not os.path.isdir(path_folder): + continue + folder = folder.lower() + if folder == "emilia_zh_en_pinyin": + continue + project_list.append(folder) + + projects_selelect = None if not project_list else project_list[-1] + + return project_list, projects_selelect + + +def create_data_project(name, tokenizer_type): + name += "_" + tokenizer_type os.makedirs(os.path.join(path_data, name), exist_ok=True) os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True) + project_list, projects_selelect = get_list_projects() + return gr.update(choices=project_list, value=name) def transcribe(file_audio, language="english"): @@ -359,14 +389,14 @@ def transcribe(file_audio, language="english"): def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()): - name_project += "_pinyin" path_project = os.path.join(path_data, name_project) path_dataset = os.path.join(path_project, "dataset") path_project_wavs = os.path.join(path_project, "wavs") file_metadata = os.path.join(path_project, "metadata.csv") - if audio_files is None: - return "You need to load an audio file." + if not user: + if audio_files is None: + return "You need to load an audio file." if os.path.isdir(path_project_wavs): shutil.rmtree(path_project_wavs) @@ -418,7 +448,7 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr. except: # noqa: E722 error_num += 1 - with open(file_metadata, "w", encoding="utf-8") as f: + with open(file_metadata, "w", encoding="utf-8-sig") as f: f.write(data) if error_num != []: @@ -437,7 +467,6 @@ def format_seconds_to_hms(seconds): def create_metadata(name_project, progress=gr.Progress()): - name_project += "_pinyin" path_project = os.path.join(path_data, name_project) path_project_wavs = os.path.join(path_project, "wavs") file_metadata = os.path.join(path_project, "metadata.csv") @@ -448,7 +477,7 @@ def create_metadata(name_project, progress=gr.Progress()): if not os.path.isfile(file_metadata): return "The file was not found in " + file_metadata - with open(file_metadata, "r", encoding="utf-8") as f: + with open(file_metadata, "r", encoding="utf-8-sig") as f: data = f.read() audio_path_list = [] @@ -499,7 +528,7 @@ def create_metadata(name_project, progress=gr.Progress()): for line in progress.tqdm(result, total=len(result), desc="prepare data"): writer.write(line) - with open(file_duration, "w", encoding="utf-8") as f: + with open(file_duration, "w") as f: json.dump({"duration": duration_list}, f, ensure_ascii=False) file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt" @@ -529,7 +558,6 @@ def calculate_train( last_per_steps, finetune, ): - name_project += "_pinyin" path_project = os.path.join(path_data, name_project) file_duraction = os.path.join(path_project, "duration.json") @@ -548,8 +576,8 @@ def calculate_train( data = json.load(file) duration_list = data["duration"] - samples = len(duration_list) + hours = sum(duration_list) / 3600 if torch.cuda.is_available(): gpu_properties = torch.cuda.get_device_properties(0) @@ -583,34 +611,67 @@ def calculate_train( save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates) last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps) + total_hours = hours + mel_hop_length = 256 + mel_sampling_rate = 24000 + + # target + wanted_max_updates = 1000000 + + # train params + gpus = 1 + frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200 + grad_accum = 1 + + # intermediate + mini_batch_frames = frames_per_gpu * grad_accum * gpus + mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600 + updates_per_epoch = total_hours / mini_batch_hours + # steps_per_epoch = updates_per_epoch * grad_accum + epochs = wanted_max_updates / updates_per_epoch + if finetune: learning_rate = 1e-5 else: learning_rate = 7.5e-5 - return batch_size_per_gpu, max_samples, num_warmup_updates, save_per_updates, last_per_steps, samples, learning_rate + return ( + batch_size_per_gpu, + max_samples, + num_warmup_updates, + save_per_updates, + last_per_steps, + samples, + learning_rate, + int(epochs), + ) -def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None: +def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str: try: checkpoint = torch.load(checkpoint_path) print("Original Checkpoint Keys:", checkpoint.keys()) ema_model_state_dict = checkpoint.get("ema_model_state_dict", None) + if ema_model_state_dict is None: + return "No 'ema_model_state_dict' found in the checkpoint." - if ema_model_state_dict is not None: + if safetensors: + new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors") + save_file(ema_model_state_dict, new_checkpoint_path) + else: + new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt") new_checkpoint = {"ema_model_state_dict": ema_model_state_dict} torch.save(new_checkpoint, new_checkpoint_path) - return f"New checkpoint saved at: {new_checkpoint_path}" - else: - return "No 'ema_model_state_dict' found in the checkpoint." + + return f"New checkpoint saved at: {new_checkpoint_path}" except Exception as e: return f"An error occurred: {e}" def vocab_check(project_name): - name_project = project_name + "_pinyin" + name_project = project_name path_project = os.path.join(path_data, name_project) file_metadata = os.path.join(path_project, "metadata.csv") @@ -619,15 +680,15 @@ def vocab_check(project_name): if not os.path.isfile(file_vocab): return f"the file {file_vocab} not found !" - with open(file_vocab, "r", encoding="utf-8") as f: + with open(file_vocab, "r", encoding="utf-8-sig") as f: data = f.read() - - vocab = data.split("\n") + vocab = data.split("\n") + vocab = set(vocab) if not os.path.isfile(file_metadata): return f"the file {file_metadata} not found !" - with open(file_metadata, "r", encoding="utf-8") as f: + with open(file_metadata, "r", encoding="utf-8-sig") as f: data = f.read() miss_symbols = [] @@ -652,7 +713,7 @@ def vocab_check(project_name): def get_random_sample_prepare(project_name): - name_project = project_name + "_pinyin" + name_project = project_name path_project = os.path.join(path_data, name_project) file_arrow = os.path.join(path_project, "raw.arrow") if not os.path.isfile(file_arrow): @@ -665,14 +726,14 @@ def get_random_sample_prepare(project_name): def get_random_sample_transcribe(project_name): - name_project = project_name + "_pinyin" + name_project = project_name path_project = os.path.join(path_data, name_project) file_metadata = os.path.join(path_project, "metadata.csv") if not os.path.isfile(file_metadata): return "", None data = "" - with open(file_metadata, "r", encoding="utf-8") as f: + with open(file_metadata, "r", encoding="utf-8-sig") as f: data = f.read() list_data = [] @@ -703,13 +764,14 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step): global last_checkpoint, last_device, tts_api if not os.path.isfile(file_checkpoint): - return None + return None, "checkpoint not found!" if training_process is not None: device_test = "cpu" else: device_test = None + device_test = "cpu" if last_checkpoint != file_checkpoint or last_device != device_test: if last_checkpoint != file_checkpoint: last_checkpoint = file_checkpoint @@ -722,19 +784,67 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step): with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name) - return f.name + return f.name, tts_api.device + + +def check_finetune(finetune): + return gr.update(interactive=finetune), gr.update(interactive=finetune), gr.update(interactive=finetune) + + +def get_checkpoints_project(project_name, is_gradio=True): + if project_name is None: + return [], "" + project_name = project_name.replace("_pinyin", "").replace("_char", "") + path_project_ckpts = os.path.join("ckpts", project_name) + + if os.path.isdir(path_project_ckpts): + files_checkpoints = glob(os.path.join(path_project_ckpts, "*.pt")) + files_checkpoints = sorted( + files_checkpoints, + key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]) + if os.path.basename(x) != "model_last.pt" + else float("inf"), + ) + else: + files_checkpoints = [] + + selelect_checkpoint = None if not files_checkpoints else files_checkpoints[0] + + if is_gradio: + return gr.update(choices=files_checkpoints, value=selelect_checkpoint) + + return files_checkpoints, selelect_checkpoint with gr.Blocks() as app: + gr.Markdown( + """ +# E2/F5 TTS AUTOMATIC FINETUNE + +This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models: + +* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching) +* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS) + +The checkpoints support English and Chinese. + +for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143) +""" + ) + with gr.Row(): + projects, projects_selelect = get_list_projects() + tokenizer_type = gr.Radio(label="Tokenizer Type", choices=["pinyin", "char"], value="pinyin") project_name = gr.Textbox(label="project name", value="my_speak") bt_create = gr.Button("create new project") - bt_create.click(fn=create_data_project, inputs=[project_name]) + cm_project = gr.Dropdown(choices=projects, value=projects_selelect, label="Project", allow_custom_value=True) + + bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project]) with gr.Tabs(): with gr.TabItem("transcribe Data"): - ch_manual = gr.Checkbox(label="user", value=False) + ch_manual = gr.Checkbox(label="audio from path", value=False) mark_info_transcribe = gr.Markdown( """```plaintext @@ -756,7 +866,7 @@ with gr.Blocks() as app: txt_info_transcribe = gr.Text(label="info", value="") bt_transcribe.click( fn=transcribe_all, - inputs=[project_name, audio_speaker, txt_lang, ch_manual], + inputs=[cm_project, audio_speaker, txt_lang, ch_manual], outputs=[txt_info_transcribe], ) ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe]) @@ -769,7 +879,7 @@ with gr.Blocks() as app: random_sample_transcribe.click( fn=get_random_sample_transcribe, - inputs=[project_name], + inputs=[cm_project], outputs=[random_text_transcribe, random_audio_transcribe], ) @@ -797,7 +907,7 @@ with gr.Blocks() as app: bt_prepare = bt_create = gr.Button("prepare") txt_info_prepare = gr.Text(label="info", value="") - bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare]) + bt_prepare.click(fn=create_metadata, inputs=[cm_project], outputs=[txt_info_prepare]) random_sample_prepare = gr.Button("random sample") @@ -806,16 +916,20 @@ with gr.Blocks() as app: random_audio_prepare = gr.Audio(label="Audio", type="filepath") random_sample_prepare.click( - fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare] + fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare] ) with gr.TabItem("train Data"): with gr.Row(): bt_calculate = bt_create = gr.Button("Auto Settings") - ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True) lb_samples = gr.Label(label="samples") batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame") + with gr.Row(): + ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True) + tokenizer_file = gr.Textbox(label="Tokenizer File", value="") + file_checkpoint_train = gr.Textbox(label="Checkpoint", value="") + with gr.Row(): exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base") learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5) @@ -844,7 +958,7 @@ with gr.Blocks() as app: start_button.click( fn=start_training, inputs=[ - project_name, + cm_project, exp_name, learning_rate, batch_size_per_gpu, @@ -857,14 +971,18 @@ with gr.Blocks() as app: save_per_updates, last_per_steps, ch_finetune, + file_checkpoint_train, + tokenizer_type, + tokenizer_file, ], outputs=[txt_info_train, start_button, stop_button], ) stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button]) + bt_calculate.click( fn=calculate_train, inputs=[ - project_name, + cm_project, batch_size_type, max_samples, learning_rate, @@ -881,29 +999,42 @@ with gr.Blocks() as app: last_per_steps, lb_samples, learning_rate, + epochs, ], ) + ch_finetune.change( + check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type] + ) + with gr.TabItem("reduse checkpoint"): txt_path_checkpoint = gr.Text(label="path checkpoint :") txt_path_checkpoint_small = gr.Text(label="path output :") + ch_safetensors = gr.Checkbox(label="safetensors", value="") txt_info_reduse = gr.Text(label="info", value="") reduse_button = gr.Button("reduse") reduse_button.click( fn=extract_and_save_ema_model, - inputs=[txt_path_checkpoint, txt_path_checkpoint_small], + inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_safetensors], outputs=[txt_info_reduse], ) with gr.TabItem("vocab check experiment"): check_button = gr.Button("check vocab") txt_info_check = gr.Text(label="info", value="") - check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check]) + check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check]) with gr.TabItem("test model"): exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS") + list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False) + nfe_step = gr.Number(label="n_step", value=32) - file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="") + + with gr.Row(): + cm_checkpoint = gr.Dropdown( + choices=list_checkpoints, value=checkpoint_select, label="checkpoints", allow_custom_value=True + ) + bt_checkpoint_refresh = gr.Button("refresh") random_sample_infer = gr.Button("random sample") @@ -911,17 +1042,24 @@ with gr.Blocks() as app: ref_audio = gr.Audio(label="audio ref", type="filepath") gen_text = gr.Textbox(label="gen text") random_sample_infer.click( - fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio] + fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio] ) - check_button_infer = gr.Button("infer") + + with gr.Row(): + txt_info_gpu = gr.Textbox("", label="device") + check_button_infer = gr.Button("infer") + gen_audio = gr.Audio(label="audio gen", type="filepath") check_button_infer.click( fn=infer, - inputs=[file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step], - outputs=[gen_audio], + inputs=[cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step], + outputs=[gen_audio, txt_info_gpu], ) + bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint]) + cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint]) + @click.command() @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")