mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-25 20:34:27 -08:00
Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b9156c0ad5 | ||
|
|
3ad3211915 | ||
|
|
f6726a78cc | ||
|
|
1d0cf2b8ba | ||
|
|
1d82b7928e | ||
|
|
4ae5347282 | ||
|
|
621559cbbe | ||
|
|
526b09eebd | ||
|
|
9afa80f204 | ||
|
|
c6b3189bbd | ||
|
|
c87ce39515 | ||
|
|
10ef27065b | ||
|
|
f374640f34 | ||
|
|
d5f4c88aa4 | ||
|
|
f968e13b6d | ||
|
|
339b17fed3 | ||
|
|
79302b694a | ||
|
|
a1e88c2a9e | ||
|
|
1ab90505a4 | ||
|
|
7e4985ca56 | ||
|
|
f05ceda4cb | ||
|
|
2bd39dd813 | ||
|
|
f017815083 | ||
|
|
297755fac3 | ||
|
|
d05075205f | ||
|
|
8722cf0766 | ||
|
|
48d1a9312e |
12
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
12
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: "Bug Report"
|
||||
description: |
|
||||
Please provide as much details to help address the issue, including logs and screenshots.
|
||||
Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
|
||||
labels:
|
||||
- bug
|
||||
body:
|
||||
@@ -15,13 +15,13 @@ body:
|
||||
required: true
|
||||
- label: I have searched for existing issues, including closed ones, and couldn't find a solution.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report in order to facilitate communication.
|
||||
- label: I am using English to submit this issue to facilitate community communication.
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Environment Details
|
||||
description: "Provide details such as OS, Python version, and any relevant software or dependencies."
|
||||
placeholder: e.g., CentOS Linux 7, RTX 3090, Python 3.10, torch==2.3.0, cuda 11.8
|
||||
description: "Provide details including OS, GPU info, Python version, any relevant software or dependencies, and trainer setting."
|
||||
placeholder: e.g., CentOS Linux 7, 4 * RTX 3090, Python 3.10, torch==2.3.0+cu118, cuda 11.8, config yaml is ...
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
@@ -39,12 +39,12 @@ body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ✔️ Expected Behavior
|
||||
placeholder: Describe what you expected to happen.
|
||||
placeholder: Describe in detail what you expected to happen.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ❌ Actual Behavior
|
||||
placeholder: Describe what actually happened.
|
||||
placeholder: Describe in detail what actually happened.
|
||||
validations:
|
||||
required: false
|
||||
2
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
2
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
@@ -15,7 +15,7 @@ body:
|
||||
required: true
|
||||
- label: I have searched for existing issues, including closed ones, and found not discussion yet.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report in order to facilitate communication.
|
||||
- label: I am using English to submit this issue to facilitate community communication.
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
16
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
16
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: "Help Wanted"
|
||||
description: |
|
||||
Please provide as much details to help address the issue, including logs and screenshots.
|
||||
Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
|
||||
labels:
|
||||
- help wanted
|
||||
body:
|
||||
@@ -15,36 +15,40 @@ body:
|
||||
required: true
|
||||
- label: I have searched for existing issues, including closed ones, and couldn't find a solution.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report in order to facilitate communication.
|
||||
- label: I am using English to submit this issue to facilitate community communication.
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Environment Details
|
||||
description: "Provide details such as OS, Python version, and any relevant software or dependencies."
|
||||
placeholder: e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
|
||||
placeholder: |
|
||||
e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
|
||||
If training or finetuning related, provide detailed configuration including GPU info and training setup.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: |
|
||||
Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
|
||||
Include detailed steps, screenshots, and logs. Provide used prompt wav and text. Use the correct markdown syntax for code blocks.
|
||||
placeholder: |
|
||||
1. Create a new conda environment.
|
||||
2. Clone the repository and install as pip package.
|
||||
3. Run the command: `f5-tts_infer-gradio` with no ref_text provided.
|
||||
4. Stuck there with the following message... (attach logs and also error msg e.g. after ctrl-c).
|
||||
5. Prompt & generated wavs are [change suffix to .mp4 to enable direct upload or pack all to .zip].
|
||||
6. Reference audio's transcription or provided ref_text is `xxx`, and text to generate is `xxx`.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ✔️ Expected Behavior
|
||||
placeholder: Describe what you expected to happen, e.g. output a generated audio
|
||||
placeholder: Describe what you expected to happen in detail, e.g. output a generated audio.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ❌ Actual Behavior
|
||||
placeholder: Describe what actually happened, failure messages, etc.
|
||||
placeholder: Describe what actually happened in detail, failure messages, etc.
|
||||
validations:
|
||||
required: false
|
||||
6
.github/ISSUE_TEMPLATE/question.yml
vendored
6
.github/ISSUE_TEMPLATE/question.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: "Question"
|
||||
description: |
|
||||
Pure question or inquiry about the project, usage issue goes with "help wanted".
|
||||
Research question or pure inquiry about the project, usage issue goes with "help wanted".
|
||||
labels:
|
||||
- question
|
||||
body:
|
||||
@@ -9,13 +9,13 @@ body:
|
||||
label: Checks
|
||||
description: "To help us grasp quickly, please confirm the following:"
|
||||
options:
|
||||
- label: This template is only for question, not feature requests or bug reports.
|
||||
- label: This template is only for research question, not usage problems, feature requests or bug reports.
|
||||
required: true
|
||||
- label: I have thoroughly reviewed the project documentation and read the related paper(s).
|
||||
required: true
|
||||
- label: I have searched for existing issues, including closed ones, no similar questions.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report in order to facilitate communication.
|
||||
- label: I am using English to submit this issue to facilitate community communication.
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -7,8 +7,6 @@ ckpts/
|
||||
wandb/
|
||||
results/
|
||||
|
||||
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.7.0
|
||||
rev: v0.11.2
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
@@ -9,6 +9,6 @@ repos:
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v2.3.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
|
||||
@@ -23,4 +23,8 @@ RUN git clone https://github.com/SWivid/F5-TTS.git \
|
||||
|
||||
ENV SHELL=/bin/bash
|
||||
|
||||
VOLUME /root/.cache/huggingface/hub/
|
||||
|
||||
EXPOSE 7860
|
||||
|
||||
WORKDIR /workspace/F5-TTS
|
||||
|
||||
11
README.md
11
README.md
@@ -100,8 +100,11 @@ conda activate f5-tts
|
||||
# Build from Dockerfile
|
||||
docker build -t f5tts:v1 .
|
||||
|
||||
# Or pull from GitHub Container Registry
|
||||
docker pull ghcr.io/swivid/f5-tts:main
|
||||
# Run from GitHub Container Registry
|
||||
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main
|
||||
|
||||
# Quickstart if you want to just run the web interface (not CLI)
|
||||
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main f5-tts_infer-gradio --host 0.0.0.0
|
||||
```
|
||||
|
||||
|
||||
@@ -200,7 +203,7 @@ Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
|
||||
|
||||
## Development
|
||||
|
||||
Use pre-commit to ensure code quality (will run linters and formatters automatically)
|
||||
Use pre-commit to ensure code quality (will run linters and formatters automatically):
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
@@ -213,7 +216,7 @@ When making a pull request, before each commit, run:
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
|
||||
Note: Some model components have linting exceptions for E722 to accommodate tensor notation.
|
||||
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
@@ -1,12 +1,3 @@
|
||||
The pretrained model checkpoints can be reached at https://huggingface.co/SWivid/F5-TTS.
|
||||
|
||||
Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
|
||||
|
||||
```
|
||||
ckpts/
|
||||
F5TTS_v1_Base/
|
||||
model_1250000.safetensors
|
||||
F5TTS_Base/
|
||||
model_1200000.safetensors
|
||||
E2TTS_Base/
|
||||
model_1200000.safetensors
|
||||
```
|
||||
Scripts will automatically pull model checkpoints from Huggingface, by default to `~/.cache/huggingface/hub/`.
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "1.0.0"
|
||||
version = "1.0.8"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
|
||||
@@ -5,6 +5,7 @@ from importlib.resources import files
|
||||
import soundfile as sf
|
||||
import tqdm
|
||||
from cached_path import cached_path
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import (
|
||||
@@ -16,7 +17,6 @@ from f5_tts.infer.utils_infer import (
|
||||
remove_silence_for_generated_wav,
|
||||
save_spectrogram,
|
||||
)
|
||||
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
|
||||
from f5_tts.model.utils import seed_everything
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class F5TTS:
|
||||
hf_cache_dir=None,
|
||||
):
|
||||
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
||||
model_cls = globals()[model_cfg.model.backbone]
|
||||
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
@@ -74,8 +74,6 @@ class F5TTS:
|
||||
elif model == "E2TTS_Base":
|
||||
repo_name = "E2-TTS"
|
||||
ckpt_step = 1200000
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
if not ckpt_file:
|
||||
ckpt_file = str(
|
||||
@@ -117,10 +115,11 @@ class F5TTS:
|
||||
seed=None,
|
||||
):
|
||||
if seed is None:
|
||||
self.seed = random.randint(0, sys.maxsize)
|
||||
seed_everything(self.seed)
|
||||
seed = random.randint(0, sys.maxsize)
|
||||
seed_everything(seed)
|
||||
self.seed = seed
|
||||
|
||||
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
|
||||
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)
|
||||
|
||||
wav, sr, spec = infer_process(
|
||||
ref_file,
|
||||
|
||||
@@ -10,7 +10,7 @@ datasets:
|
||||
num_workers: 16
|
||||
|
||||
optim:
|
||||
epochs: 11
|
||||
epochs: 11 # only suitable for Emilia, if you want to train it on LibriTTS, set epoch 686
|
||||
learning_rate: 7.5e-5
|
||||
num_warmup_updates: 20000 # warmup updates
|
||||
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
||||
@@ -49,4 +49,4 @@ ckpts:
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
|
||||
@@ -10,6 +10,7 @@ from importlib.resources import files
|
||||
import torch
|
||||
import torchaudio
|
||||
from accelerate import Accelerator
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -19,7 +20,7 @@ from f5_tts.eval.utils_eval import (
|
||||
get_seedtts_testset_metainfo,
|
||||
)
|
||||
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
|
||||
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
|
||||
accelerator = Accelerator()
|
||||
@@ -65,7 +66,7 @@ def main():
|
||||
no_ref_audio = False
|
||||
|
||||
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
||||
model_cls = globals()[model_cfg.model.backbone]
|
||||
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
dataset_name = model_cfg.datasets.name
|
||||
@@ -195,7 +196,7 @@ def main():
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
timediff = time.time() - start
|
||||
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
|
||||
print(f"Done batch inference in {timediff / 60:.2f} minutes.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -148,9 +148,9 @@ def get_inference_prompt(
|
||||
|
||||
# deal with batch
|
||||
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
||||
assert (
|
||||
min_tokens <= total_mel_len <= max_tokens
|
||||
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||
assert min_tokens <= total_mel_len <= max_tokens, (
|
||||
f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||
)
|
||||
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
|
||||
|
||||
utts[bucket_i].append(utt)
|
||||
|
||||
@@ -4,16 +4,17 @@ The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://h
|
||||
|
||||
**More checkpoints with whole community efforts can be found in [SHARED.md](SHARED.md), supporting more languages.**
|
||||
|
||||
Currently support **30s for a single** generation, which is the **total length** including both prompt and output audio. However, you can provide `infer_cli` and `infer_gradio` with longer text, will automatically do chunk generation. Long reference audio will be **clip short to ~15s**.
|
||||
Currently support **30s for a single** generation, which is the **total length** (same logic if `fix_duration`) including both prompt and output audio. However, `infer_cli` and `infer_gradio` will automatically do chunk generation for longer text. Long reference audio will be **clip short to ~12s**.
|
||||
|
||||
To avoid possible inference failures, make sure you have seen through the following instructions.
|
||||
|
||||
- Use reference audio <15s and leave some silence (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
|
||||
- Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
|
||||
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses.
|
||||
- Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English.
|
||||
- If the generation output is blank (pure silence), check for ffmpeg installation (various tutorials online, blogs, videos, etc.).
|
||||
- Try turn off use_ema if using an early-stage finetuned checkpoint (which goes just few updates).
|
||||
- Use reference audio <12s and leave proper silence space (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
|
||||
- <ins>Uppercased letters</ins> (best with form like K.F.C.) will be uttered letter by letter, and lowercased letters used for common words.
|
||||
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") <ins>to explicitly introduce some pauses</ins>.
|
||||
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
|
||||
- <ins>Preprocess numbers</ins> to Chinese letters if you want to have them read in Chinese, otherwise in English.
|
||||
- If the generation output is blank (pure silence), <ins>check for ffmpeg installation</ins>.
|
||||
- Try <ins>turn off `use_ema` if using an early-stage</ins> finetuned checkpoint (which goes just few updates).
|
||||
|
||||
|
||||
## Gradio App
|
||||
|
||||
@@ -44,6 +44,7 @@
|
||||
|
||||
```bash
|
||||
Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
|
||||
# A Variant Model: hf://SWivid/F5-TTS/F5TTS_v1_Base_no_zero_init/model_1250000.safetensors
|
||||
Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
```
|
||||
|
||||
@@ -10,6 +10,7 @@ import numpy as np
|
||||
import soundfile as sf
|
||||
import tomli
|
||||
from cached_path import cached_path
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import (
|
||||
@@ -21,13 +22,13 @@ from f5_tts.infer.utils_infer import (
|
||||
sway_sampling_coef,
|
||||
speed,
|
||||
fix_duration,
|
||||
device,
|
||||
infer_process,
|
||||
load_model,
|
||||
load_vocoder,
|
||||
preprocess_ref_audio_text,
|
||||
remove_silence_for_generated_wav,
|
||||
)
|
||||
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -162,6 +163,11 @@ parser.add_argument(
|
||||
type=float,
|
||||
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
help="Specify the device to run on",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -202,6 +208,7 @@ cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
|
||||
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
|
||||
speed = args.speed or config.get("speed", speed)
|
||||
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
|
||||
device = args.device or config.get("device", device)
|
||||
|
||||
|
||||
# patches for pip pkg user
|
||||
@@ -239,20 +246,23 @@ if vocoder_name == "vocos":
|
||||
elif vocoder_name == "bigvgan":
|
||||
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
||||
|
||||
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
|
||||
vocoder = load_vocoder(
|
||||
vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
|
||||
)
|
||||
|
||||
|
||||
# load TTS model
|
||||
|
||||
model_cfg = OmegaConf.load(
|
||||
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
||||
).model
|
||||
model_cls = globals()[model_cfg.backbone]
|
||||
)
|
||||
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
|
||||
|
||||
if model != "F5TTS_Base":
|
||||
assert vocoder_name == model_cfg.mel_spec.mel_spec_type
|
||||
assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type
|
||||
|
||||
# override for previous models
|
||||
if model == "F5TTS_Base":
|
||||
@@ -269,7 +279,9 @@ if not ckpt_file:
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
|
||||
|
||||
print(f"Using {model}...")
|
||||
ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
|
||||
ema_model = load_model(
|
||||
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
|
||||
)
|
||||
|
||||
|
||||
# inference process
|
||||
@@ -325,6 +337,7 @@ def main():
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
speed=speed,
|
||||
fix_duration=fix_duration,
|
||||
device=device,
|
||||
)
|
||||
generated_audio_segments.append(audio_segment)
|
||||
|
||||
@@ -332,7 +345,7 @@ def main():
|
||||
if len(gen_text_) > 200:
|
||||
gen_text_ = gen_text_[:200] + " ... "
|
||||
sf.write(
|
||||
os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"),
|
||||
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
|
||||
audio_segment,
|
||||
final_sample_rate,
|
||||
)
|
||||
|
||||
@@ -758,9 +758,9 @@ This is {"a local web UI for [F5 TTS](https://github.com/SWivid/F5-TTS)" if not
|
||||
|
||||
The checkpoints currently support English and Chinese.
|
||||
|
||||
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
|
||||
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 12s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
|
||||
|
||||
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
|
||||
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<12s). Ensure the audio is fully uploaded before generating.**
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@@ -7,10 +7,11 @@ from importlib.resources import files
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
|
||||
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
||||
|
||||
device = (
|
||||
@@ -40,7 +41,7 @@ target_rms = 0.1
|
||||
|
||||
|
||||
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
||||
model_cls = globals()[model_cfg.model.backbone]
|
||||
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
dataset_name = model_cfg.datasets.name
|
||||
|
||||
@@ -149,7 +149,7 @@ def initialize_asr_pipeline(device: str = device, dtype=None):
|
||||
dtype = (
|
||||
torch.float16
|
||||
if "cuda" in device
|
||||
and torch.cuda.get_device_properties(device).major >= 6
|
||||
and torch.cuda.get_device_properties(device).major >= 7
|
||||
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
||||
else torch.float32
|
||||
)
|
||||
@@ -186,7 +186,7 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
||||
dtype = (
|
||||
torch.float16
|
||||
if "cuda" in device
|
||||
and torch.cuda.get_device_properties(device).major >= 6
|
||||
and torch.cuda.get_device_properties(device).major >= 7
|
||||
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
||||
else torch.float32
|
||||
)
|
||||
@@ -289,7 +289,7 @@ def remove_silence_edges(audio, silence_threshold=-42):
|
||||
# preprocess reference audio and text
|
||||
|
||||
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device):
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
|
||||
show_info("Converting audio...")
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
aseg = AudioSegment.from_file(ref_audio_orig)
|
||||
@@ -302,7 +302,7 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 15s, clipping short. (1)")
|
||||
show_info("Audio is over 12s, clipping short. (1)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
|
||||
@@ -314,7 +314,7 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 15s, clipping short. (2)")
|
||||
show_info("Audio is over 12s, clipping short. (2)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
|
||||
@@ -323,7 +323,7 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
|
||||
# 3. if no proper silence found for clipping
|
||||
if len(aseg) > 12000:
|
||||
aseg = aseg[:12000]
|
||||
show_info("Audio is over 15s, clipping short. (3)")
|
||||
show_info("Audio is over 12s, clipping short. (3)")
|
||||
|
||||
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
|
||||
aseg.export(f.name, format="wav")
|
||||
@@ -479,14 +479,15 @@ def infer_batch_process(
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
)
|
||||
del _
|
||||
|
||||
generated = generated.to(torch.float32)
|
||||
generated = generated.to(torch.float32) # generated mel spectrogram
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
generated_mel_spec = generated.permute(0, 2, 1)
|
||||
generated = generated.permute(0, 2, 1)
|
||||
if mel_spec_type == "vocos":
|
||||
generated_wave = vocoder.decode(generated_mel_spec)
|
||||
generated_wave = vocoder.decode(generated)
|
||||
elif mel_spec_type == "bigvgan":
|
||||
generated_wave = vocoder(generated_mel_spec)
|
||||
generated_wave = vocoder(generated)
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * rms / target_rms
|
||||
|
||||
@@ -497,7 +498,9 @@ def infer_batch_process(
|
||||
for j in range(0, len(generated_wave), chunk_size):
|
||||
yield generated_wave[j : j + chunk_size], target_sample_rate
|
||||
else:
|
||||
yield generated_wave, generated_mel_spec[0].cpu().numpy()
|
||||
generated_cpu = generated[0].cpu().numpy()
|
||||
del generated
|
||||
yield generated_wave, generated_cpu
|
||||
|
||||
if streaming:
|
||||
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
|
||||
|
||||
@@ -219,7 +219,8 @@ class DiT(nn.Module):
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.checkpoint_activations:
|
||||
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
|
||||
# https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
|
||||
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
|
||||
else:
|
||||
x = block(x, t, mask=mask, rope=rope)
|
||||
|
||||
|
||||
@@ -350,7 +350,7 @@ class Trainer:
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
|
||||
desc=f"Epoch {epoch+1}/{self.epochs}",
|
||||
desc=f"Epoch {epoch + 1}/{self.epochs}",
|
||||
unit="update",
|
||||
disable=not self.accelerator.is_local_main_process,
|
||||
initial=progress_bar_initial,
|
||||
@@ -428,6 +428,7 @@ class Trainer:
|
||||
torchaudio.save(
|
||||
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
|
||||
)
|
||||
self.model.train()
|
||||
|
||||
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
|
||||
self.save_checkpoint(global_update, last=True)
|
||||
|
||||
@@ -24,7 +24,7 @@ updates_per_epoch = total_hours / mini_batch_hours
|
||||
|
||||
# result
|
||||
epochs = wanted_max_updates / updates_per_epoch
|
||||
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
|
||||
print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})")
|
||||
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
||||
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
||||
|
||||
|
||||
@@ -13,9 +13,9 @@ from importlib.resources import files
|
||||
import torch
|
||||
import torchaudio
|
||||
from huggingface_hub import hf_hub_download
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config
|
||||
from f5_tts.infer.utils_infer import (
|
||||
chunk_text,
|
||||
preprocess_ref_audio_text,
|
||||
@@ -80,7 +80,7 @@ class TTSStreamingProcessor:
|
||||
else "cpu"
|
||||
)
|
||||
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
||||
self.model_cls = globals()[model_cfg.model.backbone]
|
||||
self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
self.model_arc = model_cfg.model.arch
|
||||
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
|
||||
|
||||
@@ -51,7 +51,11 @@ Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussio
|
||||
|
||||
Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
|
||||
|
||||
The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
|
||||
If want to finetune with a variant version e.g. *F5TTS_v1_Base_no_zero_init*, manually download pretrained checkpoint from model weight repository and fill in the path correspondingly on web interface.
|
||||
|
||||
If use tensorboard as logger, install it first with `pip install tensorboard`.
|
||||
|
||||
<ins>The `use_ema = True` might be harmful for early-stage finetuned checkpoints</ins> (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off with finetune gradio option or `load_model(..., use_ema=False)`, see if offer better results.
|
||||
|
||||
### 3. W&B Logging
|
||||
|
||||
|
||||
@@ -122,7 +122,7 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None):
|
||||
for future in tqdm(
|
||||
chunk_futures,
|
||||
total=len(chunk),
|
||||
desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}",
|
||||
desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}",
|
||||
):
|
||||
try:
|
||||
result = future.result()
|
||||
@@ -233,7 +233,7 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
|
||||
dataset_name = out_dir.stem
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
|
||||
|
||||
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
|
||||
|
||||
@@ -198,7 +198,7 @@ def main():
|
||||
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
if "ZH" in langs:
|
||||
print(f"Bad zh transcription case: {total_bad_case_zh}")
|
||||
if "EN" in langs:
|
||||
|
||||
@@ -72,7 +72,7 @@ def main():
|
||||
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -50,7 +50,7 @@ def main():
|
||||
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -40,15 +40,15 @@ def parse_args():
|
||||
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
||||
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
||||
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
||||
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
|
||||
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
|
||||
parser.add_argument("--num_warmup_updates", type=int, default=20000, help="Warmup updates")
|
||||
parser.add_argument("--save_per_updates", type=int, default=50000, help="Save checkpoint every N updates")
|
||||
parser.add_argument(
|
||||
"--keep_last_n_checkpoints",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
|
||||
)
|
||||
parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
|
||||
parser.add_argument("--last_per_updates", type=int, default=5000, help="Save last checkpoint every N updates")
|
||||
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
|
||||
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
|
||||
parser.add_argument(
|
||||
@@ -65,7 +65,7 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Log inferenced samples per ckpt save updates",
|
||||
)
|
||||
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
|
||||
parser.add_argument("--logger", type=str, default=None, choices=[None, "wandb", "tensorboard"], help="logger")
|
||||
parser.add_argument(
|
||||
"--bnb_optimizer",
|
||||
action="store_true",
|
||||
|
||||
@@ -120,11 +120,11 @@ def load_settings(project_name):
|
||||
default_settings = {
|
||||
"exp_name": "F5TTS_v1_Base",
|
||||
"learning_rate": 1e-5,
|
||||
"batch_size_per_gpu": 1,
|
||||
"batch_size_type": "sample",
|
||||
"batch_size_per_gpu": 3200,
|
||||
"batch_size_type": "frame",
|
||||
"max_samples": 64,
|
||||
"grad_accumulation_steps": 4,
|
||||
"max_grad_norm": 1,
|
||||
"grad_accumulation_steps": 1,
|
||||
"max_grad_norm": 1.0,
|
||||
"epochs": 100,
|
||||
"num_warmup_updates": 100,
|
||||
"save_per_updates": 500,
|
||||
@@ -134,8 +134,8 @@ def load_settings(project_name):
|
||||
"file_checkpoint_train": "",
|
||||
"tokenizer_type": "pinyin",
|
||||
"tokenizer_file": "",
|
||||
"mixed_precision": "none",
|
||||
"logger": "wandb",
|
||||
"mixed_precision": "fp16",
|
||||
"logger": "none",
|
||||
"bnb_optimizer": False,
|
||||
}
|
||||
|
||||
@@ -361,27 +361,27 @@ def terminate_process(pid):
|
||||
|
||||
|
||||
def start_training(
|
||||
dataset_name="",
|
||||
exp_name="F5TTS_v1_Base",
|
||||
learning_rate=1e-5,
|
||||
batch_size_per_gpu=1,
|
||||
batch_size_type="sample",
|
||||
max_samples=64,
|
||||
grad_accumulation_steps=4,
|
||||
max_grad_norm=1.0,
|
||||
epochs=100,
|
||||
num_warmup_updates=100,
|
||||
save_per_updates=500,
|
||||
keep_last_n_checkpoints=-1,
|
||||
last_per_updates=100,
|
||||
finetune=True,
|
||||
file_checkpoint_train="",
|
||||
tokenizer_type="pinyin",
|
||||
tokenizer_file="",
|
||||
mixed_precision="fp16",
|
||||
stream=False,
|
||||
logger="wandb",
|
||||
ch_8bit_adam=False,
|
||||
dataset_name,
|
||||
exp_name,
|
||||
learning_rate,
|
||||
batch_size_per_gpu,
|
||||
batch_size_type,
|
||||
max_samples,
|
||||
grad_accumulation_steps,
|
||||
max_grad_norm,
|
||||
epochs,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
keep_last_n_checkpoints,
|
||||
last_per_updates,
|
||||
finetune,
|
||||
file_checkpoint_train,
|
||||
tokenizer_type,
|
||||
tokenizer_file,
|
||||
mixed_precision,
|
||||
stream,
|
||||
logger,
|
||||
ch_8bit_adam,
|
||||
):
|
||||
global training_process, tts_api, stop_signal
|
||||
|
||||
@@ -458,7 +458,10 @@ def start_training(
|
||||
|
||||
cmd += f" --tokenizer {tokenizer_type}"
|
||||
|
||||
cmd += f" --log_samples --logger {logger}"
|
||||
if logger != "none":
|
||||
cmd += f" --logger {logger}"
|
||||
|
||||
cmd += " --log_samples"
|
||||
|
||||
if ch_8bit_adam:
|
||||
cmd += " --bnb_optimizer"
|
||||
@@ -515,7 +518,7 @@ def start_training(
|
||||
training_process = subprocess.Popen(
|
||||
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env
|
||||
)
|
||||
yield "Training started...", gr.update(interactive=False), gr.update(interactive=True)
|
||||
yield "Training started ...", gr.update(interactive=False), gr.update(interactive=True)
|
||||
|
||||
stdout_queue = queue.Queue()
|
||||
stderr_queue = queue.Queue()
|
||||
@@ -584,7 +587,11 @@ def start_training(
|
||||
gr.update(interactive=True),
|
||||
)
|
||||
else:
|
||||
yield "Training complete!", gr.update(interactive=False), gr.update(interactive=True)
|
||||
yield (
|
||||
"Training complete or paused ...",
|
||||
gr.update(interactive=False),
|
||||
gr.update(interactive=True),
|
||||
)
|
||||
break
|
||||
|
||||
# Small sleep to prevent CPU thrashing
|
||||
@@ -598,9 +605,9 @@ def start_training(
|
||||
time.sleep(1)
|
||||
|
||||
if training_process is None:
|
||||
text_info = "train stop"
|
||||
text_info = "Train stopped !"
|
||||
else:
|
||||
text_info = "train complete !"
|
||||
text_info = "Train complete at end !"
|
||||
|
||||
except Exception as e: # Catch all exceptions
|
||||
# Ensure that we reset the training process variable in case of an error
|
||||
@@ -615,11 +622,11 @@ def stop_training():
|
||||
global training_process, stop_signal
|
||||
|
||||
if training_process is None:
|
||||
return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
|
||||
return "Train not running !", gr.update(interactive=True), gr.update(interactive=False)
|
||||
terminate_process_tree(training_process.pid)
|
||||
# training_process = None
|
||||
stop_signal = True
|
||||
return "train stop", gr.update(interactive=True), gr.update(interactive=False)
|
||||
return "Train stopped !", gr.update(interactive=True), gr.update(interactive=False)
|
||||
|
||||
|
||||
def get_list_projects():
|
||||
@@ -958,21 +965,23 @@ def calculate_train(
|
||||
)
|
||||
|
||||
|
||||
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str:
|
||||
def prune_checkpoint(checkpoint_path: str, new_checkpoint_path: str, save_ema: bool, safetensors: bool) -> str:
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, weights_only=True)
|
||||
print("Original Checkpoint Keys:", checkpoint.keys())
|
||||
|
||||
ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
|
||||
if ema_model_state_dict is None:
|
||||
return "No 'ema_model_state_dict' found in the checkpoint."
|
||||
to_retain = "ema_model_state_dict" if save_ema else "model_state_dict"
|
||||
try:
|
||||
model_state_dict_to_retain = checkpoint[to_retain]
|
||||
except KeyError:
|
||||
return f"{to_retain} not found in the checkpoint."
|
||||
|
||||
if safetensors:
|
||||
new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors")
|
||||
save_file(ema_model_state_dict, new_checkpoint_path)
|
||||
save_file(model_state_dict_to_retain, new_checkpoint_path)
|
||||
else:
|
||||
new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt")
|
||||
new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
|
||||
new_checkpoint = {"ema_model_state_dict": model_state_dict_to_retain}
|
||||
torch.save(new_checkpoint, new_checkpoint_path)
|
||||
|
||||
return f"New checkpoint saved at: {new_checkpoint_path}"
|
||||
@@ -1013,7 +1022,10 @@ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
|
||||
|
||||
ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema])
|
||||
|
||||
torch.save(ckpt, new_ckpt_path)
|
||||
if new_ckpt_path.endswith(".safetensors"):
|
||||
save_file(ema_sd, new_ckpt_path)
|
||||
elif new_ckpt_path.endswith(".pt"):
|
||||
torch.save(ckpt, new_ckpt_path)
|
||||
|
||||
return vocab_new
|
||||
|
||||
@@ -1125,7 +1137,7 @@ def vocab_check(project_name):
|
||||
info = "You can train using your language !"
|
||||
else:
|
||||
vocab_miss = ",".join(miss_symbols)
|
||||
info = f"The following symbols are missing in your language {len(miss_symbols)}\n\n"
|
||||
info = f"The following {len(miss_symbols)} symbols are missing in your language\n\n"
|
||||
|
||||
return info, vocab_miss
|
||||
|
||||
@@ -1212,6 +1224,9 @@ def infer(
|
||||
|
||||
print("update >> ", device_test, file_checkpoint, use_ema)
|
||||
|
||||
if seed == -1: # -1 used for random
|
||||
seed = None
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
tts_api.infer(
|
||||
ref_file=ref_audio,
|
||||
@@ -1430,9 +1445,9 @@ Skip this step if you have your dataset, metadata.csv, and a folder wavs with al
|
||||
)
|
||||
|
||||
audio_speaker = gr.File(label="Voice", type="filepath", file_count="multiple")
|
||||
txt_lang = gr.Text(label="Language", value="English")
|
||||
txt_lang = gr.Textbox(label="Language", value="English")
|
||||
bt_transcribe = bt_create = gr.Button("Transcribe")
|
||||
txt_info_transcribe = gr.Text(label="Info", value="")
|
||||
txt_info_transcribe = gr.Textbox(label="Info", value="")
|
||||
bt_transcribe.click(
|
||||
fn=transcribe_all,
|
||||
inputs=[cm_project, audio_speaker, txt_lang, ch_manual],
|
||||
@@ -1443,7 +1458,7 @@ Skip this step if you have your dataset, metadata.csv, and a folder wavs with al
|
||||
random_sample_transcribe = gr.Button("Random Sample")
|
||||
|
||||
with gr.Row():
|
||||
random_text_transcribe = gr.Text(label="Text")
|
||||
random_text_transcribe = gr.Textbox(label="Text")
|
||||
random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
|
||||
|
||||
random_sample_transcribe.click(
|
||||
@@ -1458,7 +1473,7 @@ Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are incl
|
||||
```""")
|
||||
|
||||
check_button = gr.Button("Check Vocab")
|
||||
txt_info_check = gr.Text(label="Info", value="")
|
||||
txt_info_check = gr.Textbox(label="Info", value="")
|
||||
|
||||
gr.Markdown("""```plaintext
|
||||
Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
|
||||
@@ -1478,7 +1493,7 @@ Using the extended model, you can finetune to a new language that is missing sym
|
||||
txt_count_symbol = gr.Textbox(label="New Vocab Size", value="", scale=1)
|
||||
|
||||
extend_button = gr.Button("Extend")
|
||||
txt_info_extend = gr.Text(label="Info", value="")
|
||||
txt_info_extend = gr.Textbox(label="Info", value="")
|
||||
|
||||
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
|
||||
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
|
||||
@@ -1518,8 +1533,8 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
|
||||
ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False)
|
||||
|
||||
bt_prepare = bt_create = gr.Button("Prepare")
|
||||
txt_info_prepare = gr.Text(label="Info", value="")
|
||||
txt_vocab_prepare = gr.Text(label="Vocab", value="")
|
||||
txt_info_prepare = gr.Textbox(label="Info", value="")
|
||||
txt_vocab_prepare = gr.Textbox(label="Vocab", value="")
|
||||
|
||||
bt_prepare.click(
|
||||
fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare]
|
||||
@@ -1528,7 +1543,7 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
|
||||
random_sample_prepare = gr.Button("Random Sample")
|
||||
|
||||
with gr.Row():
|
||||
random_text_prepare = gr.Text(label="Tokenizer")
|
||||
random_text_prepare = gr.Textbox(label="Tokenizer")
|
||||
random_audio_prepare = gr.Audio(label="Audio", type="filepath")
|
||||
|
||||
random_sample_prepare.click(
|
||||
@@ -1541,50 +1556,60 @@ The auto-setting is still experimental. Set a large value of epoch if not sure;
|
||||
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
|
||||
```""")
|
||||
with gr.Row():
|
||||
bt_calculate = bt_create = gr.Button("Auto Settings")
|
||||
exp_name = gr.Radio(label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"])
|
||||
tokenizer_file = gr.Textbox(label="Tokenizer File")
|
||||
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint")
|
||||
|
||||
with gr.Row():
|
||||
ch_finetune = bt_create = gr.Checkbox(label="Finetune")
|
||||
lb_samples = gr.Label(label="Samples")
|
||||
batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
|
||||
bt_calculate = bt_create = gr.Button("Auto Settings")
|
||||
|
||||
with gr.Row():
|
||||
ch_finetune = bt_create = gr.Checkbox(label="Finetune", value=True)
|
||||
tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
|
||||
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
|
||||
epochs = gr.Number(label="Epochs")
|
||||
learning_rate = gr.Number(label="Learning Rate", step=0.5e-5)
|
||||
max_grad_norm = gr.Number(label="Max Gradient Norm")
|
||||
num_warmup_updates = gr.Number(label="Warmup Updates")
|
||||
|
||||
with gr.Row():
|
||||
exp_name = gr.Radio(
|
||||
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
|
||||
batch_size_type = gr.Radio(
|
||||
label="Batch Size Type",
|
||||
choices=["frame", "sample"],
|
||||
info="frame is calculated as seconds * sampling_rate / hop_length",
|
||||
)
|
||||
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
|
||||
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", info="N frames or N samples")
|
||||
grad_accumulation_steps = gr.Number(
|
||||
label="Gradient Accumulation Steps", info="Effective batch size is multiplied by this value"
|
||||
)
|
||||
max_samples = gr.Number(label="Max Samples", info="Maximum number of samples per single GPU batch")
|
||||
|
||||
with gr.Row():
|
||||
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=3200)
|
||||
max_samples = gr.Number(label="Max Samples", value=64)
|
||||
|
||||
with gr.Row():
|
||||
grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
|
||||
max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
|
||||
|
||||
with gr.Row():
|
||||
epochs = gr.Number(label="Epochs", value=100)
|
||||
num_warmup_updates = gr.Number(label="Warmup Updates", value=100)
|
||||
|
||||
with gr.Row():
|
||||
save_per_updates = gr.Number(label="Save per Updates", value=500)
|
||||
save_per_updates = gr.Number(
|
||||
label="Save per Updates",
|
||||
info="Save intermediate checkpoints every N updates",
|
||||
minimum=10,
|
||||
)
|
||||
keep_last_n_checkpoints = gr.Number(
|
||||
label="Keep Last N Checkpoints",
|
||||
value=-1,
|
||||
step=1,
|
||||
precision=0,
|
||||
info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
|
||||
info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N",
|
||||
minimum=-1,
|
||||
)
|
||||
last_per_updates = gr.Number(label="Last per Updates", value=100)
|
||||
last_per_updates = gr.Number(
|
||||
label="Last per Updates",
|
||||
info="Save latest checkpoint with suffix _last.pt every N updates",
|
||||
minimum=10,
|
||||
)
|
||||
gr.Radio(label="") # placeholder
|
||||
|
||||
with gr.Row():
|
||||
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
||||
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="fp16")
|
||||
cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
|
||||
start_button = gr.Button("Start Training")
|
||||
stop_button = gr.Button("Stop Training", interactive=False)
|
||||
mixed_precision = gr.Radio(label="Mixed Precision", choices=["none", "fp16", "bf16"])
|
||||
cd_logger = gr.Radio(label="Logger", choices=["none", "wandb", "tensorboard"])
|
||||
with gr.Column():
|
||||
start_button = gr.Button("Start Training")
|
||||
stop_button = gr.Button("Stop Training", interactive=False)
|
||||
|
||||
if projects_selelect is not None:
|
||||
(
|
||||
@@ -1631,7 +1656,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
ch_8bit_adam.value = bnb_optimizer_value
|
||||
|
||||
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
|
||||
txt_info_train = gr.Text(label="Info", value="")
|
||||
txt_info_train = gr.Textbox(label="Info", value="")
|
||||
|
||||
list_audios, select_audio = get_audio_project(projects_selelect, False)
|
||||
|
||||
@@ -1760,7 +1785,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
|
||||
with gr.TabItem("Test Model"):
|
||||
gr.Markdown("""```plaintext
|
||||
SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
|
||||
Check the use_ema setting (True or False) for your model to see what works best for you. Set seed to -1 for random.
|
||||
```""")
|
||||
exp_name = gr.Radio(
|
||||
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
|
||||
@@ -1770,11 +1795,13 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
|
||||
with gr.Row():
|
||||
nfe_step = gr.Number(label="NFE Step", value=32)
|
||||
speed = gr.Slider(label="Speed", value=1.0, minimum=0.3, maximum=2.0, step=0.1)
|
||||
seed = gr.Number(label="Seed", value=-1, minimum=-1)
|
||||
seed = gr.Number(label="Random Seed", value=-1, minimum=-1)
|
||||
remove_silence = gr.Checkbox(label="Remove Silence")
|
||||
|
||||
ch_use_ema = gr.Checkbox(label="Use EMA", value=True)
|
||||
with gr.Row():
|
||||
ch_use_ema = gr.Checkbox(
|
||||
label="Use EMA", value=True, info="Turn off at early stage might offer better results"
|
||||
)
|
||||
cm_checkpoint = gr.Dropdown(
|
||||
choices=list_checkpoints, value=checkpoint_select, label="Checkpoints", allow_custom_value=True
|
||||
)
|
||||
@@ -1782,20 +1809,20 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
|
||||
|
||||
random_sample_infer = gr.Button("Random Sample")
|
||||
|
||||
ref_text = gr.Textbox(label="Ref Text")
|
||||
ref_audio = gr.Audio(label="Audio Ref", type="filepath")
|
||||
gen_text = gr.Textbox(label="Gen Text")
|
||||
ref_text = gr.Textbox(label="Reference Text")
|
||||
ref_audio = gr.Audio(label="Reference Audio", type="filepath")
|
||||
gen_text = gr.Textbox(label="Text to Generate")
|
||||
|
||||
random_sample_infer.click(
|
||||
fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
txt_info_gpu = gr.Textbox("", label="Device")
|
||||
seed_info = gr.Text(label="Seed :")
|
||||
check_button_infer = gr.Button("Infer")
|
||||
txt_info_gpu = gr.Textbox("", label="Inference on Device :")
|
||||
seed_info = gr.Textbox(label="Used Random Seed :")
|
||||
check_button_infer = gr.Button("Inference")
|
||||
|
||||
gen_audio = gr.Audio(label="Audio Gen", type="filepath")
|
||||
gen_audio = gr.Audio(label="Generated Audio", type="filepath")
|
||||
|
||||
check_button_infer.click(
|
||||
fn=infer,
|
||||
@@ -1822,14 +1849,16 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
|
||||
gr.Markdown("""```plaintext
|
||||
Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining.
|
||||
```""")
|
||||
txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
|
||||
txt_path_checkpoint_small = gr.Text(label="Path to Output:")
|
||||
ch_safetensors = gr.Checkbox(label="Safetensors", value="")
|
||||
txt_info_reduse = gr.Text(label="Info", value="")
|
||||
reduse_button = gr.Button("Reduce")
|
||||
txt_path_checkpoint = gr.Textbox(label="Path to Checkpoint:")
|
||||
txt_path_checkpoint_small = gr.Textbox(label="Path to Output:")
|
||||
with gr.Row():
|
||||
ch_save_ema = gr.Checkbox(label="Save EMA checkpoint", value=True)
|
||||
ch_safetensors = gr.Checkbox(label="Save with safetensors format", value=True)
|
||||
txt_info_reduse = gr.Textbox(label="Info", value="")
|
||||
reduse_button = gr.Button("Prune")
|
||||
reduse_button.click(
|
||||
fn=extract_and_save_ema_model,
|
||||
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_safetensors],
|
||||
fn=prune_checkpoint,
|
||||
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_save_ema, ch_safetensors],
|
||||
outputs=[txt_info_reduse],
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from importlib.resources import files
|
||||
import hydra
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.model import CFM, DiT, UNetT, Trainer # noqa: F401. used for config
|
||||
from f5_tts.model import CFM, Trainer
|
||||
from f5_tts.model.dataset import load_dataset
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
|
||||
@@ -14,60 +14,60 @@ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to
|
||||
|
||||
|
||||
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
|
||||
def main(cfg):
|
||||
model_cls = globals()[cfg.model.backbone]
|
||||
model_arc = cfg.model.arch
|
||||
tokenizer = cfg.model.tokenizer
|
||||
mel_spec_type = cfg.model.mel_spec.mel_spec_type
|
||||
def main(model_cfg):
|
||||
model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
tokenizer = model_cfg.model.tokenizer
|
||||
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
|
||||
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
|
||||
exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}"
|
||||
wandb_resume_id = None
|
||||
|
||||
# set text tokenizer
|
||||
if tokenizer != "custom":
|
||||
tokenizer_path = cfg.datasets.name
|
||||
tokenizer_path = model_cfg.datasets.name
|
||||
else:
|
||||
tokenizer_path = cfg.model.tokenizer_path
|
||||
tokenizer_path = model_cfg.model.tokenizer_path
|
||||
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
||||
|
||||
# set model
|
||||
model = CFM(
|
||||
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
|
||||
mel_spec_kwargs=cfg.model.mel_spec,
|
||||
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels),
|
||||
mel_spec_kwargs=model_cfg.model.mel_spec,
|
||||
vocab_char_map=vocab_char_map,
|
||||
)
|
||||
|
||||
# init trainer
|
||||
trainer = Trainer(
|
||||
model,
|
||||
epochs=cfg.optim.epochs,
|
||||
learning_rate=cfg.optim.learning_rate,
|
||||
num_warmup_updates=cfg.optim.num_warmup_updates,
|
||||
save_per_updates=cfg.ckpts.save_per_updates,
|
||||
keep_last_n_checkpoints=cfg.ckpts.keep_last_n_checkpoints,
|
||||
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
|
||||
batch_size_per_gpu=cfg.datasets.batch_size_per_gpu,
|
||||
batch_size_type=cfg.datasets.batch_size_type,
|
||||
max_samples=cfg.datasets.max_samples,
|
||||
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
|
||||
max_grad_norm=cfg.optim.max_grad_norm,
|
||||
logger=cfg.ckpts.logger,
|
||||
epochs=model_cfg.optim.epochs,
|
||||
learning_rate=model_cfg.optim.learning_rate,
|
||||
num_warmup_updates=model_cfg.optim.num_warmup_updates,
|
||||
save_per_updates=model_cfg.ckpts.save_per_updates,
|
||||
keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints,
|
||||
checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")),
|
||||
batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu,
|
||||
batch_size_type=model_cfg.datasets.batch_size_type,
|
||||
max_samples=model_cfg.datasets.max_samples,
|
||||
grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
|
||||
max_grad_norm=model_cfg.optim.max_grad_norm,
|
||||
logger=model_cfg.ckpts.logger,
|
||||
wandb_project="CFM-TTS",
|
||||
wandb_run_name=exp_name,
|
||||
wandb_resume_id=wandb_resume_id,
|
||||
last_per_updates=cfg.ckpts.last_per_updates,
|
||||
log_samples=cfg.ckpts.log_samples,
|
||||
bnb_optimizer=cfg.optim.bnb_optimizer,
|
||||
last_per_updates=model_cfg.ckpts.last_per_updates,
|
||||
log_samples=model_cfg.ckpts.log_samples,
|
||||
bnb_optimizer=model_cfg.optim.bnb_optimizer,
|
||||
mel_spec_type=mel_spec_type,
|
||||
is_local_vocoder=cfg.model.vocoder.is_local,
|
||||
local_vocoder_path=cfg.model.vocoder.local_path,
|
||||
cfg_dict=OmegaConf.to_container(cfg, resolve=True),
|
||||
is_local_vocoder=model_cfg.model.vocoder.is_local,
|
||||
local_vocoder_path=model_cfg.model.vocoder.local_path,
|
||||
model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True),
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|
||||
train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec)
|
||||
trainer.train(
|
||||
train_dataset,
|
||||
num_workers=cfg.datasets.num_workers,
|
||||
num_workers=model_cfg.datasets.num_workers,
|
||||
resumable_with_seed=666, # seed for shuffling dataset
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user