From e1eaa9745823f07d87c716895c972051904194e1 Mon Sep 17 00:00:00 2001 From: Audrey Dutcher Date: Thu, 12 Dec 2024 01:06:15 -0700 Subject: [PATCH] Typecheck more things (#437) --------- Co-authored-by: Fish --- pyvex/block.py | 14 +++++++------- pyvex/const.py | 13 ++++++++++--- pyvex/expr.py | 24 +++++++++++++++--------- pyvex/stmt.py | 24 +++++++++++++++--------- 4 files changed, 47 insertions(+), 28 deletions(-) diff --git a/pyvex/block.py b/pyvex/block.py index 03caea9f..f584e719 100644 --- a/pyvex/block.py +++ b/pyvex/block.py @@ -4,12 +4,12 @@ from typing import Optional from . import expr, stmt -from .const import get_type_size +from .const import U1, get_type_size from .const_val import ConstVal from .data_ref import DataRef from .enums import VEXObject from .errors import SkipStatementsError -from .expr import RdTmp +from .expr import Const, RdTmp from .native import pvc from .stmt import ( CAS, @@ -50,7 +50,7 @@ class IRSB(VEXObject): :ivar int addr: The address of this basic block, i.e. the address in the first IMark """ - __slots__ = ( + __slots__ = [ "addr", "arch", "statements", @@ -65,7 +65,7 @@ class IRSB(VEXObject): "_instruction_addresses", "data_refs", "const_vals", - ) + ] # The following constants shall match the defs in pyvex.h MAX_EXITS = 400 @@ -129,9 +129,9 @@ def __init__( self.arch: Arch = arch self.statements: list[IRStmt] = [] - self.next: IRExpr | None = None + self.next: IRExpr = Const(U1(0)) self._tyenv: Optional["IRTypeEnv"] = None - self.jumpkind: str | None = None + self.jumpkind: str = "UNSET" self._direct_next: bool | None = None self._size: int | None = None self._instructions: int | None = None @@ -651,7 +651,7 @@ def __init__(self, arch, types=None): def __str__(self): return " ".join(("t%d:%s" % (i, t)) for i, t in enumerate(self.types)) - def lookup(self, tmp): + def lookup(self, tmp: int) -> str: """ Return the type of temporary variable `tmp` as an enum string """ diff --git a/pyvex/const.py b/pyvex/const.py index 7b00826a..f9802c74 100644 --- a/pyvex/const.py +++ b/pyvex/const.py @@ -1,4 +1,6 @@ +# pylint:disable=missing-class-docstring,raise-missing-from,not-callable import re +from abc import ABC from .enums import VEXObject, get_enum_from_int from .errors import PyVEXError @@ -6,17 +8,17 @@ # IRConst hierarchy -class IRConst(VEXObject): +class IRConst(VEXObject, ABC): __slots__ = ["_value"] type: str - size: int | None = None + size: int tag: str c_constructor = None _value: int def pp(self): - print(self.__str__()) + print(str(self)) @property def value(self) -> int: @@ -215,6 +217,7 @@ class F32(IRConst): tag = "Ico_F32" op_format = "F32" c_constructor = pvc.IRConst_F32 + size = 32 def __init__(self, value): self._value = value @@ -234,6 +237,7 @@ class F32i(IRConst): tag = "Ico_F32i" op_format = "F32" c_constructor = pvc.IRConst_F32i + size = 32 def __init__(self, value): self._value = value @@ -253,6 +257,7 @@ class F64(IRConst): tag = "Ico_F64" op_format = "F64" c_constructor = pvc.IRConst_F64 + size = 64 def __init__(self, value): self._value = value @@ -272,6 +277,7 @@ class F64i(IRConst): tag = "Ico_F64i" op_format = "F64" c_constructor = pvc.IRConst_F64i + size = 64 def __init__(self, value): self._value = value @@ -291,6 +297,7 @@ class V128(IRConst): tag = "Ico_V128" op_format = "V128" c_constructor = pvc.IRConst_V128 + size = 128 def __init__(self, value): self._value = value diff --git a/pyvex/expr.py b/pyvex/expr.py index 1a1b399f..8a069326 100644 --- a/pyvex/expr.py +++ b/pyvex/expr.py @@ -1,12 +1,17 @@ +from __future__ import annotations + import logging import re -from typing import Optional +from typing import TYPE_CHECKING from .const import U8, U16, U32, U64, IRConst, get_type_size from .enums import IRCallee, IRRegArray, VEXObject, get_enum_from_int, get_int_from_enum from .errors import PyVEXError from .native import ffi, pvc +if TYPE_CHECKING: + from .block import IRTypeEnv + log = logging.getLogger("pyvex.expr") @@ -30,7 +35,7 @@ def _pp_str(self) -> str: raise NotImplementedError @property - def child_expressions(self) -> list["IRExpr"]: + def child_expressions(self) -> list[IRExpr]: """ A list of all of the expressions that this expression ends up evaluating. """ @@ -56,10 +61,10 @@ def constants(self): constants.append(v) return constants - def result_size(self, tyenv): + def result_size(self, tyenv: IRTypeEnv): return get_type_size(self.result_type(tyenv)) - def result_type(self, tyenv): + def result_type(self, tyenv: IRTypeEnv): raise NotImplementedError() def replace_expression(self, replacements): @@ -95,7 +100,7 @@ def replace_expression(self, replacements): v.replace_expression(replacements) @staticmethod - def _from_c(c_expr) -> Optional["IRExpr"]: + def _from_c(c_expr) -> IRExpr | None: if c_expr == ffi.NULL or c_expr[0] == ffi.NULL: return None @@ -282,7 +287,7 @@ class Get(IRExpr): tag = "Iex_Get" - def __init__(self, offset, ty: str, ty_int: int | None = None): + def __init__(self, offset: int, ty: str, ty_int: int | None = None): self.offset = offset if ty_int is None: self.ty_int = get_int_from_enum(ty) @@ -520,7 +525,7 @@ class Unop(IRExpr): tag = "Iex_Unop" - def __init__(self, op, args): + def __init__(self, op: str, args: list[IRExpr]): self.op = op self.args = args @@ -616,14 +621,14 @@ class Const(IRExpr): tag = "Iex_Const" - def __init__(self, con: "IRConst"): + def __init__(self, con: IRConst): self._con = con def _pp_str(self): return str(self.con) @property - def con(self) -> "IRConst": + def con(self) -> IRConst: return self._con @staticmethod @@ -849,6 +854,7 @@ def cmp_signature(op): if (m is None) == (m2 is None): raise PyvexOpMatchException() mfound = m if m is not None else m2 + assert mfound is not None size = int(mfound.group("size")) size_type = int_type_for_size(size) return (int_type_for_size(1), (size_type, size_type)) diff --git a/pyvex/stmt.py b/pyvex/stmt.py index 54528aa9..c1dd1c25 100644 --- a/pyvex/stmt.py +++ b/pyvex/stmt.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import logging from collections.abc import Iterator +from typing import TYPE_CHECKING from . import expr from .const import IRConst @@ -8,6 +11,9 @@ from .expr import Const, Get, IRExpr from .native import ffi, pvc +if TYPE_CHECKING: + from .block import IRTypeEnv + log = logging.getLogger("pyvex.stmt") @@ -16,7 +22,7 @@ class IRStmt(VEXObject): IR statements in VEX represents operations with side-effects. """ - tag: str | None = None + tag: str tag_int = 0 # set automatically at bottom of file __slots__ = [] @@ -25,7 +31,7 @@ def pp(self): print(str(self)) @property - def child_expressions(self) -> Iterator["IRExpr"]: + def child_expressions(self) -> Iterator[IRExpr]: for k in self.__slots__: v = getattr(self, k) if isinstance(v, IRExpr): @@ -54,7 +60,7 @@ def _from_c(c_stmt): raise PyVEXError("Unknown/unsupported IRStmtTag %s.\n" % get_enum_from_int(c_stmt.tag)) return stmt_class._from_c(c_stmt) - def typecheck(self, tyenv): # pylint: disable=unused-argument,no-self-use + def typecheck(self, tyenv: IRTypeEnv) -> bool: # pylint: disable=unused-argument,no-self-use return True def replace_expression(self, replacements): @@ -165,7 +171,7 @@ class Put(IRStmt): tag = "Ist_Put" - def __init__(self, data: "IRExpr", offset): + def __init__(self, data: IRExpr, offset: int): self.data = data self.offset = offset @@ -234,7 +240,7 @@ class WrTmp(IRStmt): tag = "Ist_WrTmp" - def __init__(self, tmp, data: "IRExpr"): + def __init__(self, tmp, data: IRExpr): self.tmp = tmp self.data = data @@ -272,7 +278,7 @@ class Store(IRStmt): tag = "Ist_Store" - def __init__(self, addr: "IRExpr", data: "IRExpr", end: str): + def __init__(self, addr: IRExpr, data: IRExpr, end: str): self.addr = addr self.data = data self.end = end @@ -403,7 +409,7 @@ class LLSC(IRStmt): tag = "Ist_LLSC" - def __init__(self, addr, storedata, result, end): + def __init__(self, addr: IRExpr, storedata: IRExpr, result: int, end: str): self.addr = addr self.storedata = storedata self.result = result @@ -527,7 +533,7 @@ class Exit(IRStmt): tag = "Ist_Exit" - def __init__(self, guard, dst, jk, offsIP): + def __init__(self, guard: IRExpr, dst: IRConst, jk: str, offsIP: int): self.guard = guard self.dst = dst self.offsIP = offsIP @@ -581,7 +587,7 @@ class LoadG(IRStmt): tag = "Ist_LoadG" - def __init__(self, end, cvt, dst, addr, alt, guard): + def __init__(self, end: str, cvt: str, dst: int, addr: IRExpr, alt: IRExpr, guard: IRExpr): self.addr = addr self.alt = alt self.guard = guard