diff --git a/inference-cli.py b/inference-cli.py index 60a69bc..480fc5c 100644 --- a/inference-cli.py +++ b/inference-cli.py @@ -1,26 +1,24 @@ +import argparse +import codecs import re -import torch -import torchaudio -import numpy as np import tempfile -from einops import rearrange -from vocos import Vocos -from pydub import AudioSegment, silence -from model import CFM, UNetT, DiT, MMDiT -from cached_path import cached_path -from model.utils import ( - load_checkpoint, - get_tokenizer, - convert_char_to_pinyin, - save_spectrogram, -) -from transformers import pipeline +from pathlib import Path + +import numpy as np import soundfile as sf import tomli -import argparse +import torch +import torchaudio import tqdm -from pathlib import Path -import codecs +from cached_path import cached_path +from einops import rearrange +from pydub import AudioSegment, silence +from transformers import pipeline +from vocos import Vocos + +from model import CFM, DiT, MMDiT, UNetT +from model.utils import (convert_char_to_pinyin, get_tokenizer, + load_checkpoint, save_spectrogram) parser = argparse.ArgumentParser( prog="python3 inference-cli.py", @@ -73,6 +71,11 @@ parser.add_argument( "--remove_silence", help="Remove silence.", ) +parser.add_argument( + "--load_vocoder_from_local", + action="store_true", + help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz", +) args = parser.parse_args() config = tomli.load(open(args.config, "rb")) @@ -88,6 +91,7 @@ model = args.model if args.model else config["model"] remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"] wave_path = Path(output_dir)/"out.wav" spectrogram_path = Path(output_dir)/"out.png" +vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" SPLIT_WORDS = [ "but", "however", "nevertheless", "yet", "still", @@ -105,7 +109,16 @@ device = ( if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) -vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") + +if args.load_vocoder_from_local: + print(f"Load vocos from local path {vocos_local_path}") + vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") + state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device) + vocos.load_state_dict(state_dict) + vocos.eval() +else: + print("Donwload Vocos from huggingface charactr/vocos-mel-24khz") + vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") print(f"Using {device} device") @@ -124,8 +137,9 @@ speed = 1.0 fix_duration = None def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step): - ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) - # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors + ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors + if not Path(ckpt_path).exists(): + ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin") model = CFM( transformer=model_cls( @@ -385,4 +399,4 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence) -infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS)) \ No newline at end of file +infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))