fix inference-cli; clean-up

This commit is contained in:
SWivid
2024-10-14 23:40:31 +08:00
parent 9ec24868a9
commit 9d2b8cb3da
12 changed files with 61 additions and 235 deletions

View File

@@ -1,4 +1,3 @@
import os
import re
import torch
import torchaudio
@@ -16,10 +15,8 @@ from model.utils import (
save_spectrogram,
)
from transformers import pipeline
import librosa
import click
import soundfile as sf
import tomllib
import tomli
import argparse
import tqdm
from pathlib import Path
@@ -42,19 +39,19 @@ parser.add_argument(
)
parser.add_argument(
"-r",
"--reference",
"--ref_audio",
type=str,
help="Reference audio file < 15 seconds."
)
parser.add_argument(
"-s",
"--subtitle",
"--ref_text",
type=str,
help="Subtitle for the reference audio."
)
parser.add_argument(
"-t",
"--text",
"--gen_text",
type=str,
help="Text to generate.",
)
@@ -70,11 +67,11 @@ parser.add_argument(
)
args = parser.parse_args()
config = tomllib.load(open(args.config, "rb"))
config = tomli.load(open(args.config, "rb"))
ref_audio = args.reference if args.reference else config["reference"]
ref_text = args.subtitle if args.subtitle else config["subtitle"]
gen_text = args.text if args.text else config["text"]
ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
ref_text = args.ref_text if args.ref_text else config["ref_text"]
gen_text = args.gen_text if args.gen_text else config["gen_text"]
output_dir = args.output_dir if args.output_dir else config["output_dir"]
exp_name = args.model if args.model else config["model"]
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
@@ -100,13 +97,6 @@ device = (
print(f"Using {device} device")
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device=device,
)
# --------------------- Settings -------------------- #
target_sample_rate = 24000
@@ -151,13 +141,6 @@ F5TTS_model_cfg = dict(
)
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
F5TTS_ema_model = load_model(
"F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
)
E2TTS_ema_model = load_model(
"E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
)
def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
if len(text.encode('utf-8')) <= max_chars:
return [text]
@@ -256,9 +239,9 @@ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence):
if exp_name == "F5-TTS":
ema_model = F5TTS_ema_model
ema_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
elif exp_name == "E2-TTS":
ema_model = E2TTS_ema_model
ema_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
audio, sr = torchaudio.load(ref_audio)
if audio.shape[0] > 1:
@@ -363,6 +346,12 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
if not ref_text.strip():
print("No reference text provided, transcribing reference audio...")
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device=device,
)
ref_text = pipe(
ref_audio,
chunk_length_s=30,