diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 34468c213..6cabe4eac 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,9 +16,9 @@ jobs: steps: - uses: actions/checkout@v2 - name: Set up Python 3.6 - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: - python-version: 3.6 + python-version: '3.6' - name: Lint if: github.event_name == 'pull_request' env: @@ -41,14 +41,15 @@ jobs: tests: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - type: ["ethereum_truffle", "ethereum_bench", "examples", "ethereum", "ethereum_vm", "native", "wasm", "wasm_sym", "other"] + type: ["ethereum_truffle", "ethereum_bench", "examples", "ethereum", "ethereum_vm", "native", "other", "wasm", "wasm_sym"] steps: - uses: actions/checkout@v1 - name: Set up Python 3.6 - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: - python-version: 3.6 + python-version: '3.6' - name: Install NPM uses: actions/setup-node@v1 with: @@ -98,9 +99,9 @@ jobs: steps: - uses: actions/checkout@v2 - name: Set up Python 3.6 - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: - python-version: 3.6 + python-version: '3.6' - name: Build Dist run: | python3 -m pip install wheel diff --git a/.gitignore b/.gitignore index 98ddc3600..01c2fbe8c 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,7 @@ nosetests.xml coverage.xml *,cover .hypothesis/ +mcore_* # Translations *.mo diff --git a/examples/evm/asm_to_smtlib.py b/examples/evm/asm_to_smtlib.py index 3af7944f7..9e6b94704 100644 --- a/examples/evm/asm_to_smtlib.py +++ b/examples/evm/asm_to_smtlib.py @@ -7,7 +7,9 @@ from manticore.utils import log # log.set_verbosity(9) -config.out_of_gas = 1 +# FIXME: What's the equivalent? +consts = config.get_group("evm") +consts["oog"] = "complete" def printi(instruction): @@ -42,11 +44,11 @@ def printi(instruction): ) -data = constraints.new_array(index_bits=256, name="array") +data = constraints.new_array(index_size=256, name="array") -class callbacks: - initial_stack = [] +class Callbacks: + initial_stack: List[BitVec] = [] def will_execute_instruction(self, pc, instr): for i in range(len(evm.stack), instr.pops): @@ -57,8 +59,8 @@ def will_execute_instruction(self, pc, instr): class DummyWorld: def __init__(self, constraints): - self.balances = constraints.new_array(index_bits=256, value_bits=256, name="balances") - self.storage = constraints.new_array(index_bits=256, value_bits=256, name="storage") + self.balances = constraints.new_array(index_size=256, value_size=256, name="balances") + self.storage = constraints.new_array(index_size=256, value_size=256, name="storage") self.origin = constraints.new_bitvec(256, name="origin") self.price = constraints.new_bitvec(256, name="price") self.timestamp = constraints.new_bitvec(256, name="timestamp") @@ -112,10 +114,10 @@ def send_funds(self, address, recipient, value): value = constraints.new_bitvec(256, name="value") world = DummyWorld(constraints) -callbacks = callbacks() +callbacks = Callbacks() # evm = world.current_vm -evm = EVM(constraints, 0x41424344454647484950, data, caller, value, code, world=world, gas=1000000) +evm = EVM(0x41424344454647484950, data, caller, value, code, world=world, gas=1000000) evm.subscribe("will_execute_instruction", callbacks.will_execute_instruction) print("CODE:") @@ -138,5 +140,5 @@ def send_funds(self, address, recipient, value): print(constraints) print( - f"PC: {translate_to_smtlib(evm.pc)} {solver.get_all_values(constraints, evm.pc, maxcnt=3, silent=True)}" + f"PC: {translate_to_smtlib(evm.pc)} {SMTLIBSolver.get_all_values(constraints, evm.pc, maxcnt=3, silent=True)}" ) diff --git a/examples/evm/coverage.py b/examples/evm/coverage.py index 9704dcbd7..e2240471f 100644 --- a/examples/evm/coverage.py +++ b/examples/evm/coverage.py @@ -8,9 +8,8 @@ user_account = m.create_account(balance=1000) -bytecode = m.compile(source_code) # Initialize contract -contract_account = m.create_contract(owner=user_account, balance=0, init=bytecode) +contract_account = m.solidity_create_contract(source_code, owner=user_account, balance=0) m.transaction( caller=user_account, diff --git a/examples/evm/mappingchallenge.py b/examples/evm/mappingchallenge.py index 3de1fc785..b454b5192 100644 --- a/examples/evm/mappingchallenge.py +++ b/examples/evm/mappingchallenge.py @@ -30,22 +30,24 @@ class StopAtDepth(Detector): """ This just aborts explorations that are too deep """ def will_run_callback(self, *args): - with self.manticore.locked_context("seen_rep", dict) as reps: - reps.clear() + if self.manticore: + with self.manticore.locked_context("seen_rep", dict) as reps: + reps.clear() def will_decode_instruction_callback(self, state, pc): world = state.platform - with self.manticore.locked_context("seen_rep", dict) as reps: - item = ( - world.current_transaction.sort == "CREATE", - world.current_transaction.address, - pc, - ) - if not item in reps: - reps[item] = 0 - reps[item] += 1 - if reps[item] > 2: - state.abandon() + if self.manticore: + with self.manticore.locked_context("seen_rep", dict) as reps: + item = ( + world.current_transaction.sort == "CREATE", + world.current_transaction.address, + pc, + ) + if not item in reps: + reps[item] = 0 + reps[item] += 1 + if reps[item] > 2: + state.abandon() m.register_plugin(StopAtDepth()) diff --git a/examples/evm/minimal-json.py b/examples/evm/minimal-json.py index 2792ce8f7..639605cd6 100644 --- a/examples/evm/minimal-json.py +++ b/examples/evm/minimal-json.py @@ -1047,7 +1047,9 @@ user_account = m.create_account(balance=1000, name="user_account") print("[+] Creating a user account", user_account.name_) -contract_account = m.json_create_contract(truffle_json, owner=user_account, name="contract_account") +contract_account = m.solidity_create_contract( + truffle_json, owner=user_account, name="contract_account" +) print("[+] Creating a contract account", contract_account.name_) contract_account.sendCoin(1, 1) diff --git a/examples/evm/minimal.py b/examples/evm/minimal.py index 565330186..4a4c939ec 100644 --- a/examples/evm/minimal.py +++ b/examples/evm/minimal.py @@ -24,13 +24,14 @@ } } """ -user_account = m.create_account(balance=m.make_symbolic_value(), name="user_account") + +user_account = m.create_account(balance=10 ** 10, name="user_account") print("[+] Creating a user account", user_account.name_) contract_account = m.solidity_create_contract( source_code, owner=user_account, name="contract_account" ) -print("[+] Creating a contract account", contract_account.name_) +print(f"[+] Creating a contract account {contract_account}") contract_account.named_func(1) print("[+] Now the symbolic values") diff --git a/examples/evm/use_def.py b/examples/evm/use_def.py index 7fafb8938..8ecb23c58 100644 --- a/examples/evm/use_def.py +++ b/examples/evm/use_def.py @@ -41,6 +41,8 @@ class EVMUseDef(Plugin): def did_evm_write_storage_callback(self, state, address, offset, value): m = self.manticore + if not m: + return world = state.platform tx = world.all_transactions[-1] md = m.get_metadata(tx.address) @@ -55,6 +57,8 @@ def did_evm_write_storage_callback(self, state, address, offset, value): def did_evm_read_storage_callback(self, state, address, offset, value): m = self.manticore + if not m: + return world = state.platform tx = world.all_transactions[-1] md = m.get_metadata(tx.address) diff --git a/examples/linux/binaries/concrete_solve.py b/examples/linux/binaries/concrete_solve.py index f98767e0e..da95a8d46 100644 --- a/examples/linux/binaries/concrete_solve.py +++ b/examples/linux/binaries/concrete_solve.py @@ -1,4 +1,4 @@ -from manticore import Manticore +from manticore.native import Manticore def fixme(): @@ -6,11 +6,9 @@ def fixme(): # Let's initialize the manticore control object -m = Manticore("multiple-styles") - # First, let's give it some fake data for the input. Anything the same size as # the real flag should work fine! -m.concrete_data = "infiltrate miami!" +m = Manticore("multiple-styles", concrete_start="infiltrate miami!") # Now we're going to want to execute a few different hooks and share data, so # let's use the m.context dict to keep our solution in diff --git a/examples/linux/binaries/symbolic_solve.py b/examples/linux/binaries/symbolic_solve.py index 747364b2d..8de19f81b 100644 --- a/examples/linux/binaries/symbolic_solve.py +++ b/examples/linux/binaries/symbolic_solve.py @@ -1,4 +1,4 @@ -from manticore import Manticore +from manticore.native import Manticore def fixme(): diff --git a/examples/script/concolic.py b/examples/script/concolic.py index 75f7eecae..84c2d1101 100755 --- a/examples/script/concolic.py +++ b/examples/script/concolic.py @@ -16,6 +16,7 @@ import queue import struct import itertools +from typing import Any from manticore import set_verbosity from manticore.native import Manticore @@ -234,7 +235,7 @@ def constraints_are_sat(cons): def get_new_constrs_for_queue(oldcons, newcons): - ret = [] + ret: List[Any] = [] # i'm pretty sure its correct to assume newcons is a superset of oldcons @@ -299,7 +300,7 @@ def concrete_input_to_constraints(ci, prev=None): def main(): - q = queue.Queue() + q: queue.Queue = queue.Queue() # todo randomly generated concrete start stdin = ints2inp(0, 5, 0) diff --git a/examples/script/introduce_symbolic_bytes.py b/examples/script/introduce_symbolic_bytes.py index 0ad3a187c..7d12ea1cd 100755 --- a/examples/script/introduce_symbolic_bytes.py +++ b/examples/script/introduce_symbolic_bytes.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import sys +import capstone from manticore import issymbolic from manticore.native import Manticore @@ -64,7 +65,6 @@ def introduce_sym(state): state.cpu.write_int(state.cpu.RBP - 0xC, val, 32) def has_tainted_operands(operands, taint_id): - # type: (list[manticore.core.cpu.abstractcpu.Operand], object) -> bool for operand in operands: op = operand.read() if issymbolic(op) and taint_id in op.taint: diff --git a/manticore/__main__.py b/manticore/__main__.py index 8f3bba29d..9eef5c569 100644 --- a/manticore/__main__.py +++ b/manticore/__main__.py @@ -26,6 +26,8 @@ def main() -> None: + import pdb + """ Dispatches execution into one of Manticore's engines: evm or native. """ diff --git a/manticore/core/smtlib/constraints.py b/manticore/core/smtlib/constraints.py index 533ad4ed7..3525be5b5 100644 --- a/manticore/core/smtlib/constraints.py +++ b/manticore/core/smtlib/constraints.py @@ -1,7 +1,6 @@ import itertools import sys import copy -from typing import Optional from ...utils.helpers import PickleSerializer from ...exceptions import SmtlibError from .expression import ( @@ -13,20 +12,21 @@ Bool, BitVec, BoolConstant, - ArrayProxy, + MutableArray, BoolEqual, Variable, Constant, ) from .visitors import ( + GetBindings, GetDeclarations, TranslatorSmtlib, get_variables, simplify, replace, - pretty_print, + CountExpressionUse, + translate_to_smtlib, ) -from ...utils import config import logging logger = logging.getLogger(__name__) @@ -40,6 +40,10 @@ class ConstraintException(SmtlibError): pass +class Model: + pass + + class ConstraintSet: """Constraint Sets @@ -78,6 +82,7 @@ def __enter__(self) -> "ConstraintSet": return self._child def __exit__(self, ty, value, traceback) -> None: + assert self._child is not None self._child._parent = None self._child = None @@ -93,9 +98,10 @@ def add(self, constraint) -> None: :param constraint: The constraint to add to the set. """ if isinstance(constraint, bool): - constraint = BoolConstant(constraint) + constraint = BoolConstant(value=constraint) assert isinstance(constraint, Bool) constraint = simplify(constraint) + # If self._child is not None this constraint set has been forked and a # a derived constraintset may be using this. So we can't add any more # constraints to this one. After the child constraintSet is deleted @@ -133,8 +139,8 @@ def related_to(self, *related_to) -> "ConstraintSet": Slices this ConstraintSet keeping only the related constraints. Two constraints are independient if they can be expressed full using a disjoint set of variables. - Todo: Research. constraints refering differen not overlapping parts of the same array - should be considered independient. + Todo: Research. constraints referring different not overlapping parts of the same array + should be considered independient. :param related_to: An expression :return: """ @@ -177,7 +183,24 @@ def related_to(self, *related_to) -> "ConstraintSet": return cs def to_string(self, replace_constants: bool = False) -> str: + replace_constants = True variables, constraints = self.get_declared_variables(), self.constraints + # rep_bindings = {} + # extra_variables, extra_constraints = [], [] + # counts = CountExpressionUse() + # for c in constraints: + # counts.visit(c) + # for exp, count in counts.counts.items(): + # if count > 1: + # if isinstance(exp, BitVec): + # new_var = BitVecVariable(size=exp.size, name=new_name) + # if isinstance(exp, Bool): + # new_var = BoolVariable(name=new_name) + # if isinstance(exp, Array): + # new_var = ArrayVariable(name=new_name, index_size=exp.index_size, value_size=exp.value_size) + # extra_constraints.append(new_var == exp) + # extra_variables.append(new_var) + # rep_bindings[exp] = new_var if replace_constants: constant_bindings = {} @@ -185,51 +208,56 @@ def to_string(self, replace_constants: bool = False) -> str: if ( isinstance(expression, BoolEqual) and isinstance(expression.operands[0], Variable) - and isinstance(expression.operands[1], (*Variable, *Constant)) + and isinstance(expression.operands[1], Constant) ): constant_bindings[expression.operands[0]] = expression.operands[1] - - tmp = set() - result = "" - for var in variables: - # FIXME - # band aid hack around the fact that we are double declaring stuff :( :( - if var.declaration in tmp: - logger.warning("Variable '%s' was copied twice somewhere", var.name) - continue - tmp.add(var.declaration) - result += var.declaration + "\n" + if ( + isinstance(expression, BoolEqual) + and isinstance(expression.operands[1], Variable) + and isinstance(expression.operands[0], Constant) + ): + constant_bindings[expression.operands[1]] = expression.operands[0] translator = TranslatorSmtlib(use_bindings=True) + # gb = GetBindings() + for v in variables: + translator.visit_Variable(v) + # for v in extra_variables: + # translator.visit_Variable(v) + # if constraints: + # for constraint in constraints: + # gb.visit(constraint) for constraint in constraints: if replace_constants: constraint = simplify(replace(constraint, constant_bindings)) - # if no variables then it is a constant - if isinstance(constraint, Constant) and constraint.value == True: - continue + if ( + isinstance(constraint, BoolEqual) + and isinstance(constraint.operands[0], Variable) + and isinstance(constraint.operands[1], Variable) + and constraint.operands[1] in constant_bindings + ): + constraint = simplify(replace(constraint, constant_bindings)) + # constraint = simplify(replace(constraint, rep_bindings)) + # if no variables then it is a constant + if isinstance(constraint, Constant) and constraint.value == True: + continue + # Translate one constraint translator.visit(constraint) + if replace_constants: for k, v in constant_bindings.items(): translator.visit(k == v) - for name, exp, smtlib in translator.bindings: - if isinstance(exp, BitVec): - result += f"(declare-fun {name} () (_ BitVec {exp.size}))" - elif isinstance(exp, Bool): - result += f"(declare-fun {name} () Bool)" - elif isinstance(exp, Array): - result += f"(declare-fun {name} () (Array (_ BitVec {exp.index_bits}) (_ BitVec {exp.value_bits})))" - else: - raise ConstraintException(f"Type not supported {exp!r}") - result += f"(assert (= {name} {smtlib}))\n" + # for constraint in extra_constraints: + # if replace_constants: + # constraint = simplify(replace(constraint, constant_bindings)) + # if isinstance(constraint, Constant) and constraint.value == True: + # continue + # # Translate one constraint + # translator.visit(constraint) - constraint_str = translator.pop() - while constraint_str is not None: - if constraint_str != "true": - result += f"(assert {constraint_str})\n" - constraint_str = translator.pop() - return result + return translator.smtlib() def _declare(self, var): """ Declare the variable `var` """ @@ -238,6 +266,10 @@ def _declare(self, var): self._declarations[var.name] = var return var + @property + def variables(self): + return self._declarations.values() + def get_declared_variables(self): """ Returns the variable expressions of this constraint set """ return self._declarations.values() @@ -348,11 +380,11 @@ def migrate(self, expression, name_migration_map=None): elif isinstance(foreign_var, Array): # Note that we are discarding the ArrayProxy encapsulation new_var = self.new_array( - index_max=foreign_var.index_max, - index_bits=foreign_var.index_bits, - value_bits=foreign_var.value_bits, + length=foreign_var.length, + index_size=foreign_var.index_size, + value_size=foreign_var.value_size, name=migrated_name, - ).array + ) else: raise NotImplementedError( f"Unknown expression type {type(foreign_var)} encountered during expression migration" @@ -380,7 +412,7 @@ def new_bool(self, name=None, taint=frozenset(), avoid_collisions=False): name = self._make_unique_name(name) if not avoid_collisions and name in self._declarations: raise ValueError(f"Name {name} already used") - var = BoolVariable(name, taint=taint) + var = BoolVariable(name=name, taint=taint) return self._declare(var) def new_bitvec(self, size, name=None, taint=frozenset(), avoid_collisions=False): @@ -392,7 +424,7 @@ def new_bitvec(self, size, name=None, taint=frozenset(), avoid_collisions=False) :return: a fresh BitVecVariable """ if size <= 0: - raise ValueError(f"Bitvec size ({size}) can't be equal to or less than 0") + raise ValueError(f"BitVec size ({size}) can't be equal to or less than 0") if name is None: name = "BV" avoid_collisions = True @@ -400,25 +432,25 @@ def new_bitvec(self, size, name=None, taint=frozenset(), avoid_collisions=False) name = self._make_unique_name(name) if not avoid_collisions and name in self._declarations: raise ValueError(f"Name {name} already used") - var = BitVecVariable(size, name, taint=taint) + var = BitVecVariable(size=size, name=name, taint=taint) return self._declare(var) def new_array( self, - index_bits=32, + index_size=32, name=None, - index_max=None, - value_bits=8, + length=None, + value_size=8, taint=frozenset(), avoid_collisions=False, default=None, ): - """Declares a free symbolic array of value_bits long bitvectors in the constraint store. - :param index_bits: size in bits for the array indexes one of [32, 64] - :param value_bits: size in bits for the array values + """Declares a free symbolic array of value_size long bitvectors in the constraint store. + :param index_size: size in bits for the array indexes one of [32, 64] + :param value_size: size in bits for the array values :param name: try to assign name to internal variable representation, if not unique, a numeric nonce will be appended - :param index_max: upper limit for indexes on this array (#FIXME) + :param length: upper limit for indexes on this array (#FIXME) :param avoid_collisions: potentially avoid_collisions the variable to avoid name collisions if True :param default: default for not initialized values :return: a fresh ArrayProxy @@ -430,5 +462,14 @@ def new_array( name = self._make_unique_name(name) if not avoid_collisions and name in self._declarations: raise ValueError(f"Name {name} already used") - var = self._declare(ArrayVariable(index_bits, index_max, value_bits, name, taint=taint)) - return ArrayProxy(var, default=default) + var = self._declare( + ArrayVariable( + index_size=index_size, + length=length, + value_size=value_size, + name=name, + taint=taint, + default=default, + ) + ) + return var diff --git a/manticore/core/smtlib/expression.py b/manticore/core/smtlib/expression.py index 5aea0210c..5cf70775f 100644 --- a/manticore/core/smtlib/expression.py +++ b/manticore/core/smtlib/expression.py @@ -1,13 +1,40 @@ +""" Module for Symbolic Expression + +ConstraintSets are considered a factory for new symbolic variables of type: +BoolVariable, BitVecVariable and ArrayVariable. + +Normal python operators are overloaded in each class, complex expressions trees +are built operating over expression variables and constants + + cs = ConstraintSet() + x = cs.new_bitvec(name="SOMEVARNAME", size=32) + y = x + 199 + condition1 = y < 1000 + condition1 = x > 0 + + cs.add( condition1 ) + cs.add( condition2 ) + +""" +from abc import ABC, abstractmethod from functools import reduce -from ...exceptions import SmtlibError import uuid - import re import copy -from typing import Union, Optional, Dict, List +from typing import overload, Union, Optional, Tuple, List, FrozenSet, Dict, Any, Set +from functools import lru_cache + + +def local_simplify(e): + from .visitors import simplify as visitor_simplify + x = visitor_simplify(e) + if isinstance(e, Bool): + assert isinstance(x, Bool) + return x -class ExpressionException(SmtlibError): + +class ExpressionError(Exception): """ Expression exception """ @@ -15,115 +42,190 @@ class ExpressionException(SmtlibError): pass -class Expression: +class XSlotted(type): + """Metaclass that will propagate slots on multi-inheritance classes + Every class should define __xslots__ (instead of __slots__) + + class Base(object, metaclass=XSlotted, abstract=True): + pass + + class A(Base, abstract=True): + __xslots__ = ('a',) + pass + + class B(Base, abstract=True): + __xslots__ = ('b',) + pass + + class C(A, B): + pass + + # Normal case / baseline + class X(object): + __slots__ = ('a', 'b') + + c = C() + c.a = 1 + c.b = 2 + + x = X() + x.a = 1 + x.b = 2 + + import sys + print (sys.getsizeof(c),sys.getsizeof(x)) #same value + + """ + + def __new__(cls, clsname, bases, attrs, abstract=False): + + xslots = frozenset(attrs.get("__xslots__", ())) + # merge the xslots of all the bases with the one defined here + for base in bases: + xslots = xslots.union(getattr(base, "__xslots__", ())) + attrs["__xslots__"] = tuple(xslots) + if abstract: + attrs["__slots__"] = tuple() + else: + attrs["__slots__"] = tuple(map(lambda attr: attr.split("#", 1)[0], attrs["__xslots__"])) + """ + def h(self): + print(self.__class__, self.__slots__) + s = [] + for name in self.__slots__: + if name in ("_concrete_cache", "_written"): + continue + try: + print (name) + x = getattr(self, name) + hash(x) + s.append(name) + except Exception as e: + print ("AAAAAAAAAAAAAAAAAAAAERRR", name, e) + return hash((clsname, tuple(getattr(self, name) for name in s))) + + attrs["__hash__"] = h + """ + attrs["__hash__"] = object.__hash__ + # attrs["__hash__"] = lambda self : hash((clsname, tuple(getattr(self, name) for name in self.__slots__ if name not in ("_concrete_cache", "_written")))) + return super().__new__(cls, clsname, bases, attrs) + + +class Expression(object, metaclass=XSlotted, abstract=True): """ Abstract taintable Expression. """ - __slots__ = ["_taint"] + __xslots__: Tuple[str, ...] = ("_taint",) - def __init__(self, taint: Union[tuple, frozenset] = ()): - if self.__class__ is Expression: - raise TypeError + def __init__(self, *, taint: FrozenSet[str] = frozenset()): + """ + An abstract Unmutable Taintable Expression + :param taint: A frozenzset of taints. Normally strings + """ + self._taint = taint super().__init__() - self._taint = frozenset(taint) def __repr__(self): - return "<{:s} at {:x}{:s}>".format(type(self).__name__, id(self), self.taint and "-T" or "") + return "<{:s} at {:x}{:s}>".format( + type(self).__name__, id(self), self._taint and "-T" or "" + ) @property - def is_tainted(self): - return len(self._taint) != 0 + def is_tainted(self) -> bool: + return bool(self._taint) @property - def taint(self): + def taint(self) -> FrozenSet[str]: return self._taint + @property + def operands(self) -> Optional[Tuple["Expression"]]: + """ Hack so we can use any Expression as a node """ + ... + + def __getstate__(self): + return {attr: getattr(self, attr) for attr in self.__slots__} -def issymbolic(value) -> bool: - """ - Helper to determine whether an object is symbolic (e.g checking - if data read from memory is symbolic) + def __setstate__(self, state): + for attr in self.__slots__: + setattr(self, attr, state[attr]) - :param object value: object to check - :return: whether `value` is symbolic - :rtype: bool - """ - return isinstance(value, Expression) +class Variable(Expression, abstract=True): + """ Variable is an Expression that has a name """ -def istainted(arg, taint=None): - """ - Helper to determine whether an object if tainted. - :param arg: a value or Expression - :param taint: a regular expression matching a taint value (eg. 'IMPORTANT.*'). If None, this function checks for any taint value. - """ + __xslots__: Tuple[str, ...] = ("_name",) - if not issymbolic(arg): - return False - if taint is None: - return len(arg.taint) != 0 - for arg_taint in arg.taint: - m = re.match(taint, arg_taint, re.DOTALL | re.IGNORECASE) - if m: - return True - return False + def __init__(self, *, name: str, **kwargs): + """Variable is an Expression that has a name + :param name: The Variable name + """ + super().__init__(**kwargs) + self._name = name + @property + def name(self) -> str: + return self._name + + def __repr__(self): + return "<{:s}({:s}) at {:x}>".format(type(self).__name__, self.name, id(self)) -def get_taints(arg, taint=None): """ - Helper to list an object taints. - :param arg: a value or Expression - :param taint: a regular expression matching a taint value (eg. 'IMPORTANT.*'). If None, this function checks for any taint value. + def __eq__(self, other): + # Ignore the taint for eq comparison + return self._name==other._name and super().__eq__(other) """ - if not issymbolic(arg): - return - for arg_taint in arg.taint: - if taint is not None: - m = re.match(taint, arg_taint, re.DOTALL | re.IGNORECASE) - if m: - yield arg_taint - else: - yield arg_taint - return +class Constant(Expression, abstract=True): + """ Constants expressions have a concrete python value. """ -def taint_with(arg, *taints, value_bits=256, index_bits=256): - """ - Helper to taint a value. - :param arg: a value or Expression - :param taint: a regular expression matching a taint value (eg. 'IMPORTANT.*'). If None, this function checks for any taint value. - """ + __xslots__: Tuple[str, ...] = ("_value",) - tainted_fset = frozenset(tuple(taints)) - if not issymbolic(arg): - if isinstance(arg, int): - arg = BitVecConstant(value_bits, arg) - arg._taint = tainted_fset - else: - raise ValueError("type not supported") + def __init__(self, *, value: Union[bool, int], **kwargs): + """A constant expression has a value - else: - if isinstance(arg, BitVecVariable): - arg = arg + BitVecConstant(value_bits, 0, taint=tainted_fset) - else: - arg = copy.copy(arg) - arg._taint |= tainted_fset + :param value: The constant value + """ + super().__init__(**kwargs) + self._value = value - return arg + @property + def value(self): + return self._value + + +class Operation(Expression, abstract=True): + """ Operation expressions contain operands which are also Expressions. """ + + __xslots__: Tuple[str, ...] = ("_operands",) + + def __init__(self, *, operands: Tuple[Expression, ...], **kwargs): + """An operation has operands + + :param operands: A tuple of expression operands + """ + self._operands = operands + taint = kwargs.get("taint") + # If taint was not forced by a keyword argument, calculate default + if taint is None: + kwargs["taint"] = frozenset({y for x in operands for y in x.taint}) + super().__init__(**kwargs) + + @property + def operands(self): + return self._operands ############################################################################### # Booleans -class Bool(Expression): - __slots__: List[str] = [] - - def __init__(self, *operands, **kwargs): - super().__init__(*operands, **kwargs) +class Bool(Expression, abstract=True): + """Bool expression represent symbolic value of truth""" def cast(self, value: Union["Bool", int, bool], **kwargs) -> Union["BoolConstant", "Bool"]: + """ Cast any type into a Bool or fail """ if isinstance(value, Bool): return value - return BoolConstant(bool(value), **kwargs) + return BoolConstant(value=bool(value), **kwargs) def __cmp__(self, *args): raise NotImplementedError("CMP for Bool") @@ -132,11 +234,10 @@ def __invert__(self): return BoolNot(self) def __eq__(self, other): + # A class that overrides __eq__() and does not define __hash__() + # will have its __hash__() implicitly set to None. return BoolEqual(self, self.cast(other)) - def __hash__(self): - return object.__hash__(self) - def __ne__(self, other): return BoolNot(self == self.cast(other)) @@ -159,126 +260,101 @@ def __rxor__(self, other): return BoolXor(self.cast(other), self) def __bool__(self): - # try to be forgiving. Allow user to use Bool in an IF sometimes - from .visitors import simplify - - x = simplify(self) - if isinstance(x, Constant): - return x.value - raise NotImplementedError("__bool__ for Bool") - - -class BoolVariable(Bool): - __slots__ = ["_name"] - - def __init__(self, name: str, *args, **kwargs): - assert " " not in name - super().__init__(*args, **kwargs) - self._name = name - - @property - def name(self): - return self._name - - def __copy__(self, memo=""): - raise ExpressionException("Copying of Variables is not allowed.") - - def __deepcopy__(self, memo=""): - raise ExpressionException("Copying of Variables is not allowed.") - - def __repr__(self): - return "<{:s}({:s}) at {:x}>".format(type(self).__name__, self.name, id(self)) - - @property - def declaration(self): - return f"(declare-fun {self.name} () Bool)" - - -class BoolConstant(Bool): - __slots__ = ["_value"] + raise ExpressionError( + "You tried to use a Bool Expression as a boolean constant. Expressions could represent a set of concrete values." + ) - def __init__(self, value: bool, *args, **kwargs): - self._value = value - super().__init__(*args, **kwargs) - def __bool__(self): - return self.value +class BoolVariable(Bool, Variable): + pass - @property - def value(self): - return self._value +class BoolConstant(Bool, Constant): + def __init__(self, *, value: bool, **kwargs): + super().__init__(value=bool(value), **kwargs) -class BoolOperation(Bool): - __slots__ = ["_operands"] + def __bool__(self) -> bool: + return bool(self._value) - def __init__(self, *operands, **kwargs): - self._operands = operands - # If taint was not forced by a keyword argument, calculate default - kwargs.setdefault("taint", reduce(lambda x, y: x.union(y.taint), operands, frozenset())) +class BoolOperation(Bool, Operation, abstract=True): + """ It's an operation that results in a Bool """ - super().__init__(**kwargs) + pass + # def __init__(self, *args, **kwargs): + # super().__init__(*args, **kwargs) - @property - def operands(self): - return self._operands + # def __xbool__(self): + # # FIXME: TODO: re-think is we want to be this forgiving every use of + # # local_simplify looks hacky + # simplified = self # local_simplify(self) + # if isinstance(simplified, Constant): + # return simplified.value + # raise ExpressionError("BoolOperation can not be reduced to a constant") class BoolNot(BoolOperation): - def __init__(self, value, **kwargs): - super().__init__(value, **kwargs) + def __init__(self, operand: Bool, **kwargs): + super().__init__(operands=(operand,), **kwargs) class BoolAnd(BoolOperation): - def __init__(self, a, b, **kwargs): - super().__init__(a, b, **kwargs) + def __init__(self, operanda: Bool, operandb: Bool, **kwargs): + super().__init__(operands=(operanda, operandb), **kwargs) class BoolOr(BoolOperation): - def __init__(self, a: "Bool", b: "Bool", **kwargs): - super().__init__(a, b, **kwargs) + def __init__(self, operanda: Bool, operandb: Bool, **kwargs): + super().__init__(operands=(operanda, operandb), **kwargs) class BoolXor(BoolOperation): - def __init__(self, a, b, **kwargs): - super().__init__(a, b, **kwargs) + def __init__(self, operanda: Bool, operandb: Bool, **kwargs): + super().__init__(operands=(operanda, operandb), **kwargs) class BoolITE(BoolOperation): - def __init__(self, cond: "Bool", true: "Bool", false: "Bool", **kwargs): - super().__init__(cond, true, false, **kwargs) + def __init__(self, cond: Bool, true: Bool, false: Bool, **kwargs): + super().__init__(operands=(cond, true, false), **kwargs) + +class BitVec(Expression, abstract=True): + """ BitVector expressions have a fixed bit size """ -class BitVec(Expression): - """ This adds a bitsize to the Expression class """ + __xslots__: Tuple[str, ...] = ("_size",) - __slots__ = ["size"] + def __init__(self, size: int, **kwargs): + """This is bit vector expression + + :param size: number of buts used + """ + super().__init__(**kwargs) + self._size = size - def __init__(self, size, *operands, **kwargs): - super().__init__(*operands, **kwargs) - self.size = size + @property + def size(self) -> int: + return self._size @property - def mask(self): + def mask(self) -> int: return (1 << self.size) - 1 @property - def signmask(self): + def signmask(self) -> int: return 1 << (self.size - 1) - def cast( - self, value: Union["BitVec", str, int, bytes], **kwargs - ) -> Union["BitVecConstant", "BitVec"]: + def cast(self, value: Union["BitVec", str, int, bytes], **kwargs) -> "BitVec": + """ Cast a value int a BitVec """ if isinstance(value, BitVec): - assert value.size == self.size + if value.size != self.size: + raise ExpressionError("BitVector of unexpected size") return value if isinstance(value, (str, bytes)) and len(value) == 1: value = ord(value) # Try to support not Integral types that can be casted to int + value = int(value) & self.mask if not isinstance(value, int): - value = int(value) - # FIXME? Assert it fits in the representation + raise ExpressionError("Not cast-able to BitVec") return BitVecConstant(self.size, value, **kwargs) def __add__(self, other): @@ -385,41 +461,42 @@ def __invert__(self): # x>=y calls x.__ge__(y). def __lt__(self, other): - return LessThan(self, self.cast(other)) + return BoolLessThan(operanda=self, operandb=self.cast(other)) def __le__(self, other): - return LessOrEqual(self, self.cast(other)) + return BoolLessOrEqualThan(self, self.cast(other)) def __eq__(self, other): + # A class that overrides __eq__() and does not define __hash__() + # will have its __hash__() implicitly set to None. return BoolEqual(self, self.cast(other)) - def __hash__(self): - return object.__hash__(self) - def __ne__(self, other): + # A class that overrides __eq__() and does not define __hash__() + # will have its __hash__() implicitly set to None. return BoolNot(BoolEqual(self, self.cast(other))) def __gt__(self, other): - return GreaterThan(self, self.cast(other)) + return BoolGreaterThan(self, self.cast(other)) def __ge__(self, other): - return GreaterOrEqual(self, self.cast(other)) + return BoolGreaterOrEqualThan(self, self.cast(other)) def __neg__(self): return BitVecNeg(self) # Unsigned comparisons def ugt(self, other): - return UnsignedGreaterThan(self, self.cast(other)) + return BoolUnsignedGreaterThan(self, self.cast(other)) def uge(self, other): - return UnsignedGreaterOrEqual(self, self.cast(other)) + return BoolUnsignedGreaterOrEqualThan(self, self.cast(other)) def ult(self, other): - return UnsignedLessThan(self, self.cast(other)) + return BoolUnsignedLessThan(self, self.cast(other)) def ule(self, other): - return UnsignedLessOrEqual(self, self.cast(other)) + return BoolUnsignedLessOrEqualThan(self, self.cast(other)) def udiv(self, other): return BitVecUnsignedDiv(self, self.cast(other)) @@ -455,500 +532,713 @@ def Bool(self): return self != 0 -class BitVecVariable(BitVec): - __slots__ = ["_name"] - - def __init__(self, size: int, name: str, *args, **kwargs): - assert " " not in name - super().__init__(size, *args, **kwargs) - self._name = name - - @property - def name(self): - return self._name - - def __copy__(self, memo=""): - raise ExpressionException("Copying of Variables is not allowed.") - - def __deepcopy__(self, memo=""): - raise ExpressionException("Copying of Variables is not allowed.") - - def __repr__(self): - return "<{:s}({:s}) at {:x}>".format(type(self).__name__, self.name, id(self)) - - @property - def declaration(self): - return f"(declare-fun {self.name} () (_ BitVec {self.size}))" - - -class BitVecConstant(BitVec): - __slots__ = ["_value"] - - def __init__(self, size: int, value: int, *args, **kwargs): - MASK = (1 << size) - 1 - self._value = value & MASK - super().__init__(size, *args, **kwargs) +class BitVecVariable(BitVec, Variable): + pass - def __bool__(self): - return self.value != 0 - def __eq__(self, other): - if self.taint: - return super().__eq__(other) - return self.value == other +class BitVecConstant(BitVec, Constant): + def __init__(self, size: int, value: int, **kwargs): + """ A bitvector constant """ + value &= (1 << size) - 1 # Can not use self.mask yet + super().__init__(size=size, value=value, **kwargs) - def __hash__(self): - return super().__hash__() + def __bool__(self) -> bool: + return bool(self.value) - @property - def value(self): - return self._value + def __int__(self) -> int: + return self.value @property - def signed_value(self): + def signed_value(self) -> int: + """ Gives signed python int representation """ if self._value & self.signmask: return self._value - (1 << self.size) else: return self._value + def __eq__(self, other): + """ If tainted keep a tainted symbolic value""" + if self.taint: + return BoolEqual(self, self.cast(other)) + # Ignore the taint for eq comparison + return self._value == other -class BitVecOperation(BitVec): - __slots__ = ["_operands"] - - def __init__(self, size, *operands, **kwargs): - self._operands = operands + def __repr__(self): + return f"BitVecConstant<{self.size}, {self.value}>" - # If taint was not forced by a keyword argument, calculate default - kwargs.setdefault("taint", reduce(lambda x, y: x.union(y.taint), operands, frozenset())) - super().__init__(size, **kwargs) +class BitVecOperation(BitVec, Operation, abstract=True): + """ Operations that result in a BitVec """ - @property - def operands(self): - return self._operands + def __init__(self, *, operands: Tuple[Expression, ...], **kwargs): + super().__init__(operands=operands, **kwargs) class BitVecAdd(BitVecOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a.size, a, b, *args, **kwargs) + def __init__(self, operanda: BitVec, operandb: BitVec, **kwargs): + assert operanda.size == operandb.size + super().__init__(size=operanda.size, operands=(operanda, operandb), **kwargs) class BitVecSub(BitVecOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a.size, a, b, *args, **kwargs) + def __init__(self, operanda, operandb, **kwargs): + super().__init__(size=operanda.size, operands=(operanda, operandb), **kwargs) class BitVecMul(BitVecOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a.size, a, b, *args, **kwargs) + def __init__(self, operanda, operandb, **kwargs): + super().__init__(size=operanda.size, operands=(operanda, operandb), **kwargs) class BitVecDiv(BitVecOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a.size, a, b, *args, **kwargs) + def __init__(self, operanda, operandb, **kwargs): + super().__init__(size=operanda.size, operands=(operanda, operandb), **kwargs) class BitVecUnsignedDiv(BitVecOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a.size, a, b, *args, **kwargs) + def __init__(self, operanda, operandb, **kwargs): + super().__init__(size=operanda.size, operands=(operanda, operandb), **kwargs) class BitVecMod(BitVecOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a.size, a, b, *args, **kwargs) + def __init__(self, operanda, operandb, **kwargs): + super().__init__(size=operanda.size, operands=(operanda, operandb), **kwargs) class BitVecRem(BitVecOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a.size, a, b, *args, **kwargs) + def __init__(self, operanda, operandb, **kwargs): + super().__init__(size=operanda.size, operands=(operanda, operandb), **kwargs) class BitVecUnsignedRem(BitVecOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a.size, a, b, *args, **kwargs) + def __init__(self, operanda, operandb, **kwargs): + super().__init__(size=operanda.size, operands=(operanda, operandb), **kwargs) class BitVecShiftLeft(BitVecOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a.size, a, b, *args, **kwargs) + def __init__(self, operanda, operandb, **kwargs): + super().__init__(size=operanda.size, operands=(operanda, operandb), **kwargs) class BitVecShiftRight(BitVecOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a.size, a, b, *args, **kwargs) + def __init__(self, operanda, operandb, **kwargs): + super().__init__(size=operanda.size, operands=(operanda, operandb), **kwargs) class BitVecArithmeticShiftLeft(BitVecOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a.size, a, b, *args, **kwargs) + def __init__(self, operanda, operandb, **kwargs): + super().__init__(size=operanda.size, operands=(operanda, operandb), **kwargs) class BitVecArithmeticShiftRight(BitVecOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a.size, a, b, *args, **kwargs) + def __init__(self, operanda, operandb, **kwargs): + super().__init__(size=operanda.size, operands=(operanda, operandb), **kwargs) class BitVecAnd(BitVecOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a.size, a, b, *args, **kwargs) + def __init__(self, operanda, operandb, *args, **kwargs): + super().__init__(size=operanda.size, operands=(operanda, operandb), **kwargs) class BitVecOr(BitVecOperation): - def __init__(self, a: BitVec, b: BitVec, *args, **kwargs): - assert a.size == b.size - super().__init__(a.size, a, b, *args, **kwargs) + def __init__(self, operanda: BitVec, operandb: BitVec, *args, **kwargs): + super().__init__(size=operanda.size, operands=(operanda, operandb), **kwargs) class BitVecXor(BitVecOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a.size, a, b, *args, **kwargs) + def __init__(self, operanda, operandb, **kwargs): + super().__init__(size=operanda.size, operands=(operanda, operandb), **kwargs) class BitVecNot(BitVecOperation): - def __init__(self, a, **kwargs): - super().__init__(a.size, a, **kwargs) + def __init__(self, operanda, **kwargs): + super().__init__(size=operanda.size, operands=(operanda,), **kwargs) class BitVecNeg(BitVecOperation): - def __init__(self, a, *args, **kwargs): - super().__init__(a.size, a, *args, **kwargs) + def __init__(self, operanda, **kwargs): + super().__init__(size=operanda.size, operands=(operanda,), **kwargs) # Comparing two bitvectors results in a Bool -class LessThan(BoolOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a, b, *args, **kwargs) +class BoolLessThan(BoolOperation): + def __init__(self, operanda: BitVec, operandb: BitVec, **kwargs): + super().__init__(operands=(operanda, operandb), **kwargs) -class LessOrEqual(BoolOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a, b, *args, **kwargs) +class BoolLessOrEqualThan(BoolOperation): + def __init__(self, operanda: BitVec, operandb: BitVec, **kwargs): + super().__init__(operands=(operanda, operandb), **kwargs) class BoolEqual(BoolOperation): - def __init__(self, a, b, *args, **kwargs): - if isinstance(a, BitVec) or isinstance(b, BitVec): - assert a.size == b.size - super().__init__(a, b, *args, **kwargs) + @overload + def __init__(self, operanda: BitVec, operandb: BitVec, **kwargs): + ... + @overload + def __init__(self, operanda: Bool, operandb: Bool, **kwargs): + ... -class GreaterThan(BoolOperation): - def __init__(self, a, b, *args, **kwargs): - assert a.size == b.size - super().__init__(a, b, *args, **kwargs) + @overload + def __init__(self, operanda: "Array", operandb: "Array", **kwargs): + ... + def __init__(self, operanda, operandb, **kwargs): + super().__init__(operands=(operanda, operandb), **kwargs) -class GreaterOrEqual(BoolOperation): - def __init__(self, a, b, *args, **kwargs): - assert a.size == b.size - super().__init__(a, b, *args, **kwargs) + def __bool__(self): + from .visitors import simplify + x = simplify(self) + if isinstance(x, Constant): + return bool(x.value) + # NOTE: True is the default return value for objects that don't implement __bool__ + # https://docs.python.org/3.6/reference/datamodel.html#object.__bool__ + return True -class UnsignedLessThan(BoolOperation): - def __init__(self, a, b, *args, **kwargs): - super().__init__(a, b, *args, **kwargs) - assert a.size == b.size +class BoolGreaterThan(BoolOperation): + def __init__(self, operanda: BitVec, operandb: BitVec, **kwargs): + super().__init__(operands=(operanda, operandb), **kwargs) -class UnsignedLessOrEqual(BoolOperation): - def __init__(self, a, b, *args, **kwargs): - assert a.size == b.size - super().__init__(a, b, *args, **kwargs) +class BoolGreaterOrEqualThan(BoolOperation): + def __init__(self, operanda: BitVec, operandb: BitVec, *args, **kwargs): + super().__init__(operands=(operanda, operandb), **kwargs) -class UnsignedGreaterThan(BoolOperation): - def __init__(self, a, b, *args, **kwargs): - assert a.size == b.size - super().__init__(a, b, *args, **kwargs) +class BoolUnsignedLessThan(BoolOperation): + def __init__(self, operanda: BitVec, operandb: BitVec, **kwargs): + super().__init__(operands=(operanda, operandb), **kwargs) -class UnsignedGreaterOrEqual(BoolOperation): - def __init__(self, a, b, *args, **kwargs): - assert a.size == b.size - super(UnsignedGreaterOrEqual, self).__init__(a, b, *args, **kwargs) +class BoolUnsignedLessOrEqualThan(BoolOperation): + def __init__(self, operanda: BitVec, operandb: BitVec, **kwargs): + super().__init__(operands=(operanda, operandb), **kwargs) -############################################################################### -# Array BV32 -> BV8 or BV64 -> BV8 -class Array(Expression): - __slots__ = ["_index_bits", "_index_max", "_value_bits"] - def __init__( - self, index_bits: int, index_max: Optional[int], value_bits: int, *operands, **kwargs - ): - assert index_bits in (32, 64, 256) - assert value_bits in (8, 16, 32, 64, 256) - assert index_max is None or index_max >= 0 and index_max < 2 ** index_bits - self._index_bits = index_bits - self._index_max = index_max - self._value_bits = value_bits - super().__init__(*operands, **kwargs) - assert type(self) is not Array, "Abstract class" - - def _get_size(self, index): - start, stop = self._fix_index(index) - size = stop - start - if isinstance(size, BitVec): - from .visitors import simplify +class BoolUnsignedGreaterThan(BoolOperation): + def __init__(self, operanda, operandb, *args, **kwargs): + super().__init__(operands=(operanda, operandb), **kwargs) - size = simplify(size) - else: - size = BitVecConstant(self.index_bits, size) - assert isinstance(size, BitVecConstant) - return size.value - def _fix_index(self, index): - """ - :param slice index: - """ - stop, start = index.stop, index.start - if start is None: - start = 0 - if stop is None: - stop = len(self) - return start, stop +class BoolUnsignedGreaterOrEqualThan(BoolOperation): + def __init__(self, operanda, operandb, **kwargs): + super(BoolUnsignedGreaterOrEqualThan, self).__init__( + operands=(operanda, operandb), **kwargs + ) - def cast(self, possible_array): - if isinstance(possible_array, bytearray): - # FIXME This should be related to a constrainSet - arr = ArrayVariable( - self.index_bits, len(possible_array), 8, "cast{}".format(uuid.uuid1()) - ) - for pos, byte in enumerate(possible_array): - arr = arr.store(pos, byte) - return arr - raise ValueError # cast not implemented - def cast_index(self, index: Union[int, "BitVec"]) -> Union["BitVecConstant", "BitVec"]: - if isinstance(index, int): - # assert self.index_max is None or index >= 0 and index < self.index_max - return BitVecConstant(self.index_bits, index) - assert index.size == self.index_bits - return index - - def cast_value( - self, value: Union["BitVec", str, bytes, int] - ) -> Union["BitVecConstant", "BitVec"]: - if isinstance(value, BitVec): - assert value.size == self.value_bits - return value - if isinstance(value, (str, bytes)) and len(value) == 1: - value = ord(value) - if not isinstance(value, int): - value = int(value) - return BitVecConstant(self.value_bits, value) +class Array(Expression, abstract=True): + """An Array expression is an unmutable mapping from bitvector to bitvector - def __len__(self): - if self.index_max is None: - raise ExpressionException("Array max index not set") - return self.index_max + array.index_size is the number of bits used for addressing a value + array.value_size is the number of bits used in the values + array.length counts the valid indexes starting at 0. Accessing outside the bound is undefined + + """ @property - def index_bits(self): - return self._index_bits + @abstractmethod + def index_size(self) -> int: + """ The bit size of the index part. Must be overloaded by a more specific class""" + ... @property - def value_bits(self): - return self._value_bits + def value_size(self) -> int: + """ The bit size of the value part. Must be overloaded by a more specific class""" + raise NotImplementedError @property - def index_max(self): - return self._index_max + def length(self) -> int: + """ Number of defined items. Must be overloaded by a more specific class""" + raise NotImplementedError - def select(self, index): - index = self.cast_index(index) - return ArraySelect(self, index) + def select(self, index) -> Union[BitVec, int]: + """ Gets a bitvector element from the Array que la""" + raise NotImplementedError - def store(self, index, value): - return ArrayStore(self, self.cast_index(index), self.cast_value(value)) + def store(self, index, value) -> "Array": + """ Create a new array that contains the updated value""" + raise NotImplementedError - def write(self, offset, buf): - if not isinstance(buf, (Array, bytearray)): - raise TypeError("Array or bytearray expected got {:s}".format(type(buf))) - arr = self - for i, val in enumerate(buf): - arr = arr.store(offset + i, val) - return arr + @property + def default(self) -> Optional[Union[BitVec, int]]: + """If defined, reading from an uninitialized index return the default value. + Otherwise, reading from an uninitialized index gives a symbol (normal Array behavior) + """ + raise NotImplementedError - def read(self, offset, size): - return ArraySlice(self, offset, size) + @property + def written(self): + """Returns the set of potentially symbolic indexes that were written in + this array. - def __getitem__(self, index): - if isinstance(index, slice): - start, stop = self._fix_index(index) - size = self._get_size(index) - return ArraySlice(self, start, size) - else: - if self.index_max is not None: - if not isinstance(index, Expression) and index >= self.index_max: - raise IndexError - return self.select(self.cast_index(index)) + Note that as you could overwrite an index so this could have more elements + than total elements in the array. + """ + raise NotImplementedError - def __iter__(self): - for i in range(len(self)): - yield self[i] + def is_known(self, index) -> Union[Bool, bool]: + """ Returned Boolean Expression holds when the index was used""" + raise NotImplementedError - def __eq__(self, other): - # FIXME taint - def compare_buffers(a, b): - if len(a) != len(b): - return BoolConstant(False) - cond = BoolConstant(True) - for i in range(len(a)): - cond = BoolAnd(cond.cast(a[i] == b[i]), cond) - if cond is BoolConstant(False): - return BoolConstant(False) - return cond - - return compare_buffers(self, other) + # Following methods are implemented on top of the abstract methods ^ + def in_bounds(self, index: Union[BitVec, int]) -> Union[Bool, bool]: + """ True if the index points inside the array (or array is unbounded)""" + if self.length is not None: + return 0 <= index < self.length + return True - def __ne__(self, other): - return BoolNot(self == other) + def __len__(self): + """ Number of values. """ + return self.length - def __hash__(self): - return super().__hash__() + def cast(self, array) -> "Array": + """Builds an Array from bytes or bytearray + FIXME: this assigns a random name to a new variable and does not use + a ConstraintSet as a Factory + """ + if isinstance(array, Array): + return array + arr = ArrayVariable( + index_size=self.index_size, + length=len(array), + default=0, + value_size=self.value_size, + name=f"cast{uuid.uuid1()}", + ) + for pos, byte in enumerate(array): + arr = arr.store(pos, byte) + return arr - @property - def underlying_variable(self): - array = self - while not isinstance(array, ArrayVariable): - array = array.array - return array + def cast_index(self, index: Union[int, BitVec]) -> BitVec: + """Forgiving casting method that will translate compatible values into + a compliant BitVec for indexing""" + if isinstance(index, int): + return BitVecConstant(self.index_size, index) + if not isinstance(index, BitVec) or index.size != self.index_size: + raise ExpressionError(f"Expected BitVector of size {self.index_size}") + if isinstance(index, Constant): + return index + return local_simplify(index) - def read_BE(self, address, size): + def cast_value(self, value: Union[BitVec, bytes, int]) -> BitVec: + """Forgiving casting method that will translate compatible values into + a compliant BitVec to be used as a value""" + if not isinstance(value, (BitVec, bytes, int)): + raise TypeError + if isinstance(value, BitVec): + if value.size != self.value_size: + raise ValueError + return value + if isinstance(value, bytes) and len(value) == 1: + value = ord(value) + if not isinstance(value, int): + value = int(value) + return BitVecConstant(self.value_size, value) + + def write(self, offset: Union[BitVec, int], buf: Union["Array", bytes]) -> "Array": + """Builds a new unmutable Array instance on top of current array by + writing buf at offset""" + array = self + for i, value in enumerate(buf): + array = array.store(offset + i, value) + return array + + def read(self, offset: int, size: int) -> "Array": + """ A projection of the current array. """ + return ArraySlice(self, offset=offset, size=size) + + @overload + def __getitem__(self, index: Union[BitVec, int]) -> Union[BitVec, int]: + ... + + @overload + def __getitem__(self, index: slice) -> "Array": + ... + + def __getitem__(self, index): + """__getitem__ allows for pythonic access + A = ArrayVariable(index_size=32, value_size=8) + A[10] := a symbol representing the value under index 10 in array A + A[10:20] := a symbol representing a slice of array A + """ + if isinstance(index, slice): + start, stop, size = self._fix_slice(index) + return self.read(start, size) + return self.select(index) + + def __iter__(self): + """ In a bounded array iterates over all elements. """ + for i in range(len(self)): + yield self[i] + + @staticmethod + def _compare_buffers(a: "Array", b: "Array") -> Bool: + """ Builds an expression that represents equality between the two arrays.""" + if a.length != b.length: + return BoolConstant(value=False) + cond: Bool = BoolConstant(value=True) + for i in range(a.length): + cond = BoolAnd(cond.cast(a[i] == b[i]), cond) + if cond is BoolConstant(value=False): + return BoolConstant(value=False) + return cond + + def __eq__(self, other): + """If both arrays has the same elements they are equal. + The difference in taints are ignored.""" + return self._compare_buffers(self, self.cast(other)) + + def __ne__(self, other): + return BoolNot(self == other) + + def _fix_slice(self, index: slice): + """Used to calculate the size of slices""" + stop, start = index.stop, index.start + if start is None: + start = 0 + if stop is None: + stop = len(self) + size = stop - start + if isinstance(size, BitVec): + size = local_simplify(size) + else: + size = BitVecConstant(self.index_size, size) + if not isinstance(size, BitVecConstant): + raise ExpressionError("Size could not be simplified to a constant in a slice operation") + return start, stop, size.value + + def _concatenate(self, array_a: "Array", array_b: "Array") -> "Array": + """Build a new array from the concatenation of the operands""" + new_arr = ArrayVariable( + index_size=self.index_size, + length=len(array_a) + len(array_b), + value_size=self.value_size, + name=f"concatenation{uuid.uuid1()}", + ) + for index in range(len(array_a)): + new_arr = new_arr.store(index, local_simplify(array_a[index])) + for index in range(len(array_b)): + new_arr = new_arr.store(index + len(array_a), local_simplify(array_b[index])) + return new_arr + + def __add__(self, other): + return self._concatenate(self, other) + + def __radd__(self, other): + return self._concatenate(other, self) + + @lru_cache(maxsize=128, typed=True) + def read_BE(self, address: Union[int, BitVec], size: int) -> Union[BitVec, int]: address = self.cast_index(address) bytes = [] for offset in range(size): - bytes.append(self.get(address + offset, 0)) - return BitVecConcat(size * self.value_bits, *bytes) + bytes.append(self.cast_value(self[address + offset])) + return BitVecConcat(operands=tuple(bytes)) - def read_LE(self, address, size): + @lru_cache(maxsize=128, typed=True) + def read_LE(self, address: Union[int, BitVec], size: int) -> Union[BitVec, int]: address = self.cast_index(address) bytes = [] for offset in range(size): - bytes.append(self.get(address + offset, 0)) - return BitVecConcat(size * self.value_bits, *reversed(bytes)) + bytes.append(self.cast_value(self[address + offset])) + return BitVecConcat(operands=tuple(reversed(bytes))) - def write_BE(self, address, value, size): + def write_BE( + self, address: Union[int, BitVec], value: Union[int, BitVec], size: int + ) -> "Array": address = self.cast_index(address) - value = BitVec(size * self.value_bits).cast(value) + value = BitVecConstant(size=size * self.value_size, value=0).cast(value) array = self for offset in range(size): array = array.store( address + offset, - BitVecExtract(value, (size - 1 - offset) * self.value_bits, self.value_bits), + BitVecExtract(value, (size - 1 - offset) * self.value_size, self.value_size), ) return array - def write_LE(self, address, value, size): + def write_LE( + self, address: Union[int, BitVec], value: Union[int, BitVec], size: int + ) -> "Array": address = self.cast_index(address) - value = BitVec(size * self.value_bits).cast(value) + value = BitVec(size * self.value_size).cast(value) array = self for offset in reversed(range(size)): array = array.store( address + offset, - BitVecExtract(value, (size - 1 - offset) * self.value_bits, self.value_bits), + BitVecExtract(value, (size - 1 - offset) * self.value_size, self.value_size), ) return array - def __add__(self, other): - if not isinstance(other, (Array, bytearray)): - raise TypeError("can't concat Array to {}".format(type(other))) - if isinstance(other, Array): - if self.index_bits != other.index_bits or self.value_bits != other.value_bits: - raise ValueError("Array sizes do not match for concatenation") - from .visitors import simplify +class ArrayConstant(Array, Constant): + __xslots__: Tuple[str, ...] = ("_index_size", "_value_size") - # FIXME This should be related to a constrainSet - new_arr = ArrayProxy( - ArrayVariable( - self.index_bits, - self.index_max + len(other), - self.value_bits, - "concatenation{}".format(uuid.uuid1()), - ) - ) - for index in range(self.index_max): - new_arr[index] = simplify(self[index]) - for index in range(len(other)): - new_arr[index + self.index_max] = simplify(other[index]) - return new_arr + def __init__( + self, + *, + index_size: int, + value_size: int, + **kwargs, + ): + self._index_size = index_size + self._value_size = value_size + super().__init__(**kwargs) - def __radd__(self, other): - if not isinstance(other, (Array, bytearray, bytes)): - raise TypeError("can't concat Array to {}".format(type(other))) - if isinstance(other, Array): - if self.index_bits != other.index_bits or self.value_bits != other.value_bits: - raise ValueError("Array sizes do not match for concatenation") + @property + def index_size(self) -> int: + return self._index_size - from .visitors import simplify + @property + def value_size(self) -> int: + return self._value_size + + @property + def length(self) -> int: + return len(self.value) + + def select(self, index): + """ ArrayConstant get """ + index = self.cast_index(index) + return self._select(index) + + @lru_cache(maxsize=128, typed=True) + def _select(self, index): + """ ArrayConstant get """ + index = self.cast_index(index) + if isinstance(index, Constant): + return BitVecConstant( + size=self.value_size, value=self.value[index.value], taint=self.taint + ) - # FIXME This should be related to a constrainSet - new_arr = ArrayProxy( - ArrayVariable( - self.index_bits, - self.index_max + len(other), - self.value_bits, - "concatenation{}".format(uuid.uuid1()), + # Index being symbolic generates a symbolic result ! + result: BitVec = BitVecConstant(size=self.value_size, value=0, taint=("out_of_bounds")) + for i, c in enumerate(self.value): + result = BitVecITE( + index == i, BitVecConstant(size=self.value_size, value=c), result, taint=self.taint ) + return result + + def read(self, offset, size): + assert len(self.value[offset : offset + size]) == size + return ArrayConstant( + index_size=self.index_size, + value_size=self.value_size, + value=self.value[offset : offset + size], ) - for index in range(len(other)): - new_arr[index] = simplify(other[index]) - _concrete_cache = new_arr._concrete_cache - for index in range(self.index_max): - new_arr[index + len(other)] = simplify(self[index]) - new_arr._concrete_cache.update(_concrete_cache) - return new_arr -class ArrayVariable(Array): - __slots__ = ["_name"] +class ArrayVariable(Array, Variable): + """ + An Array expression is a mapping from bitvector of index_size bits + into bitvectors of value_size bits. - def __init__(self, index_bits, index_max, value_bits, name, *args, **kwargs): - assert " " not in name - super().__init__(index_bits, index_max, value_bits, *args, **kwargs) - self._name = name + If a default value is provided reading from an unused index will return the + default. Otherwise each unused position in the array represents a free bitvector. - @property - def name(self): - return self._name + If a length maximum index is provided, accessing over the max is undefined. + Otherwise the array is unbounded. - def __copy__(self, memo=""): - raise ExpressionException("Copying of Variables is not allowed.") + """ - def __deepcopy__(self, memo=""): - raise ExpressionException("Copying of Variables is not allowed.") + __xslots__: Tuple[str, ...] = ( + "_index_size", + "_value_size", + "_length", + "_default", + ) - def __repr__(self): - return "<{:s}({:s}) at {:x}>".format(type(self).__name__, self.name, id(self)) + @property + def length(self): + return self._length + + def __init__( + self, + *, + index_size: int, + value_size: int, + length: Optional[int] = None, + default: Optional[int] = None, + **kwargs, + ): + """ + This is a mapping from BV to BV. Normally used to represent a memory. + + :param index_size: Number of bits in the addressing side + :param length: Max address allowed + :param value_size: Number of bits in tha value side + :param default: Reading from an uninitialized index will return default + if provided. If not the behaivor mimics thtat from smtlib, + the returned value is a free symbol. + :param kwargs: Used in other parent classes + """ + assert index_size in (32, 64, 256) + assert value_size in (8, 16, 32, 64, 256) + assert length is None or 0 <= length < 2 ** index_size + self._index_size = index_size + self._length = length + self._value_size = value_size + self._default = default + super().__init__(**kwargs) @property - def declaration(self): - return f"(declare-fun {self.name} () (Array (_ BitVec {self.index_bits}) (_ BitVec {self.value_bits})))" + def index_size(self): + return self._index_size + @property + def value_size(self): + return self._value_size -class ArrayOperation(Array): - __slots__ = ["_operands"] + @property + def default(self): + return self._default - def __init__(self, array: Array, *operands, **kwargs): - self._operands = (array, *operands) + @lru_cache(maxsize=128, typed=True) + def select(self, index): + """Gets an element from an empty Array.""" + default = self._default + if default is not None: + return default + index = self.cast_index(index) + return ArraySelect(self, index) - # If taint was not forced by a keyword argument, calculate default - kwargs.setdefault("taint", reduce(lambda x, y: x.union(y.taint), operands, frozenset())) + def store(self, index, value): + index = self.cast_index(index) + value = local_simplify(self.cast_value(value)) + return ArrayStore(array=self, index=index, value=value) + + @property + def written(self): + return frozenset() - super().__init__(array.index_bits, array.index_max, array.value_bits, **kwargs) + def is_known(self, index): + return False @property - def operands(self): - return self._operands + def underlying_variable(self): + array = self + while not isinstance(array, ArrayVariable): + array = array.array + return array + + +class ArrayOperation(Array, Operation, abstract=True): + """ It's an operation that results in an Array""" + + pass + + +def get_items(array): + if isinstance(array, ArrayStore): + yield from get_items_array_store(array) + elif isinstance(array, ArraySlice): + yield from get_items_array_slice(array) + elif isinstance(array, ArrayConstant): + yield from get_items_array_constant(array) + return + + +def get_items_array_slice(array): + assert isinstance(array, ArraySlice) + for offset, value in get_items(array.array): + yield offset + array.offset, value + + +def get_items_array_store(array): + assert isinstance(array, ArrayStore) + while isinstance(array, ArrayStore): + yield array.index, array.value + array = array.array + yield from get_items(array) + + +def get_items_array_constant(array): + assert isinstance(array, ArrayConstant) + for index, value in enumerate(array.value): + yield index, value + + +def get_items_array_variable(array): + assert isinstance(array, ArrayVariable) + raise GeneratorExit class ArrayStore(ArrayOperation): - def __init__(self, array: "Array", index: "BitVec", value: "BitVec", *args, **kwargs): - assert index.size == array.index_bits - assert value.size == array.value_bits - super().__init__(array, index, value, *args, **kwargs) + __xslots__: Tuple[str, ...] = ( + "_written#v", + "_concrete_cache#v", + "_length#v", + "_default#v", + ) + + def __init__(self, array: Array, index: BitVec, value: BitVec, **kwargs): + assert index.size == array.index_size + assert value.size == array.value_size + self._written: Optional[Set[Any]] = None # Cache of the known indexes + self._concrete_cache: Optional[Dict[Any, Any]] = None + self._length = array.length + self._default = array.default + + # recreate and reuse cache + # if isinstance(index, Constant) and isinstance(array, ArrayStore) and array._concrete_cache is not None: + # self._concrete_cache = dict(array._concrete_cache) + # self._concrete_cache[index.value] = value + + super().__init__( + operands=(array, index, value), + **kwargs, + ) + + @property + def concrete_cache(self): + if self._concrete_cache is None: + self._concrete_cache = {} + for index, value in get_items(self): + if not isinstance(index, Constant): + break + if index.value not in self._concrete_cache: + self._concrete_cache[index.value] = value + return self._concrete_cache + + @property + def written(self): + # Calculate only first time + # This can have repeated and reused written indexes. + if not self._written: + self._written = {offset for offset, _ in get_items(self)} + return self._written + + def is_known(self, index): + if isinstance(index, Constant) and index.value in self.concrete_cache: + return BoolConstant(value=True) + + is_known_index: Bool = BoolConstant(value=False) + written = self.written + for known_index in written: + if isinstance(index, Constant) and isinstance(known_index, Constant): + if known_index.value == index.value: + return BoolConstant(value=True) + is_known_index = BoolOr(is_known_index.cast(index == known_index), is_known_index) + return is_known_index + + @property + def length(self): + return self._length + + @property + def default(self): + return self._default + + @property + def index_size(self): + return self.index.size + + @property + def value_size(self): + return self.value.size @property def array(self): @@ -966,279 +1256,247 @@ def index(self): def value(self): return self.operands[2] - def __getstate__(self): - state = {} - array = self - items = [] - while isinstance(array, ArrayStore): - items.append((array.index, array.value)) - array = array.array - state["_array"] = array - state["_items"] = items - return state + def select(self, index): - def __setstate__(self, state): - array = state["_array"] - for index, value in reversed(state["_items"][0:]): - array = array.store(index, value) - self._index_bits = array.index_bits - self._index_max = array.index_max - self._value_bits = array.value_bits - index, value = state["_items"][0] - self._operands = (array, index, value) + """Gets an element from the Array. + If the element was not previously the default is used. + """ + index = local_simplify(self.cast_index(index)) + + # Emulate list[-1] + has_length = self.length is not None + if has_length: + index = local_simplify(BitVecITE(index < 0, self.length + index, index)) + + if isinstance(index, Constant): + if has_length and index.value >= self.length: + raise IndexError + if index.value in self.concrete_cache: + return self.concrete_cache[index.value] + + default = self.default + if default is None: + # No default. Returns normal array select + return ArraySelect(self, index) + + # if a default is defined we need to check if the index was previously written + return local_simplify( + BitVecITE(self.is_known(index), ArraySelect(self, index), self.cast_value(default)) + ) + + def store(self, index, value): + casted = self.cast_index(index) + index = local_simplify(casted) + value = self.cast_value(value) + new_array = ArrayStore(self, index, value) + return new_array + + def __eq__(self, other): + return self._compare_buffers(self, self.cast(other)) class ArraySlice(ArrayOperation): - def __init__( - self, array: Union["Array", "ArrayProxy"], offset: int, size: int, *args, **kwargs - ): + """Provides a projection of an underlying array. + Lets you slice an array without copying it. + (It needs to be simplified out before translating it smtlib) + """ + + def __init__(self, array: "Array", offset: int, size: int, **kwargs): if not isinstance(array, Array): raise ValueError("Array expected") - if isinstance(array, ArrayProxy): - array = array._array - super().__init__(array, **kwargs) - self._slice_offset = offset - self._slice_size = size + super().__init__( + operands=(array, array.cast_index(offset), array.cast_index(size)), + **kwargs, + ) @property def array(self): return self.operands[0] @property - def underlying_variable(self): - return self.array.underlying_variable + def offset(self): + return self.operands[1] @property - def index_bits(self): - return self.array.index_bits + def length(self): + return self.operands[2].value @property - def index_max(self): - return self._slice_size + def index_size(self): + return self.array.index_size @property - def value_bits(self): - return self.array.value_bits + def value_size(self): + return self.array.value_size + + @property + def underlying_variable(self): + return self.array.underlying_variable def select(self, index): - return self.array.select(index + self._slice_offset) + index = self.cast_index(index) + if isinstance(index, Constant): + length = self.length + if length is not None and index.value >= length: + raise IndexError + return self.array.select(local_simplify(index + self.offset)) def store(self, index, value): return ArraySlice( - self.array.store(index + self._slice_offset, value), - self._slice_offset, - self._slice_size, + self.array.store(index + self.offset, value), + offset=self.offset, + size=len(self), ) + @property + def default(self): + return self.array.default -class ArrayProxy(Array): - def __init__(self, array: Array, default: Optional[int] = None): - self._default = default - self._concrete_cache: Dict[int, int] = {} - self._written = None - if isinstance(array, ArrayProxy): - # copy constructor - super().__init__(array.index_bits, array.index_max, array.value_bits) - self._array: Array = array._array - self._name: str = array._name - if default is None: - self._default = array._default - self._concrete_cache = dict(array._concrete_cache) - self._written = set(array.written) - elif isinstance(array, ArrayVariable): - # fresh array proxy - super().__init__(array.index_bits, array.index_max, array.value_bits) - self._array = array - self._name = array.name - else: - # arrayproxy for a prepopulated array - super().__init__(array.index_bits, array.index_max, array.value_bits) - self._name = array.underlying_variable.name - self._array = array + def __eq__(self, other): + # Ignore the taint for eq comparison + return self._compare_buffers(self, self.cast(other)) - @property - def underlying_variable(self): - return self._array.underlying_variable + +class MutableArray: + """ + Arrayproxy is a layer on top of an array that provides mutability and some + simple optimizations for concrete indexes. + + It is not hasheable. + Think: + bytearray <-> MutableArray ::: not hasheable, mutable + bytes <-> Array (ArraySlice, ArrayVariable, ArrayStore) ::: hasheable, notmutable + + """ + + def __init__(self, array: Array): + if isinstance(array, MutableArray): + array = array._array + + self._array: Array = array + + def __eq__(self, other): + """ Comparing the inner array of a MutableArray with other""" + return self.array == other + + def __hash__(self): + raise NotImplementedError() @property - def array(self): + def underlying_variable(self): + if isinstance(self._array, ArrayVariable): + return self._array.underlying_variable + # NOTE: What to do here? + assert False return self._array @property def name(self): - return self._name + if isinstance(self._array, ArrayVariable): + return self._array.name + # NOTE: What to do here? + assert False + return None + + @property + def array(self): + return self._array @property def operands(self): return (self._array,) @property - def index_bits(self): - return self._array.index_bits + def index_size(self): + return self._array.index_size @property - def index_max(self): - return self._array.index_max + def length(self): + return self._array.length @property - def value_bits(self): - return self._array.value_bits + def value_size(self): + return self._array.value_size @property def taint(self): return self._array.taint - def select(self, index): - return self.get(index) + @property + def default(self): + return self._array.default - def store(self, index, value): - if not isinstance(index, Expression): - index = self.cast_index(index) - if not isinstance(value, Expression): - value = self.cast_value(value) - from .visitors import simplify + def __len__(self): + return len(self._array) - index = simplify(index) - if isinstance(index, Constant): - self._concrete_cache[index.value] = value - else: - # delete all cache as we do not know what this may overwrite. - self._concrete_cache = {} + def select(self, index): + return self._array.select(index) - # potentially generate and update .written set - # if index is symbolic it may overwrite other previous indexes - self.written.add(index) + def store(self, index, value): self._array = self._array.store(index, value) + assert self._array is not None return self + @property + def written(self): + return self._array.written + def __getitem__(self, index): + result = self._array[index] if isinstance(index, slice): - start, stop = self._fix_index(index) - size = self._get_size(index) - array_proxy_slice = ArrayProxy(ArraySlice(self, start, size), default=self._default) - array_proxy_slice._concrete_cache = {} - for k, v in self._concrete_cache.items(): - if k >= start and k < start + size: - array_proxy_slice._concrete_cache[k - start] = v - - for i in self.written: - array_proxy_slice.written.add(i - start) - return array_proxy_slice - else: - if self.index_max is not None: - if not isinstance(index, Expression) and index >= self.index_max: - raise IndexError - return self.get(index, self._default) + return MutableArray(result) + return result def __setitem__(self, index, value): if isinstance(index, slice): - start, stop = self._fix_index(index) - size = self._get_size(index) + start, stop, size = self._array._fix_slice(index) assert len(value) == size for i in range(size): - self.store(start + i, value[i]) + self._array = self._array.store(start + i, value[i]) else: - self.store(index, value) - - def __getstate__(self): - state = {} - state["_default"] = self._default - state["_array"] = self._array - state["name"] = self.name - state["_concrete_cache"] = self._concrete_cache - return state - - def __setstate__(self, state): - self._default = state["_default"] - self._array = state["_array"] - self._name = state["name"] - self._concrete_cache = state["_concrete_cache"] - self._written = None - - def __copy__(self): - return ArrayProxy(self) - - @property - def written(self): - # Calculate only first time - if self._written is None: - written = set() - # take out Proxy sleve - array = self._array - offset = 0 - while not isinstance(array, ArrayVariable): - if isinstance(array, ArraySlice): - # if it is a proxy over a slice take out the slice too - offset += array._slice_offset - else: - # The index written to underlaying Array are displaced when sliced - written.add(array.index - offset) - array = array.array - assert isinstance(array, ArrayVariable) - self._written = written - return self._written - - def is_known(self, index): - if isinstance(index, Constant) and index.value in self._concrete_cache: - return BoolConstant(True) - - is_known_index = BoolConstant(False) - written = self.written - if isinstance(index, Constant): - for i in written: - # check if the concrete index is explicitly in written - if isinstance(i, Constant) and index.value == i.value: - return BoolConstant(True) - - # Build an expression to check if our concrete index could be the - # solution for anyof the used symbolic indexes - is_known_index = BoolOr(is_known_index.cast(index == i), is_known_index) - return is_known_index - - # The index is symbolic we need to compare it agains it all - for known_index in written: - is_known_index = BoolOr(is_known_index.cast(index == known_index), is_known_index) + self._array = self._array.store(index, value) + assert self._array is not None + return self - return is_known_index + def write_BE(self, address, value, size): + self._array = self._array.write_BE(address, value, size) + assert self._array is not None + return self - def get(self, index, default=None): - if default is None: - default = self._default - index = self.cast_index(index) + def read_BE(self, address, size): + return self._array.read_BE(address, size) - if self.index_max is not None: - from .visitors import simplify + def write(self, offset, buf): + self._array = self._array.write(offset, buf) + assert self._array is not None - index = simplify( - BitVecITE(self.index_bits, index < 0, self.index_max + index + 1, index) - ) - if isinstance(index, Constant) and index.value in self._concrete_cache: - return self._concrete_cache[index.value] + return self - if default is not None: - default = self.cast_value(default) - is_known = self.is_known(index) - if isinstance(is_known, Constant) and is_known.value == False: - return default - else: - return self._array.select(index) + def read(self, offset, size): + return MutableArray(self._array[offset : offset + size]) - value = self._array.select(index) - return BitVecITE(self._array.value_bits, is_known, value, default) + def __ne__(self, other): + return BoolNot(self == other) + def __add__(self, other): + if isinstance(other, MutableArray): + other = other.array + return MutableArray(self.array + other) -class ArraySelect(BitVec): - __slots__ = ["_operands"] + def __radd__(self, other): + if isinstance(other, MutableArray): + other = other.array + return MutableArray(other + self.array) - def __init__(self, array: "Array", index: "BitVec", *operands, **kwargs): - assert index.size == array.index_bits - self._operands = (array, index, *operands) - # If taint was not forced by a keyword argument, calculate default - kwargs.setdefault("taint", reduce(lambda x, y: x.union(y.taint), operands, frozenset())) +class ArraySelect(BitVecOperation): + __xslots__ = BitVecOperation.__xslots__ - super().__init__(array.value_bits, **kwargs) + def __init__(self, array: "Array", index: "BitVec", *args, **kwargs): + assert isinstance(array, Array) + assert index.size == array.index_size + super().__init__(size=array.value_size, operands=(array, index), **kwargs) @property def array(self): @@ -1248,34 +1506,35 @@ def array(self): def index(self): return self.operands[1] - @property - def operands(self): - return self._operands - def __repr__(self): - return f"" + return f"" class BitVecSignExtend(BitVecOperation): - def __init__(self, operand: "BitVec", size_dest: int, *args, **kwargs): - assert size_dest >= operand.size - super().__init__(size_dest, operand, *args, **kwargs) - self.extend = size_dest - operand.size + def __init__(self, operand: BitVec, size: int, *args, **kwargs): + super().__init__(size=size, operands=(operand,), **kwargs) + + @property + def extend(self): + return self.size - self.operands[0].size class BitVecZeroExtend(BitVecOperation): - def __init__(self, size_dest: int, operand: "BitVec", *args, **kwargs): - assert size_dest >= operand.size - super().__init__(size_dest, operand, *args, **kwargs) - self.extend = size_dest - operand.size + def __init__(self, size: int, operand: BitVec, *args, **kwargs): + super().__init__(size=size, operands=(operand,), **kwargs) + + @property + def extend(self): + return self.size - self.operands[0].size class BitVecExtract(BitVecOperation): + __xslots__ = ("_offset",) + def __init__(self, operand: "BitVec", offset: int, size: int, *args, **kwargs): assert offset >= 0 and offset + size <= operand.size - super().__init__(size, operand, *args, **kwargs) - self._begining = offset - self._end = offset + size - 1 + super().__init__(size=size, operands=(operand,), **kwargs) + self._offset = offset @property def value(self): @@ -1283,33 +1542,33 @@ def value(self): @property def begining(self): - return self._begining + return self._offset @property def end(self): - return self._end + return self.begining + self.size - 1 class BitVecConcat(BitVecOperation): - def __init__(self, size_dest: int, *operands, **kwargs): - assert all(isinstance(x, BitVec) for x in operands) - assert size_dest == sum(x.size for x in operands) - super().__init__(size_dest, *operands, **kwargs) + def __init__(self, operands: Tuple[BitVec, ...], **kwargs): + size = sum(x.size for x in operands) + super().__init__(size=size, operands=operands, **kwargs) class BitVecITE(BitVecOperation): + __xslots__ = BitVecOperation.__xslots__ + def __init__( self, - size: int, - condition: Union["Bool", bool], - true_value: "BitVec", - false_value: "BitVec", - *args, + condition: Bool, + true_value: BitVec, + false_value: BitVec, **kwargs, ): - assert true_value.size == size - assert false_value.size == size - super().__init__(size, condition, true_value, false_value, *args, **kwargs) + + super().__init__( + size=true_value.size, operands=(condition, true_value, false_value), **kwargs + ) @property def condition(self): @@ -1324,6 +1583,72 @@ def false_value(self): return self.operands[2] -Constant = (BitVecConstant, BoolConstant) -Variable = (BitVecVariable, BoolVariable, ArrayVariable) -Operation = (BitVecOperation, BoolOperation, ArrayOperation, ArraySelect) +# auxiliary functions. Maybe move to operators +def issymbolic(value) -> bool: + """ + Helper to determine whether an object is symbolic (e.g checking + if data read from memory is symbolic) + + :param object value: object to check + :return: whether `value` is symbolic + :rtype: bool + """ + return isinstance(value, (Expression, MutableArray)) + + +def istainted(arg, taint=None): + """ + Helper to determine whether an object if tainted. + :param arg: a value or Expression + :param taint: a regular expression matching a taint value (eg. 'IMPORTANT.*'). If None, this function checks for any taint value. + """ + + if not issymbolic(arg): + return False + if taint is None: + return len(arg.taint) != 0 + for arg_taint in arg.taint: + m = re.match(taint, arg_taint, re.DOTALL | re.IGNORECASE) + if m: + return True + return False + + +def get_taints(arg, taint=None): + """ + Helper to list an object taints. + :param arg: a value or Expression + :param taint: a regular expression matching a taint value (eg. 'IMPORTANT.*'). If None, this function checks for any taint value. + """ + + if not issymbolic(arg): + return + for arg_taint in arg.taint: + if taint is not None: + m = re.match(taint, arg_taint, re.DOTALL | re.IGNORECASE) + if m: + yield arg_taint + else: + yield arg_taint + return + + +def taint_with(arg, *taints, value_size=256, index_size=256): + """ + Helper to taint a value. + :param arg: a value or Expression + :param taint: a regular expression matching a taint value (eg. 'IMPORTANT.*'). If None, this function checks for any taint value. + """ + tainted_fset = frozenset(tuple(taints)) + if not issymbolic(arg): + if isinstance(arg, int): + arg = BitVecConstant(value_size, arg) + arg._taint = tainted_fset + else: + raise ValueError("type not supported") + + else: + arg = copy.copy(arg) + arg._taint |= tainted_fset + + return arg diff --git a/manticore/core/smtlib/operators.py b/manticore/core/smtlib/operators.py index 969679b56..0fc1715fc 100644 --- a/manticore/core/smtlib/operators.py +++ b/manticore/core/smtlib/operators.py @@ -163,7 +163,7 @@ def cast(x): return BitVecConstant(arg_size, x) return x - return BitVecConcat(total_size, *list(map(cast, args))) + return BitVecConcat(operands=tuple(map(cast, args))) else: return args[0] else: @@ -184,10 +184,10 @@ def ITE(cond, true_value, false_value): return false_value if isinstance(true_value, bool): - true_value = BoolConstant(true_value) + true_value = BoolConstant(value=true_value) if isinstance(false_value, bool): - false_value = BoolConstant(false_value) + false_value = BoolConstant(value=false_value) return BoolITE(cond, true_value, false_value) @@ -217,7 +217,7 @@ def ITEBV(size, cond, true_value, false_value): if isinstance(false_value, int): false_value = BitVecConstant(size, false_value) - return BitVecITE(size, cond, true_value, false_value) + return BitVecITE(cond, true_value, false_value) def UDIV(dividend, divisor): @@ -237,14 +237,6 @@ def SDIV(a, b): return int(math.trunc(float(a) / float(b))) -def SMOD(a, b): - if isinstance(a, BitVec): - return a.smod(b) - elif isinstance(b, BitVec): - return b.rsmod(a) - return int(math.fmod(a, b)) - - def SREM(a, b): if isinstance(a, BitVec): return a.srem(b) diff --git a/manticore/core/smtlib/solver.py b/manticore/core/smtlib/solver.py index 7e61217ba..a49df3cee 100644 --- a/manticore/core/smtlib/solver.py +++ b/manticore/core/smtlib/solver.py @@ -22,7 +22,7 @@ import shlex import time from functools import lru_cache -from typing import Dict, Tuple, Sequence, Optional +from typing import Dict, Tuple, Sequence, Optional, NamedTuple from subprocess import PIPE, Popen, check_output import re from . import operators as Operators @@ -163,7 +163,10 @@ def minmax(self, constraints, x, iters=10000): return x, x -Version = collections.namedtuple("Version", "major minor patch") +class Version(NamedTuple): + major: int + minor: int + path: int class SmtlibProc: @@ -214,7 +217,7 @@ def __readline_and_count(self): assert self._proc.stdout buf = self._proc.stdout.readline() # No timeout enforced here # If debug is enabled check if the solver reports a syntax error - # Error messages may contain an unbalanced parenthesis situation + # print (">",buf) if self._debug: if "(error" in buf: raise SolverException(f"Error in smtlib: {buf}") @@ -223,10 +226,11 @@ def __readline_and_count(self): def send(self, cmd: str) -> None: """ - Send a string to the solver. + Send a string to the solver.s :param cmd: a SMTLIBv2 command (ex. (check-sat)) """ + # print ("<",cmd) if self._debug: logger.debug(">%s", cmd) self._proc.stdout.flush() # type: ignore @@ -244,6 +248,7 @@ def recv(self) -> str: buf = "".join(bufl).strip() + # print (">",buf) if self._debug: logger.debug("<%s", buf) @@ -396,6 +401,52 @@ def _pop(self): """Recall the last pushed constraint store and state.""" self._smtlib.send("(pop 1)") + @lru_cache(maxsize=32) + def get_model(self, constraints: ConstraintSet): + self._reset(constraints.to_string()) + self._smtlib.send("(check-sat)") + self._smtlib.recv() + + model = {} + for variable in constraints.variables: + value = None + if isinstance(variable, BoolVariable): + value = self.__getvalue_bool(variable.name) + elif isinstance(variable, BitVecVariable): + value = self.__getvalue_bv(variable.name) + elif isinstance(variable, Array): + try: + if variable.length is not None: + value = [] + for i in range(len(variable)): + variable_i = variable[i] + if issymbolic(variable_i): + value.append(self.__getvalue_bv(translate_to_smtlib(variable_i))) + else: + value.append(variable_i) + value = bytes(value) + else: + # Only works if we know the max index of the arrray + used_indexes = map(self.__getvalue_bv, variable.written) + valued = {} + for i in used_indexes: + valued[i] = self.__getvalue_bv(variable[i]) + + class A: + def __init__(self, d, default): + self._d = d + self._default = default + + def __getitem__(self, index): + return self._d.get(index, self._default) + + value = A(valued, variable.default) + except Exception as e: + value = None # We failed to get the model from the solver + + model[variable.name] = value + return model + @lru_cache(maxsize=32) def can_be_true(self, constraints: ConstraintSet, expression: Union[bool, Bool] = True) -> bool: """Check if two potentially symbolic values can be equal""" @@ -503,7 +554,6 @@ def get_all_values( if not isinstance(expression, Expression): return [expression] assert isinstance(expression, Expression) - expression = simplify(expression) if maxcnt is None: maxcnt = consts.maxsolutions if isinstance(expression, Bool) and consts.maxsolutions > 1: @@ -518,15 +568,15 @@ def get_all_values( var = temp_cs.new_bitvec(expression.size) elif isinstance(expression, Array): var = temp_cs.new_array( - index_max=expression.index_max, - value_bits=expression.value_bits, + length=expression.length, + index_size=expression.index_size, + value_size=expression.value_size, taint=expression.taint, - ).array + ) else: raise NotImplementedError( f"get_all_values only implemented for {type(expression)} expression type." ) - temp_cs.add(var == expression) self._reset(temp_cs.to_string()) result = [] @@ -549,7 +599,7 @@ def get_all_values( logger.info("Timeout searching for all solutions") return list(result) raise SolverError("Timeout") - # Sometimes adding a new contraint after a check-sat eats all the mem + # Sometimes adding a new constraint after a check-sat eats all the mem if self._multiple_check: self._smtlib.send(f"(assert {translate_to_smtlib(var != value)})") else: @@ -575,9 +625,11 @@ def _optimize_fancy(self, constraints: ConstraintSet, x: BitVec, goal: str, max_ X = temp_cs.new_bitvec(x.size) temp_cs.add(X == x) aux = temp_cs.new_bitvec(X.size, name="optimized_") + + temp_cs.add(operation(X, aux)) self._reset(temp_cs.to_string()) - self._assert(operation(X, aux)) + # self._assert(operation(X, aux)) self._smtlib.send("(%s %s)" % (goal, aux.name)) self._smtlib.send("(check-sat)") _status = self._smtlib.recv() @@ -590,10 +642,25 @@ def get_value(self, constraints: ConstraintSet, *expressions): Ask the solver for one possible result of given expressions using given set of constraints. """ + self._cache = getattr(self, "_cache", {}) + model = self.get_model(constraints) + + #################### values = [] start = time.time() - with constraints.related_to(*expressions) as temp_cs: + # with constraints.related_to(*expressions) as temp_cs: + with constraints as temp_cs: for expression in expressions: + bucket = self._cache.setdefault(hash(constraints), {}) + cached_result = bucket.get(hash(expression)) + if cached_result is not None: + values.append(cached_result) + continue + elif isinstance(expression, Variable): + if model[expression.name] is not None: + values.append(model[expression.name]) + continue + if not issymbolic(expression): values.append(expression) continue @@ -605,21 +672,28 @@ def get_value(self, constraints: ConstraintSet, *expressions): elif isinstance(expression, Array): var = [] result = [] - for i in range(expression.index_max): - subvar = temp_cs.new_bitvec(expression.value_bits) - var.append(subvar) - temp_cs.add(subvar == simplify(expression[i])) + for i in range(len(expression)): + expression_i = expression[i] + if issymbolic(expression_i): + subvar = temp_cs.new_bitvec(expression.value_size) + temp_cs.add(subvar == expression[i]) + var.append(subvar) + else: + var.append(expression_i) self._reset(temp_cs.to_string()) if not self._is_sat(): raise SolverError( "Solver could not find a value for expression under current constraint set" ) - - for i in range(expression.index_max): - result.append(self.__getvalue_bv(var[i].name)) + for i in range(expression.length): + if issymbolic(var[i]): + result.append(self.__getvalue_bv(var[i].name)) + else: + result.append(var[i]) values.append(bytes(result)) + bucket[hash(expression)] = values[-1] if time.time() - start > consts.timeout: - raise SolverError("Timeout") + raise SolverError(f"Timeout {expressions}") continue temp_cs.add(var == expression) @@ -635,6 +709,7 @@ def get_value(self, constraints: ConstraintSet, *expressions): values.append(self.__getvalue_bool(var.name)) if isinstance(expression, BitVec): values.append(self.__getvalue_bv(var.name)) + bucket[hash(expression)] = values[-1] if time.time() - start > consts.timeout: raise SolverError("Timeout") @@ -690,7 +765,7 @@ def __autoconfig(self): # Certain version of Z3 fails to handle multiple check-sat # https://gist.github.com/feliam/0f125c00cb99ef05a6939a08c4578902 - multiple_check = self.version < Version(4, 8, 7) + multiple_check = False # self.version < Version(4, 8, 7) return init, support_minmax, support_reset, multiple_check def _solver_version(self) -> Version: @@ -701,8 +776,8 @@ def _solver_version(self) -> Version: Anticipated version_cmd_output format: 'Z3 version 4.4.2' 'Z3 version 4.4.5 - 64 bit - build hashcode $Z3GITHASH' """ + received_version = check_output([f"{consts.z3_bin}", "--version"]) try: - received_version = check_output([f"{consts.z3_bin}", "--version"]) Z3VERSION = re.compile( r".*(?P([0-9]+))\.(?P([0-9]+))\.(?P([0-9]+)).*" ) @@ -715,7 +790,7 @@ def _solver_version(self) -> Version: logger.warning( f"Could not parse Z3 version: '{str(received_version)}'. Assuming compatibility." ) - parsed_version = Version(float("inf"), float("inf"), float("inf")) + parsed_version = Version(sys.maxsize, sys.maxsize, sys.maxsize) return parsed_version diff --git a/manticore/core/smtlib/visitors.py b/manticore/core/smtlib/visitors.py index ec6e446e8..da256cbd5 100644 --- a/manticore/core/smtlib/visitors.py +++ b/manticore/core/smtlib/visitors.py @@ -1,34 +1,43 @@ +from typing import Optional, Dict, Set, Type, TYPE_CHECKING + from ...utils.helpers import CacheDict from ...exceptions import SmtlibError from .expression import * from functools import lru_cache import copy +import io import logging import operator +import time import math -import threading from decimal import Decimal +if TYPE_CHECKING: + from . import ConstraintException + logger = logging.getLogger(__name__) +class VisitorException(Exception): + pass + + class Visitor: """Class/Type Visitor Inherit your class visitor from this one and get called on a different visiting function for each type of expression. It will call the first implemented method for the __mro__ class order. - For example for a BitVecAdd it will try - visit_BitVecAdd() if not defined then it will try with - visit_BitVecOperation() if not defined then it will try with - visit_BitVec() if not defined then it will try with + For example for a BitVecAdd expression this will try + visit_BitVecAdd() if not defined(*) then it will try with + visit_BitVecOperation() if not defined(*) then it will try with + visit_BitVec() if not defined(*) then it will try with + visit_Operation() if not defined(*) then it will try with visit_Expression() - Other class named visitors are: - - visit_BitVec() - visit_Bool() - visit_Array() + (*) Or it is defined and it returns None for the node. + You can overload the visiting method to react to different semantic + aspects of an Exrpession. """ @@ -44,116 +53,137 @@ def push(self, value): def pop(self): if len(self._stack) == 0: return None - result = self._stack.pop() - return result + return self._stack.pop() @property def result(self): assert len(self._stack) == 1 return self._stack[-1] - def _method(self, expression, *args): - for cls in expression.__class__.__mro__[:-1]: - sort = cls.__name__ - methodname = "visit_%s" % sort - if hasattr(self, methodname): - value = getattr(self, methodname)(expression, *args) - if value is not None: - return value - return self._rebuild(expression, args) - - def visit(self, node, use_fixed_point=False): + def visit(self, node: Expression, use_fixed_point: bool = False): + assert isinstance(node, Expression) """ The entry point of the visitor. The exploration algorithm is a DFS post-order traversal The implementation used two stacks instead of a recursion - The final result is store in self.result + The final result is store in result or can be taken from the resultant + stack via pop() + + #TODO: paste example :param node: Node to explore :type node: Expression :param use_fixed_point: if True, it runs _methods until a fixed point is found - :type use_fixed_point: Bool """ - if isinstance(node, ArrayProxy): - node = node.array cache = self._cache - visited = set() - stack = [] - stack.append(node) - while stack: - node = stack.pop() - if node in cache: + visited: Set[Expression] = set() + local_stack = [node] # initially the stack contains only the visiting node + while local_stack: + node = local_stack.pop() + if node in cache: # If seen we do not really need to visit this self.push(cache[node]) - elif isinstance(node, Operation): - if node in visited: - operands = [self.pop() for _ in range(len(node.operands))] - value = self._method(node, *operands) - - visited.remove(node) - self.push(value) - cache[node] = value - else: - visited.add(node) - stack.append(node) - stack.extend(node.operands) + continue + if node in visited: + visited.remove(node) + # Visited! Then there is a visited version of the operands in the stack + operands: Tuple[Expression, ...] = tuple([]) + if node.operands: + operands = tuple([self.pop() for _ in range(len(node.operands))]) + # Actually process the node + value = self._method(node, *operands) + self.push(value) + cache[node] = value else: - self.push(self._method(node)) + visited.add(node) + local_stack.append(node) + if node.operands: + local_stack.extend(node.operands) + # Repeat until the result is not changed if use_fixed_point: - old_value = None + old_value = node new_value = self.pop() while old_value is not new_value: self.visit(new_value) old_value = new_value new_value = self.pop() - self.push(new_value) - @staticmethod - def _rebuild(expression, operands): - if isinstance(expression, Operation): - if any(x is not y for x, y in zip(expression.operands, operands)): - aux = copy.copy(expression) - aux._operands = operands - return aux - return expression - - -class Translator(Visitor): - """Simple visitor to translate an expression into something else""" + def _method(self, expression: Expression, *operands): + """ + Magic method to walk the mro looking for the first overloaded + visiting method that returns something. - def _method(self, expression, *args): - # Special case. Need to get the unsleeved version of the array - if isinstance(expression, ArrayProxy): - expression = expression.array + If no visiting methods are found the expression gets _rebuild using + the processed operands. + Iff the operands did not change the expression is left unchanged. + :param expression: Expression node + :param operands: The already processed operands + :return: Optional resulting Expression + """ assert expression.__class__.__mro__[-1] is object - for cls in expression.__class__.__mro__: + for cls in expression.__class__.__mro__[:-1]: sort = cls.__name__ methodname = f"visit_{sort:s}" if hasattr(self, methodname): - value = getattr(self, methodname)(expression, *args) + value = getattr(self, methodname)(expression, *operands) if value is not None: return value - raise SmtlibError(f"No translation for this {expression}") + return self._rebuild(expression, operands) + + def _changed(self, expression: Expression, operands: Optional[Tuple[Expression, ...]]): + # False if no operands + changed = any(x is not y for x, y in zip(expression.operands or (), operands or ())) + return changed + + @lru_cache(maxsize=32, typed=True) + def _rebuild(self, expression: Operation, operands): + """Default operation used when no visiting method was successful for + this expression. If the operands have changed this rebuild the current expression + with the new operands. + + Assumes the stack is used for Expressions + """ + if self._changed(expression, operands): + aux = copy.copy(expression) + aux._operands = operands + return aux + # The expression is not modified in any way iff: + # - no visitor method is defined for the expression type + # - no operands were modified + return expression + + +class Translator(Visitor): + """Simple visitor to translate an expression into something else""" + + def _rebuild(self, expression, operands): + """The stack holds the translation of the expression. + There is no default action + + :param expression: Current expression + :param operands: The translations of the nodes + :return: No + """ + raise VisitorException(f"No translation for this {expression}") -class GetDeclarations(Visitor): +class GetDeclarations(Translator): """Simple visitor to collect all variables in an expression or set of expressions """ + def _rebuild(self, expression, operands): + return expression + def __init__(self, **kwargs): super().__init__(**kwargs) self.variables = set() - def _visit_variable(self, expression): + def visit_Variable(self, expression): self.variables.add(expression) - visit_ArrayVariable = _visit_variable - visit_BitVecVariable = _visit_variable - visit_BoolVariable = _visit_variable - @property def result(self): return self.variables @@ -164,20 +194,18 @@ class GetDepth(Translator): expressions """ + def _rebuild(self, expression, operands): + return expression + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def visit_Expression(self, expression): return 1 - def _visit_operation(self, expression, *operands): + def visit_Operation(self, expression, *operands): return 1 + max(operands) - visit_ArraySelect = _visit_operation - visit_ArrayOperation = _visit_operation - visit_BoolOperation = _visit_operation - visit_BitVecOperation = _visit_operation - def get_depth(exp): visitor = GetDepth() @@ -188,13 +216,12 @@ def get_depth(exp): class PrettyPrinter(Visitor): def __init__(self, depth=None, **kwargs): super().__init__(**kwargs) - self.output = "" + self.output = io.StringIO() self.indent = 0 self.depth = depth def _print(self, s, e=None): - self.output += " " * self.indent + str(s) # + '(%016x)'%hash(e) - self.output += "\n" + self.output.write(" " * self.indent + str(s) + "\n") # + '(%016x)'%hash(e) def visit(self, expression): """ @@ -205,22 +232,7 @@ def visit(self, expression): """ self._method(expression) - def _method(self, expression, *args): - """ - Overload Visitor._method because we want to stop to iterate over the - visit_ functions as soon as a valid visit_ function is found - """ - assert expression.__class__.__mro__[-1] is object - for cls in expression.__class__.__mro__: - sort = cls.__name__ - methodname = "visit_%s" % sort - method = getattr(self, methodname, None) - if method is not None: - method(expression, *args) - return - return - - def _visit_operation(self, expression, *operands): + def visit_Operation(self, expression, *operands): self._print(expression.__class__.__name__, expression) self.indent += 2 if self.depth is None or self.indent < self.depth * 2: @@ -229,12 +241,7 @@ def _visit_operation(self, expression, *operands): else: self._print("...") self.indent -= 2 - return "" - - visit_ArraySelect = _visit_operation - visit_ArrayOperation = _visit_operation - visit_BoolOperation = _visit_operation - visit_BitVecOperation = _visit_operation + return True def visit_BitVecExtract(self, expression): self._print( @@ -248,26 +255,19 @@ def visit_BitVecExtract(self, expression): else: self._print("...") self.indent -= 2 - return "" + return True - def _visit_constant(self, expression): + def visit_Constant(self, expression): self._print(expression.value) - return "" - - visit_BitVecConstant = _visit_constant - visit_BoolConstant = _visit_constant + return True - def _visit_variable(self, expression): + def visit_Variable(self, expression): self._print(expression.name) - return "" - - visit_ArrayVariable = _visit_variable - visit_BitVecVariable = _visit_variable - visit_BoolVariable = _visit_variable + return True @property def result(self): - return self.output + return self.output.getvalue() def pretty_print(expression, **kwargs): @@ -279,8 +279,6 @@ def pretty_print(expression, **kwargs): class ConstantFolderSimplifier(Visitor): - def __init__(self, **kw): - super().__init__(**kw) operations = { BitVecMod: operator.__mod__, @@ -298,52 +296,53 @@ def __init__(self, **kw): BoolEqual: operator.__eq__, BoolOr: operator.__or__, BoolNot: operator.__not__, - UnsignedLessThan: operator.__lt__, - UnsignedLessOrEqual: operator.__le__, - UnsignedGreaterThan: operator.__gt__, - UnsignedGreaterOrEqual: operator.__ge__, + BoolUnsignedLessThan: operator.__lt__, + BoolUnsignedLessOrEqualThan: operator.__le__, + BoolUnsignedGreaterThan: operator.__gt__, + BoolUnsignedGreaterOrEqualThan: operator.__ge__, } def visit_BitVecUnsignedDiv(self, expression, *operands) -> Optional[BitVecConstant]: if all(isinstance(o, Constant) for o in operands): a = operands[0].value b = operands[1].value - if a == 0: + if b == 0: ret = 0 else: ret = math.trunc(Decimal(a) / Decimal(b)) return BitVecConstant(expression.size, ret, taint=expression.taint) return None - def visit_LessThan(self, expression, *operands) -> Optional[BoolConstant]: + def visit_BoolLessThan(self, expression, *operands) -> Optional[BoolConstant]: if all(isinstance(o, Constant) for o in operands): a = operands[0].signed_value b = operands[1].signed_value - return BoolConstant(a < b, taint=expression.taint) + return BoolConstant(value=a < b, taint=expression.taint) return None - def visit_LessOrEqual(self, expression, *operands) -> Optional[BoolConstant]: + def visit_BoolLessOrEqual(self, expression, *operands) -> Optional[BoolConstant]: if all(isinstance(o, Constant) for o in operands): a = operands[0].signed_value b = operands[1].signed_value - return BoolConstant(a <= b, taint=expression.taint) + return BoolConstant(value=a <= b, taint=expression.taint) return None - def visit_GreaterThan(self, expression, *operands) -> Optional[BoolConstant]: + def visit_BoolGreaterThan(self, expression, *operands) -> Optional[BoolConstant]: if all(isinstance(o, Constant) for o in operands): a = operands[0].signed_value b = operands[1].signed_value - return BoolConstant(a > b, taint=expression.taint) + return BoolConstant(value=a > b, taint=expression.taint) return None - def visit_GreaterOrEqual(self, expression, *operands) -> Optional[BoolConstant]: + def visit_BoolGreaterOrEqual(self, expression, *operands) -> Optional[BoolConstant]: if all(isinstance(o, Constant) for o in operands): a = operands[0].signed_value b = operands[1].signed_value - return BoolConstant(a >= b, taint=expression.taint) + return BoolConstant(value=a >= b, taint=expression.taint) return None def visit_BitVecDiv(self, expression, *operands) -> Optional[BitVecConstant]: + return None if all(isinstance(o, Constant) for o in operands): signmask = operands[0].signmask mask = operands[0].mask @@ -392,25 +391,16 @@ def visit_BoolAnd(self, expression, a, b): if isinstance(b, Constant) and b.value == True: return a - def _visit_operation(self, expression, *operands): + def visit_Operation(self, expression: Expression, *operands): """ constant folding, if all operands of an expression are a Constant do the math """ operation = self.operations.get(type(expression), None) if operation is not None and all(isinstance(o, Constant) for o in operands): - value = operation(*(x.value for x in operands)) + value = operation(*(x.value for x in operands)) # type: ignore if isinstance(expression, BitVec): return BitVecConstant(expression.size, value, taint=expression.taint) else: isinstance(expression, Bool) - return BoolConstant(value, taint=expression.taint) - else: - if any(operands[i] is not expression.operands[i] for i in range(len(operands))): - expression = self._rebuild(expression, operands) - return expression - - visit_ArraySelect = _visit_operation - visit_ArrayOperation = _visit_operation - visit_BoolOperation = _visit_operation - visit_BitVecOperation = _visit_operation + return BoolConstant(value=value, taint=expression.taint) @lru_cache(maxsize=128, typed=True) @@ -421,32 +411,17 @@ def constant_folder(expression): class ArithmeticSimplifier(Visitor): - def __init__(self, parent=None, **kw): - super().__init__(**kw) - @staticmethod def _same_constant(a, b): return isinstance(a, Constant) and isinstance(b, Constant) and a.value == b.value or a is b - @staticmethod - def _changed(expression, operands): - if isinstance(expression, Constant) and len(operands) > 0: - return True - arity = len(operands) - return any(operands[i] is not expression.operands[i] for i in range(arity)) - - def _visit_operation(self, expression, *operands): + def visit_Operation(self, expression, *operands): """ constant folding, if all operands of an expression are a Constant do the math """ + expression = self._rebuild(expression, operands) if all(isinstance(o, Constant) for o in operands): expression = constant_folder(expression) - if self._changed(expression, operands): - expression = self._rebuild(expression, operands) return expression - visit_ArrayOperation = _visit_operation - visit_BoolOperation = _visit_operation - visit_BitVecOperation = _visit_operation - def visit_BitVecZeroExtend(self, expression, *operands): if self._changed(expression, operands): return BitVecZeroExtend(expression.size, *operands, taint=expression.taint) @@ -522,7 +497,7 @@ def visit_BoolEqual(self, expression, *operands): return BoolNot(operands[0].operands[0], taint=expression.taint) if operands[0] is operands[1]: - return BoolConstant(True, taint=expression.taint) + return BoolConstant(value=True, taint=expression.taint) if isinstance(operands[0], BitVecExtract) and isinstance(operands[1], BitVecExtract): if ( @@ -531,7 +506,7 @@ def visit_BoolEqual(self, expression, *operands): and operands[0].begining == operands[1].begining ): - return BoolConstant(True, taint=expression.taint) + return BoolConstant(value=True, taint=expression.taint) def visit_BoolOr(self, expression, a, b): if isinstance(a, Constant): @@ -560,7 +535,7 @@ def visit_BitVecITE(self, expression, *operands): return result if self._changed(expression, operands): - return BitVecITE(expression.size, *operands, taint=expression.taint) + return BitVecITE(*operands, taint=expression.taint) def visit_BitVecConcat(self, expression, *operands): """concat( extract(k1, 0, a), extract(sizeof(a)-k1, k1, a)) ==> a @@ -594,12 +569,12 @@ def visit_BitVecConcat(self, expression, *operands): if last_o is not None: new_operands.append(last_o) if changed: - return BitVecConcat(expression.size, *new_operands) + return BitVecConcat(operands=tuple(new_operands)) op = operands[0] value = None - end = None - begining = None + end: Optional[int] = None + begining: Optional[int] = None for o in operands: # If found a non BitVecExtract, do not apply if not isinstance(o, BitVecExtract): @@ -622,7 +597,7 @@ def visit_BitVecConcat(self, expression, *operands): # update begining variable begining = o.begining - if value is not None: + if value is not None and end is not None and begining is not None: if end + 1 != value.size or begining != 0: return BitVecExtract(value, begining, end - begining + 1, taint=expression.taint) @@ -643,12 +618,13 @@ def visit_BitVecExtract(self, expression, *operands): elif isinstance(op, BitVecExtract): return BitVecExtract(op.value, op.begining + begining, size, taint=expression.taint) elif isinstance(op, BitVecConcat): - new_operands = [] + new_operands: List[BitVec] = [] for item in reversed(op.operands): + assert isinstance(item, BitVec) if size == 0: assert expression.size == sum([x.size for x in new_operands]) return BitVecConcat( - expression.size, *reversed(new_operands), taint=expression.taint + operands=tuple(reversed(new_operands)), taint=expression.taint ) if begining >= item.size: @@ -677,6 +653,20 @@ def visit_BitVecExtract(self, expression, *operands): taint=expression.taint, ) + def visit_BitVecMul(self, expression, *operands): + left = operands[0] + right = operands[1] + if isinstance(right, BitVecConstant): + if right.value == 1: + return left + if right.value == 0: + return right + if isinstance(left, BitVecConstant): + if left.value == 1: + return right + if left.value == 0: + return left + def visit_BitVecAdd(self, expression, *operands): """a + 0 ==> a 0 + a ==> a @@ -782,7 +772,7 @@ def visit_ArraySelect(self, expression, *operands): """ arr, index = operands if isinstance(arr, ArrayVariable): - return self._visit_operation(expression, *operands) + return None if isinstance(index, BitVecConstant): ival = index.value @@ -809,7 +799,6 @@ def visit_ArraySelect(self, expression, *operands): out = arr.select(index) if out is not None: return arr.select(index) - return self._visit_operation(expression, *operands) def visit_Expression(self, expression, *operands): assert len(operands) == 0 @@ -819,6 +808,9 @@ def visit_Expression(self, expression, *operands): @lru_cache(maxsize=128, typed=True) def arithmetic_simplify(expression): + start = time.time() + if not isinstance(expression, Expression): + return expression simp = ArithmeticSimplifier() simp.visit(expression, use_fixed_point=True) return simp.result @@ -829,15 +821,17 @@ def to_constant(expression): Iff the expression can be simplified to a Constant get the actual concrete value. This discards/ignore any taint """ + if isinstance(expression, MutableArray): + expression = expression.array value = simplify(expression) if isinstance(value, Expression) and value.taint: raise ValueError("Can not simplify tainted values to constant") if isinstance(value, Constant): return value.value elif isinstance(value, Array): - if expression.index_max: + if expression.length: ba = bytearray() - for i in range(expression.index_max): + for i in range(expression.length): value_i = simplify(value[i]) if not isinstance(value_i, Constant): break @@ -848,46 +842,60 @@ def to_constant(expression): return value -@lru_cache(maxsize=128, typed=True) -def simplify(expression): - expression = arithmetic_simplify(expression) - return expression +simplify = arithmetic_simplify + + +class CountExpressionUse(Visitor): + def _rebuild(self, expression, operands): + return expression + + def __init__(self, *args, counts: Optional[Dict[Expression, int]] = None, **kw): + super().__init__(*args, **kw) + if counts is not None: + self.counts = counts + else: + self.counts = {} + + def visit_Expression(self, expression, *args): + if not isinstance(expression, (Variable, Constant)): + try: + self.counts[expression] += 1 + except KeyError: + self.counts[expression] = 1 class TranslatorSmtlib(Translator): """Simple visitor to translate an expression to its smtlib representation""" - unique = 0 - unique_lock = threading.Lock() - - def __init__(self, use_bindings=False, *args, **kw): + def __init__(self, use_bindings=True, *args, **kw): assert "bindings" not in kw super().__init__(*args, **kw) + self._unique = 0 self.use_bindings = use_bindings - self._bindings_cache = {} - self._bindings = [] + self._bindings_cache_exp = {} + self._bindings: Dict[str, Any] = {} + self._variables = set() def _add_binding(self, expression, smtlib): - if not self.use_bindings or len(smtlib) <= 10: + if not self.use_bindings: return smtlib - if smtlib in self._bindings_cache: - return self._bindings_cache[smtlib] - - with TranslatorSmtlib.unique_lock: - TranslatorSmtlib.unique += 1 - name = "a_%d" % TranslatorSmtlib.unique - - self._bindings.append((name, expression, smtlib)) + if isinstance(expression, (Constant, Variable)): + return smtlib + if expression in self._bindings_cache_exp: + return self._bindings_cache_exp[expression] + self._unique += 1 + name = "!a_%d!" % self._unique - self._bindings_cache[expression] = name + self._bindings[name] = expression, smtlib + self._bindings_cache_exp[expression] = name return name @property def bindings(self): return self._bindings - translation_table = { + translation_table: Dict[Type[Operation], str] = { BoolNot: "not", BoolEqual: "=", BoolAnd: "and", @@ -911,14 +919,14 @@ def bindings(self): BitVecXor: "bvxor", BitVecNot: "bvnot", BitVecNeg: "bvneg", - LessThan: "bvslt", - LessOrEqual: "bvsle", - GreaterThan: "bvsgt", - GreaterOrEqual: "bvsge", - UnsignedLessThan: "bvult", - UnsignedLessOrEqual: "bvule", - UnsignedGreaterThan: "bvugt", - UnsignedGreaterOrEqual: "bvuge", + BoolLessThan: "bvslt", + BoolLessOrEqualThan: "bvsle", + BoolGreaterThan: "bvsgt", + BoolGreaterOrEqualThan: "bvsge", + BoolUnsignedLessThan: "bvult", + BoolUnsignedLessOrEqualThan: "bvule", + BoolUnsignedGreaterThan: "bvugt", + BoolUnsignedGreaterOrEqualThan: "bvuge", BitVecSignExtend: "(_ sign_extend %d)", BitVecZeroExtend: "(_ zero_extend %d)", BitVecExtract: "(_ extract %d %d)", @@ -938,49 +946,96 @@ def visit_BitVecConstant(self, expression): def visit_BoolConstant(self, expression): return expression.value and "true" or "false" - def _visit_variable(self, expression): + def visit_Variable(self, expression): + self._variables.add(expression) return expression.name - visit_ArrayVariable = _visit_variable - visit_BitVecVariable = _visit_variable - visit_BoolVariable = _visit_variable - def visit_ArraySelect(self, expression, *operands): array_smt, index_smt = operands - if isinstance(expression.array, ArrayStore): - array_smt = self._add_binding(expression.array, array_smt) - - return "(select %s %s)" % (array_smt, index_smt) + array_smt = self._add_binding(expression.array, array_smt) + return f"(select {array_smt} {index_smt})" - def _visit_operation(self, expression, *operands): + def visit_Operation(self, expression, *operands): operation = self.translation_table[type(expression)] if isinstance(expression, (BitVecSignExtend, BitVecZeroExtend)): operation = operation % expression.extend elif isinstance(expression, BitVecExtract): operation = operation % (expression.end, expression.begining) - operands = [self._add_binding(*x) for x in zip(expression.operands, operands)] - return "(%s %s)" % (operation, " ".join(operands)) - - visit_ArrayOperation = _visit_operation - visit_BoolOperation = _visit_operation - visit_BitVecOperation = _visit_operation + bound_operands: List[str] = [ + self._add_binding(*x) for x in zip(expression.operands, operands) + ] + return f"({operation} {' '.join(bound_operands)})" @property def result(self): - output = super().result - if self.use_bindings: - for name, expr, smtlib in reversed(self._bindings): - output = "( let ((%s %s)) %s )" % (name, smtlib, output) - return output + return f"{self.apply_bindings(self._stack[-1])}" + + def apply_bindings(self, main_smtlib_str): + # Python program to print topological sorting of a DAG + # Well-sortedness requirements + from toposort import toposort_flatten as toposort + + G: Dict[Any, Any] = {} + for name, (expr, smtlib) in self._bindings.items(): + variables = {v.name for v in get_variables(expr)} + variables.update(re.findall(r"!a_\d+!", smtlib)) + + for v_name in variables: + G.setdefault(name, set()).add(v_name) + if not variables: + G[name] = set() + + # Build let statement + output = main_smtlib_str + for name in reversed(toposort(G)): + if name not in self._bindings: + continue + expr, smtlib = self._bindings[name] + + # FIXME: too much string manipulation. Search occurrences in the Expression realm + if output.count(name) <= 1: + # output = f"let (({name} {smtlib})) ({output})" + output = output.replace(name, smtlib) + else: + output = f"(let (({name} {smtlib})) {output})" + + return f"{output}" + + def declarations(self): + result = "" + for exp in self._variables: + if isinstance(exp, BitVecVariable): + result += f"(declare-fun {exp.name} () (_ BitVec {exp.size}))\n" + elif isinstance(exp, BoolVariable): + result += f"(declare-fun {exp.name} () Bool)\n" + elif isinstance(exp, ArrayVariable): + result += f"(declare-fun {exp.name} () (Array (_ BitVec {exp.index_size}) (_ BitVec {exp.value_size})))\n" + else: + raise ConstraintException(f"Type not supported {exp!r}") + return result + + def smtlib(self): + result = self.declarations() + for constraint_str in self._stack: + result += f"(assert {self.apply_bindings(constraint_str)})\n" + return result -def translate_to_smtlib(expression, **kwargs): - translator = TranslatorSmtlib(**kwargs) +@lru_cache(maxsize=128, typed=True) +def _translate_to_smtlib(expression, use_bindings=True, **kwargs): + translator = TranslatorSmtlib(use_bindings=use_bindings, **kwargs) translator.visit(expression) return translator.result +def translate_to_smtlib(expression, use_bindings=True, **kwargs): + if isinstance(expression, MutableArray): + expression = expression.array + result = _translate_to_smtlib(expression, use_bindings=use_bindings, **kwargs) + return result + + class Replace(Visitor): """ Simple visitor to replaces expressions """ @@ -990,19 +1045,21 @@ def __init__(self, bindings=None, **kwargs): raise ValueError("bindings needed in replace") self._replace_bindings = bindings - def _visit_variable(self, expression): + def visit(self, *args, **kwargs): + return super().visit(*args, **kwargs) + + def visit_Variable(self, expression): if expression in self._replace_bindings: return self._replace_bindings[expression] return expression - visit_ArrayVariable = _visit_variable - visit_BitVecVariable = _visit_variable - visit_BoolVariable = _visit_variable - def replace(expression, bindings): if not bindings: return expression + if isinstance(expression, MutableArray): + expression = expression.array + visitor = Replace(bindings) visitor.visit(expression, use_fixed_point=True) result_expression = visitor.result @@ -1033,10 +1090,37 @@ def simplify_array_select(array_exp): return simplifier.stores -def get_variables(expression): - if isinstance(expression, ArrayProxy): - expression = expression.array - +@lru_cache(maxsize=128, typed=True) +def _get_variables(expression): visitor = GetDeclarations() visitor.visit(expression) return visitor.result + + +def get_variables(expression): + if isinstance(expression, MutableArray): + expression = expression.array + return _get_variables(expression) + + +class GetBindings(Visitor): + """Simple visitor to collect all variables in an expression or set of + expressions + """ + + def _rebuild(self, expression, operands): + return expression + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.expressions = {} + + def visit_Operation(self, expression, *operands): + try: + self.expressions[expression] += 1 + except KeyError as e: + self.expressions[expression] = 1 + + @property + def result(self): + return self.expressions diff --git a/manticore/core/state.py b/manticore/core/state.py index f5536cef6..3ddd22dda 100644 --- a/manticore/core/state.py +++ b/manticore/core/state.py @@ -1,12 +1,17 @@ import copy import logging +from typing import Any, Dict, List, Optional, TypeVar, TYPE_CHECKING -from .smtlib import solver, Bool, issymbolic, BitVecConstant +from .smtlib import Bool, ConstraintSet, Expression, issymbolic, BitVecConstant, MutableArray from ..utils.event import Eventful from ..utils.helpers import PickleSerializer from ..utils import config from .plugin import StateDescriptor +if TYPE_CHECKING: + from .manticore import ManticoreBase + from ..platforms.platform import Platform + consts = config.get_group("core") consts.add( "execs_per_intermittent_cb", @@ -167,21 +172,30 @@ class StateBase(Eventful): """ Representation of a unique program state/path. - :param ConstraintSet constraints: Initial constraints - :param Platform platform: Initial operating system state + :param constraints: Initial constraints + :param platform: Initial operating system state :ivar dict context: Local context for arbitrary data storage """ _published_events = {"execution_intermittent"} - def __init__(self, constraints, platform, **kwargs): + def __init__( + self, + *, + constraints: ConstraintSet, + platform: "Platform", + manticore: Optional["ManticoreBase"] = None, + **kwargs, + ): super().__init__(**kwargs) + self._manticore = manticore self._platform = platform self._constraints = constraints - self._platform.constraints = constraints - self._input_symbols = list() + self._input_symbols: List[Expression] = list() + self._child = None - self._context = dict() + self._context: Dict[str, Any] = dict() + self._terminated_by = None self._solver = EventSolver() self._total_exec = 0 @@ -189,15 +203,16 @@ def __init__(self, constraints, platform, **kwargs): # 33 # Events are lost in serialization and fork !! self.forward_events_from(self._solver) - self.forward_events_from(platform) + self._platform.set_state(self) def __getstate__(self): state = super().__getstate__() state["platform"] = self._platform state["constraints"] = self._constraints - state["input_symbols"] = self._input_symbols state["child"] = self._child state["context"] = self._context + state["input_symbols"] = self._input_symbols + state["terminated_by"] = self._terminated_by state["exec_counter"] = self._total_exec return state @@ -205,10 +220,13 @@ def __getstate__(self): def __setstate__(self, state): super().__setstate__(state) self._platform = state["platform"] + self._constraints = state["constraints"] - self._input_symbols = state["input_symbols"] self._child = state["child"] self._context = state["context"] + self._input_symbols = state["input_symbols"] + self._manticore = None + self._terminated_by = state["terminated_by"] self._total_exec = state["exec_counter"] self._own_exec = 0 @@ -216,7 +234,7 @@ def __setstate__(self, state): # 33 # Events are lost in serialization and fork !! self.forward_events_from(self._solver) - self.forward_events_from(self._platform) + self._platform.set_state(self) @property def id(self): @@ -229,24 +247,30 @@ def __repr__(self): # This need to change. this is the center of ALL the problems. re. CoW def __enter__(self): assert self._child is None - self._platform.constraints = None - new_state = self.__class__(self._constraints.__enter__(), self._platform) - self.platform.constraints = new_state.constraints - new_state._input_symbols = list(self._input_symbols) + self._platform._constraints = None + new_state = self.__class__( + constraints=self._constraints.__enter__(), + platform=self._platform, + manticore=self._manticore, + ) + # Keep the same constraint + self.platform._constraints = new_state.constraints + # backup copy of the context new_state._context = copy.copy(self._context) + new_state._input_symbols = self._input_symbols new_state._id = None + new_state._total_exec = self._total_exec self.copy_eventful_state(new_state) - self._child = new_state assert new_state.platform.constraints is new_state.constraints + # assert self.platform.constraints is self.constraints return new_state def __exit__(self, ty, value, traceback): self._constraints.__exit__(ty, value, traceback) self._child = None - self.platform.constraints = self.constraints @property def input_symbols(self): @@ -267,7 +291,6 @@ def constraints(self): @constraints.setter def constraints(self, constraints): self._constraints = constraints - self.platform.constraints = constraints def _update_state_descriptor(self, descriptor: StateDescriptor, *args, **kwargs): """ @@ -324,8 +347,8 @@ def new_symbolic_buffer(self, nbytes, **options): taint = options.get("taint", frozenset()) expr = self._constraints.new_array( name=label, - index_max=nbytes, - value_bits=8, + length=nbytes, + value_size=8, taint=taint, avoid_collisions=avoid_collisions, ) @@ -368,7 +391,7 @@ def concretize(self, symbolic, policy, maxcount=7): than `maxcount` feasible solutions, some states will be silently ignored.** """ - assert self.constraints == self.platform.constraints + # assert self.constraints is self.platform.constraints symbolic = self.migrate_expression(symbolic) vals = [] @@ -419,6 +442,8 @@ def concretize(self, symbolic, policy, maxcount=7): return tuple(set(vals)) def migrate_expression(self, expression): + if isinstance(expression, MutableArray): + expression = expression.array if not issymbolic(expression): return expression migration_map = self.context.get("migration_map") @@ -461,19 +486,10 @@ def solve_one_n(self, *exprs, constrain=False): :return: Concrete value or a tuple of concrete values :rtype: int """ - values = [] - for expr in exprs: - if not issymbolic(expr): - values.append(expr) - else: - expr = self.migrate_expression(expr) - value = self._solver.get_value(self._constraints, expr) - if constrain: - self.constrain(expr == value) - # Include forgiveness here - if isinstance(value, bytearray): - value = bytes(value) - values.append(value) + expressions = [self.migrate_expression(e) for e in exprs] + values = self._solver.get_value(self._constraints, *expressions) + if len(expressions) == 1: + values = (values,) return values def solve_n(self, expr, nsolves): @@ -573,7 +589,7 @@ def symbolicate_buffer( if wildcard in data: size = len(data) symb = self._constraints.new_array( - name=label, index_max=size, taint=taint, avoid_collisions=True + name=label, length=size, taint=taint, avoid_collisions=True ) self._input_symbols.append(symb) diff --git a/manticore/core/workspace.py b/manticore/core/workspace.py index a4a711ae3..9b269374a 100644 --- a/manticore/core/workspace.py +++ b/manticore/core/workspace.py @@ -602,6 +602,7 @@ def save_testcase(self, state: StateBase, testcase: Testcase, message: str = "") # The workspaces should override `save_testcase` method # # Below is native-only + self.save_summary(testcase, state, message) self.save_trace(testcase, state) self.save_constraints(testcase, state) diff --git a/manticore/ethereum/abi.py b/manticore/ethereum/abi.py index e3a1bc310..3db593007 100644 --- a/manticore/ethereum/abi.py +++ b/manticore/ethereum/abi.py @@ -11,7 +11,7 @@ Operators, BitVec, ArrayVariable, - ArrayProxy, + MutableArray, to_constant, issymbolic, ) @@ -285,7 +285,7 @@ def _serialize_uint(value, size=32, padding=0): assert isinstance(value, BitVec) # FIXME This temporary array variable should be obtained from a specific constraint store buffer = ArrayVariable( - index_bits=256, index_max=32, value_bits=8, name="temp{}".format(uuid.uuid1()) + index_size=256, length=32, value_size=8, name="temp{}".format(uuid.uuid1()) ) if value.size <= size * 8: value = Operators.ZEXTEND(value, size * 8) @@ -317,10 +317,10 @@ def _serialize_int(value: typing.Union[int, BitVec], size=32, padding=0): # Help mypy out. Can remove this by teaching it how issymbolic works assert isinstance(value, BitVec) buf = ArrayVariable( - index_bits=256, index_max=32, value_bits=8, name="temp{}".format(uuid.uuid1()) + index_size=256, length=32, value_size=8, name="temp{}".format(uuid.uuid1()) ) value = Operators.SEXTEND(value, value.size, size * 8) - return ArrayProxy(buf.write_BE(padding, value, size)) + return MutableArray(buf.write_BE(padding, value, size)) else: buf_arr = bytearray() for _ in range(padding): diff --git a/manticore/ethereum/detectors.py b/manticore/ethereum/detectors.py index fd90e0ba2..ded85c1b2 100644 --- a/manticore/ethereum/detectors.py +++ b/manticore/ethereum/detectors.py @@ -632,6 +632,10 @@ class DetectDelegatecall(Detector): def _to_constant(self, expression): if isinstance(expression, Constant): return expression.value + else: + expression = simplify(expression) + if isinstance(expression, Constant): + return expression.value return expression def will_evm_execute_instruction_callback(self, state, instruction, arguments): diff --git a/manticore/ethereum/manticore.py b/manticore/ethereum/manticore.py index ea0de52c9..494184a5a 100644 --- a/manticore/ethereum/manticore.py +++ b/manticore/ethereum/manticore.py @@ -17,7 +17,7 @@ from ..core.smtlib import ( ConstraintSet, Array, - ArrayProxy, + MutableArray, BitVec, Operators, BoolConstant, @@ -119,10 +119,10 @@ class ManticoreEVM(ManticoreBase): _published_events = {"solve"} - def make_symbolic_buffer(self, size, name=None, avoid_collisions=False): + def make_symbolic_buffer(self, size, name=None, avoid_collisions=False, default=None): """Creates a symbolic buffer of size bytes to be used in transactions. - You can operate on it normally and add constraints to manticore.constraints - via manticore.constrain(constraint_expression) + You can operate on it normally and add constraints to manticore.constraints + via manticore.constrain(constraint_expression) Example use:: @@ -138,12 +138,13 @@ def make_symbolic_buffer(self, size, name=None, avoid_collisions=False): avoid_collisions = True return self.constraints.new_array( - index_bits=256, + index_size=256, name=name, - index_max=size, - value_bits=8, + length=size, + value_size=8, taint=frozenset(), avoid_collisions=avoid_collisions, + default=default, ) def make_symbolic_value(self, nbits=256, name=None): @@ -293,7 +294,6 @@ def _compile_through_crytic_compile(filename, contract_name, libraries, crytic_c filename = crytic_compile.filename_of_contract(name).absolute with open(filename) as f: source_code = f.read() - return name, source_code, bytecode, runtime, srcmap, srcmap_runtime, hashes, abi except InvalidCompilation as e: @@ -390,7 +390,7 @@ def __init__(self, plugins=None, **kwargs): constraints = ConstraintSet() # make the ethereum world state world = evm.EVMWorld(constraints) - initial_state = State(constraints, world) + initial_state = State(constraints=constraints, platform=world, maticore=self) super().__init__(initial_state, **kwargs) if plugins is not None: for p in plugins: @@ -601,18 +601,9 @@ def solidity_create_contract( for state in self.ready_states: world = state.platform - - expr = Operators.UGE(world.get_balance(owner.address), balance) - self._publish("will_solve", None, self.constraints, expr, "can_be_true") - sufficient = SelectedSolver.instance().can_be_true( - self.constraints, - expr, - ) - self._publish( - "did_solve", None, self.constraints, expr, "can_be_true", sufficient - ) - - if not sufficient: + if not state.can_be_true( + Operators.UGE(world.get_balance(owner.address), balance) + ): raise EthereumError( f"Can't create solidity contract with balance ({balance}) " f"because the owner account ({owner}) has insufficient balance." @@ -773,6 +764,8 @@ def transaction(self, caller, address, value, data, gas=None, price=1): :param price: gas unit price :raises NoAliveStates: if there are no alive states to execute """ + if isinstance(data, MutableArray): + data = data.array self._transaction( "CALL", caller, value=value, address=address, data=data, gas=gas, price=price ) @@ -865,11 +858,9 @@ def _migrate_tx_expressions(self, state, caller, address, value, data, gas, pric value = state.migrate_expression(value) if issymbolic(data): - if isinstance(data, ArrayProxy): # FIXME is this necessary here? + if isinstance(data, MutableArray): # FIXME is this necessary here? data = data.array data = state.migrate_expression(data) - if isinstance(data, Array): - data = ArrayProxy(data) if issymbolic(gas): gas = state.migrate_expression(gas) @@ -1023,7 +1014,7 @@ def preconstraint_for_call_transaction( selectors = contract_metadata.function_selectors if not selectors or len(data) <= 4: - return BoolConstant(True) + return BoolConstant(value=True) symbolic_selector = data[:4] @@ -1096,7 +1087,8 @@ def multi_tx_analysis( logger.info("Starting symbolic transaction: %d", tx_no) # run_symbolic_tx - symbolic_data = self.make_symbolic_buffer(320) + symbolic_data = self.make_symbolic_buffer(320, default=None) + symbolic_data = MutableArray(symbolic_data) if tx_send_ether: value = self.make_symbolic_value() else: @@ -1415,6 +1407,7 @@ def match(state, func, symbolic_pairs, concrete_pairs, start=None): return state.can_be_true(True) def fix_unsound_symbolication(self, state): + logger.info(f"Starting unsound symbolication search for {state.id}") soundcheck = state.context.get("soundcheck", None) if soundcheck is not None: return soundcheck @@ -1424,6 +1417,7 @@ def fix_unsound_symbolication(self, state): state.context["soundcheck"] = self.fix_unsound_symbolication_fake(state) else: state.context["soundcheck"] = True + logger.info(f"Done unsound symbolication search for {state.id}") return state.context["soundcheck"] def _terminate_state_callback(self, state, e): diff --git a/manticore/ethereum/parsetab.py b/manticore/ethereum/parsetab.py index d71b9d2e7..3faba9d00 100644 --- a/manticore/ethereum/parsetab.py +++ b/manticore/ethereum/parsetab.py @@ -5,38 +5,348 @@ _lr_method = "LALR" -_lr_signature = "ADDRESS BOOL BYTES BYTESM COMMA FIXED FIXEDMN FUNCTION INT INTN LBRAKET LPAREN NUMBER RBRAKET RPAREN STRING UFIXED UFIXEDMN UINT UINTN\n T : UINTN\n T : UINT\n T : INTN\n T : INT\n T : ADDRESS\n T : BOOL\n T : FIXEDMN\n T : UFIXEDMN\n T : FIXED\n T : UFIXED\n T : BYTESM\n T : FUNCTION\n T : BYTES\n T : STRING\n\n \n TL : T\n \n TL : T COMMA TL\n \n T : LPAREN TL RPAREN\n \n T : LPAREN RPAREN\n \n T : T LBRAKET RBRAKET\n \n T : T LBRAKET NUMBER RBRAKET\n " +_lr_signature = "ADDRESS BOOL BYTES BYTESM COMMA FIXED FIXEDMN FUNCTION INT INTN LBRAKET LPAREN NUMBER RBRAKET RPAREN STRING UFIXED UFIXEDMN UINT UINTN\n T : UINTN\n T : UINT\n T : INTN\n T : INT\n T : ADDRESS\n T : BOOL\n T : FIXEDMN\n T : UFIXEDMN\n T : FIXED\n T : UFIXED\n T : BYTESM\n T : FUNCTION\n T : BYTES\n T : STRING\n\n \n TL : T\n \n TL : T COMMA TL\n \n T : LPAREN TL RPAREN\n \n T : LPAREN RPAREN\n \n T : T LBRAKET RBRAKET\n \n T : T LBRAKET NUMBER RBRAKET\n " _lr_action_items = { - "UINTN": ([0, 16, 24], [2, 2, 2]), - "UINT": ([0, 16, 24], [3, 3, 3]), - "INTN": ([0, 16, 24], [4, 4, 4]), - "INT": ([0, 16, 24], [5, 5, 5]), - "ADDRESS": ([0, 16, 24], [6, 6, 6]), - "BOOL": ([0, 16, 24], [7, 7, 7]), - "FIXEDMN": ([0, 16, 24], [8, 8, 8]), - "UFIXEDMN": ([0, 16, 24], [9, 9, 9]), - "FIXED": ([0, 16, 24], [10, 10, 10]), - "UFIXED": ([0, 16, 24], [11, 11, 11]), - "BYTESM": ([0, 16, 24], [12, 12, 12]), - "FUNCTION": ([0, 16, 24], [13, 13, 13]), - "BYTES": ([0, 16, 24], [14, 14, 14]), - "STRING": ([0, 16, 24], [15, 15, 15]), - "LPAREN": ([0, 16, 24], [16, 16, 16]), + "UINTN": ( + [ + 0, + 16, + 24, + ], + [ + 2, + 2, + 2, + ], + ), + "UINT": ( + [ + 0, + 16, + 24, + ], + [ + 3, + 3, + 3, + ], + ), + "INTN": ( + [ + 0, + 16, + 24, + ], + [ + 4, + 4, + 4, + ], + ), + "INT": ( + [ + 0, + 16, + 24, + ], + [ + 5, + 5, + 5, + ], + ), + "ADDRESS": ( + [ + 0, + 16, + 24, + ], + [ + 6, + 6, + 6, + ], + ), + "BOOL": ( + [ + 0, + 16, + 24, + ], + [ + 7, + 7, + 7, + ], + ), + "FIXEDMN": ( + [ + 0, + 16, + 24, + ], + [ + 8, + 8, + 8, + ], + ), + "UFIXEDMN": ( + [ + 0, + 16, + 24, + ], + [ + 9, + 9, + 9, + ], + ), + "FIXED": ( + [ + 0, + 16, + 24, + ], + [ + 10, + 10, + 10, + ], + ), + "UFIXED": ( + [ + 0, + 16, + 24, + ], + [ + 11, + 11, + 11, + ], + ), + "BYTESM": ( + [ + 0, + 16, + 24, + ], + [ + 12, + 12, + 12, + ], + ), + "FUNCTION": ( + [ + 0, + 16, + 24, + ], + [ + 13, + 13, + 13, + ], + ), + "BYTES": ( + [ + 0, + 16, + 24, + ], + [ + 14, + 14, + 14, + ], + ), + "STRING": ( + [ + 0, + 16, + 24, + ], + [ + 15, + 15, + 15, + ], + ), + "LPAREN": ( + [ + 0, + 16, + 24, + ], + [ + 16, + 16, + 16, + ], + ), "$end": ( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 19, 21, 23, 25], - [0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -18, -19, -17, -20], + [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 19, + 21, + 23, + 25, + ], + [ + 0, + -1, + -2, + -3, + -4, + -5, + -6, + -7, + -8, + -9, + -10, + -11, + -12, + -13, + -14, + -18, + -19, + -17, + -20, + ], ), "LBRAKET": ( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 19, 20, 21, 23, 25], - [17, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -18, 17, -19, -17, -20], + [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 19, + 20, + 21, + 23, + 25, + ], + [ + 17, + -1, + -2, + -3, + -4, + -5, + -6, + -7, + -8, + -9, + -10, + -11, + -12, + -13, + -14, + -18, + 17, + -19, + -17, + -20, + ], ), "COMMA": ( - [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 19, 20, 21, 23, 25], - [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -18, 24, -19, -17, -20], + [ + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 19, + 20, + 21, + 23, + 25, + ], + [ + -1, + -2, + -3, + -4, + -5, + -6, + -7, + -8, + -9, + -10, + -11, + -12, + -13, + -14, + -18, + 24, + -19, + -17, + -20, + ], ), "RPAREN": ( - [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 23, 25, 26], + [ + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 18, + 19, + 20, + 21, + 23, + 25, + 26, + ], [ -1, -2, @@ -62,8 +372,24 @@ -16, ], ), - "RBRAKET": ([17, 22], [21, 25]), - "NUMBER": ([17], [22]), + "RBRAKET": ( + [ + 17, + 22, + ], + [ + 21, + 25, + ], + ), + "NUMBER": ( + [ + 17, + ], + [ + 22, + ], + ), } _lr_action = {} @@ -74,7 +400,30 @@ _lr_action[_x][_k] = _y del _lr_action_items -_lr_goto_items = {"T": ([0, 16, 24], [1, 20, 20]), "TL": ([16, 24], [18, 26])} +_lr_goto_items = { + "T": ( + [ + 0, + 16, + 24, + ], + [ + 1, + 20, + 20, + ], + ), + "TL": ( + [ + 16, + 24, + ], + [ + 18, + 26, + ], + ), +} _lr_goto = {} for _k, _v in _lr_goto_items.items(): diff --git a/manticore/native/cpu/aarch64.py b/manticore/native/cpu/aarch64.py index 2957c5fb3..ce861b753 100644 --- a/manticore/native/cpu/aarch64.py +++ b/manticore/native/cpu/aarch64.py @@ -40,7 +40,7 @@ class Aarch64InvalidInstruction(CpuException): # See "C1.2.4 Condition code". -Condspec = collections.namedtuple("CondSpec", "inverse func") +Condspec = collections.namedtuple("Condspec", "inverse func") COND_MAP = { cs.arm64.ARM64_CC_EQ: Condspec(cs.arm64.ARM64_CC_NE, lambda n, z, c, v: z == 1), cs.arm64.ARM64_CC_NE: Condspec(cs.arm64.ARM64_CC_EQ, lambda n, z, c, v: z == 0), @@ -75,7 +75,7 @@ class Aarch64InvalidInstruction(CpuException): class Aarch64RegisterFile(RegisterFile): - Regspec = collections.namedtuple("RegSpec", "parent size") + Regspec = collections.namedtuple("Regspec", "parent size") # Register table. _table = {} diff --git a/manticore/native/cpu/x86.py b/manticore/native/cpu/x86.py index c43b95f7b..730514df4 100644 --- a/manticore/native/cpu/x86.py +++ b/manticore/native/cpu/x86.py @@ -144,7 +144,7 @@ def new_method(cpu, *args, **kw_args): class AMD64RegFile(RegisterFile): - Regspec = collections.namedtuple("RegSpec", "register_id ty offset size reset") + Regspec = collections.namedtuple("Regspec", "register_id ty offset size reset") _flags = {"CF": 0, "PF": 2, "AF": 4, "ZF": 6, "SF": 7, "IF": 9, "DF": 10, "OF": 11} _table = { "CS": Regspec("CS", int, 0, 16, False), @@ -2079,7 +2079,7 @@ def NEG(cpu, dest): :param dest: destination operand. """ source = dest.read() - res = dest.write(-source) + res = dest.write((~source) + 1) cpu._calculate_logic_flags(dest.size, res) cpu.CF = source != 0 cpu.AF = (res & 0x0F) != 0x00 diff --git a/manticore/native/manticore.py b/manticore/native/manticore.py index b9cdee2af..1dcce0b4a 100644 --- a/manticore/native/manticore.py +++ b/manticore/native/manticore.py @@ -199,11 +199,8 @@ def _assertions_callback(self, state, pc, instruction): # (It may dereference pointers) assertion = parse(program, state.cpu.read_int, state.cpu.read_register) if not state.can_be_true(assertion): - logger.info(str(state.cpu)) - logger.info( - "Assertion %x -> {%s} does not hold. Aborting state.", state.cpu.PC, program - ) - raise TerminateState() + message = f"Assertion {state.cpu.PC:x} -> {program:s} does not hold. Aborting state." + raise TerminateState(message=message) # Everything is good add it. state.constraints.add(assertion) @@ -444,7 +441,7 @@ def _make_decree(program, concrete_start="", **kwargs): constraints = ConstraintSet() platform = decree.SDecree(constraints, program) - initial_state = State(constraints, platform) + initial_state = State(constraints=constraints, platform=platform) logger.info("Loading program %s", program) if concrete_start != "": @@ -491,7 +488,7 @@ def _make_linux( # TODO: use argv as arguments for function platform.set_entry(entry_pc) - initial_state = State(constraints, platform) + initial_state = State(constraints=constraints, platform=platform) if concrete_start != "": logger.info("Starting with concrete input: %s", concrete_start) diff --git a/manticore/native/memory.py b/manticore/native/memory.py index 5f8ca25c4..2417db6cc 100644 --- a/manticore/native/memory.py +++ b/manticore/native/memory.py @@ -12,6 +12,7 @@ expression, issymbolic, Expression, + MutableArray, ) from ..native.mappings import mmap, munmap from ..utils.helpers import interval_intersection @@ -346,8 +347,10 @@ def __init__( if backing_array is not None: self._array = backing_array else: - self._array = expression.ArrayProxy( - expression.ArrayVariable(index_bits, index_max=size, value_bits=8, name=name) + self._array = expression.MutableArray( + expression.ArrayVariable( + index_size=index_bits, length=size, value_size=8, name=name + ) ) def __reduce__(self): @@ -357,7 +360,7 @@ def __reduce__(self): self.start, len(self), self._perms, - self._array.index_bits, + self._array.index_size, self._array, self._array.name, ), @@ -376,16 +379,20 @@ def split(self, address: int): return self, None assert self.start < address < self.end - index_bits, value_bits = self._array.index_bits, self._array.value_bits + index_bits, value_bits = self._array.index_size, self._array.value_size left_size, right_size = address - self.start, self.end - address left_name, right_name = ["{}_{:d}".format(self._array.name, i) for i in range(2)] - head_arr = expression.ArrayProxy( - expression.ArrayVariable(index_bits, left_size, value_bits, name=left_name) + head_arr = expression.MutableArray( + expression.ArrayVariable( + index_size=index_bits, length=left_size, value_size=value_bits, name=left_name + ) ) - tail_arr = expression.ArrayProxy( - expression.ArrayVariable(index_bits, right_size, value_bits, name=right_name) + tail_arr = expression.MutableArray( + expression.ArrayVariable( + index_size=index_bits, length=right_size, value_size=value_bits, name=right_name + ) ) head = ArrayMap(self.start, left_size, self.perms, index_bits, head_arr, left_name) @@ -1376,7 +1383,7 @@ class LazySMemory(SMemory): def __init__(self, constraints, *args, **kwargs): super(LazySMemory, self).__init__(constraints, *args, **kwargs) - self.backing_array = constraints.new_array(index_bits=self.memory_bit_size) + self.backing_array = MutableArray(constraints.new_array(index_size=self.memory_bit_size)) self.backed_by_symbolic_store = set() def __reduce__(self): diff --git a/manticore/native/models.py b/manticore/native/models.py index 8c9b7e2c9..d50920277 100644 --- a/manticore/native/models.py +++ b/manticore/native/models.py @@ -5,6 +5,7 @@ from .cpu.abstractcpu import Cpu, ConcretizeArgument from .state import State from ..core.smtlib import issymbolic, BitVec +from ..core.smtlib.solver import SelectedSolver, issymbolic, BitVec from ..core.smtlib.operators import ITEBV, ZEXTEND from ..core.state import Concretize from typing import Union diff --git a/manticore/native/state.py b/manticore/native/state.py index c57031865..d8db868fa 100644 --- a/manticore/native/state.py +++ b/manticore/native/state.py @@ -1,17 +1,20 @@ import copy import logging from collections import namedtuple -from typing import Any, Callable, Dict, NamedTuple, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, NamedTuple, Optional, Set, Tuple, Union, TYPE_CHECKING from .cpu.disasm import Instruction from .memory import ConcretizeMemory, MemoryException from .. import issymbolic from ..core.state import StateBase, Concretize, TerminateState -from ..core.smtlib import Expression +from ..core.smtlib import Expression, ConstraintSet from ..platforms import linux_syscalls +if TYPE_CHECKING: + from ..platforms.linux import Linux + from ..platforms.decree import Decree -HookCallback = Callable[[StateBase], None] +HookCallback = Callable[["State"], None] logger = logging.getLogger(__name__) @@ -21,8 +24,8 @@ class CheckpointData(NamedTuple): class State(StateBase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, *, constraints: ConstraintSet, platform: Union["Linux", "Decree"], **kwargs): + super().__init__(constraints=constraints, platform=platform, **kwargs) self._hooks: Dict[Optional[int], Set[HookCallback]] = {} self._after_hooks: Dict[Optional[int], Set[HookCallback]] = {} self._sys_hooks: Dict[Optional[int], Set[HookCallback]] = {} @@ -327,7 +330,7 @@ def setstate(state: State, value): raise TerminateState(str(e), testcase=True) # Remove when code gets stable? - assert self.platform.constraints is self.constraints + # assert self.platform.constraints is self.constraints return result diff --git a/manticore/platforms/decree.py b/manticore/platforms/decree.py index be36986a1..d5057b4fe 100644 --- a/manticore/platforms/decree.py +++ b/manticore/platforms/decree.py @@ -7,7 +7,7 @@ from ..core.smtlib import * from ..core.state import TerminateState from ..binary import CGCElf -from ..platforms.platform import Platform +from ..platforms.platform import NativePlatform import logging import random @@ -76,7 +76,7 @@ def _transmit(self, buf): return len(buf) -class Decree(Platform): +class Decree(NativePlatform): """ A simple Decree Operating System. This class emulates the most common Decree system calls diff --git a/manticore/platforms/evm.py b/manticore/platforms/evm.py index 1630ba9e5..f74889d13 100644 --- a/manticore/platforms/evm.py +++ b/manticore/platforms/evm.py @@ -1,18 +1,18 @@ """Symbolic EVM implementation based on the yellow paper: http://gavwood.com/paper.pdf""" +import time import uuid import binascii import random import io import copy import inspect -from functools import wraps -from typing import List, Set, Tuple, Union +from typing import List, Set, Tuple, Union, Dict from ..platforms.platform import * from ..core.smtlib import ( SelectedSolver, BitVec, Array, - ArrayProxy, + MutableArray, Operators, Constant, ArrayVariable, @@ -20,20 +20,20 @@ BitVecConstant, translate_to_smtlib, to_constant, - simplify, get_depth, issymbolic, get_taints, istainted, taint_with, + simplify, ) from ..core.state import Concretize, TerminateState from ..utils.event import Eventful from ..utils.helpers import printable_bytes from ..utils import config -from ..core.smtlib.visitors import simplify from ..exceptions import EthereumError import pyevmasm as EVMAsm + import logging from collections import namedtuple import sha3 @@ -71,7 +71,7 @@ def globalfakesha3(data): "oog", default="ignore", description=( - "Default behavior for symbolic gas." + "Default behavior for symbolic out of gas exception." "pedantic: Fully faithful. Test at every instruction. Forks." "complete: Mostly faithful. Test at BB limit. Forks." "concrete: Incomplete. Concretize gas to MIN/MAX values. Forks." @@ -135,7 +135,8 @@ def ceil32(x): def to_signed(i): - return Operators.ITEBV(256, i < TT255, i, i - TT256) + i &= (1 << 256) - 1 + return Operators.ITEBV(256, Operators.EXTRACT(i, 255, 1) == 0, i, -((1 << 256) - i)) class Transaction: @@ -182,13 +183,24 @@ def concretize(self, state, constrain=False): :param state: a manticore state :param bool constrain: If True, constrain expr to concretized value """ - conc_caller = state.solve_one(self.caller, constrain=constrain) - conc_address = state.solve_one(self.address, constrain=constrain) - conc_value = state.solve_one(self.value, constrain=constrain) - conc_gas = state.solve_one(self.gas, constrain=constrain) - conc_data = state.solve_one(self.data, constrain=constrain) - conc_return_data = state.solve_one(self.return_data, constrain=constrain) - conc_used_gas = state.solve_one(self.used_gas, constrain=constrain) + ( + conc_caller, + conc_address, + conc_value, + conc_gas, + conc_data, + conc_return_data, + conc_used_gas, + ) = state.solve_one_n( + self.caller, + self.address, + self.value, + self.gas, + self.data, + self.return_data, + self.used_gas, + constrain=constrain, + ) return Transaction( self.sort, conc_address, @@ -490,6 +502,8 @@ def __init__(self, result, data=None): raise EVMException("Invalid end transaction result") if result is None and data is not None: raise EVMException("Invalid end transaction result") + if isinstance(data, MutableArray): + data = data.array if not isinstance(data, (type(None), Array, bytes)): raise EVMException("Invalid end transaction data type") self.result = result @@ -697,13 +711,12 @@ def pos(self, pos): def __init__( self, - constraints, address, data, caller, value, bytecode, - world=None, + world, gas=None, fork=DEFAULT_FORK, **kwargs, @@ -721,31 +734,30 @@ def __init__( :param gas: gas budget for this transaction """ super().__init__(**kwargs) + constraints = world.constraints if data is not None and not issymbolic(data): data_size = len(data) data_symbolic = constraints.new_array( - index_bits=256, - value_bits=8, - index_max=data_size, + index_size=256, + value_size=8, + length=data_size, name=f"DATA_{address:x}", avoid_collisions=True, default=0, ) - data_symbolic[0:data_size] = data - data = data_symbolic + data = data_symbolic.write(0, data) if bytecode is not None and not issymbolic(bytecode): bytecode_size = len(bytecode) bytecode_symbolic = constraints.new_array( - index_bits=256, - value_bits=8, - index_max=bytecode_size, + index_size=256, + value_size=8, + length=bytecode_size, name=f"BYTECODE_{address:x}", avoid_collisions=True, default=0, ) - bytecode_symbolic[0:bytecode_size] = bytecode - bytecode = bytecode_symbolic + data = bytecode_symbolic.write(0, bytecode) # TODO: Handle the case in which bytecode is symbolic (This happens at # CREATE instructions that has the arguments appended to the bytecode) @@ -779,12 +791,14 @@ def extend_with_zeroes(b): # raise EVMException("Need code") self._constraints = constraints # Uninitialized values in memory are 0 by spec - self.memory = constraints.new_array( - index_bits=256, - value_bits=8, - name=f"EMPTY_MEMORY_{address:x}", - avoid_collisions=True, - default=0, + self.memory = MutableArray( + constraints.new_array( + index_size=256, + value_size=8, + name=f"EMPTY_MEMORY_{address:x}", + avoid_collisions=True, + default=0, + ) ) self.address = address self.caller = ( @@ -817,7 +831,7 @@ def extend_with_zeroes(b): self._valid_jmpdests = set() self._sha3 = {} self._refund = 0 - self._temp_call_gas = None + self._temp_call_gas = 0 self._failed = False def fail_if(self, failed): @@ -826,7 +840,6 @@ def fail_if(self, failed): def is_failed(self): if isinstance(self._failed, bool): return self._failed - self._failed = simplify(self._failed) if isinstance(self._failed, Constant): return self._failed.value @@ -835,7 +848,7 @@ def setstate(state, value): state.platform._failed = value raise Concretize( - "Transaction failed", expression=self._failed, setstate=lambda a, b: None, policy="ALL" + "Transaction failed", expression=self._failed, setstate=setstate, policy="ALL" ) @property @@ -980,24 +993,36 @@ def read_code(self, address, size=1): return value def disassemble(self): - return EVMAsm.disassemble(self.bytecode) + return EVMAsm.disassemble(self._getcode(zerotail=False)) @property def PC(self): return self.pc - def _getcode(self, pc): + def _getcode(self, pc: int = 0, zerotail: bool = True): bytecode = self.bytecode for pc_i in range(pc, len(bytecode)): - yield simplify(bytecode[pc_i]).value - while True: - yield 0 # STOP opcode + c = bytecode[pc_i] + if issymbolic(c): + yield simplify(c).value + else: + yield c + if zerotail: + while True: + yield 0 # STOP opcode @property def instruction(self): """ Current instruction pointed by self.pc """ + return self.get_instruction(pc=self.pc) + + def get_instruction(self, pc: Union[BitVec, int]): + """ + Current instruction pointed by self.pc + """ + # FIXME check if pc points to invalid instruction # if self.pc >= len(self.bytecode): # return InvalidOpcode('Code out of range') @@ -1006,16 +1031,17 @@ def instruction(self): try: _decoding_cache = getattr(self, "_decoding_cache") except Exception: - self._decoding_cache = {} + self._decoding_cache: Dict[int, EVMAsm.Instruction] = {} _decoding_cache = self._decoding_cache - pc = self.pc if isinstance(pc, Constant): pc = pc.value if pc in _decoding_cache: return _decoding_cache[pc] + if isinstance(pc, BitVec): + raise EVMException("Trying to decode from symbolic pc") instruction = EVMAsm.disassemble_one(self._getcode(pc), pc=pc, fork=self.evmfork) _decoding_cache[pc] = instruction return instruction @@ -1042,7 +1068,7 @@ def _push(self, value): if isinstance(value, int): value = value & TT256M1 - value = simplify(value) + # value = simplify(value) if isinstance(value, Constant) and not value.taint: value = value.value @@ -1061,6 +1087,8 @@ def _pop(self): return self.stack.pop() def _consume(self, fee): + # if consts.oog == "ignore": + # return # Check type and bitvec size if isinstance(fee, int): if fee > (1 << 512) - 1: @@ -1068,13 +1096,6 @@ def _consume(self, fee): elif isinstance(fee, BitVec): if fee.size != 512: raise ValueError("Fees should be 512 bit long") - # This configuration variable allows the user to control and perhaps relax the gas calculation - # pedantic: gas is faithfully accounted and checked at instruction level. State may get forked in OOG/NoOOG - # complete: gas is faithfully accounted and checked at basic blocks limits. State may get forked in OOG/NoOOG - # concrete: Concretize gas. If the fee to be consumed gets to be symbolic. Choose some potential values and fork on those. - # optimistic: Try not to OOG. If it may be enough gas we ignore the OOG case. A constraint is added to assert the gas is enough and the OOG state is ignored. - # pesimistic: OOG soon. If it may NOT be enough gas we ignore the normal case. A constraint is added to assert the gas is NOT enough and the other state is ignored. - # ignore: Ignore gas. Do not account for it. Do not OOG. oog_condition = simplify(Operators.ULT(self._gas, fee)) self.fail_if(oog_condition) @@ -1167,6 +1188,10 @@ def _push_results(self, instruction, result): assert result is None def _calculate_gas(self, *arguments): + # if consts.oog == "ignore": + # return 0 + + start = time.time() current = self.instruction implementation = getattr(self, f"{current.semantics}_gas", None) if implementation is None: @@ -1242,13 +1267,15 @@ def _check_jmpdest(self): if isinstance(should_check_jumpdest, Constant): should_check_jumpdest = should_check_jumpdest.value elif issymbolic(should_check_jumpdest): - self._publish("will_solve", self.constraints, should_check_jumpdest, "get_all_values") + self._publish( + "will_solve", self.world.constraints, should_check_jumpdest, "get_all_values" + ) should_check_jumpdest_solutions = SelectedSolver.instance().get_all_values( - self.constraints, should_check_jumpdest + self.world.constraints, should_check_jumpdest ) self._publish( "did_solve", - self.constraints, + self.world.constraints, should_check_jumpdest, "get_all_values", should_check_jumpdest_solutions, @@ -1312,6 +1339,7 @@ def setstate(state, value): try: self._check_jmpdest() last_pc, last_gas, instruction, arguments, fee, allocated = self._checkpoint() + result = self._handler(*arguments) self._advance(result) except ConcretizeGas as ex: @@ -1396,6 +1424,7 @@ def setstate(state, value): raise def read_buffer(self, offset, size): + size = simplify(size) if issymbolic(size) and not isinstance(size, Constant): raise EVMException("Symbolic size not supported") if isinstance(size, Constant): @@ -1403,8 +1432,7 @@ def read_buffer(self, offset, size): if size == 0: return b"" self._allocate(offset, size) - data = self.memory[offset : offset + size] - return ArrayProxy(data) + return self.memory[offset : offset + size] def write_buffer(self, offset, data): self._allocate(offset, len(data)) @@ -1413,7 +1441,7 @@ def write_buffer(self, offset, data): def _load(self, offset, size=1): value = self.memory.read_BE(offset, size) - value = simplify(value) + # value = simplify(value) if isinstance(value, Constant) and not value.taint: value = value.value self._publish("did_evm_read_memory", offset, value, size) @@ -1685,19 +1713,17 @@ def CALLDATALOAD(self, offset): self.constraints.add(self.safe_add(offset, 32) <= len(self.data) + calldata_overflow) self._use_calldata(offset, 32) - data_length = len(self.data) bytes = [] for i in range(32): try: - c = simplify( - Operators.ITEBV( - 8, - Operators.ULT(self.safe_add(offset, i), data_length), - self.data[offset + i], - 0, - ) + c = Operators.ITEBV( + 8, + Operators.ULT(self.safe_add(offset, i), data_length), + self.data[offset + i], + 0, ) + c = simplify(c) except IndexError: # offset + i is concrete and outside data c = 0 @@ -1778,7 +1804,7 @@ def CALLDATACOPY(self, mem_offset, data_offset, size): if not issymbolic(c) or get_depth(c) < 3: x = c else: - # if te expression is deep enough lets replace it by a binding + # if the expression is deep enough lets replace it by a binding x = self.constraints.new_bitvec(8, name="temp{}".format(uuid.uuid1())) self.constraints.add(x == c) self._store(mem_offset + i, x) @@ -1793,7 +1819,6 @@ def CODECOPY_gas(self, mem_offset, code_offset, size): @concretized_args(code_offset="SAMPLED", size="SAMPLED") def CODECOPY(self, mem_offset, code_offset, size): """Copy code running in current environment to memory""" - self._allocate(mem_offset, size) GCOPY = 3 # cost to copy one 32 byte word copyfee = self.safe_mul(GCOPY, Operators.UDIV(self.safe_add(size, 31), 32)) @@ -1829,7 +1854,15 @@ def CODECOPY(self, mem_offset, code_offset, size): value = default else: value = self.bytecode[code_offset + i] + self._store(mem_offset + i, value) + + assert SelectedSolver.instance().must_be_true( + self.constraints, + self.memory[mem_offset : mem_offset + max_size] + == self.bytecode[code_offset : code_offset + max_size], + ) + self._publish("did_evm_read_code", self.address, code_offset, size) def GASPRICE(self): @@ -1981,11 +2014,7 @@ def SSTORE_gas(self, offset, value): self.fail_if(Operators.ULT(self.gas, SSSTORESENTRYGAS)) # Get the storage from the snapshot took before this call - try: - original_value = self.world._callstack[-1][-2].get(offset, 0) - except IndexError: - original_value = 0 - + original_value = self.world._callstack[-1][-2].select(offset) current_value = self.world.get_storage_data(storage_address, offset) def ITE(*args): @@ -2056,7 +2085,7 @@ def AND(*args): SstoreCleanRefund, 0, ) - self._refund += simplify(refund) + self._refund += refund return gascost def SSTORE(self, offset, value): @@ -2153,6 +2182,44 @@ def CREATE(self, value, offset, size): tx = self.world.last_transaction # At this point last and current tx are the same. return tx.return_value + def CREATE2_gas(self, value, offset, size): + return self._get_memfee(offset, size) + + @transact + def CREATE2(self, endowment, memory_start, memory_length, salt): + """Create a new account with associated code""" + + data = self.read_buffer(offset, size) + keccak_init = self.world.symbolic_function(globalsha3, data) + caller = msg.caller.read_BE(0, 20) + salt = salt.read_BE(0, 32) + address = self.world.symbolic_function(b"\xff" + caller + salt + keccak_init) & ( + (1 << 0x20) - 1 + ) + + self.world.start_transaction( + "CREATE", + address, + data=data, + caller=self.address, + value=value, + gas=self.gas, + ) + + raise StartTx() + + @CREATE.pos # type: ignore + def CREATE2(self, value, offset, size): + """Create a new account with associated code""" + tx = self.world.last_transaction # At this point last and current tx are the same. + address = tx.address + if tx.result == "RETURN": + self.world.set_code(tx.address, tx.return_data) + else: + self.world.delete_account(address) + address = 0 + return address + def CALL_gas(self, wanted_gas, address, value, in_offset, in_size, out_offset, out_size): """ Dynamic gas for CALL instruction. _arguably turing complete in itself_ """ GCALLVALUE = 9000 @@ -2363,9 +2430,8 @@ def __str__(self): sp = 0 for i in list(reversed(self.stack))[:10]: argname = args.get(sp, "") - r = "" if issymbolic(i): - r = "{:>12s} {:66s}".format(argname, repr(i)) + r = "{:>12s}".format(argname) else: r = "{:>12s} 0x{:064x}".format(argname, i) sp += 1 @@ -2390,6 +2456,12 @@ def __str__(self): else: result.append(f"Gas: {gas}") + """ + vals = SelectedSolver.instance().get_all_values(self.constraints, + self.data[0:4], maxcnt=2, + silent=True) + result.append(f"Data: {vals}") + """ return "\n".join(hex(self.address) + ": " + x for x in result) @@ -2414,11 +2486,11 @@ class EVMWorld(Platform): } def __init__(self, constraints, fork=DEFAULT_FORK, **kwargs): - super().__init__(path="NOPATH", **kwargs) + super().__init__(**kwargs) self._world_state = {} self._constraints = constraints self._callstack: List[ - Tuple[Transaction, List[EVMLog], Set[int], Union[bytearray, ArrayProxy], EVM] + Tuple[Transaction, List[EVMLog], Set[int], Union[bytearray, MutableArray], EVM] ] = [] self._deleted_accounts: Set[int] = set() self._logs: List[EVMLog] = list() @@ -2457,6 +2529,14 @@ def __setstate__(self, state): for _, _, _, _, vm in self._callstack: self.forward_events_from(vm) + @property + def constraints(self): + return self._constraints + + @constraints.setter + def constraints(self, constraints): + self._constraints = constraints + def try_simplify_to_constant(self, data): concrete_data = bytearray() # for c in data: @@ -2527,21 +2607,15 @@ def __str__(self): def logs(self): return self._logs - @property - def constraints(self): - return self._constraints - - @constraints.setter - def constraints(self, constraints): - self._constraints = constraints - if self.current_vm: - self.current_vm.constraints = constraints - @property def evmfork(self): return self._fork def _transaction_fee(self, sort, address, price, bytecode_or_data, caller, value): + return 21000 + if consts.tx_fee == "ignore": + return 0 + GTXCREATE = ( 32000 # Paid by all contract creating transactions after the Homestead transition. ) @@ -2553,25 +2627,27 @@ def _transaction_fee(self, sort, address, price, bytecode_or_data, caller, value else: tx_fee = GTRANSACTION # Simple transaction fee - zerocount = 0 - nonzerocount = 0 - if isinstance(bytecode_or_data, (Array, ArrayProxy)): - # if nothing was written we can assume all elements are default to zero - if len(bytecode_or_data.written) == 0: - zerocount = len(bytecode_or_data) - else: - for index in range(len(bytecode_or_data)): - try: - c = bytecode_or_data.get(index, 0) - except AttributeError: - c = bytecode_or_data[index] + # This is INSANE TODO FIXME + # This popcnt like thing is expensive when the bytecode or + # data has symbolic content - zerocount += Operators.ITEBV(256, c == 0, 1, 0) - nonzerocount += Operators.ITEBV(256, c == 0, 0, 1) + size = 1 + while 2 ** size < len(bytecode_or_data): + size += 1 + if size > 512: + raise Exception("hahaha") + zerocount = 0 + len_bytecode_or_data = len(bytecode_or_data) + for index in range(len_bytecode_or_data): + c = bytecode_or_data[index] + zerocount += Operators.ITEBV(2 ** size, c == 0, 1, 0) + + nonzerocount = len_bytecode_or_data - zerocount tx_fee += zerocount * GTXDATAZERO tx_fee += nonzerocount * GTXDATANONZERO - return simplify(tx_fee) + + return Operators.ZEXTEND(tx_fee, 512) def _make_vm_for_tx(self, tx): if tx.sort == "CREATE": @@ -2595,7 +2671,7 @@ def _make_vm_for_tx(self, tx): gas = tx.gas - vm = EVM(self._constraints, address, data, caller, value, bytecode, world=self, gas=gas) + vm = EVM(address, data, caller, value, bytecode, world=self, gas=gas) if self.depth == 0: # Only at human level we need to debit the tx_fee from the gas # In case of an internal tx the CALL-like instruction will @@ -2623,11 +2699,11 @@ def _open_transaction(self, sort, address, price, bytecode_or_data, caller, valu if sort not in {"CALL", "CREATE", "DELEGATECALL", "CALLCODE", "STATICCALL"}: raise EVMException(f"Transaction type '{sort}' not supported") - if caller not in self.accounts: - logger.info("Caller not in account") - raise EVMException( - f"Caller account {hex(caller)} does not exist; valid accounts: {list(map(hex, self.accounts))}" - ) + # if caller not in self.accounts: + # logger.info("Caller not in account") + # raise EVMException( + # f"Caller account {hex(caller)} does not exist; valid accounts: {list(map(hex, self.accounts))}" + # ) if sort == "CREATE": expected_address = self.new_address(sender=caller) @@ -2675,7 +2751,7 @@ def _open_transaction(self, sort, address, price, bytecode_or_data, caller, valu def _close_transaction(self, result, data=None, rollback=False): self._publish("will_close_transaction", self._callstack[-1][0]) tx, logs, deleted_accounts, account_storage, vm = self._callstack.pop() - assert self.constraints == vm.constraints + # Keep constraints gathered in the last vm self.constraints = vm.constraints @@ -2840,8 +2916,7 @@ def get_storage_data(self, storage_address, offset): :return: the value :rtype: int or BitVec """ - value = self._world_state[storage_address]["storage"].get(offset, 0) - return simplify(value) + return self._world_state[storage_address]["storage"].select(offset) def set_storage_data(self, storage_address, offset, value): """ @@ -2891,12 +2966,14 @@ def get_storage(self, address): :param address: account address :return: account storage - :rtype: bytearray or ArrayProxy + :rtype: bytearray or MutableArray """ return self._world_state[address]["storage"] def _set_storage(self, address, storage): """Private auxiliary function to replace the storage""" + if not isinstance(storage, MutableArray) or storage.default != 0: + raise TypeError self._world_state[address]["storage"] = storage def get_nonce(self, address): @@ -2914,12 +2991,12 @@ def increase_nonce(self, address): def set_balance(self, address, value): if isinstance(value, BitVec): value = Operators.ZEXTEND(value, 512) - self._world_state[int(address)]["balance"] = value + self._world_state[int(address)]["balance"] = simplify(value) def get_balance(self, address): if address not in self._world_state: return 0 - return Operators.EXTRACT(self._world_state[address]["balance"], 0, 256) + return simplify(Operators.EXTRACT(self._world_state[address]["balance"], 0, 256)) def account_exists(self, address): if address not in self._world_state: @@ -3136,17 +3213,19 @@ def create_account(self, address=None, balance=0, code=None, storage=None, nonce if storage is None: # Uninitialized values in a storage are 0 by spec - storage = self.constraints.new_array( - index_bits=256, - value_bits=256, - name=f"STORAGE_{address:x}", - avoid_collisions=True, - default=0, + storage = MutableArray( + self.constraints.new_array( + index_size=256, + value_size=256, + name=f"STORAGE_{address:x}", + avoid_collisions=True, + default=0, + ) ) else: - if isinstance(storage, ArrayProxy): + if isinstance(storage, MutableArray): if storage.index_bits != 256 or storage.value_bits != 256: - raise TypeError("An ArrayProxy 256bits -> 256bits is needed") + raise TypeError("An MutableArray 256bits -> 256bits is needed") else: if any((k < 0 or k >= 1 << 256 for k, v in storage.items())): raise TypeError( @@ -3217,6 +3296,10 @@ def start_transaction( """ assert self._pending_transaction is None, "Already started tx" assert caller is not None + if issymbolic(data): + assert data.length is not None + assert data.value_size == 8 + self._pending_transaction = PendingTransaction( sort, address, price, data, caller, value, gas, None ) @@ -3307,7 +3390,7 @@ def _pending_transaction_failed(self): sort, address, price, data, caller, value, gas, failed = self._pending_transaction # Initially the failed flag is not set. For now we need the caller to be - # concrete so the caller balance is easy to get. Initialize falied here + # concrete so the caller balance is easy to get. Initialize failed here if failed is None: # Check depth failed = self.depth >= 1024 @@ -3316,8 +3399,8 @@ def _pending_transaction_failed(self): aux_src_balance = Operators.ZEXTEND(self.get_balance(caller), 512) aux_value = Operators.ZEXTEND(value, 512) enough_balance = Operators.UGE(aux_src_balance, aux_value) - if self.depth == 0: - # take the gas from the balance + if self.depth == 0: # the tx_fee is taken at depth 0 + # take the gas from t"he balance aux_price = Operators.ZEXTEND(price, 512) aux_gas = Operators.ZEXTEND(gas, 512) aux_fee = aux_price * aux_gas @@ -3328,6 +3411,7 @@ def _pending_transaction_failed(self): failed = Operators.NOT(enough_balance) self._pending_transaction = sort, address, price, data, caller, value, gas, failed + # ok now failed exists ans it is initialized. Concretize or fork. if issymbolic(failed): # optimistic/pesimistic is inverted as the expresion represents fail policy = {"optimistic": "PESSIMISTIC", "pessimistic": "OPTIMISTIC"}.get( @@ -3354,14 +3438,6 @@ def set_failed(state, solution): policy=policy, ) - if self.depth != 0: - price = 0 - aux_price = Operators.ZEXTEND(price, 512) - aux_gas = Operators.ZEXTEND(gas, 512) - tx_fee = Operators.ITEBV(512, self.depth == 0, aux_price * aux_gas, 0) - aux_src_balance = Operators.ZEXTEND(self.get_balance(caller), 512) - aux_value = Operators.ZEXTEND(value, 512) - enough_balance = Operators.UGE(aux_src_balance, aux_value + tx_fee) return failed def _process_pending_transaction(self): diff --git a/manticore/platforms/linux.py b/manticore/platforms/linux.py index ae1fa758d..1b9a605a0 100644 --- a/manticore/platforms/linux.py +++ b/manticore/platforms/linux.py @@ -19,16 +19,16 @@ import os import random -from elftools.elf.descriptions import describe_symbol_type # Remove in favor of binary.py +from elftools.elf.descriptions import describe_symbol_type from elftools.elf.elffile import ELFFile from elftools.elf.sections import SymbolTableSection from . import linux_syscalls from .linux_syscall_stubs import SyscallStubs from ..core.state import TerminateState, Concretize -from ..core.smtlib import ConstraintSet, Operators, Expression, issymbolic, ArrayProxy +from ..core.smtlib import ConstraintSet, Operators, Expression, issymbolic, MutableArray from ..core.smtlib.solver import SelectedSolver from ..exceptions import SolverError from ..native.cpu.abstractcpu import Cpu, Syscall, ConcretizeArgument, Interruption @@ -43,7 +43,7 @@ InvalidMemoryAccess, ) from ..native.state import State -from ..platforms.platform import Platform, SyscallNotImplemented, unimplemented +from ..platforms.platform import NativePlatform, SyscallNotImplemented, unimplemented from typing import cast, Any, Deque, Dict, IO, Iterable, List, Optional, Set, Tuple, Union, Callable @@ -483,7 +483,7 @@ def __init__( # build the constraints array size = len(data) - self.array = constraints.new_array(name=self.name, index_max=size) + self.array = MutableArray(constraints.new_array(name=self.name, length=size)) symbols_cnt = 0 for i in range(size): @@ -737,7 +737,7 @@ def __init__( self.symb_name = name self.max_recv_symbolic = max_recv_symbolic # 0 for unlimited. Unlimited is not tested # Keep track of the symbolic inputs we create - self.inputs_recvd: List[ArrayProxy] = [] + self.inputs_recvd: List[MutableArray] = [] self.recv_pos = 0 # This is a meta-variable, of sorts, and it is responsible for # determining the symbolic length of the array during recv/read. @@ -776,7 +776,7 @@ def _next_symb_name(self) -> str: """ return f"{self.symb_name}-{len(self.inputs_recvd)}" - def receive(self, size: int) -> Union[ArrayProxy, List[bytes]]: + def receive(self, size: int) -> Union[MutableArray, List[bytes]]: """ Return a symbolic array of either `size` or rest of remaining symbolic bytes :param size: Size of receive @@ -794,7 +794,7 @@ def receive(self, size: int) -> Union[ArrayProxy, List[bytes]]: # Then do some forking with self._symb_len if self._symb_len is None: self._symb_len = self._constraints.new_bitvec( - 8, "_socket_symb_len", avoid_collisions=True + size=8, name="_socket_symb_len", avoid_collisions=True ) self._constraints.add(Operators.AND(self._symb_len >= 1, self._symb_len <= rx_bytes)) @@ -809,7 +809,7 @@ def setstate(state: State, value): policy="MINMAX", ) ret = self._constraints.new_array( - name=self._next_symb_name(), index_max=self._symb_len, avoid_collisions=True + name=self._next_symb_name(), length=self._symb_len, avoid_collisions=True ) logger.info(f"Setting recv symbolic length to {self._symb_len}") self.recv_pos += self._symb_len @@ -819,7 +819,7 @@ def setstate(state: State, value): return ret -class Linux(Platform): +class Linux(NativePlatform): """ A simple Linux Operating System Platform. This class emulates the most common Linux system calls diff --git a/manticore/platforms/platform.py b/manticore/platforms/platform.py index 9be3edcb3..e92194db9 100644 --- a/manticore/platforms/platform.py +++ b/manticore/platforms/platform.py @@ -1,9 +1,10 @@ import logging - from functools import wraps -from typing import Any, Callable, TypeVar +from typing import Any, Callable, TypeVar, Optional from ..utils.event import Eventful +from ..core.state import StateBase +from ..native.cpu.abstractcpu import Cpu logger = logging.getLogger(__name__) @@ -48,13 +49,21 @@ class Platform(Eventful): Base class for all platforms e.g. operating systems or virtual machines. """ + current: Any + _published_events = {"solve"} - def __init__(self, path, **kwargs): + def __init__(self, *, state: Optional[StateBase] = None, **kwargs): + self._state = state super().__init__(**kwargs) - def invoke_model(self, model, prefix_args=None): - self._function_abi.invoke(model, prefix_args) + def set_state(self, state: StateBase): + self._state = state + state.forward_events_from(self) + + @property + def constraints(self): + return self._state._constraints def __setstate__(self, state): super().__setstate__(state) @@ -63,5 +72,13 @@ def __getstate__(self): state = super().__getstate__() return state + +class NativePlatform(Platform): + def __init__(self, path, **kwargs): + super().__init__(**kwargs) + + def invoke_model(self, model, prefix_args=None): + self._function_abi.invoke(model, prefix_args) + def generate_workspace_files(self): return {} diff --git a/manticore/platforms/wasm.py b/manticore/platforms/wasm.py index 8b47dfaee..53039db16 100644 --- a/manticore/platforms/wasm.py +++ b/manticore/platforms/wasm.py @@ -1,4 +1,4 @@ -from .platform import Platform +from .platform import NativePlatform from ..wasm.structure import ( ModuleInstance, Store, @@ -33,7 +33,7 @@ def stub(arity, _state, *args): return [0 for _ in range(arity)] # TODO: Return symbolic values -class WASMWorld(Platform): +class WASMWorld(NativePlatform): """Manages global environment for a WASM state. Analagous to EVMWorld.""" def __init__(self, filename, name="self", **kwargs): @@ -65,6 +65,14 @@ def __init__(self, filename, name="self", **kwargs): self.forward_events_from(self.instance) self.forward_events_from(self.instance.executor) + @property + def constraints(self): + return self._constraints + + @constraints.setter + def constraints(self, constraints): + self._constraints = constraints + def __getstate__(self): state = super().__getstate__() state["modules"] = self.modules @@ -83,7 +91,7 @@ def __setstate__(self, state): self.store = state["store"] self.stack = state["stack"] self.advice = state["advice"] - self.constraints = state["constraints"] + self._constraints = state["constraints"] self.instantiated = state["instantiated"] self.module_names = state["module_names"] self.default_module = state["default_module"] diff --git a/manticore/wasm/manticore.py b/manticore/wasm/manticore.py index a90f89463..5e313dc21 100644 --- a/manticore/wasm/manticore.py +++ b/manticore/wasm/manticore.py @@ -244,6 +244,6 @@ def _make_wasm_bin(program, env={}, sup_env={}, **kwargs) -> State: exec_start=kwargs.get("exec_start", False), stub_missing=kwargs.get("stub_missing", True), ) - initial_state = State(constraints, platform) + initial_state = State(constraints=constraints, platform=platform) return initial_state diff --git a/manticore/wasm/structure.py b/manticore/wasm/structure.py index 0488524b8..f10f5b8bd 100644 --- a/manticore/wasm/structure.py +++ b/manticore/wasm/structure.py @@ -1797,7 +1797,7 @@ def has_type_on_top(self, t: typing.Union[type, typing.Tuple[type, ...]], n: int """ *Asserts* that the stack has at least n values of type t or type BitVec on the top - :param t: type of value to look for (Bitvec is always included as an option) + :param t: type of value to look for (BitVec is always included as an option) :param n: Number of values to check :return: True """ diff --git a/mypy.ini b/mypy.ini index cbe17228e..8597e3660 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,44 +1,16 @@ [mypy] python_version = 3.6 -files = manticore, tests +files = manticore, tests, examples/**/*.py +ignore_missing_imports = True +# TODO: LOTS OF ERRORS +# check_untyped_defs = True + +[mypy-manticore.core.smtlib.*] +check_untyped_defs = True # Generated file [mypy-manticore.ethereum.parsetab] ignore_errors = True -# 3rd-party libraries with no typing information -[mypy-capstone.*] -ignore_missing_imports = True - -[mypy-crytic_compile.*] -ignore_missing_imports = True - -[mypy-elftools.*] -ignore_missing_imports = True - -[mypy-sha3.*] -ignore_missing_imports = True - -[mypy-pyevmasm.*] -ignore_missing_imports = True - -[mypy-unicorn.*] -ignore_missing_imports = True - -[mypy-keystone.*] -ignore_missing_imports = True - -[mypy-ply.*] -ignore_missing_imports = True - -[mypy-rlp.*] -ignore_missing_imports = True - -[mypy-prettytable.*] -ignore_missing_imports = True - -[mypy-wasm.*] -ignore_missing_imports = True - [mypy-manticore.core.state_pb2] ignore_errors = True diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index 9957a2182..508779e1a 100755 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -114,7 +114,7 @@ run_tests_from_dir() { DIR=$1 COVERAGE_RCFILE=$GITHUB_WORKSPACE/.coveragerc echo "Running only the tests from 'tests/$DIR' directory" - pytest --durations=100 --cov=manticore --cov-config=$GITHUB_WORKSPACE/.coveragerc -n auto "tests/$DIR" + pytest --timeout=1200 --durations=100 --cov=manticore --cov-config=$GITHUB_WORKSPACE/.coveragerc -n auto "tests/$DIR" RESULT=$? return $RESULT } diff --git a/setup.py b/setup.py index 350b66311..4377f0e35 100644 --- a/setup.py +++ b/setup.py @@ -18,14 +18,22 @@ def rtd_dependent_deps(): # (we need to know how to import a given native dependency so we can check if native dependencies are installed) native_deps = ["capstone==4.0.1", "pyelftools", "unicorn==1.0.2rc2"] -lint_deps = ["black==20.8b1", "mypy==0.790"] +lint_deps = ["black==20.8b1", "mypy==0.812"] auto_test_deps = ["py-evm"] # Development dependencies without keystone dev_noks = ( native_deps - + ["coverage", "Sphinx", "pytest==5.3.0", "pytest-xdist==1.30.0", "pytest-cov==2.8.1", "jinja2"] + + [ + "coverage", + "Sphinx", + "pytest", + "pytest-timeout", + "pytest-xdist", + "pytest-cov", + "jinja2", + ] + lint_deps + auto_test_deps ) @@ -72,6 +80,7 @@ def rtd_dependent_deps(): "wasm", "dataclasses; python_version < '3.7'", "pyevmasm>=0.2.3", + "toposort", ] + rtd_dependent_deps(), extras_require=extra_require, diff --git a/tests/auto_generators/flags.py b/tests/auto_generators/flags.py index 62e8f7eeb..99f026ef2 100644 --- a/tests/auto_generators/flags.py +++ b/tests/auto_generators/flags.py @@ -1,4 +1,6 @@ -flags = { +from typing import Dict, List + +flags: Dict[str, Dict[str, List[str]]] = { "AAA": { "undefined": ["OF", "SF", "ZF", "PF"], "defined": ["AF", "CF"], diff --git a/tests/auto_generators/make_dump.py b/tests/auto_generators/make_dump.py index 259342ad3..76e3dd4dd 100644 --- a/tests/auto_generators/make_dump.py +++ b/tests/auto_generators/make_dump.py @@ -5,8 +5,10 @@ import sys import time import subprocess -from capstone import * -from capstone.x86 import * +from typing import Any, Dict +from capstone import Cs +from capstone.x86 import CS_ARCH_X86, CS_MODE_32, CS_MODE_64, X86_OP_MEM, X86_OP_REG, X86_OP_IMM +import capstone.x86 as csr from flags import flags flags_maks = { @@ -228,7 +230,7 @@ def read_operand(o): groups = map(instruction.group_name, instruction.groups) PC = {"i386": "EIP", "amd64": "RIP"}[arch] - registers = {PC: gdb.getR(PC)} + registers: Dict[Any, Any] = {PC: gdb.getR(PC)} memory = {} # save the encoded instruction @@ -246,11 +248,11 @@ def read_operand(o): if instruction.insn_name().upper() in ["PUSHF", "PUSHFD"]: registers["EFLAGS"] = gdb.getR("EFLAGS") - if instruction.insn_name().upper() in ["XLAT", "XLATB"]: - registers["AL"] = gdb.getR("AL") - registers[B] = gdb.getR(B) - address = registers[B] + registers["AL"] - memory[address] = chr(gdb.getByte(address)) + # if instruction.insn_name().upper() in ["XLAT", "XLATB"]: + # registers["AL"] = gdb.getR("AL") + # registers[B] = gdb.getR(B) + # address = registers[B] + registers["AL"] + # memory[address] = chr(gdb.getByte(address)) if instruction.insn_name().upper() in ["BTC", "BTR", "BTS", "BT"]: if instruction.operands[0].type == X86_OP_MEM: @@ -310,34 +312,34 @@ def read_operand(o): # registers[reg_name] = gdb.getR(reg_name) reg_sizes = { - X86_REG_AH: X86_REG_AX, - X86_REG_AL: X86_REG_AX, - X86_REG_AX: X86_REG_EAX, - X86_REG_EAX: X86_REG_RAX, - X86_REG_RAX: X86_REG_INVALID, - X86_REG_BH: X86_REG_BX, - X86_REG_BL: X86_REG_BX, - X86_REG_BX: X86_REG_EBX, - X86_REG_EBX: X86_REG_RBX, - X86_REG_RBX: X86_REG_INVALID, - X86_REG_CH: X86_REG_CX, - X86_REG_CL: X86_REG_CX, - X86_REG_CX: X86_REG_ECX, - X86_REG_ECX: X86_REG_RCX, - X86_REG_RCX: X86_REG_INVALID, - X86_REG_DH: X86_REG_DX, - X86_REG_DL: X86_REG_DX, - X86_REG_DX: X86_REG_EDX, - X86_REG_EDX: X86_REG_RDX, - X86_REG_RDX: X86_REG_INVALID, - X86_REG_DIL: X86_REG_EDI, - X86_REG_DI: X86_REG_EDI, - X86_REG_EDI: X86_REG_RDI, - X86_REG_RDI: X86_REG_INVALID, - X86_REG_SIL: X86_REG_ESI, - X86_REG_SI: X86_REG_ESI, - X86_REG_ESI: X86_REG_RSI, - X86_REG_RSI: X86_REG_INVALID, + csr.X86_REG_AH: csr.X86_REG_AX, + csr.X86_REG_AL: csr.X86_REG_AX, + csr.X86_REG_AX: csr.X86_REG_EAX, + csr.X86_REG_EAX: csr.X86_REG_RAX, + csr.X86_REG_RAX: csr.X86_REG_INVALID, + csr.X86_REG_BH: csr.X86_REG_BX, + csr.X86_REG_BL: csr.X86_REG_BX, + csr.X86_REG_BX: csr.X86_REG_EBX, + csr.X86_REG_EBX: csr.X86_REG_RBX, + csr.X86_REG_RBX: csr.X86_REG_INVALID, + csr.X86_REG_CH: csr.X86_REG_CX, + csr.X86_REG_CL: csr.X86_REG_CX, + csr.X86_REG_CX: csr.X86_REG_ECX, + csr.X86_REG_ECX: csr.X86_REG_RCX, + csr.X86_REG_RCX: csr.X86_REG_INVALID, + csr.X86_REG_DH: csr.X86_REG_DX, + csr.X86_REG_DL: csr.X86_REG_DX, + csr.X86_REG_DX: csr.X86_REG_EDX, + csr.X86_REG_EDX: csr.X86_REG_RDX, + csr.X86_REG_RDX: csr.X86_REG_INVALID, + csr.X86_REG_DIL: csr.X86_REG_EDI, + csr.X86_REG_DI: csr.X86_REG_EDI, + csr.X86_REG_EDI: csr.X86_REG_RDI, + csr.X86_REG_RDI: csr.X86_REG_INVALID, + csr.X86_REG_SIL: csr.X86_REG_ESI, + csr.X86_REG_SI: csr.X86_REG_ESI, + csr.X86_REG_ESI: csr.X86_REG_RSI, + csr.X86_REG_RSI: csr.X86_REG_INVALID, } # There is a capstone branch that should fix all these annoyances... soon # https://github.com/aquynh/capstone/tree/next @@ -387,7 +389,7 @@ def read_operand(o): registers[reg_name] = gdb.getR(reg_name) address += o.mem.scale * registers[reg_name] address = address & ({"i386": 0xFFFFFFFF, "amd64": 0xFFFFFFFFFFFFFFFF}[arch]) - for i in xrange(address, address + o.size): + for i in range(address, address + o.size): memory[i] = chr(gdb.getByte(i)) # gather PRE info diff --git a/tests/auto_generators/make_tests.py b/tests/auto_generators/make_tests.py index 318b35a27..b56af1ae7 100644 --- a/tests/auto_generators/make_tests.py +++ b/tests/auto_generators/make_tests.py @@ -1,5 +1,6 @@ from __future__ import print_function import sys +from typing import Any, Dict import random tests = [] @@ -15,7 +16,7 @@ random.shuffle(tests) -op_count = {} +op_count: Dict[str, int] = {} test_dic = {} for test in tests: try: diff --git a/tests/ethereum/contracts/simple_int_overflow.sol b/tests/ethereum/contracts/simple_int_overflow.sol index 419c2a1c6..8460a0018 100644 --- a/tests/ethereum/contracts/simple_int_overflow.sol +++ b/tests/ethereum/contracts/simple_int_overflow.sol @@ -1,7 +1,7 @@ -pragma solidity ^0.4.15; +pragma solidity ^0.4.24; contract Overflow { - uint private sellerBalance=0; + uint private sellerBalance=10; function add(uint value) public { sellerBalance += value; // complicated math with possible overflow diff --git a/tests/ethereum/test_detectors.py b/tests/ethereum/test_detectors.py index b0431bf22..8f17a0c45 100644 --- a/tests/ethereum/test_detectors.py +++ b/tests/ethereum/test_detectors.py @@ -29,15 +29,13 @@ from typing import Tuple, Type -consts = config.get_group("core") -consts.mprocessing = consts.mprocessing.single THIS_DIR = os.path.dirname(os.path.abspath(__file__)) def make_mock_evm_state(): cs = ConstraintSet() - fakestate = State(cs, EVMWorld(cs)) + fakestate = State(constraints=cs, platform=EVMWorld(cs)) return fakestate @@ -47,6 +45,8 @@ class EthDetectorTest(unittest.TestCase): def setUp(self): self.mevm = ManticoreEVM() + consts = config.get_group("core") + consts.mprocessing = consts.mprocessing.single self.mevm.register_plugin(KeepOnlyIfStorageChanges()) log.set_verbosity(0) self.worksp = self.mevm.workspace @@ -63,7 +63,7 @@ def _test(self, name: str, should_find, use_ctor_sym_arg=False): dir = os.path.join(THIS_DIR, "contracts", "detectors") filepath = os.path.join(dir, f"{name}.sol") - + print(filepath) if use_ctor_sym_arg: ctor_arg: Tuple = (mevm.make_symbolic_value(),) else: @@ -71,7 +71,7 @@ def _test(self, name: str, should_find, use_ctor_sym_arg=False): self.mevm.register_detector(self.DETECTOR_CLASS()) - with self.mevm.kill_timeout(240): + with self.mevm.kill_timeout(48000): mevm.multi_tx_analysis( filepath, contract_name="DetectThis", diff --git a/tests/ethereum/test_general.py b/tests/ethereum/test_general.py index 799456298..2242ea1f3 100644 --- a/tests/ethereum/test_general.py +++ b/tests/ethereum/test_general.py @@ -15,8 +15,8 @@ from manticore import ManticoreError from manticore.core.plugin import Plugin from manticore.core.smtlib import ConstraintSet, operators -from manticore.core.smtlib import Z3Solver -from manticore.core.smtlib.expression import BitVec +from manticore.core.smtlib import SelectedSolver +from manticore.core.smtlib.expression import BitVecVariable from manticore.core.smtlib.visitors import to_constant from manticore.core.state import TerminateState from manticore.ethereum import ( @@ -41,7 +41,7 @@ import contextlib -solver = Z3Solver.instance() +solver = SelectedSolver.instance() THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -379,7 +379,7 @@ def test_serialize_fixed_bytes_too_big(self): # test serializing symbolic buffer with bytesM def test_serialize_bytesM_symbolic(self): cs = ConstraintSet() - buf = cs.new_array(index_max=17) + buf = cs.new_array(length=17) ret = ABI.serialize("bytes32", buf) self.assertEqual(solver.minmax(cs, ret[0]), (0, 255)) self.assertEqual(solver.minmax(cs, ret[17]), (0, 0)) @@ -387,7 +387,7 @@ def test_serialize_bytesM_symbolic(self): # test serializing symbolic buffer with bytes def test_serialize_bytes_symbolic(self): cs = ConstraintSet() - buf = cs.new_array(index_max=17) + buf = cs.new_array(length=17) ret = ABI.serialize("bytes", buf) # does the offset field look right? @@ -412,10 +412,10 @@ def _make(self): price = 0 value = 10000 bytecode = b"\x05" - data = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + data = b"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" gas = 1000000 - new_vm = evm.EVM(constraints, address, data, caller, value, bytecode, gas=gas, world=world) + new_vm = evm.EVM(address, data, caller, value, bytecode, gas=gas, world=world) return constraints, world, new_vm def test_str(self): @@ -454,11 +454,12 @@ def test_SDIVS3(self): constraints, world, vm = self._make() xx = constraints.new_bitvec(256, name="x") yy = constraints.new_bitvec(256, name="y") - constraints.add(xx == 0x20) - constraints.add(yy == -1) + x, y = 0x20, -1 + constraints.add(xx == x) + constraints.add(yy == y) result = vm.SDIV(xx, yy) self.assertListEqual( - list(map(evm.to_signed, solver.get_all_values(constraints, result))), [-0x20] + list(map(evm.to_signed, solver.get_all_values(constraints, result))), [vm.SDIV(x, y)] ) def test_SDIVSx(self): @@ -466,14 +467,45 @@ def test_SDIVSx(self): constraints, world, vm = self._make() xx = constraints.new_bitvec(256, name="x") yy = constraints.new_bitvec(256, name="y") + zz = constraints.new_bitvec(256, name="z") constraints.add(xx == x) constraints.add(yy == y) result = vm.SDIV(xx, yy) + constraints.add(zz == result) + self.assertListEqual( + list(map(evm.to_signed, solver.get_all_values(constraints, zz))), [vm.SDIV(x, y)] + ) self.assertListEqual( list(map(evm.to_signed, solver.get_all_values(constraints, result))), [vm.SDIV(x, y)] ) + def test_to_sig(self): + self.assertEqual(evm.to_signed(-1), -1) + self.assertEqual(evm.to_signed(1), 1) + self.assertEqual(evm.to_signed(0), 0) + self.assertEqual(evm.to_signed(-2), -2) + self.assertEqual(evm.to_signed(2), 2) + self.assertEqual(evm.to_signed(0), 0) + self.assertEqual( + evm.to_signed(0x8000000000000000000000000000000000000000000000000000000000000001), + -0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, + ) + self.assertEqual( + evm.to_signed(0x8000000000000000000000000000000000000000000000000000000000000002), + -0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE, + ) + self.assertEqual( + evm.to_signed(0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF), + 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, + ) + self.assertEqual( + evm.to_signed(0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF), -1 + ) + self.assertEqual( + evm.to_signed(0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE), -2 + ) + class EthTests(unittest.TestCase): def setUp(self): @@ -776,7 +808,7 @@ def test_gen_testcase_only_if(self): for ext in ("summary", "constraints", "pkl", "tx.json", "tx", "trace", "logs") } - expected_files.add("state_00000001.pkl") + expected_files.add("state_00000002.pkl") actual_files = set((fn for fn in os.listdir(self.mevm.workspace) if not fn.startswith("."))) self.assertEqual(actual_files, expected_files) @@ -817,9 +849,11 @@ def test_function_name_with_signature(self): self.mevm.make_symbolic_value(), signature="(uint256,uint256)", ) + z = None for st in self.mevm.all_states: z = st.solve_one(st.platform.transactions[1].return_data) break + self.assertIsNot(z, None) self.assertEqual(ABI.deserialize("(uint256)", z)[0], 2) def test_migrate_integration(self): @@ -1293,7 +1327,6 @@ def will_decode_instruction_callback(self, state, pc): func_name, args = ABI.deserialize( "shutdown(string)", state.platform.current_transaction.data ) - print("Shutdown", to_constant(args[0])) self.manticore.shutdown() elif func_id == ABI.function_selector("can_be_true(bool)"): func_name, args = ABI.deserialize( @@ -1382,7 +1415,7 @@ def will_evm_execute_instruction_callback(self, state, i, *args, **kwargs): class EthHelpersTest(unittest.TestCase): def setUp(self): - self.bv = BitVec(256) + self.bv = BitVecVariable(size=256, name="A") def test_concretizer(self): policy = "SOME_NONSTANDARD_POLICY" @@ -1647,7 +1680,7 @@ def test_jmpdest_check(self): bytecode = world.get_code(address) gas = 100000 - new_vm = evm.EVM(constraints, address, data, caller, value, bytecode, world=world, gas=gas) + new_vm = evm.EVM(address, data, caller, value, bytecode, world=world, gas=gas) result = None returndata = "" diff --git a/tests/ethereum/test_regressions.py b/tests/ethereum/test_regressions.py index f55ff457a..a4c1bb99c 100644 --- a/tests/ethereum/test_regressions.py +++ b/tests/ethereum/test_regressions.py @@ -279,7 +279,7 @@ def test_addmod(self): caller = 0x42424242424242424242 value = 0 bytecode = "" - vm = evm.EVM(constraints, address, data, caller, value, bytecode, gas=23000) + vm = evm.EVM(address, data, caller, value, bytecode, evm.EVMWorld(constraints), gas=23000) self.assertEqual(vm.ADDMOD(12323, 2343, 20), 6) self.assertEqual(vm.ADDMOD(12323, 2343, 0), 0) @@ -336,7 +336,7 @@ def test_mulmod(self): caller = 0x42424242424242424242 value = 0 bytecode = "" - vm = evm.EVM(constraints, address, data, caller, value, bytecode, gas=23000) + vm = evm.EVM(address, data, caller, value, bytecode, evm.EVMWorld(constraints), gas=23000) self.assertEqual(vm.MULMOD(12323, 2343, 20), 9) self.assertEqual(vm.MULMOD(12323, 2343, 0), 0) diff --git a/tests/native/symfile b/tests/native/symfile new file mode 100644 index 000000000..9b26e9b10 --- /dev/null +++ b/tests/native/symfile @@ -0,0 +1 @@ ++ \ No newline at end of file diff --git a/tests/native/test_armv7unicorn.py b/tests/native/test_armv7unicorn.py index 41574328c..86fdd9efb 100644 --- a/tests/native/test_armv7unicorn.py +++ b/tests/native/test_armv7unicorn.py @@ -1493,7 +1493,7 @@ def get_state(cls): constraints = ConstraintSet() dirname = os.path.dirname(__file__) platform = linux.SLinux(os.path.join(dirname, "binaries", "basic_linux_amd64")) - cls.state = State(constraints, platform) + cls.state = State(constraints=constraints, platform=platform) cls.cpu = platform._mk_proc("armv7") return (cls.cpu, cls.state) diff --git a/tests/native/test_cpu_automatic.py b/tests/native/test_cpu_automatic.py index a3d97f3fd..d84009ae6 100644 --- a/tests/native/test_cpu_automatic.py +++ b/tests/native/test_cpu_automatic.py @@ -12674,6 +12674,7 @@ def test_ADD_1_symbolic(self): with cs as temp_cs: temp_cs.add(condition) + print(temp_cs) self.assertTrue(solver.check(temp_cs)) with cs as temp_cs: temp_cs.add(condition == False) diff --git a/tests/native/test_cpu_manual.py b/tests/native/test_cpu_manual.py index 15daf9ee9..2c4a6cc69 100644 --- a/tests/native/test_cpu_manual.py +++ b/tests/native/test_cpu_manual.py @@ -7,12 +7,12 @@ from manticore.native.cpu.x86 import AMD64Cpu from manticore.native.memory import * from manticore.core.smtlib import BitVecOr, operator, Bool -from manticore.core.smtlib.solver import Z3Solver +from manticore.core.smtlib.solver import SelectedSolver from functools import reduce from typing import List -solver = Z3Solver.instance() +solver = SelectedSolver.instance() sizes = { diff --git a/tests/native/test_driver.py b/tests/native/test_driver.py index e516fc012..a22eb2f44 100644 --- a/tests/native/test_driver.py +++ b/tests/native/test_driver.py @@ -26,7 +26,7 @@ def testCreating(self): m.log_file = "/dev/null" def test_issymbolic(self): - v = BitVecVariable(32, "sym") + v = BitVecVariable(size=32, name="sym") self.assertTrue(issymbolic(v)) def test_issymbolic_neg(self): diff --git a/tests/native/test_integration_native.py b/tests/native/test_integration_native.py index 2821c559d..cd34ee9a4 100644 --- a/tests/native/test_integration_native.py +++ b/tests/native/test_integration_native.py @@ -59,23 +59,22 @@ def test_timeout(self) -> None: workspace = os.path.join(self.test_dir, "workspace") t = time.time() with open(os.path.join(os.pardir, self.test_dir, "output.log"), "w") as output: - subprocess.check_call( - [ - "coverage", - "run", # PYTHON_BIN, - "-m", - "manticore", - "--workspace", - workspace, - "--core.timeout", - "1", - "--core.procs", - "4", - filename, - "+++++++++", - ], - stdout=output, - ) + cmd = [ + "coverage", + "run", # PYTHON_BIN, + "-m", + "manticore", + "--workspace", + workspace, + "--core.timeout", + "1", + "--core.procs", + "4", + filename, + "+++++++++", + ] + + subprocess.check_call(cmd, stdout=output) self.assertTrue(time.time() - t < 20) @@ -129,6 +128,7 @@ def _test_arguments_assertions_aux( cmd += ["--assertions", assertions] cmd += [filename, "+++++++++"] + output = subprocess.check_output(cmd).splitlines() # self.assertIn(b'm.c.manticore:INFO: Verbosity set to 1.', output[0]) diff --git a/tests/native/test_manticore.py b/tests/native/test_manticore.py index 1d91866fc..3245f0519 100644 --- a/tests/native/test_manticore.py +++ b/tests/native/test_manticore.py @@ -19,7 +19,7 @@ def setUp(self): def test_profiling_data(self): p = Profiler() - set_verbosity(0) + set_verbosity(1) self.m.register_plugin(p) self.m.run() self.m.finalize() diff --git a/tests/native/test_manticore_logger.py b/tests/native/test_manticore_logger.py new file mode 100644 index 000000000..635028093 --- /dev/null +++ b/tests/native/test_manticore_logger.py @@ -0,0 +1,24 @@ +import unittest +import os +import logging +import filecmp + +from manticore.native import Manticore +from manticore.utils.log import get_verbosity, set_verbosity + + +class ManticoreLogger(unittest.TestCase): + """Make sure we set the logging levels correctly""" + + _multiprocess_can_split_ = False + + def test_logging(self): + set_verbosity(5) + self.assertEqual(get_verbosity("manticore.native.cpu.abstractcpu"), logging.DEBUG) + self.assertEqual(get_verbosity("manticore.ethereum.abi"), logging.DEBUG) + + set_verbosity(1) + self.assertEqual(get_verbosity("manticore.native.cpu.abstractcpu"), logging.WARNING) + self.assertEqual(get_verbosity("manticore.ethereum.abi"), logging.INFO) + + set_verbosity(0) # this is global and does not work in concurrent envs diff --git a/tests/native/test_models.py b/tests/native/test_models.py index 2b41cafbf..ffc531117 100644 --- a/tests/native/test_models.py +++ b/tests/native/test_models.py @@ -47,7 +47,7 @@ def f(): class ModelTest(unittest.TestCase): dirname = os.path.dirname(__file__) l = linux.SLinux(os.path.join(dirname, "binaries", "basic_linux_amd64")) - state = State(ConstraintSet(), l) + state = State(constraints=ConstraintSet(), platform=l) stack_top = state.cpu.RSP def _clear_constraints(self): diff --git a/tests/native/test_register.py b/tests/native/test_register.py index e099fc292..0663e811b 100644 --- a/tests/native/test_register.py +++ b/tests/native/test_register.py @@ -1,6 +1,6 @@ import unittest -from manticore.core.smtlib import Bool, BitVecConstant +from manticore.core.smtlib import Bool, BoolVariable, BitVecConstant from manticore.native.cpu.register import Register @@ -47,7 +47,7 @@ def test_bool_write_nonflag(self): def test_Bool(self): r = Register(32) - b = Bool() + b = BoolVariable(name="B") r.write(b) self.assertIs(r.read(), b) diff --git a/tests/native/test_state.py b/tests/native/test_state.py index d330760df..cf70c3f46 100644 --- a/tests/native/test_state.py +++ b/tests/native/test_state.py @@ -4,6 +4,7 @@ from contextlib import redirect_stdout from manticore.core.state import StateBase +from manticore.platforms.platform import Platform from manticore.utils.event import Eventful from manticore.platforms import linux from manticore.native.state import State @@ -37,7 +38,7 @@ def memory(self): return self._memory -class FakePlatform(Eventful): +class FakePlatform(Platform): def __init__(self): super().__init__() self._constraints = None @@ -75,31 +76,31 @@ class StateTest(unittest.TestCase): def setUp(self): dirname = os.path.dirname(__file__) l = linux.Linux(os.path.join(dirname, "binaries", "basic_linux_amd64")) - self.state = State(ConstraintSet(), l) + self.state = State(constraints=ConstraintSet(), platform=l) def test_solve_one(self): val = 42 - expr = BitVecVariable(32, "tmp") + expr = BitVecVariable(size=32, name="tmp") self.state.constrain(expr == val) solved = self.state.solve_one(expr) self.assertEqual(solved, val) def test_solve_n(self): - expr = BitVecVariable(32, "tmp") + expr = BitVecVariable(size=32, name="tmp") self.state.constrain(expr > 4) self.state.constrain(expr < 7) solved = sorted(self.state.solve_n(expr, 2)) self.assertEqual(solved, [5, 6]) def test_solve_n2(self): - expr = BitVecVariable(32, "tmp") + expr = BitVecVariable(size=32, name="tmp") self.state.constrain(expr > 4) self.state.constrain(expr < 100) solved = self.state.solve_n(expr, 5) self.assertEqual(len(solved), 5) def test_solve_min_max(self): - expr = BitVecVariable(32, "tmp") + expr = BitVecVariable(size=32, name="tmp") self.state.constrain(expr > 4) self.state.constrain(expr < 7) self.assertEqual(self.state.solve_min(expr), 5) @@ -107,7 +108,7 @@ def test_solve_min_max(self): self.assertEqual(self.state.solve_minmax(expr), (5, 6)) def test_policy_one(self): - expr = BitVecVariable(32, "tmp") + expr = BitVecVariable(size=32, name="tmp") self.state.constrain(expr > 0) self.state.constrain(expr < 100) solved = self.state.concretize(expr, "ONE") @@ -116,7 +117,7 @@ def test_policy_one(self): def test_state(self): constraints = ConstraintSet() - initial_state = State(constraints, FakePlatform()) + initial_state = State(constraints=constraints, platform=FakePlatform()) arr = initial_state.symbolicate_buffer("+" * 100, label="SYMBA") initial_state.constrain(arr[0] > 0x41) @@ -153,15 +154,15 @@ def test_new_bad_symbolic_value(self): def test_tainted_symbolic_buffer(self): taint = ("TEST_TAINT",) expr = self.state.new_symbolic_buffer(64, taint=taint) - self.assertEqual(expr.taint, frozenset(taint)) + self.assertEqual(frozenset(expr.taint), frozenset(taint)) def test_tainted_symbolic_value(self): taint = ("TEST_TAINT",) expr = self.state.new_symbolic_value(64, taint=taint) - self.assertEqual(expr.taint, frozenset(taint)) + self.assertEqual(frozenset(expr.taint), frozenset(taint)) def test_state_hook(self): - initial_state = State(ConstraintSet(), FakePlatform()) + initial_state = State(constraints=ConstraintSet(), platform=FakePlatform()) def fake_hook(_: StateBase) -> None: return None @@ -224,7 +225,7 @@ def testContextSerialization(self): new_file = "" new_new_file = "" constraints = ConstraintSet() - initial_state = State(constraints, FakePlatform()) + initial_state = State(constraints=constraints, platform=FakePlatform()) initial_state.context["step"] = 10 initial_file = pickle_dumps(initial_state) with initial_state as new_state: diff --git a/tests/native/test_workspace.py b/tests/native/test_workspace.py index c8af276b1..43f5ac0cd 100644 --- a/tests/native/test_workspace.py +++ b/tests/native/test_workspace.py @@ -60,7 +60,7 @@ def setUp(self): # self.manager.start(lambda: signal.signal(signal.SIGINT, signal.SIG_IGN)) dirname = os.path.dirname(__file__) l = linux.Linux(os.path.join(dirname, "binaries", "basic_linux_amd64")) - self.state = State(ConstraintSet(), l) + self.state = State(constraints=ConstraintSet(), platform=l) # self.lock = self.manager.Condition() def test_workspace_save_load(self): diff --git a/tests/other/data/ErrRelated.pkl.gz b/tests/other/data/ErrRelated.pkl.gz deleted file mode 100644 index eb1036567..000000000 Binary files a/tests/other/data/ErrRelated.pkl.gz and /dev/null differ diff --git a/tests/other/test_smtlibv2.py b/tests/other/test_smtlibv2.py index 3050b900f..c43571799 100644 --- a/tests/other/test_smtlibv2.py +++ b/tests/other/test_smtlibv2.py @@ -1,6 +1,7 @@ import unittest import os - +import sys +import pickle from manticore.core.smtlib import ( ConstraintSet, Version, @@ -14,7 +15,7 @@ replace, BitVecConstant, ) -from manticore.core.smtlib.solver import Z3Solver, YicesSolver, CVC4Solver +from manticore.core.smtlib.solver import Z3Solver, YicesSolver, CVC4Solver, SelectedSolver from manticore.core.smtlib.expression import * from manticore.utils.helpers import pickle_dumps from manticore import config @@ -26,82 +27,364 @@ DIRPATH = os.path.dirname(__file__) -class RegressionTest(unittest.TestCase): - def test_related_to(self): - import gzip - import pickle, sys +class ExpressionTestNew(unittest.TestCase): + _multiprocess_can_split_ = True + + def setUp(self): + self.solver = Z3Solver.instance() + + def assertItemsEqual(self, a, b): + # Required for Python3 compatibility + self.assertEqual(sorted(a), sorted(b)) + + def test_xslotted(self): + """Test that XSlotted multi inheritance classes uses same amount + of memory than a single class object with slots + """ + + class Base(object, metaclass=XSlotted, abstract=True): + __xslots__ = ("t",) + pass + + class A(Base, abstract=True): + __xslots__ = ("a",) + pass + + class B(Base, abstract=True): + __xslots__ = ("b",) + pass - filename = os.path.abspath(os.path.join(DIRPATH, "data", "ErrRelated.pkl.gz")) + class C(A, B): + pass - # A constraint set and a contraint caught in the act of making related_to fail - constraints, constraint = pickle.loads(gzip.open(filename, "rb").read()) + class X(object): + __slots__ = ("t", "a", "b") - Z3Solver.instance().can_be_true.cache_clear() - ground_truth = Z3Solver.instance().can_be_true(constraints, constraint) - self.assertEqual(ground_truth, False) + c = C() + c.a = 1 + c.t = 10 - Z3Solver.instance().can_be_true.cache_clear() + c.b = 2 + c.t = 10 + + x = X() + x.a = 1 + x.b = 2 + x.t = 20 + + self.assertEqual(sys.getsizeof(c), sys.getsizeof(x)) + + def test_BitVec_ops(self): + a = BitVecVariable(size=32, name="BV") + b = BitVecVariable(size=32, name="BV1") + c = BitVecVariable(size=32, name="BV2") + x = BitVecConstant(size=32, value=100, taint=("T",)) + z = (b + 1) % b < a * x / c - 5 + self.assertSetEqual(z.taint, set(("T",))) self.assertEqual( - ground_truth, - Z3Solver.instance().can_be_true(constraints.related_to(constraints), constraint), + translate_to_smtlib(z), + "(bvslt (bvsmod (bvadd BV1 #x00000001) BV1) (bvsub (bvsdiv (bvmul BV #x00000064) BV2) #x00000005))", + ) + z = (1 + b) / b <= a - x * 5 + c + self.assertSetEqual(z.taint, set(("T",))) + self.assertEqual( + translate_to_smtlib(z), + "(bvsle (bvsdiv (bvadd #x00000001 BV1) BV1) (bvadd (bvsub BV (bvmul #x00000064 #x00000005)) BV2))", ) - # Replace - new_constraint = Operators.UGE( - Operators.SEXTEND(BitVecConstant(256, 0x1A), 256, 512) * BitVecConstant(512, 1), - 0x00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000, + def test_ConstantArrayBitVec(self): + c = ArrayConstant(index_size=32, value_size=8, value=b"ABCDE") + self.assertEqual(c[0], ord("A")) + self.assertEqual(c[1], ord("B")) + self.assertEqual(c[2], ord("C")) + self.assertEqual(c[3], ord("D")) + self.assertEqual(c[4], ord("E")) + self.assertRaises(IndexError, c.__getitem__, 5) + + def test_ConstantArrayBitVec2(self): + c = MutableArray(ArrayVariable(index_size=32, value_size=8, length=5, name="ARR")) + c[1] = 10 + c[2] = 20 + c[3] = 30 + self.assertEqual(c[1], 10) + self.assertEqual(c[2], 20) + self.assertEqual(c[3], 30) + c[2] = 100 + self.assertEqual(c[2], 100) + + def test_ArrayDefault3(self): + c = MutableArray( + ArrayVariable(index_size=32, value_size=8, length=5, default=0, name="ARR") ) - self.assertEqual(translate_to_smtlib(constraint), translate_to_smtlib(new_constraint)) + self.assertEqual(c[1], 0) + self.assertEqual(c[2], 0) + self.assertEqual(c[3], 0) - Z3Solver.instance().can_be_true.cache_clear() - self.assertEqual(ground_truth, Z3Solver.instance().can_be_true(constraints, new_constraint)) + c[1] = 10 + c[3] = 30 + self.assertEqual(c[1], 10) + self.assertEqual(c[2], 0) + self.assertEqual(c[3], 30) - Z3Solver.instance().can_be_true.cache_clear() - self.assertEqual( - ground_truth, - Z3Solver.instance().can_be_true(constraints.related_to(new_constraint), new_constraint), + def test_ArrayDefault4(self): + cs = ConstraintSet() + a = MutableArray(cs.new_array(index_size=32, value_size=8, length=4, default=0, name="ARR")) + i = cs.new_bitvec(size=a.index_size) + SelectedSolver.instance().must_be_true(cs, 0 == a.default) + SelectedSolver.instance().must_be_true(cs, a[i] == a.default) + cs.add(i == 2) + SelectedSolver.instance().must_be_true(cs, 0 == a.default) + SelectedSolver.instance().must_be_true(cs, a[i] == a.default) + + b = a[:] + i = cs.new_bitvec(size=a.index_size) + SelectedSolver.instance().must_be_true(cs, 0 == b.default) + SelectedSolver.instance().must_be_true(cs, b[i] == b.default) + + a[1] = 10 + a[2] = 20 + a[3] = 30 + # a := 0 10 20 30 0 0 x x x x (x undefined) + SelectedSolver.instance().must_be_true(cs, a.default == 0) + SelectedSolver.instance().must_be_true(cs, a[0] == 0) + SelectedSolver.instance().must_be_true(cs, a[1] == 10) + SelectedSolver.instance().must_be_true(cs, a[2] == 20) + SelectedSolver.instance().must_be_true(cs, a[3] == 30) + # SelectedSolver.instance().must_be_true(cs, a[4] == 0) #undefined! + + b = a[:] + # b := 0 10 20 30 0 0 x x x x (x undefined) + SelectedSolver.instance().must_be_true(cs, b.default == 0) + SelectedSolver.instance().must_be_true(cs, b[0] == 0) + SelectedSolver.instance().must_be_true(cs, b[1] == 10) + SelectedSolver.instance().must_be_true(cs, b[2] == 20) + SelectedSolver.instance().must_be_true(cs, b[3] == 30) + + def test_Expression(self): + # Used to check if all Expression have test + checked = set() + + def check(ty, pickle_size=None, sizeof=None, **kwargs): + x = ty(**kwargs) + """ + print( + type(x), + "\n Pickle size:", + len(pickle_dumps(x)), + "\n GetSizeOf:", + sys.getsizeof(x), + "\n Slotted:", + not hasattr(x, "__dict__"), + ) + """ + # self.assertEqual(len(pickle_dumps(x)), pickle_size) + # self.assertEqual(sys.getsizeof(x), sizeof) + # The test numbers are taken from Python 3.8.5 older pythons use 8 more bytes sometimes + self.assertLessEqual(sys.getsizeof(x), sizeof + 8) + self.assertFalse(hasattr(x, "__dict__")) # slots! + self.assertTrue(hasattr(x, "_taint")) # taint! + checked.add(ty) + + # Can not instantiate an Expression + for ty in ( + Expression, + Constant, + Variable, + Operation, + BoolOperation, + BitVecOperation, + ArrayOperation, + BitVec, + Bool, + Array, + ): + self.assertRaises(Exception, ty, ()) + self.assertTrue(hasattr(ty, "__doc__")) + self.assertTrue(ty.__doc__, ty) + checked.add(ty) + + check(BitVecVariable, size=32, name="name", pickle_size=111, sizeof=56) + check(BoolVariable, name="name", pickle_size=99, sizeof=48) + check( + ArrayVariable, + index_size=32, + value_size=8, + length=100, + name="name", + pickle_size=156, + sizeof=80, ) + check(BitVecConstant, size=32, value=10, pickle_size=107, sizeof=56) + check(BoolConstant, value=False, pickle_size=94, sizeof=48) + + # TODO! But you can instantiate an ArraConstant + """ + x = ArrayConstant(index_size=32, value_size=8, b"AAAAAAAAAAAAAAA") + self.assertLessEqual(len(pickle_dumps(x)), 76) #master 71 + self.assertLessEqual(sys.getsizeof(x), 64) #master 56 + self.assertFalse(hasattr(x, '__dict__')) #slots! + """ + # But you can instantiate a BoolOr + x = BoolVariable(name="x") + y = BoolVariable(name="y") + z = BoolVariable(name="z") + check(BoolEqual, operanda=x, operandb=y, pickle_size=159, sizeof=48) + check(BoolAnd, operanda=x, operandb=y, pickle_size=157, sizeof=48) + check(BoolOr, operanda=x, operandb=y, pickle_size=156, sizeof=48) + check(BoolXor, operanda=x, operandb=y, pickle_size=157, sizeof=48) + check(BoolNot, operand=x, pickle_size=137, sizeof=48) + check(BoolITE, cond=z, true=x, false=y, pickle_size=130, sizeof=48) + + bvx = BitVecVariable(size=32, name="bvx") + bvy = BitVecVariable(size=32, name="bvy") + + check(BoolUnsignedGreaterThan, operanda=bvx, operandb=bvy, pickle_size=142, sizeof=48) + check(BoolGreaterThan, operanda=bvx, operandb=bvy, pickle_size=134, sizeof=48) + check( + BoolUnsignedGreaterOrEqualThan, operanda=bvx, operandb=bvy, pickle_size=149, sizeof=48 + ) + check(BoolGreaterOrEqualThan, operanda=bvx, operandb=bvy, pickle_size=141, sizeof=48) + check(BoolUnsignedLessThan, operanda=bvx, operandb=bvy, pickle_size=139, sizeof=48) + check(BoolLessThan, operanda=bvx, operandb=bvy, pickle_size=131, sizeof=48) + check(BoolUnsignedLessOrEqualThan, operanda=bvx, operandb=bvy, pickle_size=146, sizeof=48) + check(BoolLessOrEqualThan, operanda=bvx, operandb=bvy, pickle_size=138, sizeof=48) + + check(BitVecOr, operanda=bvx, operandb=bvy, pickle_size=129, sizeof=56) + check(BitVecXor, operanda=bvx, operandb=bvy, pickle_size=130, sizeof=56) + check(BitVecAnd, operanda=bvx, operandb=bvy, pickle_size=130, sizeof=56) + check(BitVecNot, operanda=bvx, pickle_size=112, sizeof=56) + check(BitVecNeg, operanda=bvx, pickle_size=112, sizeof=56) + check(BitVecAdd, operanda=bvx, operandb=bvy, pickle_size=130, sizeof=56) + check(BitVecMul, operanda=bvx, operandb=bvy, pickle_size=130, sizeof=56) + check(BitVecSub, operanda=bvx, operandb=bvy, pickle_size=130, sizeof=56) + check(BitVecDiv, operanda=bvx, operandb=bvy, pickle_size=130, sizeof=56) + check(BitVecMod, operanda=bvx, operandb=bvy, pickle_size=130, sizeof=56) + check(BitVecUnsignedDiv, operanda=bvx, operandb=bvy, pickle_size=138, sizeof=56) + check(BitVecRem, operanda=bvx, operandb=bvy, pickle_size=130, sizeof=56) + check(BitVecUnsignedRem, operanda=bvx, operandb=bvy, pickle_size=138, sizeof=56) + + check(BitVecShiftLeft, operanda=bvx, operandb=bvy, pickle_size=136, sizeof=56) + check(BitVecShiftRight, operanda=bvx, operandb=bvy, pickle_size=137, sizeof=56) + check(BitVecArithmeticShiftLeft, operanda=bvx, operandb=bvy, pickle_size=146, sizeof=56) + check(BitVecArithmeticShiftRight, operanda=bvx, operandb=bvy, pickle_size=147, sizeof=56) + + check(BitVecZeroExtend, operand=bvx, size=122, pickle_size=119, sizeof=56) + check(BitVecSignExtend, operand=bvx, size=122, pickle_size=119, sizeof=56) + check(BitVecExtract, operand=bvx, offset=0, size=8, pickle_size=119, sizeof=64) + check(BitVecConcat, operands=(bvx, bvy), pickle_size=133, sizeof=56) + check(BitVecITE, condition=x, true_value=bvx, false_value=bvy, pickle_size=161, sizeof=56) + + a = ArrayVariable(index_size=32, value_size=32, length=324, name="name") + check(ArrayConstant, index_size=32, value_size=8, value=b"A", pickle_size=136, sizeof=64) + check(ArraySlice, array=a, offset=0, size=10, pickle_size=136, sizeof=48) + check(ArraySelect, array=a, index=bvx, pickle_size=161, sizeof=56) + check(ArrayStore, array=a, index=bvx, value=bvy, pickle_size=188, sizeof=80) + + def all_subclasses(cls): + return set((Expression,)).union( + set(cls.__subclasses__()).union( + [s for c in cls.__subclasses__() for s in all_subclasses(c)] + ) + ) -""" -class Z3Specific(unittest.TestCase): + all_types = all_subclasses(Expression) + self.assertSetEqual(checked, all_types) + + def test_Expression_BitVecOp(self): + a = BitVecConstant(size=32, value=100) + b = BitVecConstant(size=32, value=101) + x = a + b + self.assertTrue(isinstance(x, BitVec)) + + def test_Expression_BoolTaint(self): + # Bool can not be instantiaated + self.assertRaises(Exception, Bool, ()) + + x = BoolConstant(value=True, taint=("red",)) + y = BoolConstant(value=False, taint=("blue",)) + z = BoolOr(x, y) + self.assertIn("red", x.taint) + self.assertIn("blue", y.taint) + self.assertIn("red", z.taint) + self.assertIn("blue", z.taint) + + def test_Expression_BitVecTaint(self): + # Bool can not be instantiaated + self.assertRaises(Exception, BitVec, ()) + + x = BitVecConstant(size=32, value=123, taint=("red",)) + y = BitVecConstant(size=32, value=124, taint=("blue",)) + z = BoolGreaterOrEqualThan(x, y) + self.assertIn("red", x.taint) + self.assertIn("blue", y.taint) + self.assertIn("red", z.taint) + self.assertIn("blue", z.taint) + + def test_Expression_Array(self): + # Bool can not be instantiaated + self.assertRaises(Exception, Array, ()) + + a = ArrayConstant(index_size=32, value_size=8, value=b"ABCDE") + a[0] == ord("A") + + x = BitVecConstant(size=32, value=123, taint=("red",)) + y = BitVecConstant(size=32, value=124, taint=("blue",)) + z = BoolGreaterOrEqualThan(x, y) + self.assertIn("red", x.taint) + self.assertIn("blue", y.taint) + self.assertIn("red", z.taint) + self.assertIn("blue", z.taint) + + +class ExpressionTestLoco(unittest.TestCase): _multiprocess_can_split_ = True def setUp(self): self.solver = Z3Solver.instance() + cs = ConstraintSet() + self.assertTrue(self.solver.check(cs)) + def assertItemsEqual(self, a, b): + # Required for Python3 compatibility + self.assertEqual(sorted(a), sorted(b)) - @patch('subprocess.check_output', mock_open()) - def test_check_solver_min(self, mock_check_output): - mock_check_output.return_value = ("output", "Error") - #mock_check_output.return_value='(:version "4.4.1")' - #mock_function = create_autospec(function, return_value='(:version "4.4.1")') - #with patch.object(subprocess, 'check_output' , return_value='(:version "4.4.1")'): - #test_patch.return_value = '(:version "4.4.1")' - print (self.solver._solver_version()) - self.assertTrue(self.solver._solver_version() == Version(major=4, minor=4, patch=1)) + def tearDown(self): + del self.solver - def test_check_solver_newer(self): - self.solver._received_version = '(:version "4.5.0")' - self.assertTrue(self.solver._solver_version() > Version(major=4, minor=4, patch=1)) + def test_signed_unsigned_LT_(self): + mask = (1 << 32) - 1 - def test_check_solver_long_format(self): - self.solver._received_version = '(:version "4.8.6 - build hashcode 78ed71b8de7d")' - self.assertTrue(self.solver._solver_version() == Version(major=4, minor=8, patch=6)) + cs = ConstraintSet() + a = cs.new_bitvec(32) + b = cs.new_bitvec(32) - def test_check_solver_undefined(self): - self.solver._received_version = '(:version "78ed71b8de7d")' - self.assertTrue( + cs.add(a == 0x1) + cs.add(b == (0x80000000 - 1)) - self.solver._solver_version() - == Version(major=float("inf"), minor=float("inf"), patch=float("inf")) - ) - self.assertTrue(self.solver._solver_version() > Version(major=4, minor=4, patch=1)) -""" + lt = b < a + ult = b.ult(a) + + self.assertFalse(self.solver.can_be_true(cs, ult)) + self.assertTrue(self.solver.must_be_true(cs, Operators.NOT(lt))) + + def test_signed_unsigned_LT_simple(self): + cs = ConstraintSet() + a = cs.new_bitvec(32) + b = cs.new_bitvec(32) + + cs.add(a == 0x1) + cs.add(b == 0x80000000) + + lt = b < a + ult = b.ult(a) + + self.assertFalse(self.solver.can_be_true(cs, ult)) + self.assertTrue(self.solver.must_be_true(cs, lt)) class ExpressionTest(unittest.TestCase): - _multiprocess_can_split_ = True + _multiprocess_can_split_ = False def setUp(self): self.solver = Z3Solver.instance() @@ -133,15 +416,15 @@ def testBasicAST_001(self): """ Can't build abstract classes """ a = BitVecConstant(32, 100) - self.assertRaises(TypeError, Expression, ()) - self.assertRaises(TypeError, Constant, 123) - self.assertRaises(TypeError, Variable, "NAME") - self.assertRaises(TypeError, Operation, a) + self.assertRaises(Exception, Expression, ()) + self.assertRaises(Exception, Constant, 123) + self.assertRaises(Exception, Variable, "NAME") + self.assertRaises(Exception, Operation, a) def testBasicOperation(self): """ Add """ - a = BitVecConstant(32, 100) - b = BitVecVariable(32, "VAR") + a = BitVecConstant(size=32, value=100) + b = BitVecVariable(size=32, name="VAR") c = a + b self.assertIsInstance(c, BitVecAdd) self.assertIsInstance(c, Operation) @@ -173,12 +456,12 @@ def test_cs_new_bitvec_invalid_size(self): with self.assertRaises(ValueError) as e: cs.new_bitvec(size=0) - self.assertEqual(str(e.exception), "Bitvec size (0) can't be equal to or less than 0") + self.assertEqual(str(e.exception), "BitVec size (0) can't be equal to or less than 0") with self.assertRaises(ValueError) as e: cs.new_bitvec(size=-23) - self.assertEqual(str(e.exception), "Bitvec size (-23) can't be equal to or less than 0") + self.assertEqual(str(e.exception), "BitVec size (-23) can't be equal to or less than 0") def testBasicConstraints(self): cs = ConstraintSet() @@ -195,29 +478,29 @@ def testSolver(self): def testBool1(self): cs = ConstraintSet() - bf = BoolConstant(False) - bt = BoolConstant(True) + bf = BoolConstant(value=False) + bt = BoolConstant(value=True) cs.add(Operators.AND(bf, bt)) self.assertFalse(self.solver.check(cs)) def testBool2(self): cs = ConstraintSet() - bf = BoolConstant(False) - bt = BoolConstant(True) + bf = BoolConstant(value=False) + bt = BoolConstant(value=True) cs.add(Operators.AND(bf, bt, bt, bt)) self.assertFalse(self.solver.check(cs)) def testBool3(self): cs = ConstraintSet() - bf = BoolConstant(False) - bt = BoolConstant(True) + bf = BoolConstant(value=False) + bt = BoolConstant(value=True) cs.add(Operators.AND(bt, bt, bf, bt)) self.assertFalse(self.solver.check(cs)) def testBool4(self): cs = ConstraintSet() - bf = BoolConstant(False) - bt = BoolConstant(True) + bf = BoolConstant(value=False) + bt = BoolConstant(value=True) cs.add(Operators.OR(True, bf)) cs.add(Operators.OR(bt, bt, False)) self.assertTrue(self.solver.check(cs)) @@ -259,10 +542,76 @@ def testBasicArray(self): temp_cs.add(key == 1002) self.assertTrue(self.solver.check(temp_cs)) + def testBasicArraySelectCache(self): + cs = ConstraintSet() + # make array of 32->8 bits + array = cs.new_array(32, value_size=256) + # make free 32bit bitvector + key = cs.new_bitvec(32) + + expr1 = array[key] + expr2 = array[key] + self.assertTrue(hash(expr1) == hash(expr2)) + self.assertTrue(expr1 is expr2) + self.assertTrue(simplify(expr1) is simplify(expr2)) + d = {} + d[expr1] = 1 + d[expr2] = 2 + self.assertEqual(d[expr2], 2) + + """ + expr3 = expr1 + expr2 + expr4 = expr1 + expr2 + self.assertTrue(hash(expr3) == hash(expr4)) + + b1 = expr3 == expr4 + b2 = expr3 == expr4 + d = {} + d[expr3] = 3 + d[expr4] = 4 + + key1 = cs.new_bitvec(32) + key2 = cs.new_bitvec(32) + + expr1 = array[key1] + expr2 = array[key2] + self.assertTrue(expr1 is not expr2) + self.assertTrue(simplify(expr1) is not simplify(expr2)) + + + expr1 = array[1] + expr2 = array[1] + self.assertTrue(expr1 is expr2) + self.assertTrue(hash(expr1) == hash(expr2)) + self.assertTrue(simplify(expr1) is simplify(expr2)) + + expr1 = array[1] + expr2 = array[2] + self.assertTrue(expr1 is not expr2) + self.assertTrue(hash(expr1) != hash(expr2)) + self.assertTrue(simplify(expr1) is not simplify(expr2)) + + expr1 = cs.new_bitvec(size=256) + expr2 = cs.new_bitvec(size=32) + self.assertFalse(hash(expr1) == hash(expr2)) + + expr1 = cs.new_bitvec(size=256) + expr2 = copy.copy(expr1) + self.assertTrue(hash(expr1) == hash(expr2)) + + expr1 = BitVecConstant(size=32, value=10) + expr2 = BitVecConstant(size=32, value=10) + self.assertTrue(hash(expr1) == hash(expr2)) + + expr1 = BitVecConstant(size=32, value=10) + expr2 = BitVecConstant(size=256, value=10) + self.assertTrue(hash(expr1) != hash(expr2)) + """ + def testBasicArray256(self): cs = ConstraintSet() # make array of 32->8 bits - array = cs.new_array(32, value_bits=256) + array = cs.new_array(32, value_size=256) # make free 32bit bitvector key = cs.new_bitvec(32) @@ -270,7 +619,6 @@ def testBasicArray256(self): cs.add(array[key] == 11111111111111111111111111111111111111111111) # let's restrict key to be greater than 1000 cs.add(key.ugt(1000)) - with cs as temp_cs: # 1001 position of array can be 111...111 temp_cs.add(array[1001] == 11111111111111111111111111111111111111111111) @@ -314,9 +662,6 @@ def testBasicArrayStore(self): # 1001 position of array can be 'B' self.assertTrue(self.solver.can_be_true(cs, array.select(1001) == ord("B"))) - # name is correctly proxied - self.assertEqual(array.name, name) - with cs as temp_cs: # but if it is 'B' ... temp_cs.add(array.select(1001) == ord("B")) @@ -333,13 +678,13 @@ def testBasicArrayStore(self): def testBasicArraySymbIdx(self): cs = ConstraintSet() - array = cs.new_array(index_bits=32, value_bits=32, name="array") + array = MutableArray(cs.new_array(index_size=32, value_size=32, name="array", default=0)) key = cs.new_bitvec(32, name="key") index = cs.new_bitvec(32, name="index") array[key] = 1 # Write 1 to a single location - cs.add(array.get(index, default=0) != 0) # Constrain index so it selects that location + cs.add(array.select(index) != 0) # Constrain index so it selects that location cs.add(index != key) # key and index are the same there is only one slot in 1 @@ -347,45 +692,63 @@ def testBasicArraySymbIdx(self): def testBasicArraySymbIdx2(self): cs = ConstraintSet() - array = cs.new_array(index_bits=32, value_bits=32, name="array") + array = MutableArray(cs.new_array(index_size=32, value_size=32, name="array", default=0)) key = cs.new_bitvec(32, name="key") index = cs.new_bitvec(32, name="index") array[key] = 1 # Write 1 to a single location - cs.add(array.get(index, 0) != 0) # Constrain index so it selects that location + cs.add(array.select(index) != 0) # Constrain index so it selects that location a_index = self.solver.get_value(cs, index) # get a concrete solution for index - cs.add(array.get(a_index, 0) != 0) # now storage must have something at that location + cs.add(array.select(a_index) != 0) # now storage must have something at that location cs.add(a_index != index) # remove it from the solutions # It should not be another solution for index self.assertFalse(self.solver.check(cs)) + def testBasicArrayDefault(self): + cs = ConstraintSet() + array = cs.new_array(index_size=32, value_size=32, name="array", default=0) + key = cs.new_bitvec(32, name="key") + self.assertTrue(self.solver.must_be_true(cs, array[key] == 0)) + + def testBasicArrayDefault2(self): + cs = ConstraintSet() + array = MutableArray(cs.new_array(index_size=32, value_size=32, name="array", default=0)) + index1 = cs.new_bitvec(32) + index2 = cs.new_bitvec(32) + value = cs.new_bitvec(32) + array[index2] = value + cs.add(index1 != index2) + cs.add(value != 0) + self.assertTrue(self.solver.must_be_true(cs, array[index1] == 0)) + + def testBasicArrayIndexConcrete(self): + cs = ConstraintSet() + array = MutableArray(cs.new_array(index_size=32, value_size=32, name="array", default=0)) + array[0] = 100 + self.assertTrue(array[0] == 100) + def testBasicArrayConcatSlice(self): - hw = bytearray(b"Hello world!") + hw = b"Hello world!" cs = ConstraintSet() # make array of 32->8 bits - array = cs.new_array(32, index_max=12) - + array = cs.new_array(32, length=len(hw)) array = array.write(0, hw) - + self.assertEqual(len(array), len(hw)) self.assertTrue(self.solver.must_be_true(cs, array == hw)) + self.assertEqual(len(array.read(0, 12)), 12) + x = array.read(0, 12) == hw self.assertTrue(self.solver.must_be_true(cs, array.read(0, 12) == hw)) - + cs.add(array.read(6, 6) == hw[6:12]) self.assertTrue(self.solver.must_be_true(cs, array.read(6, 6) == hw[6:12])) + self.assertTrue(self.solver.must_be_true(cs, b"Hello " + array.read(6, 6) == hw)) - self.assertTrue(self.solver.must_be_true(cs, bytearray(b"Hello ") + array.read(6, 6) == hw)) + self.assertTrue(self.solver.must_be_true(cs, b"Hello " + array.read(6, 5) + b"!" == hw)) self.assertTrue( self.solver.must_be_true( - cs, bytearray(b"Hello ") + array.read(6, 5) + bytearray(b"!") == hw - ) - ) - - self.assertTrue( - self.solver.must_be_true( - cs, - array.read(0, 1) + bytearray(b"ello ") + array.read(6, 5) + bytearray(b"!") == hw, + cs, array.read(0, 1) + b"ello " + array.read(6, 5) + b"!" == hw ) ) @@ -393,16 +756,18 @@ def testBasicArrayConcatSlice(self): self.assertTrue(len(array[0:12]) == 12) + self.assertEqual(len(array[6:11]), 5) + results = [] for c in array[6:11]: results.append(c) - self.assertTrue(len(results) == 5) + self.assertEqual(len(results), 5) def testBasicArraySlice(self): - hw = bytearray(b"Hello world!") + hw = b"Hello world!" cs = ConstraintSet() # make array of 32->8 bits - array = cs.new_array(32, index_max=12) + array = MutableArray(cs.new_array(32, length=12)) array = array.write(0, hw) array_slice = array[0:2] self.assertTrue(self.solver.must_be_true(cs, array == hw)) @@ -422,22 +787,22 @@ def testBasicArraySlice(self): def testBasicArrayProxySymbIdx(self): cs = ConstraintSet() - array = ArrayProxy(cs.new_array(index_bits=32, value_bits=32, name="array"), default=0) + array = MutableArray(cs.new_array(index_size=32, value_size=32, name="array", default=0)) key = cs.new_bitvec(32, name="key") index = cs.new_bitvec(32, name="index") array[key] = 1 # Write 1 to a single location - cs.add(array.get(index) != 0) # Constrain index so it selects that location + cs.add(array.select(index) != 0) # Constrain index so it selects that location a_index = self.solver.get_value(cs, index) # get a concrete solution for index - cs.add(array.get(a_index) != 0) # now storage must have something at that location + cs.add(array.select(a_index) != 0) # now storage must have something at that location cs.add(a_index != index) # remove it from the solutions # It should not be another solution for index self.assertFalse(self.solver.check(cs)) def testBasicArrayProxySymbIdx2(self): cs = ConstraintSet() - array = ArrayProxy(cs.new_array(index_bits=32, value_bits=32, name="array")) + array = MutableArray(cs.new_array(index_size=32, value_size=32, name="array", default=100)) key = cs.new_bitvec(32, name="key") index = cs.new_bitvec(32, name="index") @@ -446,20 +811,29 @@ def testBasicArrayProxySymbIdx2(self): solutions = self.solver.get_all_values(cs, array[0]) # get a concrete solution for index self.assertItemsEqual(solutions, (1, 2)) - solutions = self.solver.get_all_values( - cs, array.get(0, 100) + cs, array.select(0) ) # get a concrete solution for index 0 self.assertItemsEqual(solutions, (1, 2)) solutions = self.solver.get_all_values( - cs, array.get(1, 100) + cs, array.select(1) ) # get a concrete solution for index 1 (default 100) self.assertItemsEqual(solutions, (100, 2)) - self.assertTrue( - self.solver.can_be_true(cs, array[1] == 12345) - ) # no default so it can be anything + def testBasicConstatArray(self): + cs = ConstraintSet() + array1 = MutableArray( + cs.new_array(index_size=32, value_size=32, length=10, name="array1", default=0) + ) + array2 = MutableArray( + cs.new_array(index_size=32, value_size=32, length=10, name="array2", default=0) + ) + array1[0:10] = range(10) + self.assertTrue(array1[0] == 0) + # yeah right self.assertTrue(array1[0:10] == range(10)) + array_slice = array1[0:10] + self.assertTrue(array_slice[0] == 0) def testBasicPickle(self): import pickle @@ -478,7 +852,7 @@ def testBasicPickle(self): cs = pickle.loads(pickle_dumps(cs)) self.assertTrue(self.solver.check(cs)) - def testBitvector_add(self): + def testBitVector_add(self): cs = ConstraintSet() a = cs.new_bitvec(32) b = cs.new_bitvec(32) @@ -489,7 +863,7 @@ def testBitvector_add(self): self.assertTrue(self.solver.check(cs)) self.assertEqual(self.solver.get_value(cs, c), 11) - def testBitvector_add1(self): + def testBitVector_add1(self): cs = ConstraintSet() a = cs.new_bitvec(32) b = cs.new_bitvec(32) @@ -499,7 +873,7 @@ def testBitvector_add1(self): self.assertEqual(self.solver.check(cs), True) self.assertEqual(self.solver.get_value(cs, c), 11) - def testBitvector_add2(self): + def testBitVector_add2(self): cs = ConstraintSet() a = cs.new_bitvec(32) b = cs.new_bitvec(32) @@ -508,7 +882,7 @@ def testBitvector_add2(self): self.assertTrue(self.solver.check(cs)) self.assertEqual(self.solver.get_value(cs, a), 1) - def testBitvector_max(self): + def testBitVector_max(self): cs = ConstraintSet() a = cs.new_bitvec(32) cs.add(a <= 200) @@ -527,15 +901,15 @@ def testBitvector_max(self): self.assertEqual(self.solver.minmax(cs, a), (100, 200)) consts.optimize = True - def testBitvector_max_noop(self): + def testBitVector_max_noop(self): from manticore import config consts = config.get_group("smt") consts.optimize = False - self.testBitvector_max() + self.testBitVector_max() consts.optimize = True - def testBitvector_max1(self): + def testBitVector_max1(self): cs = ConstraintSet() a = cs.new_bitvec(32) cs.add(a < 200) @@ -543,22 +917,22 @@ def testBitvector_max1(self): self.assertTrue(self.solver.check(cs)) self.assertEqual(self.solver.minmax(cs, a), (101, 199)) - def testBitvector_max1_noop(self): + def testBitVector_max1_noop(self): from manticore import config consts = config.get_group("smt") consts.optimize = False - self.testBitvector_max1() + self.testBitVector_max1() consts.optimize = True def testBool_nonzero(self): - self.assertTrue(BoolConstant(True).__bool__()) - self.assertFalse(BoolConstant(False).__bool__()) + self.assertTrue(BoolConstant(value=True).__bool__()) + self.assertFalse(BoolConstant(value=False).__bool__()) def test_visitors(self): solver = Z3Solver.instance() cs = ConstraintSet() - arr = cs.new_array(name="MEM") + arr = MutableArray(cs.new_array(name="MEM")) a = cs.new_bitvec(32, name="VAR") self.assertEqual(get_depth(a), 1) cond = Operators.AND(a < 200, a > 100) @@ -576,17 +950,22 @@ def test_visitors(self): aux = arr[a + Operators.ZEXTEND(arr[a], 32)] self.assertEqual(get_depth(aux), 9) - self.maxDiff = 1500 + self.maxDiff = 5500 self.assertEqual( - translate_to_smtlib(aux), + translate_to_smtlib(aux, use_bindings=False), "(select (store (store (store MEM #x00000000 #x61) #x00000001 #x62) #x00000003 (select (store (store MEM #x00000000 #x61) #x00000001 #x62) (bvadd VAR #x00000001))) (bvadd VAR ((_ zero_extend 24) (select (store (store (store MEM #x00000000 #x61) #x00000001 #x62) #x00000003 (select (store (store MEM #x00000000 #x61) #x00000001 #x62) (bvadd VAR #x00000001))) VAR))))", ) + self.assertEqual( + translate_to_smtlib(aux, use_bindings=True), + "(let ((!a_2! (store (store MEM #x00000000 #x61) #x00000001 #x62))) (let ((!a_4! (store !a_2! #x00000003 (select !a_2! (bvadd VAR #x00000001))))) (select !a_4! (bvadd VAR ((_ zero_extend 24) (select !a_4! VAR))))))", + ) + values = arr[0:2] self.assertEqual(len(values), 2) self.assertItemsEqual(solver.get_all_values(cs, values[0]), [ord("a")]) self.assertItemsEqual(solver.get_all_values(cs, values[1]), [ord("b")]) - arr[1:3] = "cd" + arr[1:3] = b"cd" values = arr[0:3] self.assertEqual(len(values), 3) @@ -690,8 +1069,8 @@ def test_constant_folding_udiv(self): def test_simplify_OR(self): cs = ConstraintSet() - bf = BoolConstant(False) - bt = BoolConstant(True) + bf = BoolConstant(value=False) + bt = BoolConstant(value=True) var = cs.new_bool() cs.add(simplify(Operators.OR(var, var)) == var) cs.add(simplify(Operators.OR(var, bt)) == bt) @@ -699,9 +1078,9 @@ def test_simplify_OR(self): def testBasicReplace(self): """ Add """ - a = BitVecConstant(32, 100) - b1 = BitVecVariable(32, "VAR1") - b2 = BitVecVariable(32, "VAR2") + a = BitVecConstant(size=32, value=100) + b1 = BitVecVariable(size=32, name="VAR1") + b2 = BitVecVariable(size=32, name="VAR2") c = a + b1 @@ -869,6 +1248,7 @@ def test_ORD_proper_extract(self): def test_CHR(self): solver = Z3Solver.instance() cs = ConstraintSet() + self.assertTrue(solver.check(cs)) a = cs.new_bitvec(8) cs.add(Operators.CHR(a) == Operators.CHR(0x41)) @@ -972,16 +1352,32 @@ def test_UDIV(self): b = cs.new_bitvec(8) c = cs.new_bitvec(8) d = cs.new_bitvec(8) - cs.add(b == 0x86) # 134 cs.add(c == 0x11) # 17 cs.add(a == Operators.UDIV(b, c)) cs.add(d == b.udiv(c)) cs.add(a == d) - self.assertTrue(solver.check(cs)) self.assertEqual(solver.get_value(cs, a), 7) + solver = Z3Solver.instance() + cs1 = ConstraintSet() + + a1 = cs1.new_bitvec(8) + b1 = cs1.new_bitvec(8) + c1 = cs1.new_bitvec(8) + d1 = cs1.new_bitvec(8) + self.assertTrue(solver.check(cs1)) + + cs1.add(b1 == 0x86) # -122 + cs1.add(c1 == 0x11) # 17 + cs1.add(a1 == Operators.SREM(b1, c1)) + cs1.add(d1 == b1.srem(c1)) + cs1.add(a1 == d1) + + self.assertTrue(solver.check(cs1)) + self.assertEqual(solver.get_value(cs1, a1), -3 & 0xFF) + def test_SDIV(self): solver = Z3Solver.instance() cs = ConstraintSet() @@ -1080,6 +1476,7 @@ def testRelated(self): # and send in all the other stuff self.assertNotIn("AA", cs.related_to(bb1 == bb1).to_string()) + @unittest.skip("FIXME") def test_API(self): """ As we've split up the Constant, Variable, and Operation classes to avoid using multiple inheritance, diff --git a/tests/wasm/generate_tests.sh b/tests/wasm/generate_tests.sh index 0b5082bd4..c11d273aa 100755 --- a/tests/wasm/generate_tests.sh +++ b/tests/wasm/generate_tests.sh @@ -4,14 +4,14 @@ touch __init__.py if ! [ -x "$(command -v wast2json)" ]; then - wget -nc -nv -O wabt.tgz -c https://github.com/WebAssembly/wabt/releases/download/1.0.12/wabt-1.0.12-linux.tar.gz + wget -nv -O wabt.tgz -c https://github.com/WebAssembly/wabt/releases/download/1.0.12/wabt-1.0.12-linux.tar.gz tar --wildcards --strip=1 -xf wabt.tgz 'wabt-*/wast2json' rm wabt.tgz else cp "$(command -v wast2json)" . fi -wget -nc -nv -O spec.zip -c https://github.com/WebAssembly/spec/archive/opam-1.1.zip +wget -nv -O spec.zip -c https://github.com/WebAssembly/spec/archive/opam-1.1.zip unzip -q -j spec.zip 'spec-*/test/core/*' -d . rm run.py README.md diff --git a/tests/wasm/json2mc.py b/tests/wasm/json2mc.py index 3611cc1c4..a5165c438 100644 --- a/tests/wasm/json2mc.py +++ b/tests/wasm/json2mc.py @@ -60,7 +60,7 @@ def escape_null(in_str: str): template = env.get_template("test_template.jinja2") -modules = [] +modules = [] # type: ignore registered_modules = {} imports = [] current_module = None @@ -100,7 +100,7 @@ def escape_null(in_str: str): mod_name=d["action"].get("module", None), ) elif d["action"]["type"] == "get": - modules[current_module].add_test( + modules[current_module].add_test( # type: ignore d["action"]["field"], d["line"], [], @@ -156,9 +156,9 @@ def escape_null(in_str: str): imports.append({"type": "alias", "alias": d["as"], "orig": maybe_name}) else: # This is an alias for the current module imports.append( - {"type": "import", "name": d["as"], "filename": modules[current_module].filename} + {"type": "import", "name": d["as"], "filename": modules[current_module].filename} # type: ignore ) - modules[current_module].registered_name = d["as"] + modules[current_module].registered_name = d["as"] # type: ignore if current_module: modules[current_module].imports = copy(imports)