mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-25 20:34:27 -08:00
Compare commits
44 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
25b3291715 | ||
|
|
16c480a61d | ||
|
|
d9dfbe47cc | ||
|
|
d1f6c95fe8 | ||
|
|
2428d01a56 | ||
|
|
9401842930 | ||
|
|
eca56943ec | ||
|
|
ae51cc3d34 | ||
|
|
4681a1c177 | ||
|
|
5b178397e0 | ||
|
|
2724f9f101 | ||
|
|
7258b09529 | ||
|
|
784e3862b4 | ||
|
|
6f6968b034 | ||
|
|
9bd2d13be1 | ||
|
|
b7c41af9cd | ||
|
|
eaa7fd8a01 | ||
|
|
f34465d118 | ||
|
|
393993321d | ||
|
|
29d3326bed | ||
|
|
67e43dc0fb | ||
|
|
8469025b1c | ||
|
|
5bd8cd7aed | ||
|
|
7236536f9a | ||
|
|
6b7f6eefdc | ||
|
|
b9156c0ad5 | ||
|
|
3ad3211915 | ||
|
|
f6726a78cc | ||
|
|
1d0cf2b8ba | ||
|
|
1d82b7928e | ||
|
|
4ae5347282 | ||
|
|
621559cbbe | ||
|
|
526b09eebd | ||
|
|
9afa80f204 | ||
|
|
c6b3189bbd | ||
|
|
c87ce39515 | ||
|
|
10ef27065b | ||
|
|
f374640f34 | ||
|
|
d5f4c88aa4 | ||
|
|
f968e13b6d | ||
|
|
339b17fed3 | ||
|
|
79302b694a | ||
|
|
a1e88c2a9e | ||
|
|
1ab90505a4 |
12
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
12
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: "Bug Report"
|
||||
description: |
|
||||
Please provide as much details to help address the issue, including logs and screenshots.
|
||||
Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
|
||||
labels:
|
||||
- bug
|
||||
body:
|
||||
@@ -15,13 +15,13 @@ body:
|
||||
required: true
|
||||
- label: I have searched for existing issues, including closed ones, and couldn't find a solution.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report in order to facilitate communication.
|
||||
- label: I am using English to submit this issue to facilitate community communication.
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Environment Details
|
||||
description: "Provide details such as OS, Python version, and any relevant software or dependencies."
|
||||
placeholder: e.g., CentOS Linux 7, RTX 3090, Python 3.10, torch==2.3.0, cuda 11.8
|
||||
description: "Provide details including OS, GPU info, Python version, any relevant software or dependencies, and trainer setting."
|
||||
placeholder: e.g., CentOS Linux 7, 4 * RTX 3090, Python 3.10, torch==2.3.0+cu118, cuda 11.8, config yaml is ...
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
@@ -39,12 +39,12 @@ body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ✔️ Expected Behavior
|
||||
placeholder: Describe what you expected to happen.
|
||||
placeholder: Describe in detail what you expected to happen.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ❌ Actual Behavior
|
||||
placeholder: Describe what actually happened.
|
||||
placeholder: Describe in detail what actually happened.
|
||||
validations:
|
||||
required: false
|
||||
2
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
2
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
@@ -15,7 +15,7 @@ body:
|
||||
required: true
|
||||
- label: I have searched for existing issues, including closed ones, and found not discussion yet.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report in order to facilitate communication.
|
||||
- label: I am using English to submit this issue to facilitate community communication.
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
16
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
16
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: "Help Wanted"
|
||||
description: |
|
||||
Please provide as much details to help address the issue, including logs and screenshots.
|
||||
Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
|
||||
labels:
|
||||
- help wanted
|
||||
body:
|
||||
@@ -15,36 +15,40 @@ body:
|
||||
required: true
|
||||
- label: I have searched for existing issues, including closed ones, and couldn't find a solution.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report in order to facilitate communication.
|
||||
- label: I am using English to submit this issue to facilitate community communication.
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Environment Details
|
||||
description: "Provide details such as OS, Python version, and any relevant software or dependencies."
|
||||
placeholder: e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
|
||||
placeholder: |
|
||||
e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
|
||||
If training or finetuning related, provide detailed configuration including GPU info and training setup.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: |
|
||||
Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
|
||||
Include detailed steps, screenshots, and logs. Provide used prompt wav and text. Use the correct markdown syntax for code blocks.
|
||||
placeholder: |
|
||||
1. Create a new conda environment.
|
||||
2. Clone the repository and install as pip package.
|
||||
3. Run the command: `f5-tts_infer-gradio` with no ref_text provided.
|
||||
4. Stuck there with the following message... (attach logs and also error msg e.g. after ctrl-c).
|
||||
5. Prompt & generated wavs are [change suffix to .mp4 to enable direct upload or pack all to .zip].
|
||||
6. Reference audio's transcription or provided ref_text is `xxx`, and text to generate is `xxx`.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ✔️ Expected Behavior
|
||||
placeholder: Describe what you expected to happen, e.g. output a generated audio
|
||||
placeholder: Describe what you expected to happen in detail, e.g. output a generated audio.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ❌ Actual Behavior
|
||||
placeholder: Describe what actually happened, failure messages, etc.
|
||||
placeholder: Describe what actually happened in detail, failure messages, etc.
|
||||
validations:
|
||||
required: false
|
||||
6
.github/ISSUE_TEMPLATE/question.yml
vendored
6
.github/ISSUE_TEMPLATE/question.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: "Question"
|
||||
description: |
|
||||
Pure question or inquiry about the project, usage issue goes with "help wanted".
|
||||
Research question or pure inquiry about the project, usage issue goes with "help wanted".
|
||||
labels:
|
||||
- question
|
||||
body:
|
||||
@@ -9,13 +9,13 @@ body:
|
||||
label: Checks
|
||||
description: "To help us grasp quickly, please confirm the following:"
|
||||
options:
|
||||
- label: This template is only for question, not feature requests or bug reports.
|
||||
- label: This template is only for research question, not usage problems, feature requests or bug reports.
|
||||
required: true
|
||||
- label: I have thoroughly reviewed the project documentation and read the related paper(s).
|
||||
required: true
|
||||
- label: I have searched for existing issues, including closed ones, no similar questions.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report in order to facilitate communication.
|
||||
- label: I am using English to submit this issue to facilitate community communication.
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -7,8 +7,6 @@ ckpts/
|
||||
wandb/
|
||||
results/
|
||||
|
||||
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.7.0
|
||||
rev: v0.11.2
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
@@ -9,6 +9,6 @@ repos:
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v2.3.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
|
||||
@@ -23,4 +23,8 @@ RUN git clone https://github.com/SWivid/F5-TTS.git \
|
||||
|
||||
ENV SHELL=/bin/bash
|
||||
|
||||
VOLUME /root/.cache/huggingface/hub/
|
||||
|
||||
EXPOSE 7860
|
||||
|
||||
WORKDIR /workspace/F5-TTS
|
||||
|
||||
29
README.md
29
README.md
@@ -100,13 +100,19 @@ conda activate f5-tts
|
||||
# Build from Dockerfile
|
||||
docker build -t f5tts:v1 .
|
||||
|
||||
# Or pull from GitHub Container Registry
|
||||
docker pull ghcr.io/swivid/f5-tts:main
|
||||
# Run from GitHub Container Registry
|
||||
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main
|
||||
|
||||
# Quickstart if you want to just run the web interface (not CLI)
|
||||
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main f5-tts_infer-gradio --host 0.0.0.0
|
||||
```
|
||||
|
||||
|
||||
## Inference
|
||||
|
||||
- In order to achieve desired performance, take a moment to read [detailed guidance](src/f5_tts/infer).
|
||||
- By properly searching the keywords of problem encountered, [issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very helpful.
|
||||
|
||||
### 1. Gradio App
|
||||
|
||||
Currently supported features:
|
||||
@@ -173,10 +179,18 @@ f5-tts_infer-cli -c custom.toml
|
||||
f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
|
||||
```
|
||||
|
||||
### 3. More instructions
|
||||
### 3. Runtime
|
||||
|
||||
- In order to have better generation results, take a moment to read [detailed guidance](src/f5_tts/infer).
|
||||
- The [Issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very useful, please try to find the solution by properly searching the keywords of problem encountered. If no answer found, then feel free to open an issue.
|
||||
Deployment solution with Triton and TensorRT-LLM.
|
||||
|
||||
#### Benchmark Results
|
||||
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs.
|
||||
|
||||
| Model | Concurrency | Avg Latency | RTF |
|
||||
|-------|-------------|----------------|-------|
|
||||
| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
|
||||
|
||||
See [detailed instructions](src/f5_tts/runtime/triton_trtllm/README.md) for more information.
|
||||
|
||||
|
||||
## Training
|
||||
@@ -200,7 +214,7 @@ Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
|
||||
|
||||
## Development
|
||||
|
||||
Use pre-commit to ensure code quality (will run linters and formatters automatically)
|
||||
Use pre-commit to ensure code quality (will run linters and formatters automatically):
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
@@ -213,7 +227,7 @@ When making a pull request, before each commit, run:
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
|
||||
Note: Some model components have linting exceptions for E722 to accommodate tensor notation.
|
||||
|
||||
|
||||
## Acknowledgements
|
||||
@@ -228,6 +242,7 @@ Note: Some model components have linting exceptions for E722 to accommodate tens
|
||||
- [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
|
||||
- [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
|
||||
- [F5-TTS-ONNX](https://github.com/DakeQQ/F5-TTS-ONNX) ONNX Runtime version by [DakeQQ](https://github.com/DakeQQ)
|
||||
- [Yuekai Zhang](https://github.com/yuekaizhang) Triton and TensorRT-LLM support ~
|
||||
|
||||
## Citation
|
||||
If our work and codebase is useful for you, please cite as:
|
||||
|
||||
@@ -1,12 +1,3 @@
|
||||
The pretrained model checkpoints can be reached at https://huggingface.co/SWivid/F5-TTS.
|
||||
|
||||
Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
|
||||
|
||||
```
|
||||
ckpts/
|
||||
F5TTS_v1_Base/
|
||||
model_1250000.safetensors
|
||||
F5TTS_Base/
|
||||
model_1200000.safetensors
|
||||
E2TTS_Base/
|
||||
model_1200000.safetensors
|
||||
```
|
||||
Scripts will automatically pull model checkpoints from Huggingface, by default to `~/.cache/huggingface/hub/`.
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "1.0.3"
|
||||
version = "1.1.0"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
@@ -26,6 +26,7 @@ dependencies = [
|
||||
"librosa",
|
||||
"matplotlib",
|
||||
"numpy<=1.26.4",
|
||||
"pydantic<=2.10.6",
|
||||
"pydub",
|
||||
"pypinyin",
|
||||
"safetensors",
|
||||
|
||||
@@ -5,6 +5,7 @@ from importlib.resources import files
|
||||
import soundfile as sf
|
||||
import tqdm
|
||||
from cached_path import cached_path
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import (
|
||||
@@ -16,7 +17,6 @@ from f5_tts.infer.utils_infer import (
|
||||
remove_silence_for_generated_wav,
|
||||
save_spectrogram,
|
||||
)
|
||||
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
|
||||
from f5_tts.model.utils import seed_everything
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class F5TTS:
|
||||
hf_cache_dir=None,
|
||||
):
|
||||
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
||||
model_cls = globals()[model_cfg.model.backbone]
|
||||
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
@@ -119,7 +119,7 @@ class F5TTS:
|
||||
seed_everything(seed)
|
||||
self.seed = seed
|
||||
|
||||
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
|
||||
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)
|
||||
|
||||
wav, sr, spec = infer_process(
|
||||
ref_file,
|
||||
|
||||
@@ -10,7 +10,7 @@ datasets:
|
||||
num_workers: 16
|
||||
|
||||
optim:
|
||||
epochs: 11
|
||||
epochs: 11 # only suitable for Emilia, if you want to train it on LibriTTS, set epoch 686
|
||||
learning_rate: 7.5e-5
|
||||
num_warmup_updates: 20000 # warmup updates
|
||||
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
||||
@@ -49,4 +49,4 @@ ckpts:
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
|
||||
@@ -10,6 +10,7 @@ from importlib.resources import files
|
||||
import torch
|
||||
import torchaudio
|
||||
from accelerate import Accelerator
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -19,7 +20,7 @@ from f5_tts.eval.utils_eval import (
|
||||
get_seedtts_testset_metainfo,
|
||||
)
|
||||
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
|
||||
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
|
||||
accelerator = Accelerator()
|
||||
@@ -65,7 +66,7 @@ def main():
|
||||
no_ref_audio = False
|
||||
|
||||
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
||||
model_cls = globals()[model_cfg.model.backbone]
|
||||
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
dataset_name = model_cfg.datasets.name
|
||||
@@ -195,7 +196,7 @@ def main():
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
timediff = time.time() - start
|
||||
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
|
||||
print(f"Done batch inference in {timediff / 60:.2f} minutes.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -148,9 +148,9 @@ def get_inference_prompt(
|
||||
|
||||
# deal with batch
|
||||
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
||||
assert (
|
||||
min_tokens <= total_mel_len <= max_tokens
|
||||
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||
assert min_tokens <= total_mel_len <= max_tokens, (
|
||||
f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||
)
|
||||
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
|
||||
|
||||
utts[bucket_i].append(utt)
|
||||
|
||||
@@ -4,16 +4,17 @@ The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://h
|
||||
|
||||
**More checkpoints with whole community efforts can be found in [SHARED.md](SHARED.md), supporting more languages.**
|
||||
|
||||
Currently support **30s for a single** generation, which is the **total length** including both prompt and output audio. However, you can provide `infer_cli` and `infer_gradio` with longer text, will automatically do chunk generation. Long reference audio will be **clip short to ~15s**.
|
||||
Currently support **30s for a single** generation, which is the **total length** (same logic if `fix_duration`) including both prompt and output audio. However, `infer_cli` and `infer_gradio` will automatically do chunk generation for longer text. Long reference audio will be **clip short to ~12s**.
|
||||
|
||||
To avoid possible inference failures, make sure you have seen through the following instructions.
|
||||
|
||||
- Use reference audio <15s and leave some silence (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
|
||||
- Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
|
||||
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses.
|
||||
- Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English.
|
||||
- If the generation output is blank (pure silence), check for ffmpeg installation (various tutorials online, blogs, videos, etc.).
|
||||
- Try turn off use_ema if using an early-stage finetuned checkpoint (which goes just few updates).
|
||||
- Use reference audio <12s and leave proper silence space (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
|
||||
- <ins>Uppercased letters</ins> (best with form like K.F.C.) will be uttered letter by letter, and lowercased letters used for common words.
|
||||
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") <ins>to explicitly introduce some pauses</ins>.
|
||||
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
|
||||
- <ins>Preprocess numbers</ins> to Chinese letters if you want to have them read in Chinese, otherwise in English.
|
||||
- If the generation output is blank (pure silence), <ins>check for ffmpeg installation</ins>.
|
||||
- Try <ins>turn off `use_ema` if using an early-stage</ins> finetuned checkpoint (which goes just few updates).
|
||||
|
||||
|
||||
## Gradio App
|
||||
@@ -23,7 +24,7 @@ Currently supported features:
|
||||
- Basic TTS with Chunk Inference
|
||||
- Multi-Style / Multi-Speaker Generation
|
||||
- Voice Chat powered by Qwen2.5-3B-Instruct
|
||||
- [Custom inference with more language support](src/f5_tts/infer/SHARED.md)
|
||||
- [Custom inference with more language support](SHARED.md)
|
||||
|
||||
The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@
|
||||
|
||||
```bash
|
||||
Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
|
||||
# A Variant Model: hf://SWivid/F5-TTS/F5TTS_v1_Base_no_zero_init/model_1250000.safetensors
|
||||
Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
```
|
||||
@@ -136,11 +137,11 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
||||
#### F5-TTS Base @ ja @ Jmica
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_25498980)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
|
||||
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_21999120)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
|
||||
|
||||
```bash
|
||||
Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
|
||||
Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
|
||||
Model: hf://Jmica/F5TTS/JA_21999120/model_21999120.pt
|
||||
Vocab: hf://Jmica/F5TTS/JA_21999120/vocab_japanese.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import numpy as np
|
||||
import soundfile as sf
|
||||
import tomli
|
||||
from cached_path import cached_path
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import (
|
||||
@@ -21,13 +22,13 @@ from f5_tts.infer.utils_infer import (
|
||||
sway_sampling_coef,
|
||||
speed,
|
||||
fix_duration,
|
||||
device,
|
||||
infer_process,
|
||||
load_model,
|
||||
load_vocoder,
|
||||
preprocess_ref_audio_text,
|
||||
remove_silence_for_generated_wav,
|
||||
)
|
||||
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -162,6 +163,11 @@ parser.add_argument(
|
||||
type=float,
|
||||
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
help="Specify the device to run on",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -202,6 +208,7 @@ cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
|
||||
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
|
||||
speed = args.speed or config.get("speed", speed)
|
||||
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
|
||||
device = args.device or config.get("device", device)
|
||||
|
||||
|
||||
# patches for pip pkg user
|
||||
@@ -239,20 +246,23 @@ if vocoder_name == "vocos":
|
||||
elif vocoder_name == "bigvgan":
|
||||
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
||||
|
||||
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
|
||||
vocoder = load_vocoder(
|
||||
vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
|
||||
)
|
||||
|
||||
|
||||
# load TTS model
|
||||
|
||||
model_cfg = OmegaConf.load(
|
||||
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
||||
).model
|
||||
model_cls = globals()[model_cfg.backbone]
|
||||
)
|
||||
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
|
||||
|
||||
if model != "F5TTS_Base":
|
||||
assert vocoder_name == model_cfg.mel_spec.mel_spec_type
|
||||
assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type
|
||||
|
||||
# override for previous models
|
||||
if model == "F5TTS_Base":
|
||||
@@ -269,7 +279,9 @@ if not ckpt_file:
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
|
||||
|
||||
print(f"Using {model}...")
|
||||
ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
|
||||
ema_model = load_model(
|
||||
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
|
||||
)
|
||||
|
||||
|
||||
# inference process
|
||||
@@ -325,6 +337,7 @@ def main():
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
speed=speed,
|
||||
fix_duration=fix_duration,
|
||||
device=device,
|
||||
)
|
||||
generated_audio_segments.append(audio_segment)
|
||||
|
||||
@@ -332,7 +345,7 @@ def main():
|
||||
if len(gen_text_) > 200:
|
||||
gen_text_ = gen_text_[:200] + " ... "
|
||||
sf.write(
|
||||
os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"),
|
||||
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
|
||||
audio_segment,
|
||||
final_sample_rate,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# ruff: noqa: E402
|
||||
# Above allows ruff to ignore E402: module level import not at top of file
|
||||
|
||||
import gc
|
||||
import json
|
||||
import re
|
||||
import tempfile
|
||||
@@ -11,6 +12,7 @@ import click
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torchaudio
|
||||
from cached_path import cached_path
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@@ -550,35 +552,50 @@ Have a conversation with an AI using your reference voice!
|
||||
"""
|
||||
)
|
||||
|
||||
if not USING_SPACES:
|
||||
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
|
||||
chat_model_name_list = [
|
||||
"Qwen/Qwen2.5-3B-Instruct",
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
]
|
||||
|
||||
chat_interface_container = gr.Column(visible=False)
|
||||
@gpu_decorator
|
||||
def load_chat_model(chat_model_name):
|
||||
show_info = gr.Info
|
||||
global chat_model_state, chat_tokenizer_state
|
||||
if chat_model_state is not None:
|
||||
chat_model_state = None
|
||||
chat_tokenizer_state = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@gpu_decorator
|
||||
def load_chat_model():
|
||||
global chat_model_state, chat_tokenizer_state
|
||||
if chat_model_state is None:
|
||||
show_info = gr.Info
|
||||
show_info("Loading chat model...")
|
||||
model_name = "Qwen/Qwen2.5-3B-Instruct"
|
||||
chat_model_state = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype="auto", device_map="auto"
|
||||
)
|
||||
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
|
||||
show_info("Chat model loaded.")
|
||||
show_info(f"Loading chat model: {chat_model_name}")
|
||||
chat_model_state = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype="auto", device_map="auto")
|
||||
chat_tokenizer_state = AutoTokenizer.from_pretrained(chat_model_name)
|
||||
show_info(f"Chat model {chat_model_name} loaded successfully!")
|
||||
|
||||
return gr.update(visible=False), gr.update(visible=True)
|
||||
return gr.update(visible=False), gr.update(visible=True)
|
||||
|
||||
load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])
|
||||
if USING_SPACES:
|
||||
load_chat_model(chat_model_name_list[0])
|
||||
|
||||
else:
|
||||
chat_interface_container = gr.Column()
|
||||
chat_model_name_input = gr.Dropdown(
|
||||
choices=chat_model_name_list,
|
||||
value=chat_model_name_list[0],
|
||||
label="Chat Model Name",
|
||||
info="Enter the name of a HuggingFace chat model",
|
||||
allow_custom_value=not USING_SPACES,
|
||||
)
|
||||
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary", visible=not USING_SPACES)
|
||||
chat_interface_container = gr.Column(visible=USING_SPACES)
|
||||
|
||||
if chat_model_state is None:
|
||||
model_name = "Qwen/Qwen2.5-3B-Instruct"
|
||||
chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
|
||||
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
|
||||
chat_model_name_input.change(
|
||||
lambda: gr.update(visible=True),
|
||||
None,
|
||||
load_chat_model_btn,
|
||||
show_progress="hidden",
|
||||
)
|
||||
load_chat_model_btn.click(
|
||||
load_chat_model, inputs=[chat_model_name_input], outputs=[load_chat_model_btn, chat_interface_container]
|
||||
)
|
||||
|
||||
with chat_interface_container:
|
||||
with gr.Row():
|
||||
@@ -758,9 +775,9 @@ This is {"a local web UI for [F5 TTS](https://github.com/SWivid/F5-TTS)" if not
|
||||
|
||||
The checkpoints currently support English and Chinese.
|
||||
|
||||
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
|
||||
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 12s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
|
||||
|
||||
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
|
||||
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<12s). Ensure the audio is fully uploaded before generating.**
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@@ -7,10 +7,11 @@ from importlib.resources import files
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
|
||||
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
||||
|
||||
device = (
|
||||
@@ -40,7 +41,7 @@ target_rms = 0.1
|
||||
|
||||
|
||||
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
||||
model_cls = globals()[model_cfg.model.backbone]
|
||||
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
dataset_name = model_cfg.datasets.name
|
||||
|
||||
@@ -21,7 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import tqdm
|
||||
from huggingface_hub import snapshot_download, hf_hub_download
|
||||
from huggingface_hub import hf_hub_download
|
||||
from pydub import AudioSegment, silence
|
||||
from transformers import pipeline
|
||||
from vocos import Vocos
|
||||
@@ -128,11 +128,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
|
||||
except ImportError:
|
||||
print("You need to follow the README to init submodule and change the BigVGAN source code.")
|
||||
if is_local:
|
||||
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
|
||||
# download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main
|
||||
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
||||
else:
|
||||
local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=hf_cache_dir)
|
||||
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
||||
vocoder = bigvgan.BigVGAN.from_pretrained(
|
||||
"nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False, cache_dir=hf_cache_dir
|
||||
)
|
||||
|
||||
vocoder.remove_weight_norm()
|
||||
vocoder = vocoder.eval().to(device)
|
||||
@@ -149,7 +150,7 @@ def initialize_asr_pipeline(device: str = device, dtype=None):
|
||||
dtype = (
|
||||
torch.float16
|
||||
if "cuda" in device
|
||||
and torch.cuda.get_device_properties(device).major >= 6
|
||||
and torch.cuda.get_device_properties(device).major >= 7
|
||||
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
||||
else torch.float32
|
||||
)
|
||||
@@ -186,7 +187,7 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
||||
dtype = (
|
||||
torch.float16
|
||||
if "cuda" in device
|
||||
and torch.cuda.get_device_properties(device).major >= 6
|
||||
and torch.cuda.get_device_properties(device).major >= 7
|
||||
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
||||
else torch.float32
|
||||
)
|
||||
@@ -289,7 +290,7 @@ def remove_silence_edges(audio, silence_threshold=-42):
|
||||
# preprocess reference audio and text
|
||||
|
||||
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device):
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
|
||||
show_info("Converting audio...")
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
aseg = AudioSegment.from_file(ref_audio_orig)
|
||||
@@ -302,7 +303,7 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 15s, clipping short. (1)")
|
||||
show_info("Audio is over 12s, clipping short. (1)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
|
||||
@@ -314,7 +315,7 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 15s, clipping short. (2)")
|
||||
show_info("Audio is over 12s, clipping short. (2)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
|
||||
@@ -323,7 +324,7 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
|
||||
# 3. if no proper silence found for clipping
|
||||
if len(aseg) > 12000:
|
||||
aseg = aseg[:12000]
|
||||
show_info("Audio is over 15s, clipping short. (3)")
|
||||
show_info("Audio is over 12s, clipping short. (3)")
|
||||
|
||||
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
|
||||
aseg.export(f.name, format="wav")
|
||||
|
||||
@@ -270,7 +270,7 @@ class CFM(nn.Module):
|
||||
else:
|
||||
drop_text = False
|
||||
|
||||
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
|
||||
# if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here
|
||||
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
|
||||
pred = self.transformer(
|
||||
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
|
||||
|
||||
@@ -51,7 +51,7 @@ class Trainer:
|
||||
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
||||
is_local_vocoder: bool = False, # use local path vocoder
|
||||
local_vocoder_path: str = "", # local vocoder path
|
||||
cfg_dict: dict = dict(), # training config
|
||||
model_cfg_dict: dict = dict(), # training config
|
||||
):
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
|
||||
@@ -73,8 +73,8 @@ class Trainer:
|
||||
else:
|
||||
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
||||
|
||||
if not cfg_dict:
|
||||
cfg_dict = {
|
||||
if not model_cfg_dict:
|
||||
model_cfg_dict = {
|
||||
"epochs": epochs,
|
||||
"learning_rate": learning_rate,
|
||||
"num_warmup_updates": num_warmup_updates,
|
||||
@@ -85,11 +85,11 @@ class Trainer:
|
||||
"max_grad_norm": max_grad_norm,
|
||||
"noise_scheduler": noise_scheduler,
|
||||
}
|
||||
cfg_dict["gpus"] = self.accelerator.num_processes
|
||||
model_cfg_dict["gpus"] = self.accelerator.num_processes
|
||||
self.accelerator.init_trackers(
|
||||
project_name=wandb_project,
|
||||
init_kwargs=init_kwargs,
|
||||
config=cfg_dict,
|
||||
config=model_cfg_dict,
|
||||
)
|
||||
|
||||
elif self.logger == "tensorboard":
|
||||
@@ -350,7 +350,7 @@ class Trainer:
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
|
||||
desc=f"Epoch {epoch+1}/{self.epochs}",
|
||||
desc=f"Epoch {epoch + 1}/{self.epochs}",
|
||||
unit="update",
|
||||
disable=not self.accelerator.is_local_main_process,
|
||||
initial=progress_bar_initial,
|
||||
@@ -395,6 +395,9 @@ class Trainer:
|
||||
self.writer.add_scalar("loss", loss.item(), global_update)
|
||||
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
|
||||
|
||||
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
|
||||
self.save_checkpoint(global_update, last=True)
|
||||
|
||||
if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients:
|
||||
self.save_checkpoint(global_update)
|
||||
|
||||
@@ -428,9 +431,7 @@ class Trainer:
|
||||
torchaudio.save(
|
||||
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
|
||||
)
|
||||
|
||||
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
|
||||
self.save_checkpoint(global_update, last=True)
|
||||
self.model.train()
|
||||
|
||||
self.save_checkpoint(global_update, last=True)
|
||||
|
||||
|
||||
3
src/f5_tts/runtime/triton_trtllm/Dockerfile.server
Normal file
3
src/f5_tts/runtime/triton_trtllm/Dockerfile.server
Normal file
@@ -0,0 +1,3 @@
|
||||
FROM nvcr.io/nvidia/tritonserver:24.12-py3
|
||||
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 jieba pypinyin librosa vocos
|
||||
WORKDIR /workspace
|
||||
47
src/f5_tts/runtime/triton_trtllm/README.md
Normal file
47
src/f5_tts/runtime/triton_trtllm/README.md
Normal file
@@ -0,0 +1,47 @@
|
||||
## Triton Inference Serving Best Practice for F5-TTS
|
||||
|
||||
### Quick Start
|
||||
Directly launch the service using docker compose.
|
||||
```sh
|
||||
# TODO: support F5TTS_v1_Base
|
||||
MODEL=F5TTS_Base docker compose up
|
||||
```
|
||||
|
||||
### Build Image
|
||||
Build the docker image from scratch.
|
||||
```sh
|
||||
docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
|
||||
```
|
||||
|
||||
### Create Docker Container
|
||||
```sh
|
||||
your_mount_dir=/mnt:/mnt
|
||||
docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
|
||||
```
|
||||
|
||||
### Export Models to TensorRT-LLM and Launch Server
|
||||
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
|
||||
```sh
|
||||
bash run.sh 0 4 F5TTS_Base
|
||||
```
|
||||
|
||||
### HTTP Client
|
||||
```sh
|
||||
python3 client_http.py
|
||||
```
|
||||
|
||||
### Benchmark using Dataset
|
||||
```sh
|
||||
num_task=2
|
||||
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
|
||||
```
|
||||
|
||||
### Benchmark Results
|
||||
Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
|
||||
|
||||
| Model | Concurrency | Avg Latency | RTF |
|
||||
|-------|-------------|----------------|-------|
|
||||
| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
|
||||
|
||||
### Credits
|
||||
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
|
||||
470
src/f5_tts/runtime/triton_trtllm/client_grpc.py
Normal file
470
src/f5_tts/runtime/triton_trtllm/client_grpc.py
Normal file
@@ -0,0 +1,470 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# 2023 Nvidia (authors: Yuekai Zhang)
|
||||
# 2023 Recurrent.ai (authors: Songtao Shi)
|
||||
# See LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script supports to load dataset from huggingface and sends it to the server
|
||||
for decoding, in parallel.
|
||||
|
||||
Usage:
|
||||
num_task=2
|
||||
|
||||
# For offline F5-TTS
|
||||
python3 client_grpc.py \
|
||||
--server-addr localhost \
|
||||
--model-name f5_tts \
|
||||
--num-tasks $num_task \
|
||||
--huggingface-dataset yuekai/seed_tts \
|
||||
--split-name test_zh \
|
||||
--log-dir ./log_concurrent_tasks_${num_task}
|
||||
|
||||
# For offline Spark-TTS-0.5B
|
||||
python3 client_grpc.py \
|
||||
--server-addr localhost \
|
||||
--model-name spark_tts \
|
||||
--num-tasks $num_task \
|
||||
--huggingface-dataset yuekai/seed_tts \
|
||||
--split-name wenetspeech4tts \
|
||||
--log-dir ./log_concurrent_tasks_${num_task}
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import tritonclient
|
||||
import tritonclient.grpc.aio as grpcclient
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
|
||||
def write_triton_stats(stats, summary_file):
|
||||
with open(summary_file, "w") as summary_f:
|
||||
model_stats = stats["model_stats"]
|
||||
# write a note, the log is from triton_client.get_inference_statistics(), to better human readability
|
||||
summary_f.write(
|
||||
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
|
||||
)
|
||||
summary_f.write("To learn more about the log, please refer to: \n")
|
||||
summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
|
||||
summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
|
||||
summary_f.write(
|
||||
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
|
||||
)
|
||||
summary_f.write(
|
||||
"However, there is a trade-off between the increased queue time and the increased batch size. \n"
|
||||
)
|
||||
summary_f.write(
|
||||
"You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
|
||||
)
|
||||
summary_f.write(
|
||||
"See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
|
||||
)
|
||||
for model_state in model_stats:
|
||||
if "last_inference" not in model_state:
|
||||
continue
|
||||
summary_f.write(f"model name is {model_state['name']} \n")
|
||||
model_inference_stats = model_state["inference_stats"]
|
||||
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
|
||||
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
|
||||
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
|
||||
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
|
||||
summary_f.write(
|
||||
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
|
||||
)
|
||||
model_batch_stats = model_state["batch_stats"]
|
||||
for batch in model_batch_stats:
|
||||
batch_size = int(batch["batch_size"])
|
||||
compute_input = batch["compute_input"]
|
||||
compute_output = batch["compute_output"]
|
||||
compute_infer = batch["compute_infer"]
|
||||
batch_count = int(compute_infer["count"])
|
||||
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
|
||||
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
|
||||
compute_input_time_ms = int(compute_input["ns"]) / 1e6
|
||||
compute_output_time_ms = int(compute_output["ns"]) / 1e6
|
||||
summary_f.write(
|
||||
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
|
||||
)
|
||||
summary_f.write(
|
||||
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
|
||||
)
|
||||
summary_f.write(
|
||||
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
|
||||
)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument(
|
||||
"--server-addr",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Address of the server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--server-port",
|
||||
type=int,
|
||||
default=8001,
|
||||
help="Grpc port of the triton server, default is 8001",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reference-audio",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reference-text",
|
||||
type=str,
|
||||
default="",
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target-text",
|
||||
type=str,
|
||||
default="",
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--huggingface-dataset",
|
||||
type=str,
|
||||
default="yuekai/seed_tts",
|
||||
help="dataset name in huggingface dataset hub",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
||||
help="dataset split name, default is 'test'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the manifest dir which includes wav.scp trans.txt files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="f5_tts",
|
||||
choices=["f5_tts", "spark_tts"],
|
||||
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-tasks",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of concurrent tasks for sending",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Controls how frequently we print the log.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--compute-wer",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="""True to compute WER.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-dir",
|
||||
type=str,
|
||||
required=False,
|
||||
default="./tmp",
|
||||
help="log directory",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Inference batch_size per request for offline mode.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_audio(wav_path, target_sample_rate=16000):
|
||||
assert target_sample_rate == 16000, "hard coding in server"
|
||||
if isinstance(wav_path, dict):
|
||||
waveform = wav_path["array"]
|
||||
sample_rate = wav_path["sampling_rate"]
|
||||
else:
|
||||
waveform, sample_rate = sf.read(wav_path)
|
||||
if sample_rate != target_sample_rate:
|
||||
from scipy.signal import resample
|
||||
|
||||
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
|
||||
waveform = resample(waveform, num_samples)
|
||||
return waveform, target_sample_rate
|
||||
|
||||
|
||||
async def send(
|
||||
manifest_item_list: list,
|
||||
name: str,
|
||||
triton_client: tritonclient.grpc.aio.InferenceServerClient,
|
||||
protocol_client: types.ModuleType,
|
||||
log_interval: int,
|
||||
model_name: str,
|
||||
padding_duration: int = None,
|
||||
audio_save_dir: str = "./",
|
||||
save_sample_rate: int = 16000,
|
||||
):
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
task_id = int(name[5:])
|
||||
|
||||
print(f"manifest_item_list: {manifest_item_list}")
|
||||
for i, item in enumerate(manifest_item_list):
|
||||
if i % log_interval == 0:
|
||||
print(f"{name}: {i}/{len(manifest_item_list)}")
|
||||
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
||||
duration = len(waveform) / sample_rate
|
||||
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
||||
|
||||
reference_text, target_text = item["reference_text"], item["target_text"]
|
||||
|
||||
estimated_target_duration = duration / len(reference_text) * len(target_text)
|
||||
|
||||
if padding_duration:
|
||||
# padding to nearset 10 seconds
|
||||
samples = np.zeros(
|
||||
(
|
||||
1,
|
||||
padding_duration
|
||||
* sample_rate
|
||||
* ((int(estimated_target_duration + duration) // padding_duration) + 1),
|
||||
),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
samples[0, : len(waveform)] = waveform
|
||||
else:
|
||||
samples = waveform
|
||||
|
||||
samples = samples.reshape(1, -1).astype(np.float32)
|
||||
|
||||
inputs = [
|
||||
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
|
||||
protocol_client.InferInput("reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)),
|
||||
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
|
||||
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
|
||||
]
|
||||
inputs[0].set_data_from_numpy(samples)
|
||||
inputs[1].set_data_from_numpy(lengths)
|
||||
|
||||
input_data_numpy = np.array([reference_text], dtype=object)
|
||||
input_data_numpy = input_data_numpy.reshape((1, 1))
|
||||
inputs[2].set_data_from_numpy(input_data_numpy)
|
||||
|
||||
input_data_numpy = np.array([target_text], dtype=object)
|
||||
input_data_numpy = input_data_numpy.reshape((1, 1))
|
||||
inputs[3].set_data_from_numpy(input_data_numpy)
|
||||
|
||||
outputs = [protocol_client.InferRequestedOutput("waveform")]
|
||||
|
||||
sequence_id = 100000000 + i + task_id * 10
|
||||
start = time.time()
|
||||
response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
|
||||
|
||||
audio = response.as_numpy("waveform").reshape(-1)
|
||||
|
||||
end = time.time() - start
|
||||
|
||||
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
||||
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
|
||||
|
||||
latency_data.append((end, estimated_target_duration))
|
||||
total_duration += estimated_target_duration
|
||||
|
||||
return total_duration, latency_data
|
||||
|
||||
|
||||
def load_manifests(manifest_path):
|
||||
with open(manifest_path, "r") as f:
|
||||
manifest_list = []
|
||||
for line in f:
|
||||
assert len(line.strip().split("|")) == 4
|
||||
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
||||
utt = Path(utt).stem
|
||||
# gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
|
||||
if not os.path.isabs(prompt_wav):
|
||||
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
|
||||
manifest_list.append(
|
||||
{
|
||||
"audio_filepath": prompt_wav,
|
||||
"reference_text": prompt_text,
|
||||
"target_text": gt_text,
|
||||
"target_audio_path": utt,
|
||||
}
|
||||
)
|
||||
return manifest_list
|
||||
|
||||
|
||||
def split_data(data, k):
|
||||
n = len(data)
|
||||
if n < k:
|
||||
print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
|
||||
k = n
|
||||
|
||||
quotient = n // k
|
||||
remainder = n % k
|
||||
|
||||
result = []
|
||||
start = 0
|
||||
for i in range(k):
|
||||
if i < remainder:
|
||||
end = start + quotient + 1
|
||||
else:
|
||||
end = start + quotient
|
||||
|
||||
result.append(data[start:end])
|
||||
start = end
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def main():
|
||||
args = get_args()
|
||||
url = f"{args.server_addr}:{args.server_port}"
|
||||
|
||||
triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)
|
||||
protocol_client = grpcclient
|
||||
|
||||
if args.reference_audio:
|
||||
args.num_tasks = 1
|
||||
args.log_interval = 1
|
||||
manifest_item_list = [
|
||||
{
|
||||
"reference_text": args.reference_text,
|
||||
"target_text": args.target_text,
|
||||
"audio_filepath": args.reference_audio,
|
||||
"target_audio_path": "test",
|
||||
}
|
||||
]
|
||||
elif args.huggingface_dataset:
|
||||
import datasets
|
||||
|
||||
dataset = datasets.load_dataset(
|
||||
args.huggingface_dataset,
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
manifest_item_list = []
|
||||
for i in range(len(dataset)):
|
||||
manifest_item_list.append(
|
||||
{
|
||||
"audio_filepath": dataset[i]["prompt_audio"],
|
||||
"reference_text": dataset[i]["prompt_text"],
|
||||
"target_audio_path": dataset[i]["id"],
|
||||
"target_text": dataset[i]["target_text"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
manifest_item_list = load_manifests(args.manifest_path)
|
||||
|
||||
args.num_tasks = min(args.num_tasks, len(manifest_item_list))
|
||||
manifest_item_list = split_data(manifest_item_list, args.num_tasks)
|
||||
|
||||
os.makedirs(args.log_dir, exist_ok=True)
|
||||
tasks = []
|
||||
start_time = time.time()
|
||||
for i in range(args.num_tasks):
|
||||
task = asyncio.create_task(
|
||||
send(
|
||||
manifest_item_list[i],
|
||||
name=f"task-{i}",
|
||||
triton_client=triton_client,
|
||||
protocol_client=protocol_client,
|
||||
log_interval=args.log_interval,
|
||||
model_name=args.model_name,
|
||||
audio_save_dir=args.log_dir,
|
||||
padding_duration=1,
|
||||
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
ans_list = await asyncio.gather(*tasks)
|
||||
|
||||
end_time = time.time()
|
||||
elapsed = end_time - start_time
|
||||
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
for ans in ans_list:
|
||||
total_duration += ans[0]
|
||||
latency_data += ans[1]
|
||||
|
||||
rtf = elapsed / total_duration
|
||||
|
||||
s = f"RTF: {rtf:.4f}\n"
|
||||
s += f"total_duration: {total_duration:.3f} seconds\n"
|
||||
s += f"({total_duration / 3600:.2f} hours)\n"
|
||||
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
|
||||
|
||||
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
|
||||
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
|
||||
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
|
||||
s += f"latency_variance: {latency_variance:.2f}\n"
|
||||
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
|
||||
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
|
||||
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
|
||||
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
|
||||
s += f"average_latency_ms: {latency_ms:.2f}\n"
|
||||
|
||||
print(s)
|
||||
if args.manifest_path:
|
||||
name = Path(args.manifest_path).stem
|
||||
elif args.split_name:
|
||||
name = args.split_name
|
||||
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
|
||||
f.write(s)
|
||||
|
||||
stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
|
||||
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
||||
|
||||
metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True)
|
||||
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
||||
json.dump(metadata, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
142
src/f5_tts/runtime/triton_trtllm/client_http.py
Normal file
142
src/f5_tts/runtime/triton_trtllm/client_http.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import requests
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument(
|
||||
"--server-url",
|
||||
type=str,
|
||||
default="localhost:8000",
|
||||
help="Address of the server",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reference-audio",
|
||||
type=str,
|
||||
default="../../infer/examples/basic/basic_ref_en.wav",
|
||||
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reference-text",
|
||||
type=str,
|
||||
default="Some call me nature, others call me mother nature.",
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target-text",
|
||||
type=str,
|
||||
default="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="f5_tts",
|
||||
choices=["f5_tts", "spark_tts"],
|
||||
help="triton model_repo module name to request",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-audio",
|
||||
type=str,
|
||||
default="output.wav",
|
||||
help="Path to save the output audio",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def prepare_request(
|
||||
samples,
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate=16000,
|
||||
audio_save_dir: str = "./",
|
||||
):
|
||||
assert len(samples.shape) == 1, "samples should be 1D"
|
||||
lengths = np.array([[len(samples)]], dtype=np.int32)
|
||||
samples = samples.reshape(1, -1).astype(np.float32)
|
||||
|
||||
data = {
|
||||
"inputs": [
|
||||
{"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
|
||||
{
|
||||
"name": "reference_wav_len",
|
||||
"shape": lengths.shape,
|
||||
"datatype": "INT32",
|
||||
"data": lengths.tolist(),
|
||||
},
|
||||
{"name": "reference_text", "shape": [1, 1], "datatype": "BYTES", "data": [reference_text]},
|
||||
{"name": "target_text", "shape": [1, 1], "datatype": "BYTES", "data": [target_text]},
|
||||
]
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def load_audio(wav_path, target_sample_rate=16000):
|
||||
assert target_sample_rate == 16000, "hard coding in server"
|
||||
if isinstance(wav_path, dict):
|
||||
samples = wav_path["array"]
|
||||
sample_rate = wav_path["sampling_rate"]
|
||||
else:
|
||||
samples, sample_rate = sf.read(wav_path)
|
||||
if sample_rate != target_sample_rate:
|
||||
from scipy.signal import resample
|
||||
|
||||
num_samples = int(len(samples) * (target_sample_rate / sample_rate))
|
||||
samples = resample(samples, num_samples)
|
||||
return samples, target_sample_rate
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
server_url = args.server_url
|
||||
if not server_url.startswith(("http://", "https://")):
|
||||
server_url = f"http://{server_url}"
|
||||
|
||||
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
||||
samples, sr = load_audio(args.reference_audio)
|
||||
assert sr == 16000, "sample rate hardcoded in server"
|
||||
|
||||
samples = np.array(samples, dtype=np.float32)
|
||||
data = prepare_request(samples, args.reference_text, args.target_text)
|
||||
|
||||
rsp = requests.post(
|
||||
url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
|
||||
)
|
||||
result = rsp.json()
|
||||
audio = result["outputs"][0]["data"]
|
||||
audio = np.array(audio, dtype=np.float32)
|
||||
sf.write(args.output_audio, audio, 24000, "PCM_16")
|
||||
20
src/f5_tts/runtime/triton_trtllm/docker-compose.yml
Normal file
20
src/f5_tts/runtime/triton_trtllm/docker-compose.yml
Normal file
@@ -0,0 +1,20 @@
|
||||
services:
|
||||
tts:
|
||||
image: soar97/triton-f5-tts:24.12
|
||||
shm_size: '1gb'
|
||||
ports:
|
||||
- "8000:8000"
|
||||
- "8001:8001"
|
||||
- "8002:8002"
|
||||
environment:
|
||||
- PYTHONIOENCODING=utf-8
|
||||
- MODEL_ID=${MODEL_ID}
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
device_ids: ['0']
|
||||
capabilities: [gpu]
|
||||
command: >
|
||||
/bin/bash -c "pip install vocos && rm -rf F5-TTS && git clone https://github.com/SWivid/F5-TTS.git && cd F5-TTS/src/f5_tts/runtime/triton_trtllm/ && bash run.sh 0 4 $MODEL"
|
||||
@@ -0,0 +1,431 @@
|
||||
import tensorrt as trt
|
||||
import os
|
||||
import math
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from functools import wraps
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.runtime.session import Session
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
|
||||
# Audio tensor case: batch, seq_len, feature_len
|
||||
# position_ids case: batch, seq_len
|
||||
assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
|
||||
|
||||
# Initialize a list to collect valid sequences
|
||||
valid_sequences = []
|
||||
|
||||
for i in range(input_tensor.shape[0]):
|
||||
valid_length = input_tensor_lengths[i]
|
||||
valid_sequences.append(input_tensor[i, :valid_length])
|
||||
|
||||
# Concatenate all valid sequences along the batch dimension
|
||||
output_tensor = torch.cat(valid_sequences, dim=0).contiguous()
|
||||
return output_tensor
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096):
|
||||
super().__init__()
|
||||
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
|
||||
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
||||
|
||||
def forward(self, text):
|
||||
# only keep tensors with value not -1
|
||||
text_mask = text != -1
|
||||
text_pad_cut_off_index = text_mask.sum(dim=1).max()
|
||||
|
||||
text = text[:, :text_pad_cut_off_index]
|
||||
text = self.text_embed(text)
|
||||
text = text + self.freqs_cis[: text.shape[1], :]
|
||||
for block in self.text_blocks:
|
||||
text = block(text)
|
||||
# padding text to the original length
|
||||
# text shape: B,seq_len,C
|
||||
# pad at the second dimension
|
||||
text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
|
||||
return text
|
||||
|
||||
|
||||
class GRN(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
|
||||
def forward(self, x):
|
||||
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
||||
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (x * Nx) + self.beta + x
|
||||
|
||||
|
||||
class ConvNeXtV2Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
intermediate_dim: int,
|
||||
dilation: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
padding = (dilation * (7 - 1)) // 2
|
||||
self.dwconv = nn.Conv1d(
|
||||
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
||||
) # depthwise conv
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
self.grn = GRN(intermediate_dim)
|
||||
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
x = x.transpose(1, 2) # b n d -> b d n
|
||||
x = self.dwconv(x)
|
||||
x = x.transpose(1, 2) # b d n -> b n d
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.grn(x)
|
||||
x = self.pwconv2(x)
|
||||
return residual + x
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
||||
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
freqs_cos = torch.cos(freqs) # real part
|
||||
freqs_sin = torch.sin(freqs) # imaginary part
|
||||
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
||||
|
||||
|
||||
def load_checkpoint(ckpt_path, use_ema=True):
|
||||
checkpoint = torch.load(ckpt_path, weights_only=True)
|
||||
if use_ema:
|
||||
checkpoint["model_state_dict"] = {
|
||||
k.replace("ema_model.", ""): v
|
||||
for k, v in checkpoint["ema_model_state_dict"].items()
|
||||
if k not in ["initted", "step"]
|
||||
}
|
||||
dict_state = checkpoint["model_state_dict"]
|
||||
text_embed_dict = {}
|
||||
for key in dict_state.keys():
|
||||
# transformer.text_embed.text_embed.weight -> text_embed.weight
|
||||
if "text_embed" in key:
|
||||
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
|
||||
return text_embed_dict
|
||||
|
||||
|
||||
class F5TTS(object):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
debug_mode=True,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
tllm_model_dir: Optional[str] = None,
|
||||
model_path: Optional[str] = None,
|
||||
vocab_size: Optional[int] = None,
|
||||
):
|
||||
self.dtype = config["pretrained_config"]["dtype"]
|
||||
|
||||
rank = tensorrt_llm.mpi_rank()
|
||||
world_size = config["pretrained_config"]["mapping"]["world_size"]
|
||||
cp_size = config["pretrained_config"]["mapping"]["cp_size"]
|
||||
tp_size = config["pretrained_config"]["mapping"]["tp_size"]
|
||||
pp_size = config["pretrained_config"]["mapping"]["pp_size"]
|
||||
assert pp_size == 1
|
||||
self.mapping = tensorrt_llm.Mapping(
|
||||
world_size=world_size, rank=rank, cp_size=cp_size, tp_size=tp_size, pp_size=1, gpus_per_node=1
|
||||
)
|
||||
|
||||
local_rank = rank % self.mapping.gpus_per_node
|
||||
self.device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
torch.cuda.set_device(self.device)
|
||||
|
||||
self.stream = stream
|
||||
if self.stream is None:
|
||||
self.stream = torch.cuda.Stream(self.device)
|
||||
torch.cuda.set_stream(self.stream)
|
||||
|
||||
engine_file = os.path.join(tllm_model_dir, f"rank{rank}.engine")
|
||||
logger.info(f"Loading engine from {engine_file}")
|
||||
with open(engine_file, "rb") as f:
|
||||
engine_buffer = f.read()
|
||||
|
||||
assert engine_buffer is not None
|
||||
|
||||
self.session = Session.from_serialized_engine(engine_buffer)
|
||||
|
||||
self.debug_mode = debug_mode
|
||||
|
||||
self.inputs = {}
|
||||
self.outputs = {}
|
||||
self.buffer_allocated = False
|
||||
|
||||
expected_tensor_names = ["noise", "cond", "time", "rope_cos", "rope_sin", "input_lengths", "denoised"]
|
||||
|
||||
found_tensor_names = [self.session.engine.get_tensor_name(i) for i in range(self.session.engine.num_io_tensors)]
|
||||
if not self.debug_mode and set(expected_tensor_names) != set(found_tensor_names):
|
||||
logger.error(
|
||||
f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}"
|
||||
)
|
||||
logger.error(
|
||||
f"Those tensors in engine are not expected: {set(found_tensor_names).difference(set(expected_tensor_names))}"
|
||||
)
|
||||
logger.error(f"Expected tensor names: {expected_tensor_names}")
|
||||
logger.error(f"Found tensor names: {found_tensor_names}")
|
||||
raise RuntimeError("Tensor names in engine are not the same as expected.")
|
||||
if self.debug_mode:
|
||||
self.debug_tensors = list(set(found_tensor_names) - set(expected_tensor_names))
|
||||
|
||||
self.max_mel_len = 4096
|
||||
self.text_embedding = TextEmbedding(
|
||||
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
|
||||
).to(self.device)
|
||||
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
|
||||
|
||||
self.target_audio_sample_rate = 24000
|
||||
self.target_rms = 0.15 # target rms for audio
|
||||
self.n_fft = 1024
|
||||
self.win_length = 1024
|
||||
self.hop_length = 256
|
||||
self.n_mel_channels = 100
|
||||
# self.max_mel_len = 3000
|
||||
self.head_dim = 64
|
||||
self.base_rescale_factor = 1.0
|
||||
self.interpolation_factor = 1.0
|
||||
base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
|
||||
freqs = torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq) / self.interpolation_factor
|
||||
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
|
||||
self.rope_cos = self.freqs.cos().half()
|
||||
self.rope_sin = self.freqs.sin().half()
|
||||
self.nfe_steps = 16
|
||||
t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32)
|
||||
time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t)
|
||||
delta_t = torch.diff(time_step)
|
||||
# WAR: hard coding 256 here
|
||||
tmp_dim = 256
|
||||
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
|
||||
half_dim = tmp_dim // 2
|
||||
emb_factor = math.log(10000) / (half_dim - 1)
|
||||
emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
|
||||
for i in range(self.nfe_steps):
|
||||
emb = time_step[i] * emb_factor
|
||||
time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
self.time_expand = time_expand.to(self.device)
|
||||
self.delta_t = torch.cat((delta_t, delta_t), dim=0).contiguous().to(self.device)
|
||||
|
||||
def _tensor_dtype(self, name):
|
||||
# return torch dtype given tensor name for convenience
|
||||
dtype = trt_dtype_to_torch(self.session.engine.get_tensor_dtype(name))
|
||||
return dtype
|
||||
|
||||
def _setup(self, batch_size, seq_len):
|
||||
for i in range(self.session.engine.num_io_tensors):
|
||||
name = self.session.engine.get_tensor_name(i)
|
||||
if self.session.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
|
||||
shape = list(self.session.engine.get_tensor_shape(name))
|
||||
shape[0] = batch_size
|
||||
shape[1] = seq_len
|
||||
self.outputs[name] = torch.empty(shape, dtype=self._tensor_dtype(name), device=self.device)
|
||||
|
||||
self.buffer_allocated = True
|
||||
|
||||
def cuda_stream_guard(func):
|
||||
"""Sync external stream and set current stream to the one bound to the session. Reset on exit."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
external_stream = torch.cuda.current_stream()
|
||||
if external_stream != self.stream:
|
||||
external_stream.synchronize()
|
||||
torch.cuda.set_stream(self.stream)
|
||||
ret = func(self, *args, **kwargs)
|
||||
if external_stream != self.stream:
|
||||
self.stream.synchronize()
|
||||
torch.cuda.set_stream(external_stream)
|
||||
return ret
|
||||
|
||||
return wrapper
|
||||
|
||||
@cuda_stream_guard
|
||||
def forward(
|
||||
self,
|
||||
noise: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
time_expand: torch.Tensor,
|
||||
rope_cos: torch.Tensor,
|
||||
rope_sin: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
delta_t: torch.Tensor,
|
||||
use_perf: bool = False,
|
||||
):
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_push("flow matching")
|
||||
cfg_strength = 2.0
|
||||
batch_size = noise.shape[0]
|
||||
half_batch = batch_size // 2
|
||||
noise_half = noise[:half_batch] # Store the initial half of noise
|
||||
|
||||
input_type = str_dtype_to_torch(self.dtype)
|
||||
|
||||
# Keep a copy of the initial tensors
|
||||
cond = cond.to(input_type)
|
||||
rope_cos = rope_cos.to(input_type)
|
||||
rope_sin = rope_sin.to(input_type)
|
||||
input_lengths = input_lengths.to(str_dtype_to_torch("int32"))
|
||||
|
||||
# Instead of iteratively updating noise within a single model context,
|
||||
# we'll do a single forward pass for each iteration with fresh context setup
|
||||
for i in range(self.nfe_steps):
|
||||
# Re-setup the buffers for clean execution
|
||||
self._setup(batch_size, noise.shape[1])
|
||||
if not self.buffer_allocated:
|
||||
raise RuntimeError("Buffer not allocated, please call setup first!")
|
||||
|
||||
# Re-create combined noises for this iteration
|
||||
current_noise = torch.cat([noise_half, noise_half], dim=0).to(input_type)
|
||||
|
||||
# Get time step for this iteration
|
||||
current_time = time_expand[:, i].to(input_type)
|
||||
|
||||
# Create fresh input dictionary for this iteration
|
||||
current_inputs = {
|
||||
"noise": current_noise,
|
||||
"cond": cond,
|
||||
"time": current_time,
|
||||
"rope_cos": rope_cos,
|
||||
"rope_sin": rope_sin,
|
||||
"input_lengths": input_lengths,
|
||||
}
|
||||
|
||||
# Update inputs and set shapes
|
||||
self.inputs.clear() # Clear previous inputs
|
||||
self.inputs.update(**current_inputs)
|
||||
self.session.set_shapes(self.inputs)
|
||||
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_push(f"execute {i}")
|
||||
ok = self.session.run(self.inputs, self.outputs, self.stream.cuda_stream)
|
||||
assert ok, "Failed to execute model"
|
||||
# self.session.context.execute_async_v3(self.stream.cuda_stream)
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
# Process results
|
||||
t_scale = delta_t[i].unsqueeze(0).to(input_type)
|
||||
|
||||
# Extract predictions
|
||||
pred_cond = self.outputs["denoised"][:half_batch]
|
||||
pred_uncond = self.outputs["denoised"][half_batch:]
|
||||
|
||||
# Apply classifier-free guidance with safeguards
|
||||
guidance = pred_cond + (pred_cond - pred_uncond) * cfg_strength
|
||||
# Calculate update for noise
|
||||
noise_half = noise_half + guidance * t_scale
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
return noise_half
|
||||
|
||||
def sample(
|
||||
self,
|
||||
text_pad_sequence: torch.Tensor,
|
||||
ref_mel_batch: torch.Tensor,
|
||||
ref_mel_len_batch: torch.Tensor,
|
||||
estimated_reference_target_mel_len: List[int],
|
||||
remove_input_padding: bool = False,
|
||||
use_perf: bool = False,
|
||||
):
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_push("text embedding")
|
||||
batch = text_pad_sequence.shape[0]
|
||||
max_seq_len = ref_mel_batch.shape[1]
|
||||
|
||||
text_pad_sequence_drop = torch.cat(
|
||||
(text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
|
||||
)
|
||||
|
||||
text_embedding_drop_list = []
|
||||
for i in range(batch + 1):
|
||||
text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
|
||||
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
|
||||
|
||||
text_embedding = text_embedding_drop_condition[:-1]
|
||||
# text_embedding_drop B,T,C batch should be the same
|
||||
text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
|
||||
|
||||
noise = torch.randn_like(ref_mel_batch).to(self.device)
|
||||
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
||||
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
||||
|
||||
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
|
||||
cat_mel_text_drop = torch.cat(
|
||||
(
|
||||
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),
|
||||
text_embedding_drop,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous()
|
||||
|
||||
# Convert estimated_reference_target_mel_len to tensor
|
||||
input_lengths = torch.tensor(estimated_reference_target_mel_len, dtype=torch.int32)
|
||||
|
||||
# combine above along the batch dimension
|
||||
inputs = {
|
||||
"noise": torch.cat((noise, noise), dim=0).contiguous(),
|
||||
"cond": torch.cat((cat_mel_text, cat_mel_text_drop), dim=0).contiguous(),
|
||||
"time_expand": time_expand,
|
||||
"rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(),
|
||||
"rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(),
|
||||
"input_lengths": torch.cat((input_lengths, input_lengths), dim=0).contiguous(),
|
||||
"delta_t": self.delta_t,
|
||||
}
|
||||
if use_perf and remove_input_padding:
|
||||
torch.cuda.nvtx.range_push("remove input padding")
|
||||
if remove_input_padding:
|
||||
max_seq_len = inputs["cond"].shape[1]
|
||||
inputs["noise"] = remove_tensor_padding(inputs["noise"], inputs["input_lengths"])
|
||||
inputs["cond"] = remove_tensor_padding(inputs["cond"], inputs["input_lengths"])
|
||||
# for time_expand, convert from B,D to B,T,D by repeat
|
||||
inputs["time_expand"] = inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1)
|
||||
inputs["time_expand"] = remove_tensor_padding(inputs["time_expand"], inputs["input_lengths"])
|
||||
inputs["rope_cos"] = remove_tensor_padding(inputs["rope_cos"], inputs["input_lengths"])
|
||||
inputs["rope_sin"] = remove_tensor_padding(inputs["rope_sin"], inputs["input_lengths"])
|
||||
if use_perf and remove_input_padding:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(self.device)
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
start_time = time.time()
|
||||
denoised = self.forward(**inputs, use_perf=use_perf)
|
||||
cost_time = time.time() - start_time
|
||||
if use_perf and remove_input_padding:
|
||||
torch.cuda.nvtx.range_push("remove input padding output")
|
||||
if remove_input_padding:
|
||||
denoised_list = []
|
||||
start_idx = 0
|
||||
for i in range(batch):
|
||||
denoised_list.append(denoised[start_idx : start_idx + inputs["input_lengths"][i]])
|
||||
start_idx += inputs["input_lengths"][i]
|
||||
if use_perf and remove_input_padding:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
return denoised_list, cost_time
|
||||
return denoised, cost_time
|
||||
@@ -0,0 +1,275 @@
|
||||
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions
|
||||
# are met:
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import json
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
import torchaudio
|
||||
import jieba
|
||||
import triton_python_backend_utils as pb_utils
|
||||
from pypinyin import Style, lazy_pinyin
|
||||
import os
|
||||
from f5_tts_trtllm import F5TTS
|
||||
|
||||
|
||||
def get_tokenizer(vocab_file_path: str):
|
||||
"""
|
||||
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
||||
- "char" for char-wise tokenizer, need .txt vocab_file
|
||||
- "byte" for utf-8 tokenizer
|
||||
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
||||
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
||||
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
||||
- if use "byte", set to 256 (unicode byte range)
|
||||
"""
|
||||
with open(vocab_file_path, "r", encoding="utf-8") as f:
|
||||
vocab_char_map = {}
|
||||
for i, char in enumerate(f):
|
||||
vocab_char_map[char[:-1]] = i
|
||||
vocab_size = len(vocab_char_map)
|
||||
return vocab_char_map, vocab_size
|
||||
|
||||
|
||||
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
|
||||
final_reference_target_texts_list = []
|
||||
custom_trans = str.maketrans(
|
||||
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
||||
) # add custom trans here, to address oov
|
||||
|
||||
def is_chinese(c):
|
||||
return "\u3100" <= c <= "\u9fff" # common chinese characters
|
||||
|
||||
for text in reference_target_texts_list:
|
||||
char_list = []
|
||||
text = text.translate(custom_trans)
|
||||
for seg in jieba.cut(text):
|
||||
seg_byte_len = len(bytes(seg, "UTF-8"))
|
||||
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
||||
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
||||
char_list.append(" ")
|
||||
char_list.extend(seg)
|
||||
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
|
||||
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
||||
for i, c in enumerate(seg):
|
||||
if is_chinese(c):
|
||||
char_list.append(" ")
|
||||
char_list.append(seg_[i])
|
||||
else: # if mixed characters, alphabets and symbols
|
||||
for c in seg:
|
||||
if ord(c) < 256:
|
||||
char_list.extend(c)
|
||||
elif is_chinese(c):
|
||||
char_list.append(" ")
|
||||
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
|
||||
else:
|
||||
char_list.append(c)
|
||||
final_reference_target_texts_list.append(char_list)
|
||||
|
||||
return final_reference_target_texts_list
|
||||
|
||||
|
||||
def list_str_to_idx(
|
||||
text: list[str] | list[list[str]],
|
||||
vocab_char_map: dict[str, int], # {char: idx}
|
||||
padding_value=-1,
|
||||
): # noqa: F722
|
||||
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
||||
return list_idx_tensors
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
def initialize(self, args):
|
||||
self.use_perf = True
|
||||
self.device = torch.device("cuda")
|
||||
self.target_audio_sample_rate = 24000
|
||||
self.target_rms = 0.15 # target rms for audio
|
||||
self.n_fft = 1024
|
||||
self.win_length = 1024
|
||||
self.hop_length = 256
|
||||
self.n_mel_channels = 100
|
||||
self.max_mel_len = 3000
|
||||
self.head_dim = 64
|
||||
|
||||
parameters = json.loads(args["model_config"])["parameters"]
|
||||
for key, value in parameters.items():
|
||||
parameters[key] = value["string_value"]
|
||||
|
||||
self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"])
|
||||
self.reference_sample_rate = int(parameters["reference_audio_sample_rate"])
|
||||
self.resampler = torchaudio.transforms.Resample(self.reference_sample_rate, self.target_audio_sample_rate)
|
||||
|
||||
self.tllm_model_dir = parameters["tllm_model_dir"]
|
||||
config_file = os.path.join(self.tllm_model_dir, "config.json")
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
self.model = F5TTS(
|
||||
config,
|
||||
debug_mode=False,
|
||||
tllm_model_dir=self.tllm_model_dir,
|
||||
model_path=parameters["model_path"],
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
self.vocoder = parameters["vocoder"]
|
||||
assert self.vocoder in ["vocos", "bigvgan"]
|
||||
if self.vocoder == "vocos":
|
||||
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=self.target_audio_sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mel_channels,
|
||||
power=1,
|
||||
center=True,
|
||||
normalized=False,
|
||||
norm=None,
|
||||
).to(self.device)
|
||||
self.compute_mel_fn = self.get_vocos_mel_spectrogram
|
||||
elif self.vocoder == "bigvgan":
|
||||
self.compute_mel_fn = self.get_bigvgan_mel_spectrogram
|
||||
|
||||
def get_vocos_mel_spectrogram(self, waveform):
|
||||
mel = self.mel_stft(waveform)
|
||||
mel = mel.clamp(min=1e-5).log()
|
||||
return mel.transpose(1, 2)
|
||||
|
||||
def forward_vocoder(self, mel):
|
||||
mel = mel.to(torch.float32).contiguous().cpu()
|
||||
input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel))
|
||||
|
||||
inference_request = pb_utils.InferenceRequest(
|
||||
model_name="vocoder", requested_output_names=["waveform"], inputs=[input_tensor_0]
|
||||
)
|
||||
inference_response = inference_request.exec()
|
||||
if inference_response.has_error():
|
||||
raise pb_utils.TritonModelException(inference_response.error().message())
|
||||
else:
|
||||
waveform = pb_utils.get_output_tensor_by_name(inference_response, "waveform")
|
||||
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
||||
|
||||
return waveform
|
||||
|
||||
def execute(self, requests):
|
||||
(
|
||||
reference_text_list,
|
||||
target_text_list,
|
||||
reference_target_texts_list,
|
||||
estimated_reference_target_mel_len,
|
||||
reference_mel_len,
|
||||
) = [], [], [], [], []
|
||||
mel_features_list = []
|
||||
if self.use_perf:
|
||||
torch.cuda.nvtx.range_push("preprocess")
|
||||
for request in requests:
|
||||
wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
||||
wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
||||
|
||||
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
||||
reference_text = reference_text[0][0].decode("utf-8")
|
||||
reference_text_list.append(reference_text)
|
||||
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
||||
target_text = target_text[0][0].decode("utf-8")
|
||||
target_text_list.append(target_text)
|
||||
|
||||
text = reference_text + target_text
|
||||
reference_target_texts_list.append(text)
|
||||
|
||||
wav = from_dlpack(wav_tensor.to_dlpack())
|
||||
wav_len = from_dlpack(wav_lens.to_dlpack())
|
||||
wav_len = wav_len.squeeze()
|
||||
assert wav.shape[0] == 1, "Only support batch size 1 for now."
|
||||
wav = wav[:, :wav_len]
|
||||
|
||||
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
|
||||
if ref_rms < self.target_rms:
|
||||
wav = wav * self.target_rms / ref_rms
|
||||
if self.reference_sample_rate != self.target_audio_sample_rate:
|
||||
wav = self.resampler(wav)
|
||||
wav = wav.to(self.device)
|
||||
if self.use_perf:
|
||||
torch.cuda.nvtx.range_push("compute_mel")
|
||||
mel_features = self.compute_mel_fn(wav)
|
||||
if self.use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
mel_features_list.append(mel_features)
|
||||
|
||||
reference_mel_len.append(mel_features.shape[1])
|
||||
estimated_reference_target_mel_len.append(
|
||||
int(mel_features.shape[1] * (1 + len(target_text) / len(reference_text)))
|
||||
)
|
||||
|
||||
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
|
||||
|
||||
batch = len(requests)
|
||||
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
|
||||
for i, mel in enumerate(mel_features_list):
|
||||
mel_features[i, : mel.shape[1], :] = mel
|
||||
|
||||
reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device)
|
||||
|
||||
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
|
||||
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
|
||||
|
||||
for i, item in enumerate(text_pad_sequence):
|
||||
text_pad_sequence[i] = F.pad(
|
||||
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
|
||||
)
|
||||
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
|
||||
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
|
||||
text_pad_sequence = F.pad(
|
||||
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
|
||||
)
|
||||
if self.use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
denoised, cost_time = self.model.sample(
|
||||
text_pad_sequence,
|
||||
mel_features,
|
||||
reference_mel_len_tensor,
|
||||
estimated_reference_target_mel_len,
|
||||
remove_input_padding=False,
|
||||
use_perf=self.use_perf,
|
||||
)
|
||||
if self.use_perf:
|
||||
torch.cuda.nvtx.range_push("vocoder")
|
||||
|
||||
responses = []
|
||||
for i in range(batch):
|
||||
ref_me_len = reference_mel_len[i]
|
||||
estimated_mel_len = estimated_reference_target_mel_len[i]
|
||||
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
|
||||
audio = self.forward_vocoder(denoised_one_item)
|
||||
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
||||
if rms < self.target_rms:
|
||||
audio = audio * self.target_rms / rms
|
||||
|
||||
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])
|
||||
responses.append(inference_response)
|
||||
if self.use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
return responses
|
||||
@@ -0,0 +1,81 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: "f5_tts"
|
||||
backend: "python"
|
||||
max_batch_size: 4
|
||||
dynamic_batching {
|
||||
max_queue_delay_microseconds: 1000
|
||||
}
|
||||
parameters [
|
||||
{
|
||||
key: "vocab_file"
|
||||
value: { string_value: "${vocab}"}
|
||||
},
|
||||
{
|
||||
key: "model_path",
|
||||
value: {string_value:"${model}"}
|
||||
},
|
||||
{
|
||||
key: "tllm_model_dir",
|
||||
value: {string_value:"${trtllm}"}
|
||||
},
|
||||
{
|
||||
key: "reference_audio_sample_rate",
|
||||
value: {string_value:"16000"}
|
||||
},
|
||||
{
|
||||
key: "vocoder",
|
||||
value: {string_value:"${vocoder}"}
|
||||
}
|
||||
]
|
||||
|
||||
input [
|
||||
{
|
||||
name: "reference_wav"
|
||||
data_type: TYPE_FP32
|
||||
dims: [-1]
|
||||
optional: True
|
||||
},
|
||||
{
|
||||
name: "reference_wav_len"
|
||||
data_type: TYPE_INT32
|
||||
dims: [1]
|
||||
optional: True
|
||||
},
|
||||
{
|
||||
name: "reference_text"
|
||||
data_type: TYPE_STRING
|
||||
dims: [1]
|
||||
},
|
||||
{
|
||||
name: "target_text"
|
||||
data_type: TYPE_STRING
|
||||
dims: [1]
|
||||
}
|
||||
]
|
||||
output [
|
||||
{
|
||||
name: "waveform"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
}
|
||||
]
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: 1
|
||||
kind: KIND_GPU
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,32 @@
|
||||
name: "vocoder"
|
||||
backend: "tensorrt"
|
||||
default_model_filename: "vocoder.plan"
|
||||
max_batch_size: 4
|
||||
|
||||
input [
|
||||
{
|
||||
name: "mel"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ 100, -1 ]
|
||||
}
|
||||
]
|
||||
|
||||
output [
|
||||
{
|
||||
name: "waveform"
|
||||
data_type: TYPE_FP32
|
||||
dims: [ -1 ]
|
||||
}
|
||||
]
|
||||
|
||||
dynamic_batching {
|
||||
preferred_batch_size: [1, 2, 4]
|
||||
max_queue_delay_microseconds: 1
|
||||
}
|
||||
|
||||
instance_group [
|
||||
{
|
||||
count: 1
|
||||
kind: KIND_GPU
|
||||
}
|
||||
]
|
||||
198
src/f5_tts/runtime/triton_trtllm/patch/__init__.py
Normal file
198
src/f5_tts/runtime/triton_trtllm/patch/__init__.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .baichuan.model import BaichuanForCausalLM
|
||||
from .bert.model import (
|
||||
BertForQuestionAnswering,
|
||||
BertForSequenceClassification,
|
||||
BertModel,
|
||||
RobertaForQuestionAnswering,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaModel,
|
||||
)
|
||||
from .bloom.model import BloomForCausalLM, BloomModel
|
||||
from .chatglm.config import ChatGLMConfig
|
||||
from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel
|
||||
from .cogvlm.config import CogVLMConfig
|
||||
from .cogvlm.model import CogVLMForCausalLM
|
||||
from .commandr.model import CohereForCausalLM
|
||||
from .dbrx.config import DbrxConfig
|
||||
from .dbrx.model import DbrxForCausalLM
|
||||
from .deepseek_v1.model import DeepseekForCausalLM
|
||||
from .deepseek_v2.model import DeepseekV2ForCausalLM
|
||||
from .dit.model import DiT
|
||||
from .eagle.model import EagleForCausalLM
|
||||
from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
|
||||
from .falcon.config import FalconConfig
|
||||
from .falcon.model import FalconForCausalLM, FalconModel
|
||||
from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
|
||||
from .gemma.model import GemmaForCausalLM
|
||||
from .gpt.config import GPTConfig
|
||||
from .gpt.model import GPTForCausalLM, GPTModel
|
||||
from .gptj.config import GPTJConfig
|
||||
from .gptj.model import GPTJForCausalLM, GPTJModel
|
||||
from .gptneox.model import GPTNeoXForCausalLM, GPTNeoXModel
|
||||
from .grok.model import GrokForCausalLM
|
||||
from .llama.config import LLaMAConfig
|
||||
from .llama.model import LLaMAForCausalLM, LLaMAModel
|
||||
from .mamba.model import MambaForCausalLM
|
||||
from .medusa.config import MedusaConfig
|
||||
from .medusa.model import MedusaForCausalLm
|
||||
from .mllama.model import MLLaMAModel
|
||||
from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodingMode
|
||||
from .mpt.model import MPTForCausalLM, MPTModel
|
||||
from .nemotron_nas.model import DeciLMForCausalLM
|
||||
from .opt.model import OPTForCausalLM, OPTModel
|
||||
from .phi3.model import Phi3ForCausalLM, Phi3Model
|
||||
from .phi.model import PhiForCausalLM, PhiModel
|
||||
from .qwen.model import QWenForCausalLM
|
||||
from .recurrentgemma.model import RecurrentGemmaForCausalLM
|
||||
from .redrafter.model import ReDrafterForCausalLM
|
||||
from .f5tts.model import F5TTS
|
||||
|
||||
__all__ = [
|
||||
"BertModel",
|
||||
"BertForQuestionAnswering",
|
||||
"BertForSequenceClassification",
|
||||
"RobertaModel",
|
||||
"RobertaForQuestionAnswering",
|
||||
"RobertaForSequenceClassification",
|
||||
"BloomModel",
|
||||
"BloomForCausalLM",
|
||||
"DiT",
|
||||
"DeepseekForCausalLM",
|
||||
"FalconConfig",
|
||||
"DeepseekV2ForCausalLM",
|
||||
"FalconForCausalLM",
|
||||
"FalconModel",
|
||||
"GPTConfig",
|
||||
"GPTModel",
|
||||
"GPTForCausalLM",
|
||||
"OPTForCausalLM",
|
||||
"OPTModel",
|
||||
"LLaMAConfig",
|
||||
"LLaMAForCausalLM",
|
||||
"LLaMAModel",
|
||||
"MedusaConfig",
|
||||
"MedusaForCausalLm",
|
||||
"ReDrafterForCausalLM",
|
||||
"GPTJConfig",
|
||||
"GPTJModel",
|
||||
"GPTJForCausalLM",
|
||||
"GPTNeoXModel",
|
||||
"GPTNeoXForCausalLM",
|
||||
"PhiModel",
|
||||
"PhiConfig",
|
||||
"Phi3Model",
|
||||
"Phi3Config",
|
||||
"PhiForCausalLM",
|
||||
"Phi3ForCausalLM",
|
||||
"ChatGLMConfig",
|
||||
"ChatGLMForCausalLM",
|
||||
"ChatGLMModel",
|
||||
"BaichuanForCausalLM",
|
||||
"QWenConfigQWenForCausalLM",
|
||||
"QWenModel",
|
||||
"EncoderModel",
|
||||
"DecoderModel",
|
||||
"PretrainedConfig",
|
||||
"PretrainedModel",
|
||||
"WhisperEncoder",
|
||||
"MambaForCausalLM",
|
||||
"MambaConfig",
|
||||
"MPTForCausalLM",
|
||||
"MPTModel",
|
||||
"SkyworkForCausalLM",
|
||||
"GemmaConfig",
|
||||
"GemmaForCausalLM",
|
||||
"DbrxConfig",
|
||||
"DbrxForCausalLM",
|
||||
"RecurrentGemmaForCausalLM",
|
||||
"CogVLMConfig",
|
||||
"CogVLMForCausalLM",
|
||||
"EagleForCausalLM",
|
||||
"SpeculativeDecodingMode",
|
||||
"CohereForCausalLM",
|
||||
"MLLaMAModel",
|
||||
"F5TTS",
|
||||
]
|
||||
|
||||
MODEL_MAP = {
|
||||
"GPT2LMHeadModel": GPTForCausalLM,
|
||||
"GPT2LMHeadCustomModel": GPTForCausalLM,
|
||||
"GPTBigCodeForCausalLM": GPTForCausalLM,
|
||||
"Starcoder2ForCausalLM": GPTForCausalLM,
|
||||
"FuyuForCausalLM": GPTForCausalLM,
|
||||
"Kosmos2ForConditionalGeneration": GPTForCausalLM,
|
||||
"JAISLMHeadModel": GPTForCausalLM,
|
||||
"GPTForCausalLM": GPTForCausalLM,
|
||||
"NemotronForCausalLM": GPTForCausalLM,
|
||||
"OPTForCausalLM": OPTForCausalLM,
|
||||
"BloomForCausalLM": BloomForCausalLM,
|
||||
"RWForCausalLM": FalconForCausalLM,
|
||||
"FalconForCausalLM": FalconForCausalLM,
|
||||
"PhiForCausalLM": PhiForCausalLM,
|
||||
"Phi3ForCausalLM": Phi3ForCausalLM,
|
||||
"Phi3VForCausalLM": Phi3ForCausalLM,
|
||||
"Phi3SmallForCausalLM": Phi3ForCausalLM,
|
||||
"PhiMoEForCausalLM": Phi3ForCausalLM,
|
||||
"MambaForCausalLM": MambaForCausalLM,
|
||||
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
|
||||
"GPTJForCausalLM": GPTJForCausalLM,
|
||||
"MPTForCausalLM": MPTForCausalLM,
|
||||
"GLMModel": ChatGLMForCausalLM,
|
||||
"ChatGLMModel": ChatGLMForCausalLM,
|
||||
"ChatGLMForCausalLM": ChatGLMForCausalLM,
|
||||
"LlamaForCausalLM": LLaMAForCausalLM,
|
||||
"ExaoneForCausalLM": LLaMAForCausalLM,
|
||||
"MistralForCausalLM": LLaMAForCausalLM,
|
||||
"MixtralForCausalLM": LLaMAForCausalLM,
|
||||
"ArcticForCausalLM": LLaMAForCausalLM,
|
||||
"Grok1ModelForCausalLM": GrokForCausalLM,
|
||||
"InternLMForCausalLM": LLaMAForCausalLM,
|
||||
"InternLM2ForCausalLM": LLaMAForCausalLM,
|
||||
"MedusaForCausalLM": MedusaForCausalLm,
|
||||
"ReDrafterForCausalLM": ReDrafterForCausalLM,
|
||||
"BaichuanForCausalLM": BaichuanForCausalLM,
|
||||
"BaiChuanForCausalLM": BaichuanForCausalLM,
|
||||
"SkyworkForCausalLM": LLaMAForCausalLM,
|
||||
GEMMA_ARCHITECTURE: GemmaForCausalLM,
|
||||
GEMMA2_ARCHITECTURE: GemmaForCausalLM,
|
||||
"QWenLMHeadModel": QWenForCausalLM,
|
||||
"QWenForCausalLM": QWenForCausalLM,
|
||||
"Qwen2ForCausalLM": QWenForCausalLM,
|
||||
"Qwen2MoeForCausalLM": QWenForCausalLM,
|
||||
"Qwen2ForSequenceClassification": QWenForCausalLM,
|
||||
"Qwen2VLForConditionalGeneration": QWenForCausalLM,
|
||||
"WhisperEncoder": WhisperEncoder,
|
||||
"EncoderModel": EncoderModel,
|
||||
"DecoderModel": DecoderModel,
|
||||
"DbrxForCausalLM": DbrxForCausalLM,
|
||||
"RecurrentGemmaForCausalLM": RecurrentGemmaForCausalLM,
|
||||
"CogVLMForCausalLM": CogVLMForCausalLM,
|
||||
"DiT": DiT,
|
||||
"DeepseekForCausalLM": DeepseekForCausalLM,
|
||||
"DeciLMForCausalLM": DeciLMForCausalLM,
|
||||
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
||||
"EagleForCausalLM": EagleForCausalLM,
|
||||
"CohereForCausalLM": CohereForCausalLM,
|
||||
"MllamaForConditionalGeneration": MLLaMAModel,
|
||||
"BertForQuestionAnswering": BertForQuestionAnswering,
|
||||
"BertForSequenceClassification": BertForSequenceClassification,
|
||||
"BertModel": BertModel,
|
||||
"RobertaModel": RobertaModel,
|
||||
"RobertaForQuestionAnswering": RobertaForQuestionAnswering,
|
||||
"RobertaForSequenceClassification": RobertaForSequenceClassification,
|
||||
"F5TTS": F5TTS,
|
||||
}
|
||||
225
src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py
Normal file
225
src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py
Normal file
@@ -0,0 +1,225 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import os
|
||||
|
||||
import tensorrt as trt
|
||||
from collections import OrderedDict
|
||||
from ..._utils import str_dtype_to_trt
|
||||
from ...plugin import current_all_reduce_helper
|
||||
from ..modeling_utils import PretrainedConfig, PretrainedModel
|
||||
from ...functional import Tensor, concat
|
||||
from ...module import Module, ModuleList
|
||||
from tensorrt_llm._common import default_net
|
||||
from ...layers import Linear
|
||||
|
||||
from .modules import (
|
||||
TimestepEmbedding,
|
||||
ConvPositionEmbedding,
|
||||
DiTBlock,
|
||||
AdaLayerNormZero_Final,
|
||||
)
|
||||
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
parent_dir = os.path.dirname(current_file_path)
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
|
||||
class InputEmbedding(Module):
|
||||
def __init__(self, mel_dim, text_dim, out_dim):
|
||||
super().__init__()
|
||||
self.proj = Linear(mel_dim * 2 + text_dim, out_dim)
|
||||
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
||||
|
||||
def forward(self, x, cond):
|
||||
x = self.proj(concat([x, cond], dim=-1))
|
||||
return self.conv_pos_embed(x) + x
|
||||
|
||||
|
||||
class F5TTS(PretrainedModel):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__(config)
|
||||
self.dtype = str_dtype_to_trt(config.dtype)
|
||||
|
||||
self.time_embed = TimestepEmbedding(config.hidden_size)
|
||||
self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size)
|
||||
|
||||
self.dim = config.hidden_size
|
||||
self.depth = config.num_hidden_layers
|
||||
self.transformer_blocks = ModuleList(
|
||||
[
|
||||
DiTBlock(
|
||||
dim=self.dim,
|
||||
heads=config.num_attention_heads,
|
||||
dim_head=config.dim_head,
|
||||
ff_mult=config.ff_mult,
|
||||
dropout=config.dropout,
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation
|
||||
self.proj_out = Linear(config.hidden_size, config.mel_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
noise, # nosied input audio
|
||||
cond, # masked cond audio
|
||||
time, # time step
|
||||
rope_cos,
|
||||
rope_sin,
|
||||
input_lengths,
|
||||
scale=1.0,
|
||||
):
|
||||
t = self.time_embed(time)
|
||||
x = self.input_embed(noise, cond)
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
|
||||
denoise = self.proj_out(self.norm_out(x, t))
|
||||
denoise.mark_output("denoised", self.dtype)
|
||||
return denoise
|
||||
|
||||
def prepare_inputs(self, **kwargs):
|
||||
max_batch_size = kwargs["max_batch_size"]
|
||||
batch_size_range = [2, 2, max_batch_size]
|
||||
mel_size = 100
|
||||
max_seq_len = 3000
|
||||
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
|
||||
hidden_size = 512
|
||||
concat_feature_dim = mel_size + hidden_size
|
||||
freq_embed_dim = 256
|
||||
head_dim = 64
|
||||
mapping = self.config.mapping
|
||||
if mapping.tp_size > 1:
|
||||
current_all_reduce_helper().set_workspace_tensor(mapping, 1)
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
noise = Tensor(
|
||||
name="noise",
|
||||
dtype=self.dtype,
|
||||
shape=[-1, mel_size],
|
||||
dim_range=OrderedDict(
|
||||
[
|
||||
("num_frames", [num_frames_range]),
|
||||
("n_mels", [mel_size]),
|
||||
]
|
||||
),
|
||||
)
|
||||
cond = Tensor(
|
||||
name="cond",
|
||||
dtype=self.dtype,
|
||||
shape=[-1, concat_feature_dim],
|
||||
dim_range=OrderedDict(
|
||||
[
|
||||
("num_frames", [num_frames_range]),
|
||||
("embeded_length", [concat_feature_dim]),
|
||||
]
|
||||
),
|
||||
)
|
||||
time = Tensor(
|
||||
name="time",
|
||||
dtype=self.dtype,
|
||||
shape=[-1, freq_embed_dim],
|
||||
dim_range=OrderedDict(
|
||||
[
|
||||
("num_frames", [num_frames_range]),
|
||||
("freq_dim", [freq_embed_dim]),
|
||||
]
|
||||
),
|
||||
)
|
||||
rope_cos = Tensor(
|
||||
name="rope_cos",
|
||||
dtype=self.dtype,
|
||||
shape=[-1, head_dim],
|
||||
dim_range=OrderedDict(
|
||||
[
|
||||
("num_frames", [num_frames_range]),
|
||||
("head_dim", [head_dim]),
|
||||
]
|
||||
),
|
||||
)
|
||||
rope_sin = Tensor(
|
||||
name="rope_sin",
|
||||
dtype=self.dtype,
|
||||
shape=[-1, head_dim],
|
||||
dim_range=OrderedDict(
|
||||
[
|
||||
("num_frames", [num_frames_range]),
|
||||
("head_dim", [head_dim]),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
noise = Tensor(
|
||||
name="noise",
|
||||
dtype=self.dtype,
|
||||
shape=[-1, -1, mel_size],
|
||||
dim_range=OrderedDict(
|
||||
[
|
||||
("batch_size", [batch_size_range]),
|
||||
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
|
||||
("n_mels", [mel_size]),
|
||||
]
|
||||
),
|
||||
)
|
||||
cond = Tensor(
|
||||
name="cond",
|
||||
dtype=self.dtype,
|
||||
shape=[-1, -1, concat_feature_dim],
|
||||
dim_range=OrderedDict(
|
||||
[
|
||||
("batch_size", [batch_size_range]),
|
||||
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
|
||||
("embeded_length", [concat_feature_dim]),
|
||||
]
|
||||
),
|
||||
)
|
||||
time = Tensor(
|
||||
name="time",
|
||||
dtype=self.dtype,
|
||||
shape=[-1, freq_embed_dim],
|
||||
dim_range=OrderedDict(
|
||||
[
|
||||
("batch_size", [batch_size_range]),
|
||||
("freq_dim", [freq_embed_dim]),
|
||||
]
|
||||
),
|
||||
)
|
||||
rope_cos = Tensor(
|
||||
name="rope_cos",
|
||||
dtype=self.dtype,
|
||||
shape=[-1, -1, head_dim],
|
||||
dim_range=OrderedDict(
|
||||
[
|
||||
("batch_size", [batch_size_range]),
|
||||
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
|
||||
("head_dim", [head_dim]),
|
||||
]
|
||||
),
|
||||
)
|
||||
rope_sin = Tensor(
|
||||
name="rope_sin",
|
||||
dtype=self.dtype,
|
||||
shape=[-1, -1, head_dim],
|
||||
dim_range=OrderedDict(
|
||||
[
|
||||
("batch_size", [batch_size_range]),
|
||||
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
|
||||
("head_dim", [head_dim]),
|
||||
]
|
||||
),
|
||||
)
|
||||
input_lengths = Tensor(
|
||||
name="input_lengths",
|
||||
dtype=trt.int32,
|
||||
shape=[-1],
|
||||
dim_range=OrderedDict([("batch_size", [batch_size_range])]),
|
||||
)
|
||||
return {
|
||||
"noise": noise,
|
||||
"cond": cond,
|
||||
"time": time,
|
||||
"rope_cos": rope_cos,
|
||||
"rope_sin": rope_sin,
|
||||
"input_lengths": input_lengths,
|
||||
}
|
||||
410
src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py
Normal file
410
src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py
Normal file
@@ -0,0 +1,410 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
from tensorrt_llm._common import default_net
|
||||
from ..._utils import trt_dtype_to_np, str_dtype_to_trt
|
||||
from ...functional import (
|
||||
Tensor,
|
||||
chunk,
|
||||
concat,
|
||||
constant,
|
||||
expand,
|
||||
shape,
|
||||
silu,
|
||||
slice,
|
||||
permute,
|
||||
expand_mask,
|
||||
expand_dims_like,
|
||||
unsqueeze,
|
||||
matmul,
|
||||
softmax,
|
||||
squeeze,
|
||||
cast,
|
||||
gelu,
|
||||
)
|
||||
from ...functional import expand_dims, view, bert_attention
|
||||
from ...layers import LayerNorm, Linear, Conv1d, Mish, RowLinear, ColumnLinear
|
||||
from ...module import Module
|
||||
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
self.project_in = Linear(dim, inner_dim)
|
||||
self.ff = Linear(inner_dim, dim_out)
|
||||
|
||||
def forward(self, x):
|
||||
return self.ff(gelu(self.project_in(x)))
|
||||
|
||||
|
||||
class AdaLayerNormZero(Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
|
||||
self.linear = Linear(dim, dim * 6)
|
||||
self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb=None):
|
||||
emb = self.linear(silu(emb))
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1)
|
||||
x = self.norm(x)
|
||||
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
x = x * (ones + scale_msa) + shift_msa
|
||||
else:
|
||||
x = x * (ones + unsqueeze(scale_msa, 1)) + unsqueeze(shift_msa, 1)
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class AdaLayerNormZero_Final(Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
|
||||
self.linear = Linear(dim, dim * 2)
|
||||
|
||||
self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb):
|
||||
emb = self.linear(silu(emb))
|
||||
scale, shift = chunk(emb, 2, dim=1)
|
||||
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
x = self.norm(x) * (ones + scale) + shift
|
||||
else:
|
||||
x = self.norm(x) * unsqueeze((ones + scale), 1)
|
||||
x = x + unsqueeze(shift, 1)
|
||||
return x
|
||||
|
||||
|
||||
class ConvPositionEmbedding(Module):
|
||||
def __init__(self, dim, kernel_size=31, groups=16):
|
||||
super().__init__()
|
||||
assert kernel_size % 2 != 0
|
||||
self.conv1d1 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
|
||||
self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
|
||||
self.mish = Mish()
|
||||
|
||||
def forward(self, x, mask=None): # noqa: F722
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
x = unsqueeze(x, 0)
|
||||
x = permute(x, [0, 2, 1])
|
||||
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))
|
||||
out = permute(x, [0, 2, 1])
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
out = squeeze(out, 0)
|
||||
return out
|
||||
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
processor: AttnProcessor,
|
||||
dim: int,
|
||||
heads: int = 16,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
context_dim: Optional[int] = None, # if not None -> joint attention
|
||||
context_pre_only=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
self.processor = processor
|
||||
|
||||
self.dim = dim # hidden_size
|
||||
self.heads = heads
|
||||
self.inner_dim = dim_head * heads
|
||||
self.dropout = dropout
|
||||
self.attention_head_size = dim_head
|
||||
self.context_dim = context_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
self.tp_size = 1
|
||||
self.num_attention_heads = heads // self.tp_size
|
||||
self.num_attention_kv_heads = heads // self.tp_size # 8
|
||||
self.dtype = str_dtype_to_trt("float32")
|
||||
self.attention_hidden_size = self.attention_head_size * self.num_attention_heads
|
||||
self.to_q = ColumnLinear(
|
||||
dim,
|
||||
self.tp_size * self.num_attention_heads * self.attention_head_size,
|
||||
bias=True,
|
||||
dtype=self.dtype,
|
||||
tp_group=None,
|
||||
tp_size=self.tp_size,
|
||||
)
|
||||
self.to_k = ColumnLinear(
|
||||
dim,
|
||||
self.tp_size * self.num_attention_heads * self.attention_head_size,
|
||||
bias=True,
|
||||
dtype=self.dtype,
|
||||
tp_group=None,
|
||||
tp_size=self.tp_size,
|
||||
)
|
||||
self.to_v = ColumnLinear(
|
||||
dim,
|
||||
self.tp_size * self.num_attention_heads * self.attention_head_size,
|
||||
bias=True,
|
||||
dtype=self.dtype,
|
||||
tp_group=None,
|
||||
tp_size=self.tp_size,
|
||||
)
|
||||
|
||||
if self.context_dim is not None:
|
||||
self.to_k_c = Linear(context_dim, self.inner_dim)
|
||||
self.to_v_c = Linear(context_dim, self.inner_dim)
|
||||
if self.context_pre_only is not None:
|
||||
self.to_q_c = Linear(context_dim, self.inner_dim)
|
||||
|
||||
self.to_out = RowLinear(
|
||||
self.tp_size * self.num_attention_heads * self.attention_head_size,
|
||||
dim,
|
||||
bias=True,
|
||||
dtype=self.dtype,
|
||||
tp_group=None,
|
||||
tp_size=self.tp_size,
|
||||
)
|
||||
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_out_c = Linear(self.inner_dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x, # noised input x
|
||||
rope_cos,
|
||||
rope_sin,
|
||||
input_lengths,
|
||||
c=None, # context c
|
||||
scale=1.0,
|
||||
rope=None,
|
||||
c_rope=None, # rotary position embedding for c
|
||||
) -> torch.Tensor:
|
||||
if c is not None:
|
||||
return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope)
|
||||
else:
|
||||
return self.processor(
|
||||
self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
|
||||
shape_tensor = concat(
|
||||
[shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) for i in range(tensor.ndim())]
|
||||
)
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
assert tensor.ndim() == 2
|
||||
x1 = slice(tensor, [0, 0], shape_tensor, [1, 2])
|
||||
x2 = slice(tensor, [0, 1], shape_tensor, [1, 2])
|
||||
x1 = expand_dims(x1, 2)
|
||||
x2 = expand_dims(x2, 2)
|
||||
zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
|
||||
x2 = zero - x2
|
||||
x = concat([x2, x1], 2)
|
||||
out = view(x, concat([shape(x, 0), shape(x, 1) * 2]))
|
||||
else:
|
||||
assert tensor.ndim() == 3
|
||||
|
||||
x1 = slice(tensor, [0, 0, 0], shape_tensor, [1, 1, 2])
|
||||
x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2])
|
||||
x1 = expand_dims(x1, 3)
|
||||
x2 = expand_dims(x2, 3)
|
||||
zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
|
||||
x2 = zero - x2
|
||||
x = concat([x2, x1], 3)
|
||||
out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2]))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
rot_dim = shape(rope_cos, -1) # 64
|
||||
new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64)
|
||||
x_ = slice(x, [0, 0], new_t_shape, [1, 1])
|
||||
end_dim = shape(x, -1) - shape(rope_cos, -1)
|
||||
new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
|
||||
x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
|
||||
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
|
||||
else:
|
||||
rot_dim = shape(rope_cos, 2) # 64
|
||||
new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
|
||||
x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
|
||||
end_dim = shape(x, 2) - shape(rope_cos, 2)
|
||||
new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960)
|
||||
x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
|
||||
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
|
||||
return out
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
x, # noised input x
|
||||
rope_cos,
|
||||
rope_sin,
|
||||
input_lengths,
|
||||
scale=1.0,
|
||||
rope=None,
|
||||
) -> torch.FloatTensor:
|
||||
query = attn.to_q(x)
|
||||
key = attn.to_k(x)
|
||||
value = attn.to_v(x)
|
||||
# k,v,q all (2,1226,1024)
|
||||
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin)
|
||||
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin)
|
||||
|
||||
# attention
|
||||
inner_dim = key.shape[-1]
|
||||
norm_factor = math.sqrt(attn.attention_head_size)
|
||||
q_scaling = 1.0 / norm_factor
|
||||
mask = None
|
||||
if not default_net().plugin_config.remove_input_padding:
|
||||
N = shape(x, 1)
|
||||
B = shape(x, 0)
|
||||
seq_len_2d = concat([1, N])
|
||||
max_position_embeddings = 4096
|
||||
# create position ids
|
||||
position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
|
||||
tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
|
||||
tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL
|
||||
tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1
|
||||
tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL
|
||||
mask = tmp_position_ids < tmp_input_lengths # BxL
|
||||
mask = mask.cast("int32")
|
||||
|
||||
if default_net().plugin_config.bert_attention_plugin:
|
||||
qkv = concat([query, key, value], dim=-1)
|
||||
# TRT plugin mode
|
||||
assert input_lengths is not None
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
qkv = qkv.view(concat([-1, 3 * inner_dim]))
|
||||
max_input_length = constant(
|
||||
np.zeros(
|
||||
[
|
||||
2048,
|
||||
],
|
||||
dtype=np.int32,
|
||||
)
|
||||
)
|
||||
else:
|
||||
max_input_length = None
|
||||
context = bert_attention(
|
||||
qkv,
|
||||
input_lengths,
|
||||
attn.num_attention_heads,
|
||||
attn.attention_head_size,
|
||||
q_scaling=q_scaling,
|
||||
max_input_length=max_input_length,
|
||||
)
|
||||
else:
|
||||
assert not default_net().plugin_config.remove_input_padding
|
||||
|
||||
def transpose_for_scores(x):
|
||||
new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
|
||||
|
||||
y = x.view(new_x_shape)
|
||||
y = y.transpose(1, 2)
|
||||
return y
|
||||
|
||||
def transpose_for_scores_k(x):
|
||||
new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
|
||||
|
||||
y = x.view(new_x_shape)
|
||||
y = y.permute([0, 2, 3, 1])
|
||||
return y
|
||||
|
||||
query = transpose_for_scores(query)
|
||||
key = transpose_for_scores_k(key)
|
||||
value = transpose_for_scores(value)
|
||||
|
||||
attention_scores = matmul(query, key, use_fp32_acc=False)
|
||||
|
||||
if mask is not None:
|
||||
attention_mask = expand_mask(mask, shape(query, 2))
|
||||
attention_mask = cast(attention_mask, attention_scores.dtype)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
attention_probs = softmax(attention_scores, dim=-1)
|
||||
|
||||
context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2)
|
||||
context = context.view(concat([shape(context, 0), shape(context, 1), attn.attention_hidden_size]))
|
||||
context = attn.to_out(context)
|
||||
if mask is not None:
|
||||
mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1]))
|
||||
mask = expand_dims_like(mask, context)
|
||||
mask = cast(mask, context.dtype)
|
||||
context = context * mask
|
||||
return context
|
||||
|
||||
|
||||
# DiT Block
|
||||
class DiTBlock(Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.attn_norm = AdaLayerNormZero(dim)
|
||||
self.attn = Attention(
|
||||
processor=AttnProcessor(),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
self.ff_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)
|
||||
|
||||
def forward(
|
||||
self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError
|
||||
): # x: noised input, t: time embedding
|
||||
# pre-norm & modulation for attention input
|
||||
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
||||
# attention
|
||||
# norm ----> (2,1226,1024)
|
||||
attn_output = self.attn(x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
|
||||
|
||||
# process attention output for input x
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
x = x + gate_msa * attn_output
|
||||
else:
|
||||
x = x + unsqueeze(gate_msa, 1) * attn_output
|
||||
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
|
||||
else:
|
||||
norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1)
|
||||
# norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
|
||||
ff_output = self.ff(norm)
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
x = x + gate_mlp * ff_output
|
||||
else:
|
||||
x = x + unsqueeze(gate_mlp, 1) * ff_output
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbedding(Module):
|
||||
def __init__(self, dim, freq_embed_dim=256, dtype=None):
|
||||
super().__init__()
|
||||
# self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
||||
self.mlp1 = Linear(freq_embed_dim, dim, bias=True, dtype=dtype)
|
||||
self.mlp2 = Linear(dim, dim, bias=True, dtype=dtype)
|
||||
|
||||
def forward(self, timestep):
|
||||
t_freq = self.mlp1(timestep)
|
||||
t_freq = silu(t_freq)
|
||||
t_emb = self.mlp2(t_freq)
|
||||
return t_emb
|
||||
70
src/f5_tts/runtime/triton_trtllm/run.sh
Normal file
70
src/f5_tts/runtime/triton_trtllm/run.sh
Normal file
@@ -0,0 +1,70 @@
|
||||
stage=$1
|
||||
stop_stage=$2
|
||||
model=$3 # F5TTS_Base
|
||||
if [ -z "$model" ]; then
|
||||
echo "Model is none"
|
||||
exit 1
|
||||
fi
|
||||
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
|
||||
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
|
||||
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
|
||||
|
||||
vocoder_trt_engine_path=vocos_vocoder.plan
|
||||
model_repo=./model_repo
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
echo "Downloading f5 tts from huggingface"
|
||||
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
echo "Converting checkpoint"
|
||||
python3 ./scripts/convert_checkpoint.py \
|
||||
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
|
||||
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
|
||||
python_package_path=/usr/local/lib/python3.12/dist-packages
|
||||
cp -r patch/* $python_package_path/tensorrt_llm/models
|
||||
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
|
||||
--max_batch_size 8 \
|
||||
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
echo "Exporting vocos vocoder"
|
||||
onnx_vocoder_path=vocos_vocoder.onnx
|
||||
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path
|
||||
bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
echo "Building triton server"
|
||||
rm -r $model_repo
|
||||
cp -r ./model_repo_f5_tts $model_repo
|
||||
python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos
|
||||
cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
echo "Starting triton server"
|
||||
tritonserver --model-repository=$model_repo
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
echo "Testing triton server"
|
||||
num_task=1
|
||||
log_dir=./log_concurrent_tasks_${num_task}
|
||||
rm -r $log_dir
|
||||
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
echo "Testing http client"
|
||||
audio=../../infer/examples/basic/basic_ref_en.wav
|
||||
reference_text="Some call me nature, others call me mother nature."
|
||||
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
|
||||
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
|
||||
fi
|
||||
247
src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py
Normal file
247
src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# Modified from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py
|
||||
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# MIT License
|
||||
|
||||
# Copyright (c) 2020 Shimin Zhang
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
from scipy.signal import check_COLA, get_window
|
||||
|
||||
support_clp_op = None
|
||||
if th.__version__ >= "1.7.0":
|
||||
from torch.fft import rfft as fft
|
||||
|
||||
support_clp_op = True
|
||||
else:
|
||||
from torch import rfft as fft
|
||||
|
||||
|
||||
class STFT(th.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
win_len=1024,
|
||||
win_hop=512,
|
||||
fft_len=1024,
|
||||
enframe_mode="continue",
|
||||
win_type="hann",
|
||||
win_sqrt=False,
|
||||
pad_center=True,
|
||||
):
|
||||
"""
|
||||
Implement of STFT using 1D convolution and 1D transpose convolutions.
|
||||
Implement of framing the signal in 2 ways, `break` and `continue`.
|
||||
`break` method is a kaldi-like framing.
|
||||
`continue` method is a librosa-like framing.
|
||||
|
||||
More information about `perfect reconstruction`:
|
||||
1. https://ww2.mathworks.cn/help/signal/ref/stft.html
|
||||
2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html
|
||||
|
||||
Args:
|
||||
win_len (int): Number of points in one frame. Defaults to 1024.
|
||||
win_hop (int): Number of framing stride. Defaults to 512.
|
||||
fft_len (int): Number of DFT points. Defaults to 1024.
|
||||
enframe_mode (str, optional): `break` and `continue`. Defaults to 'continue'.
|
||||
win_type (str, optional): The type of window to create. Defaults to 'hann'.
|
||||
win_sqrt (bool, optional): using square root window. Defaults to True.
|
||||
pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True.
|
||||
"""
|
||||
super(STFT, self).__init__()
|
||||
assert enframe_mode in ["break", "continue"]
|
||||
assert fft_len >= win_len
|
||||
self.win_len = win_len
|
||||
self.win_hop = win_hop
|
||||
self.fft_len = fft_len
|
||||
self.mode = enframe_mode
|
||||
self.win_type = win_type
|
||||
self.win_sqrt = win_sqrt
|
||||
self.pad_center = pad_center
|
||||
self.pad_amount = self.fft_len // 2
|
||||
|
||||
en_k, fft_k, ifft_k, ola_k = self.__init_kernel__()
|
||||
self.register_buffer("en_k", en_k)
|
||||
self.register_buffer("fft_k", fft_k)
|
||||
self.register_buffer("ifft_k", ifft_k)
|
||||
self.register_buffer("ola_k", ola_k)
|
||||
|
||||
def __init_kernel__(self):
|
||||
"""
|
||||
Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel.
|
||||
** enframe_kernel: Using conv1d layer and identity matrix.
|
||||
** fft_kernel: Using linear layer for matrix multiplication. In fact,
|
||||
enframe_kernel and fft_kernel can be combined, But for the sake of
|
||||
readability, I took the two apart.
|
||||
** ifft_kernel, pinv of fft_kernel.
|
||||
** overlap-add kernel, just like enframe_kernel, but transposed.
|
||||
|
||||
Returns:
|
||||
tuple: four kernels.
|
||||
"""
|
||||
enframed_kernel = th.eye(self.fft_len)[:, None, :]
|
||||
if support_clp_op:
|
||||
tmp = fft(th.eye(self.fft_len))
|
||||
fft_kernel = th.stack([tmp.real, tmp.imag], dim=2)
|
||||
else:
|
||||
fft_kernel = fft(th.eye(self.fft_len), 1)
|
||||
if self.mode == "break":
|
||||
enframed_kernel = th.eye(self.win_len)[:, None, :]
|
||||
fft_kernel = fft_kernel[: self.win_len]
|
||||
fft_kernel = th.cat((fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1)
|
||||
ifft_kernel = th.pinverse(fft_kernel)[:, None, :]
|
||||
window = get_window(self.win_type, self.win_len)
|
||||
|
||||
self.perfect_reconstruct = check_COLA(window, self.win_len, self.win_len - self.win_hop)
|
||||
window = th.FloatTensor(window)
|
||||
if self.mode == "continue":
|
||||
left_pad = (self.fft_len - self.win_len) // 2
|
||||
right_pad = left_pad + (self.fft_len - self.win_len) % 2
|
||||
window = F.pad(window, (left_pad, right_pad))
|
||||
if self.win_sqrt:
|
||||
self.padded_window = window
|
||||
window = th.sqrt(window)
|
||||
else:
|
||||
self.padded_window = window**2
|
||||
|
||||
fft_kernel = fft_kernel.T * window
|
||||
ifft_kernel = ifft_kernel * window
|
||||
ola_kernel = th.eye(self.fft_len)[: self.win_len, None, :]
|
||||
if self.mode == "continue":
|
||||
ola_kernel = th.eye(self.fft_len)[:, None, : self.fft_len]
|
||||
return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel
|
||||
|
||||
def is_perfect(self):
|
||||
"""
|
||||
Whether the parameters win_len, win_hop and win_sqrt
|
||||
obey constants overlap-add(COLA)
|
||||
|
||||
Returns:
|
||||
bool: Return true if parameters obey COLA.
|
||||
"""
|
||||
return self.perfect_reconstruct and self.pad_center
|
||||
|
||||
def transform(self, inputs, return_type="complex"):
|
||||
"""Take input data (audio) to STFT domain.
|
||||
|
||||
Args:
|
||||
inputs (tensor): Tensor of floats, with shape (num_batch, num_samples)
|
||||
return_type (str, optional): return (mag, phase) when `magphase`,
|
||||
return (real, imag) when `realimag` and complex(real, imag) when `complex`.
|
||||
Defaults to 'complex'.
|
||||
|
||||
Returns:
|
||||
tuple: (mag, phase) when `magphase`, return (real, imag) when
|
||||
`realimag`. Defaults to 'complex', each elements with shape
|
||||
[num_batch, num_frequencies, num_frames]
|
||||
"""
|
||||
assert return_type in ["magphase", "realimag", "complex"]
|
||||
if inputs.dim() == 2:
|
||||
inputs = th.unsqueeze(inputs, 1)
|
||||
self.num_samples = inputs.size(-1)
|
||||
if self.pad_center:
|
||||
inputs = F.pad(inputs, (self.pad_amount, self.pad_amount), mode="reflect")
|
||||
enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop)
|
||||
outputs = th.transpose(enframe_inputs, 1, 2)
|
||||
outputs = F.linear(outputs, self.fft_k)
|
||||
outputs = th.transpose(outputs, 1, 2)
|
||||
dim = self.fft_len // 2 + 1
|
||||
real = outputs[:, :dim, :]
|
||||
imag = outputs[:, dim:, :]
|
||||
if return_type == "realimag":
|
||||
return real, imag
|
||||
elif return_type == "complex":
|
||||
assert support_clp_op
|
||||
return th.complex(real, imag)
|
||||
else:
|
||||
mags = th.sqrt(real**2 + imag**2)
|
||||
phase = th.atan2(imag, real)
|
||||
return mags, phase
|
||||
|
||||
def inverse(self, input1, input2=None, input_type="magphase"):
|
||||
"""Call the inverse STFT (iSTFT), given tensors produced
|
||||
by the `transform` function.
|
||||
|
||||
Args:
|
||||
input1 (tensors): Magnitude/Real-part of STFT with shape
|
||||
[num_batch, num_frequencies, num_frames]
|
||||
input2 (tensors): Phase/Imag-part of STFT with shape
|
||||
[num_batch, num_frequencies, num_frames]
|
||||
input_type (str, optional): Mathematical meaning of input tensor's.
|
||||
Defaults to 'magphase'.
|
||||
|
||||
Returns:
|
||||
tensors: Reconstructed audio given magnitude and phase. Of
|
||||
shape [num_batch, num_samples]
|
||||
"""
|
||||
assert input_type in ["magphase", "realimag"]
|
||||
if input_type == "realimag":
|
||||
real, imag = None, None
|
||||
if support_clp_op and th.is_complex(input1):
|
||||
real, imag = input1.real, input1.imag
|
||||
else:
|
||||
real, imag = input1, input2
|
||||
else:
|
||||
real = input1 * th.cos(input2)
|
||||
imag = input1 * th.sin(input2)
|
||||
inputs = th.cat([real, imag], dim=1)
|
||||
outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop)
|
||||
t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1))
|
||||
t = t.to(inputs.device)
|
||||
coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop)
|
||||
|
||||
num_frames = input1.size(-1)
|
||||
num_samples = num_frames * self.win_hop
|
||||
|
||||
rm_start, rm_end = self.pad_amount, self.pad_amount + num_samples
|
||||
|
||||
outputs = outputs[..., rm_start:rm_end]
|
||||
coff = coff[..., rm_start:rm_end]
|
||||
coffidx = th.where(coff > 1e-8)
|
||||
outputs[coffidx] = outputs[coffidx] / (coff[coffidx])
|
||||
return outputs.squeeze(dim=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Take input data (audio) to STFT domain and then back to audio.
|
||||
|
||||
Args:
|
||||
inputs (tensor): Tensor of floats, with shape [num_batch, num_samples]
|
||||
|
||||
Returns:
|
||||
tensor: Reconstructed audio given magnitude and phase.
|
||||
Of shape [num_batch, num_samples]
|
||||
"""
|
||||
mag, phase = self.transform(inputs)
|
||||
rec_wav = self.inverse(mag, phase)
|
||||
return rec_wav
|
||||
359
src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py
Normal file
359
src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from tensorrt_llm import str_dtype_to_torch
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.convert_utils import split, split_matrix_tp
|
||||
|
||||
|
||||
def split_q_tp(v, n_head, n_hidden, tensor_parallel, rank):
|
||||
split_v = split(v, tensor_parallel, rank, dim=1)
|
||||
return split_v.contiguous()
|
||||
|
||||
|
||||
def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
|
||||
split_v = split(v, tensor_parallel, rank, dim=0)
|
||||
return split_v.contiguous()
|
||||
|
||||
|
||||
FACEBOOK_DIT_NAME_MAPPING = {
|
||||
"^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight",
|
||||
"^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias",
|
||||
"^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight",
|
||||
"^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias",
|
||||
"^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight",
|
||||
"^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias",
|
||||
"^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight",
|
||||
"^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias",
|
||||
"^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight",
|
||||
"^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias",
|
||||
"^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight",
|
||||
"^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias",
|
||||
"^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight",
|
||||
"^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias",
|
||||
"^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight",
|
||||
"^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias",
|
||||
"^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight",
|
||||
"^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias",
|
||||
"^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight",
|
||||
"^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias",
|
||||
"^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight",
|
||||
"^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias",
|
||||
"^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight",
|
||||
"^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias",
|
||||
"^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight",
|
||||
"^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias",
|
||||
"^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight",
|
||||
"^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias",
|
||||
"^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight",
|
||||
"^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias",
|
||||
"^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight",
|
||||
"^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias",
|
||||
"^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight",
|
||||
"^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias",
|
||||
"^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight",
|
||||
"^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias",
|
||||
"^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight",
|
||||
"^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias",
|
||||
"^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight",
|
||||
"^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias",
|
||||
"^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight",
|
||||
"^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias",
|
||||
"^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight",
|
||||
"^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias",
|
||||
"^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight",
|
||||
"^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias",
|
||||
"^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight",
|
||||
"^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias",
|
||||
"^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight",
|
||||
"^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias",
|
||||
"^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight",
|
||||
"^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias",
|
||||
"^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight",
|
||||
"^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias",
|
||||
"^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight",
|
||||
"^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias",
|
||||
"^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight",
|
||||
"^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias",
|
||||
"^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight",
|
||||
"^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias",
|
||||
"^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight",
|
||||
"^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias",
|
||||
"^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight",
|
||||
"^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias",
|
||||
"^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight",
|
||||
"^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias",
|
||||
"^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight",
|
||||
"^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias",
|
||||
"^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight",
|
||||
"^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias",
|
||||
"^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight",
|
||||
"^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias",
|
||||
"^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight",
|
||||
"^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias",
|
||||
"^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight",
|
||||
"^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias",
|
||||
"^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight",
|
||||
"^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias",
|
||||
"^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight",
|
||||
"^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias",
|
||||
"^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight",
|
||||
"^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias",
|
||||
"^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight",
|
||||
"^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias",
|
||||
"^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight",
|
||||
"^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias",
|
||||
"^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight",
|
||||
"^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias",
|
||||
"^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight",
|
||||
"^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias",
|
||||
"^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight",
|
||||
"^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias",
|
||||
"^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight",
|
||||
"^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias",
|
||||
"^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight",
|
||||
"^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias",
|
||||
"^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight",
|
||||
"^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias",
|
||||
"^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight",
|
||||
"^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias",
|
||||
"^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight",
|
||||
"^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias",
|
||||
"^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight",
|
||||
"^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias",
|
||||
"^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight",
|
||||
"^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias",
|
||||
"^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight",
|
||||
"^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias",
|
||||
"^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight",
|
||||
"^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias",
|
||||
"^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight",
|
||||
"^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias",
|
||||
"^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight",
|
||||
"^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias",
|
||||
"^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight",
|
||||
"^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias",
|
||||
"^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight",
|
||||
"^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias",
|
||||
"^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight",
|
||||
"^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias",
|
||||
"^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight",
|
||||
"^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias",
|
||||
"^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight",
|
||||
"^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias",
|
||||
"^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight",
|
||||
"^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias",
|
||||
"^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight",
|
||||
"^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias",
|
||||
"^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight",
|
||||
"^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias",
|
||||
"^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight",
|
||||
"^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias",
|
||||
"^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight",
|
||||
"^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias",
|
||||
"^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight",
|
||||
"^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias",
|
||||
"^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight",
|
||||
"^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias",
|
||||
"^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight",
|
||||
"^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias",
|
||||
}
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="F5TTS_Base",
|
||||
choices=[
|
||||
"F5TTS_Base",
|
||||
],
|
||||
) # TODO: support F5TTS_v1_Base
|
||||
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
|
||||
)
|
||||
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
|
||||
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
|
||||
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
|
||||
parser.add_argument("--cfg_scale", type=float, default=4.0)
|
||||
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
|
||||
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
|
||||
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
|
||||
parser.add_argument("--dtype", type=str, default="float16", choices=["float32", "bfloat16", "float16"])
|
||||
parser.add_argument("--fp8_linear", action="store_true", help="Whether use FP8 for linear layers")
|
||||
parser.add_argument(
|
||||
"--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def convert_timm_dit(args, mapping, dtype="float32"):
|
||||
weights = {}
|
||||
tik = time.time()
|
||||
torch_dtype = str_dtype_to_torch(dtype)
|
||||
tensor_parallel = mapping.tp_size
|
||||
|
||||
model_params = dict(torch.load(args.timm_ckpt))
|
||||
model_params = {
|
||||
k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
|
||||
}
|
||||
prefix = "ema_model.transformer."
|
||||
model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
|
||||
|
||||
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
|
||||
|
||||
def get_trtllm_name(timm_name):
|
||||
for k, v in timm_to_trtllm_name.items():
|
||||
m = re.match(k, timm_name)
|
||||
if m is not None:
|
||||
if "*" in v:
|
||||
v = v.replace("*", m.groups()[0])
|
||||
return v
|
||||
return timm_name
|
||||
|
||||
weights = dict()
|
||||
for name, param in model_params.items():
|
||||
if name == "input_embed.conv_pos_embed.conv1d.0.weight" or name == "input_embed.conv_pos_embed.conv1d.2.weight":
|
||||
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype).unsqueeze(-1)
|
||||
else:
|
||||
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype)
|
||||
|
||||
assert len(weights) == len(model_params)
|
||||
|
||||
# new_prefix = 'f5_transformer.'
|
||||
new_prefix = ""
|
||||
weights = {new_prefix + key: value for key, value in weights.items()}
|
||||
import math
|
||||
|
||||
scale_factor = math.pow(64, -0.25)
|
||||
for k, v in weights.items():
|
||||
if re.match("^transformer_blocks.*.attn.to_k.weight$", k):
|
||||
weights[k] *= scale_factor
|
||||
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
|
||||
|
||||
elif re.match("^transformer_blocks.*.attn.to_k.bias$", k):
|
||||
weights[k] *= scale_factor
|
||||
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
|
||||
|
||||
elif re.match("^transformer_blocks.*.attn.to_q.weight$", k):
|
||||
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
|
||||
weights[k] *= scale_factor
|
||||
|
||||
elif re.match("^transformer_blocks.*.attn.to_q.bias$", k):
|
||||
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
|
||||
weights[k] *= scale_factor
|
||||
|
||||
elif re.match("^transformer_blocks.*.attn.to_v.weight$", k):
|
||||
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
|
||||
|
||||
elif re.match("^transformer_blocks.*.attn.to_v.bias$", k):
|
||||
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
|
||||
|
||||
elif re.match("^transformer_blocks.*.attn.to_out.weight$", k):
|
||||
weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1)
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
|
||||
print(f"Weights loaded. Total time: {t}")
|
||||
return weights
|
||||
|
||||
|
||||
def save_config(args):
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
config = {
|
||||
"architecture": "F5TTS",
|
||||
"dtype": args.dtype,
|
||||
"hidden_size": 1024,
|
||||
"num_hidden_layers": 22,
|
||||
"num_attention_heads": 16,
|
||||
"dim_head": 64,
|
||||
"dropout": 0.1,
|
||||
"ff_mult": 2,
|
||||
"mel_dim": 100,
|
||||
"text_num_embeds": 256,
|
||||
"text_dim": 512,
|
||||
"conv_layers": 4,
|
||||
"long_skip_connection": False,
|
||||
"mapping": {
|
||||
"world_size": args.cp_size * args.tp_size * args.pp_size,
|
||||
"cp_size": args.cp_size,
|
||||
"tp_size": args.tp_size,
|
||||
"pp_size": args.pp_size,
|
||||
},
|
||||
}
|
||||
if args.fp8_linear:
|
||||
config["quantization"] = {
|
||||
"quant_algo": "FP8",
|
||||
# TODO: add support for exclude modules.
|
||||
# 'exclude_modules': "*final_layer*",
|
||||
}
|
||||
|
||||
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
|
||||
json.dump(config, f, indent=4)
|
||||
|
||||
|
||||
def covert_and_save(args, rank):
|
||||
if rank == 0:
|
||||
save_config(args)
|
||||
|
||||
mapping = Mapping(
|
||||
world_size=args.cp_size * args.tp_size * args.pp_size,
|
||||
rank=rank,
|
||||
cp_size=args.cp_size,
|
||||
tp_size=args.tp_size,
|
||||
pp_size=args.pp_size,
|
||||
)
|
||||
|
||||
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
|
||||
|
||||
safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
|
||||
|
||||
|
||||
def execute(workers, func, args):
|
||||
if workers == 1:
|
||||
for rank, f in enumerate(func):
|
||||
f(args, rank)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=workers) as p:
|
||||
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
|
||||
exceptions = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
exceptions.append(e)
|
||||
assert len(exceptions) == 0, "Checkpoint conversion failed, please check error log."
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
world_size = args.cp_size * args.tp_size * args.pp_size
|
||||
|
||||
assert args.pp_size == 1, "PP is not supported yet."
|
||||
|
||||
tik = time.time()
|
||||
if args.timm_ckpt is None:
|
||||
return
|
||||
print("start execute")
|
||||
execute(args.workers, [covert_and_save] * world_size, args)
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
|
||||
print(f"Total time of converting checkpoints: {t}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from conv_stft import STFT
|
||||
from vocos import Vocos
|
||||
import argparse
|
||||
|
||||
opset_version = 17
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument(
|
||||
"--vocoder",
|
||||
type=str,
|
||||
default="vocos",
|
||||
choices=["vocos", "bigvgan"],
|
||||
help="Vocoder to export",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-path",
|
||||
type=str,
|
||||
default="./vocos_vocoder.onnx",
|
||||
help="Output path",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class ISTFTHead(nn.Module):
|
||||
def __init__(self, n_fft: int, hop_length: int):
|
||||
super().__init__()
|
||||
self.out = None
|
||||
self.stft = STFT(fft_len=n_fft, win_hop=hop_length, win_len=n_fft)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.out(x).transpose(1, 2)
|
||||
mag, p = x.chunk(2, dim=1)
|
||||
mag = torch.exp(mag)
|
||||
mag = torch.clip(mag, max=1e2)
|
||||
real = mag * torch.cos(p)
|
||||
imag = mag * torch.sin(p)
|
||||
audio = self.stft.inverse(input1=real, input2=imag, input_type="realimag")
|
||||
return audio
|
||||
|
||||
|
||||
class VocosVocoder(nn.Module):
|
||||
def __init__(self, vocos_vocoder):
|
||||
super(VocosVocoder, self).__init__()
|
||||
self.vocos_vocoder = vocos_vocoder
|
||||
istft_head_out = self.vocos_vocoder.head.out
|
||||
n_fft = self.vocos_vocoder.head.istft.n_fft
|
||||
hop_length = self.vocos_vocoder.head.istft.hop_length
|
||||
istft_head_for_export = ISTFTHead(n_fft, hop_length)
|
||||
istft_head_for_export.out = istft_head_out
|
||||
self.vocos_vocoder.head = istft_head_for_export
|
||||
|
||||
def forward(self, mel):
|
||||
waveform = self.vocos_vocoder.decode(mel)
|
||||
return waveform
|
||||
|
||||
|
||||
def export_VocosVocoder(vocos_vocoder, output_path, verbose):
|
||||
vocos_vocoder = VocosVocoder(vocos_vocoder).cuda()
|
||||
vocos_vocoder.eval()
|
||||
|
||||
dummy_batch_size = 8
|
||||
dummy_input_length = 500
|
||||
|
||||
dummy_mel = torch.randn(dummy_batch_size, 100, dummy_input_length).cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
dummy_waveform = vocos_vocoder(mel=dummy_mel)
|
||||
print(dummy_waveform.shape)
|
||||
|
||||
dummy_input = dummy_mel
|
||||
|
||||
torch.onnx.export(
|
||||
vocos_vocoder,
|
||||
dummy_input,
|
||||
output_path,
|
||||
opset_version=opset_version,
|
||||
do_constant_folding=True,
|
||||
input_names=["mel"],
|
||||
output_names=["waveform"],
|
||||
dynamic_axes={
|
||||
"mel": {0: "batch_size", 2: "input_length"},
|
||||
"waveform": {0: "batch_size", 1: "output_length"},
|
||||
},
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
print("Exported to {}".format(output_path))
|
||||
|
||||
|
||||
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None):
|
||||
if vocoder_name == "vocos":
|
||||
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
||||
if is_local:
|
||||
print(f"Load vocos from local path {local_path}")
|
||||
config_path = f"{local_path}/config.yaml"
|
||||
model_path = f"{local_path}/pytorch_model.bin"
|
||||
else:
|
||||
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
||||
repo_id = "charactr/vocos-mel-24khz"
|
||||
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
|
||||
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
|
||||
vocoder = Vocos.from_hparams(config_path)
|
||||
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
||||
vocoder.load_state_dict(state_dict)
|
||||
vocoder = vocoder.eval().to(device)
|
||||
elif vocoder_name == "bigvgan":
|
||||
raise NotImplementedError("BigVGAN is not supported yet")
|
||||
vocoder.remove_weight_norm()
|
||||
vocoder = vocoder.eval().to(device)
|
||||
return vocoder
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
vocoder = load_vocoder(vocoder_name=args.vocoder, device="cpu", hf_cache_dir=None)
|
||||
if args.vocoder == "vocos":
|
||||
export_VocosVocoder(vocoder, args.output_path, verbose=False)
|
||||
43
src/f5_tts/runtime/triton_trtllm/scripts/export_vocos_trt.sh
Normal file
43
src/f5_tts/runtime/triton_trtllm/scripts/export_vocos_trt.sh
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
|
||||
|
||||
ONNX_PATH=$1
|
||||
ENGINE_PATH=$2
|
||||
echo "ONNX_PATH: $ONNX_PATH"
|
||||
echo "ENGINE_PATH: $ENGINE_PATH"
|
||||
PRECISION="fp32"
|
||||
|
||||
|
||||
MIN_BATCH_SIZE=1
|
||||
OPT_BATCH_SIZE=1
|
||||
MAX_BATCH_SIZE=8
|
||||
|
||||
MIN_INPUT_LENGTH=1
|
||||
OPT_INPUT_LENGTH=1000
|
||||
MAX_INPUT_LENGTH=3000
|
||||
|
||||
MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}"
|
||||
MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}"
|
||||
MEL_MAX_SHAPE="${MAX_BATCH_SIZE}x100x${MAX_INPUT_LENGTH}"
|
||||
|
||||
${TRTEXEC} \
|
||||
--minShapes="mel:${MEL_MIN_SHAPE}" \
|
||||
--optShapes="mel:${MEL_OPT_SHAPE}" \
|
||||
--maxShapes="mel:${MEL_MAX_SHAPE}" \
|
||||
--onnx=${ONNX_PATH} \
|
||||
--saveEngine=${ENGINE_PATH}
|
||||
|
||||
36
src/f5_tts/runtime/triton_trtllm/scripts/fill_template.py
Normal file
36
src/f5_tts/runtime/triton_trtllm/scripts/fill_template.py
Normal file
@@ -0,0 +1,36 @@
|
||||
#! /usr/bin/env python3
|
||||
from argparse import ArgumentParser
|
||||
from string import Template
|
||||
|
||||
|
||||
def main(file_path, substitutions, in_place, participant_ids):
|
||||
with open(file_path) as f:
|
||||
pbtxt = Template(f.read())
|
||||
|
||||
sub_dict = {"max_queue_size": 0}
|
||||
sub_dict["participant_ids"] = participant_ids
|
||||
for sub in substitutions.split(","):
|
||||
key, value = sub.split(":")
|
||||
sub_dict[key] = value
|
||||
|
||||
pbtxt = pbtxt.safe_substitute(sub_dict)
|
||||
|
||||
if in_place:
|
||||
with open(file_path, "w") as f:
|
||||
f.write(pbtxt)
|
||||
else:
|
||||
print(pbtxt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("file_path", help="path of the .pbtxt to modify")
|
||||
parser.add_argument(
|
||||
"substitutions",
|
||||
help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...",
|
||||
)
|
||||
parser.add_argument("--in_place", "-i", action="store_true", help="do the operation in-place")
|
||||
parser.add_argument("--participant_ids", help="Participant IDs for the model", default="")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(**vars(args))
|
||||
@@ -24,7 +24,7 @@ updates_per_epoch = total_hours / mini_batch_hours
|
||||
|
||||
# result
|
||||
epochs = wanted_max_updates / updates_per_epoch
|
||||
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
|
||||
print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})")
|
||||
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
||||
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
||||
|
||||
|
||||
@@ -13,9 +13,9 @@ from importlib.resources import files
|
||||
import torch
|
||||
import torchaudio
|
||||
from huggingface_hub import hf_hub_download
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config
|
||||
from f5_tts.infer.utils_infer import (
|
||||
chunk_text,
|
||||
preprocess_ref_audio_text,
|
||||
@@ -80,7 +80,7 @@ class TTSStreamingProcessor:
|
||||
else "cpu"
|
||||
)
|
||||
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
||||
self.model_cls = globals()[model_cfg.model.backbone]
|
||||
self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
self.model_arc = model_cfg.model.arch
|
||||
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
|
||||
|
||||
@@ -51,7 +51,11 @@ Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussio
|
||||
|
||||
Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
|
||||
|
||||
The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
|
||||
If want to finetune with a variant version e.g. *F5TTS_v1_Base_no_zero_init*, manually download pretrained checkpoint from model weight repository and fill in the path correspondingly on web interface.
|
||||
|
||||
If use tensorboard as logger, install it first with `pip install tensorboard`.
|
||||
|
||||
<ins>The `use_ema = True` might be harmful for early-stage finetuned checkpoints</ins> (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off with finetune gradio option or `load_model(..., use_ema=False)`, see if offer better results.
|
||||
|
||||
### 3. W&B Logging
|
||||
|
||||
|
||||
@@ -122,7 +122,7 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None):
|
||||
for future in tqdm(
|
||||
chunk_futures,
|
||||
total=len(chunk),
|
||||
desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}",
|
||||
desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}",
|
||||
):
|
||||
try:
|
||||
result = future.result()
|
||||
@@ -233,7 +233,7 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
|
||||
dataset_name = out_dir.stem
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
|
||||
|
||||
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
|
||||
|
||||
@@ -198,7 +198,7 @@ def main():
|
||||
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
if "ZH" in langs:
|
||||
print(f"Bad zh transcription case: {total_bad_case_zh}")
|
||||
if "EN" in langs:
|
||||
|
||||
@@ -72,7 +72,7 @@ def main():
|
||||
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -50,7 +50,7 @@ def main():
|
||||
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -40,15 +40,15 @@ def parse_args():
|
||||
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
||||
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
||||
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
||||
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
|
||||
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
|
||||
parser.add_argument("--num_warmup_updates", type=int, default=20000, help="Warmup updates")
|
||||
parser.add_argument("--save_per_updates", type=int, default=50000, help="Save checkpoint every N updates")
|
||||
parser.add_argument(
|
||||
"--keep_last_n_checkpoints",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
|
||||
)
|
||||
parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
|
||||
parser.add_argument("--last_per_updates", type=int, default=5000, help="Save last checkpoint every N updates")
|
||||
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
|
||||
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
|
||||
parser.add_argument(
|
||||
@@ -65,7 +65,7 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Log inferenced samples per ckpt save updates",
|
||||
)
|
||||
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
|
||||
parser.add_argument("--logger", type=str, default=None, choices=[None, "wandb", "tensorboard"], help="logger")
|
||||
parser.add_argument(
|
||||
"--bnb_optimizer",
|
||||
action="store_true",
|
||||
|
||||
@@ -120,11 +120,11 @@ def load_settings(project_name):
|
||||
default_settings = {
|
||||
"exp_name": "F5TTS_v1_Base",
|
||||
"learning_rate": 1e-5,
|
||||
"batch_size_per_gpu": 1,
|
||||
"batch_size_type": "sample",
|
||||
"batch_size_per_gpu": 3200,
|
||||
"batch_size_type": "frame",
|
||||
"max_samples": 64,
|
||||
"grad_accumulation_steps": 4,
|
||||
"max_grad_norm": 1,
|
||||
"grad_accumulation_steps": 1,
|
||||
"max_grad_norm": 1.0,
|
||||
"epochs": 100,
|
||||
"num_warmup_updates": 100,
|
||||
"save_per_updates": 500,
|
||||
@@ -134,8 +134,8 @@ def load_settings(project_name):
|
||||
"file_checkpoint_train": "",
|
||||
"tokenizer_type": "pinyin",
|
||||
"tokenizer_file": "",
|
||||
"mixed_precision": "none",
|
||||
"logger": "wandb",
|
||||
"mixed_precision": "fp16",
|
||||
"logger": "none",
|
||||
"bnb_optimizer": False,
|
||||
}
|
||||
|
||||
@@ -361,27 +361,27 @@ def terminate_process(pid):
|
||||
|
||||
|
||||
def start_training(
|
||||
dataset_name="",
|
||||
exp_name="F5TTS_v1_Base",
|
||||
learning_rate=1e-5,
|
||||
batch_size_per_gpu=1,
|
||||
batch_size_type="sample",
|
||||
max_samples=64,
|
||||
grad_accumulation_steps=4,
|
||||
max_grad_norm=1.0,
|
||||
epochs=100,
|
||||
num_warmup_updates=100,
|
||||
save_per_updates=500,
|
||||
keep_last_n_checkpoints=-1,
|
||||
last_per_updates=100,
|
||||
finetune=True,
|
||||
file_checkpoint_train="",
|
||||
tokenizer_type="pinyin",
|
||||
tokenizer_file="",
|
||||
mixed_precision="fp16",
|
||||
stream=False,
|
||||
logger="wandb",
|
||||
ch_8bit_adam=False,
|
||||
dataset_name,
|
||||
exp_name,
|
||||
learning_rate,
|
||||
batch_size_per_gpu,
|
||||
batch_size_type,
|
||||
max_samples,
|
||||
grad_accumulation_steps,
|
||||
max_grad_norm,
|
||||
epochs,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
keep_last_n_checkpoints,
|
||||
last_per_updates,
|
||||
finetune,
|
||||
file_checkpoint_train,
|
||||
tokenizer_type,
|
||||
tokenizer_file,
|
||||
mixed_precision,
|
||||
stream,
|
||||
logger,
|
||||
ch_8bit_adam,
|
||||
):
|
||||
global training_process, tts_api, stop_signal
|
||||
|
||||
@@ -458,7 +458,10 @@ def start_training(
|
||||
|
||||
cmd += f" --tokenizer {tokenizer_type}"
|
||||
|
||||
cmd += f" --log_samples --logger {logger}"
|
||||
if logger != "none":
|
||||
cmd += f" --logger {logger}"
|
||||
|
||||
cmd += " --log_samples"
|
||||
|
||||
if ch_8bit_adam:
|
||||
cmd += " --bnb_optimizer"
|
||||
@@ -515,7 +518,7 @@ def start_training(
|
||||
training_process = subprocess.Popen(
|
||||
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env
|
||||
)
|
||||
yield "Training started...", gr.update(interactive=False), gr.update(interactive=True)
|
||||
yield "Training started ...", gr.update(interactive=False), gr.update(interactive=True)
|
||||
|
||||
stdout_queue = queue.Queue()
|
||||
stderr_queue = queue.Queue()
|
||||
@@ -584,7 +587,11 @@ def start_training(
|
||||
gr.update(interactive=True),
|
||||
)
|
||||
else:
|
||||
yield "Training complete!", gr.update(interactive=False), gr.update(interactive=True)
|
||||
yield (
|
||||
"Training complete or paused ...",
|
||||
gr.update(interactive=False),
|
||||
gr.update(interactive=True),
|
||||
)
|
||||
break
|
||||
|
||||
# Small sleep to prevent CPU thrashing
|
||||
@@ -598,9 +605,9 @@ def start_training(
|
||||
time.sleep(1)
|
||||
|
||||
if training_process is None:
|
||||
text_info = "train stop"
|
||||
text_info = "Train stopped !"
|
||||
else:
|
||||
text_info = "train complete !"
|
||||
text_info = "Train complete at end !"
|
||||
|
||||
except Exception as e: # Catch all exceptions
|
||||
# Ensure that we reset the training process variable in case of an error
|
||||
@@ -615,11 +622,11 @@ def stop_training():
|
||||
global training_process, stop_signal
|
||||
|
||||
if training_process is None:
|
||||
return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
|
||||
return "Train not running !", gr.update(interactive=True), gr.update(interactive=False)
|
||||
terminate_process_tree(training_process.pid)
|
||||
# training_process = None
|
||||
stop_signal = True
|
||||
return "train stop", gr.update(interactive=True), gr.update(interactive=False)
|
||||
return "Train stopped !", gr.update(interactive=True), gr.update(interactive=False)
|
||||
|
||||
|
||||
def get_list_projects():
|
||||
@@ -958,21 +965,23 @@ def calculate_train(
|
||||
)
|
||||
|
||||
|
||||
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str:
|
||||
def prune_checkpoint(checkpoint_path: str, new_checkpoint_path: str, save_ema: bool, safetensors: bool) -> str:
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, weights_only=True)
|
||||
print("Original Checkpoint Keys:", checkpoint.keys())
|
||||
|
||||
ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
|
||||
if ema_model_state_dict is None:
|
||||
return "No 'ema_model_state_dict' found in the checkpoint."
|
||||
to_retain = "ema_model_state_dict" if save_ema else "model_state_dict"
|
||||
try:
|
||||
model_state_dict_to_retain = checkpoint[to_retain]
|
||||
except KeyError:
|
||||
return f"{to_retain} not found in the checkpoint."
|
||||
|
||||
if safetensors:
|
||||
new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors")
|
||||
save_file(ema_model_state_dict, new_checkpoint_path)
|
||||
save_file(model_state_dict_to_retain, new_checkpoint_path)
|
||||
else:
|
||||
new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt")
|
||||
new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
|
||||
new_checkpoint = {"ema_model_state_dict": model_state_dict_to_retain}
|
||||
torch.save(new_checkpoint, new_checkpoint_path)
|
||||
|
||||
return f"New checkpoint saved at: {new_checkpoint_path}"
|
||||
@@ -1013,7 +1022,10 @@ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
|
||||
|
||||
ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema])
|
||||
|
||||
torch.save(ckpt, new_ckpt_path)
|
||||
if new_ckpt_path.endswith(".safetensors"):
|
||||
save_file(ema_sd, new_ckpt_path)
|
||||
elif new_ckpt_path.endswith(".pt"):
|
||||
torch.save(ckpt, new_ckpt_path)
|
||||
|
||||
return vocab_new
|
||||
|
||||
@@ -1125,7 +1137,7 @@ def vocab_check(project_name):
|
||||
info = "You can train using your language !"
|
||||
else:
|
||||
vocab_miss = ",".join(miss_symbols)
|
||||
info = f"The following symbols are missing in your language {len(miss_symbols)}\n\n"
|
||||
info = f"The following {len(miss_symbols)} symbols are missing in your language\n\n"
|
||||
|
||||
return info, vocab_miss
|
||||
|
||||
@@ -1212,6 +1224,9 @@ def infer(
|
||||
|
||||
print("update >> ", device_test, file_checkpoint, use_ema)
|
||||
|
||||
if seed == -1: # -1 used for random
|
||||
seed = None
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
tts_api.infer(
|
||||
ref_file=ref_audio,
|
||||
@@ -1430,9 +1445,9 @@ Skip this step if you have your dataset, metadata.csv, and a folder wavs with al
|
||||
)
|
||||
|
||||
audio_speaker = gr.File(label="Voice", type="filepath", file_count="multiple")
|
||||
txt_lang = gr.Text(label="Language", value="English")
|
||||
txt_lang = gr.Textbox(label="Language", value="English")
|
||||
bt_transcribe = bt_create = gr.Button("Transcribe")
|
||||
txt_info_transcribe = gr.Text(label="Info", value="")
|
||||
txt_info_transcribe = gr.Textbox(label="Info", value="")
|
||||
bt_transcribe.click(
|
||||
fn=transcribe_all,
|
||||
inputs=[cm_project, audio_speaker, txt_lang, ch_manual],
|
||||
@@ -1443,7 +1458,7 @@ Skip this step if you have your dataset, metadata.csv, and a folder wavs with al
|
||||
random_sample_transcribe = gr.Button("Random Sample")
|
||||
|
||||
with gr.Row():
|
||||
random_text_transcribe = gr.Text(label="Text")
|
||||
random_text_transcribe = gr.Textbox(label="Text")
|
||||
random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
|
||||
|
||||
random_sample_transcribe.click(
|
||||
@@ -1458,7 +1473,7 @@ Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are incl
|
||||
```""")
|
||||
|
||||
check_button = gr.Button("Check Vocab")
|
||||
txt_info_check = gr.Text(label="Info", value="")
|
||||
txt_info_check = gr.Textbox(label="Info", value="")
|
||||
|
||||
gr.Markdown("""```plaintext
|
||||
Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
|
||||
@@ -1478,7 +1493,7 @@ Using the extended model, you can finetune to a new language that is missing sym
|
||||
txt_count_symbol = gr.Textbox(label="New Vocab Size", value="", scale=1)
|
||||
|
||||
extend_button = gr.Button("Extend")
|
||||
txt_info_extend = gr.Text(label="Info", value="")
|
||||
txt_info_extend = gr.Textbox(label="Info", value="")
|
||||
|
||||
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
|
||||
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
|
||||
@@ -1518,8 +1533,8 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
|
||||
ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False)
|
||||
|
||||
bt_prepare = bt_create = gr.Button("Prepare")
|
||||
txt_info_prepare = gr.Text(label="Info", value="")
|
||||
txt_vocab_prepare = gr.Text(label="Vocab", value="")
|
||||
txt_info_prepare = gr.Textbox(label="Info", value="")
|
||||
txt_vocab_prepare = gr.Textbox(label="Vocab", value="")
|
||||
|
||||
bt_prepare.click(
|
||||
fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare]
|
||||
@@ -1528,7 +1543,7 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
|
||||
random_sample_prepare = gr.Button("Random Sample")
|
||||
|
||||
with gr.Row():
|
||||
random_text_prepare = gr.Text(label="Tokenizer")
|
||||
random_text_prepare = gr.Textbox(label="Tokenizer")
|
||||
random_audio_prepare = gr.Audio(label="Audio", type="filepath")
|
||||
|
||||
random_sample_prepare.click(
|
||||
@@ -1541,50 +1556,60 @@ The auto-setting is still experimental. Set a large value of epoch if not sure;
|
||||
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
|
||||
```""")
|
||||
with gr.Row():
|
||||
bt_calculate = bt_create = gr.Button("Auto Settings")
|
||||
exp_name = gr.Radio(label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"])
|
||||
tokenizer_file = gr.Textbox(label="Tokenizer File")
|
||||
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint")
|
||||
|
||||
with gr.Row():
|
||||
ch_finetune = bt_create = gr.Checkbox(label="Finetune")
|
||||
lb_samples = gr.Label(label="Samples")
|
||||
batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
|
||||
bt_calculate = bt_create = gr.Button("Auto Settings")
|
||||
|
||||
with gr.Row():
|
||||
ch_finetune = bt_create = gr.Checkbox(label="Finetune", value=True)
|
||||
tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
|
||||
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
|
||||
epochs = gr.Number(label="Epochs")
|
||||
learning_rate = gr.Number(label="Learning Rate", step=0.5e-5)
|
||||
max_grad_norm = gr.Number(label="Max Gradient Norm")
|
||||
num_warmup_updates = gr.Number(label="Warmup Updates")
|
||||
|
||||
with gr.Row():
|
||||
exp_name = gr.Radio(
|
||||
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
|
||||
batch_size_type = gr.Radio(
|
||||
label="Batch Size Type",
|
||||
choices=["frame", "sample"],
|
||||
info="frame is calculated as seconds * sampling_rate / hop_length",
|
||||
)
|
||||
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
|
||||
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", info="N frames or N samples")
|
||||
grad_accumulation_steps = gr.Number(
|
||||
label="Gradient Accumulation Steps", info="Effective batch size is multiplied by this value"
|
||||
)
|
||||
max_samples = gr.Number(label="Max Samples", info="Maximum number of samples per single GPU batch")
|
||||
|
||||
with gr.Row():
|
||||
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=3200)
|
||||
max_samples = gr.Number(label="Max Samples", value=64)
|
||||
|
||||
with gr.Row():
|
||||
grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
|
||||
max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
|
||||
|
||||
with gr.Row():
|
||||
epochs = gr.Number(label="Epochs", value=100)
|
||||
num_warmup_updates = gr.Number(label="Warmup Updates", value=100)
|
||||
|
||||
with gr.Row():
|
||||
save_per_updates = gr.Number(label="Save per Updates", value=500)
|
||||
save_per_updates = gr.Number(
|
||||
label="Save per Updates",
|
||||
info="Save intermediate checkpoints every N updates",
|
||||
minimum=10,
|
||||
)
|
||||
keep_last_n_checkpoints = gr.Number(
|
||||
label="Keep Last N Checkpoints",
|
||||
value=-1,
|
||||
step=1,
|
||||
precision=0,
|
||||
info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
|
||||
info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N",
|
||||
minimum=-1,
|
||||
)
|
||||
last_per_updates = gr.Number(label="Last per Updates", value=100)
|
||||
last_per_updates = gr.Number(
|
||||
label="Last per Updates",
|
||||
info="Save latest checkpoint with suffix _last.pt every N updates",
|
||||
minimum=10,
|
||||
)
|
||||
gr.Radio(label="") # placeholder
|
||||
|
||||
with gr.Row():
|
||||
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
||||
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="fp16")
|
||||
cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
|
||||
start_button = gr.Button("Start Training")
|
||||
stop_button = gr.Button("Stop Training", interactive=False)
|
||||
mixed_precision = gr.Radio(label="Mixed Precision", choices=["none", "fp16", "bf16"])
|
||||
cd_logger = gr.Radio(label="Logger", choices=["none", "wandb", "tensorboard"])
|
||||
with gr.Column():
|
||||
start_button = gr.Button("Start Training")
|
||||
stop_button = gr.Button("Stop Training", interactive=False)
|
||||
|
||||
if projects_selelect is not None:
|
||||
(
|
||||
@@ -1631,7 +1656,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
ch_8bit_adam.value = bnb_optimizer_value
|
||||
|
||||
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
|
||||
txt_info_train = gr.Text(label="Info", value="")
|
||||
txt_info_train = gr.Textbox(label="Info", value="")
|
||||
|
||||
list_audios, select_audio = get_audio_project(projects_selelect, False)
|
||||
|
||||
@@ -1760,7 +1785,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
|
||||
with gr.TabItem("Test Model"):
|
||||
gr.Markdown("""```plaintext
|
||||
SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
|
||||
Check the use_ema setting (True or False) for your model to see what works best for you. Set seed to -1 for random.
|
||||
```""")
|
||||
exp_name = gr.Radio(
|
||||
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
|
||||
@@ -1770,11 +1795,13 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
|
||||
with gr.Row():
|
||||
nfe_step = gr.Number(label="NFE Step", value=32)
|
||||
speed = gr.Slider(label="Speed", value=1.0, minimum=0.3, maximum=2.0, step=0.1)
|
||||
seed = gr.Number(label="Seed", value=-1, minimum=-1)
|
||||
seed = gr.Number(label="Random Seed", value=-1, minimum=-1)
|
||||
remove_silence = gr.Checkbox(label="Remove Silence")
|
||||
|
||||
ch_use_ema = gr.Checkbox(label="Use EMA", value=True)
|
||||
with gr.Row():
|
||||
ch_use_ema = gr.Checkbox(
|
||||
label="Use EMA", value=True, info="Turn off at early stage might offer better results"
|
||||
)
|
||||
cm_checkpoint = gr.Dropdown(
|
||||
choices=list_checkpoints, value=checkpoint_select, label="Checkpoints", allow_custom_value=True
|
||||
)
|
||||
@@ -1782,20 +1809,20 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
|
||||
|
||||
random_sample_infer = gr.Button("Random Sample")
|
||||
|
||||
ref_text = gr.Textbox(label="Ref Text")
|
||||
ref_audio = gr.Audio(label="Audio Ref", type="filepath")
|
||||
gen_text = gr.Textbox(label="Gen Text")
|
||||
ref_text = gr.Textbox(label="Reference Text")
|
||||
ref_audio = gr.Audio(label="Reference Audio", type="filepath")
|
||||
gen_text = gr.Textbox(label="Text to Generate")
|
||||
|
||||
random_sample_infer.click(
|
||||
fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
txt_info_gpu = gr.Textbox("", label="Device")
|
||||
seed_info = gr.Text(label="Seed :")
|
||||
check_button_infer = gr.Button("Infer")
|
||||
txt_info_gpu = gr.Textbox("", label="Inference on Device :")
|
||||
seed_info = gr.Textbox(label="Used Random Seed :")
|
||||
check_button_infer = gr.Button("Inference")
|
||||
|
||||
gen_audio = gr.Audio(label="Audio Gen", type="filepath")
|
||||
gen_audio = gr.Audio(label="Generated Audio", type="filepath")
|
||||
|
||||
check_button_infer.click(
|
||||
fn=infer,
|
||||
@@ -1822,14 +1849,16 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
|
||||
gr.Markdown("""```plaintext
|
||||
Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining.
|
||||
```""")
|
||||
txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
|
||||
txt_path_checkpoint_small = gr.Text(label="Path to Output:")
|
||||
ch_safetensors = gr.Checkbox(label="Safetensors", value="")
|
||||
txt_info_reduse = gr.Text(label="Info", value="")
|
||||
reduse_button = gr.Button("Reduce")
|
||||
txt_path_checkpoint = gr.Textbox(label="Path to Checkpoint:")
|
||||
txt_path_checkpoint_small = gr.Textbox(label="Path to Output:")
|
||||
with gr.Row():
|
||||
ch_save_ema = gr.Checkbox(label="Save EMA checkpoint", value=True)
|
||||
ch_safetensors = gr.Checkbox(label="Save with safetensors format", value=True)
|
||||
txt_info_reduse = gr.Textbox(label="Info", value="")
|
||||
reduse_button = gr.Button("Prune")
|
||||
reduse_button.click(
|
||||
fn=extract_and_save_ema_model,
|
||||
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_safetensors],
|
||||
fn=prune_checkpoint,
|
||||
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_save_ema, ch_safetensors],
|
||||
outputs=[txt_info_reduse],
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from importlib.resources import files
|
||||
import hydra
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.model import CFM, DiT, UNetT, Trainer # noqa: F401. used for config
|
||||
from f5_tts.model import CFM, Trainer
|
||||
from f5_tts.model.dataset import load_dataset
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
|
||||
@@ -14,60 +14,60 @@ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to
|
||||
|
||||
|
||||
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
|
||||
def main(cfg):
|
||||
model_cls = globals()[cfg.model.backbone]
|
||||
model_arc = cfg.model.arch
|
||||
tokenizer = cfg.model.tokenizer
|
||||
mel_spec_type = cfg.model.mel_spec.mel_spec_type
|
||||
def main(model_cfg):
|
||||
model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}")
|
||||
model_arc = model_cfg.model.arch
|
||||
tokenizer = model_cfg.model.tokenizer
|
||||
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
|
||||
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
|
||||
exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}"
|
||||
wandb_resume_id = None
|
||||
|
||||
# set text tokenizer
|
||||
if tokenizer != "custom":
|
||||
tokenizer_path = cfg.datasets.name
|
||||
tokenizer_path = model_cfg.datasets.name
|
||||
else:
|
||||
tokenizer_path = cfg.model.tokenizer_path
|
||||
tokenizer_path = model_cfg.model.tokenizer_path
|
||||
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
||||
|
||||
# set model
|
||||
model = CFM(
|
||||
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
|
||||
mel_spec_kwargs=cfg.model.mel_spec,
|
||||
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels),
|
||||
mel_spec_kwargs=model_cfg.model.mel_spec,
|
||||
vocab_char_map=vocab_char_map,
|
||||
)
|
||||
|
||||
# init trainer
|
||||
trainer = Trainer(
|
||||
model,
|
||||
epochs=cfg.optim.epochs,
|
||||
learning_rate=cfg.optim.learning_rate,
|
||||
num_warmup_updates=cfg.optim.num_warmup_updates,
|
||||
save_per_updates=cfg.ckpts.save_per_updates,
|
||||
keep_last_n_checkpoints=cfg.ckpts.keep_last_n_checkpoints,
|
||||
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
|
||||
batch_size_per_gpu=cfg.datasets.batch_size_per_gpu,
|
||||
batch_size_type=cfg.datasets.batch_size_type,
|
||||
max_samples=cfg.datasets.max_samples,
|
||||
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
|
||||
max_grad_norm=cfg.optim.max_grad_norm,
|
||||
logger=cfg.ckpts.logger,
|
||||
epochs=model_cfg.optim.epochs,
|
||||
learning_rate=model_cfg.optim.learning_rate,
|
||||
num_warmup_updates=model_cfg.optim.num_warmup_updates,
|
||||
save_per_updates=model_cfg.ckpts.save_per_updates,
|
||||
keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints,
|
||||
checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")),
|
||||
batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu,
|
||||
batch_size_type=model_cfg.datasets.batch_size_type,
|
||||
max_samples=model_cfg.datasets.max_samples,
|
||||
grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
|
||||
max_grad_norm=model_cfg.optim.max_grad_norm,
|
||||
logger=model_cfg.ckpts.logger,
|
||||
wandb_project="CFM-TTS",
|
||||
wandb_run_name=exp_name,
|
||||
wandb_resume_id=wandb_resume_id,
|
||||
last_per_updates=cfg.ckpts.last_per_updates,
|
||||
log_samples=cfg.ckpts.log_samples,
|
||||
bnb_optimizer=cfg.optim.bnb_optimizer,
|
||||
last_per_updates=model_cfg.ckpts.last_per_updates,
|
||||
log_samples=model_cfg.ckpts.log_samples,
|
||||
bnb_optimizer=model_cfg.optim.bnb_optimizer,
|
||||
mel_spec_type=mel_spec_type,
|
||||
is_local_vocoder=cfg.model.vocoder.is_local,
|
||||
local_vocoder_path=cfg.model.vocoder.local_path,
|
||||
cfg_dict=OmegaConf.to_container(cfg, resolve=True),
|
||||
is_local_vocoder=model_cfg.model.vocoder.is_local,
|
||||
local_vocoder_path=model_cfg.model.vocoder.local_path,
|
||||
model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True),
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|
||||
train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec)
|
||||
trainer.train(
|
||||
train_dataset,
|
||||
num_workers=cfg.datasets.num_workers,
|
||||
num_workers=model_cfg.datasets.num_workers,
|
||||
resumable_with_seed=666, # seed for shuffling dataset
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user