From 1ae15686d985850f2d58fa627a0b04112213b4af Mon Sep 17 00:00:00 2001 From: mrava87 Date: Mon, 21 Oct 2024 23:24:56 +0300 Subject: [PATCH 1/2] feat: improved handling of shapes --- pylops_mpi/DistributedArray.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 6e5a471..4caa2db 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -5,6 +5,7 @@ from enum import Enum from pylops.utils import DTypeLike, NDArray +from pylops.utils._internal import _value_or_sized_to_tuple from pylops.utils.backend import get_module, get_array_module, get_module_name @@ -78,7 +79,7 @@ class DistributedArray: axis : :obj:`int`, optional Axis along which distribution occurs. Defaults to ``0``. local_shapes : :obj:`list`, optional - List of tuples representing local shapes at each rank. + List of tuples of integers representing local shapes at each rank. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) dtype : :obj:`str`, optional @@ -88,7 +89,7 @@ class DistributedArray: def __init__(self, global_shape: Union[Tuple, Integral], base_comm: Optional[MPI.Comm] = MPI.COMM_WORLD, partition: Partition = Partition.SCATTER, axis: int = 0, - local_shapes: Optional[List[Tuple]] = None, + local_shapes: Optional[List[Union[Tuple, Integral]]] = None, engine: Optional[str] = "numpy", dtype: Optional[DTypeLike] = np.float64): if isinstance(global_shape, Integral): @@ -100,10 +101,13 @@ def __init__(self, global_shape: Union[Tuple, Integral], raise ValueError(f"Should be either {Partition.BROADCAST} " f"or {Partition.SCATTER}") self.dtype = dtype - self._global_shape = global_shape + self._global_shape = _value_or_sized_to_tuple(global_shape) self._base_comm = base_comm self._partition = partition self._axis = axis + + if local_shapes is not None: + local_shapes = [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes] self._check_local_shapes(local_shapes) self._local_shape = local_shapes[base_comm.rank] if local_shapes else local_split(global_shape, base_comm, partition, axis) From db0fba5a4462a5f79b4db5677c4f70a073bca70c Mon Sep 17 00:00:00 2001 From: mrava87 Date: Tue, 22 Oct 2024 00:06:10 +0300 Subject: [PATCH 2/2] minor: small code simplication --- pylops_mpi/DistributedArray.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 4caa2db..0b2e2c6 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -79,7 +79,7 @@ class DistributedArray: axis : :obj:`int`, optional Axis along which distribution occurs. Defaults to ``0``. local_shapes : :obj:`list`, optional - List of tuples of integers representing local shapes at each rank. + List of tuples or integers representing local shapes at each rank. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) dtype : :obj:`str`, optional @@ -106,8 +106,7 @@ def __init__(self, global_shape: Union[Tuple, Integral], self._partition = partition self._axis = axis - if local_shapes is not None: - local_shapes = [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes] + local_shapes = local_shapes if local_shapes is None else [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes] self._check_local_shapes(local_shapes) self._local_shape = local_shapes[base_comm.rank] if local_shapes else local_split(global_shape, base_comm, partition, axis)