diff --git a/model/utils.py b/model/utils.py index 692f2f7..e836477 100644 --- a/model/utils.py +++ b/model/utils.py @@ -545,3 +545,28 @@ def repetition_found(text, length = 2, tolerance = 10): if count > tolerance: return True return False + + +# load model checkpoint for inference + +def load_checkpoint(model, ckpt_path, device, use_ema = True): + from ema_pytorch import EMA + + ckpt_type = ckpt_path.split(".")[-1] + if ckpt_type == "safetensors": + from safetensors.torch import load_file + checkpoint = load_file(ckpt_path, device=device) + else: + checkpoint = torch.load(ckpt_path, map_location=device) + + if use_ema == True: + ema_model = EMA(model, include_online_model = False).to(device) + if ckpt_type == "safetensors": + ema_model.load_state_dict(checkpoint) + else: + ema_model.load_state_dict(checkpoint['ema_model_state_dict']) + ema_model.copy_params_from_ema_to_model() + else: + model.load_state_dict(checkpoint['model_state_dict']) + + return model \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b88c7ff..337da40 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ jiwer librosa matplotlib pypinyin +safetensors # torch>=2.0 # torchaudio>=2.3.0 torchdiffeq diff --git a/test_infer_batch.py b/test_infer_batch.py index 19dba50..afd1b28 100644 --- a/test_infer_batch.py +++ b/test_infer_batch.py @@ -8,11 +8,11 @@ import torch import torchaudio from accelerate import Accelerator from einops import rearrange -from ema_pytorch import EMA from vocos import Vocos from model import CFM, UNetT, DiT from model.utils import ( + load_checkpoint, get_tokenizer, get_seedtts_testset_metainfo, get_librispeech_test_clean_metainfo, @@ -55,7 +55,7 @@ seed = args.seed dataset_name = args.dataset exp_name = args.expname ckpt_step = args.ckptstep -checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device) +ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" nfe_step = args.nfestep ode_method = args.odemethod @@ -152,12 +152,7 @@ model = CFM( vocab_char_map = vocab_char_map, ).to(device) -if use_ema == True: - ema_model = EMA(model, include_online_model = False).to(device) - ema_model.load_state_dict(checkpoint['ema_model_state_dict']) - ema_model.copy_params_from_ema_to_model() -else: - model.load_state_dict(checkpoint['model_state_dict']) +model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema) if not os.path.exists(output_dir) and accelerator.is_main_process: os.makedirs(output_dir) diff --git a/test_infer_single.py b/test_infer_single.py index c5579b6..a5e940a 100644 --- a/test_infer_single.py +++ b/test_infer_single.py @@ -4,11 +4,11 @@ import re import torch import torchaudio from einops import rearrange -from ema_pytorch import EMA from vocos import Vocos from model import CFM, UNetT, DiT, MMDiT from model.utils import ( + load_checkpoint, get_tokenizer, convert_char_to_pinyin, save_spectrogram, @@ -50,7 +50,7 @@ elif exp_name == "E2TTS_Base": model_cls = UNetT model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4) -checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device) +ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" output_dir = "tests" ref_audio = "tests/ref_audio/test_en_1_ref_short.wav" @@ -101,12 +101,7 @@ model = CFM( vocab_char_map = vocab_char_map, ).to(device) -if use_ema == True: - ema_model = EMA(model, include_online_model = False).to(device) - ema_model.load_state_dict(checkpoint['ema_model_state_dict']) - ema_model.copy_params_from_ema_to_model() -else: - model.load_state_dict(checkpoint['model_state_dict']) +model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema) # Audio audio, sr = torchaudio.load(ref_audio) diff --git a/test_infer_single_edit.py b/test_infer_single_edit.py index e8b3094..037d711 100644 --- a/test_infer_single_edit.py +++ b/test_infer_single_edit.py @@ -4,11 +4,11 @@ import torch import torch.nn.functional as F import torchaudio from einops import rearrange -from ema_pytorch import EMA from vocos import Vocos from model import CFM, UNetT, DiT, MMDiT from model.utils import ( + load_checkpoint, get_tokenizer, convert_char_to_pinyin, save_spectrogram, @@ -49,7 +49,7 @@ elif exp_name == "E2TTS_Base": model_cls = UNetT model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4) -checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device) +ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" output_dir = "tests" # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment] @@ -112,12 +112,7 @@ model = CFM( vocab_char_map = vocab_char_map, ).to(device) -if use_ema == True: - ema_model = EMA(model, include_online_model = False).to(device) - ema_model.load_state_dict(checkpoint['ema_model_state_dict']) - ema_model.copy_params_from_ema_to_model() -else: - model.load_state_dict(checkpoint['model_state_dict']) +model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema) # Audio audio, sr = torchaudio.load(audio_to_edit)