27 Commits
1.0.0 ... 1.0.8

Author SHA1 Message Date
SWivid
b9156c0ad5 v1.0.8 fix a fatal bug with log_samples since 37eb3b50da 2025-03-25 07:49:19 +08:00
SWivid
3ad3211915 Update F5TTS_Small.yaml 2025-03-25 07:11:35 +08:00
Zhikang Niu
f6726a78cc Update F5TTS_Small.yaml 2025-03-23 22:27:02 +08:00
SWivid
1d0cf2b8ba add device option for infer-cli, patch-1 2025-03-22 17:35:16 +08:00
SWivid
1d82b7928e add device option for infer-cli 2025-03-22 17:30:23 +08:00
SWivid
4ae5347282 pre-commit update and formatting 2025-03-21 23:01:00 +08:00
SWivid
621559cbbe v1.0.7 2025-03-21 14:40:52 +08:00
SWivid
526b09eebd add no_zero_init v1 variant path to SHARED.md 2025-03-21 14:37:14 +08:00
SWivid
9afa80f204 add option in finetune gradio to save non-ema model weight 2025-03-21 13:36:11 +08:00
SWivid
c6b3189bbd v1.0.6 improves docker usage 2025-03-20 22:48:36 +08:00
Yushen CHEN
c87ce39515 Merge pull request #890 from MicahZoltu/patch-1
Improves documentation around docker usage.
2025-03-20 22:45:40 +08:00
Micah Zoltu
10ef27065b Improves documentation around docker usage. 2025-03-20 21:37:48 +08:00
SWivid
f374640f34 Merge branch 'main' of github.com:SWivid/F5-TTS 2025-03-20 13:54:52 +08:00
SWivid
d5f4c88aa4 update issue templates 2025-03-20 13:54:15 +08:00
Yushen CHEN
f968e13b6d Update README.md 2025-03-20 10:15:47 +08:00
SWivid
339b17fed3 update README.md for infer & train 2025-03-20 10:14:22 +08:00
SWivid
79302b694a update README.md for infer & train 2025-03-20 10:03:54 +08:00
SWivid
a1e88c2a9e v1.0.5 update finetune_gradio.py for clearer guidance 2025-03-17 21:50:50 +08:00
SWivid
1ab90505a4 v1.0.4 fix finetune_gradio.py vocab extend with .safetensors ckpt 2025-03-17 16:22:26 +08:00
SWivid
7e4985ca56 v1.0.3 fix api.py 2025-03-17 02:39:20 +08:00
SWivid
f05ceda4cb v1.0.2 fix: torch.utils.checkpoint.checkpoint add use_reentrant=False 2025-03-15 16:34:32 +08:00
Yushen CHEN
2bd39dd813 Merge pull request #859 from ZhikangNiu/main
fix #858 and pass use_reentrant explicitly in checkpoint_activation mode
2025-03-15 16:23:50 +08:00
ZhikangNiu
f017815083 fix #858 and pass use_reentrant explicitly in checkpoint_activation mode 2025-03-15 15:48:47 +08:00
Yushen CHEN
297755fac3 v1.0.1 VRAM usage management #851 2025-03-14 17:31:44 +08:00
Yushen CHEN
d05075205f Merge pull request #851 from niknah/vram-usage
VRAM usage on long texts gradually uses up memory.
2025-03-14 17:25:56 +08:00
Yushen CHEN
8722cf0766 Update utils_infer.py 2025-03-14 17:23:20 +08:00
niknah
48d1a9312e VRAM usage on long texts gradually uses up memory. 2025-03-14 16:53:58 +11:00
32 changed files with 276 additions and 222 deletions

View File

@@ -1,6 +1,6 @@
name: "Bug Report"
description: |
Please provide as much details to help address the issue, including logs and screenshots.
Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
labels:
- bug
body:
@@ -15,13 +15,13 @@ body:
required: true
- label: I have searched for existing issues, including closed ones, and couldn't find a solution.
required: true
- label: I confirm that I am using English to submit this report in order to facilitate communication.
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:
label: Environment Details
description: "Provide details such as OS, Python version, and any relevant software or dependencies."
placeholder: e.g., CentOS Linux 7, RTX 3090, Python 3.10, torch==2.3.0, cuda 11.8
description: "Provide details including OS, GPU info, Python version, any relevant software or dependencies, and trainer setting."
placeholder: e.g., CentOS Linux 7, 4 * RTX 3090, Python 3.10, torch==2.3.0+cu118, cuda 11.8, config yaml is ...
validations:
required: true
- type: textarea
@@ -39,12 +39,12 @@ body:
- type: textarea
attributes:
label: ✔️ Expected Behavior
placeholder: Describe what you expected to happen.
placeholder: Describe in detail what you expected to happen.
validations:
required: false
- type: textarea
attributes:
label: ❌ Actual Behavior
placeholder: Describe what actually happened.
placeholder: Describe in detail what actually happened.
validations:
required: false

