mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-19 08:11:29 -08:00
import transcribe from utils_infer
This commit is contained in:
@@ -26,12 +26,13 @@ from datasets import Dataset as Dataset_
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from safetensors.torch import save_file
|
||||
from scipy.io import wavfile
|
||||
from transformers import pipeline
|
||||
from cached_path import cached_path
|
||||
from f5_tts.api import F5TTS
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
from f5_tts.infer.utils_infer import transcribe
|
||||
from importlib.resources import files
|
||||
|
||||
|
||||
training_process = None
|
||||
system = platform.system()
|
||||
python_executable = sys.executable or "python"
|
||||
@@ -47,8 +48,6 @@ file_train = "src/f5_tts/train/finetune_cli.py"
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
pipe = None
|
||||
|
||||
|
||||
# Save settings from a JSON file
|
||||
def save_settings(
|
||||
@@ -390,17 +389,15 @@ def start_training(
|
||||
logger="wandb",
|
||||
ch_8bit_adam=False,
|
||||
):
|
||||
global training_process, tts_api, stop_signal, pipe
|
||||
global training_process, tts_api, stop_signal
|
||||
|
||||
if tts_api is not None or pipe is not None:
|
||||
if tts_api is not None:
|
||||
if tts_api is not None:
|
||||
del tts_api
|
||||
if pipe is not None:
|
||||
del pipe
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
tts_api = None
|
||||
pipe = None
|
||||
|
||||
path_project = os.path.join(path_data, dataset_name)
|
||||
|
||||
@@ -652,27 +649,6 @@ def create_data_project(name, tokenizer_type):
|
||||
return gr.update(choices=project_list, value=name)
|
||||
|
||||
|
||||
def transcribe(file_audio, language="english"):
|
||||
global pipe
|
||||
|
||||
if pipe is None:
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model="openai/whisper-large-v3-turbo",
|
||||
torch_dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
text_transcribe = pipe(
|
||||
file_audio,
|
||||
chunk_length_s=30,
|
||||
batch_size=128,
|
||||
generate_kwargs={"task": "transcribe", "language": language},
|
||||
return_timestamps=False,
|
||||
)["text"].strip()
|
||||
return text_transcribe
|
||||
|
||||
|
||||
def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
path_dataset = os.path.join(path_project, "dataset")
|
||||
|
||||
Reference in New Issue
Block a user