update Bigvgan vocoder and F5-bigvgan version, trained on Emilia ZH&EN, 1.25m updates

This commit is contained in:
ZhikangNiu
2024-10-31 20:06:36 +08:00
parent dee0420b59
commit 712d52772e
14 changed files with 365 additions and 177 deletions

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "src/third_party/BigVGAN"]
path = src/third_party/BigVGAN
url = https://github.com/NVIDIA/BigVGAN.git

View File

@@ -46,7 +46,18 @@ cd F5-TTS
pip install -e .
```
### 3. Docker usage
### 3. Init submodule( optional, if you want to change the vocoder from vocos to bigvgan)
```bash
git submodule update --init --recursive
```
After that, you need to change the `src/third_party/BigVGAN/bigvgan.py` by adding the following code at the beginning of the file.
```python
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
```
### 4. Docker usage
```bash
# Build from Dockerfile
docker build -t f5tts:v1 .

View File

@@ -1,24 +1,18 @@
import random
import sys
import tqdm
from importlib.resources import files
import soundfile as sf
import torch
import tqdm
from cached_path import cached_path
from f5_tts.infer.utils_infer import (hop_length, infer_process, load_model,
load_vocoder, preprocess_ref_audio_text,
remove_silence_for_generated_wav,
save_spectrogram, target_sample_rate)
from f5_tts.model import DiT, UNetT
from f5_tts.model.utils import seed_everything
from f5_tts.infer.utils_infer import (
load_vocoder,
load_model,
infer_process,
remove_silence_for_generated_wav,
save_spectrogram,
preprocess_ref_audio_text,
target_sample_rate,
hop_length,
)
class F5TTS:
@@ -29,6 +23,7 @@ class F5TTS:
vocab_file="",
ode_method="euler",
use_ema=True,
vocoder_name="vocos",
local_path=None,
device=None,
):
@@ -44,11 +39,11 @@ class F5TTS:
)
# Load models
self.load_vocoder_model(local_path)
self.load_vocoder_model(vocoder_name, local_path)
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
def load_vocoder_model(self, local_path):
self.vocoder = load_vocoder(local_path is not None, local_path, self.device)
def load_vocoder_model(self, vocoder_name, local_path):
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
if model_type == "F5-TTS":

View File

@@ -1,26 +1,23 @@
import sys
import os
import sys
sys.path.append(os.getcwd())
import time
from tqdm import tqdm
import argparse
import time
from importlib.resources import files
import torch
import torchaudio
from accelerate import Accelerator
from vocos import Vocos
from tqdm import tqdm
from f5_tts.model import CFM, UNetT, DiT
from f5_tts.eval.utils_eval import (get_inference_prompt,
get_librispeech_test_clean_metainfo,
get_seedtts_testset_metainfo)
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
from f5_tts.model import CFM, DiT, UNetT
from f5_tts.model.utils import get_tokenizer
from f5_tts.infer.utils_infer import load_checkpoint
from f5_tts.eval.utils_eval import (
get_seedtts_testset_metainfo,
get_librispeech_test_clean_metainfo,
get_inference_prompt,
)
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
@@ -31,8 +28,12 @@ device = f"cuda:{accelerator.process_index}"
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
extract_backend = "bigvgan" # 'vocos' or 'bigvgan'
target_rms = 0.1
tokenizer = "pinyin"
rel_path = str(files("f5_tts").joinpath("../../"))
@@ -123,14 +124,11 @@ def main():
# Vocoder model
local = False
if local:
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()
else:
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
if extract_backend == "vocos":
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
elif extract_backend == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=extract_backend, is_local=local, local_path=vocoder_local_path)
# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
@@ -139,9 +137,12 @@ def main():
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
extract_backend=extract_backend,
),
odeint_kwargs=dict(
method=ode_method,
@@ -149,7 +150,8 @@ def main():
vocab_char_map=vocab_char_map,
).to(device)
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
dtype = torch.float16 if extract_backend == "vocos" else torch.float32
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
if not os.path.exists(output_dir) and accelerator.is_main_process:
os.makedirs(output_dir)
@@ -178,14 +180,18 @@ def main():
no_ref_audio=no_ref_audio,
seed=seed,
)
# Final result
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1)
generated_wave = vocos.decode(gen_mel_spec.cpu())
if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
# Final result
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1)
if extract_backend == "vocos":
generated_wave = vocoder.decode(gen_mel_spec.cpu())
elif extract_backend == "bigvgan":
generated_wave = vocoder(gen_mel_spec)
if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
accelerator.wait_for_everyone()
if accelerator.is_main_process:

