From 81ce1d8670b2603ba04475e12597aa91a344a273 Mon Sep 17 00:00:00 2001 From: 98440 <984400286@qq.com> Date: Mon, 20 Jan 2025 00:47:57 +0800 Subject: [PATCH] Added intel XPU support --- README.md | 4 ++++ src/f5_tts/api.py | 2 +- src/f5_tts/eval/eval_utmos.py | 2 +- src/f5_tts/infer/speech_edit.py | 2 +- src/f5_tts/infer/utils_infer.py | 2 +- src/f5_tts/socket_server.py | 2 +- src/f5_tts/train/finetune_gradio.py | 23 ++++++++++++++++++++++- 7 files changed, 31 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 5f1bb40..23676aa 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,10 @@ pip install torch==2.3.0+cu118 torchaudio==2.3.0+cu118 --extra-index-url https:/ # AMD GPU: install pytorch with your ROCm version, e.g. pip install torch==2.5.1+rocm6.2 torchaudio==2.5.1+rocm6.2 --extra-index-url https://download.pytorch.org/whl/rocm6.2 + +# intel GPU: install pytorch with your XPU version, e.g. +# IntelĀ® Deep Learning Essentials or IntelĀ® oneAPI Base Toolkit must be installed +pip install --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu ``` Then you can choose from a few options below: diff --git a/src/f5_tts/api.py b/src/f5_tts/api.py index 9798a05..4fb1712 100644 --- a/src/f5_tts/api.py +++ b/src/f5_tts/api.py @@ -47,7 +47,7 @@ class F5TTS: else: import torch - self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + self.device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" # Load models self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir) diff --git a/src/f5_tts/eval/eval_utmos.py b/src/f5_tts/eval/eval_utmos.py index 9b069cd..c4e9449 100644 --- a/src/f5_tts/eval/eval_utmos.py +++ b/src/f5_tts/eval/eval_utmos.py @@ -13,7 +13,7 @@ def main(): parser.add_argument("--ext", type=str, default="wav", help="Audio extension.") args = parser.parse_args() - device = "cuda" if torch.cuda.is_available() else "cpu" + device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu" predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True) predictor = predictor.to(device) diff --git a/src/f5_tts/infer/speech_edit.py b/src/f5_tts/infer/speech_edit.py index fc6505c..a4e276a 100644 --- a/src/f5_tts/infer/speech_edit.py +++ b/src/f5_tts/infer/speech_edit.py @@ -10,7 +10,7 @@ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectro from f5_tts.model import CFM, DiT, UNetT from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer -device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" # --------------------- Dataset Settings -------------------- # diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index 5f31fe5..0c3d87f 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -33,7 +33,7 @@ from f5_tts.model.utils import ( _ref_audio_cache = {} -device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" # ----------------------------------------- diff --git a/src/f5_tts/socket_server.py b/src/f5_tts/socket_server.py index a8b50a5..19d6f54 100644 --- a/src/f5_tts/socket_server.py +++ b/src/f5_tts/socket_server.py @@ -17,7 +17,7 @@ from model.backbones.dit import DiT class TTSStreamingProcessor: def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32): self.device = device or ( - "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) # Load the model using the provided checkpoint and vocab files diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index e27ef3a..6c10b7b 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -46,7 +46,7 @@ path_data = str(files("f5_tts").joinpath("../../data")) path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts")) file_train = str(files("f5_tts").joinpath("train/finetune_cli.py")) -device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" # Save settings from a JSON file @@ -889,6 +889,13 @@ def calculate_train( gpu_properties = torch.cuda.get_device_properties(i) total_memory += gpu_properties.total_memory / (1024**3) # in GB + elif torch.xpu.is_available(): + gpu_count = torch.xpu.device_count() + total_memory = 0 + for i in range(gpu_count): + gpu_properties = torch.xpu.get_device_properties(i) + total_memory += gpu_properties.total_memory / (1024**3) + elif torch.backends.mps.is_available(): gpu_count = 1 total_memory = psutil.virtual_memory().available / (1024**3) @@ -1284,7 +1291,21 @@ def get_gpu_stats(): f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n" f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n" ) + elif torch.xpu.is_available(): + gpu_count = torch.xpu.device_count() + for i in range(gpu_count): + gpu_name = torch.xpu.get_device_name(i) + gpu_properties = torch.xpu.get_device_properties(i) + total_memory = gpu_properties.total_memory / (1024**3) # in GB + allocated_memory = torch.xpu.memory_allocated(i) / (1024**2) # in MB + reserved_memory = torch.xpu.memory_reserved(i) / (1024**2) # in MB + gpu_stats += ( + f"GPU {i} Name: {gpu_name}\n" + f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n" + f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n" + f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n" + ) elif torch.backends.mps.is_available(): gpu_count = 1 gpu_stats += "MPS GPU\n"