allow for passing in custom mel spec module (#200)

This commit is contained in:
Haitao
2024-10-21 17:00:48 +08:00
committed by GitHub
parent 25cdc5182f
commit 795cb19e4f

View File

@@ -8,8 +8,10 @@ from torch.utils.data import Dataset, Sampler
import torchaudio
from datasets import load_from_disk
from datasets import Dataset as Dataset_
from torch import nn
from model.modules import MelSpec
from model.utils import default
class HFDataset(Dataset):
@@ -77,15 +79,22 @@ class CustomDataset(Dataset):
hop_length=256,
n_mel_channels=100,
preprocessed_mel=False,
mel_spec_module: nn.Module | None = None,
):
self.data = custom_dataset
self.durations = durations
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.preprocessed_mel = preprocessed_mel
if not preprocessed_mel:
self.mel_spectrogram = MelSpec(
target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels
self.mel_spectrogram = default(
mel_spec_module,
MelSpec(
target_sample_rate=target_sample_rate,
hop_length=hop_length,
n_mel_channels=n_mel_channels,
),
)
def get_frame_len(self, index):
@@ -201,6 +210,7 @@ def load_dataset(
tokenizer: str = "pinyin",
dataset_type: str = "CustomDataset",
audio_type: str = "raw",
mel_spec_module: nn.Module | None = None,
mel_spec_kwargs: dict = dict(),
) -> CustomDataset | HFDataset:
"""
@@ -224,7 +234,11 @@ def load_dataset(
data_dict = json.load(f)
durations = data_dict["duration"]
train_dataset = CustomDataset(
train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
train_dataset,
durations=durations,
preprocessed_mel=preprocessed_mel,
mel_spec_module=mel_spec_module,
**mel_spec_kwargs,
)
elif dataset_type == "CustomDatasetPath":