This commit is contained in:
SWivid
2024-10-08 21:56:51 +08:00
commit 074881635d
25 changed files with 4263 additions and 0 deletions

173
.gitignore vendored Normal file
View File

@@ -0,0 +1,173 @@
# Customed
.vscode/
tests/
runs/
data/
ckpts/
wandb/
results/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

395
LICENSE Normal file
View File

@@ -0,0 +1,395 @@
Attribution 4.0 International
=======================================================================
Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.
Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.
Considerations for licensors: Our public licenses are
intended for use by those authorized to give the public
permission to use material in ways otherwise restricted by
copyright and certain other rights. Our licenses are
irrevocable. Licensors should read and understand the terms
and conditions of the license they choose before applying it.
Licensors should also secure all rights necessary before
applying our licenses so that the public can reuse the
material as expected. Licensors should clearly mark any
material not subject to the license. This includes other CC-
licensed material, or material used under an exception or
limitation to copyright. More considerations for licensors:
wiki.creativecommons.org/Considerations_for_licensors
Considerations for the public: By using one of our public
licenses, a licensor grants the public permission to use the
licensed material under specified terms and conditions. If
the licensor's permission is not necessary for any reason--for
example, because of any applicable exception or limitation to
copyright--then that use is not regulated by the license. Our
licenses grant only permissions under copyright and certain
other rights that a licensor has authority to grant. Use of
the licensed material may still be restricted for other
reasons, including because others have copyright or other
rights in the material. A licensor may make special requests,
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More considerations
for the public:
wiki.creativecommons.org/Considerations_for_licensees
=======================================================================
Creative Commons Attribution 4.0 International Public License
By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution 4.0 International Public License ("Public License"). To the
extent this Public License may be interpreted as a contract, You are
granted the Licensed Rights in consideration of Your acceptance of
these terms and conditions, and the Licensor grants You such rights in
consideration of benefits the Licensor receives from making the
Licensed Material available under these terms and conditions.
Section 1 -- Definitions.
a. Adapted Material means material subject to Copyright and Similar
Rights that is derived from or based upon the Licensed Material
and in which the Licensed Material is translated, altered,
arranged, transformed, or otherwise modified in a manner requiring
permission under the Copyright and Similar Rights held by the
Licensor. For purposes of this Public License, where the Licensed
Material is a musical work, performance, or sound recording,
Adapted Material is always produced where the Licensed Material is
synched in timed relation with a moving image.
b. Adapter's License means the license You apply to Your Copyright
and Similar Rights in Your contributions to Adapted Material in
accordance with the terms and conditions of this Public License.
c. Copyright and Similar Rights means copyright and/or similar rights
closely related to copyright including, without limitation,
performance, broadcast, sound recording, and Sui Generis Database
Rights, without regard to how the rights are labeled or
categorized. For purposes of this Public License, the rights
specified in Section 2(b)(1)-(2) are not Copyright and Similar
Rights.
d. Effective Technological Measures means those measures that, in the
absence of proper authority, may not be circumvented under laws
fulfilling obligations under Article 11 of the WIPO Copyright
Treaty adopted on December 20, 1996, and/or similar international
agreements.
e. Exceptions and Limitations means fair use, fair dealing, and/or
any other exception or limitation to Copyright and Similar Rights
that applies to Your use of the Licensed Material.
f. Licensed Material means the artistic or literary work, database,
or other material to which the Licensor applied this Public
License.
g. Licensed Rights means the rights granted to You subject to the
terms and conditions of this Public License, which are limited to
all Copyright and Similar Rights that apply to Your use of the
Licensed Material and that the Licensor has authority to license.
h. Licensor means the individual(s) or entity(ies) granting rights
under this Public License.
i. Share means to provide material to the public by any means or
process that requires permission under the Licensed Rights, such
as reproduction, public display, public performance, distribution,
dissemination, communication, or importation, and to make material
available to the public including in ways that members of the
public may access the material from a place and at a time
individually chosen by them.
j. Sui Generis Database Rights means rights other than copyright
resulting from Directive 96/9/EC of the European Parliament and of
the Council of 11 March 1996 on the legal protection of databases,
as amended and/or succeeded, as well as other essentially
equivalent rights anywhere in the world.
k. You means the individual or entity exercising the Licensed Rights
under this Public License. Your has a corresponding meaning.
Section 2 -- Scope.
a. License grant.
1. Subject to the terms and conditions of this Public License,
the Licensor hereby grants You a worldwide, royalty-free,
non-sublicensable, non-exclusive, irrevocable license to
exercise the Licensed Rights in the Licensed Material to:
a. reproduce and Share the Licensed Material, in whole or
in part; and
b. produce, reproduce, and Share Adapted Material.
2. Exceptions and Limitations. For the avoidance of doubt, where
Exceptions and Limitations apply to Your use, this Public
License does not apply, and You do not need to comply with
its terms and conditions.
3. Term. The term of this Public License is specified in Section
6(a).
4. Media and formats; technical modifications allowed. The
Licensor authorizes You to exercise the Licensed Rights in
all media and formats whether now known or hereafter created,
and to make technical modifications necessary to do so. The
Licensor waives and/or agrees not to assert any right or
authority to forbid You from making technical modifications
necessary to exercise the Licensed Rights, including
technical modifications necessary to circumvent Effective
Technological Measures. For purposes of this Public License,
simply making modifications authorized by this Section 2(a)
(4) never produces Adapted Material.
5. Downstream recipients.
a. Offer from the Licensor -- Licensed Material. Every
recipient of the Licensed Material automatically
receives an offer from the Licensor to exercise the
Licensed Rights under the terms and conditions of this
Public License.
b. No downstream restrictions. You may not offer or impose
any additional or different terms or conditions on, or
apply any Effective Technological Measures to, the
Licensed Material if doing so restricts exercise of the
Licensed Rights by any recipient of the Licensed
Material.
6. No endorsement. Nothing in this Public License constitutes or
may be construed as permission to assert or imply that You
are, or that Your use of the Licensed Material is, connected
with, or sponsored, endorsed, or granted official status by,
the Licensor or others designated to receive attribution as
provided in Section 3(a)(1)(A)(i).
b. Other rights.
1. Moral rights, such as the right of integrity, are not
licensed under this Public License, nor are publicity,
privacy, and/or other similar personality rights; however, to
the extent possible, the Licensor waives and/or agrees not to
assert any such rights held by the Licensor to the limited
extent necessary to allow You to exercise the Licensed
Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this
Public License.
3. To the extent possible, the Licensor waives any right to
collect royalties from You for the exercise of the Licensed
Rights, whether directly or through a collecting society
under any voluntary or waivable statutory or compulsory
licensing scheme. In all other cases the Licensor expressly
reserves any right to collect such royalties.
Section 3 -- License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the
following conditions.
a. Attribution.
1. If You Share the Licensed Material (including in modified
form), You must:
a. retain the following if it is supplied by the Licensor
with the Licensed Material:
i. identification of the creator(s) of the Licensed
Material and any others designated to receive
attribution, in any reasonable manner requested by
the Licensor (including by pseudonym if
designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of
warranties;
v. a URI or hyperlink to the Licensed Material to the
extent reasonably practicable;
b. indicate if You modified the Licensed Material and
retain an indication of any previous modifications; and
c. indicate the Licensed Material is licensed under this
Public License, and include the text of, or the URI or
hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any
reasonable manner based on the medium, means, and context in
which You Share the Licensed Material. For example, it may be
reasonable to satisfy the conditions by providing a URI or
hyperlink to a resource that includes the required
information.
3. If requested by the Licensor, You must remove any of the
information required by Section 3(a)(1)(A) to the extent
reasonably practicable.
4. If You Share Adapted Material You produce, the Adapter's
License You apply must not prevent recipients of the Adapted
Material from complying with this Public License.
Section 4 -- Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
to extract, reuse, reproduce, and Share all or a substantial
portion of the contents of the database;
b. if You include all or a substantial portion of the database
contents in a database in which You have Sui Generis Database
Rights, then the database in which You have Sui Generis Database
Rights (but not its individual contents) is Adapted Material; and
c. You must comply with the conditions in Section 3(a) if You Share
all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
c. The disclaimer of warranties and limitation of liability provided
above shall be interpreted in a manner that, to the extent
possible, most closely approximates an absolute disclaimer and
waiver of all liability.
Section 6 -- Term and Termination.
a. This Public License applies for the term of the Copyright and
Similar Rights licensed here. However, if You fail to comply with
this Public License, then Your rights under this Public License
terminate automatically.
b. Where Your right to use the Licensed Material has terminated under
Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided
it is cured within 30 days of Your discovery of the
violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any
right the Licensor may have to seek remedies for Your violations
of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the
Licensed Material under separate terms or conditions or stop
distributing the Licensed Material at any time; however, doing so
will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
License.
Section 7 -- Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different
terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the
Licensed Material not stated herein are separate from and
independent of the terms and conditions of this Public License.
Section 8 -- Interpretation.
a. For the avoidance of doubt, this Public License does not, and
shall not be interpreted to, reduce, limit, restrict, or impose
conditions on any use of the Licensed Material that could lawfully
be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is
deemed unenforceable, it shall be automatically reformed to the
minimum extent necessary to make it enforceable. If the provision
cannot be reformed, it shall be severed from this Public License
without affecting the enforceability of the remaining terms and
conditions.
c. No term or condition of this Public License will be waived and no
failure to comply consented to unless expressly agreed to by the
Licensor.
d. Nothing in this Public License constitutes or may be interpreted
as a limitation upon, or waiver of, any privileges and immunities
that apply to the Licensor or You, including from the legal
processes of any jurisdiction or authority.
=======================================================================
Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.
Creative Commons may be contacted at creativecommons.org.

73
README.md Normal file
View File

@@ -0,0 +1,73 @@
# F5-TTS
A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
## Installation
```bash
pip install -r requirements.txt
```
## Dataset
```bash
# prepare custom dataset up to your need
# download corresponding dataset first, and fill in the path in scripts
python scripts/prepare_emilia.py
python scripts/prepare_wenetspeech4tts.py
```
## Training
```bash
# setup accelerate config, e.g. use multi-gpu ddp, fp16
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
accelerate config
accelerate launch test_train.py
```
## Inference
Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
```bash
# single test inference
python test_infer_single.py
```
## Evaluation
download seedtts testset. https://github.com/BytedanceSpeech/seed-tts-eval \
download test-clean. http://www.openslr.org/12/ \
uzip and place under data/, and fill in the path of test-clean in `test_infer_batch.py` \
our librispeech-pc 4-10s subset is already under data/ in this repo
zh asr model ckpt. https://huggingface.co/funasr/paraformer-zh \
en asr model ckpt. https://huggingface.co/Systran/faster-whisper-large-v3 \
wavlm model ckpt. https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view \
fill in the path of ckpts in `test_infer_batch.py`
```bash
# batch inference for evaluations
accelerate config # if not set before
bash test_infer_batch.sh
```
faster-whisper if cuda11,
`pip install --force-reinstall ctranslate2==3.24.0`
(recommended) `pip install faster-whisper==0.10.1`,
otherwise may encounter asr failure (output abnormal repetition)
```bash
# evaluation for Seed-TTS test set
python scripts/eval_seedtts_testset.py
# evaluation for LibriSpeech-PC test-clean cross sentence
python scripts/eval_librispeech_test_clean.py
```
## Appreciation
- <a href="https://arxiv.org/abs/2406.18009">E2-TTS</a> brilliant work, simple and effective
- <a href="https://arxiv.org/abs/2407.05361">Emilia</a>, <a href="https://arxiv.org/abs/2406.05763">WenetSpeech4TTS</a> valuable datasets
- <a href="https://github.com/lucidrains/e2-tts-pytorch">lucidrains</a> initial CFM structure</a> with also <a href="https://github.com/bfs18">bfs18</a> for discussion</a>
- <a href="https://arxiv.org/abs/2403.03206">SD3</a> & <a href="https://github.com/huggingface/diffusers">Huggingface diffusers</a> DiT and MMDiT code structure
- <a href="https://github.com/modelscope/FunASR">FunASR</a>, <a href="https://github.com/SYSTRAN/faster-whisper">faster-whisper</a> & <a href="https://github.com/microsoft/UniSpeech">UniSpeech</a> for evaluation tools
- <a href="https://github.com/rtqichen/torchdiffeq">torchdiffeq</a> as ODE solver, <a href="https://huggingface.co/charactr/vocos-mel-24khz">Vocos</a> as vocoder

7
model/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
from model.cfm import CFM
from model.backbones.unett import UNetT
from model.backbones.dit import DiT
from model.backbones.mmdit import MMDiT
from model.trainer import Trainer

20
model/backbones/README.md Normal file
View File

@@ -0,0 +1,20 @@
## Backbones quick introduction
### unett.py
- flat unet transformer
- structure same as in e2-tts & voicebox paper except using rotary pos emb
- update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
### dit.py
- adaln-zero dit
- embedded timestep as condition
- concatted noised_input + masked_cond + embedded_text, linear proj in
- possible abs pos emb & convnextv2 blocks for embedded text before concat
- possible long skip connection (first layer to last layer)
### mmdit.py
- sd3 structure
- timestep as condition
- left stream: text embedded and applied a abs pos emb
- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett

158
model/backbones/dit.py Normal file
View File

@@ -0,0 +1,158 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from x_transformers.x_transformers import RotaryEmbedding
from model.modules import (
TimestepEmbedding,
ConvNeXtV2Block,
ConvPositionEmbedding,
DiTBlock,
AdaLayerNormZero_Final,
precompute_freqs_cis, get_pos_embed_indices,
)
# Text embedding
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
else:
self.extra_modeling = False
def forward(self, text: int['b nt'], seq_len, drop_text = False):
batch, text_len = text.shape[0], text.shape[1]
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text_len), value = 0)
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), dtype=torch.long)
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
# convnextv2 blocks
text = self.text_blocks(text)
return text
# noised input audio and context mixing embedding
class InputEmbedding(nn.Module):
def __init__(self, mel_dim, text_dim, out_dim):
super().__init__()
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
if drop_audio_cond: # cfg for cond audio
cond = torch.zeros_like(cond)
x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
x = self.conv_pos_embed(x) + x
return x
# Transformer backbone using DiT blocks
class DiT(nn.Module):
def __init__(self, *,
dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
long_skip_connection = False,
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
self.dim = dim
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[
DiTBlock(
dim = dim,
heads = heads,
dim_head = dim_head,
ff_mult = ff_mult,
dropout = dropout
)
for _ in range(depth)
]
)
self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
def forward(
self,
x: float['b n d'], # nosied input audio
cond: float['b n d'], # masked cond audio
text: int['b nt'], # text
time: float['b'] | float[''], # time step
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool['b n'] | None = None,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = repeat(time, ' -> b', b = batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
rope = self.rotary_embed.forward_from_seq_len(seq_len)
if self.long_skip_connection is not None:
residual = x
for block in self.transformer_blocks:
x = block(x, t, mask = mask, rope = rope)
if self.long_skip_connection is not None:
x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
x = self.norm_out(x, t)
output = self.proj_out(x)
return output

136
model/backbones/mmdit.py Normal file
View File

@@ -0,0 +1,136 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
import torch
from torch import nn
from einops import repeat
from x_transformers.x_transformers import RotaryEmbedding
from model.modules import (
TimestepEmbedding,
ConvPositionEmbedding,
MMDiTBlock,
AdaLayerNormZero_Final,
precompute_freqs_cis, get_pos_embed_indices,
)
# text embedding
class TextEmbedding(nn.Module):
def __init__(self, out_dim, text_num_embeds):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
self.precompute_max_pos = 1024
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
text = text + 1
if drop_text:
text = torch.zeros_like(text)
text = self.text_embed(text)
# sinus pos emb
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
batch_text_len = text.shape[1]
pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
return text
# noised input & masked cond audio embedding
class AudioEmbedding(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.linear = nn.Linear(2 * in_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
if drop_audio_cond:
cond = torch.zeros_like(cond)
x = torch.cat((x, cond), dim = -1)
x = self.linear(x)
x = self.conv_pos_embed(x) + x
return x
# Transformer backbone using MM-DiT blocks
class MMDiT(nn.Module):
def __init__(self, *,
dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
text_num_embeds = 256, mel_dim = 100,
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
self.text_embed = TextEmbedding(dim, text_num_embeds)
self.audio_embed = AudioEmbedding(mel_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
self.dim = dim
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[
MMDiTBlock(
dim = dim,
heads = heads,
dim_head = dim_head,
dropout = dropout,
ff_mult = ff_mult,
context_pre_only = i == depth - 1,
)
for i in range(depth)
]
)
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
def forward(
self,
x: float['b n d'], # nosied input audio
cond: float['b n d'], # masked cond audio
text: int['b nt'], # text
time: float['b'] | float[''], # time step
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool['b n'] | None = None,
):
batch = x.shape[0]
if time.ndim == 0:
time = repeat(time, ' -> b', b = batch)
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
c = self.text_embed(text, drop_text = drop_text)
x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
seq_len = x.shape[1]
text_len = text.shape[1]
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
for block in self.transformer_blocks:
c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
x = self.norm_out(x, t)
output = self.proj_out(x)
return output

201
model/backbones/unett.py Normal file
View File

@@ -0,0 +1,201 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
from typing import Literal
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat, pack, unpack
from x_transformers import RMSNorm
from x_transformers.x_transformers import RotaryEmbedding
from model.modules import (
TimestepEmbedding,
ConvNeXtV2Block,
ConvPositionEmbedding,
Attention,
AttnProcessor,
FeedForward,
precompute_freqs_cis, get_pos_embed_indices,
)
# Text embedding
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
else:
self.extra_modeling = False
def forward(self, text: int['b nt'], seq_len, drop_text = False):
batch, text_len = text.shape[0], text.shape[1]
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text_len), value = 0)
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), dtype=torch.long)
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
# convnextv2 blocks
text = self.text_blocks(text)
return text
# noised input audio and context mixing embedding
class InputEmbedding(nn.Module):
def __init__(self, mel_dim, text_dim, out_dim):
super().__init__()
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
if drop_audio_cond: # cfg for cond audio
cond = torch.zeros_like(cond)
x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
x = self.conv_pos_embed(x) + x
return x
# Flat UNet Transformer backbone
class UNetT(nn.Module):
def __init__(self, *,
dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
):
super().__init__()
assert depth % 2 == 0, "UNet-Transformer's depth should be even."
self.time_embed = TimestepEmbedding(dim)
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
# transformer layers & skip connections
self.dim = dim
self.skip_connect_type = skip_connect_type
needs_skip_proj = skip_connect_type == 'concat'
self.depth = depth
self.layers = nn.ModuleList([])
for idx in range(depth):
is_later_half = idx >= (depth // 2)
attn_norm = RMSNorm(dim)
attn = Attention(
processor = AttnProcessor(),
dim = dim,
heads = heads,
dim_head = dim_head,
dropout = dropout,
)
ff_norm = RMSNorm(dim)
ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
self.layers.append(nn.ModuleList([
skip_proj,
attn_norm,
attn,
ff_norm,
ff,
]))
self.norm_out = RMSNorm(dim)
self.proj_out = nn.Linear(dim, mel_dim)
def forward(
self,
x: float['b n d'], # nosied input audio
cond: float['b n d'], # masked cond audio
text: int['b nt'], # text
time: float['b'] | float[''], # time step
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool['b n'] | None = None,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = repeat(time, ' -> b', b = batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
# postfix time t to input x, [b n d] -> [b n+1 d]
x, ps = pack((t, x), 'b * d')
if mask is not None:
mask = F.pad(mask, (1, 0), value=1)
rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
# flat unet transformer
skip_connect_type = self.skip_connect_type
skips = []
for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
layer = idx + 1
# skip connection logic
is_first_half = layer <= (self.depth // 2)
is_later_half = not is_first_half
if is_first_half:
skips.append(x)
if is_later_half:
skip = skips.pop()
if skip_connect_type == 'concat':
x = torch.cat((x, skip), dim = -1)
x = maybe_skip_proj(x)
elif skip_connect_type == 'add':
x = x + skip
# attention and feedforward blocks
x = attn(attn_norm(x), rope = rope, mask = mask) + x
x = ff(ff_norm(x)) + x
assert len(skips) == 0
_, x = unpack(self.norm_out(x), ps, 'b * d')
return self.proj_out(x)

273
model/cfm.py Normal file
View File

@@ -0,0 +1,273 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
from typing import Callable
from random import random
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torchdiffeq import odeint
from einops import rearrange
from model.modules import MelSpec
from model.utils import (
default, exists,
list_str_to_idx, list_str_to_tensor,
lens_to_mask, mask_from_frac_lengths,
)
class CFM(nn.Module):
def __init__(
self,
transformer: nn.Module,
sigma = 0.,
odeint_kwargs: dict = dict(
# atol = 1e-5,
# rtol = 1e-5,
method = 'euler' # 'midpoint'
),
audio_drop_prob = 0.3,
cond_drop_prob = 0.2,
num_channels = None,
mel_spec_module: nn.Module | None = None,
mel_spec_kwargs: dict = dict(),
frac_lengths_mask: tuple[float, float] = (0.7, 1.),
vocab_char_map: dict[str: int] | None = None
):
super().__init__()
self.frac_lengths_mask = frac_lengths_mask
# mel spec
self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
num_channels = default(num_channels, self.mel_spec.n_mel_channels)
self.num_channels = num_channels
# classifier-free guidance
self.audio_drop_prob = audio_drop_prob
self.cond_drop_prob = cond_drop_prob
# transformer
self.transformer = transformer
dim = transformer.dim
self.dim = dim
# conditional flow related
self.sigma = sigma
# sampling related
self.odeint_kwargs = odeint_kwargs
# vocab map for tokenization
self.vocab_char_map = vocab_char_map
@property
def device(self):
return next(self.parameters()).device
@torch.no_grad()
def sample(
self,
cond: float['b n d'] | float['b nw'],
text: int['b nt'] | list[str],
duration: int | int['b'],
*,
lens: int['b'] | None = None,
steps = 32,
cfg_strength = 1.,
sway_sampling_coef = None,
seed: int | None = None,
max_duration = 4096,
vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
no_ref_audio = False,
duplicate_test = False,
t_inter = 0.1,
):
self.eval()
# raw wave
if cond.ndim == 2:
cond = self.mel_spec(cond)
cond = rearrange(cond, 'b d n -> b n d')
assert cond.shape[-1] == self.num_channels
batch, cond_seq_len, device = *cond.shape[:2], cond.device
if not exists(lens):
lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
# text
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
if exists(text):
text_lens = (text != -1).sum(dim = -1)
lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
# duration
cond_mask = lens_to_mask(lens)
if isinstance(duration, int):
duration = torch.full((batch,), duration, device = device, dtype = torch.long)
duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
duration = duration.clamp(max = max_duration)
max_duration = duration.amax()
# duplicate test corner for inner time step oberservation
if duplicate_test:
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
cond_mask = rearrange(cond_mask, '... -> ... 1')
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
mask = lens_to_mask(duration)
# test for no ref audio
if no_ref_audio:
cond = torch.zeros_like(cond)
# neural ode
def fn(t, x):
# at each step, conditioning is fixed
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
# predict flow
pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
if cfg_strength < 1e-5:
return pred
null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
return pred + (pred - null_pred) * cfg_strength
# noise input
# to make sure batch inference result is same with different batch size, and for sure single inference
# still some difference maybe due to convolutional layers
y0 = []
for dur in duration:
if exists(seed):
torch.manual_seed(seed)
y0.append(torch.randn(dur, self.num_channels, device = self.device))
y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
t_start = 0
# duplicate test corner for inner time step oberservation
if duplicate_test:
t_start = t_inter
y0 = (1 - t_start) * y0 + t_start * test_cond
steps = int(steps * (1 - t_start))
t = torch.linspace(t_start, 1, steps, device = self.device)
if sway_sampling_coef is not None:
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
sampled = trajectory[-1]
out = sampled
out = torch.where(cond_mask, cond, out)
if exists(vocoder):
out = rearrange(out, 'b n d -> b d n')
out = vocoder(out)
return out, trajectory
def forward(
self,
inp: float['b n d'] | float['b nw'], # mel or raw wave
text: int['b nt'] | list[str],
*,
lens: int['b'] | None = None,
noise_scheduler: str | None = None,
):
# handle raw wave
if inp.ndim == 2:
inp = self.mel_spec(inp)
inp = rearrange(inp, 'b d n -> b n d')
assert inp.shape[-1] == self.num_channels
batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
# handle text as string
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
# lens and mask
if not exists(lens):
lens = torch.full((batch,), seq_len, device = device)
mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch
# get a random span to mask out for training conditionally
frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
if exists(mask):
rand_span_mask &= mask
# mel is x1
x1 = inp
# x0 is gaussian noise
x0 = torch.randn_like(x1)
# time step
time = torch.rand((batch,), dtype = dtype, device = self.device)
# TODO. noise_scheduler
# sample xt (φ_t(x) in the paper)
t = rearrange(time, 'b -> b 1 1')
φ = (1 - t) * x0 + t * x1
flow = x1 - x0
# only predict what is within the random mask span for infilling
cond = torch.where(
rand_span_mask[..., None],
torch.zeros_like(x1), x1
)
# transformer and cfg training with a drop rate
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
if random() < self.cond_drop_prob: # p_uncond in voicebox paper
drop_audio_cond = True
drop_text = True
else:
drop_text = False
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
# flow matching loss
loss = F.mse_loss(pred, flow, reduction = 'none')
loss = loss[rand_span_mask]
return loss.mean(), cond, pred

242
model/dataset.py Normal file
View File

@@ -0,0 +1,242 @@
import json
import random
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, Sampler
import torchaudio
from datasets import load_dataset, load_from_disk
from datasets import Dataset as Dataset_
from einops import rearrange
from model.modules import MelSpec
class HFDataset(Dataset):
def __init__(
self,
hf_dataset: Dataset,
target_sample_rate = 24_000,
n_mel_channels = 100,
hop_length = 256,
):
self.data = hf_dataset
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
def get_frame_len(self, index):
row = self.data[index]
audio = row['audio']['array']
sample_rate = row['audio']['sampling_rate']
return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
def __len__(self):
return len(self.data)
def __getitem__(self, index):
row = self.data[index]
audio = row['audio']['array']
# logger.info(f"Audio shape: {audio.shape}")
sample_rate = row['audio']['sampling_rate']
duration = audio.shape[-1] / sample_rate
if duration > 30 or duration < 0.3:
return self.__getitem__((index + 1) % len(self.data))
audio_tensor = torch.from_numpy(audio).float()
if sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
audio_tensor = resampler(audio_tensor)
audio_tensor = rearrange(audio_tensor, 't -> 1 t')
mel_spec = self.mel_spectrogram(audio_tensor)
mel_spec = rearrange(mel_spec, '1 d t -> d t')
text = row['text']
return dict(
mel_spec = mel_spec,
text = text,
)
class CustomDataset(Dataset):
def __init__(
self,
custom_dataset: Dataset,
durations = None,
target_sample_rate = 24_000,
hop_length = 256,
n_mel_channels = 100,
preprocessed_mel = False,
):
self.data = custom_dataset
self.durations = durations
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.preprocessed_mel = preprocessed_mel
if not preprocessed_mel:
self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
def get_frame_len(self, index):
if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
return self.durations[index] * self.target_sample_rate / self.hop_length
return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
def __len__(self):
return len(self.data)
def __getitem__(self, index):
row = self.data[index]
audio_path = row["audio_path"]
text = row["text"]
duration = row["duration"]
if self.preprocessed_mel:
mel_spec = torch.tensor(row["mel_spec"])
else:
audio, source_sample_rate = torchaudio.load(audio_path)
if duration > 30 or duration < 0.3:
return self.__getitem__((index + 1) % len(self.data))
if source_sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
audio = resampler(audio)
mel_spec = self.mel_spectrogram(audio)
mel_spec = rearrange(mel_spec, '1 d t -> d t')
return dict(
mel_spec = mel_spec,
text = text,
)
# Dynamic Batch Sampler
class DynamicBatchSampler(Sampler[list[int]]):
""" Extension of Sampler that will do the following:
1. Change the batch size (essentially number of sequences)
in a batch to ensure that the total number of frames are less
than a certain threshold.
2. Make sure the padding efficiency in the batch is high.
"""
def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
self.sampler = sampler
self.frames_threshold = frames_threshold
self.max_samples = max_samples
indices, batches = [], []
data_source = self.sampler.data_source
for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
indices.append((idx, data_source.get_frame_len(idx)))
indices.sort(key=lambda elem : elem[1])
batch = []
batch_frames = 0
for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
batch.append(idx)
batch_frames += frame_len
else:
if len(batch) > 0:
batches.append(batch)
if frame_len <= self.frames_threshold:
batch = [idx]
batch_frames = frame_len
else:
batch = []
batch_frames = 0
if not drop_last and len(batch) > 0:
batches.append(batch)
del indices
# if want to have different batches between epochs, may just set a seed and log it in ckpt
# cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
# e.g. for epoch n, use (random_seed + n)
random.seed(random_seed)
random.shuffle(batches)
self.batches = batches
def __iter__(self):
return iter(self.batches)
def __len__(self):
return len(self.batches)
# Load dataset
def load_dataset(
dataset_name: str,
tokenizer: str,
dataset_type: str = "CustomDataset",
audio_type: str = "raw",
mel_spec_kwargs: dict = dict()
) -> CustomDataset | HFDataset:
print("Loading dataset ...")
if dataset_type == "CustomDataset":
if audio_type == "raw":
try:
train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
except:
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
preprocessed_mel = False
elif audio_type == "mel":
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
preprocessed_mel = True
with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
data_dict = json.load(f)
durations = data_dict["duration"]
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
elif dataset_type == "HFDataset":
print("Should manually modify the path of huggingface dataset to your need.\n" +
"May also the corresponding script cuz different dataset may have different format.")
pre, post = dataset_name.split("_")
train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
return train_dataset
# collation
def collate_fn(batch):
mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
max_mel_length = mel_lengths.amax()
padded_mel_specs = []
for spec in mel_specs: # TODO. maybe records mask for attention here
padding = (0, max_mel_length - spec.size(-1))
padded_spec = F.pad(spec, padding, value = 0)
padded_mel_specs.append(padded_spec)
mel_specs = torch.stack(padded_mel_specs)
text = [item['text'] for item in batch]
text_lengths = torch.LongTensor([len(item) for item in text])
return dict(
mel = mel_specs,
mel_lengths = mel_lengths,
text = text,
text_lengths = text_lengths,
)

268
model/ecapa_tdnn.py Normal file
View File

@@ -0,0 +1,268 @@
# just for speaker similarity evaluation, third-party code
# From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
''' Res2Conv1d + BatchNorm1d + ReLU
'''
class Res2Conv1dReluBn(nn.Module):
'''
in_channels == out_channels == channels
'''
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
super().__init__()
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
self.scale = scale
self.width = channels // scale
self.nums = scale if scale == 1 else scale - 1
self.convs = []
self.bns = []
for i in range(self.nums):
self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
self.bns.append(nn.BatchNorm1d(self.width))
self.convs = nn.ModuleList(self.convs)
self.bns = nn.ModuleList(self.bns)
def forward(self, x):
out = []
spx = torch.split(x, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
# Order: conv -> relu -> bn
sp = self.convs[i](sp)
sp = self.bns[i](F.relu(sp))
out.append(sp)
if self.scale != 1:
out.append(spx[self.nums])
out = torch.cat(out, dim=1)
return out
''' Conv1d + BatchNorm1d + ReLU
'''
class Conv1dReluBn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
super().__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
self.bn = nn.BatchNorm1d(out_channels)
def forward(self, x):
return self.bn(F.relu(self.conv(x)))
''' The SE connection of 1D case.
'''
class SE_Connect(nn.Module):
def __init__(self, channels, se_bottleneck_dim=128):
super().__init__()
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
def forward(self, x):
out = x.mean(dim=2)
out = F.relu(self.linear1(out))
out = torch.sigmoid(self.linear2(out))
out = x * out.unsqueeze(2)
return out
''' SE-Res2Block of the ECAPA-TDNN architecture.
'''
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
# return nn.Sequential(
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
# SE_Connect(channels)
# )
class SE_Res2Block(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
super().__init__()
self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
)
def forward(self, x):
residual = x
if self.shortcut:
residual = self.shortcut(x)
x = self.Conv1dReluBn1(x)
x = self.Res2Conv1dReluBn(x)
x = self.Conv1dReluBn2(x)
x = self.SE_Connect(x)
return x + residual
''' Attentive weighted mean and standard deviation pooling.
'''
class AttentiveStatsPool(nn.Module):
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
super().__init__()
self.global_context_att = global_context_att
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
if global_context_att:
self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
else:
self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
def forward(self, x):
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1)
else:
x_in = x
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
alpha = torch.tanh(self.linear1(x_in))
# alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
std = torch.sqrt(residuals.clamp(min=1e-9))
return torch.cat([mean, std], dim=1)
class ECAPA_TDNN(nn.Module):
def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
super().__init__()
self.feat_type = feat_type
self.feature_selection = feature_selection
self.update_extract = update_extract
self.sr = sr
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
try:
local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
except:
self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
self.feat_num = self.get_feat_num()
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
if feat_type != 'fbank' and feat_type != 'mfcc':
freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
for name, param in self.feature_extract.named_parameters():
for freeze_val in freeze_list:
if freeze_val in name:
param.requires_grad = False
break
if not self.update_extract:
for param in self.feature_extract.parameters():
param.requires_grad = False
self.instance_norm = nn.InstanceNorm1d(feat_dim)
# self.channels = [channels] * 4 + [channels * 3]
self.channels = [channels] * 4 + [1536]
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
cat_channels = channels * 3
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
def get_feat_num(self):
self.feature_extract.eval()
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
with torch.no_grad():
features = self.feature_extract(wav)
select_feature = features[self.feature_selection]
if isinstance(select_feature, (list, tuple)):
return len(select_feature)
else:
return 1
def get_feat(self, x):
if self.update_extract:
x = self.feature_extract([sample for sample in x])
else:
with torch.no_grad():
if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
else:
x = self.feature_extract([sample for sample in x])
if self.feat_type == 'fbank':
x = x.log()
if self.feat_type != "fbank" and self.feat_type != "mfcc":
x = x[self.feature_selection]
if isinstance(x, (list, tuple)):
x = torch.stack(x, dim=0)
else:
x = x.unsqueeze(0)
norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
x = (norm_weights * x).sum(dim=0)
x = torch.transpose(x, 1, 2) + 1e-6
x = self.instance_norm(x)
return x
def forward(self, x):
x = self.get_feat(x)
out1 = self.layer1(x)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out = torch.cat([out2, out3, out4], dim=1)
out = F.relu(self.conv(out))
out = self.bn(self.pooling(out))
out = self.linear(out)
return out
def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)

575
model/modules.py Normal file
View File

@@ -0,0 +1,575 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
from typing import Optional
import math
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
from einops import rearrange
from x_transformers.x_transformers import apply_rotary_pos_emb
# raw wav to mel spec
class MelSpec(nn.Module):
def __init__(
self,
filter_length = 1024,
hop_length = 256,
win_length = 1024,
n_mel_channels = 100,
target_sample_rate = 24_000,
normalize = False,
power = 1,
norm = None,
center = True,
):
super().__init__()
self.n_mel_channels = n_mel_channels
self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate = target_sample_rate,
n_fft = filter_length,
win_length = win_length,
hop_length = hop_length,
n_mels = n_mel_channels,
power = power,
center = center,
normalized = normalize,
norm = norm,
)
self.register_buffer('dummy', torch.tensor(0), persistent = False)
def forward(self, inp):
if len(inp.shape) == 3:
inp = rearrange(inp, 'b 1 nw -> b nw')
assert len(inp.shape) == 2
if self.dummy.device != inp.device:
self.to(inp.device)
mel = self.mel_stft(inp)
mel = mel.clamp(min = 1e-5).log()
return mel
# sinusoidal position embedding
class SinusPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# convolutional position embedding
class ConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size = 31, groups = 16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
nn.Mish(),
nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
nn.Mish(),
)
def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.)
x = rearrange(x, 'b n d -> b d n')
x = self.conv1d(x)
out = rearrange(x, 'b d n -> b n d')
if mask is not None:
out = out.masked_fill(~mask, 0.)
return out
# rotary positional embedding related
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def get_pos_embed_indices(start, length, max_pos, scale=1.):
# length = length if isinstance(length, int) else length.max()
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
pos = start.unsqueeze(1) + (
torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
scale.unsqueeze(1)).long()
# avoid extra long error.
pos = torch.where(pos < max_pos, pos, max_pos - 1)
return pos
# Global Response Normalization layer (Instance Normalization ?)
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.transpose(1, 2) # b n d -> b d n
x = self.dwconv(x)
x = x.transpose(1, 2) # b d n -> b n d
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
return residual + x
# AdaLayerNormZero
# return with modulated x for attn input, and params for later mlp modulation
class AdaLayerNormZero(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 6)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb = None):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
# AdaLayerNormZero for final layer
# return only with modulated x for attn input, cuz no more mlp modulation
class AdaLayerNormZero_Final(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 2)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
# FeedForward
class FeedForward(nn.Module):
def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
activation = nn.GELU(approximate=approximate)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
activation
)
self.ff = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.ff(x)
# Attention with possible joint part
# modified from diffusers/src/diffusers/models/attention_processor.py
class Attention(nn.Module):
def __init__(
self,
processor: JointAttnProcessor | AttnProcessor,
dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only = None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.processor = processor
self.dim = dim
self.heads = heads
self.inner_dim = dim_head * heads
self.dropout = dropout
self.context_dim = context_dim
self.context_pre_only = context_pre_only
self.to_q = nn.Linear(dim, self.inner_dim)
self.to_k = nn.Linear(dim, self.inner_dim)
self.to_v = nn.Linear(dim, self.inner_dim)
if self.context_dim is not None:
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
if self.context_pre_only is not None:
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, dim))
self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_out_c = nn.Linear(self.inner_dim, dim)
def forward(
self,
x: float['b n d'], # noised input x
c: float['b n d'] = None, # context c
mask: bool['b n'] | None = None,
rope = None, # rotary position embedding for x
c_rope = None, # rotary position embedding for c
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
else:
return self.processor(self, x, mask = mask, rope = rope)
# Attention processor
class AttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: Attention,
x: float['b n d'], # noised input x
mask: bool['b n'] | None = None,
rope = None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x.shape[0]
# `sample` projections.
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = mask
attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if mask is not None:
mask = rearrange(mask, 'b n -> b n 1')
x = x.masked_fill(~mask, 0.)
return x
# Joint Attention processor for MM-DiT
# modified from diffusers/src/diffusers/models/attention_processor.py
class JointAttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: Attention,
x: float['b n d'], # noised input x
c: float['b nt d'] = None, # context c, here text
mask: bool['b n'] | None = None,
rope = None, # rotary position embedding for x
c_rope = None, # rotary position embedding for c
) -> torch.FloatTensor:
residual = x
batch_size = c.shape[0]
# `sample` projections.
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# `context` projections.
c_query = attn.to_q_c(c)
c_key = attn.to_k_c(c)
c_value = attn.to_v_c(c)
# apply rope for context and noised input independently
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
if c_rope is not None:
freqs, xpos_scale = c_rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
# attention
query = torch.cat([query, c_query], dim=1)
key = torch.cat([key, c_key], dim=1)
value = torch.cat([value, c_value], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# Split the attention outputs.
x, c = (
x[:, :residual.shape[1]],
x[:, residual.shape[1]:],
)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if not attn.context_pre_only:
c = attn.to_out_c(c)
if mask is not None:
mask = rearrange(mask, 'b n -> b n 1')
x = x.masked_fill(~mask, 0.)
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
return x, c
# DiT Block
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor = AttnProcessor(),
dim = dim,
heads = heads,
dim_head = dim_head,
dropout = dropout,
)
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
attn_output = self.attn(x=norm, mask=mask, rope=rope)
# process attention output for input x
x = x + gate_msa.unsqueeze(1) * attn_output
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm)
x = x + gate_mlp.unsqueeze(1) * ff_output
return x
# MMDiT Block https://arxiv.org/abs/2403.03206
class MMDiTBlock(nn.Module):
r"""
modified from diffusers/src/diffusers/models/attention.py
notes.
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
_x: noised input related. (right part)
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
"""
def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
super().__init__()
self.context_pre_only = context_pre_only
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
self.attn_norm_x = AdaLayerNormZero(dim)
self.attn = Attention(
processor = JointAttnProcessor(),
dim = dim,
heads = heads,
dim_head = dim_head,
dropout = dropout,
context_dim = dim,
context_pre_only = context_pre_only,
)
if not context_pre_only:
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
else:
self.ff_norm_c = None
self.ff_c = None
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
# pre-norm & modulation for attention input
if self.context_pre_only:
norm_c = self.attn_norm_c(c, t)
else:
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
# attention
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
# process attention output for context c
if self.context_pre_only:
c = None
else: # if not last layer
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
c_ff_output = self.ff_c(norm_c)
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
# process attention output for input x
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
x_ff_output = self.ff_x(norm_x)
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
return c, x
# time step conditioning embedding
class TimestepEmbedding(nn.Module):
def __init__(self, dim, freq_embed_dim=256):
super().__init__()
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(
nn.Linear(freq_embed_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
def forward(self, timestep: float['b']):
time_hidden = self.time_embed(timestep)
time = self.time_mlp(time_hidden) # b d
return time

245
model/trainer.py Normal file
View File

@@ -0,0 +1,245 @@
from __future__ import annotations
import os
import gc
from tqdm import tqdm
import wandb
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from einops import rearrange
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from ema_pytorch import EMA
from model import CFM
from model.utils import exists, default
from model.dataset import DynamicBatchSampler, collate_fn
# trainer
class Trainer:
def __init__(
self,
model: CFM,
epochs,
learning_rate,
num_warmup_updates = 20000,
save_per_updates = 1000,
checkpoint_path = None,
batch_size = 32,
batch_size_type: str = "sample",
max_samples = 32,
grad_accumulation_steps = 1,
max_grad_norm = 1.0,
noise_scheduler: str | None = None,
duration_predictor: torch.nn.Module | None = None,
wandb_project = "test_e2-tts",
wandb_run_name = "test_run",
wandb_resume_id: str = None,
last_per_steps = None,
accelerate_kwargs: dict = dict(),
ema_kwargs: dict = dict()
):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
self.accelerator = Accelerator(
log_with = "wandb",
kwargs_handlers = [ddp_kwargs],
gradient_accumulation_steps = grad_accumulation_steps,
**accelerate_kwargs
)
if exists(wandb_resume_id):
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
else:
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
self.accelerator.init_trackers(
project_name = wandb_project,
init_kwargs=init_kwargs,
config={"epochs": epochs,
"learning_rate": learning_rate,
"num_warmup_updates": num_warmup_updates,
"batch_size": batch_size,
"batch_size_type": batch_size_type,
"max_samples": max_samples,
"grad_accumulation_steps": grad_accumulation_steps,
"max_grad_norm": max_grad_norm,
"gpus": self.accelerator.num_processes,
"noise_scheduler": noise_scheduler}
)
self.model = model
if self.is_main:
self.ema_model = EMA(
model,
include_online_model = False,
**ema_kwargs
)
self.ema_model.to(self.accelerator.device)
self.epochs = epochs
self.num_warmup_updates = num_warmup_updates
self.save_per_updates = save_per_updates
self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
self.batch_size = batch_size
self.batch_size_type = batch_size_type
self.max_samples = max_samples
self.grad_accumulation_steps = grad_accumulation_steps
self.max_grad_norm = max_grad_norm
self.noise_scheduler = noise_scheduler
self.duration_predictor = duration_predictor
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
self.model, self.optimizer = self.accelerator.prepare(
self.model, self.optimizer
)
@property
def is_main(self):
return self.accelerator.is_main_process
def save_checkpoint(self, step, last=False):
self.accelerator.wait_for_everyone()
if self.is_main:
checkpoint = dict(
model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
ema_model_state_dict = self.ema_model.state_dict(),
scheduler_state_dict = self.scheduler.state_dict(),
step = step
)
if not os.path.exists(self.checkpoint_path):
os.makedirs(self.checkpoint_path)
if last == True:
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
print(f"Saved last checkpoint at step {step}")
else:
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
def load_checkpoint(self):
if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
return 0
self.accelerator.wait_for_everyone()
if "model_last.pt" in os.listdir(self.checkpoint_path):
latest_checkpoint = "model_last.pt"
else:
latest_checkpoint = sorted(os.listdir(self.checkpoint_path), key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
if self.is_main:
self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
if self.scheduler:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
step = checkpoint['step']
del checkpoint; gc.collect()
return step
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
if exists(resumable_with_seed):
generator = torch.Generator()
generator.manual_seed(resumable_with_seed)
else:
generator = None
if self.batch_size_type == "sample":
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True,
batch_size=self.batch_size, shuffle=True, generator=generator)
elif self.batch_size_type == "frame":
self.accelerator.even_batches = False
sampler = SequentialSampler(train_dataset)
batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True,
batch_sampler=batch_sampler)
else:
raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but recieved {self.batch_size_type}")
# accelerator.prepare() dispatches batches to devices;
# which means the length of dataloader calculated before, should consider the number of devices
warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
# otherwise by default with split_batches=False, warmup steps change with num_processes
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
decay_steps = total_steps - warmup_steps
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
self.scheduler = SequentialLR(self.optimizer,
schedulers=[warmup_scheduler, decay_scheduler],
milestones=[warmup_steps])
train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
start_step = self.load_checkpoint()
global_step = start_step
if exists(resumable_with_seed):
orig_epoch_step = len(train_dataloader)
skipped_epoch = int(start_step // orig_epoch_step)
skipped_batch = start_step % orig_epoch_step
skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
else:
skipped_epoch = 0
for epoch in range(skipped_epoch, self.epochs):
self.model.train()
if exists(resumable_with_seed) and epoch == skipped_epoch:
progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
initial=skipped_batch, total=orig_epoch_step)
else:
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
for batch in progress_bar:
with self.accelerator.accumulate(self.model):
text_inputs = batch['text']
mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
mel_lengths = batch["mel_lengths"]
# TODO. add duration predictor training
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
self.accelerator.backward(loss)
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
if self.is_main:
self.ema_model.update()
global_step += 1
if self.accelerator.is_local_main_process:
self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
self.save_checkpoint(global_step)
if global_step % self.last_per_steps == 0:
self.save_checkpoint(global_step, last=True)
self.accelerator.end_training()

545
model/utils.py Normal file
View File

@@ -0,0 +1,545 @@
from __future__ import annotations
import os
import re
import math
import random
import string
from tqdm import tqdm
from collections import defaultdict
import matplotlib
matplotlib.use("Agg")
import matplotlib.pylab as plt
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torchaudio
import einx
from einops import rearrange, reduce
import jieba
from pypinyin import lazy_pinyin, Style
import zhconv
from zhon.hanzi import punctuation
from jiwer import compute_measures
from funasr import AutoModel
from faster_whisper import WhisperModel
from model.ecapa_tdnn import ECAPA_TDNN_SMALL
from model.modules import MelSpec
# seed everything
def seed_everything(seed = 0):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
# tensor helpers
def lens_to_mask(
t: int['b'],
length: int | None = None
) -> bool['b n']:
if not exists(length):
length = t.amax()
seq = torch.arange(length, device = t.device)
return einx.less('n, b -> b n', seq, t)
def mask_from_start_end_indices(
seq_len: int['b'],
start: int['b'],
end: int['b']
):
max_seq_len = seq_len.max().item()
seq = torch.arange(max_seq_len, device = start.device).long()
return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
def mask_from_frac_lengths(
seq_len: int['b'],
frac_lengths: float['b']
):
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths
rand = torch.rand_like(frac_lengths)
start = (max_start * rand).long().clamp(min = 0)
end = start + lengths
return mask_from_start_end_indices(seq_len, start, end)
def maybe_masked_mean(
t: float['b n d'],
mask: bool['b n'] = None
) -> float['b d']:
if not exists(mask):
return t.mean(dim = 1)
t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
num = reduce(t, 'b n d -> b d', 'sum')
den = reduce(mask.float(), 'b n -> b', 'sum')
return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
# simple utf-8 tokenizer, since paper went character based
def list_str_to_tensor(
text: list[str],
padding_value = -1
) -> int['b nt']:
list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
return text
# char tokenizer, based on custom dataset's extracted .txt file
def list_str_to_idx(
text: list[str] | list[list[str]],
vocab_char_map: dict[str, int], # {char: idx}
padding_value = -1
) -> int['b nt']:
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
return text
# Get tokenizer
def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
'''
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
'''
if tokenizer in ["pinyin", "char"]:
with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
elif tokenizer == "byte":
vocab_char_map = None
vocab_size = 256
return vocab_char_map, vocab_size
# convert char to pinyin
def convert_char_to_pinyin(text_list, polyphone = True):
final_text_list = []
god_knows_why_en_testset_contains_zh_quote = str.maketrans({'': '"', '': '"', '': "'", '': "'"}) # in case librispeech (orig no-pc) test-clean
for text in text_list:
char_list = []
text = text.translate(god_knows_why_en_testset_contains_zh_quote)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, 'UTF-8'))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for c in seg:
if c not in "。,、;:?!《》【】—…":
char_list.append(" ")
char_list.append(c)
else: # if mixed chinese characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
else:
if c not in "。,、;:?!《》【】—…":
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else: # if is zh punc
char_list.append(c)
final_text_list.append(char_list)
return final_text_list
# save spectrogram
def save_spectrogram(spectrogram, path):
plt.figure(figsize=(12, 4))
plt.imshow(spectrogram, origin='lower', aspect='auto')
plt.colorbar()
plt.savefig(path)
plt.close()
# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
def get_seedtts_testset_metainfo(metalst):
f = open(metalst); lines = f.readlines(); f.close()
metainfo = []
for line in lines:
if len(line.strip().split('|')) == 5:
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
elif len(line.strip().split('|')) == 4:
utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
return metainfo
# librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
f = open(metalst); lines = f.readlines(); f.close()
metainfo = []
for line in lines:
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
# ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
# gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
return metainfo
# padded to max length mel batch
def padded_mel_batch(ref_mels):
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
padded_ref_mels = []
for mel in ref_mels:
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
return padded_ref_mels
# get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
def get_inference_prompt(
metainfo,
speed = 1., tokenizer = "pinyin", polyphone = True,
target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
use_truth_duration = False,
infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
):
prompts_all = []
min_tokens = min_secs * target_sample_rate // hop_length
max_tokens = max_secs * target_sample_rate // hop_length
batch_accum = [0] * num_buckets
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
([[] for _ in range(num_buckets)] for _ in range(6))
mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
# Audio
ref_audio, ref_sr = torchaudio.load(prompt_wav)
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
if ref_rms < target_rms:
ref_audio = ref_audio * target_rms / ref_rms
assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio)
# Text
text = [prompt_text + gt_text]
if tokenizer == "pinyin":
text_list = convert_char_to_pinyin(text, polyphone = polyphone)
else:
text_list = text
# Duration, mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
if use_truth_duration:
gt_audio, gt_sr = torchaudio.load(gt_wav)
if gt_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
gt_audio = resampler(gt_audio)
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
# # test vocoder resynthesis
# ref_audio = gt_audio
else:
zh_pause_punc = r"。,、;:?!"
ref_text_len = len(prompt_text) + len(re.findall(zh_pause_punc, prompt_text))
gen_text_len = len(gt_text) + len(re.findall(zh_pause_punc, gt_text))
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = rearrange(ref_mel, '1 d n -> d n')
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert min_tokens <= total_mel_len <= max_tokens, \
f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
utts[bucket_i].append(utt)
ref_rms_list[bucket_i].append(ref_rms)
ref_mels[bucket_i].append(ref_mel)
ref_mel_lens[bucket_i].append(ref_mel_len)
total_mel_lens[bucket_i].append(total_mel_len)
final_text_list[bucket_i].extend(text_list)
batch_accum[bucket_i] += total_mel_len
if batch_accum[bucket_i] >= infer_batch_size:
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
prompts_all.append((
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i]
))
batch_accum[bucket_i] = 0
utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
# add residual
for bucket_i, bucket_frames in enumerate(batch_accum):
if bucket_frames > 0:
prompts_all.append((
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i]
))
# not only leave easy work for last workers
random.seed(666)
random.shuffle(prompts_all)
return prompts_all
# get wav_res_ref_text of seed-tts test metalst
# https://github.com/BytedanceSpeech/seed-tts-eval
def get_seed_tts_test(metalst, gen_wav_dir, gpus):
f = open(metalst)
lines = f.readlines()
f.close()
test_set_ = []
for line in tqdm(lines):
if len(line.strip().split('|')) == 5:
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
elif len(line.strip().split('|')) == 4:
utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
continue
gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
test_set_.append((gen_wav, prompt_wav, gt_text))
num_jobs = len(gpus)
if num_jobs == 1:
return [(gpus[0], test_set_)]
wav_per_job = len(test_set_) // num_jobs + 1
test_set = []
for i in range(num_jobs):
test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
return test_set
# get librispeech test-clean cross sentence test
def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
f = open(metalst)
lines = f.readlines()
f.close()
test_set_ = []
for line in tqdm(lines):
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
if eval_ground_truth:
gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
else:
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
test_set_.append((gen_wav, ref_wav, gen_txt))
num_jobs = len(gpus)
if num_jobs == 1:
return [(gpus[0], test_set_)]
wav_per_job = len(test_set_) // num_jobs + 1
test_set = []
for i in range(num_jobs):
test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
return test_set
# load asr model
def load_asr_model(lang, ckpt_dir = ""):
if lang == "zh":
model = AutoModel(
model = os.path.join(ckpt_dir, "paraformer-zh"),
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
# punc_model = os.path.join(ckpt_dir, "ct-punc"),
# spk_model = os.path.join(ckpt_dir, "cam++"),
disable_update=True,
) # following seed-tts setting
elif lang == "en":
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
model = WhisperModel(model_size, device="cuda", compute_type="float16")
return model
# WER Evaluation, the way Seed-TTS does
def run_asr_wer(args):
rank, lang, test_set, ckpt_dir = args
if lang == "zh":
torch.cuda.set_device(rank)
elif lang == "en":
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
else:
raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
punctuation_all = punctuation + string.punctuation
wers = []
for gen_wav, prompt_wav, truth in tqdm(test_set):
if lang == "zh":
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
hypo = res[0]["text"]
hypo = zhconv.convert(hypo, 'zh-cn')
elif lang == "en":
segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
hypo = ''
for segment in segments:
hypo = hypo + ' ' + segment.text
# raw_truth = truth
# raw_hypo = hypo
for x in punctuation_all:
truth = truth.replace(x, '')
hypo = hypo.replace(x, '')
truth = truth.replace(' ', ' ')
hypo = hypo.replace(' ', ' ')
if lang == "zh":
truth = " ".join([x for x in truth])
hypo = " ".join([x for x in hypo])
elif lang == "en":
truth = truth.lower()
hypo = hypo.lower()
measures = compute_measures(truth, hypo)
wer = measures["wer"]
# ref_list = truth.split(" ")
# subs = measures["substitutions"] / len(ref_list)
# dele = measures["deletions"] / len(ref_list)
# inse = measures["insertions"] / len(ref_list)
wers.append(wer)
return wers
# SIM Evaluation
def run_sim(args):
rank, test_set, ckpt_dir = args
device = f"cuda:{rank}"
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
state_dict = torch.load(ckpt_dir, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict['model'], strict=False)
use_gpu=True if torch.cuda.is_available() else False
if use_gpu:
model = model.cuda(device)
model.eval()
sim_list = []
for wav1, wav2, truth in tqdm(test_set):
wav1, sr1 = torchaudio.load(wav1)
wav2, sr2 = torchaudio.load(wav2)
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
wav1 = resample1(wav1)
wav2 = resample2(wav2)
if use_gpu:
wav1 = wav1.cuda(device)
wav2 = wav2.cuda(device)
with torch.no_grad():
emb1 = model(wav1)
emb2 = model(wav2)
sim = F.cosine_similarity(emb1, emb2)[0].item()
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
sim_list.append(sim)
return sim_list
# filter func for dirty data with many repetitions
def repetition_found(text, length = 2, tolerance = 10):
pattern_count = defaultdict(int)
for i in range(len(text) - length + 1):
pattern = text[i:i + length]
pattern_count[pattern] += 1
for pattern, count in pattern_count.items():
if count > tolerance:
return True
return False

