From 929b5ae313f27563dfae421adc9fa3eb00f49f49 Mon Sep 17 00:00:00 2001 From: Jarod Mica Date: Fri, 15 Nov 2024 01:50:20 -0800 Subject: [PATCH] run pre-commit --- src/f5_tts/api.py | 12 +++++++++--- src/f5_tts/infer/utils_infer.py | 8 ++++++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/f5_tts/api.py b/src/f5_tts/api.py index 66dea81..610eb3b 100644 --- a/src/f5_tts/api.py +++ b/src/f5_tts/api.py @@ -56,14 +56,20 @@ class F5TTS: if model_type == "F5-TTS": if not ckpt_file: if mel_spec_type == "vocos": - ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=local_path)) + ckpt_file = str( + cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=local_path) + ) elif mel_spec_type == "bigvgan": - ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=local_path)) + ckpt_file = str( + cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=local_path) + ) model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) model_cls = DiT elif model_type == "E2-TTS": if not ckpt_file: - ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=local_path)) + ckpt_file = str( + cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=local_path) + ) model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) model_cls = UNetT else: diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index 8773269..11910b5 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -96,8 +96,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev print(f"Load vocos from local path {local_path}") repo_id = "charactr/vocos-mel-24khz" revision = None - config_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="config.yaml", revision=revision) - model_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin", revision=revision) + config_path = hf_hub_download( + repo_id=repo_id, cache_dir=local_path, filename="config.yaml", revision=revision + ) + model_path = hf_hub_download( + repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin", revision=revision + ) vocoder = Vocos.from_hparams(config_path=config_path) state_dict = torch.load(model_path, map_location="cpu") vocoder.load_state_dict(state_dict)