mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-30 14:42:13 -08:00
convert to pkg, reorganize repo (#228)
* group files in f5_tts directory * add setup.py * use global imports * simplify demo * add install directions for library mode * fix old huggingface_hub version constraint * move finetune to package * change imports to f5_tts.model * bump version * fix bad merge * Update inference-cli.py * fix HF space * reformat * fix utils.py vocab.txt import * fix format * adapt README for f5_tts package structure * simplify app.py * add gradio.Dockerfile and workflow * refactored for pyproject.toml * refactored for pyproject.toml * added in reference to packaged files * use fork for testing docker image * added in reference to packaged files * minor tweaks * fixed inference-cli.toml path * fixed inference-cli.toml path * fixed inference-cli.toml path * fixed inference-cli.toml path * refactor eval_infer_batch.py * fix typo * added eval_infer_batch to scripts --------- Co-authored-by: Roberts Slisans <rsxdalv@gmail.com> Co-authored-by: Adam Kessel <adam@rosi-kessel.org> Co-authored-by: Roberts Slisans <roberts.slisans@gmail.com>
This commit is contained in:
61
.github/workflows/publish-docker-image.yaml
vendored
Normal file
61
.github/workflows/publish-docker-image.yaml
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
name: Create and publish a Docker image
|
||||
|
||||
# Configures this workflow to run every time a change is pushed to the branch called `release`.
|
||||
on:
|
||||
push:
|
||||
branches: ['main']
|
||||
|
||||
# Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds.
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
# There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu.
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
runs-on: ubuntu-latest
|
||||
# Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
#
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Free Up GitHub Actions Ubuntu Runner Disk Space 🔧
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
# This might remove tools that are actually needed, if set to "true" but frees about 6 GB
|
||||
tool-cache: false
|
||||
|
||||
# All of these default to true, but feel free to set to "false" if necessary for your workflow
|
||||
android: true
|
||||
dotnet: true
|
||||
haskell: true
|
||||
large-packages: false
|
||||
swap-storage: false
|
||||
docker-images: false
|
||||
# Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
# This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels.
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
# This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages.
|
||||
# It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository.
|
||||
# It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step.
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
|
||||
with:
|
||||
context: .
|
||||
file: ./gradio.Dockerfile
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
52
README.md
52
README.md
@@ -63,11 +63,35 @@ pre-commit run --all-files
|
||||
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
|
||||
|
||||
|
||||
## Prepare Dataset
|
||||
|
||||
Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
|
||||
### As a pip package
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/SWivid/F5-TTS.git
|
||||
```
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
from f5_tts.gradio_app import app
|
||||
|
||||
with gr.Blocks() as main_app:
|
||||
gr.Markdown("# This is an example of using F5-TTS within a bigger Gradio app")
|
||||
|
||||
# ... other Gradio components
|
||||
|
||||
app.render()
|
||||
|
||||
main_app.launch()
|
||||
|
||||
```
|
||||
|
||||
## Prepare Dataset
|
||||
|
||||
Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `f5_tts/model/dataset.py`.
|
||||
|
||||
```bash
|
||||
# switch to the main directory
|
||||
cd f5_tts
|
||||
|
||||
# prepare custom dataset up to your need
|
||||
# download corresponding dataset first, and fill in the path in scripts
|
||||
|
||||
@@ -83,6 +107,9 @@ python scripts/prepare_wenetspeech4tts.py
|
||||
Once your datasets are prepared, you can start the training process.
|
||||
|
||||
```bash
|
||||
# switch to the main directory
|
||||
cd f5_tts
|
||||
|
||||
# setup accelerate config, e.g. use multi-gpu ddp, fp16
|
||||
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
|
||||
accelerate config
|
||||
@@ -90,7 +117,7 @@ accelerate launch train.py
|
||||
```
|
||||
An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
|
||||
|
||||
Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
|
||||
Gradio UI finetuning with `f5_tts/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
|
||||
|
||||
### Wandb Logging
|
||||
|
||||
@@ -136,6 +163,9 @@ for change model use `--ckpt_file` to specify the model you want to load,
|
||||
for change vocab.txt use `--vocab_file` to provide your vocab.txt file.
|
||||
|
||||
```bash
|
||||
# switch to the main directory
|
||||
cd f5_tts
|
||||
|
||||
python inference-cli.py \
|
||||
--model "F5-TTS" \
|
||||
--ref_audio "tests/ref_audio/test_en_1_ref_short.wav" \
|
||||
@@ -161,19 +191,19 @@ Currently supported features:
|
||||
You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may also use local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
|
||||
|
||||
```bash
|
||||
python gradio_app.py
|
||||
python f5_tts/gradio_app.py
|
||||
```
|
||||
|
||||
You can specify the port/host:
|
||||
|
||||
```bash
|
||||
python gradio_app.py --port 7860 --host 0.0.0.0
|
||||
python f5_tts/gradio_app.py --port 7860 --host 0.0.0.0
|
||||
```
|
||||
|
||||
Or launch a share link:
|
||||
|
||||
```bash
|
||||
python gradio_app.py --share
|
||||
python f5_tts/gradio_app.py --share
|
||||
```
|
||||
|
||||
### Speech Editing
|
||||
@@ -181,7 +211,7 @@ python gradio_app.py --share
|
||||
To test speech editing capabilities, use the following command.
|
||||
|
||||
```bash
|
||||
python speech_edit.py
|
||||
python f5_tts/speech_edit.py
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
@@ -199,6 +229,9 @@ python speech_edit.py
|
||||
To run batch inference for evaluations, execute the following commands:
|
||||
|
||||
```bash
|
||||
# switch to the main directory
|
||||
cd f5_tts
|
||||
|
||||
# batch inference for evaluations
|
||||
accelerate config # if not set before
|
||||
bash scripts/eval_infer_batch.sh
|
||||
@@ -234,6 +267,9 @@ pip install faster-whisper==0.10.1
|
||||
|
||||
Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
|
||||
```bash
|
||||
# switch to the main directory
|
||||
cd f5_tts
|
||||
|
||||
# Evaluation for Seed-TTS test set
|
||||
python scripts/eval_seedtts_testset.py
|
||||
|
||||
|
||||
3
app.py
Normal file
3
app.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from f5_tts.gradio_app import app
|
||||
|
||||
app.queue().launch()
|
||||
27
gradio.Dockerfile
Normal file
27
gradio.Dockerfile
Normal file
@@ -0,0 +1,27 @@
|
||||
FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
|
||||
|
||||
USER root
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
LABEL github_repo="https://github.com/rsxdalv/F5-TTS"
|
||||
|
||||
RUN set -x \
|
||||
&& apt-get update \
|
||||
&& apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \
|
||||
&& apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& apt-get clean
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone https://github.com/rsxdalv/F5-TTS.git \
|
||||
&& cd F5-TTS \
|
||||
&& pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
ENV SHELL=/bin/bash
|
||||
|
||||
WORKDIR /workspace/F5-TTS/f5_tts
|
||||
|
||||
EXPOSE 7860
|
||||
CMD python gradio_app.py
|
||||
@@ -1,10 +0,0 @@
|
||||
from model.cfm import CFM
|
||||
|
||||
from model.backbones.unett import UNetT
|
||||
from model.backbones.dit import DiT
|
||||
from model.backbones.mmdit import MMDiT
|
||||
|
||||
from model.trainer import Trainer
|
||||
|
||||
|
||||
__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
|
||||
52
pyproject.toml
Normal file
52
pyproject.toml
Normal file
@@ -0,0 +1,52 @@
|
||||
[build-system]
|
||||
requires = ["setuptools >= 61.0", "setuptools-scm>=8.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
dynamic = ["version"]
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
]
|
||||
dependencies = [
|
||||
"accelerate>=0.33.0",
|
||||
"cached_path @ git+https://github.com/rsxdalv/cached_path@main",
|
||||
"click",
|
||||
"datasets",
|
||||
"einops>=0.8.0",
|
||||
"einx>=0.3.0",
|
||||
"ema_pytorch>=0.5.2",
|
||||
"gradio",
|
||||
"jieba",
|
||||
"librosa",
|
||||
"matplotlib",
|
||||
"numpy<=1.26.4",
|
||||
"pydub",
|
||||
"pypinyin",
|
||||
"safetensors",
|
||||
"soundfile",
|
||||
"tomli",
|
||||
"torch>=2.0.0",
|
||||
"torchaudio>=2.0.0",
|
||||
"torchdiffeq",
|
||||
"tqdm>=4.65.0",
|
||||
"transformers",
|
||||
"vocos",
|
||||
"wandb",
|
||||
"x_transformers>=1.31.14",
|
||||
]
|
||||
|
||||
[[project.authors]]
|
||||
name = "Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen"
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/SWivid/F5-TTS"
|
||||
|
||||
[project.scripts]
|
||||
"finetune-cli" = "f5_tts.finetune_cli:main"
|
||||
"inference-cli" = "f5_tts.inference_cli:main"
|
||||
"eval_infer_batch" = "f5_tts.scripts.eval_infer_batch:main"
|
||||
@@ -1,198 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import time
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from accelerate import Accelerator
|
||||
from vocos import Vocos
|
||||
|
||||
from model import CFM, UNetT, DiT
|
||||
from model.utils import (
|
||||
load_checkpoint,
|
||||
get_tokenizer,
|
||||
get_seedtts_testset_metainfo,
|
||||
get_librispeech_test_clean_metainfo,
|
||||
get_inference_prompt,
|
||||
)
|
||||
|
||||
accelerator = Accelerator()
|
||||
device = f"cuda:{accelerator.process_index}"
|
||||
|
||||
|
||||
# --------------------- Dataset Settings -------------------- #
|
||||
|
||||
target_sample_rate = 24000
|
||||
n_mel_channels = 100
|
||||
hop_length = 256
|
||||
target_rms = 0.1
|
||||
|
||||
tokenizer = "pinyin"
|
||||
|
||||
|
||||
# ---------------------- infer setting ---------------------- #
|
||||
|
||||
parser = argparse.ArgumentParser(description="batch inference")
|
||||
|
||||
parser.add_argument("-s", "--seed", default=None, type=int)
|
||||
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
||||
parser.add_argument("-n", "--expname", required=True)
|
||||
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
|
||||
|
||||
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
||||
parser.add_argument("-o", "--odemethod", default="euler")
|
||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||
|
||||
parser.add_argument("-t", "--testset", required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
seed = args.seed
|
||||
dataset_name = args.dataset
|
||||
exp_name = args.expname
|
||||
ckpt_step = args.ckptstep
|
||||
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
|
||||
|
||||
nfe_step = args.nfestep
|
||||
ode_method = args.odemethod
|
||||
sway_sampling_coef = args.swaysampling
|
||||
|
||||
testset = args.testset
|
||||
|
||||
|
||||
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
||||
cfg_strength = 2.0
|
||||
speed = 1.0
|
||||
use_truth_duration = False
|
||||
no_ref_audio = False
|
||||
|
||||
|
||||
if exp_name == "F5TTS_Base":
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
|
||||
elif exp_name == "E2TTS_Base":
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
|
||||
|
||||
if testset == "ls_pc_test_clean":
|
||||
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
|
||||
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
||||
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
||||
|
||||
elif testset == "seedtts_test_zh":
|
||||
metalst = "data/seedtts_testset/zh/meta.lst"
|
||||
metainfo = get_seedtts_testset_metainfo(metalst)
|
||||
|
||||
elif testset == "seedtts_test_en":
|
||||
metalst = "data/seedtts_testset/en/meta.lst"
|
||||
metainfo = get_seedtts_testset_metainfo(metalst)
|
||||
|
||||
|
||||
# path to save genereted wavs
|
||||
if seed is None:
|
||||
seed = random.randint(-10000, 10000)
|
||||
output_dir = (
|
||||
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
||||
f"seed{seed}_{ode_method}_nfe{nfe_step}"
|
||||
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
||||
f"_cfg{cfg_strength}_speed{speed}"
|
||||
f"{'_gt-dur' if use_truth_duration else ''}"
|
||||
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------#
|
||||
|
||||
use_ema = True
|
||||
|
||||
prompts_all = get_inference_prompt(
|
||||
metainfo,
|
||||
speed=speed,
|
||||
tokenizer=tokenizer,
|
||||
target_sample_rate=target_sample_rate,
|
||||
n_mel_channels=n_mel_channels,
|
||||
hop_length=hop_length,
|
||||
target_rms=target_rms,
|
||||
use_truth_duration=use_truth_duration,
|
||||
infer_batch_size=infer_batch_size,
|
||||
)
|
||||
|
||||
# Vocoder model
|
||||
local = False
|
||||
if local:
|
||||
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
||||
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
||||
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
|
||||
vocos.load_state_dict(state_dict)
|
||||
vocos.eval()
|
||||
else:
|
||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
|
||||
# Tokenizer
|
||||
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
||||
|
||||
# Model
|
||||
model = CFM(
|
||||
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
mel_spec_kwargs=dict(
|
||||
target_sample_rate=target_sample_rate,
|
||||
n_mel_channels=n_mel_channels,
|
||||
hop_length=hop_length,
|
||||
),
|
||||
odeint_kwargs=dict(
|
||||
method=ode_method,
|
||||
),
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
|
||||
|
||||
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# start batch inference
|
||||
accelerator.wait_for_everyone()
|
||||
start = time.time()
|
||||
|
||||
with accelerator.split_between_processes(prompts_all) as prompts:
|
||||
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
||||
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
|
||||
ref_mels = ref_mels.to(device)
|
||||
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
||||
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
||||
|
||||
# Inference
|
||||
with torch.inference_mode():
|
||||
generated, _ = model.sample(
|
||||
cond=ref_mels,
|
||||
text=final_text_list,
|
||||
duration=total_mel_lens,
|
||||
lens=ref_mel_lens,
|
||||
steps=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
no_ref_audio=no_ref_audio,
|
||||
seed=seed,
|
||||
)
|
||||
# Final result
|
||||
for i, gen in enumerate(generated):
|
||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||
gen_mel_spec = gen.permute(0, 2, 1)
|
||||
generated_wave = vocos.decode(gen_mel_spec.cpu())
|
||||
if ref_rms_list[i] < target_rms:
|
||||
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||||
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
timediff = time.time() - start
|
||||
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
|
||||
@@ -3,11 +3,11 @@ import torch
|
||||
import tqdm
|
||||
from cached_path import cached_path
|
||||
|
||||
from model import DiT, UNetT
|
||||
from model.utils import save_spectrogram
|
||||
from f5_tts.model import DiT, UNetT
|
||||
from f5_tts.model.utils import save_spectrogram
|
||||
|
||||
from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
|
||||
from model.utils import seed_everything
|
||||
from f5_tts.model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
|
||||
from f5_tts.model.utils import seed_everything
|
||||
import random
|
||||
import sys
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
from model import CFM, UNetT, DiT, Trainer
|
||||
from model.utils import get_tokenizer
|
||||
from model.dataset import load_dataset
|
||||
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
from f5_tts.model.dataset import load_dataset
|
||||
from cached_path import cached_path
|
||||
import shutil
|
||||
import os
|
||||
@@ -17,14 +17,14 @@ import shutil
|
||||
import time
|
||||
|
||||
import json
|
||||
from model.utils import convert_char_to_pinyin
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
import signal
|
||||
import psutil
|
||||
import platform
|
||||
import subprocess
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from datasets import Dataset as Dataset_
|
||||
from api import F5TTS
|
||||
from f5_tts.api import F5TTS
|
||||
|
||||
|
||||
training_process = None
|
||||
@@ -27,11 +27,11 @@ def gpu_decorator(func):
|
||||
return func
|
||||
|
||||
|
||||
from model import DiT, UNetT
|
||||
from model.utils import (
|
||||
from f5_tts.model import DiT, UNetT
|
||||
from f5_tts.model.utils import (
|
||||
save_spectrogram,
|
||||
)
|
||||
from model.utils_infer import (
|
||||
from f5_tts.model.utils_infer import (
|
||||
load_vocoder,
|
||||
load_model,
|
||||
preprocess_ref_audio_text,
|
||||
@@ -1,15 +1,17 @@
|
||||
import argparse
|
||||
import codecs
|
||||
import re
|
||||
import os
|
||||
from pathlib import Path
|
||||
from importlib.resources import files
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import tomli
|
||||
from cached_path import cached_path
|
||||
|
||||
from model import DiT, UNetT
|
||||
from model.utils_infer import (
|
||||
from f5_tts.model import DiT, UNetT
|
||||
from f5_tts.model.utils_infer import (
|
||||
load_vocoder,
|
||||
load_model,
|
||||
preprocess_ref_audio_text,
|
||||
@@ -26,8 +28,8 @@ parser = argparse.ArgumentParser(
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
help="Configuration file. Default=cli-config.toml",
|
||||
default="inference-cli.toml",
|
||||
help="Configuration file. Default=inference-cli.toml",
|
||||
default=os.path.join(files('f5_tts').joinpath('data'), 'inference-cli.toml')
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
@@ -166,5 +168,8 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
|
||||
remove_silence_for_generated_wav(f.name)
|
||||
print(f.name)
|
||||
|
||||
def main():
|
||||
main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence)
|
||||
|
||||
main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence)
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
10
src/f5_tts/model/__init__.py
Normal file
10
src/f5_tts/model/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from f5_tts.model.cfm import CFM
|
||||
|
||||
from f5_tts.model.backbones.unett import UNetT
|
||||
from f5_tts.model.backbones.dit import DiT
|
||||
from f5_tts.model.backbones.mmdit import MMDiT
|
||||
|
||||
from f5_tts.model.trainer import Trainer
|
||||
|
||||
|
||||
__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
|
||||
@@ -15,7 +15,7 @@ import torch.nn.functional as F
|
||||
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from model.modules import (
|
||||
from f5_tts.model.modules import (
|
||||
TimestepEmbedding,
|
||||
ConvNeXtV2Block,
|
||||
ConvPositionEmbedding,
|
||||
@@ -14,7 +14,7 @@ from torch import nn
|
||||
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from model.modules import (
|
||||
from f5_tts.model.modules import (
|
||||
TimestepEmbedding,
|
||||
ConvPositionEmbedding,
|
||||
MMDiTBlock,
|
||||
@@ -17,7 +17,7 @@ import torch.nn.functional as F
|
||||
from x_transformers import RMSNorm
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from model.modules import (
|
||||
from f5_tts.model.modules import (
|
||||
TimestepEmbedding,
|
||||
ConvNeXtV2Block,
|
||||
ConvPositionEmbedding,
|
||||
@@ -18,8 +18,8 @@ from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from torchdiffeq import odeint
|
||||
|
||||
from model.modules import MelSpec
|
||||
from model.utils import (
|
||||
from f5_tts.model.modules import MelSpec
|
||||
from f5_tts.model.utils import (
|
||||
default,
|
||||
exists,
|
||||
list_str_to_idx,
|
||||
@@ -10,8 +10,8 @@ from datasets import load_from_disk
|
||||
from datasets import Dataset as Dataset_
|
||||
from torch import nn
|
||||
|
||||
from model.modules import MelSpec
|
||||
from model.utils import default
|
||||
from f5_tts.model.modules import MelSpec
|
||||
from f5_tts.model.utils import default
|
||||
|
||||
|
||||
class HFDataset(Dataset):
|
||||
@@ -15,9 +15,9 @@ from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
from model import CFM
|
||||
from model.utils import exists, default
|
||||
from model.dataset import DynamicBatchSampler, collate_fn
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import exists, default
|
||||
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
||||
|
||||
|
||||
# trainer
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import math
|
||||
import random
|
||||
import string
|
||||
from importlib.resources import files
|
||||
from tqdm import tqdm
|
||||
from collections import defaultdict
|
||||
|
||||
@@ -20,8 +21,8 @@ import torchaudio
|
||||
import jieba
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
|
||||
from model.ecapa_tdnn import ECAPA_TDNN_SMALL
|
||||
from model.modules import MelSpec
|
||||
from f5_tts.model.ecapa_tdnn import ECAPA_TDNN_SMALL
|
||||
from f5_tts.model.modules import MelSpec
|
||||
|
||||
|
||||
# seed everything
|
||||
@@ -121,7 +122,8 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
||||
- if use "byte", set to 256 (unicode byte range)
|
||||
"""
|
||||
if tokenizer in ["pinyin", "char"]:
|
||||
with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
|
||||
tokenizer_path = os.path.join(files('f5_tts').joinpath('data'), f"{dataset_name}_{tokenizer}/vocab.txt")
|
||||
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
||||
vocab_char_map = {}
|
||||
for i, char in enumerate(f):
|
||||
vocab_char_map[char[:-1]] = i
|
||||
@@ -12,8 +12,8 @@ from pydub import AudioSegment, silence
|
||||
from transformers import pipeline
|
||||
from vocos import Vocos
|
||||
|
||||
from model import CFM
|
||||
from model.utils import (
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import (
|
||||
load_checkpoint,
|
||||
get_tokenizer,
|
||||
convert_char_to_pinyin,
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from model import M2_TTS, DiT
|
||||
from f5_tts.model import M2_TTS, DiT
|
||||
|
||||
import torch
|
||||
import thop
|
||||
204
src/f5_tts/scripts/eval_infer_batch.py
Normal file
204
src/f5_tts/scripts/eval_infer_batch.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import time
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
from importlib.resources import files
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from accelerate import Accelerator
|
||||
from vocos import Vocos
|
||||
|
||||
from f5_tts.model import CFM, UNetT, DiT
|
||||
from f5_tts.model.utils import (
|
||||
load_checkpoint,
|
||||
get_tokenizer,
|
||||
get_seedtts_testset_metainfo,
|
||||
get_librispeech_test_clean_metainfo,
|
||||
get_inference_prompt,
|
||||
)
|
||||
|
||||
accelerator = Accelerator()
|
||||
device = f"cuda:{accelerator.process_index}"
|
||||
|
||||
|
||||
# --------------------- Dataset Settings -------------------- #
|
||||
|
||||
target_sample_rate = 24000
|
||||
n_mel_channels = 100
|
||||
hop_length = 256
|
||||
target_rms = 0.1
|
||||
|
||||
tokenizer = "pinyin"
|
||||
|
||||
def main():
|
||||
# ---------------------- infer setting ---------------------- #
|
||||
|
||||
parser = argparse.ArgumentParser(description="batch inference")
|
||||
|
||||
parser.add_argument("-s", "--seed", default=None, type=int)
|
||||
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
||||
parser.add_argument("-n", "--expname", required=True)
|
||||
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
|
||||
|
||||
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
||||
parser.add_argument("-o", "--odemethod", default="euler")
|
||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||
|
||||
parser.add_argument("-t", "--testset", required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
seed = args.seed
|
||||
dataset_name = args.dataset
|
||||
exp_name = args.expname
|
||||
ckpt_step = args.ckptstep
|
||||
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
|
||||
|
||||
nfe_step = args.nfestep
|
||||
ode_method = args.odemethod
|
||||
sway_sampling_coef = args.swaysampling
|
||||
|
||||
testset = args.testset
|
||||
|
||||
|
||||
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
||||
cfg_strength = 2.0
|
||||
speed = 1.0
|
||||
use_truth_duration = False
|
||||
no_ref_audio = False
|
||||
|
||||
|
||||
if exp_name == "F5TTS_Base":
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
|
||||
elif exp_name == "E2TTS_Base":
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
|
||||
|
||||
datapath = files('f5_tts').joinpath('data')
|
||||
|
||||
if testset == "ls_pc_test_clean":
|
||||
metalst = os.path.join(datapath,"librispeech_pc_test_clean_cross_sentence.lst")
|
||||
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
||||
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
||||
|
||||
elif testset == "seedtts_test_zh":
|
||||
metalst = os.path.join(datapath,"seedtts_testset/zh/meta.lst")
|
||||
metainfo = get_seedtts_testset_metainfo(metalst)
|
||||
|
||||
elif testset == "seedtts_test_en":
|
||||
metalst = os.path.join(datapath,"seedtts_testset/en/meta.lst")
|
||||
metainfo = get_seedtts_testset_metainfo(metalst)
|
||||
|
||||
|
||||
# path to save genereted wavs
|
||||
if seed is None:
|
||||
seed = random.randint(-10000, 10000)
|
||||
output_dir = (
|
||||
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
||||
f"seed{seed}_{ode_method}_nfe{nfe_step}"
|
||||
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
||||
f"_cfg{cfg_strength}_speed{speed}"
|
||||
f"{'_gt-dur' if use_truth_duration else ''}"
|
||||
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------#
|
||||
|
||||
use_ema = True
|
||||
|
||||
prompts_all = get_inference_prompt(
|
||||
metainfo,
|
||||
speed=speed,
|
||||
tokenizer=tokenizer,
|
||||
target_sample_rate=target_sample_rate,
|
||||
n_mel_channels=n_mel_channels,
|
||||
hop_length=hop_length,
|
||||
target_rms=target_rms,
|
||||
use_truth_duration=use_truth_duration,
|
||||
infer_batch_size=infer_batch_size,
|
||||
)
|
||||
|
||||
# Vocoder model
|
||||
local = False
|
||||
if local:
|
||||
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
||||
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
||||
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
|
||||
vocos.load_state_dict(state_dict)
|
||||
vocos.eval()
|
||||
else:
|
||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
|
||||
# Tokenizer
|
||||
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
||||
|
||||
# Model
|
||||
model = CFM(
|
||||
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
mel_spec_kwargs=dict(
|
||||
target_sample_rate=target_sample_rate,
|
||||
n_mel_channels=n_mel_channels,
|
||||
hop_length=hop_length,
|
||||
),
|
||||
odeint_kwargs=dict(
|
||||
method=ode_method,
|
||||
),
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
|
||||
|
||||
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# start batch inference
|
||||
accelerator.wait_for_everyone()
|
||||
start = time.time()
|
||||
|
||||
with accelerator.split_between_processes(prompts_all) as prompts:
|
||||
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
||||
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
|
||||
ref_mels = ref_mels.to(device)
|
||||
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
||||
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
||||
|
||||
# Inference
|
||||
with torch.inference_mode():
|
||||
generated, _ = model.sample(
|
||||
cond=ref_mels,
|
||||
text=final_text_list,
|
||||
duration=total_mel_lens,
|
||||
lens=ref_mel_lens,
|
||||
steps=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
no_ref_audio=no_ref_audio,
|
||||
seed=seed,
|
||||
)
|
||||
# Final result
|
||||
for i, gen in enumerate(generated):
|
||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||
gen_mel_spec = gen.permute(0, 2, 1)
|
||||
generated_wave = vocos.decode(gen_mel_spec.cpu())
|
||||
if ref_rms_list[i] < target_rms:
|
||||
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||||
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
timediff = time.time() - start
|
||||
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -8,7 +8,7 @@ sys.path.append(os.getcwd())
|
||||
import multiprocessing as mp
|
||||
import numpy as np
|
||||
|
||||
from model.utils import (
|
||||
from f5_tts.model.utils import (
|
||||
get_librispeech_test,
|
||||
run_asr_wer,
|
||||
run_sim,
|
||||
@@ -8,7 +8,7 @@ sys.path.append(os.getcwd())
|
||||
import multiprocessing as mp
|
||||
import numpy as np
|
||||
|
||||
from model.utils import (
|
||||
from f5_tts.model.utils import (
|
||||
get_seed_tts_test,
|
||||
run_asr_wer,
|
||||
run_sim,
|
||||
@@ -13,7 +13,7 @@ import torchaudio
|
||||
from tqdm import tqdm
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
|
||||
from model.utils import (
|
||||
from f5_tts.model.utils import (
|
||||
convert_char_to_pinyin,
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
|
||||
from model.utils import (
|
||||
from f5_tts.model.utils import (
|
||||
repetition_found,
|
||||
convert_char_to_pinyin,
|
||||
)
|
||||
@@ -13,7 +13,7 @@ from concurrent.futures import ProcessPoolExecutor
|
||||
import torchaudio
|
||||
from datasets import Dataset
|
||||
|
||||
from model.utils import convert_char_to_pinyin
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
|
||||
|
||||
def deal_with_sub_path_files(dataset_path, sub_path):
|
||||
@@ -5,8 +5,8 @@ import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from vocos import Vocos
|
||||
|
||||
from model import CFM, UNetT, DiT
|
||||
from model.utils import (
|
||||
from f5_tts.model import CFM, UNetT, DiT
|
||||
from f5_tts.model.utils import (
|
||||
load_checkpoint,
|
||||
get_tokenizer,
|
||||
convert_char_to_pinyin,
|
||||
@@ -1,6 +1,6 @@
|
||||
from model import CFM, UNetT, DiT, Trainer
|
||||
from model.utils import get_tokenizer
|
||||
from model.dataset import load_dataset
|
||||
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
from f5_tts.model.dataset import load_dataset
|
||||
|
||||
|
||||
# -------------------------- Dataset Settings --------------------------- #
|
||||
Reference in New Issue
Block a user