View File

@@ -2,15 +2,15 @@ import math
import os
import random
import string
from tqdm import tqdm
import torch
import torch.nn.functional as F
import torchaudio
from tqdm import tqdm
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import convert_char_to_pinyin
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
@@ -74,8 +74,11 @@ def get_inference_prompt(
tokenizer="pinyin",
polyphone=True,
target_sample_rate=24000,
n_fft=1024,
win_length=1024,
n_mel_channels=100,
hop_length=256,
extract_backend="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
@@ -94,7 +97,12 @@ def get_inference_prompt(
)
mel_spectrogram = MelSpec(
target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
extract_backend=extract_backend,
)
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):

View File

@@ -2,23 +2,18 @@ import argparse
import codecs
import os
import re
from pathlib import Path
from importlib.resources import files
from pathlib import Path
import numpy as np
import soundfile as sf
import tomli
from cached_path import cached_path
from f5_tts.infer.utils_infer import (infer_process, load_model, load_vocoder,
preprocess_ref_audio_text,
remove_silence_for_generated_wav)
from f5_tts.model import DiT, UNetT
from f5_tts.infer.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
)
parser = argparse.ArgumentParser(
prog="python3 infer-cli.py",
@@ -70,6 +65,7 @@ parser.add_argument(
"--remove_silence",
help="Remove silence.",
)
parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
parser.add_argument(
"--load_vocoder_from_local",
action="store_true",
@@ -111,9 +107,14 @@ remove_silence = args.remove_silence if args.remove_silence else config["remove_
speed = args.speed
wave_path = Path(output_dir) / "infer_cli_out.wav"
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
if args.vocoder_name == "vocos":
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
elif args.vocoder_name == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
vocoder = load_vocoder(
vocoder_name=args.vocoder_name, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path
)
# load models
@@ -136,6 +137,12 @@ elif model == "E2-TTS":
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
elif args.vocoder_name == "bigvgan": # TODO: need to test
repo_name = "F5-TTS"
exp_name = "F5TTS_Base_bigvgan"
ckpt_step = 1250000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
print(f"Using {model}...")
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)

View File

@@ -3,17 +3,11 @@ import os
import torch
import torch.nn.functional as F
import torchaudio
from vocos import Vocos
from f5_tts.model import CFM, UNetT, DiT
from f5_tts.model.utils import (
get_tokenizer,
convert_char_to_pinyin,
)
from f5_tts.infer.utils_infer import (
load_checkpoint,
save_spectrogram,
)
from f5_tts.infer.utils_infer import (load_checkpoint, load_vocoder,
save_spectrogram)
from f5_tts.model import CFM, DiT, UNetT
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
@@ -23,6 +17,9 @@ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
extract_backend = "bigvgan" # 'vocos' or 'bigvgan'
target_rms = 0.1
tokenizer = "pinyin"
@@ -89,15 +86,11 @@ if not os.path.exists(output_dir):
# Vocoder model
local = False
if local:
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()
else:
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
if extract_backend == "vocos":
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
elif extract_backend == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=extract_backend, is_local=local, local_path=vocoder_local_path)
# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
@@ -106,9 +99,12 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
extract_backend=extract_backend,
),
odeint_kwargs=dict(
method=ode_method,
@@ -116,7 +112,8 @@ model = CFM(
vocab_char_map=vocab_char_map,
).to(device)
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
dtype = torch.float16 if extract_backend == "vocos" else torch.float32
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
# Audio
audio, sr = torchaudio.load(audio_to_edit)
@@ -181,11 +178,15 @@ print(f"Generated mel: {generated.shape}")
# Final result
generated = generated.to(torch.float32)
generated = generated[:, ref_audio_len:, :]
generated_mel_spec = generated.permute(0, 2, 1)
generated_wave = vocos.decode(generated_mel_spec.cpu())
gen_mel_spec = generated.permute(0, 2, 1)
if extract_backend == "vocos":
generated_wave = vocoder.decode(gen_mel_spec.cpu())
elif extract_backend == "bigvgan":
generated_wave = vocoder(gen_mel_spec)
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
print(f"Generated wav: {generated_wave.shape}")

View File

@@ -1,6 +1,10 @@
# A unified script for inference process
# Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
import os
import sys
sys.path.append(f"../../{os.path.dirname(os.path.abspath(__file__))}/third_party/BigVGAN/")
from third_party.BigVGAN import bigvgan
import hashlib
import re
import tempfile
@@ -34,6 +38,9 @@ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
extract_backend = "bigvgan" # 'vocos' or 'bigvgan'
target_rms = 0.1
cross_fade_duration = 0.15
ode_method = "euler"
@@ -80,17 +87,28 @@ def chunk_text(text, max_chars=135):
# load vocoder
def load_vocoder(is_local=False, local_path="", device=device):
if is_local:
print(f"Load vocos from local path {local_path}")
vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
return vocos
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device):
if vocoder_name == "vocos":
if is_local:
print(f"Load vocos from local path {local_path}")
vocoder = Vocos.from_hparams(f"{local_path}/config.yaml")
state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location="cpu")
vocoder.load_state_dict(state_dict)
vocoder.eval()
vocoder = vocoder.eval().to(device)
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
elif vocoder_name == "bigvgan":
if is_local:
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
else:
vocoder = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False)
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to(device)
return vocoder
# load asr pipeline
@@ -111,9 +129,8 @@ def initialize_asr_pipeline(device=device):
# load model checkpoint for inference
def load_checkpoint(model, ckpt_path, device, use_ema=True):
if device == "cuda":
model = model.half()
def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True):
model = model.to(dtype)
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
@@ -156,9 +173,12 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
extract_backend=extract_backend,
),
odeint_kwargs=dict(
method=ode_method,
@@ -166,7 +186,8 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me
vocab_char_map=vocab_char_map,
).to(device)
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
dtype = torch.float16 if extract_backend == "vocos" else torch.float32
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
return model
@@ -359,18 +380,21 @@ def infer_batch_process(
sway_sampling_coef=sway_sampling_coef,
)
generated = generated.to(torch.float32)
generated = generated[:, ref_audio_len:, :]
generated_mel_spec = generated.permute(0, 2, 1)
generated_wave = vocoder.decode(generated_mel_spec.cpu())
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
generated = generated.to(torch.float32)
generated = generated[:, ref_audio_len:, :]
generated_mel_spec = generated.permute(0, 2, 1)
if extract_backend == "vocos":
generated_wave = vocoder.decode(generated_mel_spec.cpu())
elif extract_backend == "bigvgan":
generated_wave = vocoder(generated_mel_spec)
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
# wav -> numpy
generated_wave = generated_wave.squeeze().cpu().numpy()
# wav -> numpy
generated_wave = generated_wave.squeeze().cpu().numpy()
generated_waves.append(generated_wave)
spectrograms.append(generated_mel_spec[0].cpu().numpy())
generated_waves.append(generated_wave)
spectrograms.append(generated_mel_spec[0].cpu().numpy())
# Combine all generated waves with cross-fading
if cross_fade_duration <= 0:

