diff --git a/README.md b/README.md index f727ebd..3802596 100644 --- a/README.md +++ b/README.md @@ -43,14 +43,10 @@ pip install git+https://github.com/SWivid/F5-TTS.git ```bash git clone https://github.com/SWivid/F5-TTS.git cd F5-TTS +# git submodule update --init --recursive # (optional, if need bigvgan) pip install -e . - -# Init submodule (optional, if you want to change the vocoder from vocos to bigvgan) -# git submodule update --init --recursive -# pip install -e . ``` - -After init submodule, you need to change the `src/third_party/BigVGAN/bigvgan.py` by adding the following code at the beginning of the file. +If initialize submodule, you should add the following code at the beginning of `src/third_party/BigVGAN/bigvgan.py`. ```python import os import sys diff --git a/src/f5_tts/eval/eval_infer_batch.py b/src/f5_tts/eval/eval_infer_batch.py index 4e0a06b..bbccd4f 100644 --- a/src/f5_tts/eval/eval_infer_batch.py +++ b/src/f5_tts/eval/eval_infer_batch.py @@ -120,6 +120,7 @@ def main(): target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length, + mel_spec_type=mel_spec_type, target_rms=target_rms, use_truth_duration=use_truth_duration, infer_batch_size=infer_batch_size, @@ -153,12 +154,7 @@ def main(): vocab_char_map=vocab_char_map, ).to(device) - supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 - if supports_fp16 and mel_spec_type == "vocos": - dtype = torch.float16 - elif mel_spec_type == "bigvgan": - dtype = torch.float32 - + dtype = torch.float32 if mel_spec_type == "bigvgan" else None model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) if not os.path.exists(output_dir) and accelerator.is_main_process: diff --git a/src/f5_tts/eval/utils_eval.py b/src/f5_tts/eval/utils_eval.py index a03d262..00cd97a 100644 --- a/src/f5_tts/eval/utils_eval.py +++ b/src/f5_tts/eval/utils_eval.py @@ -78,7 +78,7 @@ def get_inference_prompt( win_length=1024, n_mel_channels=100, hop_length=256, - mel_spec_type="bigvgan", + mel_spec_type="vocos", target_rms=0.1, use_truth_duration=False, infer_batch_size=1, diff --git a/src/f5_tts/infer/README.md b/src/f5_tts/infer/README.md index 0e84484..8582da4 100644 --- a/src/f5_tts/infer/README.md +++ b/src/f5_tts/infer/README.md @@ -58,8 +58,8 @@ f5-tts_infer-cli \ --gen_text "Some text you want TTS model generate for you." # Choose Vocoder -f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file -f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file +f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file +f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file ``` And a `.toml` file would help with more flexible usage. diff --git a/src/f5_tts/infer/speech_edit.py b/src/f5_tts/infer/speech_edit.py index 90cb127..c33b21f 100644 --- a/src/f5_tts/infer/speech_edit.py +++ b/src/f5_tts/infer/speech_edit.py @@ -111,12 +111,7 @@ model = CFM( vocab_char_map=vocab_char_map, ).to(device) -supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 -if supports_fp16 and mel_spec_type == "vocos": - dtype = torch.float16 -elif mel_spec_type == "bigvgan": - dtype = torch.float32 - +dtype = torch.float32 if mel_spec_type == "bigvgan" else None model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) # Audio diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index 5d897d2..8c2c04d 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -40,6 +40,7 @@ n_mel_channels = 100 hop_length = 256 win_length = 1024 n_fft = 1024 +mel_spec_type = "vocos" target_rms = 0.1 cross_fade_duration = 0.15 ode_method = "euler" @@ -131,7 +132,7 @@ def initialize_asr_pipeline(device=device): # load model checkpoint for inference -def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True): +def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True): if dtype is None: dtype = ( torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32 @@ -175,7 +176,7 @@ def load_model( model_cls, model_cfg, ckpt_path, - mel_spec_type="vocos", + mel_spec_type=mel_spec_type, vocab_file="", ode_method=ode_method, use_ema=True, @@ -206,12 +207,7 @@ def load_model( vocab_char_map=vocab_char_map, ).to(device) - supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 - if supports_fp16 and mel_spec_type == "vocos": - dtype = torch.float16 - elif mel_spec_type == "bigvgan": - dtype = torch.float32 - + dtype = torch.float32 if mel_spec_type == "bigvgan" else None model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) return model @@ -307,7 +303,7 @@ def infer_process( gen_text, model_obj, vocoder, - mel_spec_type="vocos", + mel_spec_type=mel_spec_type, show_info=print, progress=tqdm, target_rms=target_rms, diff --git a/src/f5_tts/model/modules.py b/src/f5_tts/model/modules.py index d3da679..bf67fff 100644 --- a/src/f5_tts/model/modules.py +++ b/src/f5_tts/model/modules.py @@ -19,57 +19,44 @@ from librosa.filters import mel as librosa_mel_fn from torch import nn from x_transformers.x_transformers import apply_rotary_pos_emb + # raw wav to mel spec -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression_torch(x, C=1): - return torch.exp(x) / C - - -def spectral_normalize_torch(magnitudes): - return dynamic_range_compression_torch(magnitudes) - - mel_basis_cache = {} hann_window_cache = {} -# BigVGAN extract mel spectrogram -def mel_spectrogram( - y: torch.Tensor, - n_fft: int, - num_mels: int, - sampling_rate: int, - hop_size: int, - win_size: int, - fmin: int, - fmax: int = None, - center: bool = False, -) -> torch.Tensor: - """Copy from https://github.com/NVIDIA/BigVGAN/tree/main""" - device = y.device - key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}" +def get_bigvgan_mel_spectrogram( + waveform, + n_fft=1024, + n_mel_channels=100, + target_sample_rate=24000, + hop_length=256, + win_length=1024, + fmin=0, + fmax=None, + center=False, +): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main + device = waveform.device + key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}" if key not in mel_basis_cache: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax) mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()? - hann_window_cache[key] = torch.hann_window(win_size).to(device) + hann_window_cache[key] = torch.hann_window(win_length).to(device) mel_basis = mel_basis_cache[key] hann_window = hann_window_cache[key] - padding = (n_fft - hop_size) // 2 - y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) + padding = (n_fft - hop_length) // 2 + waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) spec = torch.stft( - y, + waveform, n_fft, - hop_length=hop_size, - win_length=win_size, + hop_length=hop_length, + win_length=win_length, window=hann_window, center=center, pad_mode="reflect", @@ -80,31 +67,11 @@ def mel_spectrogram( spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) mel_spec = torch.matmul(mel_basis, spec) - mel_spec = spectral_normalize_torch(mel_spec) + mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5)) return mel_spec -def get_bigvgan_mel_spectrogram( - waveform, - n_fft=1024, - n_mel_channels=100, - target_sample_rate=24000, - hop_length=256, - win_length=1024, -): - return mel_spectrogram( - waveform, - n_fft, # 1024 - n_mel_channels, # 100 - target_sample_rate, # 24000 - hop_length, # 256 - win_length, # 1024 - fmin=0, # 0 - fmax=None, # null - ) - - def get_vocos_mel_spectrogram( waveform, n_fft=1024, diff --git a/src/f5_tts/train/train.py b/src/f5_tts/train/train.py index 9ef7db4..fac0fe5 100644 --- a/src/f5_tts/train/train.py +++ b/src/f5_tts/train/train.py @@ -13,7 +13,7 @@ n_mel_channels = 100 hop_length = 256 win_length = 1024 n_fft = 1024 -mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan' +mel_spec_type = "vocos" # 'vocos' or 'bigvgan' tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)