diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index bcf4efe..f3ebc33 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -32,6 +32,8 @@ from f5_tts.model.utils import ( _ref_audio_cache = {} device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +if device == "mps": + os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # -----------------------------------------