From 5af195f1f92ce32359b539b3c54c4ed0b0f150b6 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 1 Nov 2024 11:18:05 +0200 Subject: [PATCH 1/6] only mono duraction fix value bfp16 to bf16 --- src/f5_tts/train/finetune_gradio.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 007dad8..53512bc 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -172,10 +172,9 @@ def load_settings(project_name): # Load metadata def get_audio_duration(audio_path): - """Calculate the duration of an audio file.""" + """Calculate the duration mono of an audio file.""" audio, sample_rate = torchaudio.load(audio_path) - num_channels = audio.shape[0] - return audio.shape[1] / (sample_rate * num_channels) + return audio.shape[1] / sample_rate def clear_text(text): @@ -1557,7 +1556,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle last_per_steps = gr.Number(label="Last per Steps", value=100) with gr.Row(): - mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none") + mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none") cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb") start_button = gr.Button("Start Training") stop_button = gr.Button("Stop Training", interactive=False) From f7a698bc2f52659a8a21bc34b89dc59e0167dea8 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 1 Nov 2024 11:39:06 +0200 Subject: [PATCH 2/6] resample when need --- src/f5_tts/train/finetune_gradio.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 53512bc..2836fe3 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -28,7 +28,7 @@ from safetensors.torch import save_file from scipy.io import wavfile from transformers import pipeline from cached_path import cached_path -from f5_tts.api import F5TTS +from f5_tts.api import F5TTS, target_sample_rate from f5_tts.model.utils import convert_char_to_pinyin from importlib.resources import files @@ -174,7 +174,15 @@ def load_settings(project_name): def get_audio_duration(audio_path): """Calculate the duration mono of an audio file.""" audio, sample_rate = torchaudio.load(audio_path) - return audio.shape[1] / sample_rate + + if audio.shape[0] > 1: + audio = torch.mean(audio, dim=0, keepdim=True) + + if sample_rate != target_sample_rate: + audio = torchaudio.transforms.Resample(sample_rate, target_sample_rate) + + num_channels = audio.shape[0] + return audio.shape[1] / (sample_rate * num_channels) def clear_text(text): From 199c56c23cf887bf7db45093b2740a23f35856a2 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 1 Nov 2024 12:05:12 +0200 Subject: [PATCH 3/6] clear pipe --- src/f5_tts/train/finetune_gradio.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 2836fe3..f37fb86 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -392,11 +392,15 @@ def start_training( ): global training_process, tts_api, stop_signal - if tts_api is not None: - del tts_api + if tts_api is not None or pipe is not None: + if tts_api is not None: + del tts_api + if pipe is not None: + del pipe gc.collect() torch.cuda.empty_cache() tts_api = None + pipe = None path_project = os.path.join(path_data, dataset_name) From 27d98a52cd0ab1145c540eb8c1512a0f6def4832 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 1 Nov 2024 12:08:47 +0200 Subject: [PATCH 4/6] clear pipe --- src/f5_tts/train/finetune_gradio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index f37fb86..beb509a 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -390,7 +390,7 @@ def start_training( stream=False, logger="wandb", ): - global training_process, tts_api, stop_signal + global training_process, tts_api, stop_signal, pipe if tts_api is not None or pipe is not None: if tts_api is not None: From 552c0fd99c3fafceb78b0085946c0fac68a504c5 Mon Sep 17 00:00:00 2001 From: Yushen CHEN <45333109+SWivid@users.noreply.github.com> Date: Fri, 1 Nov 2024 18:17:57 +0800 Subject: [PATCH 5/6] Update prepare_csv_wavs.py --- src/f5_tts/train/datasets/prepare_csv_wavs.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/f5_tts/train/datasets/prepare_csv_wavs.py b/src/f5_tts/train/datasets/prepare_csv_wavs.py index f39001a..dd51ef0 100644 --- a/src/f5_tts/train/datasets/prepare_csv_wavs.py +++ b/src/f5_tts/train/datasets/prepare_csv_wavs.py @@ -54,8 +54,7 @@ def prepare_csv_wavs_dir(input_dir): def get_audio_duration(audio_path): audio, sample_rate = torchaudio.load(audio_path) - num_channels = audio.shape[0] - return audio.shape[1] / (sample_rate * num_channels) + return audio.shape[1] / sample_rate def read_audio_text_pairs(csv_file_path): From b664bc77771d3a84d931a57d0cdd701ca925860a Mon Sep 17 00:00:00 2001 From: Yushen CHEN <45333109+SWivid@users.noreply.github.com> Date: Fri, 1 Nov 2024 18:20:39 +0800 Subject: [PATCH 6/6] Update finetune_gradio.py --- src/f5_tts/train/finetune_gradio.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index beb509a..9b46b01 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -28,7 +28,7 @@ from safetensors.torch import save_file from scipy.io import wavfile from transformers import pipeline from cached_path import cached_path -from f5_tts.api import F5TTS, target_sample_rate +from f5_tts.api import F5TTS from f5_tts.model.utils import convert_char_to_pinyin from importlib.resources import files @@ -174,15 +174,7 @@ def load_settings(project_name): def get_audio_duration(audio_path): """Calculate the duration mono of an audio file.""" audio, sample_rate = torchaudio.load(audio_path) - - if audio.shape[0] > 1: - audio = torch.mean(audio, dim=0, keepdim=True) - - if sample_rate != target_sample_rate: - audio = torchaudio.transforms.Resample(sample_rate, target_sample_rate) - - num_channels = audio.shape[0] - return audio.shape[1] / (sample_rate * num_channels) + return audio.shape[1] / sample_rate def clear_text(text):