Skip to content

Commit

Permalink
AIL: Model dirty expression properly.
Browse files Browse the repository at this point in the history
  • Loading branch information
ltfish committed Oct 19, 2024
1 parent 0767b5e commit 3e24a4e
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 59 deletions.
47 changes: 41 additions & 6 deletions ailment/block_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,12 @@ def _handle_MultiStatementExpression(
def _handle_DirtyExpression(
self, expr_idx: int, expr: DirtyExpression, stmt_idx: int, stmt: Statement, block: Block | None
):
self._handle_expr(0, expr.dirty_expr, stmt_idx, stmt, block)
for idx, operand in expr.operands:
self._handle_expr(idx, operand, stmt_idx, stmt, block)
if expr.guard is not None:
self._handle_expr(len(expr.operands) + 1, expr.guard, stmt_idx, stmt, block)
if expr.result_expr is not None:
self._handle_expr(len(expr.operands) + 2, expr.result_expr, stmt_idx, stmt, block)

def _handle_VEXCCallExpression(
self, expr_idx: int, expr: VEXCCallExpression, stmt_idx: int, stmt: Statement, block: Block | None
Expand Down Expand Up @@ -570,11 +575,41 @@ def _handle_Phi(self, expr_id: int, expr: Phi, stmt_idx: int, stmt: Statement, b
def _handle_DirtyExpression(
self, expr_idx: int, expr: DirtyExpression, stmt_idx: int, stmt: Statement, block: Block | None
):
new_dirty_expr = self._handle_expr(0, expr.dirty_expr, stmt_idx, stmt, block)
if new_dirty_expr is not None and new_dirty_expr is not expr.dirty_expr:
new_expr = expr.copy()
new_expr.dirty_expr = new_dirty_expr
return new_expr
changed = False
new_operands = []
for operand in expr.operands:
new_operand = self._handle_expr(0, operand, stmt_idx, stmt, block)
if new_operand is not None and new_operand is not operand:
changed = True
new_operands.append(new_operand)
else:
new_operands.append(operand)

new_result_expr = expr.result_expr
if expr.result_expr is not None:
new_result_expr = self._handle_expr(1, expr.result_expr, stmt_idx, stmt, block)
if new_result_expr is not None and new_result_expr is not expr.result_expr:
changed = True

new_guard = expr.guard
if expr.guard is not None:
new_guard = self._handle_expr(2, expr.guard, stmt_idx, stmt, block)
if new_guard is not None and new_guard is not expr.guard:
changed = True

if changed:
return DirtyExpression(
expr.idx,
expr.callee,
new_operands,
guard=new_guard,
result_expr=new_result_expr,
mfx=expr.mfx,
maddr=expr.maddr,
msize=expr.msize,
bits=expr.bits,
**expr.tags,
)
return None

def _handle_VEXCCallExpression(
Expand Down
48 changes: 23 additions & 25 deletions ailment/converter_vex.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,40 +44,24 @@ def convert(expr, manager): # pylint:disable=arguments-differ
:param expr:
:return:
"""
if isinstance(expr, pyvex.const.IRConst):
return VEXExprConverter.const_n(expr, manager)

func = EXPRESSION_MAPPINGS.get(type(expr))
if func is not None:
# When something goes wrong, return a DirtyExpression instead of crashing the program
try:
return func(expr, manager)
except UnsupportedIROpError:
log.warning("VEXExprConverter: Unsupported IROp %s.", expr.op)
return DirtyExpression(manager.next_atom(), expr, bits=expr.result_size(manager.tyenv))

if isinstance(expr, pyvex.const.IRConst):
return VEXExprConverter.const_n(expr, manager)

if isinstance(expr, pyvex.IRExpr.CCall):
operands = tuple(VEXExprConverter.convert(arg, manager) for arg in expr.args)
ccall = VEXCCallExpression(
manager.next_atom(),
expr.cee.name,
operands,
bits=expr.result_size(manager.tyenv),
ins_addr=manager.ins_addr,
vex_block_addr=manager.block_addr,
vex_stmt_idx=manager.vex_stmt_idx,
)
return DirtyExpression(
manager.next_atom(),
ccall,
bits=expr.result_size(manager.tyenv),
ins_addr=manager.ins_addr,
vex_block_addr=manager.block_addr,
vex_stmt_idx=manager.vex_stmt_idx,
)
return DirtyExpression(
manager.next_atom(), f"unsupported_{expr.op}", [], bits=expr.result_size(manager.tyenv)
)

log.warning("VEXExprConverter: Unsupported VEX expression of type %s.", type(expr))
return DirtyExpression(manager.next_atom(), expr, bits=expr.result_size(manager.tyenv))
return DirtyExpression(
manager.next_atom(), f"unsupported_{str(type(expr))}", [], bits=expr.result_size(manager.tyenv)
)

@staticmethod
def convert_list(exprs, manager):
Expand Down Expand Up @@ -397,6 +381,19 @@ def ITE(expr, manager):
vex_stmt_idx=manager.vex_stmt_idx,
)

@staticmethod
def CCall(expr: pyvex.IRExpr.CCall, manager):
operands = [VEXExprConverter.convert(arg, manager) for arg in expr.args]
return VEXCCallExpression(
manager.next_atom(),
expr.cee.name,
operands,
bits=expr.result_size(manager.tyenv),
ins_addr=manager.ins_addr,
vex_block_addr=manager.block_addr,
vex_stmt_idx=manager.vex_stmt_idx,
)


EXPRESSION_MAPPINGS = {
pyvex.IRExpr.RdTmp: VEXExprConverter.RdTmp,
Expand All @@ -409,6 +406,7 @@ def ITE(expr, manager):
pyvex.const.U64: VEXExprConverter.const_n,
pyvex.IRExpr.Load: VEXExprConverter.Load,
pyvex.IRExpr.ITE: VEXExprConverter.ITE,
pyvex.IRExpr.CCall: VEXExprConverter.CCall,
}


Expand Down
155 changes: 127 additions & 28 deletions ailment/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,44 +1308,135 @@ def copy(self) -> ITE:

class DirtyExpression(Expression):
__slots__ = (
"dirty_expr",
"callee",
"guard",
"operands",
"result_expr",
"mfx",
"maddr",
"msize",
"bits",
)

def __init__(self, idx, dirty_expr, bits=None, **kwargs):
def __init__(
self,
idx,
callee: str,
operands: list[Expression],
*,
guard: Expression | None = None,
result_expr: Expression | None = None,
mfx: str | None = None,
maddr: Expression | None = None,
msize: Expression | None = None,
# TODO: fxstate (guest state effects) is not modeled yet
bits=None,
**kwargs,
):
super().__init__(idx, 1, **kwargs)
self.dirty_expr = dirty_expr

self.callee = callee
self.guard = guard
self.operands = operands
self.result_expr = result_expr
self.mfx = mfx
self.maddr = maddr
self.msize = msize
self.bits = bits

def likes(self, other):
return type(other) is DirtyExpression and other.dirty_expr == self.dirty_expr
return (
type(other) is DirtyExpression
and other.callee == self.callee
and is_none_or_likeable(other.guard, self.guard)
and len(self.operands) == len(other.operands)
and all(op1.likes(op2) for op1, op2 in zip(self.operands, other.operands))
and is_none_or_likeable(other.result_expr, self.result_expr)
and other.mfx == self.mfx
and is_none_or_likeable(other.maddr, self.maddr)
and is_none_or_likeable(other.msize, self.msize)
and self.bits == other.bits
)

def matches(self, other):
return (
type(other) is DirtyExpression
and other.callee == self.callee
and is_none_or_matchable(other.guard, self.guard)
and len(self.operands) == len(other.operands)
and all(op1.matches(op2) for op1, op2 in zip(self.operands, other.operands))
and is_none_or_matchable(other.result_expr, self.result_expr)
and other.mfx == self.mfx
and is_none_or_matchable(other.maddr, self.maddr)
and is_none_or_matchable(other.msize, self.msize)
and self.bits == other.bits
)

matches = likes
__hash__ = TaggedObject.__hash__

def _hash_core(self):
return stable_hash((DirtyExpression, self.dirty_expr))
return stable_hash(
(
DirtyExpression,
self.callee,
self.guard,
tuple(self.operands),
self.result_expr,
self.mfx,
self.maddr,
self.msize,
self.bits,
)
)

def __repr__(self):
return "DirtyExpression (%s)" % type(self.dirty_expr)
return f"[D] {self.callee}({', '.join(repr(op) for op in self.operands)})"

def __str__(self):
return "[D] %s" % str(self.dirty_expr)
return f"[D] {self.callee}({', '.join(repr(op) for op in self.operands)})"

def copy(self) -> DirtyExpression:
return DirtyExpression(self.idx, self.dirty_expr, bits=self.bits, **self.tags)
return DirtyExpression(
self.idx,
self.callee,
self.operands,
guard=self.guard,
result_expr=self.result_expr,
mfx=self.mfx,
maddr=self.maddr,
msize=self.msize,
bits=self.bits,
**self.tags,
)

def replace(self, old_expr, new_expr):
if old_expr is self.dirty_expr:
return True, DirtyExpression(self.idx, new_expr, bits=self.bits, **self.tags)
def replace(self, old_expr: Expression, new_expr: Expression):
new_operands = []
replaced = False
for op in self.operands:
if old_expr == op:
replaced = True
new_operands.append(new_expr)
else:
r, new_op = op.replace(old_expr, new_expr)
if r:
replaced = True
new_operands.append(new_op)
else:
new_operands.append(op)

if isinstance(self.dirty_expr, Expression):
replaced, new_dirty_expr = self.dirty_expr.replace(old_expr, new_expr)
else:
replaced = False
new_dirty_expr = None
if replaced:
return True, DirtyExpression(self.idx, new_dirty_expr, bits=self.bits, **self.tags)
return True, DirtyExpression(
self.idx,
self.callee,
new_operands,
guard=self.guard,
result_expr=self.result_expr,
mfx=self.mfx,
maddr=self.maddr,
msize=self.msize,
bits=self.bits,
**self.tags,
)
else:
return False, self

Expand All @@ -1356,21 +1447,29 @@ def size(self):

class VEXCCallExpression(Expression):
__slots__ = (
"cee_name",
"callee",
"operands",
"bits",
)

def __init__(self, idx, cee_name, operands, bits=None, **kwargs):
def __init__(self, idx, callee: str, operands: list[Expression], bits=None, **kwargs):
super().__init__(idx, max(operand.depth for operand in operands), **kwargs)
self.cee_name = cee_name
self.callee = callee
self.operands = operands
self.bits = bits

@property
def op(self) -> str:
return self.callee

@property
def verbose_op(self) -> str:
return self.op

def likes(self, other):
return (
type(other) is VEXCCallExpression
and other.cee_name == self.cee_name
and other.callee == self.callee
and len(self.operands) == len(other.operands)
and self.bits == other.bits
and all(op1.likes(op2) for op1, op2 in zip(other.operands, self.operands))
Expand All @@ -1379,7 +1478,7 @@ def likes(self, other):
def matches(self, other):
return (
type(other) is VEXCCallExpression
and other.cee_name == self.cee_name
and other.callee == self.callee
and len(self.operands) == len(other.operands)
and self.bits == other.bits
and all(op1.matches(op2) for op1, op2 in zip(other.operands, self.operands))
Expand All @@ -1388,17 +1487,17 @@ def matches(self, other):
__hash__ = TaggedObject.__hash__

def _hash_core(self):
return stable_hash((VEXCCallExpression, self.cee_name, self.bits, tuple(self.operands)))
return stable_hash((VEXCCallExpression, self.callee, self.bits, tuple(self.operands)))

def __repr__(self):
return f"VEXCCallExpression [{self.cee_name}]"
return f"VEXCCallExpression [{self.callee}({', '.join(repr(op) for op in self.operands)})]"

def __str__(self):
operands_str = ", ".join(repr(op) for op in self.operands)
return f"{self.cee_name}({operands_str})"
return f"{self.callee}({operands_str})"

def copy(self) -> VEXCCallExpression:
return VEXCCallExpression(self.idx, self.cee_name, self.operands, bits=self.bits, **self.tags)
return VEXCCallExpression(self.idx, self.callee, self.operands, bits=self.bits, **self.tags)

def replace(self, old_expr, new_expr):
new_operands = []
Expand All @@ -1416,7 +1515,7 @@ def replace(self, old_expr, new_expr):
new_operands.append(operand)

if replaced:
return True, VEXCCallExpression(self.idx, self.cee_name, tuple(new_operands), bits=self.bits, **self.tags)
return True, VEXCCallExpression(self.idx, self.callee, list(new_operands), bits=self.bits, **self.tags)
else:
return False, self

Expand Down

0 comments on commit 3e24a4e

Please sign in to comment.