diff --git a/model/utils_infer.py b/model/utils_infer.py index 9c1a4db..0355220 100644 --- a/model/utils_infer.py +++ b/model/utils_infer.py @@ -24,7 +24,6 @@ from model.utils import ( def get_device(): device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - # print(f"Using {device} device") return device @@ -273,7 +272,7 @@ def infer_batch_process( if sr != target_sample_rate: resampler = torchaudio.transforms.Resample(sr, target_sample_rate) audio = resampler(audio) - audio = audio.to() + audio = audio.to(device) generated_waves = [] spectrograms = []