From cb8ce3306d70dfbee0e7d2423cc7f06e1c2b9c60 Mon Sep 17 00:00:00 2001 From: SWivid Date: Sun, 17 Nov 2024 18:57:28 +0800 Subject: [PATCH] update. compatibility with mps device #477 thanks to @aboutmydreams --- pyproject.toml | 2 +- src/f5_tts/api.py | 10 ++++++---- src/f5_tts/infer/speech_edit.py | 2 ++ src/f5_tts/infer/utils_infer.py | 3 +-- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 67c3d02..2a1c002 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ ] dependencies = [ "accelerate>=0.33.0", - "bitsandbytes>0.37.0", + "bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'", "cached_path", "click", "datasets", diff --git a/src/f5_tts/api.py b/src/f5_tts/api.py index 96f49c1..9798a05 100644 --- a/src/f5_tts/api.py +++ b/src/f5_tts/api.py @@ -3,7 +3,6 @@ import sys from importlib.resources import files import soundfile as sf -import torch import tqdm from cached_path import cached_path @@ -43,9 +42,12 @@ class F5TTS: self.mel_spec_type = vocoder_name # Set device - self.device = device or ( - "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - ) + if device is not None: + self.device = device + else: + import torch + + self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" # Load models self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir) diff --git a/src/f5_tts/infer/speech_edit.py b/src/f5_tts/infer/speech_edit.py index 07bb6d6..4eee068 100644 --- a/src/f5_tts/infer/speech_edit.py +++ b/src/f5_tts/infer/speech_edit.py @@ -1,5 +1,7 @@ import os +os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility + import torch import torch.nn.functional as F import torchaudio diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index d6e7901..f985282 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -3,6 +3,7 @@ import os import sys +os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility sys.path.append(f"../../{os.path.dirname(os.path.abspath(__file__))}/third_party/BigVGAN/") import hashlib @@ -33,8 +34,6 @@ from f5_tts.model.utils import ( _ref_audio_cache = {} device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" -if device == "mps": - os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # -----------------------------------------