fix: assert_never for exhaustive checks, tfile/line unbound, import paths, progress columns

This commit is contained in:
Willi Ballenthin
2026-04-22 14:28:19 +03:00
committed by Willi Ballenthin
parent dadf8b0961
commit 98d62bd39a
2 changed files with 12 additions and 12 deletions

View File

@@ -27,10 +27,10 @@ from zipfile import ZipFile
from datetime import datetime
import msgspec.json
from rich.text import Text
from rich.console import Console
from rich.progress import (
Task,
Text,
Progress,
BarColumn,
TextColumn,
@@ -142,6 +142,7 @@ def stdout_redirector(stream):
# Save a copy of the original stdout fd in saved_stdout_fd
saved_stdout_fd = os.dup(original_stdout_fd)
tfile = None
try:
# Create a temporary file and redirect stdout to it
tfile = tempfile.TemporaryFile(mode="w+b")
@@ -154,7 +155,8 @@ def stdout_redirector(stream):
tfile.seek(0, io.SEEK_SET)
stream.write(tfile.read())
finally:
tfile.close()
if tfile is not None:
tfile.close()
os.close(saved_stdout_fd)
@@ -197,9 +199,7 @@ def load_one_jsonl_from_path(jsonl_path: Path):
except gzip.BadGzipFile:
with jsonl_path.open(mode="rb") as f:
line = next(iter(f))
finally:
line = msgspec.json.decode(line.decode(errors="ignore"))
return line
return msgspec.json.decode(line.decode(errors="ignore"))
def get_format_from_report(sample: Path) -> str:
@@ -444,18 +444,18 @@ class MofNCompleteColumnWithUnit(MofNCompleteColumn):
class CapaProgressBar(Progress):
@classmethod
def get_default_columns(cls):
def get_default_columns(cls) -> tuple[ProgressColumn, ...]:
return (
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
TaskProgressColumn(),
BarColumn(),
MofNCompleteColumnWithUnit(),
"",
TextColumn(""),
TimeElapsedColumn(),
"<",
TextColumn("<"),
TimeRemainingColumn(),
"",
TextColumn(""),
RateColumn(),
PostfixColumn(),
)

View File

@@ -14,7 +14,7 @@
import logging
import textwrap
from typing import Iterable, Optional
from typing import Iterable, Optional, assert_never
from rich.text import Text
from rich.table import Table
@@ -183,7 +183,7 @@ def render_statement(console: Console, layout: rd.Layout, match: rd.Match, state
console.writeln()
else:
raise RuntimeError("unexpected match statement type: " + str(statement))
assert_never(statement)
def render_string_value(s: str) -> str:
@@ -281,7 +281,7 @@ def render_node(console: Console, layout: rd.Layout, rule: rd.RuleMatches, match
elif isinstance(node, rd.FeatureNode):
render_feature(console, layout, rule, match, node.feature, indent=indent)
else:
raise RuntimeError("unexpected node type: " + str(node))
assert_never(node)
# display nodes that successfully evaluated against the sample.