mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-06 01:58:14 -08:00
Added intel XPU support
This commit is contained in:
@@ -47,7 +47,15 @@ class F5TTS:
|
||||
else:
|
||||
import torch
|
||||
|
||||
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
self.device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "xpu"
|
||||
if torch.xpu.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
# Load models
|
||||
self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
|
||||
|
||||
@@ -13,7 +13,7 @@ def main():
|
||||
parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
|
||||
args = parser.parse_args()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
|
||||
|
||||
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
|
||||
predictor = predictor.to(device)
|
||||
|
||||
@@ -10,7 +10,15 @@ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectro
|
||||
from f5_tts.model import CFM, DiT, UNetT
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "xpu"
|
||||
if torch.xpu.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
|
||||
# --------------------- Dataset Settings -------------------- #
|
||||
|
||||
@@ -33,7 +33,15 @@ from f5_tts.model.utils import (
|
||||
|
||||
_ref_audio_cache = {}
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "xpu"
|
||||
if torch.xpu.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
# -----------------------------------------
|
||||
|
||||
|
||||
@@ -17,7 +17,13 @@ from model.backbones.dit import DiT
|
||||
class TTSStreamingProcessor:
|
||||
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
|
||||
self.device = device or (
|
||||
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "xpu"
|
||||
if torch.xpu.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
# Load the model using the provided checkpoint and vocab files
|
||||
|
||||
@@ -46,7 +46,15 @@ path_data = str(files("f5_tts").joinpath("../../data"))
|
||||
path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))
|
||||
file_train = str(files("f5_tts").joinpath("train/finetune_cli.py"))
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "xpu"
|
||||
if torch.xpu.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
|
||||
# Save settings from a JSON file
|
||||
@@ -889,6 +897,13 @@ def calculate_train(
|
||||
gpu_properties = torch.cuda.get_device_properties(i)
|
||||
total_memory += gpu_properties.total_memory / (1024**3) # in GB
|
||||
|
||||
elif torch.xpu.is_available():
|
||||
gpu_count = torch.xpu.device_count()
|
||||
total_memory = 0
|
||||
for i in range(gpu_count):
|
||||
gpu_properties = torch.xpu.get_device_properties(i)
|
||||
total_memory += gpu_properties.total_memory / (1024**3)
|
||||
|
||||
elif torch.backends.mps.is_available():
|
||||
gpu_count = 1
|
||||
total_memory = psutil.virtual_memory().available / (1024**3)
|
||||
@@ -1284,7 +1299,21 @@ def get_gpu_stats():
|
||||
f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
|
||||
f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
|
||||
)
|
||||
elif torch.xpu.is_available():
|
||||
gpu_count = torch.xpu.device_count()
|
||||
for i in range(gpu_count):
|
||||
gpu_name = torch.xpu.get_device_name(i)
|
||||
gpu_properties = torch.xpu.get_device_properties(i)
|
||||
total_memory = gpu_properties.total_memory / (1024**3) # in GB
|
||||
allocated_memory = torch.xpu.memory_allocated(i) / (1024**2) # in MB
|
||||
reserved_memory = torch.xpu.memory_reserved(i) / (1024**2) # in MB
|
||||
|
||||
gpu_stats += (
|
||||
f"GPU {i} Name: {gpu_name}\n"
|
||||
f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n"
|
||||
f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
|
||||
f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
|
||||
)
|
||||
elif torch.backends.mps.is_available():
|
||||
gpu_count = 1
|
||||
gpu_stats += "MPS GPU\n"
|
||||
|
||||
Reference in New Issue
Block a user