Skip to content

Commit

Permalink
fix: Update number of ports for PartialOps, and sanitize orderd edges (
Browse files Browse the repository at this point in the history
…#1635)

The number of ports in a `PartialOp` dataflow node didn't get updated
after connecting it.
We didn't see any problem with this because by default the number of
output ports is grown when connecting new outputs, so the count ended up
being correct.

#1625 found a bug where an `UnpackTuple` didn't connect its last output,
so the serialization though the node had less output ports than it
should have, and connected the order edge at an incorrect offset.

As part of this change I added various checks for the order edge index.
We use a special `-1` offset to identify them, so we should panic when
that is seen on invalid places.

- drive-by: Add a `Hugr::has_link` method
- drive-by: Avoid adding duplicated order edges

Closes #1625
  • Loading branch information
aborgna-q authored Nov 6, 2024
1 parent bf9d58b commit 81a1385
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 14 deletions.
11 changes: 10 additions & 1 deletion hugr-py/src/hugr/build/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,10 @@ def add_state_order(self, src: Node, dst: Node) -> None:
[Node(2)]
"""
# adds edge to the right of all existing edges
self.hugr.add_link(src.out(-1), dst.inp(-1))
source = src.out(-1)
target = dst.inp(-1)
if not self.hugr.has_link(source, target):
self.hugr.add_link(source, target)

def load(
self, const: ToNode | val.Value, const_parent: ToNode | None = None
Expand Down Expand Up @@ -616,6 +619,12 @@ def _wire_up(self, node: Node, ports: Iterable[Wire]) -> tys.TypeRow:
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops._PartialOp):
op._set_in_types(tys)
if isinstance(op, ops.DataflowOp):
# Update the node's input and output port count
sig = op.outer_signature()
self.hugr._update_port_count(
node, num_inps=len(sig.input), num_outs=len(sig.output)
)
return tys

def _get_dataflow_type(self, wire: Wire) -> tys.Type:
Expand Down
60 changes: 47 additions & 13 deletions hugr-py/src/hugr/hugr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,31 @@ def _update_node_outs(self, node: Node, num_outs: int | None) -> Node:
Returns:
The updated node.
"""
self[node]._num_outs = num_outs or 0
node = replace(node, _num_out_ports=num_outs)
parent = self[node].parent
if parent is not None:
pos = self[parent].children.index(node)
self[parent].children[pos] = node
return self._update_port_count(node, num_outs=num_outs)

def _update_port_count(
self, node: Node, *, num_inps: int | None = None, num_outs: int | None
) -> Node:
"""Update the number of incoming and outgoing ports for a node.
If `num_inps` or `num_outs` is None, the corresponding count is not updated.
Returns:
The updated node.
"""
if num_inps is None and num_outs is None:
return node

if num_inps is not None:
self[node]._num_inps = num_inps
if num_outs is not None:
self[node]._num_outs = num_outs
node = replace(node, _num_out_ports=num_outs)
parent = self[node].parent
if parent is not None:
pos = self[parent].children.index(node)
self[parent].children[pos] = node

return node

def add_node(
Expand Down Expand Up @@ -284,6 +303,24 @@ def _unused_sub_offset(self, port: P) -> _SubPort[P]:
sub_port = sub_port.next_sub_offset()
return sub_port

def has_link(self, src: OutPort, dst: InPort) -> bool:
"""Check if there is a link between two ports.
Args:
src: Source port.
dst: Destination port.
Returns:
True if there is a link, False otherwise.
Examples:
>>> df = dfg.Dfg(tys.Bool)
>>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0))
>>> df.hugr.has_link(df.input_node.out(0), df.output_node.inp(0))
True
"""
return dst in self.linked_ports(src)

def add_link(self, src: OutPort, dst: InPort) -> None:
"""Add a link (edge) between two nodes to the HUGR,
from an outgoing port to an incoming port.
Expand Down Expand Up @@ -604,14 +641,11 @@ def _serialize_link(
)

def _constrain_offset(self, p: P) -> PortOffset:
# negative offsets are used to refer to the last port
# An offset of -1 is a special case, indicating an order edge,
# not counted in the number of ports.
if p.offset < 0:
match p.direction:
case Direction.INCOMING:
current = self.num_incoming(p.node)
case Direction.OUTGOING:
current = self.num_outgoing(p.node)
offset = current + p.offset + 1
assert p.offset == -1, "Only order edges are allowed with offset < 0"
offset = self.num_ports(p.node, p.direction)
else:
offset = p.offset

Expand Down
5 changes: 5 additions & 0 deletions hugr-py/src/hugr/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def name(self) -> str:


def _sig_port_type(sig: tys.FunctionType, port: InPort | OutPort) -> tys.Type:
"""Get the type of the given dataflow port given the signature of the operation."""
if port.offset == -1:
# Order port
msg = "Order port has no type."
raise ValueError(msg)
if port.direction == Direction.INCOMING:
return sig.input[port.offset]
return sig.output[port.offset]
Expand Down
14 changes: 14 additions & 0 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,17 @@ def test_alias() -> None:
_dcl = mod.add_alias_decl("my_bool", tys.TypeBound.Copyable)

validate(mod.hugr)


# https://github.com/CQCL/hugr/issues/1625
def test_dfg_unpack() -> None:
dfg = Dfg(tys.Tuple(tys.Bool, tys.Bool))
bool1, _unused_bool2 = dfg.add_op(ops.UnpackTuple(), *dfg.inputs())
cond = dfg.add_conditional(bool1)
with cond.add_case(0) as case:
case.set_outputs(bool1)
with cond.add_case(1) as case:
case.set_outputs(bool1)
dfg.set_outputs(*cond.outputs())

validate(dfg.hugr)

0 comments on commit 81a1385

Please sign in to comment.