diff --git a/manticore/core/manticore.py b/manticore/core/manticore.py index 1650b66e3..7f48ab25b 100644 --- a/manticore/core/manticore.py +++ b/manticore/core/manticore.py @@ -2,6 +2,7 @@ import itertools import logging import sys +import time import typing import random import weakref @@ -21,11 +22,18 @@ from ..utils.deprecated import deprecated from ..utils.enums import StateLists, MProcessingType from ..utils.event import Eventful -from ..utils.helpers import PickleSerializer, pretty_print_state_descriptors +from ..utils.helpers import PickleSerializer, pretty_print_state_descriptors, deque from ..utils.log import set_verbosity from ..utils.nointerrupt import WithKeyboardInterruptAs from .workspace import Workspace, Testcase -from .worker import WorkerSingle, WorkerThread, WorkerProcess, DaemonThread +from .worker import ( + WorkerSingle, + WorkerThread, + WorkerProcess, + DaemonThread, + LogCaptureWorker, + state_monitor, +) from multiprocessing.managers import SyncManager import threading @@ -88,6 +96,7 @@ def wait_for(self, condition, *args, **kwargs): self._terminated_states = [] self._busy_states = [] self._killed_states = [] + self._log_queue = deque(maxlen=5000) self._shared_context = {} def _manticore_threading(self): @@ -99,6 +108,7 @@ def _manticore_threading(self): self._terminated_states = [] self._busy_states = [] self._killed_states = [] + self._log_queue = deque(maxlen=5000) self._shared_context = {} def _manticore_multiprocessing(self): @@ -120,6 +130,9 @@ def raise_signal(): self._terminated_states = self._manager.list() self._busy_states = self._manager.list() self._killed_states = self._manager.list() + # The multiprocessing queue is much slower than the deque when it gets full, so we + # triple the size in order to prevent that from happening. + self._log_queue = self._manager.Queue(15000) self._shared_context = self._manager.dict() self._context_value_types = {list: self._manager.list, dict: self._manager.dict} @@ -370,8 +383,10 @@ def __init__( # Workers will use manticore __dict__ So lets spawn them last self._workers = [self._worker_type(id=i, manticore=self) for i in range(consts.procs)] - # We won't create the daemons until .run() is called - self._daemon_threads: typing.List[DaemonThread] = [] + # Create log capture worker. We won't create the rest of the daemons until .run() is called + self._daemon_threads: typing.Dict[int, DaemonThread] = { + -1: LogCaptureWorker(id=-1, manticore=self) + } self._daemon_callbacks: typing.List[typing.Callable] = [] self._snapshot = None @@ -1102,21 +1117,27 @@ def run(self): # User subscription to events is disabled from now on self.subscribe = None + self.register_daemon(state_monitor) + self._daemon_threads[-1].start() # Start log capture worker + # Passing generators to callbacks is a bit hairy because the first callback would drain it if we didn't # clone the iterator in event.py. We're preserving the old API here, but it's something to avoid in the future. self._publish("will_run", self.ready_states) self._running.value = True + # start all the workers! for w in self._workers: w.start() # Create each daemon thread and pass it `self` - if not self._daemon_threads: # Don't recreate the threads if we call run multiple times - for i, cb in enumerate(self._daemon_callbacks): + for i, cb in enumerate(self._daemon_callbacks): + if ( + i not in self._daemon_threads + ): # Don't recreate the threads if we call run multiple times dt = DaemonThread( id=i, manticore=self ) # Potentially duplicated ids with workers. Don't mix! - self._daemon_threads.append(dt) + self._daemon_threads[dt.id] = dt dt.start(cb) # Main process. Lets just wait and capture CTRL+C at main @@ -1173,6 +1194,17 @@ def finalize(self): self.generate_testcase(state) self.remove_all() + def wait_for_log_purge(self): + """ + If a client has accessed the log server, and there are still buffered logs, + waits up to 2 seconds for the client to retrieve the logs. + """ + if self._daemon_threads[-1].activated: + for _ in range(8): + if self._log_queue.empty(): + break + time.sleep(0.25) + ############################################################################ ############################################################################ ############################################################################ @@ -1188,6 +1220,7 @@ def save_run_data(self): config.save(f) logger.info("Results in %s", self._output.store.uri) + self.wait_for_log_purge() def introspect(self) -> typing.Dict[int, StateDescriptor]: """ diff --git a/manticore/core/plugin.py b/manticore/core/plugin.py index 9d2dac38d..80a7c091a 100644 --- a/manticore/core/plugin.py +++ b/manticore/core/plugin.py @@ -670,6 +670,24 @@ def get_state_descriptors(self) -> typing.Dict[int, StateDescriptor]: out = context.copy() # TODO: is this necessary to break out of the lock? return out + def did_kill_state_callback(self, state, ex: Exception): + """ + Capture other state-killing exceptions so we can get the corresponding message + + :param state: State that was killed + :param ex: The exception w/ the termination message + """ + state_id = state.id + with self.locked_context("manticore_state", dict) as context: + if state_id not in context: + logger.warning( + "Caught killing of state %s, but failed to capture its initialization", + state_id, + ) + context.setdefault(state_id, StateDescriptor(state_id=state_id)).termination_msg = repr( + ex + ) + @property def unique_name(self) -> str: return IntrospectionAPIPlugin.NAME diff --git a/manticore/core/state.proto b/manticore/core/state.proto new file mode 100644 index 000000000..dff84b636 --- /dev/null +++ b/manticore/core/state.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package mserialize; + +message LogMessage{ + string content = 1; +} + +message State{ + + enum StateType{ + READY = 0; + BUSY = 1; + KILLED = 2; + TERMINATED = 3; + } + + int32 id = 2; // state ID + StateType type = 3; // Type of state + string reason = 4; // Reason for execution stopping + int32 num_executing = 5; // number of executing instructions + int32 wait_time = 6; +} + +message StateList{ + repeated State states = 7; +} + +message MessageList{ + repeated LogMessage messages = 8; +} diff --git a/manticore/core/state_pb2.py b/manticore/core/state_pb2.py new file mode 100644 index 000000000..05e942adb --- /dev/null +++ b/manticore/core/state_pb2.py @@ -0,0 +1,369 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: state.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor.FileDescriptor( + name="state.proto", + package="mserialize", + syntax="proto3", + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n\x0bstate.proto\x12\nmserialize"\x1d\n\nLogMessage\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\t"\xb6\x01\n\x05State\x12\n\n\x02id\x18\x02 \x01(\x05\x12)\n\x04type\x18\x03 \x01(\x0e\x32\x1b.mserialize.State.StateType\x12\x0e\n\x06reason\x18\x04 \x01(\t\x12\x15\n\rnum_executing\x18\x05 \x01(\x05\x12\x11\n\twait_time\x18\x06 \x01(\x05"<\n\tStateType\x12\t\n\x05READY\x10\x00\x12\x08\n\x04\x42USY\x10\x01\x12\n\n\x06KILLED\x10\x02\x12\x0e\n\nTERMINATED\x10\x03".\n\tStateList\x12!\n\x06states\x18\x07 \x03(\x0b\x32\x11.mserialize.State"7\n\x0bMessageList\x12(\n\x08messages\x18\x08 \x03(\x0b\x32\x16.mserialize.LogMessageb\x06proto3', +) + + +_STATE_STATETYPE = _descriptor.EnumDescriptor( + name="StateType", + full_name="mserialize.State.StateType", + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name="READY", + index=0, + number=0, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key, + ), + _descriptor.EnumValueDescriptor( + name="BUSY", + index=1, + number=1, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key, + ), + _descriptor.EnumValueDescriptor( + name="KILLED", + index=2, + number=2, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key, + ), + _descriptor.EnumValueDescriptor( + name="TERMINATED", + index=3, + number=3, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key, + ), + ], + containing_type=None, + serialized_options=None, + serialized_start=181, + serialized_end=241, +) +_sym_db.RegisterEnumDescriptor(_STATE_STATETYPE) + + +_LOGMESSAGE = _descriptor.Descriptor( + name="LogMessage", + full_name="mserialize.LogMessage", + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name="content", + full_name="mserialize.LogMessage.content", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=b"".decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=27, + serialized_end=56, +) + + +_STATE = _descriptor.Descriptor( + name="State", + full_name="mserialize.State", + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name="id", + full_name="mserialize.State.id", + index=0, + number=2, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name="type", + full_name="mserialize.State.type", + index=1, + number=3, + type=14, + cpp_type=8, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name="reason", + full_name="mserialize.State.reason", + index=2, + number=4, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=b"".decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name="num_executing", + full_name="mserialize.State.num_executing", + index=3, + number=5, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name="wait_time", + full_name="mserialize.State.wait_time", + index=4, + number=6, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[ + _STATE_STATETYPE, + ], + serialized_options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=59, + serialized_end=241, +) + + +_STATELIST = _descriptor.Descriptor( + name="StateList", + full_name="mserialize.StateList", + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name="states", + full_name="mserialize.StateList.states", + index=0, + number=7, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=243, + serialized_end=289, +) + + +_MESSAGELIST = _descriptor.Descriptor( + name="MessageList", + full_name="mserialize.MessageList", + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name="messages", + full_name="mserialize.MessageList.messages", + index=0, + number=8, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=291, + serialized_end=346, +) + +_STATE.fields_by_name["type"].enum_type = _STATE_STATETYPE +_STATE_STATETYPE.containing_type = _STATE +_STATELIST.fields_by_name["states"].message_type = _STATE +_MESSAGELIST.fields_by_name["messages"].message_type = _LOGMESSAGE +DESCRIPTOR.message_types_by_name["LogMessage"] = _LOGMESSAGE +DESCRIPTOR.message_types_by_name["State"] = _STATE +DESCRIPTOR.message_types_by_name["StateList"] = _STATELIST +DESCRIPTOR.message_types_by_name["MessageList"] = _MESSAGELIST +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +LogMessage = _reflection.GeneratedProtocolMessageType( + "LogMessage", + (_message.Message,), + { + "DESCRIPTOR": _LOGMESSAGE, + "__module__": "state_pb2" + # @@protoc_insertion_point(class_scope:mserialize.LogMessage) + }, +) +_sym_db.RegisterMessage(LogMessage) + +State = _reflection.GeneratedProtocolMessageType( + "State", + (_message.Message,), + { + "DESCRIPTOR": _STATE, + "__module__": "state_pb2" + # @@protoc_insertion_point(class_scope:mserialize.State) + }, +) +_sym_db.RegisterMessage(State) + +StateList = _reflection.GeneratedProtocolMessageType( + "StateList", + (_message.Message,), + { + "DESCRIPTOR": _STATELIST, + "__module__": "state_pb2" + # @@protoc_insertion_point(class_scope:mserialize.StateList) + }, +) +_sym_db.RegisterMessage(StateList) + +MessageList = _reflection.GeneratedProtocolMessageType( + "MessageList", + (_message.Message,), + { + "DESCRIPTOR": _MESSAGELIST, + "__module__": "state_pb2" + # @@protoc_insertion_point(class_scope:mserialize.MessageList) + }, +) +_sym_db.RegisterMessage(MessageList) + + +# @@protoc_insertion_point(module_scope) diff --git a/manticore/core/worker.py b/manticore/core/worker.py index 1e9d74012..5da1fa436 100644 --- a/manticore/core/worker.py +++ b/manticore/core/worker.py @@ -1,11 +1,22 @@ from ..utils.nointerrupt import WithKeyboardInterruptAs from .state import Concretize, TerminateState +from ..core.plugin import Plugin, StateDescriptor +from .state_pb2 import StateList, MessageList, State, LogMessage +from ..utils.log import register_log_callback +from ..utils import config +from ..utils.enums import StateStatus, StateLists +from datetime import datetime import logging import multiprocessing import threading +from collections import deque import os +import socketserver import typing +consts = config.get_group("core") +consts.add("HOST", "localhost", "Address to bind the log & state servers to") +consts.add("PORT", 3214, "Port to use for the log server. State server runs one port higher.") logger = logging.getLogger(__name__) # logger.setLevel(9) @@ -255,3 +266,137 @@ def start(self, target: typing.Optional[typing.Callable] = None): self._t = threading.Thread(target=self.run if target is None else target, args=(self,)) self._t.daemon = True self._t.start() + + +class DumpTCPHandler(socketserver.BaseRequestHandler): + """ TCP Handler that calls the `dump` method bound to the server """ + + def handle(self): + self.request.sendall(self.server.dump()) + + +class ReusableTCPServer(socketserver.TCPServer): + """ Custom socket server that gracefully allows the address to be reused """ + + allow_reuse_address = True + dump: typing.Optional[typing.Callable] = None + + +class LogCaptureWorker(DaemonThread): + """ Extended DaemonThread that runs a TCP server that dumps the captured logs """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.activated = False #: Whether a client has ever connected + register_log_callback(self.log_callback) + + def log_callback(self, msg): + q = self.manticore._log_queue + try: + q.append(msg) + except AttributeError: + # Appending to a deque with maxlen=n is about 25x faster than checking if a queue.Queue is full, + # popping if so, and appending. For that reason, we use a deque in the threading and single, but + # a manager.Queue in multiprocessing (since that's all it supports). Catching an AttributeError + # is slightly faster than using `isinstance` for the default case (threading) but does slow down + # log throughput by about 20% (on top of the 25x slowdown) when using Multiprocessing instead of + # threading + if q.full(): + q.get() + q.put(msg) + + def dump_logs(self): + """ + Converts captured logs into protobuf format + """ + self.activated = True + serialized = MessageList() + q = self.manticore._log_queue + i = 0 + while i < 50 and not q.empty(): + msg = LogMessage(content=q.get()) + serialized.messages.append(msg) + i += 1 + return serialized.SerializeToString() + + def run(self, *args): + logger.debug( + "Capturing Logs via Thread %d. Pid %d Tid %d).", + self.id, + os.getpid(), + threading.get_ident(), + ) + + m = self.manticore + + try: + with ReusableTCPServer((consts.HOST, consts.PORT), DumpTCPHandler) as server: + server.dump = self.dump_logs # type: ignore + server.serve_forever() + except OSError as e: + # TODO - this should be logger.warning, but we need to rewrite several unit tests that depend on + # specific stdout output in order to do that. + logger.info("Could not start log capture server: %s", str(e)) + + +def render_state_descriptors(desc: typing.Dict[int, StateDescriptor]): + """ + Converts the built-in list of state descriptors into a StateList from Protobuf + + :param desc: Output from ManticoreBase.introspect + :return: Protobuf StateList to send over the wire + """ + out = StateList() + for st in desc.values(): + if st.status != StateStatus.destroyed: + now = datetime.now() + out.states.append( + State( + id=st.state_id, + type={ + StateLists.ready: State.READY, # type: ignore + StateLists.busy: State.BUSY, # type: ignore + StateLists.terminated: State.TERMINATED, # type: ignore + StateLists.killed: State.KILLED, # type: ignore + }[ + getattr(st, "state_list", StateLists.killed) + ], # If the state list is missing, assume it's killed + reason=st.termination_msg, + num_executing=st.own_execs, + wait_time=int( + (now - st.field_updated_at.get("state_list", now)).total_seconds() * 1000 + ), + ) + ) + return out + + +def state_monitor(self: DaemonThread): + """ + Daemon thread callback that runs a server that listens for incoming TCP connections and + dumps the list of state descriptors. + + :param self: DeamonThread created to run the server + """ + logger.debug( + "Monitoring States via Thread %d. Pid %d Tid %d).", + self.id, + os.getpid(), + threading.get_ident(), + ) + + m = self.manticore + + def dump_states(): + sts = m.introspect() + sts = render_state_descriptors(sts) + return sts.SerializeToString() + + try: + with ReusableTCPServer((consts.HOST, consts.PORT + 1), DumpTCPHandler) as server: + server.dump = dump_states # type: ignore + server.serve_forever() + except OSError as e: + # TODO - this should be logger.warning, but we need to rewrite several unit tests that depend on + # specific stdout output in order to do that. + logger.info("Could not start state monitor server: %s", str(e)) diff --git a/manticore/utils/helpers.py b/manticore/utils/helpers.py index 5c918b3ec..6c1d431f1 100644 --- a/manticore/utils/helpers.py +++ b/manticore/utils/helpers.py @@ -1,3 +1,4 @@ +import collections import logging import pickle import string @@ -200,3 +201,13 @@ def pretty_print_state_descriptors(desc: Dict): print(tab) print() + + +class deque(collections.deque): + """ A wrapper around collections.deque that adds a few APIs present in SyncManager.Queue """ + + def empty(self) -> bool: + return len(self) == 0 + + def get(self): + return self.popleft() diff --git a/manticore/utils/log.py b/manticore/utils/log.py index f49595f0a..c9a03ec75 100644 --- a/manticore/utils/log.py +++ b/manticore/utils/log.py @@ -1,5 +1,6 @@ import logging import sys +import io from typing import List, Set, Tuple @@ -13,6 +14,23 @@ handler.setFormatter(formatter) +class CallbackStream(io.TextIOBase): + def __init__(self, callback): + self.callback = callback + + def write(self, log_str): + self.callback(log_str) + + +def register_log_callback(cb): + for name in all_loggers: + logger = logging.getLogger(name) + handler_internal = logging.StreamHandler(CallbackStream(cb)) + if name.startswith("manticore"): + handler_internal.setFormatter(formatter) + logger.addHandler(handler_internal) + + class ContextFilter(logging.Filter): """ This is a filter which injects contextual information into the log. @@ -101,7 +119,7 @@ def get_levels() -> List[List[Tuple[str, int]]]: ("manticore.core.worker", logging.INFO), ("manticore.platforms.*", logging.DEBUG), ("manticore.ethereum", logging.DEBUG), - ("manticore.core.plugin", logging.DEBUG), + ("manticore.core.plugin", logging.INFO), ("manticore.wasm.*", logging.INFO), ("manticore.utils.emulate", logging.INFO), ], @@ -112,6 +130,7 @@ def get_levels() -> List[List[Tuple[str, int]]]: ("manticore.native.memory", logging.DEBUG), ("manticore.native.cpu.*", logging.DEBUG), ("manticore.native.cpu.*.registers", logging.DEBUG), + ("manticore.core.plugin", logging.DEBUG), ("manticore.utils.helpers", logging.INFO), ], # 5 (-vvvv) diff --git a/mypy.ini b/mypy.ini index 7f7279360..cbe17228e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -39,3 +39,6 @@ ignore_missing_imports = True [mypy-wasm.*] ignore_missing_imports = True + +[mypy-manticore.core.state_pb2] +ignore_errors = True diff --git a/setup.py b/setup.py index bc99357a6..350b66311 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,8 @@ def rtd_dependent_deps(): python_requires=">=3.6", install_requires=[ "pyyaml", + "protobuf", + # evm dependencies "pysha3", "prettytable", "ply", diff --git a/tests/native/test_logging.py b/tests/native/test_logging.py new file mode 100644 index 000000000..0fae68a46 --- /dev/null +++ b/tests/native/test_logging.py @@ -0,0 +1,19 @@ +import unittest +import logging + +from manticore.utils.log import get_verbosity, set_verbosity, DEFAULT_LOG_LEVEL + + +class ManticoreLogger(unittest.TestCase): + """Make sure we set the logging levels correctly""" + + _multiprocess_can_split_ = True + + def test_logging(self): + set_verbosity(1) + self.assertEqual(get_verbosity("manticore.native.cpu.abstractcpu"), logging.WARNING) + self.assertEqual(get_verbosity("manticore.ethereum.abi"), logging.INFO) + + set_verbosity(0) + self.assertEqual(get_verbosity("manticore.native.cpu.abstractcpu"), DEFAULT_LOG_LEVEL) + self.assertEqual(get_verbosity("manticore.ethereum.abi"), DEFAULT_LOG_LEVEL) diff --git a/tests/native/test_manticore.py b/tests/native/test_manticore.py index 8217e45b7..6531905e3 100644 --- a/tests/native/test_manticore.py +++ b/tests/native/test_manticore.py @@ -134,20 +134,3 @@ def test_integration_basic_stdin(self): else: self.assertTrue(a <= 0x41) self.assertTrue(b > 0x41) - - -class ManticoreLogger(unittest.TestCase): - """Make sure we set the logging levels correctly""" - - _multiprocess_can_split_ = True - - def test_logging(self): - set_verbosity(5) - self.assertEqual(get_verbosity("manticore.native.cpu.abstractcpu"), logging.DEBUG) - self.assertEqual(get_verbosity("manticore.ethereum.abi"), logging.DEBUG) - - set_verbosity(1) - self.assertEqual(get_verbosity("manticore.native.cpu.abstractcpu"), logging.WARNING) - self.assertEqual(get_verbosity("manticore.ethereum.abi"), logging.INFO) - - set_verbosity(0) diff --git a/tests/other/test_tui_api.py b/tests/other/test_tui_api.py new file mode 100644 index 000000000..08c546e2d --- /dev/null +++ b/tests/other/test_tui_api.py @@ -0,0 +1,143 @@ +import unittest +import threading +import socket +import select +import subprocess +import logging +import time +import sys + +from google.protobuf.message import DecodeError +from manticore.core.state_pb2 import StateList, State, MessageList +from pathlib import Path + + +PYTHON_BIN: str = sys.executable + +HOST = "localhost" +PORT = 4123 + +ms_file = str( + Path(__file__).parent.parent.parent.joinpath("examples", "linux", "binaries", "multiple-styles") +) + +finished = False +logs = [] +state_captures = [] + + +def fetch_update(): + logger = logging.getLogger("FetchThread") + while not finished: + try: + # Attempts to (re)connect to manticore server + log_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + logger.debug("Connecting to %s:%s", HOST, PORT) + log_sock.connect((HOST, PORT)) + logger.info("Connected to %s:%s", HOST, PORT) + + read_sockets, write_sockets, error_sockets = select.select([log_sock], [], [], 60) + + serialized = b"" + if read_sockets: + serialized = read_sockets[0].recv(10000) + logger.info("Pulled {} bytes".format(len(serialized))) + + try: + m = MessageList() + m.ParseFromString(serialized) + logs.extend(m.messages) + logger.info("Deserialized LogMessage") + + except DecodeError: + logger.info("Unable to deserialize message, malformed response") + + read_sockets[0].shutdown(socket.SHUT_RDWR) + + log_sock.close() + except socket.error: + logger.warning("Log Socket disconnected") + log_sock.close() + + try: + # Attempts to (re)connect to manticore server + state_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + logger.debug("Connecting to %s:%s", HOST, PORT + 1) + state_sock.connect((HOST, PORT + 1)) + logger.info("Connected to %s:%s", HOST, PORT + 1) + + read_sockets, write_sockets, error_sockets = select.select([state_sock], [], [], 60) + + serialized = b"" + if read_sockets: + serialized = read_sockets[0].recv(10000) + logger.info("Pulled {} bytes".format(len(serialized))) + + try: + m = StateList() + m.ParseFromString(serialized) + + logger.info("Got %d states", len(m.states)) + + state_captures.append(m.states) + + except DecodeError: + logger.info("Unable to deserialize message, malformed response") + + state_sock.shutdown(socket.SHUT_RDWR) + + state_sock.close() + except socket.error: + logger.warning("State Socket disconnected") + state_sock.close() + + time.sleep(0.5) + + +class MyTestCase(unittest.TestCase): + def test_something(self): + global finished + + fetch_thread = threading.Thread(target=fetch_update) + fetch_thread.start() + + cmd = [ + PYTHON_BIN, + "-m", + "manticore", + "-v", + "--no-color", + "--core.procs", + str(10), + "--core.seed", + str(100), + "--core.PORT", + str(PORT), + ms_file, + ] + + self.assertEqual(subprocess.check_call(cmd), 0, "Manticore had a non-zero exit code") + + finished = True + + # Check that logs look right + self.assertTrue(any("you got it!" in i.content for i in logs)) + self.assertTrue(any("Program finished with exit status: 0" in i.content for i in logs)) + self.assertEqual( + sum(1 if "Program finished with exit status: 1" in i.content else 0 for i in logs), 17 + ) + + # Check that state lists seem correct + self.assertEqual( + max(len(list(filter(lambda x: x.type == State.BUSY, i))) for i in state_captures), 10 + ) # At most ten running states + + self.assertEqual( + min(len(list(filter(lambda x: x.type == State.BUSY, i))) for i in state_captures), 0 + ) # No running states at the end + + self.assertEqual(max(len(i) for i in state_captures), 18) # Should have 18 states in total + + +if __name__ == "__main__": + unittest.main()