View File

@@ -8,25 +8,19 @@ d - dimension
"""
from __future__ import annotations
from typing import Callable
from random import random
from typing import Callable
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torchdiffeq import odeint
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import (
default,
exists,
list_str_to_idx,
list_str_to_tensor,
lens_to_mask,
mask_from_frac_lengths,
)
from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx,
list_str_to_tensor, mask_from_frac_lengths)
class CFM(nn.Module):
@@ -99,8 +93,10 @@ class CFM(nn.Module):
):
self.eval()
if next(self.parameters()).dtype == torch.float16:
cond = cond.half()
assert next(self.parameters()).dtype == torch.float32 or next(self.parameters()).dtype == torch.float16, print(
"Only support fp16 and fp32 inference currently"
)
cond = cond.to(next(self.parameters()).dtype)
# raw wave

View File

@@ -1,15 +1,15 @@
import json
import random
from importlib.resources import files
from tqdm import tqdm
import torch
import torch.nn.functional as F
import torchaudio
from datasets import Dataset as Dataset_
from datasets import load_from_disk
from torch import nn
from torch.utils.data import Dataset, Sampler
from datasets import load_from_disk
from datasets import Dataset as Dataset_
from tqdm import tqdm
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import default
@@ -22,12 +22,21 @@ class HFDataset(Dataset):
target_sample_rate=24_000,
n_mel_channels=100,
hop_length=256,
n_fft=1024,
win_length=1024,
extract_backend="vocos",
):
self.data = hf_dataset
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.mel_spectrogram = MelSpec(
target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
extract_backend=extract_backend,
)
def get_frame_len(self, index):
@@ -79,6 +88,9 @@ class CustomDataset(Dataset):
target_sample_rate=24_000,
hop_length=256,
n_mel_channels=100,
n_fft=1024,
win_length=1024,
extract_backend="vocos",
preprocessed_mel=False,
mel_spec_module: nn.Module | None = None,
):
@@ -86,15 +98,21 @@ class CustomDataset(Dataset):
self.durations = durations
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.n_fft = n_fft
self.win_length = win_length
self.extract_backend = extract_backend
self.preprocessed_mel = preprocessed_mel
if not preprocessed_mel:
self.mel_spectrogram = default(
mel_spec_module,
MelSpec(
target_sample_rate=target_sample_rate,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
extract_backend=extract_backend,
),
)

View File

@@ -8,61 +8,173 @@ d - dimension
"""
from __future__ import annotations
from typing import Optional
import math
from typing import Optional
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
from librosa.filters import mel as librosa_mel_fn
from torch import nn
from x_transformers.x_transformers import apply_rotary_pos_emb
# raw wav to mel spec
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
return dynamic_range_compression_torch(magnitudes)
mel_basis_cache = {}
hann_window_cache = {}
# BigVGAN extract mel spectrogram
def mel_spectrogram(
y: torch.Tensor,
n_fft: int,
num_mels: int,
sampling_rate: int,
hop_size: int,
win_size: int,
fmin: int,
fmax: int = None,
center: bool = False,
) -> torch.Tensor:
"""Copy from https://github.com/NVIDIA/BigVGAN/tree/main"""
device = y.device
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
if key not in mel_basis_cache:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
hann_window_cache[key] = torch.hann_window(win_size).to(device)
mel_basis = mel_basis_cache[key]
hann_window = hann_window_cache[key]
padding = (n_fft - hop_size) // 2
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
mel_spec = torch.matmul(mel_basis, spec)
mel_spec = spectral_normalize_torch(mel_spec)
return mel_spec
def get_bigvgan_mel_spectrogram(
waveform,
n_fft=1024,
n_mel_channels=100,
target_sample_rate=24000,
hop_length=256,
win_length=1024,
):
return mel_spectrogram(
waveform,
n_fft, # 1024
n_mel_channels, # 100
target_sample_rate, # 24000
hop_length, # 256
win_length, # 1024
fmin=0, # 0
fmax=None, # null
)
def get_vocos_mel_spectrogram(
waveform,
n_fft=1024,
n_mel_channels=100,
target_sample_rate=24000,
hop_length=256,
win_length=1024,
):
mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=target_sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mel_channels,
power=1,
center=True,
normalized=False,
norm=None,
)
if len(waveform.shape) == 3:
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
assert len(waveform.shape) == 2
mel = mel_stft(waveform)
mel = mel.clamp(min=1e-5).log()
return mel
class MelSpec(nn.Module):
def __init__(
self,
filter_length=1024,
n_fft=1024,
hop_length=256,
win_length=1024,
n_mel_channels=100,
target_sample_rate=24_000,
normalize=False,
power=1,
norm=None,
center=True,
extract_backend="vocos",
):
super().__init__()
self.n_mel_channels = n_mel_channels
self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=target_sample_rate,
n_fft=filter_length,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mel_channels,
power=power,
center=center,
normalized=normalize,
norm=norm,
assert extract_backend in ["vocos", "bigvgan"], print(
"We only support two extract mel backend: vocos or bigvgan"
)
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.n_mel_channels = n_mel_channels
self.target_sample_rate = target_sample_rate
if extract_backend == "vocos":
self.extractor = get_vocos_mel_spectrogram
elif extract_backend == "bigvgan":
self.extractor = get_bigvgan_mel_spectrogram
self.register_buffer("dummy", torch.tensor(0), persistent=False)
def forward(self, inp):
if len(inp.shape) == 3:
inp = inp.squeeze(1) # 'b 1 nw -> b nw'
def forward(self, wav):
if self.dummy.device != wav.device:
self.to(wav.device)
assert len(inp.shape) == 2
mel = self.extractor(
waveform=wav,
n_fft=self.n_fft,
n_mel_channels=self.n_mel_channels,
target_sample_rate=self.target_sample_rate,
hop_length=self.hop_length,
win_length=self.win_length,
)
if self.dummy.device != inp.device:
self.to(inp.device)
mel = self.mel_stft(inp)
mel = mel.clamp(min=1e-5).log()
return mel

