From 81a1385fd56a9a12b84153756b4c0bb046808c50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Wed, 6 Nov 2024 13:33:24 +0000 Subject: [PATCH] fix: Update number of ports for PartialOps, and sanitize orderd edges (#1635) The number of ports in a `PartialOp` dataflow node didn't get updated after connecting it. We didn't see any problem with this because by default the number of output ports is grown when connecting new outputs, so the count ended up being correct. #1625 found a bug where an `UnpackTuple` didn't connect its last output, so the serialization though the node had less output ports than it should have, and connected the order edge at an incorrect offset. As part of this change I added various checks for the order edge index. We use a special `-1` offset to identify them, so we should panic when that is seen on invalid places. - drive-by: Add a `Hugr::has_link` method - drive-by: Avoid adding duplicated order edges Closes #1625 --- hugr-py/src/hugr/build/dfg.py | 11 +++++- hugr-py/src/hugr/hugr/base.py | 60 +++++++++++++++++++++++++------- hugr-py/src/hugr/ops.py | 5 +++ hugr-py/tests/test_hugr_build.py | 14 ++++++++ 4 files changed, 76 insertions(+), 14 deletions(-) diff --git a/hugr-py/src/hugr/build/dfg.py b/hugr-py/src/hugr/build/dfg.py index 8af4ae1ad..48a710df7 100644 --- a/hugr-py/src/hugr/build/dfg.py +++ b/hugr-py/src/hugr/build/dfg.py @@ -510,7 +510,10 @@ def add_state_order(self, src: Node, dst: Node) -> None: [Node(2)] """ # adds edge to the right of all existing edges - self.hugr.add_link(src.out(-1), dst.inp(-1)) + source = src.out(-1) + target = dst.inp(-1) + if not self.hugr.has_link(source, target): + self.hugr.add_link(source, target) def load( self, const: ToNode | val.Value, const_parent: ToNode | None = None @@ -616,6 +619,12 @@ def _wire_up(self, node: Node, ports: Iterable[Wire]) -> tys.TypeRow: tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] if isinstance(op := self.hugr[node].op, ops._PartialOp): op._set_in_types(tys) + if isinstance(op, ops.DataflowOp): + # Update the node's input and output port count + sig = op.outer_signature() + self.hugr._update_port_count( + node, num_inps=len(sig.input), num_outs=len(sig.output) + ) return tys def _get_dataflow_type(self, wire: Wire) -> tys.Type: diff --git a/hugr-py/src/hugr/hugr/base.py b/hugr-py/src/hugr/hugr/base.py index 979021057..14e40e795 100644 --- a/hugr-py/src/hugr/hugr/base.py +++ b/hugr-py/src/hugr/hugr/base.py @@ -182,12 +182,31 @@ def _update_node_outs(self, node: Node, num_outs: int | None) -> Node: Returns: The updated node. """ - self[node]._num_outs = num_outs or 0 - node = replace(node, _num_out_ports=num_outs) - parent = self[node].parent - if parent is not None: - pos = self[parent].children.index(node) - self[parent].children[pos] = node + return self._update_port_count(node, num_outs=num_outs) + + def _update_port_count( + self, node: Node, *, num_inps: int | None = None, num_outs: int | None + ) -> Node: + """Update the number of incoming and outgoing ports for a node. + + If `num_inps` or `num_outs` is None, the corresponding count is not updated. + + Returns: + The updated node. + """ + if num_inps is None and num_outs is None: + return node + + if num_inps is not None: + self[node]._num_inps = num_inps + if num_outs is not None: + self[node]._num_outs = num_outs + node = replace(node, _num_out_ports=num_outs) + parent = self[node].parent + if parent is not None: + pos = self[parent].children.index(node) + self[parent].children[pos] = node + return node def add_node( @@ -284,6 +303,24 @@ def _unused_sub_offset(self, port: P) -> _SubPort[P]: sub_port = sub_port.next_sub_offset() return sub_port + def has_link(self, src: OutPort, dst: InPort) -> bool: + """Check if there is a link between two ports. + + Args: + src: Source port. + dst: Destination port. + + Returns: + True if there is a link, False otherwise. + + Examples: + >>> df = dfg.Dfg(tys.Bool) + >>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0)) + >>> df.hugr.has_link(df.input_node.out(0), df.output_node.inp(0)) + True + """ + return dst in self.linked_ports(src) + def add_link(self, src: OutPort, dst: InPort) -> None: """Add a link (edge) between two nodes to the HUGR, from an outgoing port to an incoming port. @@ -604,14 +641,11 @@ def _serialize_link( ) def _constrain_offset(self, p: P) -> PortOffset: - # negative offsets are used to refer to the last port + # An offset of -1 is a special case, indicating an order edge, + # not counted in the number of ports. if p.offset < 0: - match p.direction: - case Direction.INCOMING: - current = self.num_incoming(p.node) - case Direction.OUTGOING: - current = self.num_outgoing(p.node) - offset = current + p.offset + 1 + assert p.offset == -1, "Only order edges are allowed with offset < 0" + offset = self.num_ports(p.node, p.direction) else: offset = p.offset diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index a97470adf..8c5b845a8 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -72,6 +72,11 @@ def name(self) -> str: def _sig_port_type(sig: tys.FunctionType, port: InPort | OutPort) -> tys.Type: + """Get the type of the given dataflow port given the signature of the operation.""" + if port.offset == -1: + # Order port + msg = "Order port has no type." + raise ValueError(msg) if port.direction == Direction.INCOMING: return sig.input[port.offset] return sig.output[port.offset] diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 3f7145a87..2607a4773 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -316,3 +316,17 @@ def test_alias() -> None: _dcl = mod.add_alias_decl("my_bool", tys.TypeBound.Copyable) validate(mod.hugr) + + +# https://github.com/CQCL/hugr/issues/1625 +def test_dfg_unpack() -> None: + dfg = Dfg(tys.Tuple(tys.Bool, tys.Bool)) + bool1, _unused_bool2 = dfg.add_op(ops.UnpackTuple(), *dfg.inputs()) + cond = dfg.add_conditional(bool1) + with cond.add_case(0) as case: + case.set_outputs(bool1) + with cond.add_case(1) as case: + case.set_outputs(bool1) + dfg.set_outputs(*cond.outputs()) + + validate(dfg.hugr)