View File

@@ -15,7 +15,7 @@ body:
required: true
- label: I have searched for existing issues, including closed ones, and found not discussion yet.
required: true
- label: I confirm that I am using English to submit this report in order to facilitate communication.
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:

View File

@@ -1,6 +1,6 @@
name: "Help Wanted"
description: |
Please provide as much details to help address the issue, including logs and screenshots.
Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
labels:
- help wanted
body:
@@ -15,36 +15,40 @@ body:
required: true
- label: I have searched for existing issues, including closed ones, and couldn't find a solution.
required: true
- label: I confirm that I am using English to submit this report in order to facilitate communication.
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:
label: Environment Details
description: "Provide details such as OS, Python version, and any relevant software or dependencies."
placeholder: e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
placeholder: |
e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
If training or finetuning related, provide detailed configuration including GPU info and training setup.
validations:
required: true
- type: textarea
attributes:
label: Steps to Reproduce
description: |
Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
Include detailed steps, screenshots, and logs. Provide used prompt wav and text. Use the correct markdown syntax for code blocks.
placeholder: |
1. Create a new conda environment.
2. Clone the repository and install as pip package.
3. Run the command: `f5-tts_infer-gradio` with no ref_text provided.
4. Stuck there with the following message... (attach logs and also error msg e.g. after ctrl-c).
5. Prompt & generated wavs are [change suffix to .mp4 to enable direct upload or pack all to .zip].
6. Reference audio's transcription or provided ref_text is `xxx`, and text to generate is `xxx`.
validations:
required: true
- type: textarea
attributes:
label: ✔️ Expected Behavior
placeholder: Describe what you expected to happen, e.g. output a generated audio
placeholder: Describe what you expected to happen in detail, e.g. output a generated audio.
validations:
required: false
- type: textarea
attributes:
label: ❌ Actual Behavior
placeholder: Describe what actually happened, failure messages, etc.
placeholder: Describe what actually happened in detail, failure messages, etc.
validations:
required: false

View File

@@ -1,6 +1,6 @@
name: "Question"
description: |
Pure question or inquiry about the project, usage issue goes with "help wanted".
Research question or pure inquiry about the project, usage issue goes with "help wanted".
labels:
- question
body:
@@ -9,13 +9,13 @@ body:
label: Checks
description: "To help us grasp quickly, please confirm the following:"
options:
- label: This template is only for question, not feature requests or bug reports.
- label: This template is only for research question, not usage problems, feature requests or bug reports.
required: true
- label: I have thoroughly reviewed the project documentation and read the related paper(s).
required: true
- label: I have searched for existing issues, including closed ones, no similar questions.
required: true
- label: I confirm that I am using English to submit this report in order to facilitate communication.
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:

2
.gitignore vendored
View File

@@ -7,8 +7,6 @@ ckpts/
wandb/
results/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

View File

@@ -1,7 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.7.0
rev: v0.11.2
hooks:
# Run the linter.
- id: ruff
@@ -9,6 +9,6 @@ repos:
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
rev: v5.0.0
hooks:
- id: check-yaml

View File

@@ -23,4 +23,8 @@ RUN git clone https://github.com/SWivid/F5-TTS.git \
ENV SHELL=/bin/bash
VOLUME /root/.cache/huggingface/hub/
EXPOSE 7860
WORKDIR /workspace/F5-TTS

View File

