Merge pull request #67 from chigkim/cli

Command Line Interface for Inference
This commit is contained in:
Yushen CHEN
2024-10-14 22:37:48 +08:00
committed by GitHub
3 changed files with 401 additions and 0 deletions

391
inference-cli.py Normal file
View File

@@ -0,0 +1,391 @@
import os
import re
import torch
import torchaudio
import numpy as np
import tempfile
from einops import rearrange
from vocos import Vocos
from pydub import AudioSegment, silence
from model import CFM, UNetT, DiT, MMDiT
from cached_path import cached_path
from model.utils import (
load_checkpoint,
get_tokenizer,
convert_char_to_pinyin,
save_spectrogram,
)
from transformers import pipeline
import librosa
import click
import soundfile as sf
import tomllib
import argparse
import tqdm
from pathlib import Path
parser = argparse.ArgumentParser(
prog="python3 inference-cli.py",
description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
epilog="Specify options above to override one or more settings from config.",
)
parser.add_argument(
"-c",
"--config",
help="Configuration file. Default=cli-config.toml",
default="inference-cli.toml",
)
parser.add_argument(
"-m",
"--model",
help="F5-TTS | E2-TTS",
)
parser.add_argument(
"-r",
"--reference",
type=str,
help="Reference audio file < 15 seconds."
)
parser.add_argument(
"-s",
"--subtitle",
type=str,
help="Subtitle for the reference audio."
)
parser.add_argument(
"-t",
"--text",
type=str,
help="Text to generate.",
)
parser.add_argument(
"-o",
"--output_dir",
type=str,
help="Path to output folder..",
)
parser.add_argument(
"--remove_silence",
help="Remove silence.",
)
args = parser.parse_args()
config = tomllib.load(open(args.config, "rb"))
ref_audio = args.reference if args.reference else config["reference"]
ref_text = args.subtitle if args.subtitle else config["subtitle"]
gen_text = args.text if args.text else config["text"]
output_dir = args.output_dir if args.output_dir else config["output_dir"]
exp_name = args.model if args.model else config["model"]
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
wave_path = Path(output_dir)/"out.wav"
spectrogram_path = Path(output_dir)/"out.png"
SPLIT_WORDS = [
"but", "however", "nevertheless", "yet", "still",
"therefore", "thus", "hence", "consequently",
"moreover", "furthermore", "additionally",
"meanwhile", "alternatively", "otherwise",
"namely", "specifically", "for example", "such as",
"in fact", "indeed", "notably",
"in contrast", "on the other hand", "conversely",
"in conclusion", "to summarize", "finally"
]
device = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Using {device} device")
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device=device,
)
# --------------------- Settings -------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
nfe_step = 32 # 16, 32
cfg_strength = 2.0
ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0
# fix_duration = 27 # None or float (duration in seconds)
fix_duration = None
def load_model(exp_name, model_cls, model_cfg, ckpt_step):
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
model = CFM(
transformer=model_cls(
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
),
mel_spec_kwargs=dict(
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
),
odeint_kwargs=dict(
method=ode_method,
),
vocab_char_map=vocab_char_map,
).to(device)
model = load_checkpoint(model, ckpt_path, device, use_ema = True)
return model
# load models
F5TTS_model_cfg = dict(
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
)
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
F5TTS_ema_model = load_model(
"F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
)
E2TTS_ema_model = load_model(
"E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
)
def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
if len(text.encode('utf-8')) <= max_chars:
return [text]
if text[-1] not in ['', '.', '!', '', '?', '']:
text += '.'
sentences = re.split('([。.!?])', text)
sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
batches = []
current_batch = ""
def split_by_words(text):
words = text.split()
current_word_part = ""
word_batches = []
for word in words:
if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
current_word_part += word + ' '
else:
if current_word_part:
# Try to find a suitable split word
for split_word in split_words:
split_index = current_word_part.rfind(' ' + split_word + ' ')
if split_index != -1:
word_batches.append(current_word_part[:split_index].strip())
current_word_part = current_word_part[split_index:].strip() + ' '
break
else:
# If no suitable split word found, just append the current part
word_batches.append(current_word_part.strip())
current_word_part = ""
current_word_part += word + ' '
if current_word_part:
word_batches.append(current_word_part.strip())
return word_batches
for sentence in sentences:
if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
current_batch += sentence
else:
# If adding this sentence would exceed the limit
if current_batch:
batches.append(current_batch)
current_batch = ""
# If the sentence itself is longer than max_chars, split it
if len(sentence.encode('utf-8')) > max_chars:
# First, try to split by colon
colon_parts = sentence.split(':')
if len(colon_parts) > 1:
for part in colon_parts:
if len(part.encode('utf-8')) <= max_chars:
batches.append(part)
else:
# If colon part is still too long, split by comma
comma_parts = re.split('[,]', part)
if len(comma_parts) > 1:
current_comma_part = ""
for comma_part in comma_parts:
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
current_comma_part += comma_part + ','
else:
if current_comma_part:
batches.append(current_comma_part.rstrip(','))
current_comma_part = comma_part + ','
if current_comma_part:
batches.append(current_comma_part.rstrip(','))
else:
# If no comma, split by words
batches.extend(split_by_words(part))
else:
# If no colon, split by comma
comma_parts = re.split('[,]', sentence)
if len(comma_parts) > 1:
current_comma_part = ""
for comma_part in comma_parts:
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
current_comma_part += comma_part + ','
else:
if current_comma_part:
batches.append(current_comma_part.rstrip(','))
current_comma_part = comma_part + ','
if current_comma_part:
batches.append(current_comma_part.rstrip(','))
else:
# If no comma, split by words
batches.extend(split_by_words(sentence))
else:
current_batch = sentence
if current_batch:
batches.append(current_batch)
return batches
def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence):
if exp_name == "F5-TTS":
ema_model = F5TTS_ema_model
elif exp_name == "E2-TTS":
ema_model = E2TTS_ema_model
audio, sr = torchaudio.load(ref_audio)
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
audio = audio * target_rms / rms
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
audio = resampler(audio)
audio = audio.to(device)
generated_waves = []
spectrograms = []
for i, gen_text in enumerate(tqdm.tqdm(gen_text_batches)):
# Prepare the text
if len(ref_text[-1].encode('utf-8')) == 1:
ref_text = ref_text + " "
text_list = [ref_text + gen_text]
final_text_list = convert_char_to_pinyin(text_list)
# Calculate duration
ref_audio_len = audio.shape[-1] // hop_length
zh_pause_punc = r"。,、;:?!"
ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
# inference
with torch.inference_mode():
generated, _ = ema_model.sample(
cond=audio,
text=final_text_list,
duration=duration,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
generated = generated[:, ref_audio_len:, :]
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
generated_wave = vocos.decode(generated_mel_spec.cpu())
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
# wav -> numpy
generated_wave = generated_wave.squeeze().cpu().numpy()
generated_waves.append(generated_wave)
spectrograms.append(generated_mel_spec[0].cpu().numpy())
# Combine all generated waves
final_wave = np.concatenate(generated_waves)
# Remove silence
if remove_silence:
with open(wave_path, "wb") as f:
sf.write(f.name, final_wave, target_sample_rate)
aseg = AudioSegment.from_file(f.name)
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
non_silent_wave += non_silent_seg
aseg = non_silent_wave
aseg.export(f.name, format="wav")
print(f.name)
# Create a combined spectrogram
combined_spectrogram = np.concatenate(spectrograms, axis=1)
save_spectrogram(combined_spectrogram, spectrogram_path)
print(spectrogram_path)
def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words):
if not custom_split_words.strip():
custom_words = [word.strip() for word in custom_split_words.split(',')]
global SPLIT_WORDS
SPLIT_WORDS = custom_words
print(gen_text)
print("Converting audio...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
aseg = AudioSegment.from_file(ref_audio_orig)
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
non_silent_wave += non_silent_seg
aseg = non_silent_wave
audio_duration = len(aseg)
if audio_duration > 15000:
print("Audio is over 15s, clipping to only first 15s.")
aseg = aseg[:15000]
aseg.export(f.name, format="wav")
ref_audio = f.name
if not ref_text.strip():
print("No reference text provided, transcribing reference audio...")
ref_text = pipe(
ref_audio,
chunk_length_s=30,
batch_size=128,
generate_kwargs={"task": "transcribe"},
return_timestamps=False,
)["text"].strip()
print("Finished transcription")
else:
print("Using custom reference text...")
# Split the input text into batches
if len(ref_text.encode('utf-8')) == len(ref_text) and len(gen_text.encode('utf-8')) == len(gen_text):
max_chars = 400-len(ref_text.encode('utf-8'))
else:
max_chars = 300-len(ref_text.encode('utf-8'))
gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
print('ref_text', ref_text)
for i, gen_text in enumerate(gen_text_batches):
print(f'gen_text {i}', gen_text)
print(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
return infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence)
infer(ref_audio, ref_text, gen_text, exp_name, remove_silence, ",".join(SPLIT_WORDS))

8
inference-cli.toml Normal file
View File

@@ -0,0 +1,8 @@
# F5-TTS | E2-TTS
model = "F5-TTS"
reference = "tests/ref_audio/test_en_1_ref_short.wav"
# If an empty "", transcribes the reference audio automatically.
subtitle = "Some call me nature, others call me mother nature."
text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
remove_silence = true
output_dir = "tests"

View File

@@ -21,3 +21,5 @@ wandb
x_transformers>=1.31.14
zhconv
zhon
pydub
cached_path