22
requirements.txt Normal file
View File

@@ -0,0 +1,22 @@
accelerate>=0.33.0
datasets
einops>=0.8.0
einx>=0.3.0
ema_pytorch>=0.5.2
faster_whisper
funasr
jieba
jiwer
librosa
matplotlib
pypinyin
torch>=2.0
torchaudio>=2.3.0
torchdiffeq
tqdm>=4.65.0
transformers
vocos
wandb
x_transformers>=1.31.14
zhconv
zhon

View File

@@ -0,0 +1,32 @@
'''ADAPTIVE BATCH SIZE'''
print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
print(' -> least padding, gather wavs with accumulated frames in a batch\n')
# data
total_hours = 95282
mel_hop_length = 256
mel_sampling_rate = 24000
# target
wanted_max_updates = 1000000
# train params
gpus = 8
frames_per_gpu = 38400 # 8 * 38400 = 307200
grad_accum = 1
# intermediate
mini_batch_frames = frames_per_gpu * grad_accum * gpus
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
updates_per_epoch = total_hours / mini_batch_hours
steps_per_epoch = updates_per_epoch * grad_accum
# result
epochs = wanted_max_updates / updates_per_epoch
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
print(f" or approx. 0/{steps_per_epoch:.0f} steps")
# others
print(f"total {total_hours:.0f} hours")
print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")

