From 14e923a427bcdd082ab54c8e29d25da34c355f0f Mon Sep 17 00:00:00 2001 From: atlonxp <38250872+atlonxp@users.noreply.github.com> Date: Mon, 18 Nov 2024 20:53:47 +0700 Subject: [PATCH] Update dataset.py change recursive approach to while loop, avoiding potential memory leak. --- src/f5_tts/model/dataset.py | 44 +++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/src/f5_tts/model/dataset.py b/src/f5_tts/model/dataset.py index 937836d..2c6201c 100644 --- a/src/f5_tts/model/dataset.py +++ b/src/f5_tts/model/dataset.py @@ -127,38 +127,44 @@ class CustomDataset(Dataset): return len(self.data) def __getitem__(self, index): - row = self.data[index] - audio_path = row["audio_path"] - text = row["text"] - duration = row["duration"] - + while True: + row = self.data[index] + audio_path = row["audio_path"] + text = row["text"] + duration = row["duration"] + + # Check if the duration is within the acceptable range + if 0.3 <= duration <= 30: + break # Valid sample found, exit the loop + + # Move to the next index and wrap around if necessary + index = (index + 1) % len(self.data) + if self.preprocessed_mel: mel_spec = torch.tensor(row["mel_spec"]) - else: audio, source_sample_rate = torchaudio.load(audio_path) + + # If the audio has multiple channels, convert it to mono if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) - - if duration > 30 or duration < 0.3: - return self.__getitem__((index + 1) % len(self.data)) - + + # Resample the audio if necessary if source_sample_rate != self.target_sample_rate: resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate) audio = resampler(audio) - + + # Compute the mel spectrogram mel_spec = self.mel_spectrogram(audio) - mel_spec = mel_spec.squeeze(0) # '1 d t -> d t') - - return dict( - mel_spec=mel_spec, - text=text, - ) + mel_spec = mel_spec.squeeze(0) # Convert from (1, D, T) to (D, T) + + return { + "mel_spec": mel_spec, + "text": text, + } # Dynamic Batch Sampler - - class DynamicBatchSampler(Sampler[list[int]]): """Extension of Sampler that will do the following: 1. Change the batch size (essentially number of sequences)