@@ -100,8 +100,11 @@ conda activate f5-tts
# Build from Dockerfile
docker build -t f5tts:v1 .
# Or pull from GitHub Container Registry
docker pull ghcr.io/swivid/f5-tts:main
# Run from GitHub Container Registry
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main
# Quickstart if you want to just run the web interface (not CLI)
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main f5-tts_infer-gradio --host 0.0.0.0
```
@@ -200,7 +203,7 @@ Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
## Development
Use pre-commit to ensure code quality (will run linters and formatters automatically)
Use pre-commit to ensure code quality (will run linters and formatters automatically):
```bash
pip install pre-commit
@@ -213,7 +216,7 @@ When making a pull request, before each commit, run:
pre-commit run --all-files
```
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
Note: Some model components have linting exceptions for E722 to accommodate tensor notation.
## Acknowledgements

View File

@@ -1,12 +1,3 @@
The pretrained model checkpoints can be reached at https://huggingface.co/SWivid/F5-TTS.
Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
```
ckpts/
F5TTS_v1_Base/
model_1250000.safetensors
F5TTS_Base/
model_1200000.safetensors
E2TTS_Base/
model_1200000.safetensors
```
Scripts will automatically pull model checkpoints from Huggingface, by default to `~/.cache/huggingface/hub/`.

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "1.0.0"
version = "1.0.8"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}

View File

@@ -5,6 +5,7 @@ from importlib.resources import files
import soundfile as sf
import tqdm
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
@@ -16,7 +17,6 @@ from f5_tts.infer.utils_infer import (
remove_silence_for_generated_wav,
save_spectrogram,
)
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
from f5_tts.model.utils import seed_everything
@@ -33,7 +33,7 @@ class F5TTS:
hf_cache_dir=None,
):
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
model_cls = globals()[model_cfg.model.backbone]
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
@@ -74,8 +74,6 @@ class F5TTS:
elif model == "E2TTS_Base":
repo_name = "E2-TTS"
ckpt_step = 1200000
else:
raise ValueError(f"Unknown model type: {model}")
if not ckpt_file:
ckpt_file = str(
@@ -117,10 +115,11 @@ class F5TTS:
seed=None,
):
if seed is None:
self.seed = random.randint(0, sys.maxsize)
seed_everything(self.seed)
seed = random.randint(0, sys.maxsize)
seed_everything(seed)
self.seed = seed
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)
wav, sr, spec = infer_process(
ref_file,

View File

@@ -10,7 +10,7 @@ datasets:
num_workers: 16
optim:
epochs: 11
epochs: 11 # only suitable for Emilia, if you want to train it on LibriTTS, set epoch 686
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -49,4 +49,4 @@ ckpts:
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -10,6 +10,7 @@ from importlib.resources import files
import torch
import torchaudio
from accelerate import Accelerator
from hydra.utils import get_class
from omegaconf import OmegaConf
from tqdm import tqdm
@@ -19,7 +20,7 @@ from f5_tts.eval.utils_eval import (
get_seedtts_testset_metainfo,
)
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
from f5_tts.model import CFM
from f5_tts.model.utils import get_tokenizer
accelerator = Accelerator()
@@ -65,7 +66,7 @@ def main():
no_ref_audio = False
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
model_cls = globals()[model_cfg.model.backbone]
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
dataset_name = model_cfg.datasets.name
@@ -195,7 +196,7 @@ def main():
accelerator.wait_for_everyone()
if accelerator.is_main_process:
timediff = time.time() - start
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
print(f"Done batch inference in {timediff / 60:.2f} minutes.")
if __name__ == "__main__":

View File

@@ -148,9 +148,9 @@ def get_inference_prompt(
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert (
min_tokens <= total_mel_len <= max_tokens
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
assert min_tokens <= total_mel_len <= max_tokens, (
f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
)
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
utts[bucket_i].append(utt)

View File

@@ -4,16 +4,17 @@ The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://h
**More checkpoints with whole community efforts can be found in [SHARED.md](SHARED.md), supporting more languages.**
Currently support **30s for a single** generation, which is the **total length** including both prompt and output audio. However, you can provide `infer_cli` and `infer_gradio` with longer text, will automatically do chunk generation. Long reference audio will be **clip short to ~15s**.
Currently support **30s for a single** generation, which is the **total length** (same logic if `fix_duration`) including both prompt and output audio. However, `infer_cli` and `infer_gradio` will automatically do chunk generation for longer text. Long reference audio will be **clip short to ~12s**.
To avoid possible inference failures, make sure you have seen through the following instructions.
- Use reference audio <15s and leave some silence (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
- Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses.
- Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English.
- If the generation output is blank (pure silence), check for ffmpeg installation (various tutorials online, blogs, videos, etc.).
- Try turn off use_ema if using an early-stage finetuned checkpoint (which goes just few updates).
- Use reference audio <12s and leave proper silence space (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
- <ins>Uppercased letters</ins> (best with form like K.F.C.) will be uttered letter by letter, and lowercased letters used for common words.
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") <ins>to explicitly introduce some pauses</ins>.
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
- <ins>Preprocess numbers</ins> to Chinese letters if you want to have them read in Chinese, otherwise in English.
- If the generation output is blank (pure silence), <ins>check for ffmpeg installation</ins>.
- Try <ins>turn off `use_ema` if using an early-stage</ins> finetuned checkpoint (which goes just few updates).
## Gradio App

View File

@@ -44,6 +44,7 @@
```bash
Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
# A Variant Model: hf://SWivid/F5-TTS/F5TTS_v1_Base_no_zero_init/model_1250000.safetensors
Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
```

View File

@@ -10,6 +10,7 @@ import numpy as np
import soundfile as sf
import tomli
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
@@ -21,13 +22,13 @@ from f5_tts.infer.utils_infer import (
sway_sampling_coef,
speed,
fix_duration,
device,
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
)
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
parser = argparse.ArgumentParser(
@@ -162,6 +163,11 @@ parser.add_argument(
type=float,
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
)
parser.add_argument(
"--device",
type=str,
help="Specify the device to run on",
)
args = parser.parse_args()
@@ -202,6 +208,7 @@ cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
speed = args.speed or config.get("speed", speed)
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
device = args.device or config.get("device", device)
# patches for pip pkg user
@@ -239,20 +246,23 @@ if vocoder_name == "vocos":
elif vocoder_name == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
vocoder = load_vocoder(
vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
)
# load TTS model
model_cfg = OmegaConf.load(
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
).model
model_cls = globals()[model_cfg.backbone]
)
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
if model != "F5TTS_Base":
assert vocoder_name == model_cfg.mel_spec.mel_spec_type
assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type
# override for previous models
if model == "F5TTS_Base":
@@ -269,7 +279,9 @@ if not ckpt_file:
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
print(f"Using {model}...")
ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
ema_model = load_model(
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
)
# inference process
@@ -325,6 +337,7 @@ def main():
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=device,
)
generated_audio_segments.append(audio_segment)
@@ -332,7 +345,7 @@ def main():
if len(gen_text_) > 200:
gen_text_ = gen_text_[:200] + " ... "
sf.write(
os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"),
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
audio_segment,
final_sample_rate,
)

View File

@@ -758,9 +758,9 @@ This is {"a local web UI for [F5 TTS](https://github.com/SWivid/F5-TTS)" if not
The checkpoints currently support English and Chinese.
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 12s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<12s). Ensure the audio is fully uploaded before generating.**
"""
)

View File

@@ -7,10 +7,11 @@ from importlib.resources import files
import torch
import torch.nn.functional as F
import torchaudio
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
from f5_tts.model import CFM
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
device = (
@@ -40,7 +41,7 @@ target_rms = 0.1
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
model_cls = globals()[model_cfg.model.backbone]
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
dataset_name = model_cfg.datasets.name

View File

@@ -149,7 +149,7 @@ def initialize_asr_pipeline(device: str = device, dtype=None):
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 6
and torch.cuda.get_device_properties(device).major >= 7
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
@@ -186,7 +186,7 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 6
and torch.cuda.get_device_properties(device).major >= 7
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
@@ -289,7 +289,7 @@ def remove_silence_edges(audio, silence_threshold=-42):
# preprocess reference audio and text
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device):
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
show_info("Converting audio...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
aseg = AudioSegment.from_file(ref_audio_orig)
@@ -302,7 +302,7 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 15s, clipping short. (1)")
show_info("Audio is over 12s, clipping short. (1)")
break
non_silent_wave += non_silent_seg
@@ -314,7 +314,7 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 15s, clipping short. (2)")
show_info("Audio is over 12s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
@@ -323,7 +323,7 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
# 3. if no proper silence found for clipping
if len(aseg) > 12000:
aseg = aseg[:12000]
show_info("Audio is over 15s, clipping short. (3)")
show_info("Audio is over 12s, clipping short. (3)")
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
aseg.export(f.name, format="wav")
@@ -479,14 +479,15 @@ def infer_batch_process(
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
del _
generated = generated.to(torch.float32)
generated = generated.to(torch.float32) # generated mel spectrogram
generated = generated[:, ref_audio_len:, :]
generated_mel_spec = generated.permute(0, 2, 1)
generated = generated.permute(0, 2, 1)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(generated_mel_spec)
generated_wave = vocoder.decode(generated)
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(generated_mel_spec)
generated_wave = vocoder(generated)
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
@@ -497,7 +498,9 @@ def infer_batch_process(
for j in range(0, len(generated_wave), chunk_size):
yield generated_wave[j : j + chunk_size], target_sample_rate
else:
yield generated_wave, generated_mel_spec[0].cpu().numpy()
generated_cpu = generated[0].cpu().numpy()
del generated
yield generated_wave, generated_cpu
if streaming:
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:

View File

@@ -219,7 +219,8 @@ class DiT(nn.Module):
for block in self.transformer_blocks:
if self.checkpoint_activations:
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
# https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
else:
x = block(x, t, mask=mask, rope=rope)

View File

@@ -350,7 +350,7 @@ class Trainer:
progress_bar = tqdm(
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
desc=f"Epoch {epoch+1}/{self.epochs}",
desc=f"Epoch {epoch + 1}/{self.epochs}",
unit="update",
disable=not self.accelerator.is_local_main_process,
initial=progress_bar_initial,
@@ -428,6 +428,7 @@ class Trainer:
torchaudio.save(
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
)
self.model.train()
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update, last=True)

View File

@@ -24,7 +24,7 @@ updates_per_epoch = total_hours / mini_batch_hours
# result
epochs = wanted_max_updates / updates_per_epoch
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})")
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")

View File

@@ -13,9 +13,9 @@ from importlib.resources import files
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config
from f5_tts.infer.utils_infer import (
chunk_text,
preprocess_ref_audio_text,
@@ -80,7 +80,7 @@ class TTSStreamingProcessor:
else "cpu"
)
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
self.model_cls = globals()[model_cfg.model.backbone]
self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
self.model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate

View File

@@ -51,7 +51,11 @@ Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussio
Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
If want to finetune with a variant version e.g. *F5TTS_v1_Base_no_zero_init*, manually download pretrained checkpoint from model weight repository and fill in the path correspondingly on web interface.
If use tensorboard as logger, install it first with `pip install tensorboard`.
<ins>The `use_ema = True` might be harmful for early-stage finetuned checkpoints</ins> (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off with finetune gradio option or `load_model(..., use_ema=False)`, see if offer better results.
### 3. W&B Logging

View File

@@ -122,7 +122,7 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None):
for future in tqdm(
chunk_futures,
total=len(chunk),
desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}",
desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}",
):
try:
result = future.result()
@@ -233,7 +233,7 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
dataset_name = out_dir.stem
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):

View File

@@ -198,7 +198,7 @@ def main():
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if "ZH" in langs:
print(f"Bad zh transcription case: {total_bad_case_zh}")
if "EN" in langs:

View File

@@ -72,7 +72,7 @@ def main():
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if __name__ == "__main__":

View File

@@ -50,7 +50,7 @@ def main():
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if __name__ == "__main__":

View File

@@ -40,15 +40,15 @@ def parse_args():
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
parser.add_argument("--num_warmup_updates", type=int, default=20000, help="Warmup updates")
parser.add_argument("--save_per_updates", type=int, default=50000, help="Save checkpoint every N updates")
parser.add_argument(
"--keep_last_n_checkpoints",
type=int,
default=-1,
help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
)
parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
parser.add_argument("--last_per_updates", type=int, default=5000, help="Save last checkpoint every N updates")
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
parser.add_argument(
@@ -65,7 +65,7 @@ def parse_args():
action="store_true",
help="Log inferenced samples per ckpt save updates",
)
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
parser.add_argument("--logger", type=str, default=None, choices=[None, "wandb", "tensorboard"], help="logger")
parser.add_argument(
"--bnb_optimizer",
action="store_true",

View File

@@ -120,11 +120,11 @@ def load_settings(project_name):
default_settings = {
"exp_name": "F5TTS_v1_Base",
"learning_rate": 1e-5,
"batch_size_per_gpu": 1,
"batch_size_type": "sample",
"batch_size_per_gpu": 3200,
"batch_size_type": "frame",
"max_samples": 64,
"grad_accumulation_steps": 4,
"max_grad_norm": 1,
"grad_accumulation_steps": 1,
"max_grad_norm": 1.0,
"epochs": 100,
"num_warmup_updates": 100,
"save_per_updates": 500,
@@ -134,8 +134,8 @@ def load_settings(project_name):
"file_checkpoint_train": "",
"tokenizer_type": "pinyin",
"tokenizer_file": "",
"mixed_precision": "none",
"logger": "wandb",
"mixed_precision": "fp16",
"logger": "none",
"bnb_optimizer": False,
}
@@ -361,27 +361,27 @@ def terminate_process(pid):
def start_training(
dataset_name="",
exp_name="F5TTS_v1_Base",
learning_rate=1e-5,
batch_size_per_gpu=1,
batch_size_type="sample",
max_samples=64,
grad_accumulation_steps=4,
max_grad_norm=1.0,
epochs=100,
num_warmup_updates=100,
save_per_updates=500,
keep_last_n_checkpoints=-1,
last_per_updates=100,
finetune=True,
file_checkpoint_train="",
tokenizer_type="pinyin",
tokenizer_file="",
mixed_precision="fp16",
stream=False,
logger="wandb",
ch_8bit_adam=False,
dataset_name,
exp_name,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
grad_accumulation_steps,
max_grad_norm,
epochs,
num_warmup_updates,
save_per_updates,
keep_last_n_checkpoints,
last_per_updates,
finetune,
file_checkpoint_train,
tokenizer_type,
tokenizer_file,
mixed_precision,
stream,
logger,
ch_8bit_adam,
):
global training_process, tts_api, stop_signal
@@ -458,7 +458,10 @@ def start_training(
cmd += f" --tokenizer {tokenizer_type}"
cmd += f" --log_samples --logger {logger}"
if logger != "none":
cmd += f" --logger {logger}"
cmd += " --log_samples"
if ch_8bit_adam:
cmd += " --bnb_optimizer"
@@ -515,7 +518,7 @@ def start_training(
training_process = subprocess.Popen(
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env
)
yield "Training started...", gr.update(interactive=False), gr.update(interactive=True)
yield "Training started ...", gr.update(interactive=False), gr.update(interactive=True)
stdout_queue = queue.Queue()
stderr_queue = queue.Queue()
@@ -584,7 +587,11 @@ def start_training(
gr.update(interactive=True),
)
else:
yield "Training complete!", gr.update(interactive=False), gr.update(interactive=True)
yield (
"Training complete or paused ...",
gr.update(interactive=False),
gr.update(interactive=True),
)
break
# Small sleep to prevent CPU thrashing
@@ -598,9 +605,9 @@ def start_training(
time.sleep(1)
if training_process is None:
text_info = "train stop"
text_info = "Train stopped !"
else:
text_info = "train complete !"
text_info = "Train complete at end !"
except Exception as e: # Catch all exceptions
# Ensure that we reset the training process variable in case of an error
@@ -615,11 +622,11 @@ def stop_training():
global training_process, stop_signal
if training_process is None:
return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
return "Train not running !", gr.update(interactive=True), gr.update(interactive=False)
terminate_process_tree(training_process.pid)
# training_process = None
stop_signal = True
return "train stop", gr.update(interactive=True), gr.update(interactive=False)
return "Train stopped !", gr.update(interactive=True), gr.update(interactive=False)
def get_list_projects():
@@ -958,21 +965,23 @@ def calculate_train(
)
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str:
def prune_checkpoint(checkpoint_path: str, new_checkpoint_path: str, save_ema: bool, safetensors: bool) -> str:
try:
checkpoint = torch.load(checkpoint_path, weights_only=True)
print("Original Checkpoint Keys:", checkpoint.keys())
ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
if ema_model_state_dict is None:
return "No 'ema_model_state_dict' found in the checkpoint."
to_retain = "ema_model_state_dict" if save_ema else "model_state_dict"
try:
model_state_dict_to_retain = checkpoint[to_retain]
except KeyError:
return f"{to_retain} not found in the checkpoint."
if safetensors:
new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors")
save_file(ema_model_state_dict, new_checkpoint_path)
save_file(model_state_dict_to_retain, new_checkpoint_path)
else:
new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt")
new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
new_checkpoint = {"ema_model_state_dict": model_state_dict_to_retain}
torch.save(new_checkpoint, new_checkpoint_path)
return f"New checkpoint saved at: {new_checkpoint_path}"
@@ -1013,7 +1022,10 @@ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema])
torch.save(ckpt, new_ckpt_path)
if new_ckpt_path.endswith(".safetensors"):
save_file(ema_sd, new_ckpt_path)
elif new_ckpt_path.endswith(".pt"):
torch.save(ckpt, new_ckpt_path)
return vocab_new
@@ -1125,7 +1137,7 @@ def vocab_check(project_name):
info = "You can train using your language !"
else:
vocab_miss = ",".join(miss_symbols)
info = f"The following symbols are missing in your language {len(miss_symbols)}\n\n"
info = f"The following {len(miss_symbols)} symbols are missing in your language\n\n"
return info, vocab_miss
@@ -1212,6 +1224,9 @@ def infer(
print("update >> ", device_test, file_checkpoint, use_ema)
if seed == -1: # -1 used for random
seed = None
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
tts_api.infer(
ref_file=ref_audio,
@@ -1430,9 +1445,9 @@ Skip this step if you have your dataset, metadata.csv, and a folder wavs with al
)
audio_speaker = gr.File(label="Voice", type="filepath", file_count="multiple")
txt_lang = gr.Text(label="Language", value="English")
txt_lang = gr.Textbox(label="Language", value="English")
bt_transcribe = bt_create = gr.Button("Transcribe")
txt_info_transcribe = gr.Text(label="Info", value="")
txt_info_transcribe = gr.Textbox(label="Info", value="")
bt_transcribe.click(
fn=transcribe_all,
inputs=[cm_project, audio_speaker, txt_lang, ch_manual],
@@ -1443,7 +1458,7 @@ Skip this step if you have your dataset, metadata.csv, and a folder wavs with al
random_sample_transcribe = gr.Button("Random Sample")
with gr.Row():
random_text_transcribe = gr.Text(label="Text")
random_text_transcribe = gr.Textbox(label="Text")
random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
random_sample_transcribe.click(
@@ -1458,7 +1473,7 @@ Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are incl
```""")
check_button = gr.Button("Check Vocab")
txt_info_check = gr.Text(label="Info", value="")
txt_info_check = gr.Textbox(label="Info", value="")
gr.Markdown("""```plaintext
Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
@@ -1478,7 +1493,7 @@ Using the extended model, you can finetune to a new language that is missing sym
txt_count_symbol = gr.Textbox(label="New Vocab Size", value="", scale=1)
extend_button = gr.Button("Extend")
txt_info_extend = gr.Text(label="Info", value="")
txt_info_extend = gr.Textbox(label="Info", value="")
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
@@ -1518,8 +1533,8 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False)
bt_prepare = bt_create = gr.Button("Prepare")
txt_info_prepare = gr.Text(label="Info", value="")
txt_vocab_prepare = gr.Text(label="Vocab", value="")
txt_info_prepare = gr.Textbox(label="Info", value="")
txt_vocab_prepare = gr.Textbox(label="Vocab", value="")
bt_prepare.click(
fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare]
@@ -1528,7 +1543,7 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
random_sample_prepare = gr.Button("Random Sample")
with gr.Row():
random_text_prepare = gr.Text(label="Tokenizer")
random_text_prepare = gr.Textbox(label="Tokenizer")
random_audio_prepare = gr.Audio(label="Audio", type="filepath")
random_sample_prepare.click(
@@ -1541,50 +1556,60 @@ The auto-setting is still experimental. Set a large value of epoch if not sure;
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
```""")
with gr.Row():
bt_calculate = bt_create = gr.Button("Auto Settings")
exp_name = gr.Radio(label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"])
tokenizer_file = gr.Textbox(label="Tokenizer File")
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint")
with gr.Row():
ch_finetune = bt_create = gr.Checkbox(label="Finetune")
lb_samples = gr.Label(label="Samples")
batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
bt_calculate = bt_create = gr.Button("Auto Settings")
with gr.Row():
ch_finetune = bt_create = gr.Checkbox(label="Finetune", value=True)
tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
epochs = gr.Number(label="Epochs")
learning_rate = gr.Number(label="Learning Rate", step=0.5e-5)
max_grad_norm = gr.Number(label="Max Gradient Norm")
num_warmup_updates = gr.Number(label="Warmup Updates")
with gr.Row():
exp_name = gr.Radio(
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
batch_size_type = gr.Radio(
label="Batch Size Type",
choices=["frame", "sample"],
info="frame is calculated as seconds * sampling_rate / hop_length",
)
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", info="N frames or N samples")
grad_accumulation_steps = gr.Number(
label="Gradient Accumulation Steps", info="Effective batch size is multiplied by this value"
)
max_samples = gr.Number(label="Max Samples", info="Maximum number of samples per single GPU batch")
with gr.Row():
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=3200)
max_samples = gr.Number(label="Max Samples", value=64)
with gr.Row():
grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
with gr.Row():
epochs = gr.Number(label="Epochs", value=100)
num_warmup_updates = gr.Number(label="Warmup Updates", value=100)
with gr.Row():
save_per_updates = gr.Number(label="Save per Updates", value=500)
save_per_updates = gr.Number(
label="Save per Updates",
info="Save intermediate checkpoints every N updates",
minimum=10,
)
keep_last_n_checkpoints = gr.Number(
label="Keep Last N Checkpoints",
value=-1,
step=1,
precision=0,
info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N",
minimum=-1,
)
last_per_updates = gr.Number(label="Last per Updates", value=100)
last_per_updates = gr.Number(
label="Last per Updates",
info="Save latest checkpoint with suffix _last.pt every N updates",
minimum=10,
)
gr.Radio(label="") # placeholder
with gr.Row():
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="fp16")
cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
start_button = gr.Button("Start Training")
stop_button = gr.Button("Stop Training", interactive=False)
mixed_precision = gr.Radio(label="Mixed Precision", choices=["none", "fp16", "bf16"])
cd_logger = gr.Radio(label="Logger", choices=["none", "wandb", "tensorboard"])
with gr.Column():
start_button = gr.Button("Start Training")
stop_button = gr.Button("Stop Training", interactive=False)
if projects_selelect is not None:
(
@@ -1631,7 +1656,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
ch_8bit_adam.value = bnb_optimizer_value
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
txt_info_train = gr.Text(label="Info", value="")
txt_info_train = gr.Textbox(label="Info", value="")
list_audios, select_audio = get_audio_project(projects_selelect, False)
@@ -1760,7 +1785,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
with gr.TabItem("Test Model"):
gr.Markdown("""```plaintext
SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
Check the use_ema setting (True or False) for your model to see what works best for you. Set seed to -1 for random.
```""")
exp_name = gr.Radio(
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
@@ -1770,11 +1795,13 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
with gr.Row():
nfe_step = gr.Number(label="NFE Step", value=32)
speed = gr.Slider(label="Speed", value=1.0, minimum=0.3, maximum=2.0, step=0.1)
seed = gr.Number(label="Seed", value=-1, minimum=-1)
seed = gr.Number(label="Random Seed", value=-1, minimum=-1)
remove_silence = gr.Checkbox(label="Remove Silence")
ch_use_ema = gr.Checkbox(label="Use EMA", value=True)
with gr.Row():
ch_use_ema = gr.Checkbox(
label="Use EMA", value=True, info="Turn off at early stage might offer better results"
)
cm_checkpoint = gr.Dropdown(
choices=list_checkpoints, value=checkpoint_select, label="Checkpoints", allow_custom_value=True
)
@@ -1782,20 +1809,20 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
random_sample_infer = gr.Button("Random Sample")
ref_text = gr.Textbox(label="Ref Text")
ref_audio = gr.Audio(label="Audio Ref", type="filepath")
gen_text = gr.Textbox(label="Gen Text")
ref_text = gr.Textbox(label="Reference Text")
ref_audio = gr.Audio(label="Reference Audio", type="filepath")
gen_text = gr.Textbox(label="Text to Generate")
random_sample_infer.click(
fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
)
with gr.Row():
txt_info_gpu = gr.Textbox("", label="Device")
seed_info = gr.Text(label="Seed :")
check_button_infer = gr.Button("Infer")
txt_info_gpu = gr.Textbox("", label="Inference on Device :")
seed_info = gr.Textbox(label="Used Random Seed :")
check_button_infer = gr.Button("Inference")
gen_audio = gr.Audio(label="Audio Gen", type="filepath")
gen_audio = gr.Audio(label="Generated Audio", type="filepath")
check_button_infer.click(
fn=infer,
@@ -1822,14 +1849,16 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
gr.Markdown("""```plaintext
Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining.
```""")
txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
txt_path_checkpoint_small = gr.Text(label="Path to Output:")
ch_safetensors = gr.Checkbox(label="Safetensors", value="")
txt_info_reduse = gr.Text(label="Info", value="")
reduse_button = gr.Button("Reduce")
txt_path_checkpoint = gr.Textbox(label="Path to Checkpoint:")
txt_path_checkpoint_small = gr.Textbox(label="Path to Output:")
with gr.Row():
ch_save_ema = gr.Checkbox(label="Save EMA checkpoint", value=True)
ch_safetensors = gr.Checkbox(label="Save with safetensors format", value=True)
txt_info_reduse = gr.Textbox(label="Info", value="")
reduse_button = gr.Button("Prune")
reduse_button.click(
fn=extract_and_save_ema_model,
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_safetensors],
fn=prune_checkpoint,
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_save_ema, ch_safetensors],
outputs=[txt_info_reduse],
)

View File

@@ -6,7 +6,7 @@ from importlib.resources import files
import hydra
from omegaconf import OmegaConf
from f5_tts.model import CFM, DiT, UNetT, Trainer # noqa: F401. used for config
from f5_tts.model import CFM, Trainer
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
@@ -14,60 +14,60 @@ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
def main(cfg):
model_cls = globals()[cfg.model.backbone]
model_arc = cfg.model.arch
tokenizer = cfg.model.tokenizer
mel_spec_type = cfg.model.mel_spec.mel_spec_type
def main(model_cfg):
model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}"
wandb_resume_id = None
# set text tokenizer
if tokenizer != "custom":
tokenizer_path = cfg.datasets.name
tokenizer_path = model_cfg.datasets.name
else:
tokenizer_path = cfg.model.tokenizer_path
tokenizer_path = model_cfg.model.tokenizer_path
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
# set model
model = CFM(
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
mel_spec_kwargs=cfg.model.mel_spec,
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels),
mel_spec_kwargs=model_cfg.model.mel_spec,
vocab_char_map=vocab_char_map,
)
# init trainer
trainer = Trainer(
model,
epochs=cfg.optim.epochs,
learning_rate=cfg.optim.learning_rate,
num_warmup_updates=cfg.optim.num_warmup_updates,
save_per_updates=cfg.ckpts.save_per_updates,
keep_last_n_checkpoints=cfg.ckpts.keep_last_n_checkpoints,
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
batch_size_per_gpu=cfg.datasets.batch_size_per_gpu,
batch_size_type=cfg.datasets.batch_size_type,
max_samples=cfg.datasets.max_samples,
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
max_grad_norm=cfg.optim.max_grad_norm,
logger=cfg.ckpts.logger,
epochs=model_cfg.optim.epochs,
learning_rate=model_cfg.optim.learning_rate,
num_warmup_updates=model_cfg.optim.num_warmup_updates,
save_per_updates=model_cfg.ckpts.save_per_updates,
keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints,
checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")),
batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu,
batch_size_type=model_cfg.datasets.batch_size_type,
max_samples=model_cfg.datasets.max_samples,
grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
max_grad_norm=model_cfg.optim.max_grad_norm,
logger=model_cfg.ckpts.logger,
wandb_project="CFM-TTS",
wandb_run_name=exp_name,
wandb_resume_id=wandb_resume_id,
last_per_updates=cfg.ckpts.last_per_updates,
log_samples=cfg.ckpts.log_samples,
bnb_optimizer=cfg.optim.bnb_optimizer,
last_per_updates=model_cfg.ckpts.last_per_updates,
log_samples=model_cfg.ckpts.log_samples,
bnb_optimizer=model_cfg.optim.bnb_optimizer,
mel_spec_type=mel_spec_type,
is_local_vocoder=cfg.model.vocoder.is_local,
local_vocoder_path=cfg.model.vocoder.local_path,
cfg_dict=OmegaConf.to_container(cfg, resolve=True),
is_local_vocoder=model_cfg.model.vocoder.is_local,
local_vocoder_path=model_cfg.model.vocoder.local_path,
model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True),
)
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec)
trainer.train(
train_dataset,
num_workers=cfg.datasets.num_workers,
num_workers=model_cfg.datasets.num_workers,
resumable_with_seed=666, # seed for shuffling dataset
)