update. compatibility with mps device #477 thanks to @aboutmydreams

This commit is contained in:
SWivid
2024-11-17 18:57:28 +08:00
parent 0f80f25c5f
commit cb8ce3306d
4 changed files with 10 additions and 7 deletions

View File

@@ -15,7 +15,7 @@ classifiers = [
]
dependencies = [
"accelerate>=0.33.0",
"bitsandbytes>0.37.0",
"bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
"cached_path",
"click",
"datasets",

View File

@@ -3,7 +3,6 @@ import sys
from importlib.resources import files
import soundfile as sf
import torch
import tqdm
from cached_path import cached_path
@@ -43,9 +42,12 @@ class F5TTS:
self.mel_spec_type = vocoder_name
# Set device
self.device = device or (
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)
if device is not None:
self.device = device
else:
import torch
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# Load models
self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)

View File

@@ -1,5 +1,7 @@
import os
os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
import torch
import torch.nn.functional as F
import torchaudio

View File

@@ -3,6 +3,7 @@
import os
import sys
os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
sys.path.append(f"../../{os.path.dirname(os.path.abspath(__file__))}/third_party/BigVGAN/")
import hashlib
@@ -33,8 +34,6 @@ from f5_tts.model.utils import (
_ref_audio_cache = {}
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
if device == "mps":
os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1"
# -----------------------------------------