View File

@@ -1,25 +1,22 @@
from __future__ import annotations
import os
import gc
from tqdm import tqdm
import wandb
import os
import torch
import torchaudio
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from torch.optim.lr_scheduler import LinearLR, SequentialLR
import wandb
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from ema_pytorch import EMA
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from tqdm import tqdm
from f5_tts.model import CFM
from f5_tts.model.utils import exists, default
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
from f5_tts.model.utils import default, exists
# trainer
@@ -49,6 +46,7 @@ class Trainer:
accelerate_kwargs: dict = dict(),
ema_kwargs: dict = dict(),
bnb_optimizer: bool = False,
extract_backend: str = "vocos", # "vocos" | "bigvgan"
):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
@@ -110,6 +108,7 @@ class Trainer:
self.max_samples = max_samples
self.grad_accumulation_steps = grad_accumulation_steps
self.max_grad_norm = max_grad_norm
self.vocoder_name = extract_backend
self.noise_scheduler = noise_scheduler
@@ -188,9 +187,10 @@ class Trainer:
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
if self.log_samples:
from f5_tts.infer.utils_infer import load_vocoder, nfe_step, cfg_strength, sway_sampling_coef
from f5_tts.infer.utils_infer import (cfg_strength, load_vocoder,
nfe_step, sway_sampling_coef)
vocoder = load_vocoder()
vocoder = load_vocoder(vocoder_name=self.vocoder_name)
target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.mel_stft.sample_rate
log_samples_path = f"{self.checkpoint_path}/samples"
os.makedirs(log_samples_path, exist_ok=True)

View File

@@ -2,16 +2,18 @@
from importlib.resources import files
from f5_tts.model import CFM, UNetT, DiT, Trainer
from f5_tts.model.utils import get_tokenizer
from f5_tts.model import CFM, DiT, Trainer, UNetT
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
# -------------------------- Dataset Settings --------------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
extract_backend = "bigvgan" # 'vocos' or 'bigvgan'
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
@@ -56,9 +58,12 @@ def main():
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
mel_spec_kwargs = dict(
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
extract_backend=extract_backend,
)
model = CFM(
@@ -84,6 +89,7 @@ def main():
wandb_resume_id=wandb_resume_id,
last_per_steps=last_per_steps,
log_samples=True,
extract_backend=extract_backend,
)
train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)

1
src/third_party/BigVGAN vendored Submodule

Submodule src/third_party/BigVGAN added at 7d2b454564