Merge branch 'main' of github.com:lpscr/F5-TTS into lpscr-main

This commit is contained in:
SWivid
2024-10-29 23:34:00 +08:00
4 changed files with 273 additions and 14 deletions

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import os import os
import gc import gc
from tqdm import tqdm from tqdm import tqdm
import wandb
import torch import torch
from torch.optim import AdamW from torch.optim import AdamW
@@ -19,7 +19,6 @@ from f5_tts.model import CFM
from f5_tts.model.utils import exists, default from f5_tts.model.utils import exists, default
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
# trainer # trainer
@@ -39,6 +38,8 @@ class Trainer:
max_grad_norm=1.0, max_grad_norm=1.0,
noise_scheduler: str | None = None, noise_scheduler: str | None = None,
duration_predictor: torch.nn.Module | None = None, duration_predictor: torch.nn.Module | None = None,
logger: str = "wandb", # Add logger parameter wandb,tensorboard , none
log_dir: str = "logs", # Add log directory parameter
wandb_project="test_e2-tts", wandb_project="test_e2-tts",
wandb_run_name="test_run", wandb_run_name="test_run",
wandb_resume_id: str = None, wandb_resume_id: str = None,
@@ -46,24 +47,29 @@ class Trainer:
accelerate_kwargs: dict = dict(), accelerate_kwargs: dict = dict(),
ema_kwargs: dict = dict(), ema_kwargs: dict = dict(),
bnb_optimizer: bool = False, bnb_optimizer: bool = False,
export_samples=False,
): ):
# export audio and mel
self.export_samples = export_samples
if export_samples:
self.path_ckpts_project = checkpoint_path
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
logger = "wandb" if wandb.api.api_key else None self.logger = logger
print(f"Using logger: {logger}") if self.logger == "wandb":
self.accelerator = Accelerator(
log_with="wandb",
kwargs_handlers=[ddp_kwargs],
gradient_accumulation_steps=grad_accumulation_steps,
**accelerate_kwargs,
)
self.accelerator = Accelerator(
log_with=logger,
kwargs_handlers=[ddp_kwargs],
gradient_accumulation_steps=grad_accumulation_steps,
**accelerate_kwargs,
)
if logger == "wandb":
if exists(wandb_resume_id): if exists(wandb_resume_id):
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}} init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
else: else:
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
self.accelerator.init_trackers( self.accelerator.init_trackers(
project_name=wandb_project, project_name=wandb_project,
init_kwargs=init_kwargs, init_kwargs=init_kwargs,
@@ -80,12 +86,29 @@ class Trainer:
"noise_scheduler": noise_scheduler, "noise_scheduler": noise_scheduler,
}, },
) )
elif self.logger == "tensorboard":
from torch.utils.tensorboard import SummaryWriter
self.accelerator = Accelerator(
kwargs_handlers=[ddp_kwargs],
gradient_accumulation_steps=grad_accumulation_steps,
**accelerate_kwargs,
)
if self.is_main:
path_log_dir = os.path.join(log_dir, wandb_project)
os.makedirs(path_log_dir, exist_ok=True)
existing_folders = [folder for folder in os.listdir(path_log_dir) if folder.startswith("exp")]
next_number = len(existing_folders) + 2
folder_name = f"exp{next_number}"
folder_path = os.path.join(path_log_dir, folder_name)
os.makedirs(folder_path, exist_ok=True)
self.writer = SummaryWriter(log_dir=folder_path)
self.model = model self.model = model
if self.is_main: if self.is_main:
self.ema_model = EMA(model, include_online_model=False, **ema_kwargs) self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
self.ema_model.to(self.accelerator.device) self.ema_model.to(self.accelerator.device)
self.epochs = epochs self.epochs = epochs
@@ -175,7 +198,32 @@ class Trainer:
gc.collect() gc.collect()
return step return step
def log(self, metrics, step):
"""Unified logging method for both WandB and TensorBoard"""
if self.logger == "none":
return
if self.logger == "wandb":
self.accelerator.log(metrics, step=step)
elif self.is_main:
for key, value in metrics.items():
self.writer.add_scalar(key, value, step)
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None): def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
# import only when export_sample True
if self.export_samples:
from f5_tts.infer.utils_infer import (
target_sample_rate,
hop_length,
nfe_step,
cfg_strength,
sway_sampling_coef,
vocos,
)
from f5_tts.model.utils import get_sample
self.file_path_samples = os.path.join(self.path_ckpts_project, "samples")
os.makedirs(self.file_path_samples, exist_ok=True)
if exists(resumable_with_seed): if exists(resumable_with_seed):
generator = torch.Generator() generator = torch.Generator()
generator.manual_seed(resumable_with_seed) generator.manual_seed(resumable_with_seed)
@@ -259,6 +307,7 @@ class Trainer:
for batch in progress_bar: for batch in progress_bar:
with self.accelerator.accumulate(self.model): with self.accelerator.accumulate(self.model):
text_inputs = batch["text"] text_inputs = batch["text"]
mel_spec = batch["mel"].permute(0, 2, 1) mel_spec = batch["mel"].permute(0, 2, 1)
mel_lengths = batch["mel_lengths"] mel_lengths = batch["mel_lengths"]
@@ -270,6 +319,40 @@ class Trainer:
loss, cond, pred = self.model( loss, cond, pred = self.model(
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
) )
# save 4 audio per save step
if (
self.accelerator.is_local_main_process
and self.export_samples
and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0
):
try:
wave_org, wave_gen, mel_org, mel_gen = get_sample(
vocos,
self.model,
self.file_path_samples,
global_step,
batch["mel"][0],
text_inputs,
target_sample_rate,
hop_length,
nfe_step,
cfg_strength,
sway_sampling_coef,
)
if self.logger == "tensorboard":
self.writer.add_audio(
"Audio/original", wave_org, global_step, sample_rate=target_sample_rate
)
self.writer.add_audio(
"Audio/generate", wave_gen, global_step, sample_rate=target_sample_rate
)
self.writer.add_image("Mel/original", mel_org, global_step, dataformats="CHW")
self.writer.add_image("Mel/generate", mel_gen, global_step, dataformats="CHW")
except Exception as e:
print("An error occurred:", e)
self.accelerator.backward(loss) self.accelerator.backward(loss)
if self.max_grad_norm > 0 and self.accelerator.sync_gradients: if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
@@ -285,7 +368,7 @@ class Trainer:
global_step += 1 global_step += 1
if self.accelerator.is_local_main_process: if self.accelerator.is_local_main_process:
self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step) self.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
progress_bar.set_postfix(step=str(global_step), loss=loss.item()) progress_bar.set_postfix(step=str(global_step), loss=loss.item())

