run pre-commit

This commit is contained in:
Jarod Mica
2024-11-15 01:50:20 -08:00
parent d1d8139bab
commit 929b5ae313
2 changed files with 15 additions and 5 deletions

View File

@@ -56,14 +56,20 @@ class F5TTS:
if model_type == "F5-TTS":
if not ckpt_file:
if mel_spec_type == "vocos":
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=local_path))
ckpt_file = str(
cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=local_path)
)
elif mel_spec_type == "bigvgan":
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=local_path))
ckpt_file = str(
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=local_path)
)
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cls = DiT
elif model_type == "E2-TTS":
if not ckpt_file:
ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=local_path))
ckpt_file = str(
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=local_path)
)
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
model_cls = UNetT
else:

View File

@@ -96,8 +96,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
print(f"Load vocos from local path {local_path}")
repo_id = "charactr/vocos-mel-24khz"
revision = None
config_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="config.yaml", revision=revision)
model_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin", revision=revision)
config_path = hf_hub_download(
repo_id=repo_id, cache_dir=local_path, filename="config.yaml", revision=revision
)
model_path = hf_hub_download(
repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin", revision=revision
)
vocoder = Vocos.from_hparams(config_path=config_path)
state_dict = torch.load(model_path, map_location="cpu")
vocoder.load_state_dict(state_dict)