diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 3f2f342..bfe1fae 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -26,7 +26,7 @@ from transformers import pipeline from cached_path import cached_path from f5_tts.api import F5TTS from f5_tts.model.utils import convert_char_to_pinyin - +from importlib.resources import files training_process = None system = platform.system() @@ -36,9 +36,9 @@ last_checkpoint = "" last_device = "" last_ema = None -path_basic = os.path.abspath(os.path.join(__file__, "../../../..")) -path_data = os.path.join(path_basic, "data") -path_project_ckpts = os.path.join(path_basic, "ckpts") + +path_data = str(files("f5_tts").joinpath("../../data")) +path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts")) file_train = "src/f5_tts/train/finetune_cli.py" device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"