mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-25 20:34:27 -08:00
add ckpt load opt. for .safetensor
This commit is contained in:
@@ -545,3 +545,28 @@ def repetition_found(text, length = 2, tolerance = 10):
|
||||
if count > tolerance:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# load model checkpoint for inference
|
||||
|
||||
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
||||
from ema_pytorch import EMA
|
||||
|
||||
ckpt_type = ckpt_path.split(".")[-1]
|
||||
if ckpt_type == "safetensors":
|
||||
from safetensors.torch import load_file
|
||||
checkpoint = load_file(ckpt_path, device=device)
|
||||
else:
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
|
||||
if use_ema == True:
|
||||
ema_model = EMA(model, include_online_model = False).to(device)
|
||||
if ckpt_type == "safetensors":
|
||||
ema_model.load_state_dict(checkpoint)
|
||||
else:
|
||||
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
||||
ema_model.copy_params_from_ema_to_model()
|
||||
else:
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
return model
|
||||
@@ -10,6 +10,7 @@ jiwer
|
||||
librosa
|
||||
matplotlib
|
||||
pypinyin
|
||||
safetensors
|
||||
# torch>=2.0
|
||||
# torchaudio>=2.3.0
|
||||
torchdiffeq
|
||||
|
||||
@@ -8,11 +8,11 @@ import torch
|
||||
import torchaudio
|
||||
from accelerate import Accelerator
|
||||
from einops import rearrange
|
||||
from ema_pytorch import EMA
|
||||
from vocos import Vocos
|
||||
|
||||
from model import CFM, UNetT, DiT
|
||||
from model.utils import (
|
||||
load_checkpoint,
|
||||
get_tokenizer,
|
||||
get_seedtts_testset_metainfo,
|
||||
get_librispeech_test_clean_metainfo,
|
||||
@@ -55,7 +55,7 @@ seed = args.seed
|
||||
dataset_name = args.dataset
|
||||
exp_name = args.expname
|
||||
ckpt_step = args.ckptstep
|
||||
checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
|
||||
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
|
||||
|
||||
nfe_step = args.nfestep
|
||||
ode_method = args.odemethod
|
||||
@@ -152,12 +152,7 @@ model = CFM(
|
||||
vocab_char_map = vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
if use_ema == True:
|
||||
ema_model = EMA(model, include_online_model = False).to(device)
|
||||
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
||||
ema_model.copy_params_from_ema_to_model()
|
||||
else:
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
|
||||
|
||||
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
||||
os.makedirs(output_dir)
|
||||
|
||||
@@ -4,11 +4,11 @@ import re
|
||||
import torch
|
||||
import torchaudio
|
||||
from einops import rearrange
|
||||
from ema_pytorch import EMA
|
||||
from vocos import Vocos
|
||||
|
||||
from model import CFM, UNetT, DiT, MMDiT
|
||||
from model.utils import (
|
||||
load_checkpoint,
|
||||
get_tokenizer,
|
||||
convert_char_to_pinyin,
|
||||
save_spectrogram,
|
||||
@@ -50,7 +50,7 @@ elif exp_name == "E2TTS_Base":
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
||||
|
||||
checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
|
||||
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
|
||||
output_dir = "tests"
|
||||
|
||||
ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
|
||||
@@ -101,12 +101,7 @@ model = CFM(
|
||||
vocab_char_map = vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
if use_ema == True:
|
||||
ema_model = EMA(model, include_online_model = False).to(device)
|
||||
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
||||
ema_model.copy_params_from_ema_to_model()
|
||||
else:
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
|
||||
|
||||
# Audio
|
||||
audio, sr = torchaudio.load(ref_audio)
|
||||
|
||||
@@ -4,11 +4,11 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from einops import rearrange
|
||||
from ema_pytorch import EMA
|
||||
from vocos import Vocos
|
||||
|
||||
from model import CFM, UNetT, DiT, MMDiT
|
||||
from model.utils import (
|
||||
load_checkpoint,
|
||||
get_tokenizer,
|
||||
convert_char_to_pinyin,
|
||||
save_spectrogram,
|
||||
@@ -49,7 +49,7 @@ elif exp_name == "E2TTS_Base":
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
||||
|
||||
checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
|
||||
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
|
||||
output_dir = "tests"
|
||||
|
||||
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
||||
@@ -112,12 +112,7 @@ model = CFM(
|
||||
vocab_char_map = vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
if use_ema == True:
|
||||
ema_model = EMA(model, include_online_model = False).to(device)
|
||||
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
||||
ema_model.copy_params_from_ema_to_model()
|
||||
else:
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
|
||||
|
||||
# Audio
|
||||
audio, sr = torchaudio.load(audio_to_edit)
|
||||
|
||||
Reference in New Issue
Block a user