add transcribe function

This commit is contained in:
unknown
2024-11-16 18:26:42 +02:00
parent e636d98090
commit bb4d538dc5

View File

@@ -150,6 +150,22 @@ def initialize_asr_pipeline(device=device, dtype=None):
)
# transcribe
def transcribe(ref_audio, language=None):
global asr_pipe
if asr_pipe is None:
initialize_asr_pipeline(device=device)
return asr_pipe(
ref_audio,
chunk_length_s=30,
batch_size=128,
generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
return_timestamps=False,
)["text"].strip()
# load model checkpoint for inference
@@ -306,17 +322,8 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
show_info("Using cached reference text...")
ref_text = _ref_audio_cache[audio_hash]
else:
global asr_pipe
if asr_pipe is None:
initialize_asr_pipeline(device=device)
show_info("No reference text provided, transcribing reference audio...")
ref_text = asr_pipe(
ref_audio,
chunk_length_s=30,
batch_size=128,
generate_kwargs={"task": "transcribe"},
return_timestamps=False,
)["text"].strip()
ref_text = transcribe(ref_audio)
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
_ref_audio_cache[audio_hash] = ref_text
else: