From 4ae53472822629c7a2fd49a1173992013db2895c Mon Sep 17 00:00:00 2001 From: SWivid Date: Fri, 21 Mar 2025 23:01:00 +0800 Subject: [PATCH] pre-commit update and formatting --- .gitignore | 2 - .pre-commit-config.yaml | 4 +- Dockerfile | 3 +- README.md | 4 +- ckpts/README.md | 13 +--- src/f5_tts/api.py | 4 +- src/f5_tts/eval/eval_infer_batch.py | 7 ++- src/f5_tts/eval/utils_eval.py | 6 +- src/f5_tts/infer/infer_cli.py | 13 ++-- src/f5_tts/infer/speech_edit.py | 5 +- src/f5_tts/model/trainer.py | 2 +- src/f5_tts/scripts/count_max_epoch.py | 2 +- src/f5_tts/socket_server.py | 4 +- src/f5_tts/train/datasets/prepare_csv_wavs.py | 4 +- src/f5_tts/train/datasets/prepare_emilia.py | 2 +- src/f5_tts/train/datasets/prepare_libritts.py | 2 +- src/f5_tts/train/datasets/prepare_ljspeech.py | 2 +- src/f5_tts/train/train.py | 62 +++++++++---------- 18 files changed, 66 insertions(+), 75 deletions(-) diff --git a/.gitignore b/.gitignore index b794c56..fac2555 100644 --- a/.gitignore +++ b/.gitignore @@ -7,8 +7,6 @@ ckpts/ wandb/ results/ - - # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9ac5ee1..aae76d7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.7.0 + rev: v0.11.2 hooks: # Run the linter. - id: ruff @@ -9,6 +9,6 @@ repos: # Run the formatter. - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v5.0.0 hooks: - id: check-yaml diff --git a/Dockerfile b/Dockerfile index 03cc9c1..069b76d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,9 +23,8 @@ RUN git clone https://github.com/SWivid/F5-TTS.git \ ENV SHELL=/bin/bash -# models are downloaded into this folder, so user should mount it VOLUME /root/.cache/huggingface/hub/ -# port the GUI is exposed on by default, if it is run + EXPOSE 7860 WORKDIR /workspace/F5-TTS diff --git a/README.md b/README.md index dea4b79..ea0793f 100644 --- a/README.md +++ b/README.md @@ -203,7 +203,7 @@ Read [training & finetuning guidance](src/f5_tts/train) for more instructions. ## Development -Use pre-commit to ensure code quality (will run linters and formatters automatically) +Use pre-commit to ensure code quality (will run linters and formatters automatically): ```bash pip install pre-commit @@ -216,7 +216,7 @@ When making a pull request, before each commit, run: pre-commit run --all-files ``` -Note: Some model components have linting exceptions for E722 to accommodate tensor notation +Note: Some model components have linting exceptions for E722 to accommodate tensor notation. ## Acknowledgements diff --git a/ckpts/README.md b/ckpts/README.md index 0d6b048..e1a4b7a 100644 --- a/ckpts/README.md +++ b/ckpts/README.md @@ -1,12 +1,3 @@ +The pretrained model checkpoints can be reached at https://huggingface.co/SWivid/F5-TTS. -Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS - -``` -ckpts/ - F5TTS_v1_Base/ - model_1250000.safetensors - F5TTS_Base/ - model_1200000.safetensors - E2TTS_Base/ - model_1200000.safetensors -``` \ No newline at end of file +Scripts will automatically pull model checkpoints from Huggingface, by default to `~/.cache/huggingface/hub/`. diff --git a/src/f5_tts/api.py b/src/f5_tts/api.py index ea2c1a0..0ead776 100644 --- a/src/f5_tts/api.py +++ b/src/f5_tts/api.py @@ -5,6 +5,7 @@ from importlib.resources import files import soundfile as sf import tqdm from cached_path import cached_path +from hydra.utils import get_class from omegaconf import OmegaConf from f5_tts.infer.utils_infer import ( @@ -16,7 +17,6 @@ from f5_tts.infer.utils_infer import ( remove_silence_for_generated_wav, save_spectrogram, ) -from f5_tts.model import DiT, UNetT # noqa: F401. used for config from f5_tts.model.utils import seed_everything @@ -33,7 +33,7 @@ class F5TTS: hf_cache_dir=None, ): model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) - model_cls = globals()[model_cfg.model.backbone] + model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type diff --git a/src/f5_tts/eval/eval_infer_batch.py b/src/f5_tts/eval/eval_infer_batch.py index e779ff0..17dd130 100644 --- a/src/f5_tts/eval/eval_infer_batch.py +++ b/src/f5_tts/eval/eval_infer_batch.py @@ -10,6 +10,7 @@ from importlib.resources import files import torch import torchaudio from accelerate import Accelerator +from hydra.utils import get_class from omegaconf import OmegaConf from tqdm import tqdm @@ -19,7 +20,7 @@ from f5_tts.eval.utils_eval import ( get_seedtts_testset_metainfo, ) from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder -from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config +from f5_tts.model import CFM from f5_tts.model.utils import get_tokenizer accelerator = Accelerator() @@ -65,7 +66,7 @@ def main(): no_ref_audio = False model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml"))) - model_cls = globals()[model_cfg.model.backbone] + model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch dataset_name = model_cfg.datasets.name @@ -195,7 +196,7 @@ def main(): accelerator.wait_for_everyone() if accelerator.is_main_process: timediff = time.time() - start - print(f"Done batch inference in {timediff / 60 :.2f} minutes.") + print(f"Done batch inference in {timediff / 60:.2f} minutes.") if __name__ == "__main__": diff --git a/src/f5_tts/eval/utils_eval.py b/src/f5_tts/eval/utils_eval.py index d8407ad..f819cdc 100644 --- a/src/f5_tts/eval/utils_eval.py +++ b/src/f5_tts/eval/utils_eval.py @@ -148,9 +148,9 @@ def get_inference_prompt( # deal with batch assert infer_batch_size > 0, "infer_batch_size should be greater than 0." - assert ( - min_tokens <= total_mel_len <= max_tokens - ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + assert min_tokens <= total_mel_len <= max_tokens, ( + f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]." + ) bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets) utts[bucket_i].append(utt) diff --git a/src/f5_tts/infer/infer_cli.py b/src/f5_tts/infer/infer_cli.py index 5c7a1bb..acabf6f 100644 --- a/src/f5_tts/infer/infer_cli.py +++ b/src/f5_tts/infer/infer_cli.py @@ -10,6 +10,7 @@ import numpy as np import soundfile as sf import tomli from cached_path import cached_path +from hydra.utils import get_class from omegaconf import OmegaConf from f5_tts.infer.utils_infer import ( @@ -27,7 +28,6 @@ from f5_tts.infer.utils_infer import ( preprocess_ref_audio_text, remove_silence_for_generated_wav, ) -from f5_tts.model import DiT, UNetT # noqa: F401. used for config parser = argparse.ArgumentParser( @@ -246,13 +246,14 @@ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_loc model_cfg = OmegaConf.load( args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) -).model -model_cls = globals()[model_cfg.backbone] +) +model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") +model_arc = model_cfg.model.arch repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" if model != "F5TTS_Base": - assert vocoder_name == model_cfg.mel_spec.mel_spec_type + assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type # override for previous models if model == "F5TTS_Base": @@ -269,7 +270,7 @@ if not ckpt_file: ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}")) print(f"Using {model}...") -ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file) +ema_model = load_model(model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file) # inference process @@ -332,7 +333,7 @@ def main(): if len(gen_text_) > 200: gen_text_ = gen_text_[:200] + " ... " sf.write( - os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"), + os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"), audio_segment, final_sample_rate, ) diff --git a/src/f5_tts/infer/speech_edit.py b/src/f5_tts/infer/speech_edit.py index d8d073e..b724258 100644 --- a/src/f5_tts/infer/speech_edit.py +++ b/src/f5_tts/infer/speech_edit.py @@ -7,10 +7,11 @@ from importlib.resources import files import torch import torch.nn.functional as F import torchaudio +from hydra.utils import get_class from omegaconf import OmegaConf from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram -from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config +from f5_tts.model import CFM from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer device = ( @@ -40,7 +41,7 @@ target_rms = 0.1 model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml"))) -model_cls = globals()[model_cfg.model.backbone] +model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch dataset_name = model_cfg.datasets.name diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index d9ab4a8..6e3888f 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -350,7 +350,7 @@ class Trainer: progress_bar = tqdm( range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)), - desc=f"Epoch {epoch+1}/{self.epochs}", + desc=f"Epoch {epoch + 1}/{self.epochs}", unit="update", disable=not self.accelerator.is_local_main_process, initial=progress_bar_initial, diff --git a/src/f5_tts/scripts/count_max_epoch.py b/src/f5_tts/scripts/count_max_epoch.py index fe291e5..5e62b76 100644 --- a/src/f5_tts/scripts/count_max_epoch.py +++ b/src/f5_tts/scripts/count_max_epoch.py @@ -24,7 +24,7 @@ updates_per_epoch = total_hours / mini_batch_hours # result epochs = wanted_max_updates / updates_per_epoch -print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})") +print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})") print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates") # print(f" or approx. 0/{steps_per_epoch:.0f} steps") diff --git a/src/f5_tts/socket_server.py b/src/f5_tts/socket_server.py index 344b1d7..23df15f 100644 --- a/src/f5_tts/socket_server.py +++ b/src/f5_tts/socket_server.py @@ -13,9 +13,9 @@ from importlib.resources import files import torch import torchaudio from huggingface_hub import hf_hub_download +from hydra.utils import get_class from omegaconf import OmegaConf -from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config from f5_tts.infer.utils_infer import ( chunk_text, preprocess_ref_audio_text, @@ -80,7 +80,7 @@ class TTSStreamingProcessor: else "cpu" ) model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) - self.model_cls = globals()[model_cfg.model.backbone] + self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") self.model_arc = model_cfg.model.arch self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate diff --git a/src/f5_tts/train/datasets/prepare_csv_wavs.py b/src/f5_tts/train/datasets/prepare_csv_wavs.py index 323a143..4fa1099 100644 --- a/src/f5_tts/train/datasets/prepare_csv_wavs.py +++ b/src/f5_tts/train/datasets/prepare_csv_wavs.py @@ -122,7 +122,7 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None): for future in tqdm( chunk_futures, total=len(chunk), - desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}", + desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}", ): try: result = future.result() @@ -233,7 +233,7 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine dataset_name = out_dir.stem print(f"\nFor {dataset_name}, sample count: {len(result)}") print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") - print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours") + print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours") def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None): diff --git a/src/f5_tts/train/datasets/prepare_emilia.py b/src/f5_tts/train/datasets/prepare_emilia.py index d9b276a..d9a3520 100644 --- a/src/f5_tts/train/datasets/prepare_emilia.py +++ b/src/f5_tts/train/datasets/prepare_emilia.py @@ -198,7 +198,7 @@ def main(): print(f"\nFor {dataset_name}, sample count: {len(result)}") print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") - print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours") + print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours") if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}") if "EN" in langs: diff --git a/src/f5_tts/train/datasets/prepare_libritts.py b/src/f5_tts/train/datasets/prepare_libritts.py index 2a35dd9..9af48d4 100644 --- a/src/f5_tts/train/datasets/prepare_libritts.py +++ b/src/f5_tts/train/datasets/prepare_libritts.py @@ -72,7 +72,7 @@ def main(): print(f"\nFor {dataset_name}, sample count: {len(result)}") print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") - print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours") + print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours") if __name__ == "__main__": diff --git a/src/f5_tts/train/datasets/prepare_ljspeech.py b/src/f5_tts/train/datasets/prepare_ljspeech.py index 19a5b2a..129ff45 100644 --- a/src/f5_tts/train/datasets/prepare_ljspeech.py +++ b/src/f5_tts/train/datasets/prepare_ljspeech.py @@ -50,7 +50,7 @@ def main(): print(f"\nFor {dataset_name}, sample count: {len(result)}") print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") - print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours") + print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours") if __name__ == "__main__": diff --git a/src/f5_tts/train/train.py b/src/f5_tts/train/train.py index 2e191a3..a935e36 100644 --- a/src/f5_tts/train/train.py +++ b/src/f5_tts/train/train.py @@ -6,7 +6,7 @@ from importlib.resources import files import hydra from omegaconf import OmegaConf -from f5_tts.model import CFM, DiT, UNetT, Trainer # noqa: F401. used for config +from f5_tts.model import CFM, Trainer from f5_tts.model.dataset import load_dataset from f5_tts.model.utils import get_tokenizer @@ -14,60 +14,60 @@ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to @hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None) -def main(cfg): - model_cls = globals()[cfg.model.backbone] - model_arc = cfg.model.arch - tokenizer = cfg.model.tokenizer - mel_spec_type = cfg.model.mel_spec.mel_spec_type +def main(model_cfg): + model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}") + model_arc = model_cfg.model.arch + tokenizer = model_cfg.model.tokenizer + mel_spec_type = model_cfg.model.mel_spec.mel_spec_type - exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}" + exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}" wandb_resume_id = None # set text tokenizer if tokenizer != "custom": - tokenizer_path = cfg.datasets.name + tokenizer_path = model_cfg.datasets.name else: - tokenizer_path = cfg.model.tokenizer_path + tokenizer_path = model_cfg.model.tokenizer_path vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) # set model model = CFM( - transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels), - mel_spec_kwargs=cfg.model.mel_spec, + transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels), + mel_spec_kwargs=model_cfg.model.mel_spec, vocab_char_map=vocab_char_map, ) # init trainer trainer = Trainer( model, - epochs=cfg.optim.epochs, - learning_rate=cfg.optim.learning_rate, - num_warmup_updates=cfg.optim.num_warmup_updates, - save_per_updates=cfg.ckpts.save_per_updates, - keep_last_n_checkpoints=cfg.ckpts.keep_last_n_checkpoints, - checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")), - batch_size_per_gpu=cfg.datasets.batch_size_per_gpu, - batch_size_type=cfg.datasets.batch_size_type, - max_samples=cfg.datasets.max_samples, - grad_accumulation_steps=cfg.optim.grad_accumulation_steps, - max_grad_norm=cfg.optim.max_grad_norm, - logger=cfg.ckpts.logger, + epochs=model_cfg.optim.epochs, + learning_rate=model_cfg.optim.learning_rate, + num_warmup_updates=model_cfg.optim.num_warmup_updates, + save_per_updates=model_cfg.ckpts.save_per_updates, + keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints, + checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")), + batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu, + batch_size_type=model_cfg.datasets.batch_size_type, + max_samples=model_cfg.datasets.max_samples, + grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps, + max_grad_norm=model_cfg.optim.max_grad_norm, + logger=model_cfg.ckpts.logger, wandb_project="CFM-TTS", wandb_run_name=exp_name, wandb_resume_id=wandb_resume_id, - last_per_updates=cfg.ckpts.last_per_updates, - log_samples=cfg.ckpts.log_samples, - bnb_optimizer=cfg.optim.bnb_optimizer, + last_per_updates=model_cfg.ckpts.last_per_updates, + log_samples=model_cfg.ckpts.log_samples, + bnb_optimizer=model_cfg.optim.bnb_optimizer, mel_spec_type=mel_spec_type, - is_local_vocoder=cfg.model.vocoder.is_local, - local_vocoder_path=cfg.model.vocoder.local_path, - cfg_dict=OmegaConf.to_container(cfg, resolve=True), + is_local_vocoder=model_cfg.model.vocoder.is_local, + local_vocoder_path=model_cfg.model.vocoder.local_path, + model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True), ) - train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec) + train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec) trainer.train( train_dataset, - num_workers=cfg.datasets.num_workers, + num_workers=model_cfg.datasets.num_workers, resumable_with_seed=666, # seed for shuffling dataset )