mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-13 21:47:14 -08:00
fix. #213 correct device initialization
This commit is contained in:
@@ -19,8 +19,14 @@ from model.utils import (
|
||||
convert_char_to_pinyin,
|
||||
)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
print(f"Using {device} device")
|
||||
# get device
|
||||
|
||||
|
||||
def get_device():
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
# print(f"Using {device} device")
|
||||
return device
|
||||
|
||||
|
||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
|
||||
@@ -76,7 +82,9 @@ def chunk_text(text, max_chars=135):
|
||||
|
||||
|
||||
# load vocoder
|
||||
def load_vocoder(is_local=False, local_path="", device=device):
|
||||
def load_vocoder(is_local=False, local_path="", device=None):
|
||||
if device is None:
|
||||
device = get_device()
|
||||
if is_local:
|
||||
print(f"Load vocos from local path {local_path}")
|
||||
vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
|
||||
@@ -94,8 +102,10 @@ def load_vocoder(is_local=False, local_path="", device=device):
|
||||
asr_pipe = None
|
||||
|
||||
|
||||
def initialize_asr_pipeline(device=device):
|
||||
def initialize_asr_pipeline(device=None):
|
||||
global asr_pipe
|
||||
if device is None:
|
||||
device = get_device()
|
||||
|
||||
asr_pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
@@ -108,7 +118,9 @@ def initialize_asr_pipeline(device=device):
|
||||
# load model for inference
|
||||
|
||||
|
||||
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device):
|
||||
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=None):
|
||||
if device is None:
|
||||
device = get_device()
|
||||
if vocab_file == "":
|
||||
vocab_file = "Emilia_ZH_EN"
|
||||
tokenizer = "pinyin"
|
||||
@@ -141,7 +153,9 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me
|
||||
# preprocess reference audio and text
|
||||
|
||||
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device):
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=None):
|
||||
device = get_device(device)
|
||||
|
||||
show_info("Converting audio...")
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
aseg = AudioSegment.from_file(ref_audio_orig)
|
||||
@@ -243,7 +257,11 @@ def infer_batch_process(
|
||||
sway_sampling_coef=-1,
|
||||
speed=1,
|
||||
fix_duration=None,
|
||||
device=None,
|
||||
):
|
||||
if device is None:
|
||||
device = get_device()
|
||||
|
||||
audio, sr = ref_audio
|
||||
if audio.shape[0] > 1:
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
@@ -254,7 +272,7 @@ def infer_batch_process(
|
||||
if sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
||||
audio = resampler(audio)
|
||||
audio = audio.to(device)
|
||||
audio = audio.to()
|
||||
|
||||
generated_waves = []
|
||||
spectrograms = []
|
||||
|
||||
Reference in New Issue
Block a user