mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-27 23:34:17 -08:00
add device option for infer-cli
This commit is contained in:
@@ -119,7 +119,7 @@ class F5TTS:
|
||||
seed_everything(seed)
|
||||
self.seed = seed
|
||||
|
||||
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
|
||||
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)
|
||||
|
||||
wav, sr, spec = infer_process(
|
||||
ref_file,
|
||||
|
||||
@@ -162,6 +162,11 @@ parser.add_argument(
|
||||
type=float,
|
||||
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
help="Specify the device to run on",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -202,6 +207,7 @@ cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
|
||||
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
|
||||
speed = args.speed or config.get("speed", speed)
|
||||
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
|
||||
device = args.device
|
||||
|
||||
|
||||
# patches for pip pkg user
|
||||
@@ -239,7 +245,9 @@ if vocoder_name == "vocos":
|
||||
elif vocoder_name == "bigvgan":
|
||||
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
||||
|
||||
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
|
||||
vocoder = load_vocoder(
|
||||
vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
|
||||
)
|
||||
|
||||
|
||||
# load TTS model
|
||||
@@ -270,7 +278,9 @@ if not ckpt_file:
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
|
||||
|
||||
print(f"Using {model}...")
|
||||
ema_model = load_model(model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
|
||||
ema_model = load_model(
|
||||
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
|
||||
)
|
||||
|
||||
|
||||
# inference process
|
||||
@@ -326,6 +336,7 @@ def main():
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
speed=speed,
|
||||
fix_duration=fix_duration,
|
||||
device=device,
|
||||
)
|
||||
generated_audio_segments.append(audio_segment)
|
||||
|
||||
|
||||
@@ -149,7 +149,7 @@ def initialize_asr_pipeline(device: str = device, dtype=None):
|
||||
dtype = (
|
||||
torch.float16
|
||||
if "cuda" in device
|
||||
and torch.cuda.get_device_properties(device).major >= 6
|
||||
and torch.cuda.get_device_properties(device).major >= 7
|
||||
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
||||
else torch.float32
|
||||
)
|
||||
@@ -186,7 +186,7 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
||||
dtype = (
|
||||
torch.float16
|
||||
if "cuda" in device
|
||||
and torch.cuda.get_device_properties(device).major >= 6
|
||||
and torch.cuda.get_device_properties(device).major >= 7
|
||||
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
||||
else torch.float32
|
||||
)
|
||||
@@ -289,7 +289,7 @@ def remove_silence_edges(audio, silence_threshold=-42):
|
||||
# preprocess reference audio and text
|
||||
|
||||
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device):
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
|
||||
show_info("Converting audio...")
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
aseg = AudioSegment.from_file(ref_audio_orig)
|
||||
|
||||
Reference in New Issue
Block a user