Update inference-cli.py add load vocos from local path

This commit is contained in:
Zhikang Niu
2024-10-15 16:48:54 +08:00
committed by GitHub
parent 2d26fba7bf
commit 830a2fe19e

View File

@@ -1,26 +1,24 @@
import argparse
import codecs
import re
import torch
import torchaudio
import numpy as np
import tempfile
from einops import rearrange
from vocos import Vocos
from pydub import AudioSegment, silence
from model import CFM, UNetT, DiT, MMDiT
from cached_path import cached_path
from model.utils import (
load_checkpoint,
get_tokenizer,
convert_char_to_pinyin,
save_spectrogram,
)
from transformers import pipeline
from pathlib import Path
import numpy as np
import soundfile as sf
import tomli
import argparse
import torch
import torchaudio
import tqdm
from pathlib import Path
import codecs
from cached_path import cached_path
from einops import rearrange
from pydub import AudioSegment, silence
from transformers import pipeline
from vocos import Vocos
from model import CFM, DiT, MMDiT, UNetT
from model.utils import (convert_char_to_pinyin, get_tokenizer,
load_checkpoint, save_spectrogram)
parser = argparse.ArgumentParser(
prog="python3 inference-cli.py",
@@ -73,6 +71,11 @@ parser.add_argument(
"--remove_silence",
help="Remove silence.",
)
parser.add_argument(
"--load_vocoder_from_local",
action="store_true",
help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
)
args = parser.parse_args()
config = tomli.load(open(args.config, "rb"))
@@ -88,6 +91,7 @@ model = args.model if args.model else config["model"]
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
wave_path = Path(output_dir)/"out.wav"
spectrogram_path = Path(output_dir)/"out.png"
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
SPLIT_WORDS = [
"but", "however", "nevertheless", "yet", "still",
@@ -105,7 +109,16 @@ device = (
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
if args.load_vocoder_from_local:
print(f"Load vocos from local path {vocos_local_path}")
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()
else:
print("Donwload Vocos from huggingface charactr/vocos-mel-24khz")
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
print(f"Using {device} device")
@@ -124,8 +137,9 @@ speed = 1.0
fix_duration = None
def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
if not Path(ckpt_path).exists():
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
model = CFM(
transformer=model_cls(
@@ -385,4 +399,4 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))
infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))