mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-30 06:31:54 -08:00
Update dataset.py, formatting
This commit is contained in:
@@ -133,31 +133,30 @@ class CustomDataset(Dataset):
|
||||
text = row["text"]
|
||||
duration = row["duration"]
|
||||
|
||||
# Check if the duration is within the acceptable range
|
||||
# filter by given length
|
||||
if 0.3 <= duration <= 30:
|
||||
break # Valid sample found, exit the loop
|
||||
|
||||
# Move to the next index and wrap around if necessary
|
||||
break # valid
|
||||
|
||||
index = (index + 1) % len(self.data)
|
||||
|
||||
|
||||
if self.preprocessed_mel:
|
||||
mel_spec = torch.tensor(row["mel_spec"])
|
||||
else:
|
||||
audio, source_sample_rate = torchaudio.load(audio_path)
|
||||
|
||||
# If the audio has multiple channels, convert it to mono
|
||||
|
||||
# make sure mono input
|
||||
if audio.shape[0] > 1:
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
|
||||
# Resample the audio if necessary
|
||||
|
||||
# resample if necessary
|
||||
if source_sample_rate != self.target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
|
||||
audio = resampler(audio)
|
||||
|
||||
# Compute the mel spectrogram
|
||||
|
||||
# to mel spectrogram
|
||||
mel_spec = self.mel_spectrogram(audio)
|
||||
mel_spec = mel_spec.squeeze(0) # Convert from (1, D, T) to (D, T)
|
||||
|
||||
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
|
||||
|
||||
return {
|
||||
"mel_spec": mel_spec,
|
||||
"text": text,
|
||||
|
||||
Reference in New Issue
Block a user