Skip to content

Commit

Permalink
add default tags to send/recv nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm authored and inducer committed Feb 6, 2024
1 parent d2b5498 commit dc1421b
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions pytato/distributed/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

from pytato.array import (
Array, _SuppliedShapeAndDtypeMixin, ShapeType, AxesT,
_get_default_axes, ConvertibleToShape, normalize_shape)
_get_default_axes, _get_default_tags, ConvertibleToShape, normalize_shape)

CommTagType = Hashable

Expand Down Expand Up @@ -220,7 +220,7 @@ def make_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagTyp
DistributedSend:
"""Make a :class:`DistributedSend` object."""
return DistributedSend(data=sent_data, dest_rank=dest_rank, comm_tag=comm_tag,
tags=send_tags)
tags=(send_tags | _get_default_tags()))


def make_distributed_send_ref_holder(
Expand All @@ -230,7 +230,8 @@ def make_distributed_send_ref_holder(
) -> DistributedSendRefHolder:
"""Make a :class:`DistributedSendRefHolder` object."""
return DistributedSendRefHolder(
send=send, passthrough_data=passthrough_data, tags=tags)
send=send, passthrough_data=passthrough_data,
tags=(tags | _get_default_tags()))


def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagType,
Expand All @@ -241,8 +242,9 @@ def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagT
"""Make a :class:`DistributedSend` object wrapped in a
:class:`DistributedSendRefHolder` object."""
return make_distributed_send_ref_holder(
send=DistributedSend(data=sent_data, dest_rank=dest_rank,
comm_tag=comm_tag, tags=send_tags),
send=make_distributed_send(
sent_data=sent_data, dest_rank=dest_rank, comm_tag=comm_tag,
send_tags=send_tags),
passthrough_data=stapled_to,
tags=ref_holder_tags)

Expand All @@ -261,7 +263,7 @@ def make_distributed_recv(src_rank: int, comm_tag: CommTagType,
dtype = np.dtype(dtype)
return DistributedRecv(
src_rank=src_rank, comm_tag=comm_tag, shape=shape, dtype=dtype,
tags=tags, axes=axes)
axes=axes, tags=(tags | _get_default_tags()))

# }}}

Expand Down

0 comments on commit dc1421b

Please sign in to comment.