import transcribe from utils_infer

This commit is contained in:
unknown
2024-11-16 18:39:52 +02:00
parent c4d7252cf8
commit 96946f85fa

View File

@@ -26,12 +26,13 @@ from datasets import Dataset as Dataset_
from datasets.arrow_writer import ArrowWriter
from safetensors.torch import save_file
from scipy.io import wavfile
from transformers import pipeline
from cached_path import cached_path
from f5_tts.api import F5TTS
from f5_tts.model.utils import convert_char_to_pinyin
from f5_tts.infer.utils_infer import transcribe
from importlib.resources import files
training_process = None
system = platform.system()
python_executable = sys.executable or "python"
@@ -47,8 +48,6 @@ file_train = "src/f5_tts/train/finetune_cli.py"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
pipe = None
# Save settings from a JSON file
def save_settings(
@@ -390,17 +389,15 @@ def start_training(
logger="wandb",
ch_8bit_adam=False,
):
global training_process, tts_api, stop_signal, pipe
global training_process, tts_api, stop_signal
if tts_api is not None or pipe is not None:
if tts_api is not None:
if tts_api is not None:
del tts_api
if pipe is not None:
del pipe
gc.collect()
torch.cuda.empty_cache()
tts_api = None
pipe = None
path_project = os.path.join(path_data, dataset_name)
@@ -652,27 +649,6 @@ def create_data_project(name, tokenizer_type):
return gr.update(choices=project_list, value=name)
def transcribe(file_audio, language="english"):
global pipe
if pipe is None:
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device=device,
)
text_transcribe = pipe(
file_audio,
chunk_length_s=30,
batch_size=128,
generate_kwargs={"task": "transcribe", "language": language},
return_timestamps=False,
)["text"].strip()
return text_transcribe
def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
path_project = os.path.join(path_data, name_project)
path_dataset = os.path.join(path_project, "dataset")