mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-06-12 10:01:15 -07:00
fix: path traversal in finetune Gradio handlers (closes #1293)
Add _safe_project_path() helper that rejects absolute paths, null bytes, and path separators, then verifies the resolved path stays within the intended base directory via realpath + startswith check. Apply to all 10 sinks in save_settings, load_settings, create_data_project, vocab_extend, transcribe_all, create_metadata, and related functions.
This commit is contained in:
@@ -33,6 +33,22 @@ from f5_tts.infer.utils_infer import transcribe
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
|
||||
|
||||
def _safe_project_path(base: str, name: str) -> str:
|
||||
"""Return the resolved absolute path of base/name, raising ValueError if name
|
||||
is absolute, contains a null byte, or resolves outside base."""
|
||||
if not name or os.path.isabs(name) or "\x00" in name:
|
||||
raise ValueError(f"invalid project_name: {name!r}")
|
||||
# Strip path separators and control characters to a plain filename component.
|
||||
name = re.sub(r"[/\\]", "", name).strip()
|
||||
if not name or name in (".", ".."):
|
||||
raise ValueError(f"invalid project_name: {name!r}")
|
||||
candidate = os.path.realpath(os.path.join(base, name))
|
||||
base_real = os.path.realpath(base)
|
||||
if not (candidate + os.sep).startswith(base_real + os.sep):
|
||||
raise ValueError(f"project_name escapes base directory: {name!r}")
|
||||
return candidate
|
||||
|
||||
|
||||
training_process = None
|
||||
system = platform.system()
|
||||
python_executable = sys.executable or "python"
|
||||
@@ -80,7 +96,7 @@ def save_settings(
|
||||
logger,
|
||||
ch_8bit_adam,
|
||||
):
|
||||
path_project = os.path.join(path_project_ckpts, project_name)
|
||||
path_project = _safe_project_path(path_project_ckpts, project_name)
|
||||
os.makedirs(path_project, exist_ok=True)
|
||||
file_setting = os.path.join(path_project, "setting.json")
|
||||
|
||||
@@ -113,7 +129,7 @@ def save_settings(
|
||||
# Load settings from a JSON file
|
||||
def load_settings(project_name):
|
||||
project_name = project_name.replace("_pinyin", "").replace("_char", "")
|
||||
path_project = os.path.join(path_project_ckpts, project_name)
|
||||
path_project = _safe_project_path(path_project_ckpts, project_name)
|
||||
file_setting = os.path.join(path_project, "setting.json")
|
||||
|
||||
# Default settings
|
||||
@@ -356,7 +372,7 @@ def start_training(
|
||||
torch.cuda.empty_cache()
|
||||
tts_api = None
|
||||
|
||||
path_project = os.path.join(path_data, dataset_name)
|
||||
path_project = _safe_project_path(path_data, dataset_name)
|
||||
|
||||
if not os.path.isdir(path_project):
|
||||
yield (
|
||||
@@ -610,14 +626,15 @@ def get_list_projects():
|
||||
|
||||
def create_data_project(name, tokenizer_type):
|
||||
name += "_" + tokenizer_type
|
||||
os.makedirs(os.path.join(path_data, name), exist_ok=True)
|
||||
os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
|
||||
project_dir = _safe_project_path(path_data, name)
|
||||
os.makedirs(project_dir, exist_ok=True)
|
||||
os.makedirs(os.path.join(project_dir, "dataset"), exist_ok=True)
|
||||
project_list, projects_selelect = get_list_projects()
|
||||
return gr.update(choices=project_list, value=name)
|
||||
|
||||
|
||||
def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
path_project = _safe_project_path(path_data, name_project)
|
||||
path_dataset = os.path.join(path_project, "dataset")
|
||||
path_project_wavs = os.path.join(path_project, "wavs")
|
||||
file_metadata = os.path.join(path_project, "metadata.csv")
|
||||
@@ -726,7 +743,7 @@ def get_correct_audio_path(
|
||||
|
||||
|
||||
def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
path_project = _safe_project_path(path_data, name_project)
|
||||
path_project_wavs = os.path.join(path_project, "wavs")
|
||||
file_metadata = os.path.join(path_project, "metadata.csv")
|
||||
file_raw = os.path.join(path_project, "raw.arrow")
|
||||
@@ -850,7 +867,7 @@ def calculate_train(
|
||||
num_warmup_updates,
|
||||
finetune,
|
||||
):
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
path_project = _safe_project_path(path_data, name_project)
|
||||
file_duration = os.path.join(path_project, "duration.json")
|
||||
|
||||
hop_length = 256
|
||||
@@ -1003,7 +1020,7 @@ def vocab_extend(project_name, symbols, model_type):
|
||||
return "Symbols empty!"
|
||||
|
||||
name_project = project_name
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
path_project = _safe_project_path(path_data, name_project)
|
||||
file_vocab_project = os.path.join(path_project, "vocab.txt")
|
||||
|
||||
file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
|
||||
@@ -1049,7 +1066,7 @@ def vocab_extend(project_name, symbols, model_type):
|
||||
vocab_size_new = len(miss_symbols)
|
||||
|
||||
dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
|
||||
new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
|
||||
new_ckpt_path = _safe_project_path(path_project_ckpts, dataset_name)
|
||||
os.makedirs(new_ckpt_path, exist_ok=True)
|
||||
|
||||
# Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
|
||||
@@ -1063,7 +1080,7 @@ def vocab_extend(project_name, symbols, model_type):
|
||||
|
||||
def vocab_check(project_name, tokenizer_type):
|
||||
name_project = project_name
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
path_project = _safe_project_path(path_data, name_project)
|
||||
|
||||
file_metadata = os.path.join(path_project, "metadata.csv")
|
||||
|
||||
@@ -1110,7 +1127,7 @@ def vocab_check(project_name, tokenizer_type):
|
||||
|
||||
def get_random_sample_prepare(project_name):
|
||||
name_project = project_name
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
path_project = _safe_project_path(path_data, name_project)
|
||||
file_arrow = os.path.join(path_project, "raw.arrow")
|
||||
if not os.path.isfile(file_arrow):
|
||||
return "", None
|
||||
@@ -1123,7 +1140,7 @@ def get_random_sample_prepare(project_name):
|
||||
|
||||
def get_random_sample_transcribe(project_name):
|
||||
name_project = project_name
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
path_project = _safe_project_path(path_data, name_project)
|
||||
file_metadata = os.path.join(path_project, "metadata.csv")
|
||||
if not os.path.isfile(file_metadata):
|
||||
return "", None
|
||||
|
||||
Reference in New Issue
Block a user