diff --git a/src/f5_tts/model/dataset.py b/src/f5_tts/model/dataset.py index 50448c3..1f729f9 100644 --- a/src/f5_tts/model/dataset.py +++ b/src/f5_tts/model/dataset.py @@ -37,6 +37,7 @@ class HFDataset(Dataset): target_sample_rate=target_sample_rate, mel_spec_type=mel_spec_type, ) + self._resamplers = {} def get_frame_len(self, index): row = self.data[index] @@ -51,8 +52,6 @@ class HFDataset(Dataset): row = self.data[index] audio = row["audio"]["array"] - # logger.info(f"Audio shape: {audio.shape}") - sample_rate = row["audio"]["sampling_rate"] duration = audio.shape[-1] / sample_rate @@ -62,8 +61,9 @@ class HFDataset(Dataset): audio_tensor = torch.from_numpy(audio).float() if sample_rate != self.target_sample_rate: - resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate) - audio_tensor = resampler(audio_tensor) + if sample_rate not in self._resamplers: + self._resamplers[sample_rate] = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate) + audio_tensor = self._resamplers[sample_rate](audio_tensor) audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t') @@ -114,6 +114,7 @@ class CustomDataset(Dataset): mel_spec_type=mel_spec_type, ), ) + self._resamplers = {} def get_frame_len(self, index): if ( @@ -149,8 +150,11 @@ class CustomDataset(Dataset): # resample if necessary if source_sample_rate != self.target_sample_rate: - resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate) - audio = resampler(audio) + if source_sample_rate not in self._resamplers: + self._resamplers[source_sample_rate] = torchaudio.transforms.Resample( + source_sample_rate, self.target_sample_rate + ) + audio = self._resamplers[source_sample_rate](audio) # to mel spectrogram mel_spec = self.mel_spectrogram(audio) diff --git a/src/f5_tts/model/modules.py b/src/f5_tts/model/modules.py index bc1dad6..7165632 100644 --- a/src/f5_tts/model/modules.py +++ b/src/f5_tts/model/modules.py @@ -29,6 +29,7 @@ from f5_tts.model.utils import is_package_available mel_basis_cache = {} hann_window_cache = {} +vocos_mel_stft_cache = {} def get_bigvgan_mel_spectrogram( @@ -84,23 +85,26 @@ def get_vocos_mel_spectrogram( hop_length=256, win_length=1024, ): - mel_stft = torchaudio.transforms.MelSpectrogram( - sample_rate=target_sample_rate, - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - n_mels=n_mel_channels, - power=1, - center=True, - normalized=False, - norm=None, - ).to(waveform.device) + device = waveform.device + key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{device}" + if key not in vocos_mel_stft_cache: + vocos_mel_stft_cache[key] = torchaudio.transforms.MelSpectrogram( + sample_rate=target_sample_rate, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mel_channels, + power=1, + center=True, + normalized=False, + norm=None, + ).to(device) if len(waveform.shape) == 3: waveform = waveform.squeeze(1) # 'b 1 nw -> b nw' assert len(waveform.shape) == 2 - mel = mel_stft(waveform) + mel = vocos_mel_stft_cache[key](waveform) mel = mel.clamp(min=1e-5).log() return mel