View File

@@ -0,0 +1,35 @@
import sys, os
sys.path.append(os.getcwd())
from model import M2_TTS, UNetT, DiT, MMDiT
import torch
import thop
''' ~155M '''
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
# transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
''' ~335M '''
# FLOPs: 622.1 G, Params: 333.2 M
# transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
# FLOPs: 363.4 G, Params: 335.8 M
transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
model = M2_TTS(transformer=transformer)
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
duration = 20
frame_length = int(duration * target_sample_rate / hop_length)
text_length = 150
flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
print(f"FLOPs: {flops / 1e9} G")
print(f"Params: {params / 1e6} M")

View File

@@ -0,0 +1,67 @@
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
import sys, os
sys.path.append(os.getcwd())
import multiprocessing as mp
import numpy as np
from model.utils import (
get_librispeech_test,
run_asr_wer,
run_sim,
)
eval_task = "wer" # sim | wer
lang = "en"
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
gpus = [0,1,2,3,4,5,6,7]
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
## leading to a low similarity for the ground truth in some cases.
# test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
local = False
if local: # use local custom checkpoint dir
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
else:
asr_ckpt_dir = "" # auto download to cache dir
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
# --------------------------- WER ---------------------------
if eval_task == "wer":
wers = []
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_asr_wer, args)
for wers_ in results:
wers.extend(wers_)
wer = round(np.mean(wers)*100, 3)
print(f"\nTotal {len(wers)} samples")
print(f"WER : {wer}%")
# --------------------------- SIM ---------------------------
if eval_task == "sim":
sim_list = []
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_sim, args)
for sim_ in results:
sim_list.extend(sim_)
sim = round(sum(sim_list)/len(sim_list), 3)
print(f"\nTotal {len(sim_list)} samples")
print(f"SIM : {sim}")

