From 712d52772ef496b6cd191ba6197bac6e112fddd8 Mon Sep 17 00:00:00 2001 From: ZhikangNiu Date: Thu, 31 Oct 2024 20:06:36 +0800 Subject: [PATCH] update Bigvgan vocoder and F5-bigvgan version, trained on Emilia ZH&EN, 1.25m updates --- .gitmodules | 3 + README.md | 13 ++- src/f5_tts/api.py | 23 ++-- src/f5_tts/eval/eval_infer_batch.py | 66 ++++++----- src/f5_tts/eval/utils_eval.py | 14 ++- src/f5_tts/infer/infer_cli.py | 29 +++-- src/f5_tts/infer/speech_edit.py | 53 ++++----- src/f5_tts/infer/utils_infer.py | 78 ++++++++----- src/f5_tts/model/cfm.py | 22 ++-- src/f5_tts/model/dataset.py | 28 ++++- src/f5_tts/model/modules.py | 172 +++++++++++++++++++++++----- src/f5_tts/model/trainer.py | 24 ++-- src/f5_tts/train/train.py | 16 ++- src/third_party/BigVGAN | 1 + 14 files changed, 365 insertions(+), 177 deletions(-) create mode 100644 .gitmodules create mode 160000 src/third_party/BigVGAN diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..1f572cc --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "src/third_party/BigVGAN"] + path = src/third_party/BigVGAN + url = https://github.com/NVIDIA/BigVGAN.git diff --git a/README.md b/README.md index 239b2c1..3fcc8e4 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,18 @@ cd F5-TTS pip install -e . ``` -### 3. Docker usage +### 3. Init submodule( optional, if you want to change the vocoder from vocos to bigvgan) + +```bash +git submodule update --init --recursive +``` +After that, you need to change the `src/third_party/BigVGAN/bigvgan.py` by adding the following code at the beginning of the file. +```python +import sys +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +``` + +### 4. Docker usage ```bash # Build from Dockerfile docker build -t f5tts:v1 . diff --git a/src/f5_tts/api.py b/src/f5_tts/api.py index 823067d..71d60e2 100644 --- a/src/f5_tts/api.py +++ b/src/f5_tts/api.py @@ -1,24 +1,18 @@ import random import sys -import tqdm from importlib.resources import files import soundfile as sf import torch +import tqdm from cached_path import cached_path +from f5_tts.infer.utils_infer import (hop_length, infer_process, load_model, + load_vocoder, preprocess_ref_audio_text, + remove_silence_for_generated_wav, + save_spectrogram, target_sample_rate) from f5_tts.model import DiT, UNetT from f5_tts.model.utils import seed_everything -from f5_tts.infer.utils_infer import ( - load_vocoder, - load_model, - infer_process, - remove_silence_for_generated_wav, - save_spectrogram, - preprocess_ref_audio_text, - target_sample_rate, - hop_length, -) class F5TTS: @@ -29,6 +23,7 @@ class F5TTS: vocab_file="", ode_method="euler", use_ema=True, + vocoder_name="vocos", local_path=None, device=None, ): @@ -44,11 +39,11 @@ class F5TTS: ) # Load models - self.load_vocoder_model(local_path) + self.load_vocoder_model(vocoder_name, local_path) self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema) - def load_vocoder_model(self, local_path): - self.vocoder = load_vocoder(local_path is not None, local_path, self.device) + def load_vocoder_model(self, vocoder_name, local_path): + self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device) def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema): if model_type == "F5-TTS": diff --git a/src/f5_tts/eval/eval_infer_batch.py b/src/f5_tts/eval/eval_infer_batch.py index dda1dac..c9067aa 100644 --- a/src/f5_tts/eval/eval_infer_batch.py +++ b/src/f5_tts/eval/eval_infer_batch.py @@ -1,26 +1,23 @@ -import sys import os +import sys sys.path.append(os.getcwd()) -import time -from tqdm import tqdm import argparse +import time from importlib.resources import files import torch import torchaudio from accelerate import Accelerator -from vocos import Vocos +from tqdm import tqdm -from f5_tts.model import CFM, UNetT, DiT +from f5_tts.eval.utils_eval import (get_inference_prompt, + get_librispeech_test_clean_metainfo, + get_seedtts_testset_metainfo) +from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder +from f5_tts.model import CFM, DiT, UNetT from f5_tts.model.utils import get_tokenizer -from f5_tts.infer.utils_infer import load_checkpoint -from f5_tts.eval.utils_eval import ( - get_seedtts_testset_metainfo, - get_librispeech_test_clean_metainfo, - get_inference_prompt, -) accelerator = Accelerator() device = f"cuda:{accelerator.process_index}" @@ -31,8 +28,12 @@ device = f"cuda:{accelerator.process_index}" target_sample_rate = 24000 n_mel_channels = 100 hop_length = 256 +win_length = 1024 +n_fft = 1024 +extract_backend = "bigvgan" # 'vocos' or 'bigvgan' target_rms = 0.1 + tokenizer = "pinyin" rel_path = str(files("f5_tts").joinpath("../../")) @@ -123,14 +124,11 @@ def main(): # Vocoder model local = False - if local: - vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" - vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") - state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device) - vocos.load_state_dict(state_dict) - vocos.eval() - else: - vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") + if extract_backend == "vocos": + vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz" + elif extract_backend == "bigvgan": + vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" + vocoder = load_vocoder(vocoder_name=extract_backend, is_local=local, local_path=vocoder_local_path) # Tokenizer vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) @@ -139,9 +137,12 @@ def main(): model = CFM( transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=dict( - target_sample_rate=target_sample_rate, - n_mel_channels=n_mel_channels, + n_fft=n_fft, hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + extract_backend=extract_backend, ), odeint_kwargs=dict( method=ode_method, @@ -149,7 +150,8 @@ def main(): vocab_char_map=vocab_char_map, ).to(device) - model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema) + dtype = torch.float16 if extract_backend == "vocos" else torch.float32 + model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema) if not os.path.exists(output_dir) and accelerator.is_main_process: os.makedirs(output_dir) @@ -178,14 +180,18 @@ def main(): no_ref_audio=no_ref_audio, seed=seed, ) - # Final result - for i, gen in enumerate(generated): - gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) - gen_mel_spec = gen.permute(0, 2, 1) - generated_wave = vocos.decode(gen_mel_spec.cpu()) - if ref_rms_list[i] < target_rms: - generated_wave = generated_wave * ref_rms_list[i] / target_rms - torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate) + # Final result + for i, gen in enumerate(generated): + gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) + gen_mel_spec = gen.permute(0, 2, 1) + if extract_backend == "vocos": + generated_wave = vocoder.decode(gen_mel_spec.cpu()) + elif extract_backend == "bigvgan": + generated_wave = vocoder(gen_mel_spec) + + if ref_rms_list[i] < target_rms: + generated_wave = generated_wave * ref_rms_list[i] / target_rms + torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.squeeze(0).cpu(), target_sample_rate) accelerator.wait_for_everyone() if accelerator.is_main_process: diff --git a/src/f5_tts/eval/utils_eval.py b/src/f5_tts/eval/utils_eval.py index c2cf38e..3b79268 100644 --- a/src/f5_tts/eval/utils_eval.py +++ b/src/f5_tts/eval/utils_eval.py @@ -2,15 +2,15 @@ import math import os import random import string -from tqdm import tqdm import torch import torch.nn.functional as F import torchaudio +from tqdm import tqdm +from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL from f5_tts.model.modules import MelSpec from f5_tts.model.utils import convert_char_to_pinyin -from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav @@ -74,8 +74,11 @@ def get_inference_prompt( tokenizer="pinyin", polyphone=True, target_sample_rate=24000, + n_fft=1024, + win_length=1024, n_mel_channels=100, hop_length=256, + extract_backend="bigvgan", target_rms=0.1, use_truth_duration=False, infer_batch_size=1, @@ -94,7 +97,12 @@ def get_inference_prompt( ) mel_spectrogram = MelSpec( - target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + extract_backend=extract_backend, ) for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."): diff --git a/src/f5_tts/infer/infer_cli.py b/src/f5_tts/infer/infer_cli.py index 1d9b319..6c0deb7 100644 --- a/src/f5_tts/infer/infer_cli.py +++ b/src/f5_tts/infer/infer_cli.py @@ -2,23 +2,18 @@ import argparse import codecs import os import re -from pathlib import Path from importlib.resources import files +from pathlib import Path import numpy as np import soundfile as sf import tomli from cached_path import cached_path +from f5_tts.infer.utils_infer import (infer_process, load_model, load_vocoder, + preprocess_ref_audio_text, + remove_silence_for_generated_wav) from f5_tts.model import DiT, UNetT -from f5_tts.infer.utils_infer import ( - load_vocoder, - load_model, - preprocess_ref_audio_text, - infer_process, - remove_silence_for_generated_wav, -) - parser = argparse.ArgumentParser( prog="python3 infer-cli.py", @@ -70,6 +65,7 @@ parser.add_argument( "--remove_silence", help="Remove silence.", ) +parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name") parser.add_argument( "--load_vocoder_from_local", action="store_true", @@ -111,9 +107,14 @@ remove_silence = args.remove_silence if args.remove_silence else config["remove_ speed = args.speed wave_path = Path(output_dir) / "infer_cli_out.wav" # spectrogram_path = Path(output_dir) / "infer_cli_out.png" -vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" +if args.vocoder_name == "vocos": + vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz" +elif args.vocoder_name == "bigvgan": + vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" -vocoder = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path) +vocoder = load_vocoder( + vocoder_name=args.vocoder_name, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path +) # load models @@ -136,6 +137,12 @@ elif model == "E2-TTS": ckpt_step = 1200000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path + elif args.vocoder_name == "bigvgan": # TODO: need to test + repo_name = "F5-TTS" + exp_name = "F5TTS_Base_bigvgan" + ckpt_step = 1250000 + ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")) + print(f"Using {model}...") ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file) diff --git a/src/f5_tts/infer/speech_edit.py b/src/f5_tts/infer/speech_edit.py index 5a8176a..4edbc11 100644 --- a/src/f5_tts/infer/speech_edit.py +++ b/src/f5_tts/infer/speech_edit.py @@ -3,17 +3,11 @@ import os import torch import torch.nn.functional as F import torchaudio -from vocos import Vocos -from f5_tts.model import CFM, UNetT, DiT -from f5_tts.model.utils import ( - get_tokenizer, - convert_char_to_pinyin, -) -from f5_tts.infer.utils_infer import ( - load_checkpoint, - save_spectrogram, -) +from f5_tts.infer.utils_infer import (load_checkpoint, load_vocoder, + save_spectrogram) +from f5_tts.model import CFM, DiT, UNetT +from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" @@ -23,6 +17,9 @@ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is target_sample_rate = 24000 n_mel_channels = 100 hop_length = 256 +win_length = 1024 +n_fft = 1024 +extract_backend = "bigvgan" # 'vocos' or 'bigvgan' target_rms = 0.1 tokenizer = "pinyin" @@ -89,15 +86,11 @@ if not os.path.exists(output_dir): # Vocoder model local = False -if local: - vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" - vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") - state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device) - vocos.load_state_dict(state_dict) - - vocos.eval() -else: - vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") +if extract_backend == "vocos": + vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz" +elif extract_backend == "bigvgan": + vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" +vocoder = load_vocoder(vocoder_name=extract_backend, is_local=local, local_path=vocoder_local_path) # Tokenizer vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) @@ -106,9 +99,12 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) model = CFM( transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=dict( - target_sample_rate=target_sample_rate, - n_mel_channels=n_mel_channels, + n_fft=n_fft, hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + extract_backend=extract_backend, ), odeint_kwargs=dict( method=ode_method, @@ -116,7 +112,8 @@ model = CFM( vocab_char_map=vocab_char_map, ).to(device) -model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema) +dtype = torch.float16 if extract_backend == "vocos" else torch.float32 +model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema) # Audio audio, sr = torchaudio.load(audio_to_edit) @@ -181,11 +178,15 @@ print(f"Generated mel: {generated.shape}") # Final result generated = generated.to(torch.float32) generated = generated[:, ref_audio_len:, :] -generated_mel_spec = generated.permute(0, 2, 1) -generated_wave = vocos.decode(generated_mel_spec.cpu()) +gen_mel_spec = generated.permute(0, 2, 1) +if extract_backend == "vocos": + generated_wave = vocoder.decode(gen_mel_spec.cpu()) +elif extract_backend == "bigvgan": + generated_wave = vocoder(gen_mel_spec) + if rms < target_rms: generated_wave = generated_wave * rms / target_rms -save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png") -torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate) +save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png") +torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.squeeze(0).cpu(), target_sample_rate) print(f"Generated wav: {generated_wave.shape}") diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index be48b5a..7d1af69 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -1,6 +1,10 @@ # A unified script for inference process # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format +import os +import sys +sys.path.append(f"../../{os.path.dirname(os.path.abspath(__file__))}/third_party/BigVGAN/") +from third_party.BigVGAN import bigvgan import hashlib import re import tempfile @@ -34,6 +38,9 @@ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is target_sample_rate = 24000 n_mel_channels = 100 hop_length = 256 +win_length = 1024 +n_fft = 1024 +extract_backend = "bigvgan" # 'vocos' or 'bigvgan' target_rms = 0.1 cross_fade_duration = 0.15 ode_method = "euler" @@ -80,17 +87,28 @@ def chunk_text(text, max_chars=135): # load vocoder -def load_vocoder(is_local=False, local_path="", device=device): - if is_local: - print(f"Load vocos from local path {local_path}") - vocos = Vocos.from_hparams(f"{local_path}/config.yaml") - state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location=device) - vocos.load_state_dict(state_dict) - vocos.eval() - else: - print("Download Vocos from huggingface charactr/vocos-mel-24khz") - vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") - return vocos +def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device): + if vocoder_name == "vocos": + if is_local: + print(f"Load vocos from local path {local_path}") + vocoder = Vocos.from_hparams(f"{local_path}/config.yaml") + state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location="cpu") + vocoder.load_state_dict(state_dict) + vocoder.eval() + vocoder = vocoder.eval().to(device) + else: + print("Download Vocos from huggingface charactr/vocos-mel-24khz") + vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz") + elif vocoder_name == "bigvgan": + if is_local: + """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main""" + vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False) + else: + vocoder = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False) + + vocoder.remove_weight_norm() + vocoder = vocoder.eval().to(device) + return vocoder # load asr pipeline @@ -111,9 +129,8 @@ def initialize_asr_pipeline(device=device): # load model checkpoint for inference -def load_checkpoint(model, ckpt_path, device, use_ema=True): - if device == "cuda": - model = model.half() +def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True): + model = model.to(dtype) ckpt_type = ckpt_path.split(".")[-1] if ckpt_type == "safetensors": @@ -156,9 +173,12 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me model = CFM( transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=dict( - target_sample_rate=target_sample_rate, - n_mel_channels=n_mel_channels, + n_fft=n_fft, hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + extract_backend=extract_backend, ), odeint_kwargs=dict( method=ode_method, @@ -166,7 +186,8 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me vocab_char_map=vocab_char_map, ).to(device) - model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema) + dtype = torch.float16 if extract_backend == "vocos" else torch.float32 + model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema) return model @@ -359,18 +380,21 @@ def infer_batch_process( sway_sampling_coef=sway_sampling_coef, ) - generated = generated.to(torch.float32) - generated = generated[:, ref_audio_len:, :] - generated_mel_spec = generated.permute(0, 2, 1) - generated_wave = vocoder.decode(generated_mel_spec.cpu()) - if rms < target_rms: - generated_wave = generated_wave * rms / target_rms + generated = generated.to(torch.float32) + generated = generated[:, ref_audio_len:, :] + generated_mel_spec = generated.permute(0, 2, 1) + if extract_backend == "vocos": + generated_wave = vocoder.decode(generated_mel_spec.cpu()) + elif extract_backend == "bigvgan": + generated_wave = vocoder(generated_mel_spec) + if rms < target_rms: + generated_wave = generated_wave * rms / target_rms - # wav -> numpy - generated_wave = generated_wave.squeeze().cpu().numpy() + # wav -> numpy + generated_wave = generated_wave.squeeze().cpu().numpy() - generated_waves.append(generated_wave) - spectrograms.append(generated_mel_spec[0].cpu().numpy()) + generated_waves.append(generated_wave) + spectrograms.append(generated_mel_spec[0].cpu().numpy()) # Combine all generated waves with cross-fading if cross_fade_duration <= 0: diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index 2a300a5..7d3f639 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -8,25 +8,19 @@ d - dimension """ from __future__ import annotations -from typing import Callable + from random import random +from typing import Callable import torch -from torch import nn import torch.nn.functional as F +from torch import nn from torch.nn.utils.rnn import pad_sequence - from torchdiffeq import odeint from f5_tts.model.modules import MelSpec -from f5_tts.model.utils import ( - default, - exists, - list_str_to_idx, - list_str_to_tensor, - lens_to_mask, - mask_from_frac_lengths, -) +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) class CFM(nn.Module): @@ -99,8 +93,10 @@ class CFM(nn.Module): ): self.eval() - if next(self.parameters()).dtype == torch.float16: - cond = cond.half() + assert next(self.parameters()).dtype == torch.float32 or next(self.parameters()).dtype == torch.float16, print( + "Only support fp16 and fp32 inference currently" + ) + cond = cond.to(next(self.parameters()).dtype) # raw wave diff --git a/src/f5_tts/model/dataset.py b/src/f5_tts/model/dataset.py index 48e245e..93ddbe0 100644 --- a/src/f5_tts/model/dataset.py +++ b/src/f5_tts/model/dataset.py @@ -1,15 +1,15 @@ import json import random from importlib.resources import files -from tqdm import tqdm import torch import torch.nn.functional as F import torchaudio +from datasets import Dataset as Dataset_ +from datasets import load_from_disk from torch import nn from torch.utils.data import Dataset, Sampler -from datasets import load_from_disk -from datasets import Dataset as Dataset_ +from tqdm import tqdm from f5_tts.model.modules import MelSpec from f5_tts.model.utils import default @@ -22,12 +22,21 @@ class HFDataset(Dataset): target_sample_rate=24_000, n_mel_channels=100, hop_length=256, + n_fft=1024, + win_length=1024, + extract_backend="vocos", ): self.data = hf_dataset self.target_sample_rate = target_sample_rate self.hop_length = hop_length + self.mel_spectrogram = MelSpec( - target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + extract_backend=extract_backend, ) def get_frame_len(self, index): @@ -79,6 +88,9 @@ class CustomDataset(Dataset): target_sample_rate=24_000, hop_length=256, n_mel_channels=100, + n_fft=1024, + win_length=1024, + extract_backend="vocos", preprocessed_mel=False, mel_spec_module: nn.Module | None = None, ): @@ -86,15 +98,21 @@ class CustomDataset(Dataset): self.durations = durations self.target_sample_rate = target_sample_rate self.hop_length = hop_length + self.n_fft = n_fft + self.win_length = win_length + self.extract_backend = extract_backend self.preprocessed_mel = preprocessed_mel if not preprocessed_mel: self.mel_spectrogram = default( mel_spec_module, MelSpec( - target_sample_rate=target_sample_rate, + n_fft=n_fft, hop_length=hop_length, + win_length=win_length, n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + extract_backend=extract_backend, ), ) diff --git a/src/f5_tts/model/modules.py b/src/f5_tts/model/modules.py index c026eff..27ca19b 100644 --- a/src/f5_tts/model/modules.py +++ b/src/f5_tts/model/modules.py @@ -8,61 +8,173 @@ d - dimension """ from __future__ import annotations -from typing import Optional + import math +from typing import Optional import torch -from torch import nn import torch.nn.functional as F import torchaudio - +from librosa.filters import mel as librosa_mel_fn +from torch import nn from x_transformers.x_transformers import apply_rotary_pos_emb - # raw wav to mel spec +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + return dynamic_range_compression_torch(magnitudes) + + +mel_basis_cache = {} +hann_window_cache = {} + + +# BigVGAN extract mel spectrogram +def mel_spectrogram( + y: torch.Tensor, + n_fft: int, + num_mels: int, + sampling_rate: int, + hop_size: int, + win_size: int, + fmin: int, + fmax: int = None, + center: bool = False, +) -> torch.Tensor: + """Copy from https://github.com/NVIDIA/BigVGAN/tree/main""" + device = y.device + key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}" + + if key not in mel_basis_cache: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()? + hann_window_cache[key] = torch.hann_window(win_size).to(device) + + mel_basis = mel_basis_cache[key] + hann_window = hann_window_cache[key] + + padding = (n_fft - hop_size) // 2 + y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) + + mel_spec = torch.matmul(mel_basis, spec) + mel_spec = spectral_normalize_torch(mel_spec) + + return mel_spec + + +def get_bigvgan_mel_spectrogram( + waveform, + n_fft=1024, + n_mel_channels=100, + target_sample_rate=24000, + hop_length=256, + win_length=1024, +): + return mel_spectrogram( + waveform, + n_fft, # 1024 + n_mel_channels, # 100 + target_sample_rate, # 24000 + hop_length, # 256 + win_length, # 1024 + fmin=0, # 0 + fmax=None, # null + ) + + +def get_vocos_mel_spectrogram( + waveform, + n_fft=1024, + n_mel_channels=100, + target_sample_rate=24000, + hop_length=256, + win_length=1024, +): + mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=target_sample_rate, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mel_channels, + power=1, + center=True, + normalized=False, + norm=None, + ) + if len(waveform.shape) == 3: + waveform = waveform.squeeze(1) # 'b 1 nw -> b nw' + + assert len(waveform.shape) == 2 + + mel = mel_stft(waveform) + mel = mel.clamp(min=1e-5).log() + return mel + + class MelSpec(nn.Module): def __init__( self, - filter_length=1024, + n_fft=1024, hop_length=256, win_length=1024, n_mel_channels=100, target_sample_rate=24_000, - normalize=False, - power=1, - norm=None, - center=True, + extract_backend="vocos", ): super().__init__() - self.n_mel_channels = n_mel_channels - - self.mel_stft = torchaudio.transforms.MelSpectrogram( - sample_rate=target_sample_rate, - n_fft=filter_length, - win_length=win_length, - hop_length=hop_length, - n_mels=n_mel_channels, - power=power, - center=center, - normalized=normalize, - norm=norm, + assert extract_backend in ["vocos", "bigvgan"], print( + "We only support two extract mel backend: vocos or bigvgan" ) + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.n_mel_channels = n_mel_channels + self.target_sample_rate = target_sample_rate + + if extract_backend == "vocos": + self.extractor = get_vocos_mel_spectrogram + elif extract_backend == "bigvgan": + self.extractor = get_bigvgan_mel_spectrogram + self.register_buffer("dummy", torch.tensor(0), persistent=False) - def forward(self, inp): - if len(inp.shape) == 3: - inp = inp.squeeze(1) # 'b 1 nw -> b nw' + def forward(self, wav): + if self.dummy.device != wav.device: + self.to(wav.device) - assert len(inp.shape) == 2 + mel = self.extractor( + waveform=wav, + n_fft=self.n_fft, + n_mel_channels=self.n_mel_channels, + target_sample_rate=self.target_sample_rate, + hop_length=self.hop_length, + win_length=self.win_length, + ) - if self.dummy.device != inp.device: - self.to(inp.device) - - mel = self.mel_stft(inp) - mel = mel.clamp(min=1e-5).log() return mel diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 08903d6..363f0a1 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -1,25 +1,22 @@ from __future__ import annotations -import os import gc -from tqdm import tqdm -import wandb +import os import torch import torchaudio -from torch.optim import AdamW -from torch.utils.data import DataLoader, Dataset, SequentialSampler -from torch.optim.lr_scheduler import LinearLR, SequentialLR - +import wandb from accelerate import Accelerator from accelerate.utils import DistributedDataParallelKwargs - from ema_pytorch import EMA +from torch.optim import AdamW +from torch.optim.lr_scheduler import LinearLR, SequentialLR +from torch.utils.data import DataLoader, Dataset, SequentialSampler +from tqdm import tqdm from f5_tts.model import CFM -from f5_tts.model.utils import exists, default from f5_tts.model.dataset import DynamicBatchSampler, collate_fn - +from f5_tts.model.utils import default, exists # trainer @@ -49,6 +46,7 @@ class Trainer: accelerate_kwargs: dict = dict(), ema_kwargs: dict = dict(), bnb_optimizer: bool = False, + extract_backend: str = "vocos", # "vocos" | "bigvgan" ): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) @@ -110,6 +108,7 @@ class Trainer: self.max_samples = max_samples self.grad_accumulation_steps = grad_accumulation_steps self.max_grad_norm = max_grad_norm + self.vocoder_name = extract_backend self.noise_scheduler = noise_scheduler @@ -188,9 +187,10 @@ class Trainer: def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None): if self.log_samples: - from f5_tts.infer.utils_infer import load_vocoder, nfe_step, cfg_strength, sway_sampling_coef + from f5_tts.infer.utils_infer import (cfg_strength, load_vocoder, + nfe_step, sway_sampling_coef) - vocoder = load_vocoder() + vocoder = load_vocoder(vocoder_name=self.vocoder_name) target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.mel_stft.sample_rate log_samples_path = f"{self.checkpoint_path}/samples" os.makedirs(log_samples_path, exist_ok=True) diff --git a/src/f5_tts/train/train.py b/src/f5_tts/train/train.py index 94fe9b5..44e8cb4 100644 --- a/src/f5_tts/train/train.py +++ b/src/f5_tts/train/train.py @@ -2,16 +2,18 @@ from importlib.resources import files -from f5_tts.model import CFM, UNetT, DiT, Trainer -from f5_tts.model.utils import get_tokenizer +from f5_tts.model import CFM, DiT, Trainer, UNetT from f5_tts.model.dataset import load_dataset - +from f5_tts.model.utils import get_tokenizer # -------------------------- Dataset Settings --------------------------- # target_sample_rate = 24000 n_mel_channels = 100 hop_length = 256 +win_length = 1024 +n_fft = 1024 +extract_backend = "bigvgan" # 'vocos' or 'bigvgan' tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) @@ -56,9 +58,12 @@ def main(): vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) mel_spec_kwargs = dict( - target_sample_rate=target_sample_rate, - n_mel_channels=n_mel_channels, + n_fft=n_fft, hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + extract_backend=extract_backend, ) model = CFM( @@ -84,6 +89,7 @@ def main(): wandb_resume_id=wandb_resume_id, last_per_steps=last_per_steps, log_samples=True, + extract_backend=extract_backend, ) train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) diff --git a/src/third_party/BigVGAN b/src/third_party/BigVGAN new file mode 160000 index 0000000..7d2b454 --- /dev/null +++ b/src/third_party/BigVGAN @@ -0,0 +1 @@ +Subproject commit 7d2b454564a6c7d014227f635b7423881f14bdac