add ckpt load opt. for .safetensor

This commit is contained in:
SWivid
2024-10-13 10:55:18 +08:00
parent edc189fa96
commit 9395289d7a
5 changed files with 35 additions and 24 deletions

View File

@@ -545,3 +545,28 @@ def repetition_found(text, length = 2, tolerance = 10):
if count > tolerance:
return True
return False
# load model checkpoint for inference
def load_checkpoint(model, ckpt_path, device, use_ema = True):
from ema_pytorch import EMA
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
checkpoint = load_file(ckpt_path, device=device)
else:
checkpoint = torch.load(ckpt_path, map_location=device)
if use_ema == True:
ema_model = EMA(model, include_online_model = False).to(device)
if ckpt_type == "safetensors":
ema_model.load_state_dict(checkpoint)
else:
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
ema_model.copy_params_from_ema_to_model()
else:
model.load_state_dict(checkpoint['model_state_dict'])
return model

View File

@@ -10,6 +10,7 @@ jiwer
librosa
matplotlib
pypinyin
safetensors
# torch>=2.0
# torchaudio>=2.3.0
torchdiffeq

View File

@@ -8,11 +8,11 @@ import torch
import torchaudio
from accelerate import Accelerator
from einops import rearrange
from ema_pytorch import EMA
from vocos import Vocos
from model import CFM, UNetT, DiT
from model.utils import (
load_checkpoint,
get_tokenizer,
get_seedtts_testset_metainfo,
get_librispeech_test_clean_metainfo,
@@ -55,7 +55,7 @@ seed = args.seed
dataset_name = args.dataset
exp_name = args.expname
ckpt_step = args.ckptstep
checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
nfe_step = args.nfestep
ode_method = args.odemethod
@@ -152,12 +152,7 @@ model = CFM(
vocab_char_map = vocab_char_map,
).to(device)
if use_ema == True:
ema_model = EMA(model, include_online_model = False).to(device)
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
ema_model.copy_params_from_ema_to_model()
else:
model.load_state_dict(checkpoint['model_state_dict'])
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
if not os.path.exists(output_dir) and accelerator.is_main_process:
os.makedirs(output_dir)

View File

@@ -4,11 +4,11 @@ import re
import torch
import torchaudio
from einops import rearrange
from ema_pytorch import EMA
from vocos import Vocos
from model import CFM, UNetT, DiT, MMDiT
from model.utils import (
load_checkpoint,
get_tokenizer,
convert_char_to_pinyin,
save_spectrogram,
@@ -50,7 +50,7 @@ elif exp_name == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
output_dir = "tests"
ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
@@ -101,12 +101,7 @@ model = CFM(
vocab_char_map = vocab_char_map,
).to(device)
if use_ema == True:
ema_model = EMA(model, include_online_model = False).to(device)
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
ema_model.copy_params_from_ema_to_model()
else:
model.load_state_dict(checkpoint['model_state_dict'])
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
# Audio
audio, sr = torchaudio.load(ref_audio)

View File

@@ -4,11 +4,11 @@ import torch
import torch.nn.functional as F
import torchaudio
from einops import rearrange
from ema_pytorch import EMA
from vocos import Vocos
from model import CFM, UNetT, DiT, MMDiT
from model.utils import (
load_checkpoint,
get_tokenizer,
convert_char_to_pinyin,
save_spectrogram,
@@ -49,7 +49,7 @@ elif exp_name == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
output_dir = "tests"
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
@@ -112,12 +112,7 @@ model = CFM(
vocab_char_map = vocab_char_map,
).to(device)
if use_ema == True:
ema_model = EMA(model, include_online_model = False).to(device)
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
ema_model.copy_params_from_ema_to_model()
else:
model.load_state_dict(checkpoint['model_state_dict'])
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
# Audio
audio, sr = torchaudio.load(audio_to_edit)