mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-03-12 21:02:50 -07:00
runtime trtllm: minor fixes. pytorch: update text_embedding logic to correct v0 batching.
This commit is contained in:
@@ -12,6 +12,7 @@ from __future__ import annotations
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from x_transformers.x_transformers import RotaryEmbedding
|
from x_transformers.x_transformers import RotaryEmbedding
|
||||||
|
|
||||||
from f5_tts.model.modules import (
|
from f5_tts.model.modules import (
|
||||||
@@ -236,19 +237,30 @@ class DiT(nn.Module):
|
|||||||
cache: bool = True,
|
cache: bool = True,
|
||||||
audio_mask: bool["b n"] | None = None, # noqa: F722
|
audio_mask: bool["b n"] | None = None, # noqa: F722
|
||||||
):
|
):
|
||||||
seq_len = x.shape[1]
|
if self.text_uncond is None or self.text_cond is None or not cache:
|
||||||
# TODO. modify to get text_embed one by one (to avoid misalignment when batching), as done in runtime imple.
|
batch = x.shape[0]
|
||||||
|
seq_lens = audio_mask.sum(dim=1)
|
||||||
|
text_embed_list = []
|
||||||
|
for i in range(batch):
|
||||||
|
text_embed_i = self.text_embed(
|
||||||
|
text[i].unsqueeze(0),
|
||||||
|
seq_lens[i].item(),
|
||||||
|
drop_text=drop_text,
|
||||||
|
audio_mask=audio_mask,
|
||||||
|
)
|
||||||
|
text_embed_list.append(text_embed_i[0])
|
||||||
|
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
|
||||||
|
if cache:
|
||||||
|
if drop_text:
|
||||||
|
self.text_uncond = text_embed
|
||||||
|
else:
|
||||||
|
self.text_cond = text_embed
|
||||||
|
|
||||||
if cache:
|
if cache:
|
||||||
if drop_text:
|
if drop_text:
|
||||||
if self.text_uncond is None:
|
|
||||||
self.text_uncond = self.text_embed(text, seq_len, drop_text=True, audio_mask=audio_mask)
|
|
||||||
text_embed = self.text_uncond
|
text_embed = self.text_uncond
|
||||||
else:
|
else:
|
||||||
if self.text_cond is None:
|
|
||||||
self.text_cond = self.text_embed(text, seq_len, drop_text=False, audio_mask=audio_mask)
|
|
||||||
text_embed = self.text_cond
|
text_embed = self.text_cond
|
||||||
else:
|
|
||||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text, audio_mask=audio_mask)
|
|
||||||
|
|
||||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||||
|
|
||||||
|
|||||||
@@ -42,13 +42,16 @@ class TextEmbedding(nn.Module):
|
|||||||
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
|
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
|
||||||
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
||||||
|
|
||||||
def forward(self, text, seq_len):
|
def forward(self, text, seq_len, drop_text=False):
|
||||||
text = text + 1
|
text = text + 1
|
||||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||||
text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
|
text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
|
||||||
if self.mask_padding:
|
if self.mask_padding:
|
||||||
text_mask = text == 0
|
text_mask = text == 0
|
||||||
|
|
||||||
|
if drop_text: # cfg for text
|
||||||
|
text = torch.zeros_like(text)
|
||||||
|
|
||||||
text = self.text_embed(text) # b n -> b n d
|
text = self.text_embed(text) # b n -> b n d
|
||||||
text = text + self.freqs_cis[:seq_len, :]
|
text = text + self.freqs_cis[:seq_len, :]
|
||||||
if self.mask_padding:
|
if self.mask_padding:
|
||||||
@@ -385,17 +388,17 @@ class F5TTS(object):
|
|||||||
# get text_embed one by one to avoid misalignment
|
# get text_embed one by one to avoid misalignment
|
||||||
text_and_drop_embedding_list = []
|
text_and_drop_embedding_list = []
|
||||||
for i in range(batch):
|
for i in range(batch):
|
||||||
text_and_drop_embedding_i = self.text_embedding(
|
text_embedding_i = self.text_embedding(
|
||||||
torch.cat(
|
text_pad_sequence[i].unsqueeze(0).to(self.device),
|
||||||
(
|
|
||||||
text_pad_sequence[i].unsqueeze(0).to(self.device),
|
|
||||||
torch.full((1, text_pad_sequence.shape[1]), -1, dtype=torch.int32).to(self.device),
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
),
|
|
||||||
estimated_reference_target_mel_len[i],
|
estimated_reference_target_mel_len[i],
|
||||||
|
drop_text=False,
|
||||||
)
|
)
|
||||||
text_and_drop_embedding_list.extend([text_and_drop_embedding_i[0], text_and_drop_embedding_i[1]])
|
text_embedding_drop_i = self.text_embedding(
|
||||||
|
text_pad_sequence[i].unsqueeze(0).to(self.device),
|
||||||
|
estimated_reference_target_mel_len[i],
|
||||||
|
drop_text=True,
|
||||||
|
)
|
||||||
|
text_and_drop_embedding_list.extend([text_embedding_i[0], text_embedding_drop_i[0]])
|
||||||
|
|
||||||
# pad separately computed text_embed to form batch with max_seq_len
|
# pad separately computed text_embed to form batch with max_seq_len
|
||||||
text_and_drop_embedding = pad_sequence(
|
text_and_drop_embedding = pad_sequence(
|
||||||
|
|||||||
@@ -229,7 +229,7 @@ class TritonPythonModel:
|
|||||||
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
|
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
|
||||||
|
|
||||||
batch = len(requests)
|
batch = len(requests)
|
||||||
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
|
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device)
|
||||||
for i, mel in enumerate(mel_features_list):
|
for i, mel in enumerate(mel_features_list):
|
||||||
mel_features[i, : mel.shape[1], :] = mel
|
mel_features[i, : mel.shape[1], :] = mel
|
||||||
|
|
||||||
@@ -254,9 +254,9 @@ class TritonPythonModel:
|
|||||||
|
|
||||||
responses = []
|
responses = []
|
||||||
for i in range(batch):
|
for i in range(batch):
|
||||||
ref_me_len = reference_mel_len[i]
|
ref_mel_len = reference_mel_len[i]
|
||||||
estimated_mel_len = estimated_reference_target_mel_len[i]
|
estimated_mel_len = estimated_reference_target_mel_len[i]
|
||||||
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
|
denoised_one_item = denoised[i, ref_mel_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
|
||||||
audio = self.forward_vocoder(denoised_one_item)
|
audio = self.forward_vocoder(denoised_one_item)
|
||||||
if reference_rms_list[i] < self.target_rms:
|
if reference_rms_list[i] < self.target_rms:
|
||||||
audio = audio * reference_rms_list[i] / self.target_rms
|
audio = audio * reference_rms_list[i] / self.target_rms
|
||||||
|
|||||||
Reference in New Issue
Block a user