Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a function add_scope_connectors() to the Map{Entry, Exit} #1829

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion dace/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ def __init__(self, wrapped_type, typename=None):
# Convert python basic types
if isinstance(wrapped_type, str):
try:

if wrapped_type == "bool":
wrapped_type = numpy.bool_
else:
Expand Down
42 changes: 42 additions & 0 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,43 @@ def add_out_connector(self, connector_name: str, dtype: dtypes.typeclass = None,
self.out_connectors = connectors
return True

def _add_scope_connectors(
self,
connector_name: str,
dtype: Optional[dtypes.typeclass] = None,
force: bool = False,
) -> None:
""" Adds input and output connector names to `self` in one step.

The function will add an input connector with name `'IN_' + connector_name`
and an output connector with name `'OUT_' + connector_name`.
The function is a shorthand for calling `add_in_connector()` and `add_out_connector()`.

:param connector_name: The base name of the new connectors.
:param dtype: The type of the connectors, or `None` for auto-detect.
:param force: Add connector even if input or output connector of that name already exists.
:return: True if the operation is successful, otherwise False.
"""
in_connector_name = "IN_" + connector_name
out_connector_name = "OUT_" + connector_name
if not force:
if in_connector_name in self.in_connectors or in_connector_name in self.out_connectors:
return False
if out_connector_name in self.in_connectors or out_connector_name in self.out_connectors:
return False
# We force unconditionally because we have performed the tests above.
self.add_in_connector(
connector_name=in_connector_name,
dtype=dtype,
force=True,
)
self.add_out_connector(
connector_name=out_connector_name,
dtype=dtype,
force=True,
)
return True

def remove_in_connector(self, connector_name: str):
""" Removes an input connector from the node.

Expand Down Expand Up @@ -741,6 +778,9 @@ class EntryNode(Node):
def validate(self, sdfg, state):
self.map.validate(sdfg, state, self)

add_scope_connectors = Node._add_scope_connectors



# ------------------------------------------------------------------------------

Expand All @@ -752,6 +792,8 @@ class ExitNode(Node):
def validate(self, sdfg, state):
self.map.validate(sdfg, state, self)

add_scope_connectors = Node._add_scope_connectors


# ------------------------------------------------------------------------------

Expand Down
35 changes: 35 additions & 0 deletions tests/sdfg/nodes_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import dace

def test_add_scope_connectors():
sdfg = dace.SDFG("add_scope_connectors_sdfg")
state = sdfg.add_state(is_start_block=True)
me: dace.nodes.MapEntry
mx: dace.nodes.MapExit
me, mx = state.add_map("test_map", ndrange={"__i0": "0:10"})
assert all(
len(mn.in_connectors) == 0 and len(mn.out_connectors) == 0
for mn in [me, mx]
)
me.add_in_connector("IN_T", dtype=dace.float64)
assert len(me.in_connectors) == 1 and me.in_connectors["IN_T"] is dace.float64 and len(me.out_connectors) == 0
assert len(mx.in_connectors) == 0 and len(mx.out_connectors) == 0

# Because there is already an `IN_T` this call will fail.
assert not me.add_scope_connectors("T")
assert len(me.in_connectors) == 1 and me.in_connectors["IN_T"] is dace.float64 and len(me.out_connectors) == 0
assert len(mx.in_connectors) == 0 and len(mx.out_connectors) == 0

# Now it will work, because we specify force, however, the current type for `IN_T` will be overridden.
assert me.add_scope_connectors("T", force=True)
assert len(me.in_connectors) == 1 and me.in_connectors["IN_T"].type is None
assert len(me.out_connectors) == 1 and me.out_connectors["OUT_T"].type is None
assert len(mx.in_connectors) == 0 and len(mx.out_connectors) == 0

# Now tries to the full adding.
assert mx.add_scope_connectors("B", dtype=dace.int64)
assert len(mx.in_connectors) == 1 and mx.in_connectors["IN_B"] is dace.int64
assert len(mx.out_connectors) == 1 and mx.out_connectors["OUT_B"] is dace.int64


if __name__ == "__main__":
test_add_scope_connectors()
Loading