diff --git a/hate_crack/api.py b/hate_crack/api.py index 6c02da5..f263e72 100644 --- a/hate_crack/api.py +++ b/hate_crack/api.py @@ -1,3 +1,4 @@ +import concurrent.futures import json import sys import os @@ -15,6 +16,16 @@ from hate_crack.formatting import print_multicolumn_list _TORRENT_CLEANUP_REGISTERED = False +def _get_hate_path(): + _package_path = os.path.dirname(os.path.realpath(__file__)) + _repo_root = os.path.dirname(_package_path) + if os.path.isdir(os.path.join(_package_path, "hashcat-utils")): + return _package_path + elif os.path.isdir(os.path.join(_repo_root, "hashcat-utils")): + return _repo_root + return _package_path + + def _candidate_roots(): cwd = os.getcwd() home = os.path.expanduser("~") @@ -71,7 +82,7 @@ def get_hcat_wordlists_dir(): if path: path = os.path.expanduser(path) if not os.path.isabs(path): - path = os.path.join(os.path.dirname(config_path), path) + path = os.path.normpath(os.path.join(_get_hate_path(), path)) os.makedirs(path, exist_ok=True) return path except Exception: @@ -91,7 +102,7 @@ def get_rules_dir(): if path: path = os.path.expanduser(path) if not os.path.isabs(path): - path = os.path.join(os.path.dirname(config_path), path) + path = os.path.normpath(os.path.join(_get_hate_path(), path)) os.makedirs(path, exist_ok=True) return path except Exception: @@ -1739,33 +1750,45 @@ def list_and_download_hashmob_rules(rules_dir=None): return sanitized in downloaded_rules if sel.lower() == "a": - for entry in rules: - file_name = entry.get("file_name") - if not file_name: - print("No file_name found for an entry, skipping.") - continue - out_path = os.path.join(rules_dir, sanitize_filename(file_name)) - if already_downloaded(file_name): - print(f"[i] Skipping already downloaded rule: {file_name}") - continue - download_hashmob_rule(file_name, out_path) - return + entries = rules + else: + indices = parse_indices(sel, len(rules)) + if not indices: + print("No valid selection.") + return + entries = [rules[idx - 1] for idx in indices] - indices = parse_indices(sel, len(rules)) - if not indices: - print("No valid selection.") - return - for idx in indices: - entry = rules[idx - 1] + jobs = [] + for entry in entries: file_name = entry.get("file_name") if not file_name: - print("No file_name found for selection, skipping.") + print("No file_name found for an entry, skipping.") continue - out_path = os.path.join(rules_dir, sanitize_filename(file_name)) if already_downloaded(file_name): print(f"[i] Skipping already downloaded rule: {file_name}") continue - download_hashmob_rule(file_name, out_path) + out_path = os.path.join(rules_dir, sanitize_filename(file_name)) + jobs.append((file_name, out_path)) + + if not jobs: + return + + succeeded = 0 + failed = 0 + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + futures = { + executor.submit(download_hashmob_rule, fn, op): fn for fn, op in jobs + } + for future in concurrent.futures.as_completed(futures): + file_name = futures[future] + try: + future.result() + succeeded += 1 + except Exception as exc: + print(f"[!] Failed to download {file_name}: {exc}") + failed += 1 + + print(f"[i] Rule downloads complete: {succeeded} succeeded, {failed} failed.") def download_official_wordlist(file_name, out_path): diff --git a/hate_crack/attacks.py b/hate_crack/attacks.py index fa2c92f..7a787a2 100644 --- a/hate_crack/attacks.py +++ b/hate_crack/attacks.py @@ -29,9 +29,7 @@ def quick_crack(ctx: Any) -> None: rule_choice = None selected_hcatRules = [] - wordlist_files = sorted( - f for f in os.listdir(ctx.hcatWordlists) if f != ".DS_Store" - ) + wordlist_files = ctx.list_wordlist_files(ctx.hcatWordlists) wordlist_entries = [ f"{i}. {file}" for i, file in enumerate(wordlist_files, start=1) ] @@ -511,6 +509,33 @@ def ollama_attack(ctx: Any) -> None: ctx.hcatOllama(ctx.hcatHashType, ctx.hcatHashFile, "target", target_info) +def _omen_pick_training_wordlist(ctx: Any): + """Show wordlist picker for OMEN training. Returns path or None.""" + wordlist_files = ctx.list_wordlist_files(ctx.hcatWordlists) + if wordlist_files: + entries = [f"{i}. {f}" for i, f in enumerate(wordlist_files, start=1)] + max_len = max((len(e) for e in entries), default=24) + print_multicolumn_list( + "Training Wordlists", + entries, + min_col_width=max_len, + max_col_width=max_len, + ) + print("\tp. Enter a custom path") + sel = input("\n\tSelect wordlist for training: ").strip() + if sel.lower() == "p": + path = input("\n\tPath to training wordlist: ").strip() + return path if path else None + try: + idx = int(sel) + if 1 <= idx <= len(wordlist_files): + return os.path.join(ctx.hcatWordlists, wordlist_files[idx - 1]) + except (ValueError, IndexError): + pass + print("\t[!] Invalid selection.") + return None + + def omen_attack(ctx: Any) -> None: print("\n\tOMEN Attack (Ordered Markov ENumerator)") omen_dir = os.path.join(ctx.hate_path, "omen") @@ -520,16 +545,36 @@ def omen_attack(ctx: Any) -> None: print("\n\tOMEN binaries not found. Build them with:") print(f"\t cd {omen_dir} && make") return - model_dir = os.path.join(os.path.expanduser("~"), ".hate_crack", "omen") - model_exists = os.path.isfile(os.path.join(model_dir, "createConfig")) - if not model_exists: - print("\n\tNo OMEN model found. Training is required before generation.") - training_source = input( - "\n\tTraining source (path to password list, or press Enter for default): " - ).strip() - if not training_source: - training_source = ctx.omenTrainingList - ctx.hcatOmenTrain(training_source) + + model_dir = ctx._omen_model_dir() + model_valid = ctx._omen_model_is_valid(model_dir) + need_training = True + + if model_valid: + info = ctx._omen_model_info(model_dir) + trained_with = info.get("training_file", "unknown") if info else "unknown" + print(f"\n\tOMEN model found (trained with: {trained_with})") + print("\t1. Use existing model") + print("\t2. Train new model (overwrites existing)") + print("\t3. Cancel") + choice = input("\n\tChoice: ").strip() + if choice == "1": + need_training = False + elif choice == "3": + return + elif choice != "2": + return + else: + print("\n\tNo valid OMEN model found. Training is required.") + + if need_training: + training_file = _omen_pick_training_wordlist(ctx) + if not training_file: + return + if not ctx.hcatOmenTrain(training_file): + print("\n\t[!] Training failed. Aborting OMEN attack.") + return + max_candidates = input( f"\n\tMax candidates to generate ({ctx.omenMaxCandidates}): " ).strip() diff --git a/tests/test_api_downloads.py b/tests/test_api_downloads.py index 8288709..aa358e5 100644 --- a/tests/test_api_downloads.py +++ b/tests/test_api_downloads.py @@ -1,7 +1,8 @@ import json import os -from unittest.mock import MagicMock, patch +import pytest +from unittest.mock import MagicMock, call, patch from hate_crack.api import ( check_7z, @@ -11,6 +12,7 @@ from hate_crack.api import ( get_hashmob_api_key, get_hcat_potfile_args, get_hcat_potfile_path, + list_and_download_hashmob_rules, sanitize_filename, ) @@ -225,3 +227,60 @@ class TestDownloadHashmobWordlist: patch("hate_crack.api.time.sleep"): result = download_hashmob_wordlist("test.txt", str(out)) assert result is False + + +class TestParallelRuleDownloads: + def _make_rules(self, names): + return [{"file_name": n} for n in names] + + def _patch_stdin_tty(self): + mock_stdin = MagicMock() + mock_stdin.isatty.return_value = True + return patch("hate_crack.api.sys.stdin", mock_stdin) + + def test_submits_to_thread_pool(self, tmp_path): + rules = self._make_rules(["rule1.rule", "rule2.rule", "rule3.rule"]) + rules_dir = str(tmp_path / "rules") + os.makedirs(rules_dir) + with patch("hate_crack.api.download_hashmob_rule_list", return_value=rules), \ + patch("hate_crack.api.download_hashmob_rule") as mock_dl, \ + self._patch_stdin_tty(), \ + patch("builtins.input", return_value="a"): + list_and_download_hashmob_rules(rules_dir=rules_dir) + assert mock_dl.call_count == 3 + downloaded_names = {c.args[0] for c in mock_dl.call_args_list} + assert downloaded_names == {"rule1.rule", "rule2.rule", "rule3.rule"} + + def test_failure_does_not_block_others(self, tmp_path, capsys): + rules = self._make_rules(["good.rule", "bad.rule", "also_good.rule"]) + rules_dir = str(tmp_path / "rules") + os.makedirs(rules_dir) + + def side_effect(file_name, out_path): + if file_name == "bad.rule": + raise RuntimeError("download error") + + with patch("hate_crack.api.download_hashmob_rule_list", return_value=rules), \ + patch("hate_crack.api.download_hashmob_rule", side_effect=side_effect), \ + self._patch_stdin_tty(), \ + patch("builtins.input", return_value="a"): + list_and_download_hashmob_rules(rules_dir=rules_dir) + + captured = capsys.readouterr() + assert "2 succeeded" in captured.out + assert "1 failed" in captured.out + + def test_skips_already_downloaded(self, tmp_path, capsys): + rules = self._make_rules(["existing.rule", "new.rule"]) + rules_dir = str(tmp_path / "rules") + os.makedirs(rules_dir) + (tmp_path / "rules" / "existing.rule").touch() + with patch("hate_crack.api.download_hashmob_rule_list", return_value=rules), \ + patch("hate_crack.api.download_hashmob_rule") as mock_dl, \ + self._patch_stdin_tty(), \ + patch("builtins.input", return_value="a"): + list_and_download_hashmob_rules(rules_dir=rules_dir) + assert mock_dl.call_count == 1 + assert mock_dl.call_args.args[0] == "new.rule" + captured = capsys.readouterr() + assert "Skipping already downloaded" in captured.out diff --git a/tests/test_utils.py b/tests/test_utils.py index c0b3397..eda4b26 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -58,6 +58,7 @@ def test_get_hcat_wordlists_dir_from_config(tmp_path, monkeypatch): config_path.write_text('{"hcatWordlists": "wordlists"}') monkeypatch.setattr(api, "_resolve_config_path", lambda: str(config_path)) + monkeypatch.setattr(api, "_get_hate_path", lambda: str(tmp_path)) result = api.get_hcat_wordlists_dir() assert result == str(tmp_path / "wordlists") @@ -79,6 +80,7 @@ def test_get_rules_dir_from_config(tmp_path, monkeypatch): config_path.write_text('{"rules_directory": "rules"}') monkeypatch.setattr(api, "_resolve_config_path", lambda: str(config_path)) + monkeypatch.setattr(api, "_get_hate_path", lambda: str(tmp_path)) result = api.get_rules_dir() assert result == str(tmp_path / "rules")