Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pystr_to_symbolic not correctly interpreting constants as boolean values in boolean comparisons #1756

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion dace/symbolic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import ast
from functools import lru_cache
import sys
import sympy
import pickle
import re
Expand Down Expand Up @@ -982,6 +983,32 @@ def _process_is(elem: Union[Is, IsNot]):
return expr


# Depending on the Python version we need to handle different AST nodes to correctly interpret and detect falsy / truthy
# values.
if sys.version_info < (3, 8):
_SimpleASTNode = (ast.Constant, ast.Name, ast.NameConstant, ast.Num)
_SimpleASTNodeT = Union[ast.Constant, ast.Name, ast.NameConstant, ast.Num]

def __comp_convert_truthy_falsy(node: _SimpleASTNodeT):
if isinstance(node, ast.Num):
node_val = node.n
elif isinstance(node, ast.Name):
node_val = node.id
else:
node_val = node.value
return ast.copy_location(ast.NameConstant(bool(node_val)), node)
else:
_SimpleASTNode = (ast.Constant, ast.Name)
_SimpleASTNodeT = Union[ast.Constant, ast.Name]

def __comp_convert_truthy_falsy(node: _SimpleASTNodeT):
return ast.copy_location(ast.Constant(bool(node.value)), node)

# Convert simple AST node (constant) into a falsy / truthy. Anything other than 0, None, and an empty string '' is
# considered a truthy, while the listed exceptions are considered falsy values - following the semantics of Python's
# bool() builtin.
_convert_truthy_falsy = __comp_convert_truthy_falsy

class PythonOpToSympyConverter(ast.NodeTransformer):
"""
Replaces various operations with the appropriate SymPy functions to avoid non-symbolic evaluation.
Expand Down Expand Up @@ -1067,6 +1094,13 @@ def visit_Compare(self, node: ast.Compare):
raise NotImplementedError
op = node.ops[0]
arguments = [node.left, node.comparators[0]]

# Ensure constant values in boolean comparisons are interpreted als booleans.
if isinstance(node.left, ast.Compare) and isinstance(node.comparators[0], _SimpleASTNode):
arguments[1] = _convert_truthy_falsy(node.comparators[0])
elif isinstance(node.left, _SimpleASTNode) and isinstance(node.comparators[0], ast.Compare):
arguments[0] = _convert_truthy_falsy(node.left)

func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_comparators[type(op)], ctx=ast.Load()), node)
new_node = ast.Call(func=func_node, args=[self.visit(arg) for arg in arguments], keywords=[])
return ast.copy_location(new_node, node)
Expand Down
23 changes: 22 additions & 1 deletion tests/passes/dead_code_elimination_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
""" Various tests for dead code elimination passes. """

import numpy as np
Expand Down Expand Up @@ -45,6 +45,26 @@ def test_dse_unconditional():
assert set(sdfg.states()) == {s, s2, e}


def test_dse_edge_condition_with_integer_as_boolean_regression():
"""
This is a regression test for issue #1129, which describes dead state elimination incorrectly eliminating interstate
edges when integers are used as boolean values in interstate edge conditions. Code taken from issue #1129.
"""
sdfg = dace.SDFG('dse_edge_condition_with_integer_as_boolean_regression')
sdfg.add_scalar('N', dtype=dace.int32, transient=True)
sdfg.add_scalar('result', dtype=dace.int32)
state_init = sdfg.add_state()
state_middle = sdfg.add_state()
state_end = sdfg.add_state()
sdfg.add_edge(state_init, state_end, dace.InterstateEdge(condition='(not ((N > 20) != 0))',
assignments={'result': 'N'}))
sdfg.add_edge(state_init, state_middle, dace.InterstateEdge(condition='((N > 20) != 0)'))
sdfg.add_edge(state_middle, state_end, dace.InterstateEdge(assignments={'result': '20'}))

res = DeadStateElimination().apply_pass(sdfg, {})
assert res is None


def test_dde_simple():

@dace.program
Expand Down Expand Up @@ -307,6 +327,7 @@ def test_dce_add_type_hint_of_variable(dtype):
if __name__ == '__main__':
test_dse_simple()
test_dse_unconditional()
test_dse_edge_condition_with_integer_as_boolean_regression()
test_dde_simple()
test_dde_libnode()
test_dde_access_node_in_scope(False)
Expand Down