View File

@@ -11,6 +11,10 @@ from torch.nn.utils.rnn import pad_sequence
import jieba import jieba
from pypinyin import lazy_pinyin, Style from pypinyin import lazy_pinyin, Style
import numpy as np
import matplotlib.pyplot as plt
import soundfile as sf
import torchaudio
# seed everything # seed everything
@@ -183,3 +187,73 @@ def repetition_found(text, length=2, tolerance=10):
if count > tolerance: if count > tolerance:
return True return True
return False return False
def normalize_and_colorize_spectrogram(mel_org):
mel_min, mel_max = mel_org.min(), mel_org.max()
mel_norm = (mel_org - mel_min) / (mel_max - mel_min + 1e-8)
mel_colored = plt.get_cmap("viridis")(mel_norm.detach().cpu().numpy())[:, :, :3]
mel_colored = np.transpose(mel_colored, (2, 0, 1))
return mel_colored
def export_audio(file_out, wav, target_sample_rate):
sf.write(file_out, wav, samplerate=target_sample_rate)
def export_mel(mel_colored_hwc, file_out):
plt.imsave(file_out, mel_colored_hwc)
def gen_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef):
audio, sr = torchaudio.load(file_wav_org)
audio = audio.to("cuda")
ref_audio_len = audio.shape[-1] // hop_length
text = [text_inputs[0] + [" . "] + text_inputs[0]]
duration = int((audio.shape[1] / 256) * 2.0)
with torch.inference_mode():
generated_gen, _ = model.sample(
cond=audio,
text=text,
duration=duration,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
generated_gen = generated_gen.to(torch.float32)
generated_gen = generated_gen[:, ref_audio_len:, :]
generated_mel_spec_gen = generated_gen.permute(0, 2, 1)
generated_wave_gen = vocos.decode(generated_mel_spec_gen.cpu())
generated_wave_gen = generated_wave_gen.squeeze().cpu().numpy()
return generated_wave_gen, generated_mel_spec_gen
def get_sample(
vocos,
model,
file_path_samples,
global_step,
mel_org,
text_inputs,
target_sample_rate,
hop_length,
nfe_step,
cfg_strength,
sway_sampling_coef,
):
generated_wave_org = vocos.decode(mel_org.unsqueeze(0).cpu())
generated_wave_org = generated_wave_org.squeeze().cpu().numpy()
file_wav_org = os.path.join(file_path_samples, f"step_{global_step}_org.wav")
export_audio(file_wav_org, generated_wave_org, target_sample_rate)
generated_wave_gen, generated_mel_spec_gen = gen_sample(
model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef
)
file_wav_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.wav")
export_audio(file_wav_gen, generated_wave_gen, target_sample_rate)
mel_org = normalize_and_colorize_spectrogram(mel_org)
mel_gen = normalize_and_colorize_spectrogram(generated_mel_spec_gen[0])
file_gen_org = os.path.join(file_path_samples, f"step_{global_step}_org.png")
export_mel(np.transpose(mel_org, (1, 2, 0)), file_gen_org)
file_gen_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.png")
export_mel(np.transpose(mel_gen, (1, 2, 0)), file_gen_gen)
return generated_wave_org, generated_wave_gen, mel_org, mel_gen

View File

@@ -56,6 +56,14 @@ def parse_args():
help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')", help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
) )
parser.add_argument(
"--export_samples",
type=bool,
default=False,
help="Export 4 audio and spect samples for the checkpoint audio, per step.",
)
parser.add_argument("--logger", type=str, default="wandb", choices=["none", "wandb", "tensorboard"], help="logger")
return parser.parse_args() return parser.parse_args()
@@ -64,6 +72,7 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}")) checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
# Model parameters based on experiment name # Model parameters based on experiment name
@@ -136,6 +145,8 @@ def main():
wandb_run_name=args.exp_name, wandb_run_name=args.exp_name,
wandb_resume_id=wandb_resume_id, wandb_resume_id=wandb_resume_id,
last_per_steps=args.last_per_steps, last_per_steps=args.last_per_steps,
logger=args.logger,
export_samples=args.export_samples,
) )
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)

