diff --git a/src/f5_tts/train/train.py b/src/f5_tts/train/train.py index 9ecd4c3..3a428f2 100644 --- a/src/f5_tts/train/train.py +++ b/src/f5_tts/train/train.py @@ -8,8 +8,10 @@ from f5_tts.model import CFM, DiT, Trainer, UNetT from f5_tts.model.dataset import load_dataset from f5_tts.model.utils import get_tokenizer +os.chdir(str(files("f5_tts").joinpath("../.."))) -@hydra.main(config_path=os.path.join("..", "configs"), config_name=None) + +@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None) def main(cfg): tokenizer = cfg.model.tokenizer mel_spec_type = cfg.model.mel_spec.mel_spec_type