mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-04-28 08:43:06 -07:00
Merge pull request #1285 from ZhikangNiu/main
reuse resamplers and cache vocos MelSpectrogram instances
This commit is contained in:
@@ -37,6 +37,7 @@ class HFDataset(Dataset):
|
||||
target_sample_rate=target_sample_rate,
|
||||
mel_spec_type=mel_spec_type,
|
||||
)
|
||||
self._resamplers = {}
|
||||
|
||||
def get_frame_len(self, index):
|
||||
row = self.data[index]
|
||||
@@ -51,8 +52,6 @@ class HFDataset(Dataset):
|
||||
row = self.data[index]
|
||||
audio = row["audio"]["array"]
|
||||
|
||||
# logger.info(f"Audio shape: {audio.shape}")
|
||||
|
||||
sample_rate = row["audio"]["sampling_rate"]
|
||||
duration = audio.shape[-1] / sample_rate
|
||||
|
||||
@@ -62,8 +61,9 @@ class HFDataset(Dataset):
|
||||
audio_tensor = torch.from_numpy(audio).float()
|
||||
|
||||
if sample_rate != self.target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
|
||||
audio_tensor = resampler(audio_tensor)
|
||||
if sample_rate not in self._resamplers:
|
||||
self._resamplers[sample_rate] = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
|
||||
audio_tensor = self._resamplers[sample_rate](audio_tensor)
|
||||
|
||||
audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
|
||||
|
||||
@@ -114,6 +114,7 @@ class CustomDataset(Dataset):
|
||||
mel_spec_type=mel_spec_type,
|
||||
),
|
||||
)
|
||||
self._resamplers = {}
|
||||
|
||||
def get_frame_len(self, index):
|
||||
if (
|
||||
@@ -149,8 +150,11 @@ class CustomDataset(Dataset):
|
||||
|
||||
# resample if necessary
|
||||
if source_sample_rate != self.target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
|
||||
audio = resampler(audio)
|
||||
if source_sample_rate not in self._resamplers:
|
||||
self._resamplers[source_sample_rate] = torchaudio.transforms.Resample(
|
||||
source_sample_rate, self.target_sample_rate
|
||||
)
|
||||
audio = self._resamplers[source_sample_rate](audio)
|
||||
|
||||
# to mel spectrogram
|
||||
mel_spec = self.mel_spectrogram(audio)
|
||||
|
||||
@@ -29,6 +29,7 @@ from f5_tts.model.utils import is_package_available
|
||||
|
||||
mel_basis_cache = {}
|
||||
hann_window_cache = {}
|
||||
vocos_mel_stft_cache = {}
|
||||
|
||||
|
||||
def get_bigvgan_mel_spectrogram(
|
||||
@@ -84,23 +85,26 @@ def get_vocos_mel_spectrogram(
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
):
|
||||
mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=target_sample_rate,
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
n_mels=n_mel_channels,
|
||||
power=1,
|
||||
center=True,
|
||||
normalized=False,
|
||||
norm=None,
|
||||
).to(waveform.device)
|
||||
device = waveform.device
|
||||
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{device}"
|
||||
if key not in vocos_mel_stft_cache:
|
||||
vocos_mel_stft_cache[key] = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=target_sample_rate,
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
n_mels=n_mel_channels,
|
||||
power=1,
|
||||
center=True,
|
||||
normalized=False,
|
||||
norm=None,
|
||||
).to(device)
|
||||
if len(waveform.shape) == 3:
|
||||
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
|
||||
|
||||
assert len(waveform.shape) == 2
|
||||
|
||||
mel = mel_stft(waveform)
|
||||
mel = vocos_mel_stft_cache[key](waveform)
|
||||
mel = mel.clamp(min=1e-5).log()
|
||||
return mel
|
||||
|
||||
|
||||
Reference in New Issue
Block a user