diff --git a/src/f5_tts/api.py b/src/f5_tts/api.py index a4196f2..610eb3b 100644 --- a/src/f5_tts/api.py +++ b/src/f5_tts/api.py @@ -46,24 +46,30 @@ class F5TTS: ) # Load models - self.load_vocoder_model(vocoder_name, local_path) - self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema) + self.load_vocoder_model(vocoder_name, local_path=local_path) + self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, local_path=local_path) def load_vocoder_model(self, vocoder_name, local_path): self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device) - def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema): + def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, local_path): if model_type == "F5-TTS": if not ckpt_file: if mel_spec_type == "vocos": - ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")) + ckpt_file = str( + cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=local_path) + ) elif mel_spec_type == "bigvgan": - ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt")) + ckpt_file = str( + cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=local_path) + ) model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) model_cls = DiT elif model_type == "E2-TTS": if not ckpt_file: - ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) + ckpt_file = str( + cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=local_path) + ) model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) model_cls = UNetT else: diff --git a/src/f5_tts/infer/SHARED.md b/src/f5_tts/infer/SHARED.md index dbdd63c..a180bd8 100644 --- a/src/f5_tts/infer/SHARED.md +++ b/src/f5_tts/infer/SHARED.md @@ -40,6 +40,17 @@ VOCAB_FILE: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt ## Mandarin +## Japanese + +#### F5-TTS Base @ pretrain/finetune @ ja +|Model|🤗Hugging Face|Data (Hours)|Model License| +|:---:|:------------:|:-----------:|:-------------:| +|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_8500000)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0| + +```bash +MODEL_CKPT: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt +VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt +``` ## English diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index 469855f..11910b5 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -19,6 +19,7 @@ import numpy as np import torch import torchaudio import tqdm +from huggingface_hub import snapshot_download, hf_hub_download from pydub import AudioSegment, silence from transformers import pipeline from vocos import Vocos @@ -93,8 +94,16 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev if vocoder_name == "vocos": if is_local: print(f"Load vocos from local path {local_path}") - vocoder = Vocos.from_hparams(f"{local_path}/config.yaml") - state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location="cpu") + repo_id = "charactr/vocos-mel-24khz" + revision = None + config_path = hf_hub_download( + repo_id=repo_id, cache_dir=local_path, filename="config.yaml", revision=revision + ) + model_path = hf_hub_download( + repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin", revision=revision + ) + vocoder = Vocos.from_hparams(config_path=config_path) + state_dict = torch.load(model_path, map_location="cpu") vocoder.load_state_dict(state_dict) vocoder = vocoder.eval().to(device) else: @@ -107,6 +116,7 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev print("You need to follow the README to init submodule and change the BigVGAN source code.") if is_local: """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main""" + local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=local_path) vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False) else: vocoder = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False)