convert to pkg, reorganize repo (#228)

* group files in f5_tts directory

* add setup.py

* use global imports

* simplify demo

* add install directions for library mode

* fix old huggingface_hub version constraint

* move finetune to package

* change imports to f5_tts.model

* bump version

* fix bad merge

* Update inference-cli.py

* fix HF space

* reformat

* fix utils.py vocab.txt import

* fix format

* adapt README for f5_tts package structure

* simplify app.py

* add gradio.Dockerfile and workflow

* refactored for pyproject.toml

* refactored for pyproject.toml

* added in reference to packaged files

* use fork for testing docker image

* added in reference to packaged files

* minor tweaks

* fixed inference-cli.toml path

* fixed inference-cli.toml path

* fixed inference-cli.toml path

* fixed inference-cli.toml path

* refactor eval_infer_batch.py

* fix typo

* added eval_infer_batch to scripts

---------

Co-authored-by: Roberts Slisans <rsxdalv@gmail.com>
Co-authored-by: Adam Kessel <adam@rosi-kessel.org>
Co-authored-by: Roberts Slisans <roberts.slisans@gmail.com>
This commit is contained in:
Yushen CHEN
2024-10-23 21:07:59 +08:00
committed by GitHub
parent 32c3ee7701
commit c4eee0f96b
38 changed files with 451 additions and 259 deletions

View File

@@ -0,0 +1,61 @@
name: Create and publish a Docker image
# Configures this workflow to run every time a change is pushed to the branch called `release`.
on:
push:
branches: ['main']
# Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds.
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
# There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu.
jobs:
build-and-push-image:
runs-on: ubuntu-latest
# Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
permissions:
contents: read
packages: write
#
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Free Up GitHub Actions Ubuntu Runner Disk Space 🔧
uses: jlumbroso/free-disk-space@main
with:
# This might remove tools that are actually needed, if set to "true" but frees about 6 GB
tool-cache: false
# All of these default to true, but feel free to set to "false" if necessary for your workflow
android: true
dotnet: true
haskell: true
large-packages: false
swap-storage: false
docker-images: false
# Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
- name: Log in to the Container registry
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels.
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
# This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages.
# It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository.
# It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step.
- name: Build and push Docker image
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4
with:
context: .
file: ./gradio.Dockerfile
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -63,11 +63,35 @@ pre-commit run --all-files
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
## Prepare Dataset
Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
### As a pip package
```bash
pip install git+https://github.com/SWivid/F5-TTS.git
```
```python
import gradio as gr
from f5_tts.gradio_app import app
with gr.Blocks() as main_app:
gr.Markdown("# This is an example of using F5-TTS within a bigger Gradio app")
# ... other Gradio components
app.render()
main_app.launch()
```
## Prepare Dataset
Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `f5_tts/model/dataset.py`.
```bash
# switch to the main directory
cd f5_tts
# prepare custom dataset up to your need
# download corresponding dataset first, and fill in the path in scripts
@@ -83,6 +107,9 @@ python scripts/prepare_wenetspeech4tts.py
Once your datasets are prepared, you can start the training process.
```bash
# switch to the main directory
cd f5_tts
# setup accelerate config, e.g. use multi-gpu ddp, fp16
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
accelerate config
@@ -90,7 +117,7 @@ accelerate launch train.py
```
An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
Gradio UI finetuning with `f5_tts/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
### Wandb Logging
@@ -136,6 +163,9 @@ for change model use `--ckpt_file` to specify the model you want to load,
for change vocab.txt use `--vocab_file` to provide your vocab.txt file.
```bash
# switch to the main directory
cd f5_tts
python inference-cli.py \
--model "F5-TTS" \
--ref_audio "tests/ref_audio/test_en_1_ref_short.wav" \
@@ -161,19 +191,19 @@ Currently supported features:
You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may also use local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
```bash
python gradio_app.py
python f5_tts/gradio_app.py
```
You can specify the port/host:
```bash
python gradio_app.py --port 7860 --host 0.0.0.0
python f5_tts/gradio_app.py --port 7860 --host 0.0.0.0
```
Or launch a share link:
```bash
python gradio_app.py --share
python f5_tts/gradio_app.py --share
```
### Speech Editing
@@ -181,7 +211,7 @@ python gradio_app.py --share
To test speech editing capabilities, use the following command.
```bash
python speech_edit.py
python f5_tts/speech_edit.py
```
## Evaluation
@@ -199,6 +229,9 @@ python speech_edit.py
To run batch inference for evaluations, execute the following commands:
```bash
# switch to the main directory
cd f5_tts
# batch inference for evaluations
accelerate config # if not set before
bash scripts/eval_infer_batch.sh
@@ -234,6 +267,9 @@ pip install faster-whisper==0.10.1
Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
```bash
# switch to the main directory
cd f5_tts
# Evaluation for Seed-TTS test set
python scripts/eval_seedtts_testset.py

3
app.py Normal file
View File

@@ -0,0 +1,3 @@
from f5_tts.gradio_app import app
app.queue().launch()

27
gradio.Dockerfile Normal file
View File

@@ -0,0 +1,27 @@
FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
USER root
ARG DEBIAN_FRONTEND=noninteractive
LABEL github_repo="https://github.com/rsxdalv/F5-TTS"
RUN set -x \
&& apt-get update \
&& apt-get -y install wget curl man git less openssl libssl-dev unzip unar build-essential aria2 tmux vim \
&& apt-get install -y openssh-server sox libsox-fmt-all libsox-fmt-mp3 libsndfile1-dev ffmpeg \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
WORKDIR /workspace
RUN git clone https://github.com/rsxdalv/F5-TTS.git \
&& cd F5-TTS \
&& pip install --no-cache-dir -r requirements.txt
ENV SHELL=/bin/bash
WORKDIR /workspace/F5-TTS/f5_tts
EXPOSE 7860
CMD python gradio_app.py

View File

@@ -1,10 +0,0 @@
from model.cfm import CFM
from model.backbones.unett import UNetT
from model.backbones.dit import DiT
from model.backbones.mmdit import MMDiT
from model.trainer import Trainer
__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]

52
pyproject.toml Normal file
View File

@@ -0,0 +1,52 @@
[build-system]
requires = ["setuptools >= 61.0", "setuptools-scm>=8.0"]
build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
dynamic = ["version"]
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
]
dependencies = [
"accelerate>=0.33.0",
"cached_path @ git+https://github.com/rsxdalv/cached_path@main",
"click",
"datasets",
"einops>=0.8.0",
"einx>=0.3.0",
"ema_pytorch>=0.5.2",
"gradio",
"jieba",
"librosa",
"matplotlib",
"numpy<=1.26.4",
"pydub",
"pypinyin",
"safetensors",
"soundfile",
"tomli",
"torch>=2.0.0",
"torchaudio>=2.0.0",
"torchdiffeq",
"tqdm>=4.65.0",
"transformers",
"vocos",
"wandb",
"x_transformers>=1.31.14",
]
[[project.authors]]
name = "Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen"
[project.urls]
Homepage = "https://github.com/SWivid/F5-TTS"
[project.scripts]
"finetune-cli" = "f5_tts.finetune_cli:main"
"inference-cli" = "f5_tts.inference_cli:main"
"eval_infer_batch" = "f5_tts.scripts.eval_infer_batch:main"

View File

@@ -1,198 +0,0 @@
import sys
import os
sys.path.append(os.getcwd())
import time
import random
from tqdm import tqdm
import argparse
import torch
import torchaudio
from accelerate import Accelerator
from vocos import Vocos
from model import CFM, UNetT, DiT
from model.utils import (
load_checkpoint,
get_tokenizer,
get_seedtts_testset_metainfo,
get_librispeech_test_clean_metainfo,
get_inference_prompt,
)
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
# --------------------- Dataset Settings -------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
tokenizer = "pinyin"
# ---------------------- infer setting ---------------------- #
parser = argparse.ArgumentParser(description="batch inference")
parser.add_argument("-s", "--seed", default=None, type=int)
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
parser.add_argument("-n", "--expname", required=True)
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
parser.add_argument("-o", "--odemethod", default="euler")
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
parser.add_argument("-t", "--testset", required=True)
args = parser.parse_args()
seed = args.seed
dataset_name = args.dataset
exp_name = args.expname
ckpt_step = args.ckptstep
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
nfe_step = args.nfestep
ode_method = args.odemethod
sway_sampling_coef = args.swaysampling
testset = args.testset
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
cfg_strength = 2.0
speed = 1.0
use_truth_duration = False
no_ref_audio = False
if exp_name == "F5TTS_Base":
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
elif exp_name == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if testset == "ls_pc_test_clean":
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
elif testset == "seedtts_test_zh":
metalst = "data/seedtts_testset/zh/meta.lst"
metainfo = get_seedtts_testset_metainfo(metalst)
elif testset == "seedtts_test_en":
metalst = "data/seedtts_testset/en/meta.lst"
metainfo = get_seedtts_testset_metainfo(metalst)
# path to save genereted wavs
if seed is None:
seed = random.randint(-10000, 10000)
output_dir = (
f"results/{exp_name}_{ckpt_step}/{testset}/"
f"seed{seed}_{ode_method}_nfe{nfe_step}"
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
f"_cfg{cfg_strength}_speed{speed}"
f"{'_gt-dur' if use_truth_duration else ''}"
f"{'_no-ref-audio' if no_ref_audio else ''}"
)
# -------------------------------------------------#
use_ema = True
prompts_all = get_inference_prompt(
metainfo,
speed=speed,
tokenizer=tokenizer,
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
target_rms=target_rms,
use_truth_duration=use_truth_duration,
infer_batch_size=infer_batch_size,
)
# Vocoder model
local = False
if local:
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()
else:
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
# Model
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
),
odeint_kwargs=dict(
method=ode_method,
),
vocab_char_map=vocab_char_map,
).to(device)
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
if not os.path.exists(output_dir) and accelerator.is_main_process:
os.makedirs(output_dir)
# start batch inference
accelerator.wait_for_everyone()
start = time.time()
with accelerator.split_between_processes(prompts_all) as prompts:
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
ref_mels = ref_mels.to(device)
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
# Inference
with torch.inference_mode():
generated, _ = model.sample(
cond=ref_mels,
text=final_text_list,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
no_ref_audio=no_ref_audio,
seed=seed,
)
# Final result
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1)
generated_wave = vocos.decode(gen_mel_spec.cpu())
if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
timediff = time.time() - start
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")

View File

@@ -3,11 +3,11 @@ import torch
import tqdm
from cached_path import cached_path
from model import DiT, UNetT
from model.utils import save_spectrogram
from f5_tts.model import DiT, UNetT
from f5_tts.model.utils import save_spectrogram
from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
from model.utils import seed_everything
from f5_tts.model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
from f5_tts.model.utils import seed_everything
import random
import sys

View File

@@ -1,7 +1,7 @@
import argparse
from model import CFM, UNetT, DiT, Trainer
from model.utils import get_tokenizer
from model.dataset import load_dataset
from f5_tts.model import CFM, UNetT, DiT, Trainer
from f5_tts.model.utils import get_tokenizer
from f5_tts.model.dataset import load_dataset
from cached_path import cached_path
import shutil
import os

View File

@@ -17,14 +17,14 @@ import shutil
import time
import json
from model.utils import convert_char_to_pinyin
from f5_tts.model.utils import convert_char_to_pinyin
import signal
import psutil
import platform
import subprocess
from datasets.arrow_writer import ArrowWriter
from datasets import Dataset as Dataset_
from api import F5TTS
from f5_tts.api import F5TTS
training_process = None

View File

@@ -27,11 +27,11 @@ def gpu_decorator(func):
return func
from model import DiT, UNetT
from model.utils import (
from f5_tts.model import DiT, UNetT
from f5_tts.model.utils import (
save_spectrogram,
)
from model.utils_infer import (
from f5_tts.model.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,

View File

@@ -1,15 +1,17 @@
import argparse
import codecs
import re
import os
from pathlib import Path
from importlib.resources import files
import numpy as np
import soundfile as sf
import tomli
from cached_path import cached_path
from model import DiT, UNetT
from model.utils_infer import (
from f5_tts.model import DiT, UNetT
from f5_tts.model.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,
@@ -26,8 +28,8 @@ parser = argparse.ArgumentParser(
parser.add_argument(
"-c",
"--config",
help="Configuration file. Default=cli-config.toml",
default="inference-cli.toml",
help="Configuration file. Default=inference-cli.toml",
default=os.path.join(files('f5_tts').joinpath('data'), 'inference-cli.toml')
)
parser.add_argument(
"-m",
@@ -166,5 +168,8 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
remove_silence_for_generated_wav(f.name)
print(f.name)
def main():
main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence)
main_process(ref_audio, ref_text, gen_text, ema_model, remove_silence)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,10 @@
from f5_tts.model.cfm import CFM
from f5_tts.model.backbones.unett import UNetT
from f5_tts.model.backbones.dit import DiT
from f5_tts.model.backbones.mmdit import MMDiT
from f5_tts.model.trainer import Trainer
__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]

View File

@@ -15,7 +15,7 @@ import torch.nn.functional as F
from x_transformers.x_transformers import RotaryEmbedding
from model.modules import (
from f5_tts.model.modules import (
TimestepEmbedding,
ConvNeXtV2Block,
ConvPositionEmbedding,

View File

@@ -14,7 +14,7 @@ from torch import nn
from x_transformers.x_transformers import RotaryEmbedding
from model.modules import (
from f5_tts.model.modules import (
TimestepEmbedding,
ConvPositionEmbedding,
MMDiTBlock,

View File

@@ -17,7 +17,7 @@ import torch.nn.functional as F
from x_transformers import RMSNorm
from x_transformers.x_transformers import RotaryEmbedding
from model.modules import (
from f5_tts.model.modules import (
TimestepEmbedding,
ConvNeXtV2Block,
ConvPositionEmbedding,

View File

@@ -18,8 +18,8 @@ from torch.nn.utils.rnn import pad_sequence
from torchdiffeq import odeint
from model.modules import MelSpec
from model.utils import (
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import (
default,
exists,
list_str_to_idx,

View File

@@ -10,8 +10,8 @@ from datasets import load_from_disk
from datasets import Dataset as Dataset_
from torch import nn
from model.modules import MelSpec
from model.utils import default
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import default
class HFDataset(Dataset):

View File

@@ -15,9 +15,9 @@ from accelerate.utils import DistributedDataParallelKwargs
from ema_pytorch import EMA
from model import CFM
from model.utils import exists, default
from model.dataset import DynamicBatchSampler, collate_fn
from f5_tts.model import CFM
from f5_tts.model.utils import exists, default
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
# trainer

View File

@@ -4,6 +4,7 @@ import os
import math
import random
import string
from importlib.resources import files
from tqdm import tqdm
from collections import defaultdict
@@ -20,8 +21,8 @@ import torchaudio
import jieba
from pypinyin import lazy_pinyin, Style
from model.ecapa_tdnn import ECAPA_TDNN_SMALL
from model.modules import MelSpec
from f5_tts.model.ecapa_tdnn import ECAPA_TDNN_SMALL
from f5_tts.model.modules import MelSpec
# seed everything
@@ -121,7 +122,8 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
- if use "byte", set to 256 (unicode byte range)
"""
if tokenizer in ["pinyin", "char"]:
with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
tokenizer_path = os.path.join(files('f5_tts').joinpath('data'), f"{dataset_name}_{tokenizer}/vocab.txt")
with open(tokenizer_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i

View File

@@ -12,8 +12,8 @@ from pydub import AudioSegment, silence
from transformers import pipeline
from vocos import Vocos
from model import CFM
from model.utils import (
from f5_tts.model import CFM
from f5_tts.model.utils import (
load_checkpoint,
get_tokenizer,
convert_char_to_pinyin,

View File

@@ -3,7 +3,7 @@ import os
sys.path.append(os.getcwd())
from model import M2_TTS, DiT
from f5_tts.model import M2_TTS, DiT
import torch
import thop

View File

@@ -0,0 +1,204 @@
import sys
import os
sys.path.append(os.getcwd())
import time
import random
from tqdm import tqdm
import argparse
from importlib.resources import files
import torch
import torchaudio
from accelerate import Accelerator
from vocos import Vocos
from f5_tts.model import CFM, UNetT, DiT
from f5_tts.model.utils import (
load_checkpoint,
get_tokenizer,
get_seedtts_testset_metainfo,
get_librispeech_test_clean_metainfo,
get_inference_prompt,
)
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
# --------------------- Dataset Settings -------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
tokenizer = "pinyin"
def main():
# ---------------------- infer setting ---------------------- #
parser = argparse.ArgumentParser(description="batch inference")
parser.add_argument("-s", "--seed", default=None, type=int)
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
parser.add_argument("-n", "--expname", required=True)
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
parser.add_argument("-o", "--odemethod", default="euler")
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
parser.add_argument("-t", "--testset", required=True)
args = parser.parse_args()
seed = args.seed
dataset_name = args.dataset
exp_name = args.expname
ckpt_step = args.ckptstep
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
nfe_step = args.nfestep
ode_method = args.odemethod
sway_sampling_coef = args.swaysampling
testset = args.testset
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
cfg_strength = 2.0
speed = 1.0
use_truth_duration = False
no_ref_audio = False
if exp_name == "F5TTS_Base":
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
elif exp_name == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
datapath = files('f5_tts').joinpath('data')
if testset == "ls_pc_test_clean":
metalst = os.path.join(datapath,"librispeech_pc_test_clean_cross_sentence.lst")
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
elif testset == "seedtts_test_zh":
metalst = os.path.join(datapath,"seedtts_testset/zh/meta.lst")
metainfo = get_seedtts_testset_metainfo(metalst)
elif testset == "seedtts_test_en":
metalst = os.path.join(datapath,"seedtts_testset/en/meta.lst")
metainfo = get_seedtts_testset_metainfo(metalst)
# path to save genereted wavs
if seed is None:
seed = random.randint(-10000, 10000)
output_dir = (
f"results/{exp_name}_{ckpt_step}/{testset}/"
f"seed{seed}_{ode_method}_nfe{nfe_step}"
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
f"_cfg{cfg_strength}_speed{speed}"
f"{'_gt-dur' if use_truth_duration else ''}"
f"{'_no-ref-audio' if no_ref_audio else ''}"
)
# -------------------------------------------------#
use_ema = True
prompts_all = get_inference_prompt(
metainfo,
speed=speed,
tokenizer=tokenizer,
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
target_rms=target_rms,
use_truth_duration=use_truth_duration,
infer_batch_size=infer_batch_size,
)
# Vocoder model
local = False
if local:
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()
else:
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
# Model
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
),
odeint_kwargs=dict(
method=ode_method,
),
vocab_char_map=vocab_char_map,
).to(device)
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
if not os.path.exists(output_dir) and accelerator.is_main_process:
os.makedirs(output_dir)
# start batch inference
accelerator.wait_for_everyone()
start = time.time()
with accelerator.split_between_processes(prompts_all) as prompts:
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
ref_mels = ref_mels.to(device)
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
# Inference
with torch.inference_mode():
generated, _ = model.sample(
cond=ref_mels,
text=final_text_list,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
no_ref_audio=no_ref_audio,
seed=seed,
)
# Final result
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1)
generated_wave = vocos.decode(gen_mel_spec.cpu())
if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
timediff = time.time() - start
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
if __name__ == "__main__":
main()

View File

@@ -8,7 +8,7 @@ sys.path.append(os.getcwd())
import multiprocessing as mp
import numpy as np
from model.utils import (
from f5_tts.model.utils import (
get_librispeech_test,
run_asr_wer,
run_sim,

View File

@@ -8,7 +8,7 @@ sys.path.append(os.getcwd())
import multiprocessing as mp
import numpy as np
from model.utils import (
from f5_tts.model.utils import (
get_seed_tts_test,
run_asr_wer,
run_sim,

View File

@@ -13,7 +13,7 @@ import torchaudio
from tqdm import tqdm
from datasets.arrow_writer import ArrowWriter
from model.utils import (
from f5_tts.model.utils import (
convert_char_to_pinyin,
)

View File

@@ -16,7 +16,7 @@ from concurrent.futures import ProcessPoolExecutor
from datasets.arrow_writer import ArrowWriter
from model.utils import (
from f5_tts.model.utils import (
repetition_found,
convert_char_to_pinyin,
)

View File

@@ -13,7 +13,7 @@ from concurrent.futures import ProcessPoolExecutor
import torchaudio
from datasets import Dataset
from model.utils import convert_char_to_pinyin
from f5_tts.model.utils import convert_char_to_pinyin
def deal_with_sub_path_files(dataset_path, sub_path):

View File

@@ -5,8 +5,8 @@ import torch.nn.functional as F
import torchaudio
from vocos import Vocos
from model import CFM, UNetT, DiT
from model.utils import (
from f5_tts.model import CFM, UNetT, DiT
from f5_tts.model.utils import (
load_checkpoint,
get_tokenizer,
convert_char_to_pinyin,

View File

@@ -1,6 +1,6 @@
from model import CFM, UNetT, DiT, Trainer
from model.utils import get_tokenizer
from model.dataset import load_dataset
from f5_tts.model import CFM, UNetT, DiT, Trainer
from f5_tts.model.utils import get_tokenizer
from f5_tts.model.dataset import load_dataset
# -------------------------- Dataset Settings --------------------------- #