fix: auto-detect training device instead of defaulting to CUDA

The PassGPT training device menu now uses _detect_device() to default
to the best available device (CUDA > MPS > CPU) rather than always
defaulting to CUDA, which fails on systems without NVIDIA GPUs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Justin Bollinger
2026-02-18 18:47:41 -05:00
parent 00a60af9a6
commit f0bba73225
2 changed files with 21 additions and 8 deletions

View File

@@ -574,13 +574,20 @@ def passgpt_attack(ctx: Any) -> None:
if not base:
base = default_model
from hate_crack.passgpt_train import _detect_device
detected = _detect_device()
device_labels = {"cuda": "cuda", "mps": "mps (Apple Silicon)", "cpu": "cpu"}
device_options = ["cuda", "mps", "cpu"]
print("\n\tSelect training device:")
print("\t (1) cuda (Recommended)")
print("\t (2) mps (Apple Silicon)")
print("\t (3) cpu")
device_choice = input("\n\tDevice [1]: ").strip()
device_map = {"1": "cuda", "2": "mps", "3": "cpu", "": "cuda"}
device = device_map.get(device_choice, "cuda")
for i, dev in enumerate(device_options, 1):
label = device_labels[dev]
suffix = " (detected)" if dev == detected else ""
print(f"\t ({i}) {label}{suffix}")
default_idx = device_options.index(detected) + 1
device_choice = input(f"\n\tDevice [{default_idx}]: ").strip()
device_map = {"1": "cuda", "2": "mps", "3": "cpu", "": detected}
device = device_map.get(device_choice, detected)
result = ctx.hcatPassGPTTrain(training_file, base, device=device)
if result is None:

View File

@@ -261,11 +261,14 @@ class TestPassGPTAttackHandler:
ctx.select_file_with_autocomplete.return_value = "/tmp/wordlist.txt"
ctx.hcatPassGPTTrain.return_value = "/home/user/.hate_crack/passgpt/wordlist"
# "T" for train, "" for default base model, "" for default device (cuda), "" for default max candidates
# "T" for train, "" for default base model, "" for default device (auto-detected), "" for default max candidates
inputs = iter(["T", "", "", ""])
with (
patch("builtins.input", side_effect=inputs),
patch("hate_crack.attacks.os.path.isdir", return_value=False),
patch(
"hate_crack.passgpt_train._detect_device", return_value="cuda"
),
):
from hate_crack.attacks import passgpt_attack
@@ -283,11 +286,14 @@ class TestPassGPTAttackHandler:
ctx.select_file_with_autocomplete.return_value = "/tmp/wordlist.txt"
ctx.hcatPassGPTTrain.return_value = None
# "T" for train, "" for default base model, "" for default device (cuda)
# "T" for train, "" for default base model, "" for default device (auto-detected)
inputs = iter(["T", "", ""])
with (
patch("builtins.input", side_effect=inputs),
patch("hate_crack.attacks.os.path.isdir", return_value=False),
patch(
"hate_crack.passgpt_train._detect_device", return_value="cuda"
),
):
from hate_crack.attacks import passgpt_attack