View File

@@ -0,0 +1,69 @@
# Evaluate with Seed-TTS testset
import sys, os
sys.path.append(os.getcwd())
import multiprocessing as mp
import numpy as np
from model.utils import (
get_seed_tts_test,
run_asr_wer,
run_sim,
)
eval_task = "wer" # sim | wer
lang = "zh" # zh | en
metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
# gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
# zh 1.254 seems a result of 4 workers wer_seed_tts
gpus = [0,1,2,3,4,5,6,7]
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
local = False
if local: # use local custom checkpoint dir
if lang == "zh":
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
elif lang == "en":
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
else:
asr_ckpt_dir = "" # auto download to cache dir
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
# --------------------------- WER ---------------------------
if eval_task == "wer":
wers = []
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_asr_wer, args)
for wers_ in results:
wers.extend(wers_)
wer = round(np.mean(wers)*100, 3)
print(f"\nTotal {len(wers)} samples")
print(f"WER : {wer}%")
# --------------------------- SIM ---------------------------
if eval_task == "sim":
sim_list = []
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_sim, args)
for sim_ in results:
sim_list.extend(sim_)
sim = round(sum(sim_list)/len(sim_list), 3)
print(f"\nTotal {len(sim_list)} samples")
print(f"SIM : {sim}")

143
scripts/prepare_emilia.py Normal file
View File

