Merge pull request #469 from JarodMica/main_repo_update

Allow for local path specification of HF models/repos
This commit is contained in:
Yushen CHEN
2024-11-15 18:22:06 +08:00
committed by GitHub
3 changed files with 35 additions and 8 deletions

View File

@@ -46,24 +46,30 @@ class F5TTS:
)
# Load models
self.load_vocoder_model(vocoder_name, local_path)
self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema)
self.load_vocoder_model(vocoder_name, local_path=local_path)
self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, local_path=local_path)
def load_vocoder_model(self, vocoder_name, local_path):
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema):
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, local_path):
if model_type == "F5-TTS":
if not ckpt_file:
if mel_spec_type == "vocos":
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
ckpt_file = str(
cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=local_path)
)
elif mel_spec_type == "bigvgan":
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"))
ckpt_file = str(
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=local_path)
)
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cls = DiT
elif model_type == "E2-TTS":
if not ckpt_file:
ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
ckpt_file = str(
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=local_path)
)
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
model_cls = UNetT
else:

View File

@@ -40,6 +40,17 @@ VOCAB_FILE: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
## Mandarin
## Japanese
#### F5-TTS Base @ pretrain/finetune @ ja
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_8500000)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
```bash
MODEL_CKPT: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt
VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
```
## English

View File

@@ -19,6 +19,7 @@ import numpy as np
import torch
import torchaudio
import tqdm
from huggingface_hub import snapshot_download, hf_hub_download
from pydub import AudioSegment, silence
from transformers import pipeline
from vocos import Vocos
@@ -93,8 +94,16 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
if vocoder_name == "vocos":
if is_local:
print(f"Load vocos from local path {local_path}")
vocoder = Vocos.from_hparams(f"{local_path}/config.yaml")
state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location="cpu")
repo_id = "charactr/vocos-mel-24khz"
revision = None
config_path = hf_hub_download(
repo_id=repo_id, cache_dir=local_path, filename="config.yaml", revision=revision
)
model_path = hf_hub_download(
repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin", revision=revision
)
vocoder = Vocos.from_hparams(config_path=config_path)
state_dict = torch.load(model_path, map_location="cpu")
vocoder.load_state_dict(state_dict)
vocoder = vocoder.eval().to(device)
else:
@@ -107,6 +116,7 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
print("You need to follow the README to init submodule and change the BigVGAN source code.")
if is_local:
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=local_path)
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
else:
vocoder = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False)