fix. #213 correct device initialization

This commit is contained in:
lpscr
2024-10-22 12:48:48 +03:00
committed by GitHub
parent f8eb8ab740
commit cd3c4afa69

View File

@@ -19,8 +19,14 @@ from model.utils import (
convert_char_to_pinyin,
)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
# get device
def get_device():
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# print(f"Using {device} device")
return device
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
@@ -76,7 +82,9 @@ def chunk_text(text, max_chars=135):
# load vocoder
def load_vocoder(is_local=False, local_path="", device=device):
def load_vocoder(is_local=False, local_path="", device=None):
if device is None:
device = get_device()
if is_local:
print(f"Load vocos from local path {local_path}")
vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
@@ -94,8 +102,10 @@ def load_vocoder(is_local=False, local_path="", device=device):
asr_pipe = None
def initialize_asr_pipeline(device=device):
def initialize_asr_pipeline(device=None):
global asr_pipe
if device is None:
device = get_device()
asr_pipe = pipeline(
"automatic-speech-recognition",
@@ -108,7 +118,9 @@ def initialize_asr_pipeline(device=device):
# load model for inference
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device):
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=None):
if device is None:
device = get_device()
if vocab_file == "":
vocab_file = "Emilia_ZH_EN"
tokenizer = "pinyin"
@@ -141,7 +153,9 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me
# preprocess reference audio and text
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device):
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=None):
device = get_device(device)
show_info("Converting audio...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
aseg = AudioSegment.from_file(ref_audio_orig)
@@ -243,7 +257,11 @@ def infer_batch_process(
sway_sampling_coef=-1,
speed=1,
fix_duration=None,
device=None,
):
if device is None:
device = get_device()
audio, sr = ref_audio
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
@@ -254,7 +272,7 @@ def infer_batch_process(
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
audio = resampler(audio)
audio = audio.to(device)
audio = audio.to()
generated_waves = []
spectrograms = []