From 59761ec9f87c854c8eba79fff186bd1f8df3ec94 Mon Sep 17 00:00:00 2001 From: SWivid Date: Mon, 11 Nov 2024 11:13:11 +0800 Subject: [PATCH] Update. Cache last used custom model path #447 --- src/f5_tts/infer/SHARED.md | 22 +++++++----- src/f5_tts/infer/infer_gradio.py | 62 ++++++++++++++++++++++---------- 2 files changed, 57 insertions(+), 27 deletions(-) diff --git a/src/f5_tts/infer/SHARED.md b/src/f5_tts/infer/SHARED.md index 881b519..9311fc6 100644 --- a/src/f5_tts/infer/SHARED.md +++ b/src/f5_tts/infer/SHARED.md @@ -1,21 +1,25 @@ # Shared Model Cards -- This document is serving as a quick lookup table for the community training/finetuning result, with various language support. -- The models in this repository are open source and are based on voluntary contributions from contributors. -- The use of models must be conditioned on respect for the respective creators. The convenience brought comes from their efforts. -- Welcome to pull request sharing your result here. +- **Prerequisites of using** + - This document is serving as a quick lookup table for the community training/finetuning result, with various language support. + - The models in this repository are open source and are based on voluntary contributions from contributors. + - The use of models must be conditioned on respect for the respective creators. The convenience brought comes from their efforts. +- **Welcome to share here** + - Have a pretrained/finetuned result: model checkpoint (pruned best to facilitate inference, i.e. leave only `ema_model_state_dict`) and corresponding vocab file (for tokenization). + - Host a public [huggingface model repository](https://huggingface.co/new) and upload the model related files. + - Make a pull request adding a model card to the current page, i.e. `src\f5_tts\infer\SHARED.md`. ### Support Language - [Multilingual](#multilingual) - - [F5-TTS Base @ pretrain @ zh \& en](#f5-tts-base--pretrain--zh--en) + - [F5-TTS Base @ pretrain @ zh \& en](#f5-tts-base--pretrain--zh--en) - [Mandarin](#mandarin) - [English](#english) -### Multilingual +## Multilingual #### F5-TTS Base @ pretrain @ zh & en |Model|🤗Hugging Face|Data (Hours)|Model License| @@ -26,10 +30,10 @@ MODEL_CKPT: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors VOCAB_FILE: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt ``` -*Other infos, e.g. Link to some sampled results, Github repo, Usage instruction, Tutorial (Blog, Video, etc.) ...* +*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...* -### Mandarin +## Mandarin -### English +## English diff --git a/src/f5_tts/infer/infer_gradio.py b/src/f5_tts/infer/infer_gradio.py index b322097..3de5b55 100644 --- a/src/f5_tts/infer/infer_gradio.py +++ b/src/f5_tts/infer/infer_gradio.py @@ -4,6 +4,7 @@ import re import tempfile from collections import OrderedDict +from importlib.resources import files import click import gradio as gr @@ -71,6 +72,7 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None): F5TTS_ema_model = load_f5tts() E2TTS_ema_model = load_e2tts() if USING_SPACES else None +custom_ema_model, pre_custom_path = None, "" chat_model_state = None chat_tokenizer_state = None @@ -115,8 +117,11 @@ def infer( ema_model = E2TTS_ema_model elif isinstance(model, list) and model[0] == "Custom": assert not USING_SPACES, "Only official checkpoints allowed in Spaces." - show_info("Loading Custom TTS model...") - custom_ema_model = load_custom(model[1], vocab_path=model[2]) + global custom_ema_model, pre_custom_path + if pre_custom_path != model[1]: + show_info("Loading Custom TTS model...") + custom_ema_model = load_custom(model[1], vocab_path=model[2]) + pre_custom_path = model[1] ema_model = custom_ema_model final_wave, final_sample_rate, combined_spectrogram = infer_process( @@ -739,14 +744,29 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip """ ) + last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom.txt") + + def load_last_used_custom(): + try: + with open(last_used_custom, "r") as f: + return f.read().split(",") + except FileNotFoundError: + last_used_custom.parent.mkdir(parents=True, exist_ok=True) + return [ + "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", + "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt", + ] + def switch_tts_model(new_choice, custom_ckpt_path, custom_vocab_path): global tts_model_choice if new_choice == "Custom": tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path] - return gr.update(visible=True) + with open(last_used_custom, "w") as f: + f.write(f"{custom_ckpt_path},{custom_vocab_path}") + return gr.update(visible=True), gr.update(visible=True) else: tts_model_choice = new_choice - return gr.update(visible=False) + return gr.update(visible=False), gr.update(visible=False) with gr.Row(): if not USING_SPACES: @@ -757,32 +777,38 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip choose_tts_model = gr.Radio( choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL ) - with gr.Column(visible=False) as choose_custom_tts_model: - custom_ckpt_path = gr.Textbox( - placeholder="MODEL_CKPT: local_path | hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", - show_label=False, - min_width=200, - ) - custom_vocab_path = gr.Textbox( - placeholder="VOCAB_FILE: local_path | hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt | leave blank to use default", - show_label=False, - min_width=200, - ) + custom_ckpt_path = gr.Dropdown( + choices=["hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"], + value=load_last_used_custom()[0], + allow_custom_value=True, + label="MODEL CKPT: local_path | hf://user_id/repo_id/model_ckpt", + visible=False, + ) + custom_vocab_path = gr.Dropdown( + choices=["hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt"], + value=load_last_used_custom()[1], + allow_custom_value=True, + label="VOCAB FILE: local_path | hf://user_id/repo_id/vocab_file", + visible=False, + ) choose_tts_model.change( switch_tts_model, inputs=[choose_tts_model, custom_ckpt_path, custom_vocab_path], - outputs=[choose_custom_tts_model], + outputs=[custom_ckpt_path, custom_vocab_path], + show_progress="hidden", ) custom_ckpt_path.change( switch_tts_model, inputs=[choose_tts_model, custom_ckpt_path, custom_vocab_path], - outputs=[choose_custom_tts_model], + outputs=[custom_ckpt_path, custom_vocab_path], + show_progress="hidden", ) custom_vocab_path.change( switch_tts_model, inputs=[choose_tts_model, custom_ckpt_path, custom_vocab_path], - outputs=[choose_custom_tts_model], + outputs=[custom_ckpt_path, custom_vocab_path], + show_progress="hidden", ) gr.TabbedInterface(