Update dataset.py

change recursive approach to while loop, avoiding potential memory leak.
This commit is contained in:
atlonxp
2024-11-18 20:53:47 +07:00
committed by GitHub
parent 5cc02536a6
commit 14e923a427

View File

@@ -127,38 +127,44 @@ class CustomDataset(Dataset):
return len(self.data)
def __getitem__(self, index):
row = self.data[index]
audio_path = row["audio_path"]
text = row["text"]
duration = row["duration"]
while True:
row = self.data[index]
audio_path = row["audio_path"]
text = row["text"]
duration = row["duration"]
# Check if the duration is within the acceptable range
if 0.3 <= duration <= 30:
break # Valid sample found, exit the loop
# Move to the next index and wrap around if necessary
index = (index + 1) % len(self.data)
if self.preprocessed_mel:
mel_spec = torch.tensor(row["mel_spec"])
else:
audio, source_sample_rate = torchaudio.load(audio_path)
# If the audio has multiple channels, convert it to mono
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
if duration > 30 or duration < 0.3:
return self.__getitem__((index + 1) % len(self.data))
# Resample the audio if necessary
if source_sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
audio = resampler(audio)
# Compute the mel spectrogram
mel_spec = self.mel_spectrogram(audio)
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
return dict(
mel_spec=mel_spec,
text=text,
)
mel_spec = mel_spec.squeeze(0) # Convert from (1, D, T) to (D, T)
return {
"mel_spec": mel_spec,
"text": text,
}
# Dynamic Batch Sampler
class DynamicBatchSampler(Sampler[list[int]]):
"""Extension of Sampler that will do the following:
1. Change the batch size (essentially number of sequences)