diff --git a/gradio_app.py b/gradio_app.py index 1a37459..0bbb594 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -173,6 +173,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, sway_sampling_coef=sway_sampling_coef, ) + generated = generated.to(torch.float32) generated = generated[:, ref_audio_len:, :] generated_mel_spec = rearrange(generated, "1 n d -> 1 d n") generated_wave = vocos.decode(generated_mel_spec.cpu()) diff --git a/inference-cli.py b/inference-cli.py index d7f7760..7162790 100644 --- a/inference-cli.py +++ b/inference-cli.py @@ -145,9 +145,9 @@ def load_model(model_cls, model_cfg, ckpt_path,file_vocab): else: tokenizer="custom" - print("\nvocab : ",vocab_file,tokenizer) - print("tokenizer : ",tokenizer) - print("model : ",ckpt_path,"\n") + print("\nvocab : ", vocab_file,tokenizer) + print("tokenizer : ", tokenizer) + print("model : ", ckpt_path,"\n") vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer) model = CFM( @@ -265,6 +265,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_voca sway_sampling_coef=sway_sampling_coef, ) + generated = generated.to(torch.float32) generated = generated[:, ref_audio_len:, :] generated_mel_spec = rearrange(generated, "1 n d -> 1 d n") generated_wave = vocos.decode(generated_mel_spec.cpu()) diff --git a/model/cfm.py b/model/cfm.py index 70a38a7..c494700 100644 --- a/model/cfm.py +++ b/model/cfm.py @@ -99,6 +99,8 @@ class CFM(nn.Module): ): self.eval() + cond = cond.half() + # raw wave if cond.ndim == 2: @@ -175,7 +177,7 @@ class CFM(nn.Module): for dur in duration: if exists(seed): torch.manual_seed(seed) - y0.append(torch.randn(dur, self.num_channels, device = self.device)) + y0.append(torch.randn(dur, self.num_channels, device = self.device, dtype=step_cond.dtype)) y0 = pad_sequence(y0, padding_value = 0, batch_first = True) t_start = 0 @@ -186,7 +188,7 @@ class CFM(nn.Module): y0 = (1 - t_start) * y0 + t_start * test_cond steps = int(steps * (1 - t_start)) - t = torch.linspace(t_start, 1, steps, device = self.device) + t = torch.linspace(t_start, 1, steps, device = self.device, dtype=step_cond.dtype) if sway_sampling_coef is not None: t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) diff --git a/model/modules.py b/model/modules.py index fa2e3b1..fffa4b4 100644 --- a/model/modules.py +++ b/model/modules.py @@ -571,5 +571,6 @@ class TimestepEmbedding(nn.Module): def forward(self, timestep: float['b']): time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) time = self.time_mlp(time_hidden) # b d return time diff --git a/model/trainer.py b/model/trainer.py index c5c956a..cf7c9e6 100644 --- a/model/trainer.py +++ b/model/trainer.py @@ -45,7 +45,8 @@ class Trainer: wandb_resume_id: str = None, last_per_steps = None, accelerate_kwargs: dict = dict(), - ema_kwargs: dict = dict() + ema_kwargs: dict = dict(), + bnb_optimizer: bool = False, ): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True) @@ -107,7 +108,11 @@ class Trainer: self.duration_predictor = duration_predictor - self.optimizer = AdamW(model.parameters(), lr=learning_rate) + if bnb_optimizer: + import bitsandbytes as bnb + self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate) + else: + self.optimizer = AdamW(model.parameters(), lr=learning_rate) self.model, self.optimizer = self.accelerator.prepare( self.model, self.optimizer ) diff --git a/model/utils.py b/model/utils.py index 434c4ee..2d66167 100644 --- a/model/utils.py +++ b/model/utils.py @@ -557,23 +557,23 @@ def repetition_found(text, length = 2, tolerance = 10): # load model checkpoint for inference def load_checkpoint(model, ckpt_path, device, use_ema = True): - from ema_pytorch import EMA + model = model.half() ckpt_type = ckpt_path.split(".")[-1] if ckpt_type == "safetensors": from safetensors.torch import load_file - checkpoint = load_file(ckpt_path, device=device) + checkpoint = load_file(ckpt_path) else: - checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device) + checkpoint = torch.load(ckpt_path, weights_only=True) - if use_ema == True: - ema_model = EMA(model, include_online_model = False).to(device) + if use_ema: if ckpt_type == "safetensors": - ema_model.load_state_dict(checkpoint) - else: - ema_model.load_state_dict(checkpoint['ema_model_state_dict']) - ema_model.copy_params_from_ema_to_model() - else: + checkpoint = {'ema_model_state_dict': checkpoint} + checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]} model.load_state_dict(checkpoint['model_state_dict']) - - return model \ No newline at end of file + else: + if ckpt_type == "safetensors": + checkpoint = {'model_state_dict': checkpoint} + model.load_state_dict(checkpoint['model_state_dict']) + + return model.to(device) diff --git a/requirements.txt b/requirements.txt index debc498..8d2ef05 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ accelerate>=0.33.0 +bitsandbytes>0.37.0 cached_path click datasets diff --git a/speech_edit.py b/speech_edit.py index 991eac4..97f031b 100644 --- a/speech_edit.py +++ b/speech_edit.py @@ -49,7 +49,7 @@ elif exp_name == "E2TTS_Base": model_cls = UNetT model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4) -ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" +ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors" output_dir = "tests" # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment] @@ -172,12 +172,13 @@ with torch.inference_mode(): print(f"Generated mel: {generated.shape}") # Final result +generated = generated.to(torch.float32) generated = generated[:, ref_audio_len:, :] generated_mel_spec = rearrange(generated, '1 n d -> 1 d n') generated_wave = vocos.decode(generated_mel_spec.cpu()) if rms < target_rms: generated_wave = generated_wave * rms / target_rms -save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single_edit.png") -torchaudio.save(f"{output_dir}/test_single_edit.wav", generated_wave, target_sample_rate) +save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png") +torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate) print(f"Generated wav: {generated_wave.shape}")