mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-09 03:43:19 -08:00
Use default fp16 inference
This commit is contained in:
@@ -173,6 +173,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
)
|
||||
|
||||
generated = generated.to(torch.float32)
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
||||
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
||||
|
||||
@@ -145,9 +145,9 @@ def load_model(model_cls, model_cfg, ckpt_path,file_vocab):
|
||||
else:
|
||||
tokenizer="custom"
|
||||
|
||||
print("\nvocab : ",vocab_file,tokenizer)
|
||||
print("tokenizer : ",tokenizer)
|
||||
print("model : ",ckpt_path,"\n")
|
||||
print("\nvocab : ", vocab_file,tokenizer)
|
||||
print("tokenizer : ", tokenizer)
|
||||
print("model : ", ckpt_path,"\n")
|
||||
|
||||
vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer)
|
||||
model = CFM(
|
||||
@@ -265,6 +265,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_voca
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
)
|
||||
|
||||
generated = generated.to(torch.float32)
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
||||
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
||||
|
||||
@@ -99,6 +99,8 @@ class CFM(nn.Module):
|
||||
):
|
||||
self.eval()
|
||||
|
||||
cond = cond.half()
|
||||
|
||||
# raw wave
|
||||
|
||||
if cond.ndim == 2:
|
||||
@@ -175,7 +177,7 @@ class CFM(nn.Module):
|
||||
for dur in duration:
|
||||
if exists(seed):
|
||||
torch.manual_seed(seed)
|
||||
y0.append(torch.randn(dur, self.num_channels, device = self.device))
|
||||
y0.append(torch.randn(dur, self.num_channels, device = self.device, dtype=step_cond.dtype))
|
||||
y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
|
||||
|
||||
t_start = 0
|
||||
@@ -186,7 +188,7 @@ class CFM(nn.Module):
|
||||
y0 = (1 - t_start) * y0 + t_start * test_cond
|
||||
steps = int(steps * (1 - t_start))
|
||||
|
||||
t = torch.linspace(t_start, 1, steps, device = self.device)
|
||||
t = torch.linspace(t_start, 1, steps, device = self.device, dtype=step_cond.dtype)
|
||||
if sway_sampling_coef is not None:
|
||||
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
||||
|
||||
|
||||
@@ -571,5 +571,6 @@ class TimestepEmbedding(nn.Module):
|
||||
|
||||
def forward(self, timestep: float['b']):
|
||||
time_hidden = self.time_embed(timestep)
|
||||
time_hidden = time_hidden.to(timestep.dtype)
|
||||
time = self.time_mlp(time_hidden) # b d
|
||||
return time
|
||||
|
||||
@@ -45,7 +45,8 @@ class Trainer:
|
||||
wandb_resume_id: str = None,
|
||||
last_per_steps = None,
|
||||
accelerate_kwargs: dict = dict(),
|
||||
ema_kwargs: dict = dict()
|
||||
ema_kwargs: dict = dict(),
|
||||
bnb_optimizer: bool = False,
|
||||
):
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
|
||||
@@ -107,7 +108,11 @@ class Trainer:
|
||||
|
||||
self.duration_predictor = duration_predictor
|
||||
|
||||
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
||||
if bnb_optimizer:
|
||||
import bitsandbytes as bnb
|
||||
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
||||
else:
|
||||
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
||||
self.model, self.optimizer = self.accelerator.prepare(
|
||||
self.model, self.optimizer
|
||||
)
|
||||
|
||||
@@ -557,23 +557,23 @@ def repetition_found(text, length = 2, tolerance = 10):
|
||||
# load model checkpoint for inference
|
||||
|
||||
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
||||
from ema_pytorch import EMA
|
||||
model = model.half()
|
||||
|
||||
ckpt_type = ckpt_path.split(".")[-1]
|
||||
if ckpt_type == "safetensors":
|
||||
from safetensors.torch import load_file
|
||||
checkpoint = load_file(ckpt_path, device=device)
|
||||
checkpoint = load_file(ckpt_path)
|
||||
else:
|
||||
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
||||
checkpoint = torch.load(ckpt_path, weights_only=True)
|
||||
|
||||
if use_ema == True:
|
||||
ema_model = EMA(model, include_online_model = False).to(device)
|
||||
if use_ema:
|
||||
if ckpt_type == "safetensors":
|
||||
ema_model.load_state_dict(checkpoint)
|
||||
else:
|
||||
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
||||
ema_model.copy_params_from_ema_to_model()
|
||||
else:
|
||||
checkpoint = {'ema_model_state_dict': checkpoint}
|
||||
checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
return model
|
||||
else:
|
||||
if ckpt_type == "safetensors":
|
||||
checkpoint = {'model_state_dict': checkpoint}
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
return model.to(device)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
accelerate>=0.33.0
|
||||
bitsandbytes>0.37.0
|
||||
cached_path
|
||||
click
|
||||
datasets
|
||||
|
||||
@@ -49,7 +49,7 @@ elif exp_name == "E2TTS_Base":
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
||||
|
||||
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
|
||||
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
||||
output_dir = "tests"
|
||||
|
||||
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
||||
@@ -172,12 +172,13 @@ with torch.inference_mode():
|
||||
print(f"Generated mel: {generated.shape}")
|
||||
|
||||
# Final result
|
||||
generated = generated.to(torch.float32)
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
|
||||
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * rms / target_rms
|
||||
|
||||
save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single_edit.png")
|
||||
torchaudio.save(f"{output_dir}/test_single_edit.wav", generated_wave, target_sample_rate)
|
||||
save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
|
||||
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
|
||||
print(f"Generated wav: {generated_wave.shape}")
|
||||
|
||||
Reference in New Issue
Block a user