mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-05 20:40:12 -08:00
support back avg upsampling for batch, cover up non-mask case
This commit is contained in:
@@ -43,7 +43,7 @@ class TextEmbedding(nn.Module):
|
||||
|
||||
if conv_layers > 0:
|
||||
self.extra_modeling = True
|
||||
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
||||
self.precompute_max_pos = 8192 # 8192 is ~87.38s of 24khz audio; 4096 is ~43.69s of 24khz audio
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
||||
self.text_blocks = nn.Sequential(
|
||||
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
||||
@@ -53,32 +53,33 @@ class TextEmbedding(nn.Module):
|
||||
|
||||
def average_upsample_text_by_mask(self, text, text_mask):
|
||||
batch, text_len, text_dim = text.shape
|
||||
assert batch == 1
|
||||
|
||||
valid_mask = text_mask[0]
|
||||
audio_len = text_len
|
||||
valid_len = valid_mask.sum().item()
|
||||
|
||||
if valid_len == 0:
|
||||
return torch.zeros_like(text)
|
||||
audio_len = text_len # cuz text already padded to same length as audio sequence
|
||||
text_lens = text_mask.sum(dim=1) # [batch]
|
||||
|
||||
upsampled_text = torch.zeros_like(text)
|
||||
|
||||
valid_ind = torch.where(valid_mask)[0]
|
||||
valid_data = text[0, valid_ind, :] # [valid_len, text_dim]
|
||||
for i in range(batch):
|
||||
text_len = text_lens[i].item()
|
||||
|
||||
base_repeat = audio_len // valid_len
|
||||
remainder = audio_len % valid_len
|
||||
if text_len == 0:
|
||||
continue
|
||||
|
||||
indices = []
|
||||
for j in range(valid_len):
|
||||
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
|
||||
indices.extend([j] * repeat_count)
|
||||
valid_ind = torch.where(text_mask[i])[0]
|
||||
valid_data = text[i, valid_ind, :] # [text_len, text_dim]
|
||||
|
||||
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
|
||||
upsampled = valid_data[indices] # [audio_len, text_dim]
|
||||
base_repeat = audio_len // text_len
|
||||
remainder = audio_len % text_len
|
||||
|
||||
upsampled_text[0, :audio_len, :] = upsampled
|
||||
indices = []
|
||||
for j in range(text_len):
|
||||
repeat_count = base_repeat + (1 if j >= text_len - remainder else 0)
|
||||
indices.extend([j] * repeat_count)
|
||||
|
||||
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
|
||||
upsampled = valid_data[indices] # [audio_len, text_dim]
|
||||
|
||||
upsampled_text[i, :audio_len, :] = upsampled
|
||||
|
||||
return upsampled_text
|
||||
|
||||
|
||||
Reference in New Issue
Block a user