From 2a844ae064deb2ff0e6d99b8d0aafa20eeb5e253 Mon Sep 17 00:00:00 2001 From: SWivid Date: Fri, 15 Nov 2024 19:15:34 +0800 Subject: [PATCH] minor update patch-1 --- src/f5_tts/api.py | 19 +++++++++++-------- src/f5_tts/infer/utils_infer.py | 12 ++++++------ 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/f5_tts/api.py b/src/f5_tts/api.py index 21a71e5..a7152bf 100644 --- a/src/f5_tts/api.py +++ b/src/f5_tts/api.py @@ -32,6 +32,7 @@ class F5TTS: vocoder_name="vocos", local_path=None, device=None, + hf_cache_dir=None, ): # Initialize parameters self.final_wave = None @@ -46,29 +47,31 @@ class F5TTS: ) # Load models - self.load_vocoder_model(vocoder_name, local_path=local_path) - self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, local_path=local_path) + self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir) + self.load_ema_model( + model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir + ) - def load_vocoder_model(self, vocoder_name, local_path=None): - self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device) + def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None): + self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir) - def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, local_path=None): + def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None): if model_type == "F5-TTS": if not ckpt_file: if mel_spec_type == "vocos": ckpt_file = str( - cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=local_path) + cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir) ) elif mel_spec_type == "bigvgan": ckpt_file = str( - cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=local_path) + cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir) ) model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) model_cls = DiT elif model_type == "E2-TTS": if not ckpt_file: ckpt_file = str( - cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=local_path) + cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir) ) model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) model_cls = UNetT diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index 5182754..1de9048 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -90,18 +90,18 @@ def chunk_text(text, max_chars=135): # load vocoder -def load_vocoder(vocoder_name="vocos", is_local=False, local_path=None, device=device): +def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device, hf_cache_dir=None): if vocoder_name == "vocos": # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) - if is_local and local_path is not None: + if is_local: print(f"Load vocos from local path {local_path}") config_path = f"{local_path}/config.yaml" model_path = f"{local_path}/pytorch_model.bin" else: print("Download Vocos from huggingface charactr/vocos-mel-24khz") repo_id = "charactr/vocos-mel-24khz" - config_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="config.yaml") - model_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin") + config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") + model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") vocoder = Vocos.from_hparams(config_path) state_dict = torch.load(model_path, map_location="cpu", weights_only=True) from vocos.feature_extractors import EncodecFeatures @@ -119,11 +119,11 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path=None, device=d from third_party.BigVGAN import bigvgan except ImportError: print("You need to follow the README to init submodule and change the BigVGAN source code.") - if is_local and local_path is not None: + if is_local: """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main""" vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False) else: - local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=local_path) + local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=hf_cache_dir) vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False) vocoder.remove_weight_norm()