From 5f7944a748384334ccaa18fdcf19f2d48e6559bd Mon Sep 17 00:00:00 2001 From: Yushen CHEN <45333109+SWivid@users.noreply.github.com> Date: Mon, 18 Nov 2024 22:28:03 +0800 Subject: [PATCH] Update dataset.py, formatting --- src/f5_tts/model/dataset.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/f5_tts/model/dataset.py b/src/f5_tts/model/dataset.py index 2c6201c..f86aec2 100644 --- a/src/f5_tts/model/dataset.py +++ b/src/f5_tts/model/dataset.py @@ -133,31 +133,30 @@ class CustomDataset(Dataset): text = row["text"] duration = row["duration"] - # Check if the duration is within the acceptable range + # filter by given length if 0.3 <= duration <= 30: - break # Valid sample found, exit the loop - - # Move to the next index and wrap around if necessary + break # valid + index = (index + 1) % len(self.data) - + if self.preprocessed_mel: mel_spec = torch.tensor(row["mel_spec"]) else: audio, source_sample_rate = torchaudio.load(audio_path) - - # If the audio has multiple channels, convert it to mono + + # make sure mono input if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) - - # Resample the audio if necessary + + # resample if necessary if source_sample_rate != self.target_sample_rate: resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate) audio = resampler(audio) - - # Compute the mel spectrogram + + # to mel spectrogram mel_spec = self.mel_spectrogram(audio) - mel_spec = mel_spec.squeeze(0) # Convert from (1, D, T) to (D, T) - + mel_spec = mel_spec.squeeze(0) # '1 d t -> d t' + return { "mel_spec": mel_spec, "text": text,