mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-28 13:46:20 -08:00
243 lines
8.3 KiB
Python
243 lines
8.3 KiB
Python
import json
|
|
import random
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset, Sampler
|
|
import torchaudio
|
|
from datasets import load_dataset, load_from_disk
|
|
from datasets import Dataset as Dataset_
|
|
|
|
from einops import rearrange
|
|
|
|
from model.modules import MelSpec
|
|
|
|
|
|
class HFDataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
hf_dataset: Dataset,
|
|
target_sample_rate = 24_000,
|
|
n_mel_channels = 100,
|
|
hop_length = 256,
|
|
):
|
|
self.data = hf_dataset
|
|
self.target_sample_rate = target_sample_rate
|
|
self.hop_length = hop_length
|
|
self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
|
|
|
|
def get_frame_len(self, index):
|
|
row = self.data[index]
|
|
audio = row['audio']['array']
|
|
sample_rate = row['audio']['sampling_rate']
|
|
return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
row = self.data[index]
|
|
audio = row['audio']['array']
|
|
|
|
# logger.info(f"Audio shape: {audio.shape}")
|
|
|
|
sample_rate = row['audio']['sampling_rate']
|
|
duration = audio.shape[-1] / sample_rate
|
|
|
|
if duration > 30 or duration < 0.3:
|
|
return self.__getitem__((index + 1) % len(self.data))
|
|
|
|
audio_tensor = torch.from_numpy(audio).float()
|
|
|
|
if sample_rate != self.target_sample_rate:
|
|
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
|
|
audio_tensor = resampler(audio_tensor)
|
|
|
|
audio_tensor = rearrange(audio_tensor, 't -> 1 t')
|
|
|
|
mel_spec = self.mel_spectrogram(audio_tensor)
|
|
|
|
mel_spec = rearrange(mel_spec, '1 d t -> d t')
|
|
|
|
text = row['text']
|
|
|
|
return dict(
|
|
mel_spec = mel_spec,
|
|
text = text,
|
|
)
|
|
|
|
|
|
class CustomDataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
custom_dataset: Dataset,
|
|
durations = None,
|
|
target_sample_rate = 24_000,
|
|
hop_length = 256,
|
|
n_mel_channels = 100,
|
|
preprocessed_mel = False,
|
|
):
|
|
self.data = custom_dataset
|
|
self.durations = durations
|
|
self.target_sample_rate = target_sample_rate
|
|
self.hop_length = hop_length
|
|
self.preprocessed_mel = preprocessed_mel
|
|
if not preprocessed_mel:
|
|
self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
|
|
|
|
def get_frame_len(self, index):
|
|
if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
|
|
return self.durations[index] * self.target_sample_rate / self.hop_length
|
|
return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
row = self.data[index]
|
|
audio_path = row["audio_path"]
|
|
text = row["text"]
|
|
duration = row["duration"]
|
|
|
|
if self.preprocessed_mel:
|
|
mel_spec = torch.tensor(row["mel_spec"])
|
|
|
|
else:
|
|
audio, source_sample_rate = torchaudio.load(audio_path)
|
|
|
|
if duration > 30 or duration < 0.3:
|
|
return self.__getitem__((index + 1) % len(self.data))
|
|
|
|
if source_sample_rate != self.target_sample_rate:
|
|
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
|
|
audio = resampler(audio)
|
|
|
|
mel_spec = self.mel_spectrogram(audio)
|
|
mel_spec = rearrange(mel_spec, '1 d t -> d t')
|
|
|
|
return dict(
|
|
mel_spec = mel_spec,
|
|
text = text,
|
|
)
|
|
|
|
|
|
# Dynamic Batch Sampler
|
|
|
|
class DynamicBatchSampler(Sampler[list[int]]):
|
|
""" Extension of Sampler that will do the following:
|
|
1. Change the batch size (essentially number of sequences)
|
|
in a batch to ensure that the total number of frames are less
|
|
than a certain threshold.
|
|
2. Make sure the padding efficiency in the batch is high.
|
|
"""
|
|
|
|
def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
|
|
self.sampler = sampler
|
|
self.frames_threshold = frames_threshold
|
|
self.max_samples = max_samples
|
|
|
|
indices, batches = [], []
|
|
data_source = self.sampler.data_source
|
|
|
|
for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
|
|
indices.append((idx, data_source.get_frame_len(idx)))
|
|
indices.sort(key=lambda elem : elem[1])
|
|
|
|
batch = []
|
|
batch_frames = 0
|
|
for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
|
|
if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
|
|
batch.append(idx)
|
|
batch_frames += frame_len
|
|
else:
|
|
if len(batch) > 0:
|
|
batches.append(batch)
|
|
if frame_len <= self.frames_threshold:
|
|
batch = [idx]
|
|
batch_frames = frame_len
|
|
else:
|
|
batch = []
|
|
batch_frames = 0
|
|
|
|
if not drop_last and len(batch) > 0:
|
|
batches.append(batch)
|
|
|
|
del indices
|
|
|
|
# if want to have different batches between epochs, may just set a seed and log it in ckpt
|
|
# cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
|
|
# e.g. for epoch n, use (random_seed + n)
|
|
random.seed(random_seed)
|
|
random.shuffle(batches)
|
|
|
|
self.batches = batches
|
|
|
|
def __iter__(self):
|
|
return iter(self.batches)
|
|
|
|
def __len__(self):
|
|
return len(self.batches)
|
|
|
|
|
|
# Load dataset
|
|
|
|
def load_dataset(
|
|
dataset_name: str,
|
|
tokenizer: str,
|
|
dataset_type: str = "CustomDataset",
|
|
audio_type: str = "raw",
|
|
mel_spec_kwargs: dict = dict()
|
|
) -> CustomDataset | HFDataset:
|
|
|
|
print("Loading dataset ...")
|
|
|
|
if dataset_type == "CustomDataset":
|
|
if audio_type == "raw":
|
|
try:
|
|
train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
|
|
except:
|
|
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
|
|
preprocessed_mel = False
|
|
elif audio_type == "mel":
|
|
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
|
|
preprocessed_mel = True
|
|
with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
|
|
data_dict = json.load(f)
|
|
durations = data_dict["duration"]
|
|
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
|
|
|
|
elif dataset_type == "HFDataset":
|
|
print("Should manually modify the path of huggingface dataset to your need.\n" +
|
|
"May also the corresponding script cuz different dataset may have different format.")
|
|
pre, post = dataset_name.split("_")
|
|
train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
|
|
|
|
return train_dataset
|
|
|
|
|
|
# collation
|
|
|
|
def collate_fn(batch):
|
|
mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
|
|
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
|
|
max_mel_length = mel_lengths.amax()
|
|
|
|
padded_mel_specs = []
|
|
for spec in mel_specs: # TODO. maybe records mask for attention here
|
|
padding = (0, max_mel_length - spec.size(-1))
|
|
padded_spec = F.pad(spec, padding, value = 0)
|
|
padded_mel_specs.append(padded_spec)
|
|
|
|
mel_specs = torch.stack(padded_mel_specs)
|
|
|
|
text = [item['text'] for item in batch]
|
|
text_lengths = torch.LongTensor([len(item) for item in text])
|
|
|
|
return dict(
|
|
mel = mel_specs,
|
|
mel_lengths = mel_lengths,
|
|
text = text,
|
|
text_lengths = text_lengths,
|
|
)
|