View File

@@ -69,6 +69,7 @@ def save_settings(
tokenizer_type, tokenizer_type,
tokenizer_file, tokenizer_file,
mixed_precision, mixed_precision,
logger,
): ):
path_project = os.path.join(path_project_ckpts, project_name) path_project = os.path.join(path_project_ckpts, project_name)
os.makedirs(path_project, exist_ok=True) os.makedirs(path_project, exist_ok=True)
@@ -91,6 +92,7 @@ def save_settings(
"tokenizer_type": tokenizer_type, "tokenizer_type": tokenizer_type,
"tokenizer_file": tokenizer_file, "tokenizer_file": tokenizer_file,
"mixed_precision": mixed_precision, "mixed_precision": mixed_precision,
"logger": logger,
} }
with open(file_setting, "w") as f: with open(file_setting, "w") as f:
json.dump(settings, f, indent=4) json.dump(settings, f, indent=4)
@@ -121,6 +123,7 @@ def load_settings(project_name):
"tokenizer_type": "pinyin", "tokenizer_type": "pinyin",
"tokenizer_file": "", "tokenizer_file": "",
"mixed_precision": "none", "mixed_precision": "none",
"logger": "wandb",
} }
return ( return (
settings["exp_name"], settings["exp_name"],
@@ -139,6 +142,7 @@ def load_settings(project_name):
settings["tokenizer_type"], settings["tokenizer_type"],
settings["tokenizer_file"], settings["tokenizer_file"],
settings["mixed_precision"], settings["mixed_precision"],
settings["logger"],
) )
with open(file_setting, "r") as f: with open(file_setting, "r") as f:
@@ -160,6 +164,7 @@ def load_settings(project_name):
settings["tokenizer_type"], settings["tokenizer_type"],
settings["tokenizer_file"], settings["tokenizer_file"],
settings["mixed_precision"], settings["mixed_precision"],
settings["logger"],
) )
@@ -374,6 +379,7 @@ def start_training(
tokenizer_file="", tokenizer_file="",
mixed_precision="fp16", mixed_precision="fp16",
stream=False, stream=False,
logger="wandb",
): ):
global training_process, tts_api, stop_signal global training_process, tts_api, stop_signal
@@ -447,6 +453,8 @@ def start_training(
cmd += f" --tokenizer {tokenizer_type} " cmd += f" --tokenizer {tokenizer_type} "
cmd += f" --export_samples True --logger {logger} "
print(cmd) print(cmd)
save_settings( save_settings(
@@ -467,6 +475,7 @@ def start_training(
tokenizer_type, tokenizer_type,
tokenizer_file, tokenizer_file,
mixed_precision, mixed_precision,
logger,
) )
try: try:
@@ -1223,6 +1232,27 @@ def get_checkpoints_project(project_name, is_gradio=True):
return files_checkpoints, selelect_checkpoint return files_checkpoints, selelect_checkpoint
def get_audio_project(project_name, is_gradio=True):
if project_name is None:
return [], ""
project_name = project_name.replace("_pinyin", "").replace("_char", "")
if os.path.isdir(path_project_ckpts):
files_audios = glob(os.path.join(path_project_ckpts, project_name, "samples", "*.wav"))
files_audios = sorted(files_audios, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]))
files_audios = [item.replace("_gen.wav", "") for item in files_audios if item.endswith("_gen.wav")]
else:
files_audios = []
selelect_checkpoint = None if not files_audios else files_audios[0]
if is_gradio:
return gr.update(choices=files_audios, value=selelect_checkpoint)
return files_audios, selelect_checkpoint
def get_gpu_stats(): def get_gpu_stats():
gpu_stats = "" gpu_stats = ""
@@ -1290,6 +1320,21 @@ def get_combined_stats():
return combined_stats return combined_stats
def get_audio_select(file_sample):
select_audio_org = file_sample
select_audio_gen = file_sample
select_image_org = file_sample
select_image_gen = file_sample
if file_sample is not None:
select_audio_org += "_org.wav"
select_audio_gen += "_gen.wav"
select_image_org += "_org.png"
select_image_gen += "_gen.png"
return select_audio_org, select_audio_gen, select_image_org, select_image_gen
with gr.Blocks() as app: with gr.Blocks() as app:
gr.Markdown( gr.Markdown(
""" """
@@ -1470,6 +1515,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
with gr.Row(): with gr.Row():
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none") mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
cd_logger = gr.Radio(label="logger", choices=["none", "wandb", "tensorboard"], value="wandb")
start_button = gr.Button("Start Training") start_button = gr.Button("Start Training")
stop_button = gr.Button("Stop Training", interactive=False) stop_button = gr.Button("Stop Training", interactive=False)
@@ -1491,6 +1537,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
tokenizer_typev, tokenizer_typev,
tokenizer_filev, tokenizer_filev,
mixed_precisionv, mixed_precisionv,
cd_loggerv,
) = load_settings(projects_selelect) ) = load_settings(projects_selelect)
exp_name.value = exp_namev exp_name.value = exp_namev
learning_rate.value = learning_ratev learning_rate.value = learning_ratev
@@ -1508,9 +1555,51 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
tokenizer_type.value = tokenizer_typev tokenizer_type.value = tokenizer_typev
tokenizer_file.value = tokenizer_filev tokenizer_file.value = tokenizer_filev
mixed_precision.value = mixed_precisionv mixed_precision.value = mixed_precisionv
cd_logger.value = cd_loggerv
ch_stream = gr.Checkbox(label="stream output experiment.", value=True) ch_stream = gr.Checkbox(label="stream output experiment.", value=True)
txt_info_train = gr.Text(label="info", value="") txt_info_train = gr.Text(label="info", value="")
list_audios, select_audio = get_audio_project(projects_selelect, False)
select_audio_org = select_audio
select_audio_gen = select_audio
select_image_org = select_audio
select_image_gen = select_audio
if select_audio is not None:
select_audio_org += "_org.wav"
select_audio_gen += "_gen.wav"
select_image_org += "_org.png"
select_image_gen += "_gen.png"
with gr.Row():
ch_list_audio = gr.Dropdown(
choices=list_audios,
value=select_audio,
label="audios",
allow_custom_value=True,
scale=6,
interactive=True,
)
bt_stream_audio = gr.Button("refresh", scale=1)
bt_stream_audio.click(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
cm_project.change(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
with gr.Row():
audio_org_stream = gr.Audio(label="original", type="filepath", value=select_audio_org)
mel_org_stream = gr.Image(label="original", type="filepath", value=select_image_org)
with gr.Row():
audio_gen_stream = gr.Audio(label="generate", type="filepath", value=select_audio_gen)
mel_gen_stream = gr.Image(label="generate", type="filepath", value=select_image_gen)
ch_list_audio.change(
fn=get_audio_select,
inputs=[ch_list_audio],
outputs=[audio_org_stream, audio_gen_stream, mel_org_stream, mel_gen_stream],
)
start_button.click( start_button.click(
fn=start_training, fn=start_training,
inputs=[ inputs=[
@@ -1532,6 +1621,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
tokenizer_file, tokenizer_file,
mixed_precision, mixed_precision,
ch_stream, ch_stream,
cd_logger,
], ],
outputs=[txt_info_train, start_button, stop_button], outputs=[txt_info_train, start_button, stop_button],
) )
@@ -1583,6 +1673,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
tokenizer_type, tokenizer_type,
tokenizer_file, tokenizer_file,
mixed_precision, mixed_precision,
cd_logger,
] ]
return output_components return output_components