From 09b28e706c5d3e03628e79f5ef17dde24576d266 Mon Sep 17 00:00:00 2001 From: Daniel Fairhead Date: Thu, 31 Oct 2024 09:10:05 +0000 Subject: [PATCH] Playing with adding typeing info... --- pyproject.toml | 14 ++++++ simpleeval.py | 130 +++++++++++++++++++++++++++---------------------- 2 files changed, 86 insertions(+), 58 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3fa7662..e6af3c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,3 +78,17 @@ disable = [ "unnecessary-pass", "bad-super-call", ] + +[tool.mypy] +# strict = true +# Start off with these +warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true + +# Getting these passing should be easy +strict_equality = true +strict_concatenate = true + +# Strongly recommend enabling this one as soon as you can +check_untyped_defs = true diff --git a/simpleeval.py b/simpleeval.py index fa3909e..4da43fd 100644 --- a/simpleeval.py +++ b/simpleeval.py @@ -103,11 +103,12 @@ """ import ast +from collections.abc import Callable import operator as op import sys import warnings from random import random -from typing import Type, Dict, Set, Union +from typing import Any, List, Optional, Tuple, Type, Dict, Set, Union ######################################## # Module wide 'globals' @@ -138,7 +139,7 @@ # builtins is a dict in python >3.6 but a module before DISALLOW_FUNCTIONS = {type, isinstance, eval, getattr, setattr, repr, compile, open, exec} if hasattr(__builtins__, "help") or ( - hasattr(__builtins__, "__contains__") and "help" in __builtins__ # type: ignore + hasattr(__builtins__, "__contains__") and "help" in __builtins__ ): # PyInstaller environment doesn't include this module. DISALLOW_FUNCTIONS.add(help) @@ -294,13 +295,11 @@ class FunctionNotDefined(InvalidExpression): """sorry! That function isn't defined!""" def __init__(self, func_name, expression): - self.message = "Function '{0}' not defined," " for expression '{1}'.".format( - func_name, expression - ) + self.message = f"Function '{func_name}' not defined, for expression '{expression}'." setattr(self, "func_name", func_name) # bypass 2to3 confusion. self.expression = expression - super(InvalidExpression, self).__init__(self.message) + super().__init__(self.message) class NameNotDefined(InvalidExpression): @@ -308,34 +307,32 @@ class NameNotDefined(InvalidExpression): def __init__(self, name, expression): self.name = name - self.message = "'{0}' is not defined for expression '{1}'".format(name, expression) + self.message = f"'{name}' is not defined for expression '{expression}'" self.expression = expression - super(InvalidExpression, self).__init__(self.message) + super().__init__(self.message) class AttributeDoesNotExist(InvalidExpression): """attribute does not exist""" def __init__(self, attr, expression): - self.message = "Attribute '{0}' does not exist in expression '{1}'".format( - attr, expression - ) + self.message = f"Attribute '{attr}' does not exist in expression '{expression}'" self.attr = attr self.expression = expression - super(InvalidExpression, self).__init__(self.message) + super().__init__(self.message) class OperatorNotDefined(InvalidExpression): """operator does not exist""" def __init__(self, attr, expression): - self.message = "Operator '{0}' does not exist in expression '{1}'".format(attr, expression) + self.message = f"Operator '{attr}' does not exist in expression '{expression}'" self.attr = attr self.expression = expression - super(InvalidExpression, self).__init__(self.message) + super().__init__(self.message) class FeatureNotAvailable(InvalidExpression): @@ -383,7 +380,7 @@ def safe_power(a, b): # pylint: disable=invalid-name """a limited exponent/to-the-power-of function, for safety reasons""" if abs(a) > MAX_POWER or abs(b) > MAX_POWER: - raise NumberTooHigh("Sorry! I don't want to evaluate {0} ** {1}".format(a, b)) + raise NumberTooHigh(f"Sorry! I don't want to evaluate {a} ** {b}") return a**b @@ -404,7 +401,7 @@ def safe_add(a, b): # pylint: disable=invalid-name if hasattr(a, "__len__") and hasattr(b, "__len__"): if len(a) + len(b) > MAX_STRING_LENGTH: raise IterableTooLong( - "Sorry, adding those two together would" " make something too long." + "Sorry, adding those two together would make something too long." ) return a + b @@ -412,21 +409,21 @@ def safe_add(a, b): # pylint: disable=invalid-name def safe_rshift(a, b): # pylint: disable=invalid-name """rshift, but with input limits""" if abs(b) > MAX_SHIFT or abs(a) > MAX_SHIFT_BASE: - raise NumberTooHigh("Sorry! I don't want to evaluate {0} >> {1}".format(a, b)) + raise NumberTooHigh(f"Sorry! I don't want to evaluate {a} >> {b}") return a >> b def safe_lshift(a, b): # pylint: disable=invalid-name """lshift, but with input limits""" if abs(b) > MAX_SHIFT or abs(a) > MAX_SHIFT_BASE: - raise NumberTooHigh("Sorry! I don't want to evaluate {0} << {1}".format(a, b)) + raise NumberTooHigh(f"Sorry! I don't want to evaluate {a} << {b}") return a << b ######################################## # Defaults for the evaluator: -DEFAULT_OPERATORS = { +DEFAULT_OPERATORS: Dict[Type, Callable] = { ast.Add: safe_add, ast.Sub: op.sub, ast.Mult: safe_mult, @@ -455,7 +452,7 @@ def safe_lshift(a, b): # pylint: disable=invalid-name ast.IsNot: lambda x, y: x is not y, } -DEFAULT_FUNCTIONS = { +DEFAULT_FUNCTIONS: Dict[str, Callable] = { "rand": random, "randint": random_int, "int": int, @@ -463,7 +460,7 @@ def safe_lshift(a, b): # pylint: disable=invalid-name "str": str, } -DEFAULT_NAMES = {"True": True, "False": False, "None": None} +DEFAULT_NAMES: Dict[str, Any] = {"True": True, "False": False, "None": None} ATTR_INDEX_FALLBACK = True @@ -472,19 +469,27 @@ def safe_lshift(a, b): # pylint: disable=invalid-name # And the actual evaluator: -class SimpleEval(object): # pylint: disable=too-few-public-methods +class SimpleEval: # pylint: disable=too-few-public-methods """A very simple expression parser. >>> s = SimpleEval() >>> s.eval("20 + 30 - ( 10 * 5)") 0 """ - expr = "" + expr: str + nodes: Dict[Type, Callable] - def __init__(self, operators=None, functions=None, names=None, allowed_attrs=None): + def __init__( + self, + operators: Optional[Dict[Type, Callable]] = None, + functions: Optional[Dict[str, Callable]] = None, + names: Optional[Dict[str, Any]] = None, + allowed_attrs: Optional[Dict[Union[Type, None], Set]] = None, + ): """ Create the evaluator instance. Set up valid operators (+,-, etc) - functions (add, random, get_val, whatever) and names.""" + functions (add, random, get_val, whatever) and names. + """ if operators is None: operators = DEFAULT_OPERATORS.copy() @@ -544,10 +549,10 @@ def __init__(self, operators=None, functions=None, names=None, allowed_attrs=Non raise FeatureNotAvailable("This function {} is a really bad idea.".format(f)) def __del__(self): - self.nodes = None + self.nodes = None # type: ignore @staticmethod - def parse(expr): + def parse(expr: str) -> ast.stmt: """parse an expression into a node tree""" parsed = ast.parse(expr.strip()) @@ -561,7 +566,7 @@ def parse(expr): ) return parsed.body[0] - def eval(self, expr, previously_parsed=None): + def eval(self, expr: str, previously_parsed: Optional[ast.expr] = None) -> Any: """evaluate an expresssion, using the operators, functions and names previously set up.""" @@ -570,7 +575,7 @@ def eval(self, expr, previously_parsed=None): return self._eval(previously_parsed or self.parse(expr)) - def _eval(self, node): + def _eval(self, node: ast.AST): """The internal evaluator used on each node in the parsed tree.""" try: @@ -582,16 +587,16 @@ def _eval(self, node): return handler(node) - def _eval_expr(self, node): + def _eval_expr(self, node: ast.Expr): return self._eval(node.value) - def _eval_assign(self, node): + def _eval_assign(self, node: ast.Assign): warnings.warn( "Assignment ({}) attempted, but this is ignored".format(self.expr), AssignmentAttempted ) return self._eval(node.value) - def _eval_aug_assign(self, node): + def _eval_aug_assign(self, node: ast.AugAssign): warnings.warn( "Assignment ({}) attempted, but this is ignored".format(self.expr), AssignmentAttempted ) @@ -602,11 +607,11 @@ def _eval_import(node): raise FeatureNotAvailable("Sorry, 'import' is not allowed.") @staticmethod - def _eval_num(node): + def _eval_num(node: ast.Num): return node.n @staticmethod - def _eval_str(node): + def _eval_str(node: ast.Str): if len(node.s) > MAX_STRING_LENGTH: raise IterableTooLong( "String Literal in statement is too long! ({0}, when {1} is max)".format( @@ -616,7 +621,7 @@ def _eval_str(node): return node.s @staticmethod - def _eval_constant(node): + def _eval_constant(node: ast.Constant): if hasattr(node.value, "__len__") and len(node.value) > MAX_STRING_LENGTH: raise IterableTooLong( "Literal in statement is too long! ({0}, when {1} is max)".format( @@ -639,7 +644,7 @@ def _eval_binop(self, node): raise OperatorNotDefined(node.op, self.expr) return operator(self._eval(node.left), self._eval(node.right)) - def _eval_boolop(self, node): + def _eval_boolop(self, node: ast.BoolOp): to_return = False if isinstance(node.op, ast.And): for value in node.values: @@ -653,7 +658,7 @@ def _eval_boolop(self, node): break return to_return - def _eval_compare(self, node): + def _eval_compare(self, node: ast.Compare): right = self._eval(node.left) to_return = True for operation, comp in zip(node.ops, node.comparators): @@ -664,17 +669,18 @@ def _eval_compare(self, node): to_return = self.operators[type(operation)](left, right) return to_return - def _eval_ifexp(self, node): + def _eval_ifexp(self, node: ast.IfExp): return self._eval(node.body) if self._eval(node.test) else self._eval(node.orelse) - def _eval_call(self, node): + def _eval_call(self, node: ast.Call): if isinstance(node.func, ast.Attribute): func = self._eval(node.func) else: try: - func = self.functions[node.func.id] + # is a Named function, ast node thinks func exprs don't have id... + func = self.functions[node.func.id] # type: ignore except KeyError: - raise FunctionNotDefined(node.func.id, self.expr) + raise FunctionNotDefined(node.func.id, self.expr) # type: ignore except AttributeError: raise FeatureNotAvailable("Lambda Functions not implemented") @@ -685,10 +691,10 @@ def _eval_call(self, node): *(self._eval(a) for a in node.args), **dict(self._eval(k) for k in node.keywords) ) - def _eval_keyword(self, node): + def _eval_keyword(self, node: ast.keyword): return node.arg, self._eval(node.value) - def _eval_name(self, node): + def _eval_name(self, node: ast.Name): try: # This happens at least for slicing # This is a safe thing to do because it is impossible @@ -717,14 +723,14 @@ def _eval_name(self, node): raise NameNotDefined(node.id, self.expr) - def _eval_subscript(self, node): + def _eval_subscript(self, node: ast.Subscript): container = self._eval(node.value) key = self._eval(node.slice) # Currently if there's a KeyError, that gets raised straight up. # TODO: Should that be wrapped in an InvalidExpression? return container[key] - def _eval_attribute(self, node): + def _eval_attribute(self, node: ast.Attribute): # DISALLOW_PREFIXES & DISALLOW_METHODS are global, there's never any access to # attrs with these names, so we can bail early: for prefix in DISALLOW_PREFIXES: @@ -754,7 +760,7 @@ def _eval_attribute(self, node): f"Sorry, attribute access not allowed on '{type_to_check}'" f" (attempted to access `.{node.attr}`)" ) - if node.attr not in allowed_attrs: + if node.attr not in allowed_attrs: # type: ignore raise FeatureNotAvailable( f"Sorry, '.{node.attr}' access not allowed on '{type_to_check}'" ) @@ -775,10 +781,10 @@ def _eval_attribute(self, node): # If it is neither, raise an exception raise AttributeDoesNotExist(node.attr, self.expr) - def _eval_index(self, node): - return self._eval(node.value) + def _eval_index(self, node: ast.Index): + return self._eval(node.value) # type: ignore - def _eval_slice(self, node): + def _eval_slice(self, node: ast.Slice): lower = upper = step = None if node.lower is not None: lower = self._eval(node.lower) @@ -788,7 +794,7 @@ def _eval_slice(self, node): step = self._eval(node.step) return slice(lower, upper, step) - def _eval_joinedstr(self, node): + def _eval_joinedstr(self, node: ast.JoinedStr): length = 0 evaluated_values = [] for n in node.values: @@ -798,7 +804,7 @@ def _eval_joinedstr(self, node): evaluated_values.append(val) return "".join(evaluated_values) - def _eval_formattedvalue(self, node): + def _eval_formattedvalue(self, node: ast.FormattedValue): if node.format_spec: fmt = "{:" + self._eval(node.format_spec) + "}" return fmt.format(self._eval(node.value)) @@ -813,8 +819,14 @@ class EvalWithCompoundTypes(SimpleEval): _max_count = 0 - def __init__(self, operators=None, functions=None, names=None): - super(EvalWithCompoundTypes, self).__init__(operators, functions, names) + def __init__( + self, + operators: Optional[Dict[Type, Callable]] = None, + functions: Optional[Dict[str, Callable]] = None, + names: Optional[Dict[str, Any]] = None, + allowed_attrs: Optional[Dict[Union[Type, None], Set]] = None, + ): + super().__init__(operators, functions, names, allowed_attrs) self.functions.update(list=list, tuple=tuple, dict=dict, set=set) @@ -833,7 +845,7 @@ def __init__(self, operators=None, functions=None, names=None): def eval(self, expr, previously_parsed=None): # reset _max_count for each eval run self._max_count = 0 - return super(EvalWithCompoundTypes, self).eval(expr, previously_parsed) + return super().eval(expr, previously_parsed) def _eval_dict(self, node): result = {} @@ -858,19 +870,21 @@ def _eval_list(self, node): return result - def _eval_tuple(self, node): + def _eval_tuple(self, node) -> Tuple[Any]: return tuple(self._eval(x) for x in node.elts) - def _eval_set(self, node): + def _eval_set(self, node) -> Set[Any]: return set(self._eval(x) for x in node.elts) def _eval_comprehension(self, node): + to_return: Union[Dict[Any, Any], List[Any]] + if isinstance(node, ast.DictComp): to_return = {} else: to_return = [] - extra_names = {} + extra_names: Dict[str, Any] = {} previous_name_evaller = self.nodes[ast.Name] @@ -920,7 +934,7 @@ def do_generator(gi=0): return to_return -def simple_eval(expr, operators=None, functions=None, names=None, allowed_attrs=None): +def simple_eval(expr: str, operators=None, functions=None, names=None, allowed_attrs=None): """Simply evaluate an expresssion""" s = SimpleEval( operators=operators,