mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-13 21:47:14 -08:00
allow for passing in custom mel spec module (#200)
This commit is contained in:
@@ -8,8 +8,10 @@ from torch.utils.data import Dataset, Sampler
|
||||
import torchaudio
|
||||
from datasets import load_from_disk
|
||||
from datasets import Dataset as Dataset_
|
||||
from torch import nn
|
||||
|
||||
from model.modules import MelSpec
|
||||
from model.utils import default
|
||||
|
||||
|
||||
class HFDataset(Dataset):
|
||||
@@ -77,15 +79,22 @@ class CustomDataset(Dataset):
|
||||
hop_length=256,
|
||||
n_mel_channels=100,
|
||||
preprocessed_mel=False,
|
||||
mel_spec_module: nn.Module | None = None,
|
||||
):
|
||||
self.data = custom_dataset
|
||||
self.durations = durations
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.hop_length = hop_length
|
||||
self.preprocessed_mel = preprocessed_mel
|
||||
|
||||
if not preprocessed_mel:
|
||||
self.mel_spectrogram = MelSpec(
|
||||
target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels
|
||||
self.mel_spectrogram = default(
|
||||
mel_spec_module,
|
||||
MelSpec(
|
||||
target_sample_rate=target_sample_rate,
|
||||
hop_length=hop_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
),
|
||||
)
|
||||
|
||||
def get_frame_len(self, index):
|
||||
@@ -201,6 +210,7 @@ def load_dataset(
|
||||
tokenizer: str = "pinyin",
|
||||
dataset_type: str = "CustomDataset",
|
||||
audio_type: str = "raw",
|
||||
mel_spec_module: nn.Module | None = None,
|
||||
mel_spec_kwargs: dict = dict(),
|
||||
) -> CustomDataset | HFDataset:
|
||||
"""
|
||||
@@ -224,7 +234,11 @@ def load_dataset(
|
||||
data_dict = json.load(f)
|
||||
durations = data_dict["duration"]
|
||||
train_dataset = CustomDataset(
|
||||
train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
|
||||
train_dataset,
|
||||
durations=durations,
|
||||
preprocessed_mel=preprocessed_mel,
|
||||
mel_spec_module=mel_spec_module,
|
||||
**mel_spec_kwargs,
|
||||
)
|
||||
|
||||
elif dataset_type == "CustomDatasetPath":
|
||||
|
||||
Reference in New Issue
Block a user