diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 2b75a73ffa..d717c0a85f 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -108,40 +108,40 @@ def to_json(self, parent): def __repr__(self): return type(self).__name__ + ' (' + self.__str__() + ')' - def add_in_connector(self, connector_name: str, dtype: dtypes.typeclass = None, force: bool = False): + def add_in_connector(self, connector_name: str, dtype: Any = None, force: bool = False): """ Adds a new input connector to the node. The operation will fail if a connector (either input or output) with the same name already exists in the node. :param connector_name: The name of the new connector. :param dtype: The type of the connector, or None for auto-detect. - :param force: Add connector even if output connector already exists. + :param force: Add connector even if input or output connector of that name already exists. :return: True if the operation is successful, otherwise False. """ if (not force and (connector_name in self.in_connectors or connector_name in self.out_connectors)): return False - connectors = self.in_connectors - connectors[connector_name] = dtype - self.in_connectors = connectors + if not isinstance(dtype, dace.typeclass): + dtype = dace.typeclass(dtype) + self.in_connectors[connector_name] = dtype return True - def add_out_connector(self, connector_name: str, dtype: dtypes.typeclass = None, force: bool = False): + def add_out_connector(self, connector_name: str, dtype: Any = None, force: bool = False,) -> bool: """ Adds a new output connector to the node. The operation will fail if a connector (either input or output) with the same name already exists in the node. :param connector_name: The name of the new connector. :param dtype: The type of the connector, or None for auto-detect. - :param force: Add connector even if input connector already exists. + :param force: Add connector even if input or output connector of that name already exists. :return: True if the operation is successful, otherwise False. """ if (not force and (connector_name in self.in_connectors or connector_name in self.out_connectors)): return False - connectors = self.out_connectors - connectors[connector_name] = dtype - self.out_connectors = connectors + if not isinstance(dtype, dace.typeclass): + dtype = dace.typeclass(dtype) + self.out_connectors[connector_name] = dtype return True def remove_in_connector(self, connector_name: str):