mirror of
https://github.com/trustedsec/hate_crack.git
synced 2026-03-12 21:23:05 -07:00
fix: auto-detect training device instead of defaulting to CUDA
The PassGPT training device menu now uses _detect_device() to default to the best available device (CUDA > MPS > CPU) rather than always defaulting to CUDA, which fails on systems without NVIDIA GPUs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -574,13 +574,20 @@ def passgpt_attack(ctx: Any) -> None:
|
||||
if not base:
|
||||
base = default_model
|
||||
|
||||
from hate_crack.passgpt_train import _detect_device
|
||||
|
||||
detected = _detect_device()
|
||||
device_labels = {"cuda": "cuda", "mps": "mps (Apple Silicon)", "cpu": "cpu"}
|
||||
device_options = ["cuda", "mps", "cpu"]
|
||||
print("\n\tSelect training device:")
|
||||
print("\t (1) cuda (Recommended)")
|
||||
print("\t (2) mps (Apple Silicon)")
|
||||
print("\t (3) cpu")
|
||||
device_choice = input("\n\tDevice [1]: ").strip()
|
||||
device_map = {"1": "cuda", "2": "mps", "3": "cpu", "": "cuda"}
|
||||
device = device_map.get(device_choice, "cuda")
|
||||
for i, dev in enumerate(device_options, 1):
|
||||
label = device_labels[dev]
|
||||
suffix = " (detected)" if dev == detected else ""
|
||||
print(f"\t ({i}) {label}{suffix}")
|
||||
default_idx = device_options.index(detected) + 1
|
||||
device_choice = input(f"\n\tDevice [{default_idx}]: ").strip()
|
||||
device_map = {"1": "cuda", "2": "mps", "3": "cpu", "": detected}
|
||||
device = device_map.get(device_choice, detected)
|
||||
|
||||
result = ctx.hcatPassGPTTrain(training_file, base, device=device)
|
||||
if result is None:
|
||||
|
||||
@@ -261,11 +261,14 @@ class TestPassGPTAttackHandler:
|
||||
ctx.select_file_with_autocomplete.return_value = "/tmp/wordlist.txt"
|
||||
ctx.hcatPassGPTTrain.return_value = "/home/user/.hate_crack/passgpt/wordlist"
|
||||
|
||||
# "T" for train, "" for default base model, "" for default device (cuda), "" for default max candidates
|
||||
# "T" for train, "" for default base model, "" for default device (auto-detected), "" for default max candidates
|
||||
inputs = iter(["T", "", "", ""])
|
||||
with (
|
||||
patch("builtins.input", side_effect=inputs),
|
||||
patch("hate_crack.attacks.os.path.isdir", return_value=False),
|
||||
patch(
|
||||
"hate_crack.passgpt_train._detect_device", return_value="cuda"
|
||||
),
|
||||
):
|
||||
from hate_crack.attacks import passgpt_attack
|
||||
|
||||
@@ -283,11 +286,14 @@ class TestPassGPTAttackHandler:
|
||||
ctx.select_file_with_autocomplete.return_value = "/tmp/wordlist.txt"
|
||||
ctx.hcatPassGPTTrain.return_value = None
|
||||
|
||||
# "T" for train, "" for default base model, "" for default device (cuda)
|
||||
# "T" for train, "" for default base model, "" for default device (auto-detected)
|
||||
inputs = iter(["T", "", ""])
|
||||
with (
|
||||
patch("builtins.input", side_effect=inputs),
|
||||
patch("hate_crack.attacks.os.path.isdir", return_value=False),
|
||||
patch(
|
||||
"hate_crack.passgpt_train._detect_device", return_value="cuda"
|
||||
),
|
||||
):
|
||||
from hate_crack.attacks import passgpt_attack
|
||||
|
||||
|
||||
Reference in New Issue
Block a user