@@ -0,0 +1,143 @@
# Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
# if use updated new version, i.e. WebDataset, feel free to modify / draft your own script
# generate audio text map for Emilia ZH & EN
# evaluate for vocab size
import sys, os
sys.path.append(os.getcwd())
from pathlib import Path
import json
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
from datasets import Dataset
from datasets.arrow_writer import ArrowWriter
from model.utils import (
repetition_found,
convert_char_to_pinyin,
)
out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
zh_filters = ["", ""]
# seems synthesized audios, or heavily code-switched
out_en = {
"EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
"EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
}
en_filters = ["ا", "", ""]
def deal_with_audio_dir(audio_dir):
audio_jsonl = audio_dir.with_suffix(".jsonl")
sub_result, durations = [], []
vocab_set = set()
bad_case_zh = 0
bad_case_en = 0
with open(audio_jsonl, "r") as f:
lines = f.readlines()
for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
obj = json.loads(line)
text = obj["text"]
if obj['language'] == "zh":
if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
bad_case_zh += 1
continue
else:
text = text.translate(str.maketrans({',': '', '!': '', '?': ''})) # not "。" cuz much code-switched
if obj['language'] == "en":
if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
bad_case_en += 1
continue
if tokenizer == "pinyin":
text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
duration = obj["duration"]
sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
durations.append(duration)
vocab_set.update(list(text))
return sub_result, durations, vocab_set, bad_case_zh, bad_case_en
def main():
assert tokenizer in ["pinyin", "char"]
result = []
duration_list = []
text_vocab_set = set()
total_bad_case_zh = 0
total_bad_case_en = 0
# process raw data
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
for lang in langs:
dataset_path = Path(os.path.join(dataset_dir, lang))
[
futures.append(executor.submit(deal_with_audio_dir, audio_dir))
for audio_dir in dataset_path.iterdir()
if audio_dir.is_dir()
]
for futures in tqdm(futures, total=len(futures)):
sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
result.extend(sub_result)
duration_list.extend(durations)
text_vocab_set.update(vocab_set)
total_bad_case_zh += bad_case_zh
total_bad_case_en += bad_case_en
executor.shutdown()
# save preprocessed dataset to disk
if not os.path.exists(f"data/{dataset_name}"):
os.makedirs(f"data/{dataset_name}")
print(f"\nSaving to data/{dataset_name} ...")
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
writer.write(line)
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# vocab map, i.e. tokenizer
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
# if tokenizer == "pinyin":
# text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
with open(f"data/{dataset_name}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "pinyin" # "pinyin" | "char"
polyphone = True
langs = ["ZH", "EN"]
dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
print(f"\nPrepare for {dataset_name}\n")
main()
# Emilia ZH & EN
# samples count 37837916 (after removal)
# pinyin vocab size 2543 (polyphone)
# total duration 95281.87 (hours)
# bad zh asr cnt 230435 (samples)
# bad eh asr cnt 37217 (samples)
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
# please be careful if using pretrained model, make sure the vocab.txt is same

View File

@@ -0,0 +1,116 @@
# generate audio text map for WenetSpeech4TTS
# evaluate for vocab size
import sys, os
sys.path.append(os.getcwd())
import json
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import torchaudio
from datasets import Dataset
from model.utils import convert_char_to_pinyin
def deal_with_sub_path_files(dataset_path, sub_path):
print(f"Dealing with: {sub_path}")
text_dir = os.path.join(dataset_path, sub_path, "txts")
audio_dir = os.path.join(dataset_path, sub_path, "wavs")
text_files = os.listdir(text_dir)
audio_paths, texts, durations = [], [], []
for text_file in tqdm(text_files):
with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
first_line = file.readline().split("\t")
audio_nm = first_line[0]
audio_path = os.path.join(audio_dir, audio_nm + ".wav")
text = first_line[1].strip()
audio_paths.append(audio_path)
if tokenizer == "pinyin":
texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
elif tokenizer == "char":
texts.append(text)
audio, sample_rate = torchaudio.load(audio_path)
durations.append(audio.shape[-1] / sample_rate)
return audio_paths, texts, durations
def main():
assert tokenizer in ["pinyin", "char"]
audio_path_list, text_list, duration_list = [], [], []
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
for dataset_path in dataset_paths:
sub_items = os.listdir(dataset_path)
sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))]
for sub_path in sub_paths:
futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path))
for future in tqdm(futures, total=len(futures)):
audio_paths, texts, durations = future.result()
audio_path_list.extend(audio_paths)
text_list.extend(texts)
duration_list.extend(durations)
executor.shutdown()
if not os.path.exists("data"):
os.makedirs("data")
print(f"\nSaving to data/{dataset_name}_{tokenizer} ...")
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease
print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
text_vocab_set = set()
for text in tqdm(text_list):
text_vocab_set.update(list(text))
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
if tokenizer == "pinyin":
text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "pinyin" # "pinyin" | "char"
polyphone = True
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
dataset_paths = [
"<SOME_PATH>/WenetSpeech4TTS/Basic",
"<SOME_PATH>/WenetSpeech4TTS/Standard",
"<SOME_PATH>/WenetSpeech4TTS/Premium",
][-dataset_choice:]
print(f"\nChoose Dataset: {dataset_name}\n")
main()
# Results (if adding alphabets with accents and symbols):
# WenetSpeech4TTS Basic Standard Premium
# samples count 3932473 1941220 407494
# pinyin vocab size 1349 1348 1344 (no polyphone)
# - - 1459 (polyphone)
# char vocab size 5264 5219 5042
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
# please be careful if using pretrained model, make sure the vocab.txt is same

