mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-26 12:51:16 -08:00
Compare commits
33 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b9156c0ad5 | ||
|
|
3ad3211915 | ||
|
|
f6726a78cc | ||
|
|
1d0cf2b8ba | ||
|
|
1d82b7928e | ||
|
|
4ae5347282 | ||
|
|
621559cbbe | ||
|
|
526b09eebd | ||
|
|
9afa80f204 | ||
|
|
c6b3189bbd | ||
|
|
c87ce39515 | ||
|
|
10ef27065b | ||
|
|
f374640f34 | ||
|
|
d5f4c88aa4 | ||
|
|
f968e13b6d | ||
|
|
339b17fed3 | ||
|
|
79302b694a | ||
|
|
a1e88c2a9e | ||
|
|
1ab90505a4 | ||
|
|
7e4985ca56 | ||
|
|
f05ceda4cb | ||
|
|
2bd39dd813 | ||
|
|
f017815083 | ||
|
|
297755fac3 | ||
|
|
d05075205f | ||
|
|
8722cf0766 | ||
|
|
48d1a9312e | ||
|
|
128f4e4bf3 | ||
|
|
2695e9305d | ||
|
|
69909ac167 | ||
|
|
79bbde5d76 | ||
|
|
bf651d541e | ||
|
|
ca6e49adaa |
12
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
12
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: "Bug Report"
|
||||
description: |
|
||||
Please provide as much details to help address the issue, including logs and screenshots.
|
||||
Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
|
||||
labels:
|
||||
- bug
|
||||
body:
|
||||
@@ -15,13 +15,13 @@ body:
|
||||
required: true
|
||||
- label: I have searched for existing issues, including closed ones, and couldn't find a solution.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report in order to facilitate communication.
|
||||
- label: I am using English to submit this issue to facilitate community communication.
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Environment Details
|
||||
description: "Provide details such as OS, Python version, and any relevant software or dependencies."
|
||||
placeholder: e.g., CentOS Linux 7, RTX 3090, Python 3.10, torch==2.3.0, cuda 11.8
|
||||
description: "Provide details including OS, GPU info, Python version, any relevant software or dependencies, and trainer setting."
|
||||
placeholder: e.g., CentOS Linux 7, 4 * RTX 3090, Python 3.10, torch==2.3.0+cu118, cuda 11.8, config yaml is ...
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
@@ -39,12 +39,12 @@ body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ✔️ Expected Behavior
|
||||
placeholder: Describe what you expected to happen.
|
||||
placeholder: Describe in detail what you expected to happen.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ❌ Actual Behavior
|
||||
placeholder: Describe what actually happened.
|
||||
placeholder: Describe in detail what actually happened.
|
||||
validations:
|
||||
required: false
|
||||
2
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
2
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
@@ -15,7 +15,7 @@ body:
|
||||
required: true
|
||||
- label: I have searched for existing issues, including closed ones, and found not discussion yet.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report in order to facilitate communication.
|
||||
- label: I am using English to submit this issue to facilitate community communication.
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
16
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
16
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: "Help Wanted"
|
||||
description: |
|
||||
Please provide as much details to help address the issue, including logs and screenshots.
|
||||
Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
|
||||
labels:
|
||||
- help wanted
|
||||
body:
|
||||
@@ -15,36 +15,40 @@ body:
|
||||
required: true
|
||||
- label: I have searched for existing issues, including closed ones, and couldn't find a solution.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report in order to facilitate communication.
|
||||
- label: I am using English to submit this issue to facilitate community communication.
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Environment Details
|
||||
description: "Provide details such as OS, Python version, and any relevant software or dependencies."
|
||||
placeholder: e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
|
||||
placeholder: |
|
||||
e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
|
||||
If training or finetuning related, provide detailed configuration including GPU info and training setup.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: |
|
||||
Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
|
||||
Include detailed steps, screenshots, and logs. Provide used prompt wav and text. Use the correct markdown syntax for code blocks.
|
||||
placeholder: |
|
||||
1. Create a new conda environment.
|
||||
2. Clone the repository and install as pip package.
|
||||
3. Run the command: `f5-tts_infer-gradio` with no ref_text provided.
|
||||
4. Stuck there with the following message... (attach logs and also error msg e.g. after ctrl-c).
|
||||
5. Prompt & generated wavs are [change suffix to .mp4 to enable direct upload or pack all to .zip].
|
||||
6. Reference audio's transcription or provided ref_text is `xxx`, and text to generate is `xxx`.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ✔️ Expected Behavior
|
||||
placeholder: Describe what you expected to happen, e.g. output a generated audio
|
||||
placeholder: Describe what you expected to happen in detail, e.g. output a generated audio.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ❌ Actual Behavior
|
||||
placeholder: Describe what actually happened, failure messages, etc.
|
||||
placeholder: Describe what actually happened in detail, failure messages, etc.
|
||||
validations:
|
||||
required: false
|
||||
6
.github/ISSUE_TEMPLATE/question.yml
vendored
6
.github/ISSUE_TEMPLATE/question.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: "Question"
|
||||
description: |
|
||||
Pure question or inquiry about the project, usage issue goes with "help wanted".
|
||||
Research question or pure inquiry about the project, usage issue goes with "help wanted".
|
||||
labels:
|
||||
- question
|
||||
body:
|
||||
@@ -9,13 +9,13 @@ body:
|
||||
label: Checks
|
||||
description: "To help us grasp quickly, please confirm the following:"
|
||||
options:
|
||||
- label: This template is only for question, not feature requests or bug reports.
|
||||
- label: This template is only for research question, not usage problems, feature requests or bug reports.
|
||||
required: true
|
||||
- label: I have thoroughly reviewed the project documentation and read the related paper(s).
|
||||
required: true
|
||||
- label: I have searched for existing issues, including closed ones, no similar questions.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report in order to facilitate communication.
|
||||
- label: I am using English to submit this issue to facilitate community communication.
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
66
.github/workflows/publish-pypi.yaml
vendored
Normal file
66
.github/workflows/publish-pypi.yaml
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
# GitHub recommends pinning actions to a commit SHA.
|
||||
# To get a newer version, you will need to update the SHA.
|
||||
# You can also reference a tag or branch, but the action may change without warning.
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
release-build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.x"
|
||||
|
||||
- name: Build release distributions
|
||||
run: |
|
||||
# NOTE: put your own distribution build steps here.
|
||||
python -m pip install build
|
||||
python -m build
|
||||
|
||||
- name: Upload distributions
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-dists
|
||||
path: dist/
|
||||
|
||||
pypi-publish:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
needs:
|
||||
- release-build
|
||||
|
||||
permissions:
|
||||
# IMPORTANT: this permission is mandatory for trusted publishing
|
||||
id-token: write
|
||||
|
||||
# Dedicated environments with protections for publishing are strongly recommended.
|
||||
environment:
|
||||
name: pypi
|
||||
# OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
|
||||
# url: https://pypi.org/p/YOURPROJECT
|
||||
|
||||
steps:
|
||||
- name: Retrieve release distributions
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: release-dists
|
||||
path: dist/
|
||||
|
||||
- name: Publish release distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -7,8 +7,6 @@ ckpts/
|
||||
wandb/
|
||||
results/
|
||||
|
||||
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.7.0
|
||||
rev: v0.11.2
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
@@ -9,6 +9,6 @@ repos:
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v2.3.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
|
||||
@@ -23,4 +23,8 @@ RUN git clone https://github.com/SWivid/F5-TTS.git \
|
||||
|
||||
ENV SHELL=/bin/bash
|
||||
|
||||
VOLUME /root/.cache/huggingface/hub/
|
||||
|
||||
EXPOSE 7860
|
||||
|
||||
WORKDIR /workspace/F5-TTS
|
||||
|
||||
29
README.md
29
README.md
@@ -18,6 +18,7 @@
|
||||
### Thanks to all the contributors !
|
||||
|
||||
## News
|
||||
- **2025/03/12**: 🔥 F5-TTS v1 base model with better training and inference performance. [Few demo](https://swivid.github.io/F5-TTS_updates).
|
||||
- **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
|
||||
|
||||
## Installation
|
||||
@@ -37,7 +38,7 @@ conda activate f5-tts
|
||||
|
||||
> ```bash
|
||||
> # Install pytorch with your CUDA version, e.g.
|
||||
> pip install torch==2.3.0+cu118 torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
> pip install torch==2.4.0+cu124 torchaudio==2.4.0+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
|
||||
> ```
|
||||
|
||||
</details>
|
||||
@@ -82,7 +83,7 @@ conda activate f5-tts
|
||||
> ### 1. As a pip package (if just for inference)
|
||||
>
|
||||
> ```bash
|
||||
> pip install git+https://github.com/SWivid/F5-TTS.git
|
||||
> pip install f5-tts
|
||||
> ```
|
||||
>
|
||||
> ### 2. Local editable (if also do training, finetuning)
|
||||
@@ -99,8 +100,11 @@ conda activate f5-tts
|
||||
# Build from Dockerfile
|
||||
docker build -t f5tts:v1 .
|
||||
|
||||
# Or pull from GitHub Container Registry
|
||||
docker pull ghcr.io/swivid/f5-tts:main
|
||||
# Run from GitHub Container Registry
|
||||
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main
|
||||
|
||||
# Quickstart if you want to just run the web interface (not CLI)
|
||||
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main f5-tts_infer-gradio --host 0.0.0.0
|
||||
```
|
||||
|
||||
|
||||
@@ -158,9 +162,8 @@ volumes:
|
||||
```bash
|
||||
# Run with flags
|
||||
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
|
||||
f5-tts_infer-cli \
|
||||
--model "F5-TTS" \
|
||||
--ref_audio "ref_audio.wav" \
|
||||
f5-tts_infer-cli --model F5TTS_v1_Base \
|
||||
--ref_audio "provide_prompt_wav_path_here.wav" \
|
||||
--ref_text "The content, subtitle or transcription of reference audio." \
|
||||
--gen_text "Some text you want TTS model generate for you."
|
||||
|
||||
@@ -181,22 +184,26 @@ f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
|
||||
|
||||
## Training
|
||||
|
||||
### 1. Gradio App
|
||||
### 1. With Hugging Face Accelerate
|
||||
|
||||
Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
|
||||
Refer to [training & finetuning guidance](src/f5_tts/train) for best practice.
|
||||
|
||||
### 2. With Gradio App
|
||||
|
||||
```bash
|
||||
# Quick start with Gradio web interface
|
||||
f5-tts_finetune-gradio
|
||||
```
|
||||
|
||||
Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
|
||||
|
||||
|
||||
## [Evaluation](src/f5_tts/eval)
|
||||
|
||||
|
||||
## Development
|
||||
|
||||
Use pre-commit to ensure code quality (will run linters and formatters automatically)
|
||||
Use pre-commit to ensure code quality (will run linters and formatters automatically):
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
@@ -209,7 +216,7 @@ When making a pull request, before each commit, run:
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
|
||||
Note: Some model components have linting exceptions for E722 to accommodate tensor notation.
|
||||
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
The pretrained model checkpoints can be reached at https://huggingface.co/SWivid/F5-TTS.
|
||||
|
||||
Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
|
||||
|
||||
```
|
||||
ckpts/
|
||||
E2TTS_Base/
|
||||
model_1200000.pt
|
||||
F5TTS_Base/
|
||||
model_1200000.pt
|
||||
```
|
||||
Scripts will automatically pull model checkpoints from Huggingface, by default to `~/.cache/huggingface/hub/`.
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "0.6.2"
|
||||
version = "1.0.8"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
@@ -25,7 +25,6 @@ dependencies = [
|
||||
"jieba",
|
||||
"librosa",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy<=1.26.4",
|
||||
"pydub",
|
||||
"pypinyin",
|
||||
|
||||
@@ -5,43 +5,43 @@ from importlib.resources import files
|
||||
import soundfile as sf
|
||||
import tqdm
|
||||
from cached_path import cached_path
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import (
|
||||
hop_length,
|
||||
infer_process,
|
||||
load_model,
|
||||
load_vocoder,
|
||||
transcribe,
|
||||
preprocess_ref_audio_text,
|
||||
infer_process,
|
||||
remove_silence_for_generated_wav,
|
||||
save_spectrogram,
|
||||
transcribe,
|
||||
target_sample_rate,
|
||||
)
|
||||
from f5_tts.model import DiT, UNetT
|
||||
from f5_tts.model.utils import seed_everything
|
||||
|
||||
|
||||
class F5TTS:
|
||||
def __init__(
|
||||
self,
|
||||
model_type="F5-TTS",
|
||||
model="F5TTS_v1_Base",
|
||||
ckpt_file="",
|
||||
vocab_file="",
|
||||
ode_method="euler",
|
||||
use_ema=True,
|
||||
vocoder_name="vocos",
|
||||
local_path=None,
|
||||
vocoder_local_path=None,
|
||||
device=None,
|
||||
hf_cache_dir=None,
|
||||
):
|
||||
# Initialize parameters
|
||||
self.final_wave = None
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.hop_length = hop_length
|
||||
self.seed = -1
|
||||
self.mel_spec_type = vocoder_name
|
||||
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
||||
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
|
||||
|
||||
self.ode_method = ode_method
|
||||
self.use_ema = use_ema
|
||||
|
||||
# Set device
|
||||
if device is not None:
|
||||
self.device = device
|
||||
else:
|
||||
@@ -58,39 +58,29 @@ class F5TTS:
|
||||
)
|
||||
|
||||
# Load models
|
||||
self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
|
||||
self.load_ema_model(
|
||||
model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
|
||||
self.vocoder = load_vocoder(
|
||||
self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir
|
||||
)
|
||||
|
||||
def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
|
||||
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
|
||||
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
|
||||
|
||||
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
|
||||
if model_type == "F5-TTS":
|
||||
if not ckpt_file:
|
||||
if mel_spec_type == "vocos":
|
||||
ckpt_file = str(
|
||||
cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
|
||||
)
|
||||
elif mel_spec_type == "bigvgan":
|
||||
ckpt_file = str(
|
||||
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
|
||||
)
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
model_cls = DiT
|
||||
elif model_type == "E2-TTS":
|
||||
if not ckpt_file:
|
||||
ckpt_file = str(
|
||||
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
|
||||
)
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
model_cls = UNetT
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
# override for previous models
|
||||
if model == "F5TTS_Base":
|
||||
if self.mel_spec_type == "vocos":
|
||||
ckpt_step = 1200000
|
||||
elif self.mel_spec_type == "bigvgan":
|
||||
model = "F5TTS_Base_bigvgan"
|
||||
ckpt_type = "pt"
|
||||
elif model == "E2TTS_Base":
|
||||
repo_name = "E2-TTS"
|
||||
ckpt_step = 1200000
|
||||
|
||||
if not ckpt_file:
|
||||
ckpt_file = str(
|
||||
cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
|
||||
)
|
||||
self.ema_model = load_model(
|
||||
model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
|
||||
model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
|
||||
)
|
||||
|
||||
def transcribe(self, ref_audio, language=None):
|
||||
@@ -102,8 +92,8 @@ class F5TTS:
|
||||
if remove_silence:
|
||||
remove_silence_for_generated_wav(file_wave)
|
||||
|
||||
def export_spectrogram(self, spect, file_spect):
|
||||
save_spectrogram(spect, file_spect)
|
||||
def export_spectrogram(self, spec, file_spec):
|
||||
save_spectrogram(spec, file_spec)
|
||||
|
||||
def infer(
|
||||
self,
|
||||
@@ -121,17 +111,17 @@ class F5TTS:
|
||||
fix_duration=None,
|
||||
remove_silence=False,
|
||||
file_wave=None,
|
||||
file_spect=None,
|
||||
seed=-1,
|
||||
file_spec=None,
|
||||
seed=None,
|
||||
):
|
||||
if seed == -1:
|
||||
if seed is None:
|
||||
seed = random.randint(0, sys.maxsize)
|
||||
seed_everything(seed)
|
||||
self.seed = seed
|
||||
|
||||
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
|
||||
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)
|
||||
|
||||
wav, sr, spect = infer_process(
|
||||
wav, sr, spec = infer_process(
|
||||
ref_file,
|
||||
ref_text,
|
||||
gen_text,
|
||||
@@ -153,22 +143,22 @@ class F5TTS:
|
||||
if file_wave is not None:
|
||||
self.export_wav(wav, file_wave, remove_silence)
|
||||
|
||||
if file_spect is not None:
|
||||
self.export_spectrogram(spect, file_spect)
|
||||
if file_spec is not None:
|
||||
self.export_spectrogram(spec, file_spec)
|
||||
|
||||
return wav, sr, spect
|
||||
return wav, sr, spec
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
f5tts = F5TTS()
|
||||
|
||||
wav, sr, spect = f5tts.infer(
|
||||
wav, sr, spec = f5tts.infer(
|
||||
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
||||
ref_text="some call me nature, others call me mother nature.",
|
||||
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
||||
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
|
||||
file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
|
||||
seed=-1, # random seed = -1
|
||||
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
|
||||
seed=None,
|
||||
)
|
||||
|
||||
print("seed :", f5tts.seed)
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
hydra:
|
||||
run:
|
||||
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
||||
|
||||
|
||||
datasets:
|
||||
name: Emilia_ZH_EN # dataset name
|
||||
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
||||
batch_size_type: frame # "frame" or "sample"
|
||||
batch_size_type: frame # frame | sample
|
||||
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
||||
num_workers: 16
|
||||
|
||||
optim:
|
||||
epochs: 15
|
||||
epochs: 11
|
||||
learning_rate: 7.5e-5
|
||||
num_warmup_updates: 20000 # warmup updates
|
||||
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
||||
@@ -20,25 +20,29 @@ optim:
|
||||
model:
|
||||
name: E2TTS_Base
|
||||
tokenizer: pinyin
|
||||
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
||||
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
|
||||
backbone: UNetT
|
||||
arch:
|
||||
dim: 1024
|
||||
depth: 24
|
||||
heads: 16
|
||||
ff_mult: 4
|
||||
text_mask_padding: False
|
||||
pe_attn_head: 1
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
n_mel_channels: 100
|
||||
hop_length: 256
|
||||
win_length: 1024
|
||||
n_fft: 1024
|
||||
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
||||
mel_spec_type: vocos # vocos | bigvgan
|
||||
vocoder:
|
||||
is_local: False # use local offline ckpt or not
|
||||
local_path: None # local vocoder path
|
||||
local_path: null # local vocoder path
|
||||
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
logger: wandb # wandb | tensorboard | null
|
||||
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
@@ -1,16 +1,16 @@
|
||||
hydra:
|
||||
run:
|
||||
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
||||
|
||||
|
||||
datasets:
|
||||
name: Emilia_ZH_EN
|
||||
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
||||
batch_size_type: frame # "frame" or "sample"
|
||||
batch_size_type: frame # frame | sample
|
||||
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
||||
num_workers: 16
|
||||
|
||||
optim:
|
||||
epochs: 15
|
||||
epochs: 11
|
||||
learning_rate: 7.5e-5
|
||||
num_warmup_updates: 20000 # warmup updates
|
||||
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
||||
@@ -20,25 +20,29 @@ optim:
|
||||
model:
|
||||
name: E2TTS_Small
|
||||
tokenizer: pinyin
|
||||
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
||||
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
|
||||
backbone: UNetT
|
||||
arch:
|
||||
dim: 768
|
||||
depth: 20
|
||||
heads: 12
|
||||
ff_mult: 4
|
||||
text_mask_padding: False
|
||||
pe_attn_head: 1
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
n_mel_channels: 100
|
||||
hop_length: 256
|
||||
win_length: 1024
|
||||
n_fft: 1024
|
||||
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
||||
mel_spec_type: vocos # vocos | bigvgan
|
||||
vocoder:
|
||||
is_local: False # use local offline ckpt or not
|
||||
local_path: None # local vocoder path
|
||||
local_path: null # local vocoder path
|
||||
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
logger: wandb # wandb | tensorboard | null
|
||||
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
@@ -1,16 +1,16 @@
|
||||
hydra:
|
||||
run:
|
||||
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
||||
|
||||
|
||||
datasets:
|
||||
name: Emilia_ZH_EN # dataset name
|
||||
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
||||
batch_size_type: frame # "frame" or "sample"
|
||||
batch_size_type: frame # frame | sample
|
||||
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
||||
num_workers: 16
|
||||
|
||||
optim:
|
||||
epochs: 15
|
||||
epochs: 11
|
||||
learning_rate: 7.5e-5
|
||||
num_warmup_updates: 20000 # warmup updates
|
||||
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
||||
@@ -20,14 +20,17 @@ optim:
|
||||
model:
|
||||
name: F5TTS_Base # model name
|
||||
tokenizer: pinyin # tokenizer type
|
||||
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
||||
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
|
||||
backbone: DiT
|
||||
arch:
|
||||
dim: 1024
|
||||
depth: 22
|
||||
heads: 16
|
||||
ff_mult: 2
|
||||
text_dim: 512
|
||||
text_mask_padding: False
|
||||
conv_layers: 4
|
||||
pe_attn_head: 1
|
||||
checkpoint_activations: False # recompute activations and save memory for extra compute
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
@@ -35,13 +38,14 @@ model:
|
||||
hop_length: 256
|
||||
win_length: 1024
|
||||
n_fft: 1024
|
||||
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
||||
mel_spec_type: vocos # vocos | bigvgan
|
||||
vocoder:
|
||||
is_local: False # use local offline ckpt or not
|
||||
local_path: None # local vocoder path
|
||||
local_path: null # local vocoder path
|
||||
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
logger: wandb # wandb | tensorboard | null
|
||||
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
@@ -1,16 +1,16 @@
|
||||
hydra:
|
||||
run:
|
||||
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
||||
|
||||
|
||||
datasets:
|
||||
name: Emilia_ZH_EN
|
||||
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
||||
batch_size_type: frame # "frame" or "sample"
|
||||
batch_size_type: frame # frame | sample
|
||||
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
||||
num_workers: 16
|
||||
|
||||
optim:
|
||||
epochs: 15
|
||||
epochs: 11 # only suitable for Emilia, if you want to train it on LibriTTS, set epoch 686
|
||||
learning_rate: 7.5e-5
|
||||
num_warmup_updates: 20000 # warmup updates
|
||||
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
||||
@@ -20,14 +20,17 @@ optim:
|
||||
model:
|
||||
name: F5TTS_Small
|
||||
tokenizer: pinyin
|
||||
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
||||
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
|
||||
backbone: DiT
|
||||
arch:
|
||||
dim: 768
|
||||
depth: 18
|
||||
heads: 12
|
||||
ff_mult: 2
|
||||
text_dim: 512
|
||||
text_mask_padding: False
|
||||
conv_layers: 4
|
||||
pe_attn_head: 1
|
||||
checkpoint_activations: False # recompute activations and save memory for extra compute
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
@@ -35,14 +38,15 @@ model:
|
||||
hop_length: 256
|
||||
win_length: 1024
|
||||
n_fft: 1024
|
||||
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
||||
mel_spec_type: vocos # vocos | bigvgan
|
||||
vocoder:
|
||||
is_local: False # use local offline ckpt or not
|
||||
local_path: None # local vocoder path
|
||||
local_path: null # local vocoder path
|
||||
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
logger: wandb # wandb | tensorboard | null
|
||||
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
53
src/f5_tts/configs/F5TTS_v1_Base.yaml
Normal file
53
src/f5_tts/configs/F5TTS_v1_Base.yaml
Normal file
@@ -0,0 +1,53 @@
|
||||
hydra:
|
||||
run:
|
||||
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
||||
|
||||
datasets:
|
||||
name: Emilia_ZH_EN # dataset name
|
||||
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
||||
batch_size_type: frame # frame | sample
|
||||
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
||||
num_workers: 16
|
||||
|
||||
optim:
|
||||
epochs: 11
|
||||
learning_rate: 7.5e-5
|
||||
num_warmup_updates: 20000 # warmup updates
|
||||
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
||||
max_grad_norm: 1.0 # gradient clipping
|
||||
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
||||
|
||||
model:
|
||||
name: F5TTS_v1_Base # model name
|
||||
tokenizer: pinyin # tokenizer type
|
||||
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
|
||||
backbone: DiT
|
||||
arch:
|
||||
dim: 1024
|
||||
depth: 22
|
||||
heads: 16
|
||||
ff_mult: 2
|
||||
text_dim: 512
|
||||
text_mask_padding: True
|
||||
qk_norm: null # null | rms_norm
|
||||
conv_layers: 4
|
||||
pe_attn_head: null
|
||||
checkpoint_activations: False # recompute activations and save memory for extra compute
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
n_mel_channels: 100
|
||||
hop_length: 256
|
||||
win_length: 1024
|
||||
n_fft: 1024
|
||||
mel_spec_type: vocos # vocos | bigvgan
|
||||
vocoder:
|
||||
is_local: False # use local offline ckpt or not
|
||||
local_path: null # local vocoder path
|
||||
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | null
|
||||
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
@@ -10,6 +10,8 @@ from importlib.resources import files
|
||||
import torch
|
||||
import torchaudio
|
||||
from accelerate import Accelerator
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from f5_tts.eval.utils_eval import (
|
||||
@@ -18,36 +20,26 @@ from f5_tts.eval.utils_eval import (
|
||||
get_seedtts_testset_metainfo,
|
||||
)
|
||||
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
|
||||
from f5_tts.model import CFM, DiT, UNetT
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
|
||||
accelerator = Accelerator()
|
||||
device = f"cuda:{accelerator.process_index}"
|
||||
|
||||
|
||||
# --------------------- Dataset Settings -------------------- #
|
||||
|
||||
target_sample_rate = 24000
|
||||
n_mel_channels = 100
|
||||
hop_length = 256
|
||||
win_length = 1024
|
||||
n_fft = 1024
|
||||
use_ema = True
|
||||
target_rms = 0.1
|
||||
|
||||
|
||||
rel_path = str(files("f5_tts").joinpath("../../"))
|
||||
|
||||
|
||||
def main():
|
||||
# ---------------------- infer setting ---------------------- #
|
||||
|
||||
parser = argparse.ArgumentParser(description="batch inference")
|
||||
|
||||
parser.add_argument("-s", "--seed", default=None, type=int)
|
||||
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
||||
parser.add_argument("-n", "--expname", required=True)
|
||||
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
|
||||
parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
|
||||
parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"])
|
||||
parser.add_argument("-c", "--ckptstep", default=1250000, type=int)
|
||||
|
||||
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
||||
parser.add_argument("-o", "--odemethod", default="euler")
|
||||
@@ -58,12 +50,8 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
seed = args.seed
|
||||
dataset_name = args.dataset
|
||||
exp_name = args.expname
|
||||
ckpt_step = args.ckptstep
|
||||
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
||||
mel_spec_type = args.mel_spec_type
|
||||
tokenizer = args.tokenizer
|
||||
|
||||
nfe_step = args.nfestep
|
||||
ode_method = args.odemethod
|
||||
@@ -77,13 +65,19 @@ def main():
|
||||
use_truth_duration = False
|
||||
no_ref_audio = False
|
||||
|
||||
if exp_name == "F5TTS_Base":
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
||||
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
elif exp_name == "E2TTS_Base":
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
dataset_name = model_cfg.datasets.name
|
||||
tokenizer = model_cfg.model.tokenizer
|
||||
|
||||
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
|
||||
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
|
||||
hop_length = model_cfg.model.mel_spec.hop_length
|
||||
win_length = model_cfg.model.mel_spec.win_length
|
||||
n_fft = model_cfg.model.mel_spec.n_fft
|
||||
|
||||
if testset == "ls_pc_test_clean":
|
||||
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
|
||||
@@ -111,8 +105,6 @@ def main():
|
||||
|
||||
# -------------------------------------------------#
|
||||
|
||||
use_ema = True
|
||||
|
||||
prompts_all = get_inference_prompt(
|
||||
metainfo,
|
||||
speed=speed,
|
||||
@@ -139,7 +131,7 @@ def main():
|
||||
|
||||
# Model
|
||||
model = CFM(
|
||||
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
mel_spec_kwargs=dict(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
@@ -154,6 +146,10 @@ def main():
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
||||
if not os.path.exists(ckpt_path):
|
||||
print("Loading from self-organized training checkpoints rather than released pretrained.")
|
||||
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
|
||||
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
||||
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
||||
|
||||
@@ -200,7 +196,7 @@ def main():
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
timediff = time.time() - start
|
||||
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
|
||||
print(f"Done batch inference in {timediff / 60:.2f} minutes.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
# e.g. F5-TTS, 16 NFE
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16
|
||||
|
||||
# e.g. Vanilla E2 TTS, 32 NFE
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
|
||||
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
||||
|
||||
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
||||
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
|
||||
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0
|
||||
|
||||
# etc.
|
||||
|
||||
@@ -53,43 +53,37 @@ def main():
|
||||
asr_ckpt_dir = "" # auto download to cache dir
|
||||
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
||||
|
||||
# --------------------------- WER ---------------------------
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
full_results = []
|
||||
metrics = []
|
||||
|
||||
if eval_task == "wer":
|
||||
wer_results = []
|
||||
wers = []
|
||||
|
||||
with mp.Pool(processes=len(gpus)) as pool:
|
||||
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
||||
results = pool.map(run_asr_wer, args)
|
||||
for r in results:
|
||||
wer_results.extend(r)
|
||||
|
||||
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
|
||||
with open(wer_result_path, "w") as f:
|
||||
for line in wer_results:
|
||||
wers.append(line["wer"])
|
||||
json_line = json.dumps(line, ensure_ascii=False)
|
||||
f.write(json_line + "\n")
|
||||
|
||||
wer = round(np.mean(wers) * 100, 3)
|
||||
print(f"\nTotal {len(wers)} samples")
|
||||
print(f"WER : {wer}%")
|
||||
print(f"Results have been saved to {wer_result_path}")
|
||||
|
||||
# --------------------------- SIM ---------------------------
|
||||
|
||||
if eval_task == "sim":
|
||||
sims = []
|
||||
full_results.extend(r)
|
||||
elif eval_task == "sim":
|
||||
with mp.Pool(processes=len(gpus)) as pool:
|
||||
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
||||
results = pool.map(run_sim, args)
|
||||
for r in results:
|
||||
sims.extend(r)
|
||||
full_results.extend(r)
|
||||
else:
|
||||
raise ValueError(f"Unknown metric type: {eval_task}")
|
||||
|
||||
sim = round(sum(sims) / len(sims), 3)
|
||||
print(f"\nTotal {len(sims)} samples")
|
||||
print(f"SIM : {sim}")
|
||||
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
|
||||
with open(result_path, "w") as f:
|
||||
for line in full_results:
|
||||
metrics.append(line[eval_task])
|
||||
f.write(json.dumps(line, ensure_ascii=False) + "\n")
|
||||
metric = round(np.mean(metrics), 5)
|
||||
f.write(f"\n{eval_task.upper()}: {metric}\n")
|
||||
|
||||
print(f"\nTotal {len(metrics)} samples")
|
||||
print(f"{eval_task.upper()}: {metric}")
|
||||
print(f"{eval_task.upper()} results saved to {result_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -52,43 +52,37 @@ def main():
|
||||
asr_ckpt_dir = "" # auto download to cache dir
|
||||
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
||||
|
||||
# --------------------------- WER ---------------------------
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
full_results = []
|
||||
metrics = []
|
||||
|
||||
if eval_task == "wer":
|
||||
wer_results = []
|
||||
wers = []
|
||||
|
||||
with mp.Pool(processes=len(gpus)) as pool:
|
||||
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
||||
results = pool.map(run_asr_wer, args)
|
||||
for r in results:
|
||||
wer_results.extend(r)
|
||||
|
||||
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
|
||||
with open(wer_result_path, "w") as f:
|
||||
for line in wer_results:
|
||||
wers.append(line["wer"])
|
||||
json_line = json.dumps(line, ensure_ascii=False)
|
||||
f.write(json_line + "\n")
|
||||
|
||||
wer = round(np.mean(wers) * 100, 3)
|
||||
print(f"\nTotal {len(wers)} samples")
|
||||
print(f"WER : {wer}%")
|
||||
print(f"Results have been saved to {wer_result_path}")
|
||||
|
||||
# --------------------------- SIM ---------------------------
|
||||
|
||||
if eval_task == "sim":
|
||||
sims = []
|
||||
full_results.extend(r)
|
||||
elif eval_task == "sim":
|
||||
with mp.Pool(processes=len(gpus)) as pool:
|
||||
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
||||
results = pool.map(run_sim, args)
|
||||
for r in results:
|
||||
sims.extend(r)
|
||||
full_results.extend(r)
|
||||
else:
|
||||
raise ValueError(f"Unknown metric type: {eval_task}")
|
||||
|
||||
sim = round(sum(sims) / len(sims), 3)
|
||||
print(f"\nTotal {len(sims)} samples")
|
||||
print(f"SIM : {sim}")
|
||||
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
|
||||
with open(result_path, "w") as f:
|
||||
for line in full_results:
|
||||
metrics.append(line[eval_task])
|
||||
f.write(json.dumps(line, ensure_ascii=False) + "\n")
|
||||
metric = round(np.mean(metrics), 5)
|
||||
f.write(f"\n{eval_task.upper()}: {metric}\n")
|
||||
|
||||
print(f"\nTotal {len(metrics)} samples")
|
||||
print(f"{eval_task.upper()}: {metric}")
|
||||
print(f"{eval_task.upper()} results saved to {result_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -19,25 +19,23 @@ def main():
|
||||
predictor = predictor.to(device)
|
||||
|
||||
audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
|
||||
utmos_results = {}
|
||||
utmos_score = 0
|
||||
|
||||
for audio_path in tqdm(audio_paths, desc="Processing"):
|
||||
wav_name = audio_path.stem
|
||||
wav, sr = librosa.load(audio_path, sr=None, mono=True)
|
||||
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
|
||||
score = predictor(wav_tensor, sr)
|
||||
utmos_results[str(wav_name)] = score.item()
|
||||
utmos_score += score.item()
|
||||
|
||||
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
|
||||
print(f"UTMOS: {avg_score}")
|
||||
|
||||
utmos_result_path = Path(args.audio_dir) / "utmos_results.json"
|
||||
utmos_result_path = Path(args.audio_dir) / "_utmos_results.jsonl"
|
||||
with open(utmos_result_path, "w", encoding="utf-8") as f:
|
||||
json.dump(utmos_results, f, ensure_ascii=False, indent=4)
|
||||
for audio_path in tqdm(audio_paths, desc="Processing"):
|
||||
wav, sr = librosa.load(audio_path, sr=None, mono=True)
|
||||
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
|
||||
score = predictor(wav_tensor, sr)
|
||||
line = {}
|
||||
line["wav"], line["utmos"] = str(audio_path.stem), score.item()
|
||||
utmos_score += score.item()
|
||||
f.write(json.dumps(line, ensure_ascii=False) + "\n")
|
||||
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
|
||||
f.write(f"\nUTMOS: {avg_score:.4f}\n")
|
||||
|
||||
print(f"Results have been saved to {utmos_result_path}")
|
||||
print(f"UTMOS: {avg_score:.4f}")
|
||||
print(f"UTMOS results saved to {utmos_result_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -148,9 +148,9 @@ def get_inference_prompt(
|
||||
|
||||
# deal with batch
|
||||
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
||||
assert (
|
||||
min_tokens <= total_mel_len <= max_tokens
|
||||
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||
assert min_tokens <= total_mel_len <= max_tokens, (
|
||||
f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||
)
|
||||
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
|
||||
|
||||
utts[bucket_i].append(utt)
|
||||
@@ -389,10 +389,10 @@ def run_sim(args):
|
||||
model = model.cuda(device)
|
||||
model.eval()
|
||||
|
||||
sims = []
|
||||
for wav1, wav2, truth in tqdm(test_set):
|
||||
wav1, sr1 = torchaudio.load(wav1)
|
||||
wav2, sr2 = torchaudio.load(wav2)
|
||||
sim_results = []
|
||||
for gen_wav, prompt_wav, truth in tqdm(test_set):
|
||||
wav1, sr1 = torchaudio.load(gen_wav)
|
||||
wav2, sr2 = torchaudio.load(prompt_wav)
|
||||
|
||||
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
||||
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
||||
@@ -408,6 +408,11 @@ def run_sim(args):
|
||||
|
||||
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
||||
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
||||
sims.append(sim)
|
||||
sim_results.append(
|
||||
{
|
||||
"wav": Path(gen_wav).stem,
|
||||
"sim": sim,
|
||||
}
|
||||
)
|
||||
|
||||
return sims
|
||||
return sim_results
|
||||
|
||||
@@ -4,16 +4,17 @@ The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://h
|
||||
|
||||
**More checkpoints with whole community efforts can be found in [SHARED.md](SHARED.md), supporting more languages.**
|
||||
|
||||
Currently support **30s for a single** generation, which is the **total length** including both prompt and output audio. However, you can provide `infer_cli` and `infer_gradio` with longer text, will automatically do chunk generation. Long reference audio will be **clip short to ~15s**.
|
||||
Currently support **30s for a single** generation, which is the **total length** (same logic if `fix_duration`) including both prompt and output audio. However, `infer_cli` and `infer_gradio` will automatically do chunk generation for longer text. Long reference audio will be **clip short to ~12s**.
|
||||
|
||||
To avoid possible inference failures, make sure you have seen through the following instructions.
|
||||
|
||||
- Use reference audio <15s and leave some silence (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
|
||||
- Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
|
||||
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses.
|
||||
- Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English.
|
||||
- If the generation output is blank (pure silence), check for ffmpeg installation (various tutorials online, blogs, videos, etc.).
|
||||
- Try turn off use_ema if using an early-stage finetuned checkpoint (which goes just few updates).
|
||||
- Use reference audio <12s and leave proper silence space (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
|
||||
- <ins>Uppercased letters</ins> (best with form like K.F.C.) will be uttered letter by letter, and lowercased letters used for common words.
|
||||
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") <ins>to explicitly introduce some pauses</ins>.
|
||||
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
|
||||
- <ins>Preprocess numbers</ins> to Chinese letters if you want to have them read in Chinese, otherwise in English.
|
||||
- If the generation output is blank (pure silence), <ins>check for ffmpeg installation</ins>.
|
||||
- Try <ins>turn off `use_ema` if using an early-stage</ins> finetuned checkpoint (which goes just few updates).
|
||||
|
||||
|
||||
## Gradio App
|
||||
@@ -68,14 +69,16 @@ Basically you can inference with flags:
|
||||
```bash
|
||||
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
|
||||
f5-tts_infer-cli \
|
||||
--model "F5-TTS" \
|
||||
--model F5TTS_v1_Base \
|
||||
--ref_audio "ref_audio.wav" \
|
||||
--ref_text "The content, subtitle or transcription of reference audio." \
|
||||
--gen_text "Some text you want TTS model generate for you."
|
||||
|
||||
# Choose Vocoder
|
||||
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
|
||||
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
|
||||
# Use BigVGAN as vocoder. Currently only support F5TTS_Base.
|
||||
f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local
|
||||
|
||||
# Use custom path checkpoint, e.g.
|
||||
f5-tts_infer-cli --ckpt_file ckpts/F5TTS_v1_Base/model_1250000.safetensors
|
||||
|
||||
# More instructions
|
||||
f5-tts_infer-cli --help
|
||||
@@ -90,8 +93,8 @@ f5-tts_infer-cli -c custom.toml
|
||||
For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
|
||||
|
||||
```toml
|
||||
# F5-TTS | E2-TTS
|
||||
model = "F5-TTS"
|
||||
# F5TTS_v1_Base | E2TTS_Base
|
||||
model = "F5TTS_v1_Base"
|
||||
ref_audio = "infer/examples/basic/basic_ref_en.wav"
|
||||
# If an empty "", transcribes the reference audio automatically.
|
||||
ref_text = "Some call me nature, others call me mother nature."
|
||||
@@ -105,8 +108,8 @@ output_dir = "tests"
|
||||
You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
|
||||
|
||||
```toml
|
||||
# F5-TTS | E2-TTS
|
||||
model = "F5-TTS"
|
||||
# F5TTS_v1_Base | E2TTS_Base
|
||||
model = "F5TTS_v1_Base"
|
||||
ref_audio = "infer/examples/multi/main.flac"
|
||||
# If an empty "", transcribes the reference audio automatically.
|
||||
ref_text = ""
|
||||
@@ -126,6 +129,22 @@ ref_text = ""
|
||||
```
|
||||
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
|
||||
|
||||
## Socket Real-time Service
|
||||
|
||||
Real-time voice output with chunk stream:
|
||||
|
||||
```bash
|
||||
# Start socket server
|
||||
python src/f5_tts/socket_server.py
|
||||
|
||||
# If PyAudio not installed
|
||||
sudo apt-get install portaudio19-dev
|
||||
pip install pyaudio
|
||||
|
||||
# Communicate with socket client
|
||||
python src/f5_tts/socket_client.py
|
||||
```
|
||||
|
||||
## Speech Editing
|
||||
|
||||
To test speech editing capabilities, use the following command:
|
||||
@@ -134,86 +153,3 @@ To test speech editing capabilities, use the following command:
|
||||
python src/f5_tts/infer/speech_edit.py
|
||||
```
|
||||
|
||||
## Socket Realtime Client
|
||||
|
||||
To communicate with socket server you need to run
|
||||
```bash
|
||||
python src/f5_tts/socket_server.py
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Then create client to communicate</summary>
|
||||
|
||||
```bash
|
||||
# If PyAudio not installed
|
||||
sudo apt-get install portaudio19-dev
|
||||
pip install pyaudio
|
||||
```
|
||||
|
||||
``` python
|
||||
# Create the socket_client.py
|
||||
import socket
|
||||
import asyncio
|
||||
import pyaudio
|
||||
import numpy as np
|
||||
import logging
|
||||
import time
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
|
||||
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
|
||||
|
||||
start_time = time.time()
|
||||
first_chunk_time = None
|
||||
|
||||
async def play_audio_stream():
|
||||
nonlocal first_chunk_time
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
|
||||
if not data:
|
||||
break
|
||||
if data == b"END":
|
||||
logger.info("End of audio received.")
|
||||
break
|
||||
|
||||
audio_array = np.frombuffer(data, dtype=np.float32)
|
||||
stream.write(audio_array.tobytes())
|
||||
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time.time()
|
||||
|
||||
finally:
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
|
||||
|
||||
try:
|
||||
data_to_send = f"{text}".encode("utf-8")
|
||||
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
|
||||
await play_audio_stream()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in listen_to_F5TTS: {e}")
|
||||
|
||||
finally:
|
||||
client_socket.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
|
||||
|
||||
asyncio.run(listen_to_F5TTS(text_to_send))
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
<!-- omit in toc -->
|
||||
### Supported Languages
|
||||
- [Multilingual](#multilingual)
|
||||
- [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts)
|
||||
- [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts)
|
||||
- [English](#english)
|
||||
- [Finnish](#finnish)
|
||||
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
|
||||
@@ -37,7 +37,18 @@
|
||||
|
||||
## Multilingual
|
||||
|
||||
#### F5-TTS Base @ zh & en @ F5-TTS
|
||||
#### F5-TTS v1 v0 Base @ zh & en @ F5-TTS
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
||||
|
||||
```bash
|
||||
Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
|
||||
# A Variant Model: hf://SWivid/F5-TTS/F5TTS_v1_Base_no_zero_init/model_1250000.safetensors
|
||||
Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
```
|
||||
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
||||
@@ -45,7 +56,7 @@
|
||||
```bash
|
||||
Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
|
||||
Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
|
||||
@@ -64,7 +75,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
||||
```bash
|
||||
Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
|
||||
Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
|
||||
@@ -78,7 +89,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
||||
```bash
|
||||
Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
|
||||
Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
|
||||
@@ -96,7 +107,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
||||
```bash
|
||||
Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
|
||||
Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
|
||||
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
- Authors: SPRING Lab, Indian Institute of Technology, Madras
|
||||
@@ -113,7 +124,7 @@ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "c
|
||||
```bash
|
||||
Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
|
||||
Vocab: hf://alien79/F5-TTS-italian/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
- Trained by [Mithril Man](https://github.com/MithrilMan)
|
||||
@@ -131,7 +142,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
||||
```bash
|
||||
Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
|
||||
Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
|
||||
@@ -148,7 +159,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
||||
```bash
|
||||
Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
|
||||
Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
- Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
|
||||
- Any improvements are welcome
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# F5-TTS | E2-TTS
|
||||
model = "F5-TTS"
|
||||
# F5TTS_v1_Base | E2TTS_Base
|
||||
model = "F5TTS_v1_Base"
|
||||
ref_audio = "infer/examples/basic/basic_ref_en.wav"
|
||||
# If an empty "", transcribes the reference audio automatically.
|
||||
ref_text = "Some call me nature, others call me mother nature."
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# F5-TTS | E2-TTS
|
||||
model = "F5-TTS"
|
||||
# F5TTS_v1_Base | E2TTS_Base
|
||||
model = "F5TTS_v1_Base"
|
||||
ref_audio = "infer/examples/multi/main.flac"
|
||||
# If an empty "", transcribes the reference audio automatically.
|
||||
ref_text = ""
|
||||
|
||||
@@ -10,6 +10,7 @@ import numpy as np
|
||||
import soundfile as sf
|
||||
import tomli
|
||||
from cached_path import cached_path
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import (
|
||||
@@ -21,13 +22,13 @@ from f5_tts.infer.utils_infer import (
|
||||
sway_sampling_coef,
|
||||
speed,
|
||||
fix_duration,
|
||||
device,
|
||||
infer_process,
|
||||
load_model,
|
||||
load_vocoder,
|
||||
preprocess_ref_audio_text,
|
||||
remove_silence_for_generated_wav,
|
||||
)
|
||||
from f5_tts.model import DiT, UNetT
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -50,7 +51,7 @@ parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
type=str,
|
||||
help="The model name: F5-TTS | E2-TTS",
|
||||
help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-mc",
|
||||
@@ -162,6 +163,11 @@ parser.add_argument(
|
||||
type=float,
|
||||
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
help="Specify the device to run on",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -172,8 +178,7 @@ config = tomli.load(open(args.config, "rb"))
|
||||
|
||||
# command-line interface parameters
|
||||
|
||||
model = args.model or config.get("model", "F5-TTS")
|
||||
model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml")))
|
||||
model = args.model or config.get("model", "F5TTS_v1_Base")
|
||||
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
|
||||
vocab_file = args.vocab_file or config.get("vocab_file", "")
|
||||
|
||||
@@ -203,6 +208,7 @@ cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
|
||||
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
|
||||
speed = args.speed or config.get("speed", speed)
|
||||
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
|
||||
device = args.device or config.get("device", device)
|
||||
|
||||
|
||||
# patches for pip pkg user
|
||||
@@ -240,41 +246,42 @@ if vocoder_name == "vocos":
|
||||
elif vocoder_name == "bigvgan":
|
||||
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
||||
|
||||
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
|
||||
vocoder = load_vocoder(
|
||||
vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
|
||||
)
|
||||
|
||||
|
||||
# load TTS model
|
||||
|
||||
if model == "F5-TTS":
|
||||
model_cls = DiT
|
||||
model_cfg = OmegaConf.load(model_cfg).model.arch
|
||||
if not ckpt_file: # path not specified, download from repo
|
||||
if vocoder_name == "vocos":
|
||||
repo_name = "F5-TTS"
|
||||
exp_name = "F5TTS_Base"
|
||||
ckpt_step = 1200000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
||||
elif vocoder_name == "bigvgan":
|
||||
repo_name = "F5-TTS"
|
||||
exp_name = "F5TTS_Base_bigvgan"
|
||||
ckpt_step = 1250000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
|
||||
model_cfg = OmegaConf.load(
|
||||
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
||||
)
|
||||
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
elif model == "E2-TTS":
|
||||
assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet"
|
||||
assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet"
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
if not ckpt_file: # path not specified, download from repo
|
||||
repo_name = "E2-TTS"
|
||||
exp_name = "E2TTS_Base"
|
||||
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
|
||||
|
||||
if model != "F5TTS_Base":
|
||||
assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type
|
||||
|
||||
# override for previous models
|
||||
if model == "F5TTS_Base":
|
||||
if vocoder_name == "vocos":
|
||||
ckpt_step = 1200000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
||||
elif vocoder_name == "bigvgan":
|
||||
model = "F5TTS_Base_bigvgan"
|
||||
ckpt_type = "pt"
|
||||
elif model == "E2TTS_Base":
|
||||
repo_name = "E2-TTS"
|
||||
ckpt_step = 1200000
|
||||
|
||||
if not ckpt_file:
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
|
||||
|
||||
print(f"Using {model}...")
|
||||
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
|
||||
ema_model = load_model(
|
||||
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
|
||||
)
|
||||
|
||||
|
||||
# inference process
|
||||
@@ -330,6 +337,7 @@ def main():
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
speed=speed,
|
||||
fix_duration=fix_duration,
|
||||
device=device,
|
||||
)
|
||||
generated_audio_segments.append(audio_segment)
|
||||
|
||||
@@ -337,7 +345,7 @@ def main():
|
||||
if len(gen_text_) > 200:
|
||||
gen_text_ = gen_text_[:200] + " ... "
|
||||
sf.write(
|
||||
os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"),
|
||||
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
|
||||
audio_segment,
|
||||
final_sample_rate,
|
||||
)
|
||||
|
||||
@@ -41,12 +41,12 @@ from f5_tts.infer.utils_infer import (
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_TTS_MODEL = "F5-TTS"
|
||||
DEFAULT_TTS_MODEL = "F5-TTS_v1"
|
||||
tts_model_choice = DEFAULT_TTS_MODEL
|
||||
|
||||
DEFAULT_TTS_MODEL_CFG = [
|
||||
"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
|
||||
"hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
|
||||
"hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
|
||||
"hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt",
|
||||
json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
|
||||
]
|
||||
|
||||
@@ -56,13 +56,15 @@ DEFAULT_TTS_MODEL_CFG = [
|
||||
vocoder = load_vocoder()
|
||||
|
||||
|
||||
def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))):
|
||||
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
def load_f5tts():
|
||||
ckpt_path = str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
|
||||
F5TTS_model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
|
||||
return load_model(DiT, F5TTS_model_cfg, ckpt_path)
|
||||
|
||||
|
||||
def load_e2tts(ckpt_path=str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))):
|
||||
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
def load_e2tts():
|
||||
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
|
||||
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1)
|
||||
return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
|
||||
|
||||
|
||||
@@ -73,7 +75,7 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
|
||||
if vocab_path.startswith("hf://"):
|
||||
vocab_path = str(cached_path(vocab_path))
|
||||
if model_cfg is None:
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
|
||||
return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
|
||||
|
||||
|
||||
@@ -130,7 +132,7 @@ def infer(
|
||||
|
||||
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
||||
|
||||
if model == "F5-TTS":
|
||||
if model == DEFAULT_TTS_MODEL:
|
||||
ema_model = F5TTS_ema_model
|
||||
elif model == "E2-TTS":
|
||||
global E2TTS_ema_model
|
||||
@@ -756,13 +758,13 @@ This is {"a local web UI for [F5 TTS](https://github.com/SWivid/F5-TTS)" if not
|
||||
|
||||
The checkpoints currently support English and Chinese.
|
||||
|
||||
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
|
||||
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 12s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
|
||||
|
||||
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
|
||||
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<12s). Ensure the audio is fully uploaded before generating.**
|
||||
"""
|
||||
)
|
||||
|
||||
last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info.txt")
|
||||
last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt")
|
||||
|
||||
def load_last_used_custom():
|
||||
try:
|
||||
@@ -821,7 +823,30 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
||||
custom_model_cfg = gr.Dropdown(
|
||||
choices=[
|
||||
DEFAULT_TTS_MODEL_CFG[2],
|
||||
json.dumps(dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4)),
|
||||
json.dumps(
|
||||
dict(
|
||||
dim=1024,
|
||||
depth=22,
|
||||
heads=16,
|
||||
ff_mult=2,
|
||||
text_dim=512,
|
||||
text_mask_padding=False,
|
||||
conv_layers=4,
|
||||
pe_attn_head=1,
|
||||
)
|
||||
),
|
||||
json.dumps(
|
||||
dict(
|
||||
dim=768,
|
||||
depth=18,
|
||||
heads=12,
|
||||
ff_mult=2,
|
||||
text_dim=512,
|
||||
text_mask_padding=False,
|
||||
conv_layers=4,
|
||||
pe_attn_head=1,
|
||||
)
|
||||
),
|
||||
],
|
||||
value=load_last_used_custom()[2],
|
||||
allow_custom_value=True,
|
||||
|
||||
@@ -2,12 +2,16 @@ import os
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
||||
|
||||
from importlib.resources import files
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
|
||||
from f5_tts.model import CFM, DiT, UNetT
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
||||
|
||||
device = (
|
||||
@@ -21,44 +25,40 @@ device = (
|
||||
)
|
||||
|
||||
|
||||
# --------------------- Dataset Settings -------------------- #
|
||||
|
||||
target_sample_rate = 24000
|
||||
n_mel_channels = 100
|
||||
hop_length = 256
|
||||
win_length = 1024
|
||||
n_fft = 1024
|
||||
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
|
||||
target_rms = 0.1
|
||||
|
||||
tokenizer = "pinyin"
|
||||
dataset_name = "Emilia_ZH_EN"
|
||||
|
||||
|
||||
# ---------------------- infer setting ---------------------- #
|
||||
|
||||
seed = None # int | None
|
||||
|
||||
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
|
||||
ckpt_step = 1200000
|
||||
exp_name = "F5TTS_v1_Base" # F5TTS_v1_Base | E2TTS_Base
|
||||
ckpt_step = 1250000
|
||||
|
||||
nfe_step = 32 # 16, 32
|
||||
cfg_strength = 2.0
|
||||
ode_method = "euler" # euler | midpoint
|
||||
sway_sampling_coef = -1.0
|
||||
speed = 1.0
|
||||
target_rms = 0.1
|
||||
|
||||
if exp_name == "F5TTS_Base":
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
|
||||
elif exp_name == "E2TTS_Base":
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
||||
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
||||
dataset_name = model_cfg.datasets.name
|
||||
tokenizer = model_cfg.model.tokenizer
|
||||
|
||||
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
|
||||
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
|
||||
hop_length = model_cfg.model.mel_spec.hop_length
|
||||
win_length = model_cfg.model.mel_spec.win_length
|
||||
n_fft = model_cfg.model.mel_spec.n_fft
|
||||
|
||||
|
||||
ckpt_path = str(files("f5_tts").joinpath("../../")) + f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
||||
output_dir = "tests"
|
||||
|
||||
|
||||
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
||||
# pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
|
||||
# [write the origin_text into a file, e.g. tests/test_edit.txt]
|
||||
@@ -67,7 +67,7 @@ output_dir = "tests"
|
||||
# [--language "zho" for Chinese, "eng" for English]
|
||||
# [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
|
||||
|
||||
audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_en.wav"
|
||||
audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav"))
|
||||
origin_text = "Some call me nature, others call me mother nature."
|
||||
target_text = "Some call me optimist, others call me realist."
|
||||
parts_to_edit = [
|
||||
@@ -106,7 +106,7 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
||||
|
||||
# Model
|
||||
model = CFM(
|
||||
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
mel_spec_kwargs=dict(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
|
||||
@@ -149,7 +149,7 @@ def initialize_asr_pipeline(device: str = device, dtype=None):
|
||||
dtype = (
|
||||
torch.float16
|
||||
if "cuda" in device
|
||||
and torch.cuda.get_device_properties(device).major >= 6
|
||||
and torch.cuda.get_device_properties(device).major >= 7
|
||||
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
||||
else torch.float32
|
||||
)
|
||||
@@ -186,7 +186,7 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
||||
dtype = (
|
||||
torch.float16
|
||||
if "cuda" in device
|
||||
and torch.cuda.get_device_properties(device).major >= 6
|
||||
and torch.cuda.get_device_properties(device).major >= 7
|
||||
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
||||
else torch.float32
|
||||
)
|
||||
@@ -289,7 +289,7 @@ def remove_silence_edges(audio, silence_threshold=-42):
|
||||
# preprocess reference audio and text
|
||||
|
||||
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device):
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
|
||||
show_info("Converting audio...")
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
aseg = AudioSegment.from_file(ref_audio_orig)
|
||||
@@ -301,29 +301,29 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
|
||||
show_info("Audio is over 15s, clipping short. (1)")
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 12s, clipping short. (1)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
|
||||
# 2. try to find short silence for clipping if 1. failed
|
||||
if len(non_silent_wave) > 15000:
|
||||
if len(non_silent_wave) > 12000:
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
|
||||
show_info("Audio is over 15s, clipping short. (2)")
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 12s, clipping short. (2)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
|
||||
aseg = non_silent_wave
|
||||
|
||||
# 3. if no proper silence found for clipping
|
||||
if len(aseg) > 15000:
|
||||
aseg = aseg[:15000]
|
||||
show_info("Audio is over 15s, clipping short. (3)")
|
||||
if len(aseg) > 12000:
|
||||
aseg = aseg[:12000]
|
||||
show_info("Audio is over 12s, clipping short. (3)")
|
||||
|
||||
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
|
||||
aseg.export(f.name, format="wav")
|
||||
@@ -383,7 +383,7 @@ def infer_process(
|
||||
):
|
||||
# Split the input text into batches
|
||||
audio, sr = torchaudio.load(ref_audio)
|
||||
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
|
||||
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
|
||||
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
||||
for i, gen_text in enumerate(gen_text_batches):
|
||||
print(f"gen_text {i}", gen_text)
|
||||
@@ -479,14 +479,15 @@ def infer_batch_process(
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
)
|
||||
del _
|
||||
|
||||
generated = generated.to(torch.float32)
|
||||
generated = generated.to(torch.float32) # generated mel spectrogram
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
generated_mel_spec = generated.permute(0, 2, 1)
|
||||
generated = generated.permute(0, 2, 1)
|
||||
if mel_spec_type == "vocos":
|
||||
generated_wave = vocoder.decode(generated_mel_spec)
|
||||
generated_wave = vocoder.decode(generated)
|
||||
elif mel_spec_type == "bigvgan":
|
||||
generated_wave = vocoder(generated_mel_spec)
|
||||
generated_wave = vocoder(generated)
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * rms / target_rms
|
||||
|
||||
@@ -497,7 +498,9 @@ def infer_batch_process(
|
||||
for j in range(0, len(generated_wave), chunk_size):
|
||||
yield generated_wave[j : j + chunk_size], target_sample_rate
|
||||
else:
|
||||
yield generated_wave, generated_mel_spec[0].cpu().numpy()
|
||||
generated_cpu = generated[0].cpu().numpy()
|
||||
del generated
|
||||
yield generated_wave, generated_cpu
|
||||
|
||||
if streaming:
|
||||
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
### unett.py
|
||||
- flat unet transformer
|
||||
- structure same as in e2-tts & voicebox paper except using rotary pos emb
|
||||
- update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
|
||||
- possible abs pos emb & convnextv2 blocks for embedded text before concat
|
||||
|
||||
### dit.py
|
||||
- adaln-zero dit
|
||||
@@ -14,7 +14,7 @@
|
||||
- possible long skip connection (first layer to last layer)
|
||||
|
||||
### mmdit.py
|
||||
- sd3 structure
|
||||
- stable diffusion 3 block structure
|
||||
- timestep as condition
|
||||
- left stream: text embedded and applied a abs pos emb
|
||||
- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
|
||||
|
||||
@@ -20,7 +20,7 @@ from f5_tts.model.modules import (
|
||||
ConvNeXtV2Block,
|
||||
ConvPositionEmbedding,
|
||||
DiTBlock,
|
||||
AdaLayerNormZero_Final,
|
||||
AdaLayerNorm_Final,
|
||||
precompute_freqs_cis,
|
||||
get_pos_embed_indices,
|
||||
)
|
||||
@@ -30,10 +30,12 @@ from f5_tts.model.modules import (
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
||||
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
|
||||
super().__init__()
|
||||
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
||||
|
||||
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
||||
|
||||
if conv_layers > 0:
|
||||
self.extra_modeling = True
|
||||
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
||||
@@ -49,6 +51,8 @@ class TextEmbedding(nn.Module):
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
batch, text_len = text.shape[0], text.shape[1]
|
||||
text = F.pad(text, (0, seq_len - text_len), value=0)
|
||||
if self.mask_padding:
|
||||
text_mask = text == 0
|
||||
|
||||
if drop_text: # cfg for text
|
||||
text = torch.zeros_like(text)
|
||||
@@ -64,7 +68,13 @@ class TextEmbedding(nn.Module):
|
||||
text = text + text_pos_embed
|
||||
|
||||
# convnextv2 blocks
|
||||
text = self.text_blocks(text)
|
||||
if self.mask_padding:
|
||||
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
||||
for block in self.text_blocks:
|
||||
text = block(text)
|
||||
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
||||
else:
|
||||
text = self.text_blocks(text)
|
||||
|
||||
return text
|
||||
|
||||
@@ -103,7 +113,10 @@ class DiT(nn.Module):
|
||||
mel_dim=100,
|
||||
text_num_embeds=256,
|
||||
text_dim=None,
|
||||
text_mask_padding=True,
|
||||
qk_norm=None,
|
||||
conv_layers=0,
|
||||
pe_attn_head=None,
|
||||
long_skip_connection=False,
|
||||
checkpoint_activations=False,
|
||||
):
|
||||
@@ -112,7 +125,10 @@ class DiT(nn.Module):
|
||||
self.time_embed = TimestepEmbedding(dim)
|
||||
if text_dim is None:
|
||||
text_dim = mel_dim
|
||||
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
|
||||
self.text_embed = TextEmbedding(
|
||||
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
|
||||
)
|
||||
self.text_cond, self.text_uncond = None, None # text cache
|
||||
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
||||
|
||||
self.rotary_embed = RotaryEmbedding(dim_head)
|
||||
@@ -121,15 +137,40 @@ class DiT(nn.Module):
|
||||
self.depth = depth
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
|
||||
[
|
||||
DiTBlock(
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
ff_mult=ff_mult,
|
||||
dropout=dropout,
|
||||
qk_norm=qk_norm,
|
||||
pe_attn_head=pe_attn_head,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
||||
|
||||
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
|
||||
self.proj_out = nn.Linear(dim, mel_dim)
|
||||
|
||||
self.checkpoint_activations = checkpoint_activations
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
# Zero-out AdaLN layers in DiT blocks:
|
||||
for block in self.transformer_blocks:
|
||||
nn.init.constant_(block.attn_norm.linear.weight, 0)
|
||||
nn.init.constant_(block.attn_norm.linear.bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.norm_out.linear.weight, 0)
|
||||
nn.init.constant_(self.norm_out.linear.bias, 0)
|
||||
nn.init.constant_(self.proj_out.weight, 0)
|
||||
nn.init.constant_(self.proj_out.bias, 0)
|
||||
|
||||
def ckpt_wrapper(self, module):
|
||||
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
|
||||
def ckpt_forward(*inputs):
|
||||
@@ -138,6 +179,9 @@ class DiT(nn.Module):
|
||||
|
||||
return ckpt_forward
|
||||
|
||||
def clear_cache(self):
|
||||
self.text_cond, self.text_uncond = None, None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
@@ -147,14 +191,25 @@ class DiT(nn.Module):
|
||||
drop_audio_cond, # cfg for cond audio
|
||||
drop_text, # cfg for text
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
cache=False,
|
||||
):
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
time = time.repeat(batch)
|
||||
|
||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||
# t: conditioning time, text: text, x: noised audio + cond audio + text
|
||||
t = self.time_embed(time)
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
||||
if cache:
|
||||
if drop_text:
|
||||
if self.text_uncond is None:
|
||||
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
|
||||
text_embed = self.text_uncond
|
||||
else:
|
||||
if self.text_cond is None:
|
||||
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
|
||||
text_embed = self.text_cond
|
||||
else:
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||
@@ -164,7 +219,8 @@ class DiT(nn.Module):
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.checkpoint_activations:
|
||||
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
|
||||
# https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
|
||||
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
|
||||
else:
|
||||
x = block(x, t, mask=mask, rope=rope)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from f5_tts.model.modules import (
|
||||
TimestepEmbedding,
|
||||
ConvPositionEmbedding,
|
||||
MMDiTBlock,
|
||||
AdaLayerNormZero_Final,
|
||||
AdaLayerNorm_Final,
|
||||
precompute_freqs_cis,
|
||||
get_pos_embed_indices,
|
||||
)
|
||||
@@ -28,18 +28,24 @@ from f5_tts.model.modules import (
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, out_dim, text_num_embeds):
|
||||
def __init__(self, out_dim, text_num_embeds, mask_padding=True):
|
||||
super().__init__()
|
||||
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
|
||||
|
||||
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
||||
|
||||
self.precompute_max_pos = 1024
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
|
||||
|
||||
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
|
||||
text = text + 1
|
||||
if drop_text:
|
||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||
if self.mask_padding:
|
||||
text_mask = text == 0
|
||||
|
||||
if drop_text: # cfg for text
|
||||
text = torch.zeros_like(text)
|
||||
text = self.text_embed(text)
|
||||
|
||||
text = self.text_embed(text) # b nt -> b nt d
|
||||
|
||||
# sinus pos emb
|
||||
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
|
||||
@@ -49,6 +55,9 @@ class TextEmbedding(nn.Module):
|
||||
|
||||
text = text + text_pos_embed
|
||||
|
||||
if self.mask_padding:
|
||||
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@@ -83,13 +92,16 @@ class MMDiT(nn.Module):
|
||||
dim_head=64,
|
||||
dropout=0.1,
|
||||
ff_mult=4,
|
||||
text_num_embeds=256,
|
||||
mel_dim=100,
|
||||
text_num_embeds=256,
|
||||
text_mask_padding=True,
|
||||
qk_norm=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_embed = TimestepEmbedding(dim)
|
||||
self.text_embed = TextEmbedding(dim, text_num_embeds)
|
||||
self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding)
|
||||
self.text_cond, self.text_uncond = None, None # text cache
|
||||
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
||||
|
||||
self.rotary_embed = RotaryEmbedding(dim_head)
|
||||
@@ -106,13 +118,33 @@ class MMDiT(nn.Module):
|
||||
dropout=dropout,
|
||||
ff_mult=ff_mult,
|
||||
context_pre_only=i == depth - 1,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
|
||||
self.proj_out = nn.Linear(dim, mel_dim)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
# Zero-out AdaLN layers in MMDiT blocks:
|
||||
for block in self.transformer_blocks:
|
||||
nn.init.constant_(block.attn_norm_x.linear.weight, 0)
|
||||
nn.init.constant_(block.attn_norm_x.linear.bias, 0)
|
||||
nn.init.constant_(block.attn_norm_c.linear.weight, 0)
|
||||
nn.init.constant_(block.attn_norm_c.linear.bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.norm_out.linear.weight, 0)
|
||||
nn.init.constant_(self.norm_out.linear.bias, 0)
|
||||
nn.init.constant_(self.proj_out.weight, 0)
|
||||
nn.init.constant_(self.proj_out.bias, 0)
|
||||
|
||||
def clear_cache(self):
|
||||
self.text_cond, self.text_uncond = None, None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
@@ -122,6 +154,7 @@ class MMDiT(nn.Module):
|
||||
drop_audio_cond, # cfg for cond audio
|
||||
drop_text, # cfg for text
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
cache=False,
|
||||
):
|
||||
batch = x.shape[0]
|
||||
if time.ndim == 0:
|
||||
@@ -129,7 +162,17 @@ class MMDiT(nn.Module):
|
||||
|
||||
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
c = self.text_embed(text, drop_text=drop_text)
|
||||
if cache:
|
||||
if drop_text:
|
||||
if self.text_uncond is None:
|
||||
self.text_uncond = self.text_embed(text, drop_text=True)
|
||||
c = self.text_uncond
|
||||
else:
|
||||
if self.text_cond is None:
|
||||
self.text_cond = self.text_embed(text, drop_text=False)
|
||||
c = self.text_cond
|
||||
else:
|
||||
c = self.text_embed(text, drop_text=drop_text)
|
||||
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
seq_len = x.shape[1]
|
||||
|
||||
@@ -33,10 +33,12 @@ from f5_tts.model.modules import (
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
||||
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
|
||||
super().__init__()
|
||||
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
||||
|
||||
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
||||
|
||||
if conv_layers > 0:
|
||||
self.extra_modeling = True
|
||||
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
||||
@@ -52,6 +54,8 @@ class TextEmbedding(nn.Module):
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
batch, text_len = text.shape[0], text.shape[1]
|
||||
text = F.pad(text, (0, seq_len - text_len), value=0)
|
||||
if self.mask_padding:
|
||||
text_mask = text == 0
|
||||
|
||||
if drop_text: # cfg for text
|
||||
text = torch.zeros_like(text)
|
||||
@@ -67,7 +71,13 @@ class TextEmbedding(nn.Module):
|
||||
text = text + text_pos_embed
|
||||
|
||||
# convnextv2 blocks
|
||||
text = self.text_blocks(text)
|
||||
if self.mask_padding:
|
||||
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
||||
for block in self.text_blocks:
|
||||
text = block(text)
|
||||
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
||||
else:
|
||||
text = self.text_blocks(text)
|
||||
|
||||
return text
|
||||
|
||||
@@ -106,7 +116,10 @@ class UNetT(nn.Module):
|
||||
mel_dim=100,
|
||||
text_num_embeds=256,
|
||||
text_dim=None,
|
||||
text_mask_padding=True,
|
||||
qk_norm=None,
|
||||
conv_layers=0,
|
||||
pe_attn_head=None,
|
||||
skip_connect_type: Literal["add", "concat", "none"] = "concat",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -115,7 +128,10 @@ class UNetT(nn.Module):
|
||||
self.time_embed = TimestepEmbedding(dim)
|
||||
if text_dim is None:
|
||||
text_dim = mel_dim
|
||||
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
|
||||
self.text_embed = TextEmbedding(
|
||||
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
|
||||
)
|
||||
self.text_cond, self.text_uncond = None, None # text cache
|
||||
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
||||
|
||||
self.rotary_embed = RotaryEmbedding(dim_head)
|
||||
@@ -134,11 +150,12 @@ class UNetT(nn.Module):
|
||||
|
||||
attn_norm = RMSNorm(dim)
|
||||
attn = Attention(
|
||||
processor=AttnProcessor(),
|
||||
processor=AttnProcessor(pe_attn_head=pe_attn_head),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
|
||||
ff_norm = RMSNorm(dim)
|
||||
@@ -161,6 +178,9 @@ class UNetT(nn.Module):
|
||||
self.norm_out = RMSNorm(dim)
|
||||
self.proj_out = nn.Linear(dim, mel_dim)
|
||||
|
||||
def clear_cache(self):
|
||||
self.text_cond, self.text_uncond = None, None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
@@ -170,6 +190,7 @@ class UNetT(nn.Module):
|
||||
drop_audio_cond, # cfg for cond audio
|
||||
drop_text, # cfg for text
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
cache=False,
|
||||
):
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
@@ -177,7 +198,17 @@ class UNetT(nn.Module):
|
||||
|
||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
||||
if cache:
|
||||
if drop_text:
|
||||
if self.text_uncond is None:
|
||||
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
|
||||
text_embed = self.text_uncond
|
||||
else:
|
||||
if self.text_cond is None:
|
||||
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
|
||||
text_embed = self.text_cond
|
||||
else:
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
# postfix time t to input x, [b n d] -> [b n+1 d]
|
||||
|
||||
@@ -162,13 +162,13 @@ class CFM(nn.Module):
|
||||
|
||||
# predict flow
|
||||
pred = self.transformer(
|
||||
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
|
||||
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
|
||||
)
|
||||
if cfg_strength < 1e-5:
|
||||
return pred
|
||||
|
||||
null_pred = self.transformer(
|
||||
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
|
||||
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
|
||||
)
|
||||
return pred + (pred - null_pred) * cfg_strength
|
||||
|
||||
@@ -195,6 +195,7 @@ class CFM(nn.Module):
|
||||
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
||||
|
||||
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
||||
self.transformer.clear_cache()
|
||||
|
||||
sampled = trajectory[-1]
|
||||
out = sampled
|
||||
|
||||
@@ -173,7 +173,7 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
|
||||
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_residual: bool = False
|
||||
):
|
||||
self.sampler = sampler
|
||||
self.frames_threshold = frames_threshold
|
||||
@@ -208,12 +208,15 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
||||
batch = []
|
||||
batch_frames = 0
|
||||
|
||||
if not drop_last and len(batch) > 0:
|
||||
if not drop_residual and len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
del indices
|
||||
self.batches = batches
|
||||
|
||||
# Ensure even batches with accelerate BatchSamplerShard cls under frame_per_batch setting
|
||||
self.drop_last = True
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
"""Sets the epoch for this sampler."""
|
||||
self.epoch = epoch
|
||||
|
||||
@@ -269,11 +269,36 @@ class ConvNeXtV2Block(nn.Module):
|
||||
return residual + x
|
||||
|
||||
|
||||
# AdaLayerNormZero
|
||||
# RMSNorm
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.native_rms_norm = float(torch.__version__[:3]) >= 2.4
|
||||
|
||||
def forward(self, x):
|
||||
if self.native_rms_norm:
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
x = x.to(self.weight.dtype)
|
||||
x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps)
|
||||
else:
|
||||
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.eps)
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
x = x.to(self.weight.dtype)
|
||||
x = x * self.weight
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# AdaLayerNorm
|
||||
# return with modulated x for attn input, and params for later mlp modulation
|
||||
|
||||
|
||||
class AdaLayerNormZero(nn.Module):
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
|
||||
@@ -290,11 +315,11 @@ class AdaLayerNormZero(nn.Module):
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
# AdaLayerNormZero for final layer
|
||||
# AdaLayerNorm for final layer
|
||||
# return only with modulated x for attn input, cuz no more mlp modulation
|
||||
|
||||
|
||||
class AdaLayerNormZero_Final(nn.Module):
|
||||
class AdaLayerNorm_Final(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
|
||||
@@ -341,7 +366,8 @@ class Attention(nn.Module):
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
context_dim: Optional[int] = None, # if not None -> joint attention
|
||||
context_pre_only=None,
|
||||
context_pre_only: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -362,18 +388,32 @@ class Attention(nn.Module):
|
||||
self.to_k = nn.Linear(dim, self.inner_dim)
|
||||
self.to_v = nn.Linear(dim, self.inner_dim)
|
||||
|
||||
if qk_norm is None:
|
||||
self.q_norm = None
|
||||
self.k_norm = None
|
||||
elif qk_norm == "rms_norm":
|
||||
self.q_norm = RMSNorm(dim_head, eps=1e-6)
|
||||
self.k_norm = RMSNorm(dim_head, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Unimplemented qk_norm: {qk_norm}")
|
||||
|
||||
if self.context_dim is not None:
|
||||
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
||||
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
||||
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
||||
if self.context_pre_only is not None:
|
||||
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
||||
if qk_norm is None:
|
||||
self.c_q_norm = None
|
||||
self.c_k_norm = None
|
||||
elif qk_norm == "rms_norm":
|
||||
self.c_q_norm = RMSNorm(dim_head, eps=1e-6)
|
||||
self.c_k_norm = RMSNorm(dim_head, eps=1e-6)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_out_c = nn.Linear(self.inner_dim, dim)
|
||||
if self.context_dim is not None and not self.context_pre_only:
|
||||
self.to_out_c = nn.Linear(self.inner_dim, context_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -393,8 +433,11 @@ class Attention(nn.Module):
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
|
||||
):
|
||||
self.pe_attn_head = pe_attn_head
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -405,19 +448,11 @@ class AttnProcessor:
|
||||
) -> torch.FloatTensor:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
# `sample` projections
|
||||
query = attn.to_q(x)
|
||||
key = attn.to_k(x)
|
||||
value = attn.to_v(x)
|
||||
|
||||
# apply rotary position embedding
|
||||
if rope is not None:
|
||||
freqs, xpos_scale = rope
|
||||
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||
|
||||
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||
|
||||
# attention
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
@@ -425,6 +460,25 @@ class AttnProcessor:
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# qk norm
|
||||
if attn.q_norm is not None:
|
||||
query = attn.q_norm(query)
|
||||
if attn.k_norm is not None:
|
||||
key = attn.k_norm(key)
|
||||
|
||||
# apply rotary position embedding
|
||||
if rope is not None:
|
||||
freqs, xpos_scale = rope
|
||||
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||
|
||||
if self.pe_attn_head is not None:
|
||||
pn = self.pe_attn_head
|
||||
query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale)
|
||||
key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale)
|
||||
else:
|
||||
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if mask is not None:
|
||||
attn_mask = mask
|
||||
@@ -470,16 +524,36 @@ class JointAttnProcessor:
|
||||
|
||||
batch_size = c.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
# `sample` projections
|
||||
query = attn.to_q(x)
|
||||
key = attn.to_k(x)
|
||||
value = attn.to_v(x)
|
||||
|
||||
# `context` projections.
|
||||
# `context` projections
|
||||
c_query = attn.to_q_c(c)
|
||||
c_key = attn.to_k_c(c)
|
||||
c_value = attn.to_v_c(c)
|
||||
|
||||
# attention
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
c_query = c_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# qk norm
|
||||
if attn.q_norm is not None:
|
||||
query = attn.q_norm(query)
|
||||
if attn.k_norm is not None:
|
||||
key = attn.k_norm(key)
|
||||
if attn.c_q_norm is not None:
|
||||
c_query = attn.c_q_norm(c_query)
|
||||
if attn.c_k_norm is not None:
|
||||
c_key = attn.c_k_norm(c_key)
|
||||
|
||||
# apply rope for context and noised input independently
|
||||
if rope is not None:
|
||||
freqs, xpos_scale = rope
|
||||
@@ -492,16 +566,10 @@ class JointAttnProcessor:
|
||||
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
||||
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
||||
|
||||
# attention
|
||||
query = torch.cat([query, c_query], dim=1)
|
||||
key = torch.cat([key, c_key], dim=1)
|
||||
value = torch.cat([value, c_value], dim=1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
# joint attention
|
||||
query = torch.cat([query, c_query], dim=2)
|
||||
key = torch.cat([key, c_key], dim=2)
|
||||
value = torch.cat([value, c_value], dim=2)
|
||||
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if mask is not None:
|
||||
@@ -540,16 +608,17 @@ class JointAttnProcessor:
|
||||
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
|
||||
super().__init__()
|
||||
|
||||
self.attn_norm = AdaLayerNormZero(dim)
|
||||
self.attn_norm = AdaLayerNorm(dim)
|
||||
self.attn = Attention(
|
||||
processor=AttnProcessor(),
|
||||
processor=AttnProcessor(pe_attn_head=pe_attn_head),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
|
||||
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
@@ -585,26 +654,30 @@ class MMDiTBlock(nn.Module):
|
||||
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
||||
"""
|
||||
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
|
||||
def __init__(
|
||||
self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if context_dim is None:
|
||||
context_dim = dim
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
||||
self.attn_norm_x = AdaLayerNormZero(dim)
|
||||
self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim)
|
||||
self.attn_norm_x = AdaLayerNorm(dim)
|
||||
self.attn = Attention(
|
||||
processor=JointAttnProcessor(),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
context_dim=dim,
|
||||
context_dim=context_dim,
|
||||
context_pre_only=context_pre_only,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
|
||||
if not context_pre_only:
|
||||
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||
self.ff_norm_c = nn.LayerNorm(context_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_c = FeedForward(dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||
else:
|
||||
self.ff_norm_c = None
|
||||
self.ff_c = None
|
||||
|
||||
@@ -32,7 +32,7 @@ class Trainer:
|
||||
save_per_updates=1000,
|
||||
keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
checkpoint_path=None,
|
||||
batch_size=32,
|
||||
batch_size_per_gpu=32,
|
||||
batch_size_type: str = "sample",
|
||||
max_samples=32,
|
||||
grad_accumulation_steps=1,
|
||||
@@ -40,7 +40,7 @@ class Trainer:
|
||||
noise_scheduler: str | None = None,
|
||||
duration_predictor: torch.nn.Module | None = None,
|
||||
logger: str | None = "wandb", # "wandb" | "tensorboard" | None
|
||||
wandb_project="test_e2-tts",
|
||||
wandb_project="test_f5-tts",
|
||||
wandb_run_name="test_run",
|
||||
wandb_resume_id: str = None,
|
||||
log_samples: bool = False,
|
||||
@@ -51,6 +51,7 @@ class Trainer:
|
||||
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
||||
is_local_vocoder: bool = False, # use local path vocoder
|
||||
local_vocoder_path: str = "", # local vocoder path
|
||||
cfg_dict: dict = dict(), # training config
|
||||
):
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
|
||||
@@ -72,21 +73,23 @@ class Trainer:
|
||||
else:
|
||||
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
||||
|
||||
self.accelerator.init_trackers(
|
||||
project_name=wandb_project,
|
||||
init_kwargs=init_kwargs,
|
||||
config={
|
||||
if not cfg_dict:
|
||||
cfg_dict = {
|
||||
"epochs": epochs,
|
||||
"learning_rate": learning_rate,
|
||||
"num_warmup_updates": num_warmup_updates,
|
||||
"batch_size": batch_size,
|
||||
"batch_size_per_gpu": batch_size_per_gpu,
|
||||
"batch_size_type": batch_size_type,
|
||||
"max_samples": max_samples,
|
||||
"grad_accumulation_steps": grad_accumulation_steps,
|
||||
"max_grad_norm": max_grad_norm,
|
||||
"gpus": self.accelerator.num_processes,
|
||||
"noise_scheduler": noise_scheduler,
|
||||
},
|
||||
}
|
||||
cfg_dict["gpus"] = self.accelerator.num_processes
|
||||
self.accelerator.init_trackers(
|
||||
project_name=wandb_project,
|
||||
init_kwargs=init_kwargs,
|
||||
config=cfg_dict,
|
||||
)
|
||||
|
||||
elif self.logger == "tensorboard":
|
||||
@@ -111,9 +114,9 @@ class Trainer:
|
||||
self.save_per_updates = save_per_updates
|
||||
self.keep_last_n_checkpoints = keep_last_n_checkpoints
|
||||
self.last_per_updates = default(last_per_updates, save_per_updates)
|
||||
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
||||
self.checkpoint_path = default(checkpoint_path, "ckpts/test_f5-tts")
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.batch_size_per_gpu = batch_size_per_gpu
|
||||
self.batch_size_type = batch_size_type
|
||||
self.max_samples = max_samples
|
||||
self.grad_accumulation_steps = grad_accumulation_steps
|
||||
@@ -179,7 +182,7 @@ class Trainer:
|
||||
if (
|
||||
not exists(self.checkpoint_path)
|
||||
or not os.path.exists(self.checkpoint_path)
|
||||
or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path))
|
||||
or not any(filename.endswith((".pt", ".safetensors")) for filename in os.listdir(self.checkpoint_path))
|
||||
):
|
||||
return 0
|
||||
|
||||
@@ -191,7 +194,7 @@ class Trainer:
|
||||
all_checkpoints = [
|
||||
f
|
||||
for f in os.listdir(self.checkpoint_path)
|
||||
if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith(".pt")
|
||||
if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith((".pt", ".safetensors"))
|
||||
]
|
||||
|
||||
# First try to find regular training checkpoints
|
||||
@@ -205,8 +208,16 @@ class Trainer:
|
||||
# If no training checkpoints, use pretrained model
|
||||
latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
|
||||
|
||||
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
||||
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
|
||||
if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint
|
||||
from safetensors.torch import load_file
|
||||
|
||||
checkpoint = load_file(f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu")
|
||||
checkpoint = {"ema_model_state_dict": checkpoint}
|
||||
elif latest_checkpoint.endswith(".pt"):
|
||||
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
||||
checkpoint = torch.load(
|
||||
f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu"
|
||||
)
|
||||
|
||||
# patch for backward compatibility, 305e3ea
|
||||
for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
|
||||
@@ -271,7 +282,7 @@ class Trainer:
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
batch_size=self.batch_size,
|
||||
batch_size=self.batch_size_per_gpu,
|
||||
shuffle=True,
|
||||
generator=generator,
|
||||
)
|
||||
@@ -280,10 +291,10 @@ class Trainer:
|
||||
sampler = SequentialSampler(train_dataset)
|
||||
batch_sampler = DynamicBatchSampler(
|
||||
sampler,
|
||||
self.batch_size,
|
||||
self.batch_size_per_gpu,
|
||||
max_samples=self.max_samples,
|
||||
random_seed=resumable_with_seed, # This enables reproducible shuffling
|
||||
drop_last=False,
|
||||
drop_residual=False,
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
@@ -339,7 +350,7 @@ class Trainer:
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
|
||||
desc=f"Epoch {epoch+1}/{self.epochs}",
|
||||
desc=f"Epoch {epoch + 1}/{self.epochs}",
|
||||
unit="update",
|
||||
disable=not self.accelerator.is_local_main_process,
|
||||
initial=progress_bar_initial,
|
||||
@@ -417,6 +428,7 @@ class Trainer:
|
||||
torchaudio.save(
|
||||
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
|
||||
)
|
||||
self.model.train()
|
||||
|
||||
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
|
||||
self.save_checkpoint(global_update, last=True)
|
||||
|
||||
@@ -133,11 +133,12 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
||||
|
||||
# convert char to pinyin
|
||||
|
||||
jieba.initialize()
|
||||
print("Word segmentation module jieba initialized.\n")
|
||||
|
||||
|
||||
def convert_char_to_pinyin(text_list, polyphone=True):
|
||||
if jieba.dt.initialized is False:
|
||||
jieba.default_logger.setLevel(50) # CRITICAL
|
||||
jieba.initialize()
|
||||
|
||||
final_text_list = []
|
||||
custom_trans = str.maketrans(
|
||||
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
||||
|
||||
@@ -9,7 +9,7 @@ mel_hop_length = 256
|
||||
mel_sampling_rate = 24000
|
||||
|
||||
# target
|
||||
wanted_max_updates = 1000000
|
||||
wanted_max_updates = 1200000
|
||||
|
||||
# train params
|
||||
gpus = 8
|
||||
@@ -24,7 +24,7 @@ updates_per_epoch = total_hours / mini_batch_hours
|
||||
|
||||
# result
|
||||
epochs = wanted_max_updates / updates_per_epoch
|
||||
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
|
||||
print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})")
|
||||
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
||||
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
||||
|
||||
|
||||
61
src/f5_tts/socket_client.py
Normal file
61
src/f5_tts/socket_client.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import socket
|
||||
import asyncio
|
||||
import pyaudio
|
||||
import numpy as np
|
||||
import logging
|
||||
import time
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
|
||||
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
|
||||
|
||||
start_time = time.time()
|
||||
first_chunk_time = None
|
||||
|
||||
async def play_audio_stream():
|
||||
nonlocal first_chunk_time
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
|
||||
if not data:
|
||||
break
|
||||
if data == b"END":
|
||||
logger.info("End of audio received.")
|
||||
break
|
||||
|
||||
audio_array = np.frombuffer(data, dtype=np.float32)
|
||||
stream.write(audio_array.tobytes())
|
||||
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time.time()
|
||||
|
||||
finally:
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
|
||||
|
||||
try:
|
||||
data_to_send = f"{text}".encode("utf-8")
|
||||
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
|
||||
await play_audio_stream()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in listen_to_F5TTS: {e}")
|
||||
|
||||
finally:
|
||||
client_socket.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
|
||||
|
||||
asyncio.run(listen_to_F5TTS(text_to_send))
|
||||
@@ -13,8 +13,9 @@ from importlib.resources import files
|
||||
import torch
|
||||
import torchaudio
|
||||
from huggingface_hub import hf_hub_download
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.model.backbones.dit import DiT
|
||||
from f5_tts.infer.utils_infer import (
|
||||
chunk_text,
|
||||
preprocess_ref_audio_text,
|
||||
@@ -68,7 +69,7 @@ class AudioFileWriterThread(threading.Thread):
|
||||
|
||||
|
||||
class TTSStreamingProcessor:
|
||||
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
|
||||
def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
|
||||
self.device = device or (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
@@ -78,21 +79,24 @@ class TTSStreamingProcessor:
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
self.mel_spec_type = "vocos"
|
||||
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
||||
self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
self.model_arc = model_cfg.model.arch
|
||||
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
|
||||
|
||||
self.model = self.load_ema_model(ckpt_file, vocab_file, dtype)
|
||||
self.vocoder = self.load_vocoder_model()
|
||||
self.sampling_rate = 24000
|
||||
|
||||
self.update_reference(ref_audio, ref_text)
|
||||
self._warm_up()
|
||||
self.file_writer_thread = None
|
||||
self.first_package = True
|
||||
|
||||
def load_ema_model(self, ckpt_file, vocab_file, dtype):
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
model_cls = DiT
|
||||
return load_model(
|
||||
model_cls=model_cls,
|
||||
model_cfg=model_cfg,
|
||||
self.model_cls,
|
||||
self.model_arc,
|
||||
ckpt_path=ckpt_file,
|
||||
mel_spec_type=self.mel_spec_type,
|
||||
vocab_file=vocab_file,
|
||||
@@ -212,9 +216,14 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--port", default=9998)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="F5TTS_v1_Base",
|
||||
help="The model name, e.g. F5TTS_v1_Base",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ckpt_file",
|
||||
default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_Base/model_1200000.safetensors")),
|
||||
default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_v1_Base/model_1250000.safetensors")),
|
||||
help="Path to the model checkpoint file",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -242,6 +251,7 @@ if __name__ == "__main__":
|
||||
try:
|
||||
# Initialize the processor with the model and vocoder
|
||||
processor = TTSStreamingProcessor(
|
||||
model=args.model,
|
||||
ckpt_file=args.ckpt_file,
|
||||
vocab_file=args.vocab_file,
|
||||
ref_audio=args.ref_audio,
|
||||
|
||||
@@ -40,10 +40,10 @@ Once your datasets are prepared, you can start the training process.
|
||||
accelerate config
|
||||
|
||||
# .yaml files are under src/f5_tts/configs directory
|
||||
accelerate launch src/f5_tts/train/train.py --config-name F5TTS_Base_train.yaml
|
||||
accelerate launch src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml
|
||||
|
||||
# possible to overwrite accelerate and hydra config
|
||||
accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_Small_train.yaml ++datasets.batch_size_per_gpu=19200
|
||||
accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml ++datasets.batch_size_per_gpu=19200
|
||||
```
|
||||
|
||||
### 2. Finetuning practice
|
||||
@@ -51,9 +51,13 @@ Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussio
|
||||
|
||||
Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
|
||||
|
||||
The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
|
||||
If want to finetune with a variant version e.g. *F5TTS_v1_Base_no_zero_init*, manually download pretrained checkpoint from model weight repository and fill in the path correspondingly on web interface.
|
||||
|
||||
### 3. Wandb Logging
|
||||
If use tensorboard as logger, install it first with `pip install tensorboard`.
|
||||
|
||||
<ins>The `use_ema = True` might be harmful for early-stage finetuned checkpoints</ins> (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off with finetune gradio option or `load_model(..., use_ema=False)`, see if offer better results.
|
||||
|
||||
### 3. W&B Logging
|
||||
|
||||
The `wandb/` dir will be created under path you run training/finetuning scripts.
|
||||
|
||||
@@ -62,7 +66,7 @@ By default, the training script does NOT use logging (assuming you didn't manual
|
||||
To turn on wandb logging, you can either:
|
||||
|
||||
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
|
||||
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
|
||||
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/authorize and set the environment variable as follows:
|
||||
|
||||
On Mac & Linux:
|
||||
|
||||
@@ -75,7 +79,7 @@ On Windows:
|
||||
```
|
||||
set WANDB_API_KEY=<YOUR WANDB API KEY>
|
||||
```
|
||||
Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows:
|
||||
Moreover, if you couldn't access W&B and want to log metrics offline, you can set the environment variable as follows:
|
||||
|
||||
```
|
||||
export WANDB_MODE=offline
|
||||
|
||||
@@ -122,7 +122,7 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None):
|
||||
for future in tqdm(
|
||||
chunk_futures,
|
||||
total=len(chunk),
|
||||
desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}",
|
||||
desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}",
|
||||
):
|
||||
try:
|
||||
result = future.result()
|
||||
@@ -233,7 +233,7 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
|
||||
dataset_name = out_dir.stem
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
|
||||
|
||||
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
|
||||
|
||||
@@ -198,7 +198,7 @@ def main():
|
||||
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
if "ZH" in langs:
|
||||
print(f"Bad zh transcription case: {total_bad_case_zh}")
|
||||
if "EN" in langs:
|
||||
|
||||
@@ -72,7 +72,7 @@ def main():
|
||||
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -50,7 +50,7 @@ def main():
|
||||
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
from importlib.resources import files
|
||||
|
||||
from cached_path import cached_path
|
||||
|
||||
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
from f5_tts.model.dataset import load_dataset
|
||||
from importlib.resources import files
|
||||
|
||||
|
||||
# -------------------------- Dataset Settings --------------------------- #
|
||||
@@ -20,19 +21,14 @@ mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
|
||||
|
||||
# -------------------------- Argument Parsing --------------------------- #
|
||||
def parse_args():
|
||||
# batch_size_per_gpu = 1000 settting for gpu 8GB
|
||||
# batch_size_per_gpu = 1600 settting for gpu 12GB
|
||||
# batch_size_per_gpu = 2000 settting for gpu 16GB
|
||||
# batch_size_per_gpu = 3200 settting for gpu 24GB
|
||||
|
||||
# num_warmup_updates = 300 for 5000 sample about 10 hours
|
||||
|
||||
# change save_per_updates , last_per_updates change this value what you need ,
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train CFM Model")
|
||||
|
||||
parser.add_argument(
|
||||
"--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
|
||||
"--exp_name",
|
||||
type=str,
|
||||
default="F5TTS_v1_Base",
|
||||
choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"],
|
||||
help="Experiment name",
|
||||
)
|
||||
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
|
||||
@@ -44,15 +40,15 @@ def parse_args():
|
||||
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
||||
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
||||
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
||||
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
|
||||
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
|
||||
parser.add_argument("--num_warmup_updates", type=int, default=20000, help="Warmup updates")
|
||||
parser.add_argument("--save_per_updates", type=int, default=50000, help="Save checkpoint every N updates")
|
||||
parser.add_argument(
|
||||
"--keep_last_n_checkpoints",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
|
||||
)
|
||||
parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
|
||||
parser.add_argument("--last_per_updates", type=int, default=5000, help="Save last checkpoint every N updates")
|
||||
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
|
||||
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
|
||||
parser.add_argument(
|
||||
@@ -69,7 +65,7 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Log inferenced samples per ckpt save updates",
|
||||
)
|
||||
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
|
||||
parser.add_argument("--logger", type=str, default=None, choices=[None, "wandb", "tensorboard"], help="logger")
|
||||
parser.add_argument(
|
||||
"--bnb_optimizer",
|
||||
action="store_true",
|
||||
@@ -88,19 +84,54 @@ def main():
|
||||
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
|
||||
|
||||
# Model parameters based on experiment name
|
||||
if args.exp_name == "F5TTS_Base":
|
||||
|
||||
if args.exp_name == "F5TTS_v1_Base":
|
||||
wandb_resume_id = None
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
model_cfg = dict(
|
||||
dim=1024,
|
||||
depth=22,
|
||||
heads=16,
|
||||
ff_mult=2,
|
||||
text_dim=512,
|
||||
conv_layers=4,
|
||||
)
|
||||
if args.finetune:
|
||||
if args.pretrain is None:
|
||||
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
|
||||
else:
|
||||
ckpt_path = args.pretrain
|
||||
|
||||
elif args.exp_name == "F5TTS_Base":
|
||||
wandb_resume_id = None
|
||||
model_cls = DiT
|
||||
model_cfg = dict(
|
||||
dim=1024,
|
||||
depth=22,
|
||||
heads=16,
|
||||
ff_mult=2,
|
||||
text_dim=512,
|
||||
text_mask_padding=False,
|
||||
conv_layers=4,
|
||||
pe_attn_head=1,
|
||||
)
|
||||
if args.finetune:
|
||||
if args.pretrain is None:
|
||||
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
||||
else:
|
||||
ckpt_path = args.pretrain
|
||||
|
||||
elif args.exp_name == "E2TTS_Base":
|
||||
wandb_resume_id = None
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
model_cfg = dict(
|
||||
dim=1024,
|
||||
depth=24,
|
||||
heads=16,
|
||||
ff_mult=4,
|
||||
text_mask_padding=False,
|
||||
pe_attn_head=1,
|
||||
)
|
||||
if args.finetune:
|
||||
if args.pretrain is None:
|
||||
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
||||
@@ -120,6 +151,7 @@ def main():
|
||||
print("copy checkpoint for finetune")
|
||||
|
||||
# Use the tokenizer and tokenizer_path provided in the command line arguments
|
||||
|
||||
tokenizer = args.tokenizer
|
||||
if tokenizer == "custom":
|
||||
if not args.tokenizer_path:
|
||||
@@ -156,7 +188,7 @@ def main():
|
||||
save_per_updates=args.save_per_updates,
|
||||
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
|
||||
checkpoint_path=checkpoint_path,
|
||||
batch_size=args.batch_size_per_gpu,
|
||||
batch_size_per_gpu=args.batch_size_per_gpu,
|
||||
batch_size_type=args.batch_size_type,
|
||||
max_samples=args.max_samples,
|
||||
grad_accumulation_steps=args.grad_accumulation_steps,
|
||||
|
||||
@@ -1,36 +1,36 @@
|
||||
import threading
|
||||
import queue
|
||||
import re
|
||||
|
||||
import gc
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import platform
|
||||
import psutil
|
||||
import queue
|
||||
import random
|
||||
import re
|
||||
import signal
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from glob import glob
|
||||
from importlib.resources import files
|
||||
from scipy.io import wavfile
|
||||
|
||||
import click
|
||||
import gradio as gr
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from cached_path import cached_path
|
||||
from datasets import Dataset as Dataset_
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from safetensors.torch import save_file
|
||||
from scipy.io import wavfile
|
||||
from cached_path import cached_path
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from f5_tts.api import F5TTS
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
from f5_tts.infer.utils_infer import transcribe
|
||||
from importlib.resources import files
|
||||
|
||||
|
||||
training_process = None
|
||||
@@ -118,24 +118,24 @@ def load_settings(project_name):
|
||||
|
||||
# Default settings
|
||||
default_settings = {
|
||||
"exp_name": "F5TTS_Base",
|
||||
"learning_rate": 1e-05,
|
||||
"batch_size_per_gpu": 1000,
|
||||
"exp_name": "F5TTS_v1_Base",
|
||||
"learning_rate": 1e-5,
|
||||
"batch_size_per_gpu": 3200,
|
||||
"batch_size_type": "frame",
|
||||
"max_samples": 64,
|
||||
"grad_accumulation_steps": 1,
|
||||
"max_grad_norm": 1,
|
||||
"max_grad_norm": 1.0,
|
||||
"epochs": 100,
|
||||
"num_warmup_updates": 2,
|
||||
"save_per_updates": 300,
|
||||
"num_warmup_updates": 100,
|
||||
"save_per_updates": 500,
|
||||
"keep_last_n_checkpoints": -1,
|
||||
"last_per_updates": 100,
|
||||
"finetune": True,
|
||||
"file_checkpoint_train": "",
|
||||
"tokenizer_type": "pinyin",
|
||||
"tokenizer_file": "",
|
||||
"mixed_precision": "none",
|
||||
"logger": "wandb",
|
||||
"mixed_precision": "fp16",
|
||||
"logger": "none",
|
||||
"bnb_optimizer": False,
|
||||
}
|
||||
|
||||
@@ -361,27 +361,27 @@ def terminate_process(pid):
|
||||
|
||||
|
||||
def start_training(
|
||||
dataset_name="",
|
||||
exp_name="F5TTS_Base",
|
||||
learning_rate=1e-4,
|
||||
batch_size_per_gpu=400,
|
||||
batch_size_type="frame",
|
||||
max_samples=64,
|
||||
grad_accumulation_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
epochs=11,
|
||||
num_warmup_updates=200,
|
||||
save_per_updates=400,
|
||||
keep_last_n_checkpoints=-1,
|
||||
last_per_updates=800,
|
||||
finetune=True,
|
||||
file_checkpoint_train="",
|
||||
tokenizer_type="pinyin",
|
||||
tokenizer_file="",
|
||||
mixed_precision="fp16",
|
||||
stream=False,
|
||||
logger="wandb",
|
||||
ch_8bit_adam=False,
|
||||
dataset_name,
|
||||
exp_name,
|
||||
learning_rate,
|
||||
batch_size_per_gpu,
|
||||
batch_size_type,
|
||||
max_samples,
|
||||
grad_accumulation_steps,
|
||||
max_grad_norm,
|
||||
epochs,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
keep_last_n_checkpoints,
|
||||
last_per_updates,
|
||||
finetune,
|
||||
file_checkpoint_train,
|
||||
tokenizer_type,
|
||||
tokenizer_file,
|
||||
mixed_precision,
|
||||
stream,
|
||||
logger,
|
||||
ch_8bit_adam,
|
||||
):
|
||||
global training_process, tts_api, stop_signal
|
||||
|
||||
@@ -458,7 +458,10 @@ def start_training(
|
||||
|
||||
cmd += f" --tokenizer {tokenizer_type}"
|
||||
|
||||
cmd += f" --log_samples --logger {logger}"
|
||||
if logger != "none":
|
||||
cmd += f" --logger {logger}"
|
||||
|
||||
cmd += " --log_samples"
|
||||
|
||||
if ch_8bit_adam:
|
||||
cmd += " --bnb_optimizer"
|
||||
@@ -515,7 +518,7 @@ def start_training(
|
||||
training_process = subprocess.Popen(
|
||||
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env
|
||||
)
|
||||
yield "Training started...", gr.update(interactive=False), gr.update(interactive=True)
|
||||
yield "Training started ...", gr.update(interactive=False), gr.update(interactive=True)
|
||||
|
||||
stdout_queue = queue.Queue()
|
||||
stderr_queue = queue.Queue()
|
||||
@@ -584,7 +587,11 @@ def start_training(
|
||||
gr.update(interactive=True),
|
||||
)
|
||||
else:
|
||||
yield "Training complete!", gr.update(interactive=False), gr.update(interactive=True)
|
||||
yield (
|
||||
"Training complete or paused ...",
|
||||
gr.update(interactive=False),
|
||||
gr.update(interactive=True),
|
||||
)
|
||||
break
|
||||
|
||||
# Small sleep to prevent CPU thrashing
|
||||
@@ -598,9 +605,9 @@ def start_training(
|
||||
time.sleep(1)
|
||||
|
||||
if training_process is None:
|
||||
text_info = "train stop"
|
||||
text_info = "Train stopped !"
|
||||
else:
|
||||
text_info = "train complete !"
|
||||
text_info = "Train complete at end !"
|
||||
|
||||
except Exception as e: # Catch all exceptions
|
||||
# Ensure that we reset the training process variable in case of an error
|
||||
@@ -615,11 +622,11 @@ def stop_training():
|
||||
global training_process, stop_signal
|
||||
|
||||
if training_process is None:
|
||||
return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
|
||||
return "Train not running !", gr.update(interactive=True), gr.update(interactive=False)
|
||||
terminate_process_tree(training_process.pid)
|
||||
# training_process = None
|
||||
stop_signal = True
|
||||
return "train stop", gr.update(interactive=True), gr.update(interactive=False)
|
||||
return "Train stopped !", gr.update(interactive=True), gr.update(interactive=False)
|
||||
|
||||
|
||||
def get_list_projects():
|
||||
@@ -797,14 +804,14 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
|
||||
print(f"Error processing {file_audio}: {e}")
|
||||
continue
|
||||
|
||||
if duration < 1 or duration > 25:
|
||||
if duration > 25:
|
||||
error_files.append([file_audio, "duration > 25 sec"])
|
||||
if duration < 1 or duration > 30:
|
||||
if duration > 30:
|
||||
error_files.append([file_audio, "duration > 30 sec"])
|
||||
if duration < 1:
|
||||
error_files.append([file_audio, "duration < 1 sec "])
|
||||
continue
|
||||
if len(text) < 3:
|
||||
error_files.append([file_audio, "very small text len 3"])
|
||||
error_files.append([file_audio, "very short text length 3"])
|
||||
continue
|
||||
|
||||
text = clear_text(text)
|
||||
@@ -871,40 +878,37 @@ def check_user(value):
|
||||
|
||||
def calculate_train(
|
||||
name_project,
|
||||
epochs,
|
||||
learning_rate,
|
||||
batch_size_per_gpu,
|
||||
batch_size_type,
|
||||
max_samples,
|
||||
learning_rate,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_updates,
|
||||
finetune,
|
||||
):
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
file_duraction = os.path.join(path_project, "duration.json")
|
||||
file_duration = os.path.join(path_project, "duration.json")
|
||||
|
||||
if not os.path.isfile(file_duraction):
|
||||
hop_length = 256
|
||||
sampling_rate = 24000
|
||||
|
||||
if not os.path.isfile(file_duration):
|
||||
return (
|
||||
1000,
|
||||
epochs,
|
||||
learning_rate,
|
||||
batch_size_per_gpu,
|
||||
max_samples,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_updates,
|
||||
"project not found !",
|
||||
learning_rate,
|
||||
)
|
||||
|
||||
with open(file_duraction, "r") as file:
|
||||
with open(file_duration, "r") as file:
|
||||
data = json.load(file)
|
||||
|
||||
duration_list = data["duration"]
|
||||
samples = len(duration_list)
|
||||
hours = sum(duration_list) / 3600
|
||||
|
||||
# if torch.cuda.is_available():
|
||||
# gpu_properties = torch.cuda.get_device_properties(0)
|
||||
# total_memory = gpu_properties.total_memory / (1024**3)
|
||||
# elif torch.backends.mps.is_available():
|
||||
# total_memory = psutil.virtual_memory().available / (1024**3)
|
||||
max_sample_length = max(duration_list) * sampling_rate / hop_length
|
||||
total_samples = len(duration_list)
|
||||
total_duration = sum(duration_list)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
gpu_count = torch.cuda.device_count()
|
||||
@@ -912,64 +916,39 @@ def calculate_train(
|
||||
for i in range(gpu_count):
|
||||
gpu_properties = torch.cuda.get_device_properties(i)
|
||||
total_memory += gpu_properties.total_memory / (1024**3) # in GB
|
||||
|
||||
elif torch.xpu.is_available():
|
||||
gpu_count = torch.xpu.device_count()
|
||||
total_memory = 0
|
||||
for i in range(gpu_count):
|
||||
gpu_properties = torch.xpu.get_device_properties(i)
|
||||
total_memory += gpu_properties.total_memory / (1024**3)
|
||||
|
||||
elif torch.backends.mps.is_available():
|
||||
gpu_count = 1
|
||||
total_memory = psutil.virtual_memory().available / (1024**3)
|
||||
|
||||
avg_gpu_memory = total_memory / gpu_count
|
||||
|
||||
# rough estimate of batch size
|
||||
if batch_size_type == "frame":
|
||||
batch = int(total_memory * 0.5)
|
||||
batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
|
||||
batch_size_per_gpu = int(38400 / batch)
|
||||
else:
|
||||
batch_size_per_gpu = int(total_memory / 8)
|
||||
batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
|
||||
batch = batch_size_per_gpu
|
||||
batch_size_per_gpu = max(int(38400 * (avg_gpu_memory - 5) / 75), int(max_sample_length))
|
||||
elif batch_size_type == "sample":
|
||||
batch_size_per_gpu = int(200 / (total_duration / total_samples))
|
||||
|
||||
if batch_size_per_gpu <= 0:
|
||||
batch_size_per_gpu = 1
|
||||
if total_samples < 64:
|
||||
max_samples = int(total_samples * 0.25)
|
||||
|
||||
if samples < 64:
|
||||
max_samples = int(samples * 0.25)
|
||||
else:
|
||||
max_samples = 64
|
||||
num_warmup_updates = max(num_warmup_updates, int(total_samples * 0.05))
|
||||
|
||||
num_warmup_updates = int(samples * 0.05)
|
||||
save_per_updates = int(samples * 0.10)
|
||||
last_per_updates = int(save_per_updates * 0.25)
|
||||
# take 1.2M updates as the maximum
|
||||
max_updates = 1200000
|
||||
|
||||
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
|
||||
num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
|
||||
save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
|
||||
last_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_updates)
|
||||
if last_per_updates <= 0:
|
||||
last_per_updates = 2
|
||||
if batch_size_type == "frame":
|
||||
mini_batch_duration = batch_size_per_gpu * gpu_count * hop_length / sampling_rate
|
||||
updates_per_epoch = total_duration / mini_batch_duration
|
||||
elif batch_size_type == "sample":
|
||||
updates_per_epoch = total_samples / batch_size_per_gpu / gpu_count
|
||||
|
||||
total_hours = hours
|
||||
mel_hop_length = 256
|
||||
mel_sampling_rate = 24000
|
||||
|
||||
# target
|
||||
wanted_max_updates = 1000000
|
||||
|
||||
# train params
|
||||
gpus = gpu_count
|
||||
frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
|
||||
grad_accum = 1
|
||||
|
||||
# intermediate
|
||||
mini_batch_frames = frames_per_gpu * grad_accum * gpus
|
||||
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
|
||||
updates_per_epoch = total_hours / mini_batch_hours
|
||||
# steps_per_epoch = updates_per_epoch * grad_accum
|
||||
epochs = wanted_max_updates / updates_per_epoch
|
||||
epochs = int(max_updates / updates_per_epoch)
|
||||
|
||||
if finetune:
|
||||
learning_rate = 1e-5
|
||||
@@ -977,32 +956,32 @@ def calculate_train(
|
||||
learning_rate = 7.5e-5
|
||||
|
||||
return (
|
||||
epochs,
|
||||
learning_rate,
|
||||
batch_size_per_gpu,
|
||||
max_samples,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_updates,
|
||||
samples,
|
||||
learning_rate,
|
||||
int(epochs),
|
||||
total_samples,
|
||||
)
|
||||
|
||||
|
||||
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str:
|
||||
def prune_checkpoint(checkpoint_path: str, new_checkpoint_path: str, save_ema: bool, safetensors: bool) -> str:
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, weights_only=True)
|
||||
print("Original Checkpoint Keys:", checkpoint.keys())
|
||||
|
||||
ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
|
||||
if ema_model_state_dict is None:
|
||||
return "No 'ema_model_state_dict' found in the checkpoint."
|
||||
to_retain = "ema_model_state_dict" if save_ema else "model_state_dict"
|
||||
try:
|
||||
model_state_dict_to_retain = checkpoint[to_retain]
|
||||
except KeyError:
|
||||
return f"{to_retain} not found in the checkpoint."
|
||||
|
||||
if safetensors:
|
||||
new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors")
|
||||
save_file(ema_model_state_dict, new_checkpoint_path)
|
||||
save_file(model_state_dict_to_retain, new_checkpoint_path)
|
||||
else:
|
||||
new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt")
|
||||
new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
|
||||
new_checkpoint = {"ema_model_state_dict": model_state_dict_to_retain}
|
||||
torch.save(new_checkpoint, new_checkpoint_path)
|
||||
|
||||
return f"New checkpoint saved at: {new_checkpoint_path}"
|
||||
@@ -1021,7 +1000,11 @@ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
if ckpt_path.endswith(".safetensors"):
|
||||
ckpt = load_file(ckpt_path, device="cpu")
|
||||
ckpt = {"ema_model_state_dict": ckpt}
|
||||
elif ckpt_path.endswith(".pt"):
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
ema_sd = ckpt.get("ema_model_state_dict", {})
|
||||
embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
|
||||
@@ -1039,7 +1022,10 @@ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
|
||||
|
||||
ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema])
|
||||
|
||||
torch.save(ckpt, new_ckpt_path)
|
||||
if new_ckpt_path.endswith(".safetensors"):
|
||||
save_file(ema_sd, new_ckpt_path)
|
||||
elif new_ckpt_path.endswith(".pt"):
|
||||
torch.save(ckpt, new_ckpt_path)
|
||||
|
||||
return vocab_new
|
||||
|
||||
@@ -1089,9 +1075,11 @@ def vocab_extend(project_name, symbols, model_type):
|
||||
with open(file_vocab_project, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(vocab))
|
||||
|
||||
if model_type == "F5-TTS":
|
||||
if model_type == "F5TTS_v1_Base":
|
||||
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
|
||||
elif model_type == "F5TTS_Base":
|
||||
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
||||
else:
|
||||
elif model_type == "E2TTS_Base":
|
||||
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
||||
|
||||
vocab_size_new = len(miss_symbols)
|
||||
@@ -1101,7 +1089,7 @@ def vocab_extend(project_name, symbols, model_type):
|
||||
os.makedirs(new_ckpt_path, exist_ok=True)
|
||||
|
||||
# Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
|
||||
new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_model_1200000.pt")
|
||||
new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_" + os.path.basename(ckpt_path))
|
||||
|
||||
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
|
||||
|
||||
@@ -1149,7 +1137,7 @@ def vocab_check(project_name):
|
||||
info = "You can train using your language !"
|
||||
else:
|
||||
vocab_miss = ",".join(miss_symbols)
|
||||
info = f"The following symbols are missing in your language {len(miss_symbols)}\n\n"
|
||||
info = f"The following {len(miss_symbols)} symbols are missing in your language\n\n"
|
||||
|
||||
return info, vocab_miss
|
||||
|
||||
@@ -1231,21 +1219,24 @@ def infer(
|
||||
vocab_file = os.path.join(path_data, project, "vocab.txt")
|
||||
|
||||
tts_api = F5TTS(
|
||||
model_type=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema
|
||||
model=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema
|
||||
)
|
||||
|
||||
print("update >> ", device_test, file_checkpoint, use_ema)
|
||||
|
||||
if seed == -1: # -1 used for random
|
||||
seed = None
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
tts_api.infer(
|
||||
gen_text=gen_text.lower().strip(),
|
||||
ref_text=ref_text.lower().strip(),
|
||||
ref_file=ref_audio,
|
||||
ref_text=ref_text.lower().strip(),
|
||||
gen_text=gen_text.lower().strip(),
|
||||
nfe_step=nfe_step,
|
||||
file_wave=f.name,
|
||||
speed=speed,
|
||||
seed=seed,
|
||||
remove_silence=remove_silence,
|
||||
file_wave=f.name,
|
||||
seed=seed,
|
||||
)
|
||||
return f.name, tts_api.device, str(tts_api.seed)
|
||||
|
||||
@@ -1404,14 +1395,14 @@ def get_audio_select(file_sample):
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# E2/F5 TTS Automatic Finetune
|
||||
# F5 TTS Automatic Finetune
|
||||
|
||||
This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
|
||||
This is a local web UI for F5 TTS finetuning support. This app supports the following TTS models:
|
||||
|
||||
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
|
||||
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
|
||||
|
||||
The checkpoints support English and Chinese.
|
||||
The pretrained checkpoints support English and Chinese.
|
||||
|
||||
For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
|
||||
"""
|
||||
@@ -1454,9 +1445,9 @@ Skip this step if you have your dataset, metadata.csv, and a folder wavs with al
|
||||
)
|
||||
|
||||
audio_speaker = gr.File(label="Voice", type="filepath", file_count="multiple")
|
||||
txt_lang = gr.Text(label="Language", value="English")
|
||||
txt_lang = gr.Textbox(label="Language", value="English")
|
||||
bt_transcribe = bt_create = gr.Button("Transcribe")
|
||||
txt_info_transcribe = gr.Text(label="Info", value="")
|
||||
txt_info_transcribe = gr.Textbox(label="Info", value="")
|
||||
bt_transcribe.click(
|
||||
fn=transcribe_all,
|
||||
inputs=[cm_project, audio_speaker, txt_lang, ch_manual],
|
||||
@@ -1467,7 +1458,7 @@ Skip this step if you have your dataset, metadata.csv, and a folder wavs with al
|
||||
random_sample_transcribe = gr.Button("Random Sample")
|
||||
|
||||
with gr.Row():
|
||||
random_text_transcribe = gr.Text(label="Text")
|
||||
random_text_transcribe = gr.Textbox(label="Text")
|
||||
random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
|
||||
|
||||
random_sample_transcribe.click(
|
||||
@@ -1482,13 +1473,15 @@ Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are incl
|
||||
```""")
|
||||
|
||||
check_button = gr.Button("Check Vocab")
|
||||
txt_info_check = gr.Text(label="Info", value="")
|
||||
txt_info_check = gr.Textbox(label="Info", value="")
|
||||
|
||||
gr.Markdown("""```plaintext
|
||||
Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
|
||||
```""")
|
||||
|
||||
exp_name_extend = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
|
||||
exp_name_extend = gr.Radio(
|
||||
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
txt_extend = gr.Textbox(
|
||||
@@ -1500,7 +1493,7 @@ Using the extended model, you can finetune to a new language that is missing sym
|
||||
txt_count_symbol = gr.Textbox(label="New Vocab Size", value="", scale=1)
|
||||
|
||||
extend_button = gr.Button("Extend")
|
||||
txt_info_extend = gr.Text(label="Info", value="")
|
||||
txt_info_extend = gr.Textbox(label="Info", value="")
|
||||
|
||||
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
|
||||
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
|
||||
@@ -1540,8 +1533,8 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
|
||||
ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False)
|
||||
|
||||
bt_prepare = bt_create = gr.Button("Prepare")
|
||||
txt_info_prepare = gr.Text(label="Info", value="")
|
||||
txt_vocab_prepare = gr.Text(label="Vocab", value="")
|
||||
txt_info_prepare = gr.Textbox(label="Info", value="")
|
||||
txt_vocab_prepare = gr.Textbox(label="Vocab", value="")
|
||||
|
||||
bt_prepare.click(
|
||||
fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare]
|
||||
@@ -1550,61 +1543,73 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
|
||||
random_sample_prepare = gr.Button("Random Sample")
|
||||
|
||||
with gr.Row():
|
||||
random_text_prepare = gr.Text(label="Tokenizer")
|
||||
random_text_prepare = gr.Textbox(label="Tokenizer")
|
||||
random_audio_prepare = gr.Audio(label="Audio", type="filepath")
|
||||
|
||||
random_sample_prepare.click(
|
||||
fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
|
||||
)
|
||||
|
||||
with gr.TabItem("Train Data"):
|
||||
with gr.TabItem("Train Model"):
|
||||
gr.Markdown("""```plaintext
|
||||
The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per updates are set correctly, or change them manually as needed.
|
||||
The auto-setting is still experimental. Set a large value of epoch if not sure; and keep last N checkpoints if limited disk space.
|
||||
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
|
||||
```""")
|
||||
with gr.Row():
|
||||
bt_calculate = bt_create = gr.Button("Auto Settings")
|
||||
exp_name = gr.Radio(label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"])
|
||||
tokenizer_file = gr.Textbox(label="Tokenizer File")
|
||||
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint")
|
||||
|
||||
with gr.Row():
|
||||
ch_finetune = bt_create = gr.Checkbox(label="Finetune")
|
||||
lb_samples = gr.Label(label="Samples")
|
||||
batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
|
||||
bt_calculate = bt_create = gr.Button("Auto Settings")
|
||||
|
||||
with gr.Row():
|
||||
ch_finetune = bt_create = gr.Checkbox(label="Finetune", value=True)
|
||||
tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
|
||||
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
|
||||
epochs = gr.Number(label="Epochs")
|
||||
learning_rate = gr.Number(label="Learning Rate", step=0.5e-5)
|
||||
max_grad_norm = gr.Number(label="Max Gradient Norm")
|
||||
num_warmup_updates = gr.Number(label="Warmup Updates")
|
||||
|
||||
with gr.Row():
|
||||
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
|
||||
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
|
||||
batch_size_type = gr.Radio(
|
||||
label="Batch Size Type",
|
||||
choices=["frame", "sample"],
|
||||
info="frame is calculated as seconds * sampling_rate / hop_length",
|
||||
)
|
||||
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", info="N frames or N samples")
|
||||
grad_accumulation_steps = gr.Number(
|
||||
label="Gradient Accumulation Steps", info="Effective batch size is multiplied by this value"
|
||||
)
|
||||
max_samples = gr.Number(label="Max Samples", info="Maximum number of samples per single GPU batch")
|
||||
|
||||
with gr.Row():
|
||||
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
|
||||
max_samples = gr.Number(label="Max Samples", value=64)
|
||||
|
||||
with gr.Row():
|
||||
grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
|
||||
max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
|
||||
|
||||
with gr.Row():
|
||||
epochs = gr.Number(label="Epochs", value=10)
|
||||
num_warmup_updates = gr.Number(label="Warmup Updates", value=2)
|
||||
|
||||
with gr.Row():
|
||||
save_per_updates = gr.Number(label="Save per Updates", value=300)
|
||||
save_per_updates = gr.Number(
|
||||
label="Save per Updates",
|
||||
info="Save intermediate checkpoints every N updates",
|
||||
minimum=10,
|
||||
)
|
||||
keep_last_n_checkpoints = gr.Number(
|
||||
label="Keep Last N Checkpoints",
|
||||
value=-1,
|
||||
step=1,
|
||||
precision=0,
|
||||
info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints",
|
||||
info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N",
|
||||
minimum=-1,
|
||||
)
|
||||
last_per_updates = gr.Number(label="Last per Updates", value=100)
|
||||
last_per_updates = gr.Number(
|
||||
label="Last per Updates",
|
||||
info="Save latest checkpoint with suffix _last.pt every N updates",
|
||||
minimum=10,
|
||||
)
|
||||
gr.Radio(label="") # placeholder
|
||||
|
||||
with gr.Row():
|
||||
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
||||
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none")
|
||||
cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
|
||||
start_button = gr.Button("Start Training")
|
||||
stop_button = gr.Button("Stop Training", interactive=False)
|
||||
mixed_precision = gr.Radio(label="Mixed Precision", choices=["none", "fp16", "bf16"])
|
||||
cd_logger = gr.Radio(label="Logger", choices=["none", "wandb", "tensorboard"])
|
||||
with gr.Column():
|
||||
start_button = gr.Button("Start Training")
|
||||
stop_button = gr.Button("Stop Training", interactive=False)
|
||||
|
||||
if projects_selelect is not None:
|
||||
(
|
||||
@@ -1651,7 +1656,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
ch_8bit_adam.value = bnb_optimizer_value
|
||||
|
||||
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
|
||||
txt_info_train = gr.Text(label="Info", value="")
|
||||
txt_info_train = gr.Textbox(label="Info", value="")
|
||||
|
||||
list_audios, select_audio = get_audio_project(projects_selelect, False)
|
||||
|
||||
@@ -1718,23 +1723,21 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
fn=calculate_train,
|
||||
inputs=[
|
||||
cm_project,
|
||||
epochs,
|
||||
learning_rate,
|
||||
batch_size_per_gpu,
|
||||
batch_size_type,
|
||||
max_samples,
|
||||
learning_rate,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_updates,
|
||||
ch_finetune,
|
||||
],
|
||||
outputs=[
|
||||
epochs,
|
||||
learning_rate,
|
||||
batch_size_per_gpu,
|
||||
max_samples,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_updates,
|
||||
lb_samples,
|
||||
learning_rate,
|
||||
epochs,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1744,25 +1747,25 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
|
||||
def setup_load_settings():
|
||||
output_components = [
|
||||
exp_name, # 1
|
||||
learning_rate, # 2
|
||||
batch_size_per_gpu, # 3
|
||||
batch_size_type, # 4
|
||||
max_samples, # 5
|
||||
grad_accumulation_steps, # 6
|
||||
max_grad_norm, # 7
|
||||
epochs, # 8
|
||||
num_warmup_updates, # 9
|
||||
save_per_updates, # 10
|
||||
keep_last_n_checkpoints, # 11
|
||||
last_per_updates, # 12
|
||||
ch_finetune, # 13
|
||||
file_checkpoint_train, # 14
|
||||
tokenizer_type, # 15
|
||||
tokenizer_file, # 16
|
||||
mixed_precision, # 17
|
||||
cd_logger, # 18
|
||||
ch_8bit_adam, # 19
|
||||
exp_name,
|
||||
learning_rate,
|
||||
batch_size_per_gpu,
|
||||
batch_size_type,
|
||||
max_samples,
|
||||
grad_accumulation_steps,
|
||||
max_grad_norm,
|
||||
epochs,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
keep_last_n_checkpoints,
|
||||
last_per_updates,
|
||||
ch_finetune,
|
||||
file_checkpoint_train,
|
||||
tokenizer_type,
|
||||
tokenizer_file,
|
||||
mixed_precision,
|
||||
cd_logger,
|
||||
ch_8bit_adam,
|
||||
]
|
||||
return output_components
|
||||
|
||||
@@ -1782,19 +1785,23 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
|
||||
with gr.TabItem("Test Model"):
|
||||
gr.Markdown("""```plaintext
|
||||
SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
|
||||
Check the use_ema setting (True or False) for your model to see what works best for you. Set seed to -1 for random.
|
||||
```""")
|
||||
exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
|
||||
exp_name = gr.Radio(
|
||||
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
|
||||
)
|
||||
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
|
||||
|
||||
with gr.Row():
|
||||
nfe_step = gr.Number(label="NFE Step", value=32)
|
||||
speed = gr.Slider(label="Speed", value=1.0, minimum=0.3, maximum=2.0, step=0.1)
|
||||
seed = gr.Number(label="Seed", value=-1, minimum=-1)
|
||||
seed = gr.Number(label="Random Seed", value=-1, minimum=-1)
|
||||
remove_silence = gr.Checkbox(label="Remove Silence")
|
||||
|
||||
ch_use_ema = gr.Checkbox(label="Use EMA", value=True)
|
||||
with gr.Row():
|
||||
ch_use_ema = gr.Checkbox(
|
||||
label="Use EMA", value=True, info="Turn off at early stage might offer better results"
|
||||
)
|
||||
cm_checkpoint = gr.Dropdown(
|
||||
choices=list_checkpoints, value=checkpoint_select, label="Checkpoints", allow_custom_value=True
|
||||
)
|
||||
@@ -1802,20 +1809,20 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
|
||||
|
||||
random_sample_infer = gr.Button("Random Sample")
|
||||
|
||||
ref_text = gr.Textbox(label="Ref Text")
|
||||
ref_audio = gr.Audio(label="Audio Ref", type="filepath")
|
||||
gen_text = gr.Textbox(label="Gen Text")
|
||||
ref_text = gr.Textbox(label="Reference Text")
|
||||
ref_audio = gr.Audio(label="Reference Audio", type="filepath")
|
||||
gen_text = gr.Textbox(label="Text to Generate")
|
||||
|
||||
random_sample_infer.click(
|
||||
fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
txt_info_gpu = gr.Textbox("", label="Device")
|
||||
seed_info = gr.Text(label="Seed :")
|
||||
check_button_infer = gr.Button("Infer")
|
||||
txt_info_gpu = gr.Textbox("", label="Inference on Device :")
|
||||
seed_info = gr.Textbox(label="Used Random Seed :")
|
||||
check_button_infer = gr.Button("Inference")
|
||||
|
||||
gen_audio = gr.Audio(label="Audio Gen", type="filepath")
|
||||
gen_audio = gr.Audio(label="Generated Audio", type="filepath")
|
||||
|
||||
check_button_infer.click(
|
||||
fn=infer,
|
||||
@@ -1838,18 +1845,20 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
|
||||
bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
||||
cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
|
||||
|
||||
with gr.TabItem("Reduce Checkpoint"):
|
||||
with gr.TabItem("Prune Checkpoint"):
|
||||
gr.Markdown("""```plaintext
|
||||
Reduce the model size from 5GB to 1.3GB. The new checkpoint can be used for inference or fine-tuning afterward, but it cannot be used to continue training.
|
||||
Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining.
|
||||
```""")
|
||||
txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
|
||||
txt_path_checkpoint_small = gr.Text(label="Path to Output:")
|
||||
ch_safetensors = gr.Checkbox(label="Safetensors", value="")
|
||||
txt_info_reduse = gr.Text(label="Info", value="")
|
||||
reduse_button = gr.Button("Reduce")
|
||||
txt_path_checkpoint = gr.Textbox(label="Path to Checkpoint:")
|
||||
txt_path_checkpoint_small = gr.Textbox(label="Path to Output:")
|
||||
with gr.Row():
|
||||
ch_save_ema = gr.Checkbox(label="Save EMA checkpoint", value=True)
|
||||
ch_safetensors = gr.Checkbox(label="Save with safetensors format", value=True)
|
||||
txt_info_reduse = gr.Textbox(label="Info", value="")
|
||||
reduse_button = gr.Button("Prune")
|
||||
reduse_button.click(
|
||||
fn=extract_and_save_ema_model,
|
||||
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_safetensors],
|
||||
fn=prune_checkpoint,
|
||||
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_save_ema, ch_safetensors],
|
||||
outputs=[txt_info_reduse],
|
||||
)
|
||||
|
||||
|
||||
@@ -4,8 +4,9 @@ import os
|
||||
from importlib.resources import files
|
||||
|
||||
import hydra
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.model import CFM, DiT, Trainer, UNetT
|
||||
from f5_tts.model import CFM, Trainer
|
||||
from f5_tts.model.dataset import load_dataset
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
|
||||
@@ -13,61 +14,60 @@ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to
|
||||
|
||||
|
||||
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
|
||||
def main(cfg):
|
||||
tokenizer = cfg.model.tokenizer
|
||||
mel_spec_type = cfg.model.mel_spec.mel_spec_type
|
||||
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
|
||||
def main(model_cfg):
|
||||
model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
tokenizer = model_cfg.model.tokenizer
|
||||
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
|
||||
exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}"
|
||||
wandb_resume_id = None
|
||||
|
||||
# set text tokenizer
|
||||
if tokenizer != "custom":
|
||||
tokenizer_path = cfg.datasets.name
|
||||
tokenizer_path = model_cfg.datasets.name
|
||||
else:
|
||||
tokenizer_path = cfg.model.tokenizer_path
|
||||
tokenizer_path = model_cfg.model.tokenizer_path
|
||||
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
||||
|
||||
# set model
|
||||
if "F5TTS" in cfg.model.name:
|
||||
model_cls = DiT
|
||||
elif "E2TTS" in cfg.model.name:
|
||||
model_cls = UNetT
|
||||
wandb_resume_id = None
|
||||
|
||||
model = CFM(
|
||||
transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
|
||||
mel_spec_kwargs=cfg.model.mel_spec,
|
||||
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels),
|
||||
mel_spec_kwargs=model_cfg.model.mel_spec,
|
||||
vocab_char_map=vocab_char_map,
|
||||
)
|
||||
|
||||
# init trainer
|
||||
trainer = Trainer(
|
||||
model,
|
||||
epochs=cfg.optim.epochs,
|
||||
learning_rate=cfg.optim.learning_rate,
|
||||
num_warmup_updates=cfg.optim.num_warmup_updates,
|
||||
save_per_updates=cfg.ckpts.save_per_updates,
|
||||
keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", -1),
|
||||
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
|
||||
batch_size=cfg.datasets.batch_size_per_gpu,
|
||||
batch_size_type=cfg.datasets.batch_size_type,
|
||||
max_samples=cfg.datasets.max_samples,
|
||||
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
|
||||
max_grad_norm=cfg.optim.max_grad_norm,
|
||||
logger=cfg.ckpts.logger,
|
||||
epochs=model_cfg.optim.epochs,
|
||||
learning_rate=model_cfg.optim.learning_rate,
|
||||
num_warmup_updates=model_cfg.optim.num_warmup_updates,
|
||||
save_per_updates=model_cfg.ckpts.save_per_updates,
|
||||
keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints,
|
||||
checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")),
|
||||
batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu,
|
||||
batch_size_type=model_cfg.datasets.batch_size_type,
|
||||
max_samples=model_cfg.datasets.max_samples,
|
||||
grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
|
||||
max_grad_norm=model_cfg.optim.max_grad_norm,
|
||||
logger=model_cfg.ckpts.logger,
|
||||
wandb_project="CFM-TTS",
|
||||
wandb_run_name=exp_name,
|
||||
wandb_resume_id=wandb_resume_id,
|
||||
last_per_updates=cfg.ckpts.last_per_updates,
|
||||
log_samples=True,
|
||||
bnb_optimizer=cfg.optim.bnb_optimizer,
|
||||
last_per_updates=model_cfg.ckpts.last_per_updates,
|
||||
log_samples=model_cfg.ckpts.log_samples,
|
||||
bnb_optimizer=model_cfg.optim.bnb_optimizer,
|
||||
mel_spec_type=mel_spec_type,
|
||||
is_local_vocoder=cfg.model.vocoder.is_local,
|
||||
local_vocoder_path=cfg.model.vocoder.local_path,
|
||||
is_local_vocoder=model_cfg.model.vocoder.is_local,
|
||||
local_vocoder_path=model_cfg.model.vocoder.local_path,
|
||||
model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True),
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|
||||
train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec)
|
||||
trainer.train(
|
||||
train_dataset,
|
||||
num_workers=cfg.datasets.num_workers,
|
||||
num_workers=model_cfg.datasets.num_workers,
|
||||
resumable_with_seed=666, # seed for shuffling dataset
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user