From 795cb19e4fb37fe2c9ade85dc6a80f6d6286b775 Mon Sep 17 00:00:00 2001 From: Haitao Date: Mon, 21 Oct 2024 17:00:48 +0800 Subject: [PATCH] allow for passing in custom mel spec module (#200) --- model/dataset.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/model/dataset.py b/model/dataset.py index 03ed473..c293fe2 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -8,8 +8,10 @@ from torch.utils.data import Dataset, Sampler import torchaudio from datasets import load_from_disk from datasets import Dataset as Dataset_ +from torch import nn from model.modules import MelSpec +from model.utils import default class HFDataset(Dataset): @@ -77,15 +79,22 @@ class CustomDataset(Dataset): hop_length=256, n_mel_channels=100, preprocessed_mel=False, + mel_spec_module: nn.Module | None = None, ): self.data = custom_dataset self.durations = durations self.target_sample_rate = target_sample_rate self.hop_length = hop_length self.preprocessed_mel = preprocessed_mel + if not preprocessed_mel: - self.mel_spectrogram = MelSpec( - target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels + self.mel_spectrogram = default( + mel_spec_module, + MelSpec( + target_sample_rate=target_sample_rate, + hop_length=hop_length, + n_mel_channels=n_mel_channels, + ), ) def get_frame_len(self, index): @@ -201,6 +210,7 @@ def load_dataset( tokenizer: str = "pinyin", dataset_type: str = "CustomDataset", audio_type: str = "raw", + mel_spec_module: nn.Module | None = None, mel_spec_kwargs: dict = dict(), ) -> CustomDataset | HFDataset: """ @@ -224,7 +234,11 @@ def load_dataset( data_dict = json.load(f) durations = data_dict["duration"] train_dataset = CustomDataset( - train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs + train_dataset, + durations=durations, + preprocessed_mel=preprocessed_mel, + mel_spec_module=mel_spec_module, + **mel_spec_kwargs, ) elif dataset_type == "CustomDatasetPath":