mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-04-28 08:43:06 -07:00
reuse resamplers and cache vocos MelSpectrogram instances, it will reduce some training cost
This commit is contained in:
@@ -37,6 +37,7 @@ class HFDataset(Dataset):
|
|||||||
target_sample_rate=target_sample_rate,
|
target_sample_rate=target_sample_rate,
|
||||||
mel_spec_type=mel_spec_type,
|
mel_spec_type=mel_spec_type,
|
||||||
)
|
)
|
||||||
|
self._resamplers = {}
|
||||||
|
|
||||||
def get_frame_len(self, index):
|
def get_frame_len(self, index):
|
||||||
row = self.data[index]
|
row = self.data[index]
|
||||||
@@ -51,8 +52,6 @@ class HFDataset(Dataset):
|
|||||||
row = self.data[index]
|
row = self.data[index]
|
||||||
audio = row["audio"]["array"]
|
audio = row["audio"]["array"]
|
||||||
|
|
||||||
# logger.info(f"Audio shape: {audio.shape}")
|
|
||||||
|
|
||||||
sample_rate = row["audio"]["sampling_rate"]
|
sample_rate = row["audio"]["sampling_rate"]
|
||||||
duration = audio.shape[-1] / sample_rate
|
duration = audio.shape[-1] / sample_rate
|
||||||
|
|
||||||
@@ -62,8 +61,9 @@ class HFDataset(Dataset):
|
|||||||
audio_tensor = torch.from_numpy(audio).float()
|
audio_tensor = torch.from_numpy(audio).float()
|
||||||
|
|
||||||
if sample_rate != self.target_sample_rate:
|
if sample_rate != self.target_sample_rate:
|
||||||
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
|
if sample_rate not in self._resamplers:
|
||||||
audio_tensor = resampler(audio_tensor)
|
self._resamplers[sample_rate] = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
|
||||||
|
audio_tensor = self._resamplers[sample_rate](audio_tensor)
|
||||||
|
|
||||||
audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
|
audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
|
||||||
|
|
||||||
@@ -114,6 +114,7 @@ class CustomDataset(Dataset):
|
|||||||
mel_spec_type=mel_spec_type,
|
mel_spec_type=mel_spec_type,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
self._resamplers = {}
|
||||||
|
|
||||||
def get_frame_len(self, index):
|
def get_frame_len(self, index):
|
||||||
if (
|
if (
|
||||||
@@ -149,8 +150,11 @@ class CustomDataset(Dataset):
|
|||||||
|
|
||||||
# resample if necessary
|
# resample if necessary
|
||||||
if source_sample_rate != self.target_sample_rate:
|
if source_sample_rate != self.target_sample_rate:
|
||||||
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
|
if source_sample_rate not in self._resamplers:
|
||||||
audio = resampler(audio)
|
self._resamplers[source_sample_rate] = torchaudio.transforms.Resample(
|
||||||
|
source_sample_rate, self.target_sample_rate
|
||||||
|
)
|
||||||
|
audio = self._resamplers[source_sample_rate](audio)
|
||||||
|
|
||||||
# to mel spectrogram
|
# to mel spectrogram
|
||||||
mel_spec = self.mel_spectrogram(audio)
|
mel_spec = self.mel_spectrogram(audio)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from f5_tts.model.utils import is_package_available
|
|||||||
|
|
||||||
mel_basis_cache = {}
|
mel_basis_cache = {}
|
||||||
hann_window_cache = {}
|
hann_window_cache = {}
|
||||||
|
vocos_mel_stft_cache = {}
|
||||||
|
|
||||||
|
|
||||||
def get_bigvgan_mel_spectrogram(
|
def get_bigvgan_mel_spectrogram(
|
||||||
@@ -84,23 +85,26 @@ def get_vocos_mel_spectrogram(
|
|||||||
hop_length=256,
|
hop_length=256,
|
||||||
win_length=1024,
|
win_length=1024,
|
||||||
):
|
):
|
||||||
mel_stft = torchaudio.transforms.MelSpectrogram(
|
device = waveform.device
|
||||||
sample_rate=target_sample_rate,
|
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{device}"
|
||||||
n_fft=n_fft,
|
if key not in vocos_mel_stft_cache:
|
||||||
win_length=win_length,
|
vocos_mel_stft_cache[key] = torchaudio.transforms.MelSpectrogram(
|
||||||
hop_length=hop_length,
|
sample_rate=target_sample_rate,
|
||||||
n_mels=n_mel_channels,
|
n_fft=n_fft,
|
||||||
power=1,
|
win_length=win_length,
|
||||||
center=True,
|
hop_length=hop_length,
|
||||||
normalized=False,
|
n_mels=n_mel_channels,
|
||||||
norm=None,
|
power=1,
|
||||||
).to(waveform.device)
|
center=True,
|
||||||
|
normalized=False,
|
||||||
|
norm=None,
|
||||||
|
).to(device)
|
||||||
if len(waveform.shape) == 3:
|
if len(waveform.shape) == 3:
|
||||||
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
|
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
|
||||||
|
|
||||||
assert len(waveform.shape) == 2
|
assert len(waveform.shape) == 2
|
||||||
|
|
||||||
mel = mel_stft(waveform)
|
mel = vocos_mel_stft_cache[key](waveform)
|
||||||
mel = mel.clamp(min=1e-5).log()
|
mel = mel.clamp(min=1e-5).log()
|
||||||
return mel
|
return mel
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user