From f7a698bc2f52659a8a21bc34b89dc59e0167dea8 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 1 Nov 2024 11:39:06 +0200 Subject: [PATCH] resample when need --- src/f5_tts/train/finetune_gradio.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 53512bc..2836fe3 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -28,7 +28,7 @@ from safetensors.torch import save_file from scipy.io import wavfile from transformers import pipeline from cached_path import cached_path -from f5_tts.api import F5TTS +from f5_tts.api import F5TTS, target_sample_rate from f5_tts.model.utils import convert_char_to_pinyin from importlib.resources import files @@ -174,7 +174,15 @@ def load_settings(project_name): def get_audio_duration(audio_path): """Calculate the duration mono of an audio file.""" audio, sample_rate = torchaudio.load(audio_path) - return audio.shape[1] / sample_rate + + if audio.shape[0] > 1: + audio = torch.mean(audio, dim=0, keepdim=True) + + if sample_rate != target_sample_rate: + audio = torchaudio.transforms.Resample(sample_rate, target_sample_rate) + + num_channels = audio.shape[0] + return audio.shape[1] / (sample_rate * num_channels) def clear_text(text):