diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 692c95c..3acfec3 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -33,6 +33,22 @@ from f5_tts.infer.utils_infer import transcribe from f5_tts.model.utils import convert_char_to_pinyin +def _safe_project_path(base: str, name: str) -> str: + """Return the resolved absolute path of base/name, raising ValueError if name + is absolute, contains a null byte, or resolves outside base.""" + if not name or os.path.isabs(name) or "\x00" in name: + raise ValueError(f"invalid project_name: {name!r}") + # Strip path separators and control characters to a plain filename component. + name = re.sub(r"[/\\]", "", name).strip() + if not name or name in (".", ".."): + raise ValueError(f"invalid project_name: {name!r}") + candidate = os.path.realpath(os.path.join(base, name)) + base_real = os.path.realpath(base) + if not (candidate + os.sep).startswith(base_real + os.sep): + raise ValueError(f"project_name escapes base directory: {name!r}") + return candidate + + training_process = None system = platform.system() python_executable = sys.executable or "python" @@ -80,7 +96,7 @@ def save_settings( logger, ch_8bit_adam, ): - path_project = os.path.join(path_project_ckpts, project_name) + path_project = _safe_project_path(path_project_ckpts, project_name) os.makedirs(path_project, exist_ok=True) file_setting = os.path.join(path_project, "setting.json") @@ -113,7 +129,7 @@ def save_settings( # Load settings from a JSON file def load_settings(project_name): project_name = project_name.replace("_pinyin", "").replace("_char", "") - path_project = os.path.join(path_project_ckpts, project_name) + path_project = _safe_project_path(path_project_ckpts, project_name) file_setting = os.path.join(path_project, "setting.json") # Default settings @@ -356,7 +372,7 @@ def start_training( torch.cuda.empty_cache() tts_api = None - path_project = os.path.join(path_data, dataset_name) + path_project = _safe_project_path(path_data, dataset_name) if not os.path.isdir(path_project): yield ( @@ -610,14 +626,15 @@ def get_list_projects(): def create_data_project(name, tokenizer_type): name += "_" + tokenizer_type - os.makedirs(os.path.join(path_data, name), exist_ok=True) - os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True) + project_dir = _safe_project_path(path_data, name) + os.makedirs(project_dir, exist_ok=True) + os.makedirs(os.path.join(project_dir, "dataset"), exist_ok=True) project_list, projects_selelect = get_list_projects() return gr.update(choices=project_list, value=name) def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()): - path_project = os.path.join(path_data, name_project) + path_project = _safe_project_path(path_data, name_project) path_dataset = os.path.join(path_project, "dataset") path_project_wavs = os.path.join(path_project, "wavs") file_metadata = os.path.join(path_project, "metadata.csv") @@ -726,7 +743,7 @@ def get_correct_audio_path( def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()): - path_project = os.path.join(path_data, name_project) + path_project = _safe_project_path(path_data, name_project) path_project_wavs = os.path.join(path_project, "wavs") file_metadata = os.path.join(path_project, "metadata.csv") file_raw = os.path.join(path_project, "raw.arrow") @@ -850,7 +867,7 @@ def calculate_train( num_warmup_updates, finetune, ): - path_project = os.path.join(path_data, name_project) + path_project = _safe_project_path(path_data, name_project) file_duration = os.path.join(path_project, "duration.json") hop_length = 256 @@ -1003,7 +1020,7 @@ def vocab_extend(project_name, symbols, model_type): return "Symbols empty!" name_project = project_name - path_project = os.path.join(path_data, name_project) + path_project = _safe_project_path(path_data, name_project) file_vocab_project = os.path.join(path_project, "vocab.txt") file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt") @@ -1049,7 +1066,7 @@ def vocab_extend(project_name, symbols, model_type): vocab_size_new = len(miss_symbols) dataset_name = name_project.replace("_pinyin", "").replace("_char", "") - new_ckpt_path = os.path.join(path_project_ckpts, dataset_name) + new_ckpt_path = _safe_project_path(path_project_ckpts, dataset_name) os.makedirs(new_ckpt_path, exist_ok=True) # Add pretrained_ prefix to model when copying for consistency with finetune_cli.py @@ -1063,7 +1080,7 @@ def vocab_extend(project_name, symbols, model_type): def vocab_check(project_name, tokenizer_type): name_project = project_name - path_project = os.path.join(path_data, name_project) + path_project = _safe_project_path(path_data, name_project) file_metadata = os.path.join(path_project, "metadata.csv") @@ -1110,7 +1127,7 @@ def vocab_check(project_name, tokenizer_type): def get_random_sample_prepare(project_name): name_project = project_name - path_project = os.path.join(path_data, name_project) + path_project = _safe_project_path(path_data, name_project) file_arrow = os.path.join(path_project, "raw.arrow") if not os.path.isfile(file_arrow): return "", None @@ -1123,7 +1140,7 @@ def get_random_sample_prepare(project_name): def get_random_sample_transcribe(project_name): name_project = project_name - path_project = os.path.join(path_data, name_project) + path_project = _safe_project_path(path_data, name_project) file_metadata = os.path.join(path_project, "metadata.csv") if not os.path.isfile(file_metadata): return "", None