Skip to content

Commit

Permalink
Merge pull request #112 from PyLops/patch-tupleshapes
Browse files Browse the repository at this point in the history
feat: improved handling of shapes
  • Loading branch information
mrava87 authored Oct 24, 2024
2 parents 1ae276c + db0fba5 commit 57b793e
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from enum import Enum

from pylops.utils import DTypeLike, NDArray
from pylops.utils._internal import _value_or_sized_to_tuple
from pylops.utils.backend import get_module, get_array_module, get_module_name


Expand Down Expand Up @@ -78,7 +79,7 @@ class DistributedArray:
axis : :obj:`int`, optional
Axis along which distribution occurs. Defaults to ``0``.
local_shapes : :obj:`list`, optional
List of tuples representing local shapes at each rank.
List of tuples or integers representing local shapes at each rank.
engine : :obj:`str`, optional
Engine used to store array (``numpy`` or ``cupy``)
dtype : :obj:`str`, optional
Expand All @@ -88,7 +89,7 @@ class DistributedArray:
def __init__(self, global_shape: Union[Tuple, Integral],
base_comm: Optional[MPI.Comm] = MPI.COMM_WORLD,
partition: Partition = Partition.SCATTER, axis: int = 0,
local_shapes: Optional[List[Tuple]] = None,
local_shapes: Optional[List[Union[Tuple, Integral]]] = None,
engine: Optional[str] = "numpy",
dtype: Optional[DTypeLike] = np.float64):
if isinstance(global_shape, Integral):
Expand All @@ -100,10 +101,12 @@ def __init__(self, global_shape: Union[Tuple, Integral],
raise ValueError(f"Should be either {Partition.BROADCAST} "
f"or {Partition.SCATTER}")
self.dtype = dtype
self._global_shape = global_shape
self._global_shape = _value_or_sized_to_tuple(global_shape)
self._base_comm = base_comm
self._partition = partition
self._axis = axis

local_shapes = local_shapes if local_shapes is None else [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes]
self._check_local_shapes(local_shapes)
self._local_shape = local_shapes[base_comm.rank] if local_shapes else local_split(global_shape, base_comm,
partition, axis)
Expand Down

0 comments on commit 57b793e

Please sign in to comment.