pre-commit update and formatting

This commit is contained in:
SWivid
2025-03-21 23:01:00 +08:00
parent 621559cbbe
commit 4ae5347282
18 changed files with 66 additions and 75 deletions

2
.gitignore vendored
View File

@@ -7,8 +7,6 @@ ckpts/
wandb/ wandb/
results/ results/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]

View File

@@ -1,7 +1,7 @@
repos: repos:
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version. # Ruff version.
rev: v0.7.0 rev: v0.11.2
hooks: hooks:
# Run the linter. # Run the linter.
- id: ruff - id: ruff
@@ -9,6 +9,6 @@ repos:
# Run the formatter. # Run the formatter.
- id: ruff-format - id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0 rev: v5.0.0
hooks: hooks:
- id: check-yaml - id: check-yaml

View File

@@ -23,9 +23,8 @@ RUN git clone https://github.com/SWivid/F5-TTS.git \
ENV SHELL=/bin/bash ENV SHELL=/bin/bash
# models are downloaded into this folder, so user should mount it
VOLUME /root/.cache/huggingface/hub/ VOLUME /root/.cache/huggingface/hub/
# port the GUI is exposed on by default, if it is run
EXPOSE 7860 EXPOSE 7860
WORKDIR /workspace/F5-TTS WORKDIR /workspace/F5-TTS

View File

@@ -203,7 +203,7 @@ Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
## Development ## Development
Use pre-commit to ensure code quality (will run linters and formatters automatically) Use pre-commit to ensure code quality (will run linters and formatters automatically):
```bash ```bash
pip install pre-commit pip install pre-commit
@@ -216,7 +216,7 @@ When making a pull request, before each commit, run:
pre-commit run --all-files pre-commit run --all-files
``` ```
Note: Some model components have linting exceptions for E722 to accommodate tensor notation Note: Some model components have linting exceptions for E722 to accommodate tensor notation.
## Acknowledgements ## Acknowledgements

View File

@@ -1,12 +1,3 @@
The pretrained model checkpoints can be reached at https://huggingface.co/SWivid/F5-TTS.
Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS Scripts will automatically pull model checkpoints from Huggingface, by default to `~/.cache/huggingface/hub/`.
```
ckpts/
F5TTS_v1_Base/
model_1250000.safetensors
F5TTS_Base/
model_1200000.safetensors
E2TTS_Base/
model_1200000.safetensors
```

View File

@@ -5,6 +5,7 @@ from importlib.resources import files
import soundfile as sf import soundfile as sf
import tqdm import tqdm
from cached_path import cached_path from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import ( from f5_tts.infer.utils_infer import (
@@ -16,7 +17,6 @@ from f5_tts.infer.utils_infer import (
remove_silence_for_generated_wav, remove_silence_for_generated_wav,
save_spectrogram, save_spectrogram,
) )
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
from f5_tts.model.utils import seed_everything from f5_tts.model.utils import seed_everything
@@ -33,7 +33,7 @@ class F5TTS:
hf_cache_dir=None, hf_cache_dir=None,
): ):
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
model_cls = globals()[model_cfg.model.backbone] model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type

View File

@@ -10,6 +10,7 @@ from importlib.resources import files
import torch import torch
import torchaudio import torchaudio
from accelerate import Accelerator from accelerate import Accelerator
from hydra.utils import get_class
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
@@ -19,7 +20,7 @@ from f5_tts.eval.utils_eval import (
get_seedtts_testset_metainfo, get_seedtts_testset_metainfo,
) )
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config from f5_tts.model import CFM
from f5_tts.model.utils import get_tokenizer from f5_tts.model.utils import get_tokenizer
accelerator = Accelerator() accelerator = Accelerator()
@@ -65,7 +66,7 @@ def main():
no_ref_audio = False no_ref_audio = False
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml"))) model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
model_cls = globals()[model_cfg.model.backbone] model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch model_arc = model_cfg.model.arch
dataset_name = model_cfg.datasets.name dataset_name = model_cfg.datasets.name

View File

@@ -148,9 +148,9 @@ def get_inference_prompt(
# deal with batch # deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0." assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert ( assert min_tokens <= total_mel_len <= max_tokens, (
min_tokens <= total_mel_len <= max_tokens f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." )
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets) bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
utts[bucket_i].append(utt) utts[bucket_i].append(utt)

View File

