diff --git a/CHANGELOG.md b/CHANGELOG.md index b3f3868e..aedcc7ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ - fix: optimizer doesn't recurse into And/Or/Some children @williballenthin #3020 +- fix: address classes __eq__ and __lt__ assert on type @williballenthin #3021 + ### capa Explorer Web ### capa Explorer IDA Pro plugin diff --git a/capa/features/address.py b/capa/features/address.py index 95e4ba64..110c2816 100644 --- a/capa/features/address.py +++ b/capa/features/address.py @@ -68,11 +68,13 @@ class ProcessAddress(Address): return hash((self.ppid, self.pid)) def __eq__(self, other): - assert isinstance(other, ProcessAddress) + if not isinstance(other, ProcessAddress): + return NotImplemented return (self.ppid, self.pid) == (other.ppid, other.pid) def __lt__(self, other): - assert isinstance(other, ProcessAddress) + if not isinstance(other, ProcessAddress): + return NotImplemented return (self.ppid, self.pid) < (other.ppid, other.pid) @@ -91,11 +93,13 @@ class ThreadAddress(Address): return hash((self.process, self.tid)) def __eq__(self, other): - assert isinstance(other, ThreadAddress) + if not isinstance(other, ThreadAddress): + return NotImplemented return (self.process, self.tid) == (other.process, other.tid) def __lt__(self, other): - assert isinstance(other, ThreadAddress) + if not isinstance(other, ThreadAddress): + return NotImplemented return (self.process, self.tid) < (other.process, other.tid) @@ -114,10 +118,13 @@ class DynamicCallAddress(Address): return hash((self.thread, self.id)) def __eq__(self, other): - return isinstance(other, DynamicCallAddress) and (self.thread, self.id) == (other.thread, other.id) + if not isinstance(other, DynamicCallAddress): + return NotImplemented + return (self.thread, self.id) == (other.thread, other.id) def __lt__(self, other): - assert isinstance(other, DynamicCallAddress) + if not isinstance(other, DynamicCallAddress): + return NotImplemented return (self.thread, self.id) < (other.thread, other.id) diff --git a/tests/test_engine.py b/tests/test_engine.py index 3d5e0a6d..d3f78488 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import capa.features.address from capa.engine import Or, And, Not, Some, Range from capa.features.insn import Number +from capa.features.address import ThreadAddress, ProcessAddress, DynamicCallAddress ADDR1 = capa.features.address.AbsoluteVirtualAddress(0x401001) ADDR2 = capa.features.address.AbsoluteVirtualAddress(0x401002) @@ -194,3 +197,52 @@ def test_eval_order(): assert Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}).children[1].statement == Number(2) assert Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}).children[1].statement != Number(1) + + +def test_address_cross_type_eq(): + proc = ProcessAddress(pid=1, ppid=0) + ava = capa.features.address.AbsoluteVirtualAddress(0x401001) + + assert (proc == ava) is False + assert (ava == proc) is False + + +def test_process_address_sorting(): + proc1 = ProcessAddress(pid=1, ppid=0) + proc2 = ProcessAddress(pid=2, ppid=0) + + assert sorted([proc2, proc1]) == [proc1, proc2] + + +def test_process_address_cross_type_sort_raises(): + proc = ProcessAddress(pid=1, ppid=0) + ava = capa.features.address.AbsoluteVirtualAddress(0x401001) + + with pytest.raises(TypeError): + sorted([proc, ava]) + + +def test_process_address_lt_returns_not_implemented_for_other_types(): + proc = ProcessAddress(pid=1, ppid=0) + ava = capa.features.address.AbsoluteVirtualAddress(0x401001) + + assert proc.__lt__(ava) is NotImplemented + + +def test_thread_address_cross_type_eq(): + proc = ProcessAddress(pid=1, ppid=0) + thread = ThreadAddress(process=proc, tid=10) + ava = capa.features.address.AbsoluteVirtualAddress(0x401001) + + assert (thread == ava) is False + assert (ava == thread) is False + + +def test_dynamic_call_address_cross_type_eq(): + proc = ProcessAddress(pid=1, ppid=0) + thread = ThreadAddress(process=proc, tid=10) + call = DynamicCallAddress(thread=thread, id=0) + ava = capa.features.address.AbsoluteVirtualAddress(0x401001) + + assert (call == ava) is False + assert (ava == call) is False