add device option for infer-cli

This commit is contained in:
SWivid
2025-03-22 17:30:23 +08:00
parent 4ae5347282
commit 1d82b7928e
3 changed files with 17 additions and 6 deletions

View File

@@ -119,7 +119,7 @@ class F5TTS:
seed_everything(seed)
self.seed = seed
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)
wav, sr, spec = infer_process(
ref_file,

View File

@@ -162,6 +162,11 @@ parser.add_argument(
type=float,
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
)
parser.add_argument(
"--device",
type=str,
help="Specify the device to run on",
)
args = parser.parse_args()
@@ -202,6 +207,7 @@ cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
speed = args.speed or config.get("speed", speed)
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
device = args.device
# patches for pip pkg user
@@ -239,7 +245,9 @@ if vocoder_name == "vocos":
elif vocoder_name == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
vocoder = load_vocoder(
vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
)
# load TTS model
@@ -270,7 +278,9 @@ if not ckpt_file:
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
print(f"Using {model}...")
ema_model = load_model(model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
ema_model = load_model(
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
)
# inference process
@@ -326,6 +336,7 @@ def main():
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=device,
)
generated_audio_segments.append(audio_segment)

View File

@@ -149,7 +149,7 @@ def initialize_asr_pipeline(device: str = device, dtype=None):
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 6
and torch.cuda.get_device_properties(device).major >= 7
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
@@ -186,7 +186,7 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 6
and torch.cuda.get_device_properties(device).major >= 7
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
@@ -289,7 +289,7 @@ def remove_silence_edges(audio, silence_threshold=-42):
# preprocess reference audio and text
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device):
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
show_info("Converting audio...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
aseg = AudioSegment.from_file(ref_audio_orig)