reuse resamplers and cache vocos MelSpectrogram instances, it will reduce some training cost

This commit is contained in:
ZhikangNiu
2026-04-04 14:31:00 +08:00
parent 82fc4fe622
commit 5486a158d4
2 changed files with 26 additions and 18 deletions

View File

@@ -37,6 +37,7 @@ class HFDataset(Dataset):
target_sample_rate=target_sample_rate, target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type, mel_spec_type=mel_spec_type,
) )
self._resamplers = {}
def get_frame_len(self, index): def get_frame_len(self, index):
row = self.data[index] row = self.data[index]
@@ -51,8 +52,6 @@ class HFDataset(Dataset):
row = self.data[index] row = self.data[index]
audio = row["audio"]["array"] audio = row["audio"]["array"]
# logger.info(f"Audio shape: {audio.shape}")
sample_rate = row["audio"]["sampling_rate"] sample_rate = row["audio"]["sampling_rate"]
duration = audio.shape[-1] / sample_rate duration = audio.shape[-1] / sample_rate
@@ -62,8 +61,9 @@ class HFDataset(Dataset):
audio_tensor = torch.from_numpy(audio).float() audio_tensor = torch.from_numpy(audio).float()
if sample_rate != self.target_sample_rate: if sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate) if sample_rate not in self._resamplers:
audio_tensor = resampler(audio_tensor) self._resamplers[sample_rate] = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
audio_tensor = self._resamplers[sample_rate](audio_tensor)
audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t') audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
@@ -114,6 +114,7 @@ class CustomDataset(Dataset):
mel_spec_type=mel_spec_type, mel_spec_type=mel_spec_type,
), ),
) )
self._resamplers = {}
def get_frame_len(self, index): def get_frame_len(self, index):
if ( if (
@@ -149,8 +150,11 @@ class CustomDataset(Dataset):
# resample if necessary # resample if necessary
if source_sample_rate != self.target_sample_rate: if source_sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate) if source_sample_rate not in self._resamplers:
audio = resampler(audio) self._resamplers[source_sample_rate] = torchaudio.transforms.Resample(
source_sample_rate, self.target_sample_rate
)
audio = self._resamplers[source_sample_rate](audio)
# to mel spectrogram # to mel spectrogram
mel_spec = self.mel_spectrogram(audio) mel_spec = self.mel_spectrogram(audio)

View File

@@ -29,6 +29,7 @@ from f5_tts.model.utils import is_package_available
mel_basis_cache = {} mel_basis_cache = {}
hann_window_cache = {} hann_window_cache = {}
vocos_mel_stft_cache = {}
def get_bigvgan_mel_spectrogram( def get_bigvgan_mel_spectrogram(
@@ -84,23 +85,26 @@ def get_vocos_mel_spectrogram(
hop_length=256, hop_length=256,
win_length=1024, win_length=1024,
): ):
mel_stft = torchaudio.transforms.MelSpectrogram( device = waveform.device
sample_rate=target_sample_rate, key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{device}"
n_fft=n_fft, if key not in vocos_mel_stft_cache:
win_length=win_length, vocos_mel_stft_cache[key] = torchaudio.transforms.MelSpectrogram(
hop_length=hop_length, sample_rate=target_sample_rate,
n_mels=n_mel_channels, n_fft=n_fft,
power=1, win_length=win_length,
center=True, hop_length=hop_length,
normalized=False, n_mels=n_mel_channels,
norm=None, power=1,
).to(waveform.device) center=True,
normalized=False,
norm=None,
).to(device)
if len(waveform.shape) == 3: if len(waveform.shape) == 3:
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw' waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
assert len(waveform.shape) == 2 assert len(waveform.shape) == 2
mel = mel_stft(waveform) mel = vocos_mel_stft_cache[key](waveform)
mel = mel.clamp(min=1e-5).log() mel = mel.clamp(min=1e-5).log()
return mel return mel