202
test_infer_batch.py Normal file
View File

@@ -0,0 +1,202 @@
import os
import time
import random
from tqdm import tqdm
import argparse
import torch
import torchaudio
from accelerate import Accelerator
from einops import rearrange
from ema_pytorch import EMA
from vocos import Vocos
from model import CFM, UNetT, DiT
from model.utils import (
get_tokenizer,
get_seedtts_testset_metainfo,
get_librispeech_test_clean_metainfo,
get_inference_prompt,
)
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
# --------------------- Dataset Settings -------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
tokenizer = "pinyin"
# ---------------------- infer setting ---------------------- #
parser = argparse.ArgumentParser(description="batch inference")
parser.add_argument('-s', '--seed', default=None, type=int)
parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
parser.add_argument('-n', '--expname', required=True)
parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
parser.add_argument('-nfe', '--nfestep', default=32, type=int)
parser.add_argument('-o', '--odemethod', default="euler")
parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
parser.add_argument('-t', '--testset', required=True)
args = parser.parse_args()
seed = args.seed
dataset_name = args.dataset
exp_name = args.expname
ckpt_step = args.ckptstep
checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
nfe_step = args.nfestep
ode_method = args.odemethod
sway_sampling_coef = args.swaysampling
testset = args.testset
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
cfg_strength = 2.
speed = 1.
use_truth_duration = False
no_ref_audio = False
if exp_name == "F5TTS_Base":
model_cls = DiT
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
elif exp_name == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
if testset == "ls_pc_test_clean":
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
elif testset == "seedtts_test_zh":
metalst = "data/seedtts_testset/zh/meta.lst"
metainfo = get_seedtts_testset_metainfo(metalst)
elif testset == "seedtts_test_en":
metalst = "data/seedtts_testset/en/meta.lst"
metainfo = get_seedtts_testset_metainfo(metalst)
# path to save genereted wavs
if seed is None: seed = random.randint(-10000, 10000)
output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
f"seed{seed}_{ode_method}_nfe{nfe_step}" \
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
f"_cfg{cfg_strength}_speed{speed}" \
f"{'_gt-dur' if use_truth_duration else ''}" \
f"{'_no-ref-audio' if no_ref_audio else ''}"
# -------------------------------------------------#
use_ema = True
prompts_all = get_inference_prompt(
metainfo,
speed = speed,
tokenizer = tokenizer,
target_sample_rate = target_sample_rate,
n_mel_channels = n_mel_channels,
hop_length = hop_length,
target_rms = target_rms,
use_truth_duration = use_truth_duration,
infer_batch_size = infer_batch_size,
)
# Vocoder model
local = False
if local:
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()
else:
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
# Model
model = CFM(
transformer = model_cls(
**model_cfg,
text_num_embeds = vocab_size,
mel_dim = n_mel_channels
),
mel_spec_kwargs = dict(
target_sample_rate = target_sample_rate,
n_mel_channels = n_mel_channels,
hop_length = hop_length,
),
odeint_kwargs = dict(
method = ode_method,
),
vocab_char_map = vocab_char_map,
).to(device)
if use_ema == True:
ema_model = EMA(model, include_online_model = False).to(device)
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
ema_model.copy_params_from_ema_to_model()
else:
model.load_state_dict(checkpoint['model_state_dict'])
if not os.path.exists(output_dir) and accelerator.is_main_process:
os.makedirs(output_dir)
# start batch inference
accelerator.wait_for_everyone()
start = time.time()
with accelerator.split_between_processes(prompts_all) as prompts:
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
ref_mels = ref_mels.to(device)
ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
# Inference
with torch.inference_mode():
generated, _ = model.sample(
cond = ref_mels,
text = final_text_list,
duration = total_mel_lens,
lens = ref_mel_lens,
steps = nfe_step,
cfg_strength = cfg_strength,
sway_sampling_coef = sway_sampling_coef,
no_ref_audio = no_ref_audio,
seed = seed,
)
# Final result
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
generated_wave = vocos.decode(gen_mel_spec.cpu())
if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
timediff = time.time() - start
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")

