diff --git a/dace/symbolic.py b/dace/symbolic.py index beb8ccb288..d1c1b89ee2 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -1,6 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast from functools import lru_cache +import sys import sympy import pickle import re @@ -982,6 +983,32 @@ def _process_is(elem: Union[Is, IsNot]): return expr +# Depending on the Python version we need to handle different AST nodes to correctly interpret and detect falsy / truthy +# values. +if sys.version_info < (3, 8): + _SimpleASTNode = (ast.Constant, ast.Name, ast.NameConstant, ast.Num) + _SimpleASTNodeT = Union[ast.Constant, ast.Name, ast.NameConstant, ast.Num] + + def __comp_convert_truthy_falsy(node: _SimpleASTNodeT): + if isinstance(node, ast.Num): + node_val = node.n + elif isinstance(node, ast.Name): + node_val = node.id + else: + node_val = node.value + return ast.copy_location(ast.NameConstant(bool(node_val)), node) +else: + _SimpleASTNode = (ast.Constant, ast.Name) + _SimpleASTNodeT = Union[ast.Constant, ast.Name] + + def __comp_convert_truthy_falsy(node: _SimpleASTNodeT): + return ast.copy_location(ast.Constant(bool(node.value)), node) + +# Convert simple AST node (constant) into a falsy / truthy. Anything other than 0, None, and an empty string '' is +# considered a truthy, while the listed exceptions are considered falsy values - following the semantics of Python's +# bool() builtin. +_convert_truthy_falsy = __comp_convert_truthy_falsy + class PythonOpToSympyConverter(ast.NodeTransformer): """ Replaces various operations with the appropriate SymPy functions to avoid non-symbolic evaluation. @@ -1078,10 +1105,10 @@ def visit_Compare(self, node: ast.Compare): if self.interpret_numeric_booleans: # Ensure constant values in boolean comparisons are interpreted als booleans. - if isinstance(node.left, ast.Compare) and isinstance(node.comparators[0], ast.Constant): - arguments[1] = ast.copy_location(ast.Constant(bool(node.comparators[0].value)), node.comparators[0]) - elif isinstance(node.left, ast.Constant) and isinstance(node.comparators[0], ast.Compare): - arguments[0] = ast.copy_location(ast.Constant(bool(node.left.value)), node.left) + if isinstance(node.left, ast.Compare) and isinstance(node.comparators[0], _SimpleASTNode): + arguments[1] = _convert_truthy_falsy(node.comparators[0]) + elif isinstance(node.left, _SimpleASTNode) and isinstance(node.comparators[0], ast.Compare): + arguments[0] = _convert_truthy_falsy(node.left) func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_comparators[type(op)], ctx=ast.Load()), node) new_node = ast.Call(func=func_node, args=[self.visit(arg) for arg in arguments], keywords=[])