diff --git a/pytato/distributed/nodes.py b/pytato/distributed/nodes.py index b18b351a3..38fec57eb 100644 --- a/pytato/distributed/nodes.py +++ b/pytato/distributed/nodes.py @@ -61,7 +61,7 @@ from pytato.array import ( Array, _SuppliedShapeAndDtypeMixin, ShapeType, AxesT, - _get_default_axes, ConvertibleToShape, normalize_shape) + _get_default_axes, _get_default_tags, ConvertibleToShape, normalize_shape) CommTagType = Hashable @@ -220,7 +220,7 @@ def make_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagTyp DistributedSend: """Make a :class:`DistributedSend` object.""" return DistributedSend(data=sent_data, dest_rank=dest_rank, comm_tag=comm_tag, - tags=send_tags) + tags=(send_tags | _get_default_tags())) def make_distributed_send_ref_holder( @@ -230,7 +230,8 @@ def make_distributed_send_ref_holder( ) -> DistributedSendRefHolder: """Make a :class:`DistributedSendRefHolder` object.""" return DistributedSendRefHolder( - send=send, passthrough_data=passthrough_data, tags=tags) + send=send, passthrough_data=passthrough_data, + tags=(tags | _get_default_tags())) def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagType, @@ -241,8 +242,9 @@ def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagT """Make a :class:`DistributedSend` object wrapped in a :class:`DistributedSendRefHolder` object.""" return make_distributed_send_ref_holder( - send=DistributedSend(data=sent_data, dest_rank=dest_rank, - comm_tag=comm_tag, tags=send_tags), + send=make_distributed_send( + sent_data=sent_data, dest_rank=dest_rank, comm_tag=comm_tag, + send_tags=send_tags), passthrough_data=stapled_to, tags=ref_holder_tags) @@ -261,7 +263,7 @@ def make_distributed_recv(src_rank: int, comm_tag: CommTagType, dtype = np.dtype(dtype) return DistributedRecv( src_rank=src_rank, comm_tag=comm_tag, shape=shape, dtype=dtype, - tags=tags, axes=axes) + axes=axes, tags=(tags | _get_default_tags())) # }}}