diff --git a/src/f5_tts/api.py b/src/f5_tts/api.py index 0ead776..7bf872e 100644 --- a/src/f5_tts/api.py +++ b/src/f5_tts/api.py @@ -119,7 +119,7 @@ class F5TTS: seed_everything(seed) self.seed = seed - ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device) + ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text) wav, sr, spec = infer_process( ref_file, diff --git a/src/f5_tts/infer/infer_cli.py b/src/f5_tts/infer/infer_cli.py index acabf6f..ebe7a43 100644 --- a/src/f5_tts/infer/infer_cli.py +++ b/src/f5_tts/infer/infer_cli.py @@ -162,6 +162,11 @@ parser.add_argument( type=float, help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}", ) +parser.add_argument( + "--device", + type=str, + help="Specify the device to run on", +) args = parser.parse_args() @@ -202,6 +207,7 @@ cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength) sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef) speed = args.speed or config.get("speed", speed) fix_duration = args.fix_duration or config.get("fix_duration", fix_duration) +device = args.device # patches for pip pkg user @@ -239,7 +245,9 @@ if vocoder_name == "vocos": elif vocoder_name == "bigvgan": vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" -vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path) +vocoder = load_vocoder( + vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device +) # load TTS model @@ -270,7 +278,9 @@ if not ckpt_file: ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")) print(f"Using {model}...") -ema_model = load_model(model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file) +ema_model = load_model( + model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device +) # inference process @@ -326,6 +336,7 @@ def main(): sway_sampling_coef=sway_sampling_coef, speed=speed, fix_duration=fix_duration, + device=device, ) generated_audio_segments.append(audio_segment) diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index ac1a778..6b33bb8 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -149,7 +149,7 @@ def initialize_asr_pipeline(device: str = device, dtype=None): dtype = ( torch.float16 if "cuda" in device - and torch.cuda.get_device_properties(device).major >= 6 + and torch.cuda.get_device_properties(device).major >= 7 and not torch.cuda.get_device_name().endswith("[ZLUDA]") else torch.float32 ) @@ -186,7 +186,7 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True): dtype = ( torch.float16 if "cuda" in device - and torch.cuda.get_device_properties(device).major >= 6 + and torch.cuda.get_device_properties(device).major >= 7 and not torch.cuda.get_device_name().endswith("[ZLUDA]") else torch.float32 ) @@ -289,7 +289,7 @@ def remove_silence_edges(audio, silence_threshold=-42): # preprocess reference audio and text -def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device): +def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print): show_info("Converting audio...") with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: aseg = AudioSegment.from_file(ref_audio_orig)