13
test_infer_batch.sh Normal file
View File

@@ -0,0 +1,13 @@
#!/bin/bash
# e.g. F5-TTS, 16 NFE
accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
# e.g. Vanilla E2 TTS, 32 NFE
accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
# etc.

162
test_infer_single.py Normal file
View File

@@ -0,0 +1,162 @@
import os
import re
import torch
import torchaudio
from einops import rearrange
from ema_pytorch import EMA
from vocos import Vocos
from model import CFM, UNetT, DiT, MMDiT
from model.utils import (
get_tokenizer,
convert_char_to_pinyin,
save_spectrogram,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
# --------------------- Dataset Settings -------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
tokenizer = "pinyin"
dataset_name = "Emilia_ZH_EN"
# ---------------------- infer setting ---------------------- #
seed = None # int | None
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
ckpt_step = 1200000
nfe_step = 32 # 16, 32
cfg_strength = 2.
ode_method = 'euler' # euler | midpoint
sway_sampling_coef = -1.
speed = 1.
fix_duration = 27 # None (will linear estimate. if code-switched, consider fix) | float (total in seconds, include ref audio)
if exp_name == "F5TTS_Base":
model_cls = DiT
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
elif exp_name == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
output_dir = "tests"
ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
ref_text = "Some call me nature, others call me mother nature."
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
# ref_audio = "tests/ref_audio/test_zh_1_ref_short.wav"
# ref_text = "对,这就是我,万人敬仰的太乙真人。"
# gen_text = "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\""
# -------------------------------------------------#
use_ema = True
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Vocoder model
local = False
if local:
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()
else:
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
# Model
model = CFM(
transformer = model_cls(
**model_cfg,
text_num_embeds = vocab_size,
mel_dim = n_mel_channels
),
mel_spec_kwargs = dict(
target_sample_rate = target_sample_rate,
n_mel_channels = n_mel_channels,
hop_length = hop_length,
),
odeint_kwargs = dict(
method = ode_method,
),
vocab_char_map = vocab_char_map,
).to(device)
if use_ema == True:
ema_model = EMA(model, include_online_model = False).to(device)
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
ema_model.copy_params_from_ema_to_model()
else:
model.load_state_dict(checkpoint['model_state_dict'])
# Audio
audio, sr = torchaudio.load(ref_audio)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
audio = audio * target_rms / rms
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
audio = resampler(audio)
audio = audio.to(device)
# Text
text_list = [ref_text + gen_text]
if tokenizer == "pinyin":
final_text_list = convert_char_to_pinyin(text_list)
else:
final_text_list = [text_list]
print(f"text : {text_list}")
print(f"pinyin: {final_text_list}")
# Duration
ref_audio_len = audio.shape[-1] // hop_length
if fix_duration is not None:
duration = int(fix_duration * target_sample_rate / hop_length)
else: # simple linear scale calcul
zh_pause_punc = r"。,、;:?!"
ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
# Inference
with torch.inference_mode():
generated, trajectory = model.sample(
cond = audio,
text = final_text_list,
duration = duration,
steps = nfe_step,
cfg_strength = cfg_strength,
sway_sampling_coef = sway_sampling_coef,
seed = seed,
)
print(f"Generated mel: {generated.shape}")
# Final result
generated = generated[:, ref_audio_len:, :]
generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
generated_wave = vocos.decode(generated_mel_spec.cpu())
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single.png")
torchaudio.save(f"{output_dir}/test_single.wav", generated_wave, target_sample_rate)
print(f"Generated wav: {generated_wave.shape}")

91
test_train.py Normal file
View File

@@ -0,0 +1,91 @@
from model import CFM, UNetT, DiT, MMDiT, Trainer
from model.utils import get_tokenizer
from model.dataset import load_dataset
# -------------------------- Dataset Settings --------------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
tokenizer = "pinyin"
dataset_name = "Emilia_ZH_EN"
# -------------------------- Training Settings -------------------------- #
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
learning_rate = 7.5e-5
batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type = "frame" # "frame" or "sample"
max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm = 1.
epochs = 11 # use linear decay, thus epochs control the slope
num_warmup_updates = 20000 # warmup steps
save_per_updates = 50000 # save checkpoint per steps
last_per_steps = 5000 # save last checkpoint per steps
# model params
if exp_name == "F5TTS_Base":
wandb_resume_id = None
model_cls = DiT
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
elif exp_name == "E2TTS_Base":
wandb_resume_id = None
model_cls = UNetT
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
# ----------------------------------------------------------------------- #
def main():
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
mel_spec_kwargs = dict(
target_sample_rate = target_sample_rate,
n_mel_channels = n_mel_channels,
hop_length = hop_length,
)
e2tts = CFM(
transformer = model_cls(
**model_cfg,
text_num_embeds = vocab_size,
mel_dim = n_mel_channels
),
mel_spec_kwargs = mel_spec_kwargs,
vocab_char_map = vocab_char_map,
)
trainer = Trainer(
e2tts,
epochs,
learning_rate,
num_warmup_updates = num_warmup_updates,
save_per_updates = save_per_updates,
checkpoint_path = f'ckpts/{exp_name}',
batch_size = batch_size_per_gpu,
batch_size_type = batch_size_type,
max_samples = max_samples,
grad_accumulation_steps = grad_accumulation_steps,
max_grad_norm = max_grad_norm,
wandb_project = "CFM-TTS",
wandb_run_name = exp_name,
wandb_resume_id = wandb_resume_id,
last_per_steps = last_per_steps,
)
train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
trainer.train(train_dataset,
resumable_with_seed = 666 # seed for shuffling dataset
)
if __name__ == '__main__':
main()