mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-10 20:25:52 -08:00
Update inference-cli.py add load vocos from local path
This commit is contained in:
@@ -1,26 +1,24 @@
|
||||
import argparse
|
||||
import codecs
|
||||
import re
|
||||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
import tempfile
|
||||
from einops import rearrange
|
||||
from vocos import Vocos
|
||||
from pydub import AudioSegment, silence
|
||||
from model import CFM, UNetT, DiT, MMDiT
|
||||
from cached_path import cached_path
|
||||
from model.utils import (
|
||||
load_checkpoint,
|
||||
get_tokenizer,
|
||||
convert_char_to_pinyin,
|
||||
save_spectrogram,
|
||||
)
|
||||
from transformers import pipeline
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import tomli
|
||||
import argparse
|
||||
import torch
|
||||
import torchaudio
|
||||
import tqdm
|
||||
from pathlib import Path
|
||||
import codecs
|
||||
from cached_path import cached_path
|
||||
from einops import rearrange
|
||||
from pydub import AudioSegment, silence
|
||||
from transformers import pipeline
|
||||
from vocos import Vocos
|
||||
|
||||
from model import CFM, DiT, MMDiT, UNetT
|
||||
from model.utils import (convert_char_to_pinyin, get_tokenizer,
|
||||
load_checkpoint, save_spectrogram)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="python3 inference-cli.py",
|
||||
@@ -73,6 +71,11 @@ parser.add_argument(
|
||||
"--remove_silence",
|
||||
help="Remove silence.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_vocoder_from_local",
|
||||
action="store_true",
|
||||
help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = tomli.load(open(args.config, "rb"))
|
||||
@@ -88,6 +91,7 @@ model = args.model if args.model else config["model"]
|
||||
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
||||
wave_path = Path(output_dir)/"out.wav"
|
||||
spectrogram_path = Path(output_dir)/"out.png"
|
||||
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
||||
|
||||
SPLIT_WORDS = [
|
||||
"but", "however", "nevertheless", "yet", "still",
|
||||
@@ -105,7 +109,16 @@ device = (
|
||||
if torch.cuda.is_available()
|
||||
else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
)
|
||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
|
||||
if args.load_vocoder_from_local:
|
||||
print(f"Load vocos from local path {vocos_local_path}")
|
||||
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
||||
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
|
||||
vocos.load_state_dict(state_dict)
|
||||
vocos.eval()
|
||||
else:
|
||||
print("Donwload Vocos from huggingface charactr/vocos-mel-24khz")
|
||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
|
||||
print(f"Using {device} device")
|
||||
|
||||
@@ -124,8 +137,9 @@ speed = 1.0
|
||||
fix_duration = None
|
||||
|
||||
def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
|
||||
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
|
||||
if not Path(ckpt_path).exists():
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
||||
model = CFM(
|
||||
transformer=model_cls(
|
||||
@@ -385,4 +399,4 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
|
||||
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
|
||||
|
||||
|
||||
infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))
|
||||
infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))
|
||||
|
||||
Reference in New Issue
Block a user