@@ -10,6 +10,7 @@ import numpy as np
import soundfile as sf import soundfile as sf
import tomli import tomli
from cached_path import cached_path from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import ( from f5_tts.infer.utils_infer import (
@@ -27,7 +28,6 @@ from f5_tts.infer.utils_infer import (
preprocess_ref_audio_text, preprocess_ref_audio_text,
remove_silence_for_generated_wav, remove_silence_for_generated_wav,
) )
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@@ -246,13 +246,14 @@ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_loc
model_cfg = OmegaConf.load( model_cfg = OmegaConf.load(
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
).model )
model_cls = globals()[model_cfg.backbone] model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
if model != "F5TTS_Base": if model != "F5TTS_Base":
assert vocoder_name == model_cfg.mel_spec.mel_spec_type assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type
# override for previous models # override for previous models
if model == "F5TTS_Base": if model == "F5TTS_Base":
@@ -269,7 +270,7 @@ if not ckpt_file:
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")) ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
print(f"Using {model}...") print(f"Using {model}...")
ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file) ema_model = load_model(model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
# inference process # inference process

View File

@@ -7,10 +7,11 @@ from importlib.resources import files
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
from hydra.utils import get_class
from omegaconf import OmegaConf from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config from f5_tts.model import CFM
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
device = ( device = (
@@ -40,7 +41,7 @@ target_rms = 0.1
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml"))) model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
model_cls = globals()[model_cfg.model.backbone] model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch model_arc = model_cfg.model.arch
dataset_name = model_cfg.datasets.name dataset_name = model_cfg.datasets.name

View File

@@ -13,9 +13,9 @@ from importlib.resources import files
import torch import torch
import torchaudio import torchaudio
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from hydra.utils import get_class
from omegaconf import OmegaConf from omegaconf import OmegaConf
from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config
from f5_tts.infer.utils_infer import ( from f5_tts.infer.utils_infer import (
chunk_text, chunk_text,
preprocess_ref_audio_text, preprocess_ref_audio_text,
@@ -80,7 +80,7 @@ class TTSStreamingProcessor:
else "cpu" else "cpu"
) )
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
self.model_cls = globals()[model_cfg.model.backbone] self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
self.model_arc = model_cfg.model.arch self.model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate

View File

@@ -6,7 +6,7 @@ from importlib.resources import files
import hydra import hydra
from omegaconf import OmegaConf from omegaconf import OmegaConf
from f5_tts.model import CFM, DiT, UNetT, Trainer # noqa: F401. used for config from f5_tts.model import CFM, Trainer
from f5_tts.model.dataset import load_dataset from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer from f5_tts.model.utils import get_tokenizer
@@ -14,60 +14,60 @@ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None) @hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
def main(cfg): def main(model_cfg):
model_cls = globals()[cfg.model.backbone] model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = cfg.model.arch model_arc = model_cfg.model.arch
tokenizer = cfg.model.tokenizer tokenizer = model_cfg.model.tokenizer
mel_spec_type = cfg.model.mel_spec.mel_spec_type mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}" exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}"
wandb_resume_id = None wandb_resume_id = None
# set text tokenizer # set text tokenizer
if tokenizer != "custom": if tokenizer != "custom":
tokenizer_path = cfg.datasets.name tokenizer_path = model_cfg.datasets.name
else: else:
tokenizer_path = cfg.model.tokenizer_path tokenizer_path = model_cfg.model.tokenizer_path
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
# set model # set model
model = CFM( model = CFM(
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels), transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels),
mel_spec_kwargs=cfg.model.mel_spec, mel_spec_kwargs=model_cfg.model.mel_spec,
vocab_char_map=vocab_char_map, vocab_char_map=vocab_char_map,
) )
# init trainer # init trainer
trainer = Trainer( trainer = Trainer(
model, model,
epochs=cfg.optim.epochs, epochs=model_cfg.optim.epochs,
learning_rate=cfg.optim.learning_rate, learning_rate=model_cfg.optim.learning_rate,
num_warmup_updates=cfg.optim.num_warmup_updates, num_warmup_updates=model_cfg.optim.num_warmup_updates,
save_per_updates=cfg.ckpts.save_per_updates, save_per_updates=model_cfg.ckpts.save_per_updates,
keep_last_n_checkpoints=cfg.ckpts.keep_last_n_checkpoints, keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints,
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")), checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")),
batch_size_per_gpu=cfg.datasets.batch_size_per_gpu, batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu,
batch_size_type=cfg.datasets.batch_size_type, batch_size_type=model_cfg.datasets.batch_size_type,
max_samples=cfg.datasets.max_samples, max_samples=model_cfg.datasets.max_samples,
grad_accumulation_steps=cfg.optim.grad_accumulation_steps, grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
max_grad_norm=cfg.optim.max_grad_norm, max_grad_norm=model_cfg.optim.max_grad_norm,
logger=cfg.ckpts.logger, logger=model_cfg.ckpts.logger,
wandb_project="CFM-TTS", wandb_project="CFM-TTS",
wandb_run_name=exp_name, wandb_run_name=exp_name,
wandb_resume_id=wandb_resume_id, wandb_resume_id=wandb_resume_id,
last_per_updates=cfg.ckpts.last_per_updates, last_per_updates=model_cfg.ckpts.last_per_updates,
log_samples=cfg.ckpts.log_samples, log_samples=model_cfg.ckpts.log_samples,
bnb_optimizer=cfg.optim.bnb_optimizer, bnb_optimizer=model_cfg.optim.bnb_optimizer,
mel_spec_type=mel_spec_type, mel_spec_type=mel_spec_type,
is_local_vocoder=cfg.model.vocoder.is_local, is_local_vocoder=model_cfg.model.vocoder.is_local,
local_vocoder_path=cfg.model.vocoder.local_path, local_vocoder_path=model_cfg.model.vocoder.local_path,
cfg_dict=OmegaConf.to_container(cfg, resolve=True), model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True),
) )
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec) train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec)
trainer.train( trainer.train(
train_dataset, train_dataset,
num_workers=cfg.datasets.num_workers, num_workers=model_cfg.datasets.num_workers,
resumable_with_seed=666, # seed for shuffling dataset resumable_with_seed=666, # seed for shuffling dataset
) )