From 5a258f499098394c0af25e2e3f00b1b603c2334d Mon Sep 17 00:00:00 2001 From: Eric Hennenfent Date: Wed, 27 Jan 2021 08:17:22 -0800 Subject: [PATCH 1/2] TUI Support Infrastructure (#1620) * Support for TUI (#1605) * Update worker thread for server creation * Add necessary files for TUI connectivity * Add necessary files for TUI connectivity * Update MonitorWorker * Update protocol * Blacken * Update setup.py dependencies * Remove state debugging messages * Update setup.py to build protobuf protocol upon install * Remove previously generated state_pb2.py * Change subprocess.Popen to subprocess.check_output * Remove extraneous output * First attempt at fixing protobuf installation It might work, it might not. We'll let the CI sort it out. * Can't forget the f-string * Error on missing protoc * Disable auto-generation of protobuf file * Ignore pb2_errors * Disable monitor start See if this makes the EVM tests pass Co-authored-by: Eric Hennenfent * Add log monitoring * Log monitoring via TCP * Swittch to rendering state lists directly * Extraneous line * Switch log buffer to multiprocessing queue * Create state transition events Should make it possible to track movements between state lists * Plug new events into context This will break the state merging plugin (but I'll fix it eventually) * move most enums to their own module * Blacken * Add DaemonThread from TUI branch * Add interface for registering daemon threads * Timestamp StateDescriptor upon updates * Capture return value * Blacken * Add solver wrapper to StateBase * Add `solve` events to all instances of SelectedSolver.instance() * Remove executor constraints from WASM * Add solve events to memory.py * Add intermittent execution event * Be more generous with states whose initialization we missed * Add Native callback for updating state descriptor * Fix state killing * Blacken * codecov: Remove outdated 'yml' entry in CI From these commits https://github.com/codecov/codecov-action/commit/ebea5cacdf7d0b843f66bcb1ae27d8f27b758e81 https://github.com/codecov/codecov-action/commit/49c86d6a5fd072b05c31a42baa67b8fb2e87c8f7 * Add solve event to evm Make warning messages better Debug GH actions Revert "Debug GH actions" This reverts commit f575eea3c3a09dbf2cd2b8e81940fc4297fdf039. Fix some pycharm-detected problems Make symbolic function error message more verbose Add solve to published events Loud errors in callbacks by default Trying to find out what's killing truffle Revert "Trying to find out what's killing truffle" This reverts commit 8bd02245ccc2ce54c9161072eba231880c084874. Revert "Make symbolic function error message more verbose" This reverts commit bd3e90cdb0dfabf0ac874badad723be50998f63b. Debugging Truffle Restore introspector Add try_except on every callback Unconditionally print error message Add traceback Update event.py Debug subscriptions Debug arguments to callbacks Different debug msg 1ast arg Print statement debugging... Pass in `None` as state Revert "Add try_except on every callback" This reverts commit 1c689dd43c619ba322cda5e7283686308c8db4e0. * Drop solve events outside of a state context Forgot did_solve Remove traceback * Fix must/cannot_be_null usage * Fix missing solve event * Partially restore old did_fork_state ABI * Called internally * Clone iterators instead of creating a list * Use isgenerator instead of checking if iterable * Fix snapshot restoration * Slightly improve Unicorn test API usage * Temporarily disable property verifier tests * improper skip arg * Add simple tests for introspection API * Add test for custom introspector, improve base introspection test * Add intermittent update timestamp * Only allow daemon registration and introspection registration at initialization * Add docs to manticore.py * Add docs for plugin, add update_state_descriptor to EVM * Fix renamed will_start_run --> will_run * Docstrings for DaemonThread and EventSolver * Docs for enums * Improve pretty printer, add some mypy fixes * Don't run daemon threads if run is called multiple times * If at first you don't succeed, destroy all the evidence you tried. * Test the pretty printer * Add StateDescriptor to RTD * Add newlines for RTD parsing * Update to work with new state introspection API * Add termination messages * Also capture killed state messages * Make info logs debug logs * Apply suggestions from code review Newlines for doc comments Co-authored-by: Eric Kilmer * Add some type hints to manticore.py * Add some type hints to plugin.py * Fix type hint for get_state * Add termination message from TUI PR * Add example script * Add docstrings to the example script * Pass introspection plugin type as an argument * Unskip property verifier tests * Add mypy-requests type hints * Remove itertools.tee The problem with usign tee is that only the first callback to use the iterator can write to it. In `ready_states`, the `save_state` after the `yield` statement is ignored for all others. * Make generator cloning a little bit more robust Now Manticore will give up and return the original argument instead of blowing up if it can't clone the generator * Clean up invalidated unit tests We now fire `introspect` for the first time before we have any states * Debug missing Truffle & Examples coverage * Merge coverage from XML file * Switch coverage to JSON, ignore debug logging and NotImplemented code * Fix copy commands * Move .coverage files directly * Set examples to append coverage * FLAG_NAME doesn't work the way we'd like * Use plugin dict to store introspector * Appease mypy * Fix missing property on unique name * Grab EVM PC * Blacken * Run black on all files if the git diff command fails * Fix mypy errors * Make plugin logging even less verbose * Move log capture and state monitoring to daemon threads * Use the config module for host & port * Fix worker configuration and add test for TUI API * Fix log messages breaking native tests * Split up base Manticore tests and logging tests The verbosity changes seem to be taking hold when they shouldn't * Merge LogTCPHandler and MonitorTCPHandler * Confirm that logging tests return to base level * Fix mypy * Switch back to using a deque for log buffering in the default case * Fix deque API * Update state_pb2.py * Reformat programatically generated files * Drop max verbosity in logging tests Haven't been able to figure out why, but somehow other loggers get "stuck" at this high verbosity and the integration tests try to print out the values of every single register. * Fix duplicated code from bad merge * Remove is_main from state_monitor * Add comment about log buffer size * Remove vestigial is_main * Blacken Co-authored-by: Philip Wang Co-authored-by: Eric Kilmer --- manticore/core/manticore.py | 47 ++++- manticore/core/plugin.py | 18 ++ manticore/core/state.proto | 31 +++ manticore/core/state_pb2.py | 369 +++++++++++++++++++++++++++++++++ manticore/core/worker.py | 145 +++++++++++++ manticore/utils/helpers.py | 11 + manticore/utils/log.py | 21 +- mypy.ini | 3 + setup.py | 2 + tests/native/test_logging.py | 19 ++ tests/native/test_manticore.py | 17 -- tests/other/test_tui_api.py | 143 +++++++++++++ 12 files changed, 801 insertions(+), 25 deletions(-) create mode 100644 manticore/core/state.proto create mode 100644 manticore/core/state_pb2.py create mode 100644 tests/native/test_logging.py create mode 100644 tests/other/test_tui_api.py diff --git a/manticore/core/manticore.py b/manticore/core/manticore.py index 1650b66e3..7f48ab25b 100644 --- a/manticore/core/manticore.py +++ b/manticore/core/manticore.py @@ -2,6 +2,7 @@ import itertools import logging import sys +import time import typing import random import weakref @@ -21,11 +22,18 @@ from ..utils.deprecated import deprecated from ..utils.enums import StateLists, MProcessingType from ..utils.event import Eventful -from ..utils.helpers import PickleSerializer, pretty_print_state_descriptors +from ..utils.helpers import PickleSerializer, pretty_print_state_descriptors, deque from ..utils.log import set_verbosity from ..utils.nointerrupt import WithKeyboardInterruptAs from .workspace import Workspace, Testcase -from .worker import WorkerSingle, WorkerThread, WorkerProcess, DaemonThread +from .worker import ( + WorkerSingle, + WorkerThread, + WorkerProcess, + DaemonThread, + LogCaptureWorker, + state_monitor, +) from multiprocessing.managers import SyncManager import threading @@ -88,6 +96,7 @@ def wait_for(self, condition, *args, **kwargs): self._terminated_states = [] self._busy_states = [] self._killed_states = [] + self._log_queue = deque(maxlen=5000) self._shared_context = {} def _manticore_threading(self): @@ -99,6 +108,7 @@ def _manticore_threading(self): self._terminated_states = [] self._busy_states = [] self._killed_states = [] + self._log_queue = deque(maxlen=5000) self._shared_context = {} def _manticore_multiprocessing(self): @@ -120,6 +130,9 @@ def raise_signal(): self._terminated_states = self._manager.list() self._busy_states = self._manager.list() self._killed_states = self._manager.list() + # The multiprocessing queue is much slower than the deque when it gets full, so we + # triple the size in order to prevent that from happening. + self._log_queue = self._manager.Queue(15000) self._shared_context = self._manager.dict() self._context_value_types = {list: self._manager.list, dict: self._manager.dict} @@ -370,8 +383,10 @@ def __init__( # Workers will use manticore __dict__ So lets spawn them last self._workers = [self._worker_type(id=i, manticore=self) for i in range(consts.procs)] - # We won't create the daemons until .run() is called - self._daemon_threads: typing.List[DaemonThread] = [] + # Create log capture worker. We won't create the rest of the daemons until .run() is called + self._daemon_threads: typing.Dict[int, DaemonThread] = { + -1: LogCaptureWorker(id=-1, manticore=self) + } self._daemon_callbacks: typing.List[typing.Callable] = [] self._snapshot = None @@ -1102,21 +1117,27 @@ def run(self): # User subscription to events is disabled from now on self.subscribe = None + self.register_daemon(state_monitor) + self._daemon_threads[-1].start() # Start log capture worker + # Passing generators to callbacks is a bit hairy because the first callback would drain it if we didn't # clone the iterator in event.py. We're preserving the old API here, but it's something to avoid in the future. self._publish("will_run", self.ready_states) self._running.value = True + # start all the workers! for w in self._workers: w.start() # Create each daemon thread and pass it `self` - if not self._daemon_threads: # Don't recreate the threads if we call run multiple times - for i, cb in enumerate(self._daemon_callbacks): + for i, cb in enumerate(self._daemon_callbacks): + if ( + i not in self._daemon_threads + ): # Don't recreate the threads if we call run multiple times dt = DaemonThread( id=i, manticore=self ) # Potentially duplicated ids with workers. Don't mix! - self._daemon_threads.append(dt) + self._daemon_threads[dt.id] = dt dt.start(cb) # Main process. Lets just wait and capture CTRL+C at main @@ -1173,6 +1194,17 @@ def finalize(self): self.generate_testcase(state) self.remove_all() + def wait_for_log_purge(self): + """ + If a client has accessed the log server, and there are still buffered logs, + waits up to 2 seconds for the client to retrieve the logs. + """ + if self._daemon_threads[-1].activated: + for _ in range(8): + if self._log_queue.empty(): + break + time.sleep(0.25) + ############################################################################ ############################################################################ ############################################################################ @@ -1188,6 +1220,7 @@ def save_run_data(self): config.save(f) logger.info("Results in %s", self._output.store.uri) + self.wait_for_log_purge() def introspect(self) -> typing.Dict[int, StateDescriptor]: """ diff --git a/manticore/core/plugin.py b/manticore/core/plugin.py index 9d2dac38d..80a7c091a 100644 --- a/manticore/core/plugin.py +++ b/manticore/core/plugin.py @@ -670,6 +670,24 @@ def get_state_descriptors(self) -> typing.Dict[int, StateDescriptor]: out = context.copy() # TODO: is this necessary to break out of the lock? return out + def did_kill_state_callback(self, state, ex: Exception): + """ + Capture other state-killing exceptions so we can get the corresponding message + + :param state: State that was killed + :param ex: The exception w/ the termination message + """ + state_id = state.id + with self.locked_context("manticore_state", dict) as context: + if state_id not in context: + logger.warning( + "Caught killing of state %s, but failed to capture its initialization", + state_id, + ) + context.setdefault(state_id, StateDescriptor(state_id=state_id)).termination_msg = repr( + ex + ) + @property def unique_name(self) -> str: return IntrospectionAPIPlugin.NAME diff --git a/manticore/core/state.proto b/manticore/core/state.proto new file mode 100644 index 000000000..dff84b636 --- /dev/null +++ b/manticore/core/state.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package mserialize; + +message LogMessage{ + string content = 1; +} + +message State{ + + enum StateType{ + READY = 0; + BUSY = 1; + KILLED = 2; + TERMINATED = 3; + } + + int32 id = 2; // state ID + StateType type = 3; // Type of state + string reason = 4; // Reason for execution stopping + int32 num_executing = 5; // number of executing instructions + int32 wait_time = 6; +} + +message StateList{ + repeated State states = 7; +} + +message MessageList{ + repeated LogMessage messages = 8; +} diff --git a/manticore/core/state_pb2.py b/manticore/core/state_pb2.py new file mode 100644 index 000000000..05e942adb --- /dev/null +++ b/manticore/core/state_pb2.py @@ -0,0 +1,369 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: state.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor.FileDescriptor( + name="state.proto", + package="mserialize", + syntax="proto3", + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n\x0bstate.proto\x12\nmserialize"\x1d\n\nLogMessage\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\t"\xb6\x01\n\x05State\x12\n\n\x02id\x18\x02 \x01(\x05\x12)\n\x04type\x18\x03 \x01(\x0e\x32\x1b.mserialize.State.StateType\x12\x0e\n\x06reason\x18\x04 \x01(\t\x12\x15\n\rnum_executing\x18\x05 \x01(\x05\x12\x11\n\twait_time\x18\x06 \x01(\x05"<\n\tStateType\x12\t\n\x05READY\x10\x00\x12\x08\n\x04\x42USY\x10\x01\x12\n\n\x06KILLED\x10\x02\x12\x0e\n\nTERMINATED\x10\x03".\n\tStateList\x12!\n\x06states\x18\x07 \x03(\x0b\x32\x11.mserialize.State"7\n\x0bMessageList\x12(\n\x08messages\x18\x08 \x03(\x0b\x32\x16.mserialize.LogMessageb\x06proto3', +) + + +_STATE_STATETYPE = _descriptor.EnumDescriptor( + name="StateType", + full_name="mserialize.State.StateType", + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name="READY", + index=0, + number=0, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key, + ), + _descriptor.EnumValueDescriptor( + name="BUSY", + index=1, + number=1, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key, + ), + _descriptor.EnumValueDescriptor( + name="KILLED", + index=2, + number=2, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key, + ), + _descriptor.EnumValueDescriptor( + name="TERMINATED", + index=3, + number=3, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key, + ), + ], + containing_type=None, + serialized_options=None, + serialized_start=181, + serialized_end=241, +) +_sym_db.RegisterEnumDescriptor(_STATE_STATETYPE) + + +_LOGMESSAGE = _descriptor.Descriptor( + name="LogMessage", + full_name="mserialize.LogMessage", + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name="content", + full_name="mserialize.LogMessage.content", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=b"".decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=27, + serialized_end=56, +) + + +_STATE = _descriptor.Descriptor( + name="State", + full_name="mserialize.State", + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name="id", + full_name="mserialize.State.id", + index=0, + number=2, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name="type", + full_name="mserialize.State.type", + index=1, + number=3, + type=14, + cpp_type=8, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name="reason", + full_name="mserialize.State.reason", + index=2, + number=4, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=b"".decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name="num_executing", + full_name="mserialize.State.num_executing", + index=3, + number=5, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name="wait_time", + full_name="mserialize.State.wait_time", + index=4, + number=6, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[ + _STATE_STATETYPE, + ], + serialized_options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=59, + serialized_end=241, +) + + +_STATELIST = _descriptor.Descriptor( + name="StateList", + full_name="mserialize.StateList", + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name="states", + full_name="mserialize.StateList.states", + index=0, + number=7, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=243, + serialized_end=289, +) + + +_MESSAGELIST = _descriptor.Descriptor( + name="MessageList", + full_name="mserialize.MessageList", + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name="messages", + full_name="mserialize.MessageList.messages", + index=0, + number=8, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=291, + serialized_end=346, +) + +_STATE.fields_by_name["type"].enum_type = _STATE_STATETYPE +_STATE_STATETYPE.containing_type = _STATE +_STATELIST.fields_by_name["states"].message_type = _STATE +_MESSAGELIST.fields_by_name["messages"].message_type = _LOGMESSAGE +DESCRIPTOR.message_types_by_name["LogMessage"] = _LOGMESSAGE +DESCRIPTOR.message_types_by_name["State"] = _STATE +DESCRIPTOR.message_types_by_name["StateList"] = _STATELIST +DESCRIPTOR.message_types_by_name["MessageList"] = _MESSAGELIST +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +LogMessage = _reflection.GeneratedProtocolMessageType( + "LogMessage", + (_message.Message,), + { + "DESCRIPTOR": _LOGMESSAGE, + "__module__": "state_pb2" + # @@protoc_insertion_point(class_scope:mserialize.LogMessage) + }, +) +_sym_db.RegisterMessage(LogMessage) + +State = _reflection.GeneratedProtocolMessageType( + "State", + (_message.Message,), + { + "DESCRIPTOR": _STATE, + "__module__": "state_pb2" + # @@protoc_insertion_point(class_scope:mserialize.State) + }, +) +_sym_db.RegisterMessage(State) + +StateList = _reflection.GeneratedProtocolMessageType( + "StateList", + (_message.Message,), + { + "DESCRIPTOR": _STATELIST, + "__module__": "state_pb2" + # @@protoc_insertion_point(class_scope:mserialize.StateList) + }, +) +_sym_db.RegisterMessage(StateList) + +MessageList = _reflection.GeneratedProtocolMessageType( + "MessageList", + (_message.Message,), + { + "DESCRIPTOR": _MESSAGELIST, + "__module__": "state_pb2" + # @@protoc_insertion_point(class_scope:mserialize.MessageList) + }, +) +_sym_db.RegisterMessage(MessageList) + + +# @@protoc_insertion_point(module_scope) diff --git a/manticore/core/worker.py b/manticore/core/worker.py index 1e9d74012..5da1fa436 100644 --- a/manticore/core/worker.py +++ b/manticore/core/worker.py @@ -1,11 +1,22 @@ from ..utils.nointerrupt import WithKeyboardInterruptAs from .state import Concretize, TerminateState +from ..core.plugin import Plugin, StateDescriptor +from .state_pb2 import StateList, MessageList, State, LogMessage +from ..utils.log import register_log_callback +from ..utils import config +from ..utils.enums import StateStatus, StateLists +from datetime import datetime import logging import multiprocessing import threading +from collections import deque import os +import socketserver import typing +consts = config.get_group("core") +consts.add("HOST", "localhost", "Address to bind the log & state servers to") +consts.add("PORT", 3214, "Port to use for the log server. State server runs one port higher.") logger = logging.getLogger(__name__) # logger.setLevel(9) @@ -255,3 +266,137 @@ def start(self, target: typing.Optional[typing.Callable] = None): self._t = threading.Thread(target=self.run if target is None else target, args=(self,)) self._t.daemon = True self._t.start() + + +class DumpTCPHandler(socketserver.BaseRequestHandler): + """ TCP Handler that calls the `dump` method bound to the server """ + + def handle(self): + self.request.sendall(self.server.dump()) + + +class ReusableTCPServer(socketserver.TCPServer): + """ Custom socket server that gracefully allows the address to be reused """ + + allow_reuse_address = True + dump: typing.Optional[typing.Callable] = None + + +class LogCaptureWorker(DaemonThread): + """ Extended DaemonThread that runs a TCP server that dumps the captured logs """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.activated = False #: Whether a client has ever connected + register_log_callback(self.log_callback) + + def log_callback(self, msg): + q = self.manticore._log_queue + try: + q.append(msg) + except AttributeError: + # Appending to a deque with maxlen=n is about 25x faster than checking if a queue.Queue is full, + # popping if so, and appending. For that reason, we use a deque in the threading and single, but + # a manager.Queue in multiprocessing (since that's all it supports). Catching an AttributeError + # is slightly faster than using `isinstance` for the default case (threading) but does slow down + # log throughput by about 20% (on top of the 25x slowdown) when using Multiprocessing instead of + # threading + if q.full(): + q.get() + q.put(msg) + + def dump_logs(self): + """ + Converts captured logs into protobuf format + """ + self.activated = True + serialized = MessageList() + q = self.manticore._log_queue + i = 0 + while i < 50 and not q.empty(): + msg = LogMessage(content=q.get()) + serialized.messages.append(msg) + i += 1 + return serialized.SerializeToString() + + def run(self, *args): + logger.debug( + "Capturing Logs via Thread %d. Pid %d Tid %d).", + self.id, + os.getpid(), + threading.get_ident(), + ) + + m = self.manticore + + try: + with ReusableTCPServer((consts.HOST, consts.PORT), DumpTCPHandler) as server: + server.dump = self.dump_logs # type: ignore + server.serve_forever() + except OSError as e: + # TODO - this should be logger.warning, but we need to rewrite several unit tests that depend on + # specific stdout output in order to do that. + logger.info("Could not start log capture server: %s", str(e)) + + +def render_state_descriptors(desc: typing.Dict[int, StateDescriptor]): + """ + Converts the built-in list of state descriptors into a StateList from Protobuf + + :param desc: Output from ManticoreBase.introspect + :return: Protobuf StateList to send over the wire + """ + out = StateList() + for st in desc.values(): + if st.status != StateStatus.destroyed: + now = datetime.now() + out.states.append( + State( + id=st.state_id, + type={ + StateLists.ready: State.READY, # type: ignore + StateLists.busy: State.BUSY, # type: ignore + StateLists.terminated: State.TERMINATED, # type: ignore + StateLists.killed: State.KILLED, # type: ignore + }[ + getattr(st, "state_list", StateLists.killed) + ], # If the state list is missing, assume it's killed + reason=st.termination_msg, + num_executing=st.own_execs, + wait_time=int( + (now - st.field_updated_at.get("state_list", now)).total_seconds() * 1000 + ), + ) + ) + return out + + +def state_monitor(self: DaemonThread): + """ + Daemon thread callback that runs a server that listens for incoming TCP connections and + dumps the list of state descriptors. + + :param self: DeamonThread created to run the server + """ + logger.debug( + "Monitoring States via Thread %d. Pid %d Tid %d).", + self.id, + os.getpid(), + threading.get_ident(), + ) + + m = self.manticore + + def dump_states(): + sts = m.introspect() + sts = render_state_descriptors(sts) + return sts.SerializeToString() + + try: + with ReusableTCPServer((consts.HOST, consts.PORT + 1), DumpTCPHandler) as server: + server.dump = dump_states # type: ignore + server.serve_forever() + except OSError as e: + # TODO - this should be logger.warning, but we need to rewrite several unit tests that depend on + # specific stdout output in order to do that. + logger.info("Could not start state monitor server: %s", str(e)) diff --git a/manticore/utils/helpers.py b/manticore/utils/helpers.py index 5c918b3ec..6c1d431f1 100644 --- a/manticore/utils/helpers.py +++ b/manticore/utils/helpers.py @@ -1,3 +1,4 @@ +import collections import logging import pickle import string @@ -200,3 +201,13 @@ def pretty_print_state_descriptors(desc: Dict): print(tab) print() + + +class deque(collections.deque): + """ A wrapper around collections.deque that adds a few APIs present in SyncManager.Queue """ + + def empty(self) -> bool: + return len(self) == 0 + + def get(self): + return self.popleft() diff --git a/manticore/utils/log.py b/manticore/utils/log.py index f49595f0a..c9a03ec75 100644 --- a/manticore/utils/log.py +++ b/manticore/utils/log.py @@ -1,5 +1,6 @@ import logging import sys +import io from typing import List, Set, Tuple @@ -13,6 +14,23 @@ handler.setFormatter(formatter) +class CallbackStream(io.TextIOBase): + def __init__(self, callback): + self.callback = callback + + def write(self, log_str): + self.callback(log_str) + + +def register_log_callback(cb): + for name in all_loggers: + logger = logging.getLogger(name) + handler_internal = logging.StreamHandler(CallbackStream(cb)) + if name.startswith("manticore"): + handler_internal.setFormatter(formatter) + logger.addHandler(handler_internal) + + class ContextFilter(logging.Filter): """ This is a filter which injects contextual information into the log. @@ -101,7 +119,7 @@ def get_levels() -> List[List[Tuple[str, int]]]: ("manticore.core.worker", logging.INFO), ("manticore.platforms.*", logging.DEBUG), ("manticore.ethereum", logging.DEBUG), - ("manticore.core.plugin", logging.DEBUG), + ("manticore.core.plugin", logging.INFO), ("manticore.wasm.*", logging.INFO), ("manticore.utils.emulate", logging.INFO), ], @@ -112,6 +130,7 @@ def get_levels() -> List[List[Tuple[str, int]]]: ("manticore.native.memory", logging.DEBUG), ("manticore.native.cpu.*", logging.DEBUG), ("manticore.native.cpu.*.registers", logging.DEBUG), + ("manticore.core.plugin", logging.DEBUG), ("manticore.utils.helpers", logging.INFO), ], # 5 (-vvvv) diff --git a/mypy.ini b/mypy.ini index 7f7279360..cbe17228e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -39,3 +39,6 @@ ignore_missing_imports = True [mypy-wasm.*] ignore_missing_imports = True + +[mypy-manticore.core.state_pb2] +ignore_errors = True diff --git a/setup.py b/setup.py index bc99357a6..350b66311 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,8 @@ def rtd_dependent_deps(): python_requires=">=3.6", install_requires=[ "pyyaml", + "protobuf", + # evm dependencies "pysha3", "prettytable", "ply", diff --git a/tests/native/test_logging.py b/tests/native/test_logging.py new file mode 100644 index 000000000..0fae68a46 --- /dev/null +++ b/tests/native/test_logging.py @@ -0,0 +1,19 @@ +import unittest +import logging + +from manticore.utils.log import get_verbosity, set_verbosity, DEFAULT_LOG_LEVEL + + +class ManticoreLogger(unittest.TestCase): + """Make sure we set the logging levels correctly""" + + _multiprocess_can_split_ = True + + def test_logging(self): + set_verbosity(1) + self.assertEqual(get_verbosity("manticore.native.cpu.abstractcpu"), logging.WARNING) + self.assertEqual(get_verbosity("manticore.ethereum.abi"), logging.INFO) + + set_verbosity(0) + self.assertEqual(get_verbosity("manticore.native.cpu.abstractcpu"), DEFAULT_LOG_LEVEL) + self.assertEqual(get_verbosity("manticore.ethereum.abi"), DEFAULT_LOG_LEVEL) diff --git a/tests/native/test_manticore.py b/tests/native/test_manticore.py index 8217e45b7..6531905e3 100644 --- a/tests/native/test_manticore.py +++ b/tests/native/test_manticore.py @@ -134,20 +134,3 @@ def test_integration_basic_stdin(self): else: self.assertTrue(a <= 0x41) self.assertTrue(b > 0x41) - - -class ManticoreLogger(unittest.TestCase): - """Make sure we set the logging levels correctly""" - - _multiprocess_can_split_ = True - - def test_logging(self): - set_verbosity(5) - self.assertEqual(get_verbosity("manticore.native.cpu.abstractcpu"), logging.DEBUG) - self.assertEqual(get_verbosity("manticore.ethereum.abi"), logging.DEBUG) - - set_verbosity(1) - self.assertEqual(get_verbosity("manticore.native.cpu.abstractcpu"), logging.WARNING) - self.assertEqual(get_verbosity("manticore.ethereum.abi"), logging.INFO) - - set_verbosity(0) diff --git a/tests/other/test_tui_api.py b/tests/other/test_tui_api.py new file mode 100644 index 000000000..08c546e2d --- /dev/null +++ b/tests/other/test_tui_api.py @@ -0,0 +1,143 @@ +import unittest +import threading +import socket +import select +import subprocess +import logging +import time +import sys + +from google.protobuf.message import DecodeError +from manticore.core.state_pb2 import StateList, State, MessageList +from pathlib import Path + + +PYTHON_BIN: str = sys.executable + +HOST = "localhost" +PORT = 4123 + +ms_file = str( + Path(__file__).parent.parent.parent.joinpath("examples", "linux", "binaries", "multiple-styles") +) + +finished = False +logs = [] +state_captures = [] + + +def fetch_update(): + logger = logging.getLogger("FetchThread") + while not finished: + try: + # Attempts to (re)connect to manticore server + log_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + logger.debug("Connecting to %s:%s", HOST, PORT) + log_sock.connect((HOST, PORT)) + logger.info("Connected to %s:%s", HOST, PORT) + + read_sockets, write_sockets, error_sockets = select.select([log_sock], [], [], 60) + + serialized = b"" + if read_sockets: + serialized = read_sockets[0].recv(10000) + logger.info("Pulled {} bytes".format(len(serialized))) + + try: + m = MessageList() + m.ParseFromString(serialized) + logs.extend(m.messages) + logger.info("Deserialized LogMessage") + + except DecodeError: + logger.info("Unable to deserialize message, malformed response") + + read_sockets[0].shutdown(socket.SHUT_RDWR) + + log_sock.close() + except socket.error: + logger.warning("Log Socket disconnected") + log_sock.close() + + try: + # Attempts to (re)connect to manticore server + state_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + logger.debug("Connecting to %s:%s", HOST, PORT + 1) + state_sock.connect((HOST, PORT + 1)) + logger.info("Connected to %s:%s", HOST, PORT + 1) + + read_sockets, write_sockets, error_sockets = select.select([state_sock], [], [], 60) + + serialized = b"" + if read_sockets: + serialized = read_sockets[0].recv(10000) + logger.info("Pulled {} bytes".format(len(serialized))) + + try: + m = StateList() + m.ParseFromString(serialized) + + logger.info("Got %d states", len(m.states)) + + state_captures.append(m.states) + + except DecodeError: + logger.info("Unable to deserialize message, malformed response") + + state_sock.shutdown(socket.SHUT_RDWR) + + state_sock.close() + except socket.error: + logger.warning("State Socket disconnected") + state_sock.close() + + time.sleep(0.5) + + +class MyTestCase(unittest.TestCase): + def test_something(self): + global finished + + fetch_thread = threading.Thread(target=fetch_update) + fetch_thread.start() + + cmd = [ + PYTHON_BIN, + "-m", + "manticore", + "-v", + "--no-color", + "--core.procs", + str(10), + "--core.seed", + str(100), + "--core.PORT", + str(PORT), + ms_file, + ] + + self.assertEqual(subprocess.check_call(cmd), 0, "Manticore had a non-zero exit code") + + finished = True + + # Check that logs look right + self.assertTrue(any("you got it!" in i.content for i in logs)) + self.assertTrue(any("Program finished with exit status: 0" in i.content for i in logs)) + self.assertEqual( + sum(1 if "Program finished with exit status: 1" in i.content else 0 for i in logs), 17 + ) + + # Check that state lists seem correct + self.assertEqual( + max(len(list(filter(lambda x: x.type == State.BUSY, i))) for i in state_captures), 10 + ) # At most ten running states + + self.assertEqual( + min(len(list(filter(lambda x: x.type == State.BUSY, i))) for i in state_captures), 0 + ) # No running states at the end + + self.assertEqual(max(len(i) for i in state_captures), 18) # Should have 18 states in total + + +if __name__ == "__main__": + unittest.main() From d85a9a270d89cee1097538bd65aeb330ad8ae4cb Mon Sep 17 00:00:00 2001 From: Sonya <60201678+sschriner@users.noreply.github.com> Date: Mon, 8 Feb 2021 13:38:19 -0500 Subject: [PATCH 2/2] Syscall specific hooks (#2389) * Non state specific functioning * State specific functioning * Add None to add_hook call in hook decorator * Moved will/did_invoke_syscall * Added functionality for hooking by function name to state specific hooks * Added functionality for hooking by sys function name to non state specific hooks * State specific tests --- manticore/native/cpu/abstractcpu.py | 1 + manticore/native/manticore.py | 108 +++++++++++++++---- manticore/native/state.py | 156 +++++++++++++++++++++++----- manticore/platforms/linux.py | 4 +- tests/native/test_manticore.py | 54 ++++++++++ tests/native/test_state.py | 24 +++++ 6 files changed, 300 insertions(+), 47 deletions(-) diff --git a/manticore/native/cpu/abstractcpu.py b/manticore/native/cpu/abstractcpu.py index 6e847937c..cccc2173d 100644 --- a/manticore/native/cpu/abstractcpu.py +++ b/manticore/native/cpu/abstractcpu.py @@ -503,6 +503,7 @@ class Cpu(Eventful): "read_memory", "decode_instruction", "execute_instruction", + "invoke_syscall", "set_descriptor", "map_memory", "protect_memory", diff --git a/manticore/native/manticore.py b/manticore/native/manticore.py index 9ec16062c..b9cdee2af 100644 --- a/manticore/native/manticore.py +++ b/manticore/native/manticore.py @@ -5,7 +5,7 @@ import os import shlex import time -from typing import Callable, Optional +from typing import Callable, Optional, Union import sys from elftools.elf.elffile import ELFFile from elftools.elf.sections import SymbolTableSection @@ -39,17 +39,24 @@ def __init__(self, path_or_state, argv=None, workspace_url=None, policy="random" initial_state = _make_initial_state(path_or_state, argv=argv, **kwargs) else: initial_state = path_or_state - super().__init__(initial_state, workspace_url=workspace_url, policy=policy, **kwargs) # Move the following into a linux plugin self._assertions = {} self.trace = None + self._linux_machine_arch: str # used when looking up syscall numbers for sys hooks # sugar for 'will_execute_instruction" self._hooks = {} self._after_hooks = {} + self._sys_hooks = {} + self._sys_after_hooks = {} self._init_hooks = set() + from ..platforms.linux import Linux + + if isinstance(initial_state.platform, Linux): + self._linux_machine_arch = initial_state.platform.current.machine + # self.subscribe('will_generate_testcase', self._generate_testcase_callback) ############################################################################ @@ -215,54 +222,91 @@ def init(self, f): self.subscribe("will_run", self._init_callback) return f - def hook(self, pc, after=False): + def hook( + self, pc_or_sys: Optional[Union[int, str]], after: bool = False, syscall: bool = False + ): """ A decorator used to register a hook function for a given instruction address. Equivalent to calling :func:`~add_hook`. - :param pc: Address of instruction to hook - :type pc: int or None + :param pc_or_sys: Address of instruction, syscall number, or syscall name to remove hook from + :type pc_or_sys: int or None if `syscall` = False. int, str, or None if `syscall` = True + :param after: Hook after PC (or after syscall) executes? + :param syscall: Catch a syscall invocation instead of instruction? """ def decorator(f): - self.add_hook(pc, f, after) + self.add_hook(pc_or_sys, f, after, None, syscall) return f return decorator def add_hook( self, - pc: Optional[int], + pc_or_sys: Optional[Union[int, str]], callback: HookCallback, after: bool = False, state: Optional[State] = None, + syscall: bool = False, ): """ - Add a callback to be invoked on executing a program counter. Pass `None` - for pc to invoke callback on every instruction. `callback` should be a callable - that takes one :class:`~manticore.core.state.State` argument. + Add a callback to be invoked on executing a program counter (or syscall). Pass `None` + for `pc_or_sys` to invoke callback on every instruction (or syscall). `callback` should + be a callable that takes one :class:`~manticore.core.state.State` argument. - :param pc: Address of instruction to hook + :param pc_or_sys: Address of instruction, syscall number, or syscall name to remove hook from + :type pc_or_sys: int or None if `syscall` = False. int, str, or None if `syscall` = True :param callback: Hook function - :param after: Hook after PC executes? + :param after: Hook after PC (or after syscall) executes? :param state: Optionally, add hook for this state only, else all states + :param syscall: Catch a syscall invocation instead of instruction? """ - if not (isinstance(pc, int) or pc is None): - raise TypeError(f"pc must be either an int or None, not {pc.__class__.__name__}") + if not (isinstance(pc_or_sys, int) or pc_or_sys is None or syscall): + raise TypeError(f"pc must be either an int or None, not {pc_or_sys.__class__.__name__}") + elif not (isinstance(pc_or_sys, (int, str)) or pc_or_sys is None) and syscall: + raise TypeError( + f"syscall must be either an int, string, or None, not {pc_or_sys.__class__.__name__}" + ) + + if isinstance(pc_or_sys, str): + from ..platforms import linux_syscalls + + table = getattr(linux_syscalls, self._linux_machine_arch) + for index, name in table.items(): + if name == pc_or_sys: + pc_or_sys = index + break + if isinstance(pc_or_sys, str): + logger.warning( + f"{pc_or_sys} is not a valid syscall name in architecture {self._linux_machine_arch}. " + "Please refer to manticore/platforms/linux_syscalls.py to find the correct name." + ) + return if state is None: # add hook to all states - hooks, when, hook_callback = ( - (self._hooks, "will_execute_instruction", self._hook_callback) - if not after - else (self._after_hooks, "did_execute_instruction", self._after_hook_callback) - ) - hooks.setdefault(pc, set()).add(callback) + if not syscall: + hooks, when, hook_callback = ( + (self._hooks, "will_execute_instruction", self._hook_callback) + if not after + else (self._after_hooks, "did_execute_instruction", self._after_hook_callback) + ) + else: + hooks, when, hook_callback = ( + (self._sys_hooks, "will_invoke_syscall", self._sys_hook_callback) + if not after + else ( + self._sys_after_hooks, + "did_invoke_syscall", + self._sys_after_hook_callback, + ) + ) + hooks.setdefault(pc_or_sys, set()).add(callback) if hooks: self.subscribe(when, hook_callback) else: # only hook for the specified state - state.add_hook(pc, callback, after) + state.add_hook(pc_or_sys, callback, after, syscall) def _hook_callback(self, state, pc, instruction): "Invoke all registered generic hooks" @@ -293,6 +337,28 @@ def _after_hook_callback(self, state, last_pc, pc, instruction): for cb in self._after_hooks.get(None, []): cb(state) + def _sys_hook_callback(self, state, syscall_num): + "Invoke all registered generic hooks" + + # Invoke all syscall_num-specific hooks + for cb in self._sys_hooks.get(syscall_num, []): + cb(state) + + # Invoke all syscall_num-agnostic hooks + for cb in self._sys_hooks.get(None, []): + cb(state) + + def _sys_after_hook_callback(self, state, syscall_num): + "Invoke all registered generic hooks" + + # Invoke all syscall_num-specific hooks + for cb in self._sys_after_hooks.get(syscall_num, []): + cb(state) + + # Invoke all syscall_num-agnostic hooks + for cb in self._sys_after_hooks.get(None, []): + cb(state) + def _init_callback(self, ready_states): for cb in self._init_hooks: # We _should_ only ever have one starting state. Right now we're putting diff --git a/manticore/native/state.py b/manticore/native/state.py index 33606c034..c57031865 100644 --- a/manticore/native/state.py +++ b/manticore/native/state.py @@ -1,4 +1,5 @@ import copy +import logging from collections import namedtuple from typing import Any, Callable, Dict, NamedTuple, Optional, Set, Tuple, Union @@ -7,9 +8,11 @@ from .. import issymbolic from ..core.state import StateBase, Concretize, TerminateState from ..core.smtlib import Expression +from ..platforms import linux_syscalls HookCallback = Callable[[StateBase], None] +logger = logging.getLogger(__name__) class CheckpointData(NamedTuple): @@ -22,23 +25,31 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._hooks: Dict[Optional[int], Set[HookCallback]] = {} self._after_hooks: Dict[Optional[int], Set[HookCallback]] = {} + self._sys_hooks: Dict[Optional[int], Set[HookCallback]] = {} + self._sys_after_hooks: Dict[Optional[int], Set[HookCallback]] = {} def __getstate__(self) -> Dict[str, Any]: state = super().__getstate__() state["hooks"] = self._hooks state["after_hooks"] = self._after_hooks + state["sys_hooks"] = self._sys_hooks + state["sys_after_hooks"] = self._sys_after_hooks return state def __setstate__(self, state: Dict[str, Any]) -> None: super().__setstate__(state) self._hooks = state["hooks"] self._after_hooks = state["after_hooks"] + self._sys_hooks = state["sys_hooks"] + self._sys_after_hooks = state["sys_after_hooks"] self._resub_hooks() def __enter__(self) -> "State": new_state = super().__enter__() new_state._hooks = copy.copy(self._hooks) new_state._after_hooks = copy.copy(self._after_hooks) + new_state._sys_hooks = copy.copy(self._sys_hooks) + new_state._sys_after_hooks = copy.copy(self._sys_after_hooks) # Update constraint pointers in platform objects from ..platforms.linux import SLinux @@ -55,56 +66,111 @@ def __enter__(self) -> "State": return new_state def _get_hook_context( - self, after: bool = True + self, after: bool = True, syscall: bool = False ) -> Tuple[Dict[Optional[int], Set[HookCallback]], str, Any]: """ Internal helper function to get hook context information. :param after: Whether we want info pertaining to hooks after instruction executes or before + :param syscall: Catch a syscall invocation instead of instruction? :return: Information for hooks after or before: - set of hooks for specified after or before - string of callback event - State function that handles the callback """ - return ( - (self._hooks, "will_execute_instruction", self._state_hook_callback) - if not after - else (self._after_hooks, "did_execute_instruction", self._state_after_hook_callback) - ) - - def remove_hook(self, pc: Optional[int], callback: HookCallback, after: bool = False) -> bool: + if not syscall: + return ( + (self._hooks, "will_execute_instruction", self._state_hook_callback) + if not after + else (self._after_hooks, "did_execute_instruction", self._state_after_hook_callback) + ) + else: + return ( + (self._sys_hooks, "will_invoke_syscall", self._state_sys_hook_callback) + if not after + else ( + self._sys_after_hooks, + "did_invoke_syscall", + self._state_sys_after_hook_callback, + ) + ) + + def remove_hook( + self, + pc_or_sys: Optional[Union[int, str]], + callback: HookCallback, + after: bool = False, + syscall: bool = False, + ) -> bool: """ Remove a callback with the specified properties - :param pc: Address of instruction to remove from - :param callback: The callback function that was at the address + :param pc_or_sys: Address of instruction, syscall number, or syscall name to remove hook from + :type pc_or_sys: int or None if `syscall` = False. int, str, or None if `syscall` = True + :param callback: The callback function that was at the address (or syscall) :param after: Whether it was after instruction executed or not + :param syscall: Catch a syscall invocation instead of instruction? :return: Whether it was removed """ - hooks, when, _ = self._get_hook_context(after) - cbs = hooks.get(pc, set()) + + if isinstance(pc_or_sys, str): + table = getattr(linux_syscalls, self._platform.current.machine) + for index, name in table.items(): + if name == pc_or_sys: + pc_or_sys = index + break + if isinstance(pc_or_sys, str): + logger.warning( + f"{pc_or_sys} is not a valid syscall name in architecture {self._platform.current.machine}. " + "Please refer to manticore/platforms/linux_syscalls.py to find the correct name." + ) + return False + + hooks, when, _ = self._get_hook_context(after, syscall) + cbs = hooks.get(pc_or_sys, set()) if callback in cbs: cbs.remove(callback) else: return False - if len(hooks.get(pc, set())) == 0: - del hooks[pc] + if not len(hooks.get(pc_or_sys, set())): + del hooks[pc_or_sys] return True - def add_hook(self, pc: Optional[int], callback: HookCallback, after: bool = False) -> None: + def add_hook( + self, + pc_or_sys: Optional[Union[int, str]], + callback: HookCallback, + after: bool = False, + syscall: bool = False, + ) -> None: """ - Add a callback to be invoked on executing a program counter. Pass `None` - for pc to invoke callback on every instruction. `callback` should be a callable - that takes one :class:`~manticore.native.state.State` argument. + Add a callback to be invoked on executing a program counter (or syscall). Pass `None` + for `pc_or_sys` to invoke callback on every instruction (or syscall invocation). + `callback` should be a callable that takes one :class:`~manticore.native.state.State` argument. - :param pc: Address of instruction to hook + :param pc_or_sys: Address of instruction to hook, syscall number, or syscall name + :type pc_or_sys: int or None if `syscall` = False. int, str, or None if `syscall` = True :param callback: Hook function - :param after: Hook after PC executes? - :param state: Add hook to this state + :param after: Hook after PC (or after syscall) executes? + :param syscall: Catch a syscall invocation instead of instruction? """ - hooks, when, hook_callback = self._get_hook_context(after) - hooks.setdefault(pc, set()).add(callback) + + if isinstance(pc_or_sys, str): + table = getattr(linux_syscalls, self._platform.current.machine) + for index, name in table.items(): + if name == pc_or_sys: + pc_or_sys = index + break + if isinstance(pc_or_sys, str): + logger.warning( + f"{pc_or_sys} is not a valid syscall name in architecture {self._platform.current.machine}. " + "Please refer to manticore/platforms/linux_syscalls.py to find the correct name." + ) + return + + hooks, when, hook_callback = self._get_hook_context(after, syscall) + hooks.setdefault(pc_or_sys, set()).add(callback) if hooks: self.subscribe(when, hook_callback) @@ -114,10 +180,16 @@ def _resub_hooks(self) -> None: state is active again. """ # TODO: check if the lists actually have hooks - _, when, hook_callback = self._get_hook_context(False) + _, when, hook_callback = self._get_hook_context(False, False) + self.subscribe(when, hook_callback) + + _, when, hook_callback = self._get_hook_context(True, False) + self.subscribe(when, hook_callback) + + _, when, hook_callback = self._get_hook_context(False, True) self.subscribe(when, hook_callback) - _, when, hook_callback = self._get_hook_context(True) + _, when, hook_callback = self._get_hook_context(True, True) self.subscribe(when, hook_callback) def _state_hook_callback(self, pc: int, _instruction: Instruction) -> None: @@ -156,6 +228,40 @@ def _state_after_hook_callback(self, last_pc: int, _pc: int, _instruction: Instr for cb in tmp_hooks.get(None, []): cb(self) + def _state_sys_hook_callback(self, syscall_num: int) -> None: + """ + Invoke all registered State hooks before the syscall executes. + + :param syscall_num: index of the syscall about to be executed + """ + # Prevent crash if removing hook(s) during a callback + tmp_hooks = copy.deepcopy(self._sys_hooks) + + # Invoke all syscall-specific hooks + for cb in tmp_hooks.get(syscall_num, []): + cb(self) + + # Invoke all syscall-agnostic hooks + for cb in tmp_hooks.get(None, []): + cb(self) + + def _state_sys_after_hook_callback(self, syscall_num: int): + """ + Invoke all registered State hooks after the syscall executes. + + :param syscall_num: index of the syscall that was just executed + """ + # Prevent crash if removing hook(s) during a callback + tmp_hooks = copy.deepcopy(self._sys_after_hooks) + + # Invoke all syscall-specific hooks + for cb in tmp_hooks.get(syscall_num, []): + cb(self) + + # Invoke all syscall-agnostic hooks + for cb in tmp_hooks.get(None, []): + cb(self) + @property def cpu(self): """ diff --git a/manticore/platforms/linux.py b/manticore/platforms/linux.py index 19e18fa6e..ae1fa758d 100644 --- a/manticore/platforms/linux.py +++ b/manticore/platforms/linux.py @@ -2909,13 +2909,15 @@ def execute(self): self.check_timers() self.sched() except (Interruption, Syscall) as e: + index: int = self._syscall_abi.syscall_number() + self._syscall_abi._cpu._publish("will_invoke_syscall", index) try: self.syscall() if hasattr(e, "on_handled"): e.on_handled() + self._syscall_abi._cpu._publish("did_invoke_syscall", index) except RestartSyscall: pass - return True # 64bit syscalls diff --git a/tests/native/test_manticore.py b/tests/native/test_manticore.py index 6531905e3..1d91866fc 100644 --- a/tests/native/test_manticore.py +++ b/tests/native/test_manticore.py @@ -88,6 +88,60 @@ def tmp(state): assert tmp in self.m._after_hooks[entry] + def test_add_sys_hook(self): + name = "sys_brk" + index = 12 + + def tmp(state): + assert state._platformn._syscall_abi.syscall_number() == index + self.m.kill() + + self.m.add_hook(name, tmp, syscall=True) + self.assertTrue(tmp in self.m._sys_hooks[index]) + + def test_sys_hook_dec(self): + index = 12 + + @self.m.hook(index, syscall=True) + def tmp(state): + assert state._platformn._syscall_abi.syscall_number() == index + self.m.kill() + + self.assertTrue(tmp in self.m._sys_hooks[index]) + + def test_sys_hook(self): + self.m.context["x"] = 0 + + @self.m.hook(None, syscall=True) + def tmp(state): + with self.m.locked_context() as ctx: + ctx["x"] = 1 + self.m.kill() + + self.m.run() + + self.assertEqual(self.m.context["x"], 1) + + def test_add_sys_hook_after(self): + def tmp(state): + pass + + index = 12 + self.m.add_hook(index, tmp, after=True, syscall=True) + assert tmp in self.m._sys_after_hooks[index] + + def test_sys_hook_after_dec(self): + name = "sys_mmap" + index = 9 + + @self.m.hook(name, after=True, syscall=True) + def tmp(state): + pass + + self.m.run() + + assert tmp in self.m._sys_after_hooks[index] + def test_init_hook(self): self.m.context["x"] = 0 diff --git a/tests/native/test_state.py b/tests/native/test_state.py index 93597ec2d..d330760df 100644 --- a/tests/native/test_state.py +++ b/tests/native/test_state.py @@ -319,6 +319,30 @@ def process_hook(state: State) -> None: self.m.run() self.assertIn("Reached fin callback", f.getvalue()) + def test_state_sys_hooks(self): + @self.m.hook(12, after=False, syscall=True) + def process_hook(state: State) -> None: + # We can't remove because the globally applied hooks are stored in + # the Manticore class, not State + self.assertFalse(state.remove_hook(12, process_hook, after=True, syscall=True)) + # We can remove this one because it was applied specifically to this + # State (or its parent) + self.assertTrue(state.remove_hook(None, do_nothing, after=True, syscall=True)) + + state.add_hook(None, do_nothing, after=False, syscall=True) + state.add_hook(None, do_nothing, after=True, syscall=True) + + # Should execute directly after sys_brk invocation + state.add_hook("sys_brk", fin, after=True, syscall=True) + + for state in self.m.ready_states: + self.m.add_hook(None, do_nothing, after=True, state=state, syscall=True) + + f = io.StringIO() + with redirect_stdout(f): + self.m.run() + self.assertIn("Reached fin callback", f.getvalue()) + class StateMergeTest(unittest.TestCase):