From e0505602bc708d53380464a533acd65a77092566 Mon Sep 17 00:00:00 2001 From: saivythik <60150574+saivythik@users.noreply.github.com> Date: Mon, 3 Apr 2023 00:11:39 +0530 Subject: [PATCH 1/2] replace factories.array() with DNDarray construct --- heat/core/manipulations.py | 70 +++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index d5e1529586..4119a306d1 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -299,7 +299,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: s0, s1 = arr0.split, arr1.split # no splits, local concat if s0 is None and s1 is None: - return factories.array( + return DNDarray( torch.cat((arr0.larray, arr1.larray), dim=axis), device=arr0.device, comm=arr0.comm ) @@ -313,7 +313,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: elif (s0 is None and s1 != axis) or (s1 is None and s0 != axis): _, _, arr0_slice = arr1.comm.chunk(arr0.shape, arr1.split) _, _, arr1_slice = arr0.comm.chunk(arr1.shape, arr0.split) - out = factories.array( + out = DNDarray( torch.cat((arr0.larray[arr0_slice], arr1.larray[arr1_slice]), dim=axis), dtype=out_dtype, is_split=s1 if s1 is not None else s0, @@ -328,7 +328,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: # the axis is different than the split axis, this case can be easily implemented # torch cat arrays together and return a new array that is_split - out = factories.array( + out = DNDarray( torch.cat((arr0.larray, arr1.larray), dim=axis), dtype=out_dtype, is_split=s0, @@ -498,7 +498,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: t_arr1.unsqueeze_(axis) res = torch.cat((t_arr0, t_arr1), dim=axis) - out = factories.array( + out = DNDarray( res, is_split=s0 if s0 is not None else s1, dtype=out_dtype, @@ -581,7 +581,7 @@ def diag(a: DNDarray, offset: int = 0) -> DNDarray: local = torch.zeros(lshape, dtype=a.dtype.torch_type(), device=a.device.torch_device) local[indices_x, indices_y] = a.larray[indices_x] - return factories.array(local, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm) + return DNDarray(local, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm) def diagonal(a: DNDarray, offset: int = 0, dim1: int = 0, dim2: int = 1) -> DNDarray: @@ -655,7 +655,7 @@ def diagonal(a: DNDarray, offset: int = 0, dim1: int = 0, dim2: int = 1) -> DNDa vz = 1 if a.split == dim1 else -1 off, _, _ = a.comm.chunk(a.shape, a.split) result = torch.diagonal(a.larray, offset=offset + vz * off, dim1=dim1, dim2=dim2) - return factories.array(result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm) + return DNDarray(result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm) def dsplit(x: Sequence[DNDarray, ...], indices_or_sections: Iterable) -> List[DNDarray, ...]: @@ -806,14 +806,14 @@ def flatten(a: DNDarray) -> DNDarray: sanitation.sanitize_in(a) if a.split is None: - return factories.array( + return DNDarray( torch.flatten(a.larray), dtype=a.dtype, is_split=None, device=a.device, comm=a.comm ) if a.split > 0: a = resplit(a, 0) - a = factories.array( + a = DNDarray( torch.flatten(a.larray), dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm ) a.balance_() @@ -864,7 +864,7 @@ def flip(a: DNDarray, axis: Union[int, Tuple[int, ...]] = None) -> DNDarray: flipped = torch.flip(a.larray, axis) if a.split not in axis: - return factories.array( + return DNDarray( flipped, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm ) @@ -879,7 +879,7 @@ def flip(a: DNDarray, axis: Union[int, Tuple[int, ...]] = None) -> DNDarray: received = torch.empty(new_lshape, dtype=a.larray.dtype, device=a.device.torch_device) a.comm.Recv(received, source=dest_proc) - res = factories.array(received, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm) + res = DNDarray(received, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm) res.balance_() # after swapping, first processes may be empty req.Wait() return res @@ -1443,7 +1443,7 @@ def pad( padded_torch_tensor, pad_tuple, mode, value_tuple[i] ) - padded_tensor = factories.array( + padded_tensor = DNDarray( padded_torch_tensor, dtype=array.dtype, is_split=array.split, @@ -1485,7 +1485,7 @@ def ravel(a: DNDarray) -> DNDarray: sanitation.sanitize_in(a) if a.split is None: - return factories.array( + return DNDarray( torch.flatten(a._DNDarray__array), dtype=a.dtype, copy=False, @@ -1498,7 +1498,7 @@ def ravel(a: DNDarray) -> DNDarray: if a.split != 0: return flatten(a) - result = factories.array( + result = DNDarray( torch.flatten(a._DNDarray__array), dtype=a.dtype, copy=False, @@ -1608,9 +1608,9 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr # sanitation `a` if not isinstance(a, DNDarray): if isinstance(a, (int, float)): - a = factories.array([a]) + a = DNDarray([a]) elif isinstance(a, (tuple, list, np.ndarray)): - a = factories.array(a) + a = DNDarray(a) else: raise TypeError( "`a` must be a ht.DNDarray, np.ndarray, list, tuple, integer, or float, currently: {}".format( @@ -1644,7 +1644,7 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr if repeats.dtype == types.int64: pass elif types.can_cast(repeats.dtype, types.int64): - repeats = factories.array( + repeats = DNDarray( repeats, dtype=types.int64, is_split=repeats.split, @@ -1662,7 +1662,7 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr "Invalid dtype for np.ndarray `repeats`. Has to be integer," " but was {}".format(repeats.dtype.type) ) - repeats = factories.array( + repeats = DNDarray( repeats, dtype=types.int64, is_split=None, device=a.device, comm=a.comm ) # invalid list/tuple @@ -1672,7 +1672,7 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr ) # valid list/tuple else: - repeats = factories.array( + repeats = DNDarray( repeats, dtype=types.int64, is_split=None, device=a.device, comm=a.comm ) @@ -1771,7 +1771,7 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr ) repeats = resplit(repeats, 0) flatten_repeats_t = torch.flatten(repeats._DNDarray__array) - repeats = factories.array( + repeats = DNDarray( flatten_repeats_t, is_split=repeats.split, device=repeats.device, @@ -1810,7 +1810,7 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr a._DNDarray__array, repeats._DNDarray__array, axis ) - repeated_array = factories.array( + repeated_array = DNDarray( repeated_array_torch, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm ) repeated_array.balance_() @@ -2218,7 +2218,7 @@ def rot90(m: DNDarray, k: int = 1, axes: Sequence[int, int] = (0, 1)) -> DNDarra raise ValueError("Axes={} out of range for array of ndim={}.".format(axes, m.ndim)) if m.split is None: - return factories.array( + return DNDarray( torch.rot90(m.larray, k, axes), dtype=m.dtype, device=m.device, comm=m.comm ) @@ -2505,14 +2505,14 @@ def sort(a: DNDarray, axis: int = -1, descending: bool = False, out: Optional[DN val = tmp_indices[idx] final_indices[idx] = second_indices[val.item()][idx[1:]] final_indices = final_indices.transpose(0, axis) - return_indices = factories.array( + return_indices = DNDarray( final_indices, dtype=types.int32, is_split=a.split, device=a.device, comm=a.comm ) if out is not None: out.larray = final_result return return_indices else: - tensor = factories.array( + tensor = DNDarray( final_result, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm ) return tensor, return_indices @@ -2613,7 +2613,7 @@ def split(x: DNDarray, indices_or_sections: Iterable, axis: int = 0) -> List[DND elif isinstance(indices_or_sections, (list, tuple, DNDarray)): if isinstance(indices_or_sections, (list, tuple)): - indices_or_sections = factories.array(indices_or_sections) + indices_or_sections = DNDarray(indices_or_sections) if len(indices_or_sections.gshape) != 1: raise ValueError( "Expected indices_or_sections to be 1-dimensional, but was {}-dimensional instead.".format( @@ -2702,7 +2702,7 @@ def split(x: DNDarray, indices_or_sections: Iterable, axis: int = 0) -> List[DND indices_or_sections_t = arithmetics.diff( indices_or_sections_t, prepend=slice_axis.start, append=slice_axis.stop ) - indices_or_sections_t = factories.array( + indices_or_sections_t = DNDarray( indices_or_sections_t, dtype=types.int64, is_split=indices_or_sections_t.split, @@ -2739,7 +2739,7 @@ def split(x: DNDarray, indices_or_sections: Iterable, axis: int = 0) -> List[DND indices_or_sections_t = arithmetics.diff( indices_or_sections_t, prepend=0, append=x.gshape[axis] ) - indices_or_sections_t = factories.array( + indices_or_sections_t = DNDarray( indices_or_sections_t, dtype=types.int64, is_split=indices_or_sections_t.split, @@ -2753,7 +2753,7 @@ def split(x: DNDarray, indices_or_sections: Iterable, axis: int = 0) -> List[DND sub_arrays_t = torch.split(x._DNDarray__array, indices_or_sections_t, axis) sub_arrays_ht = [ - factories.array(sub_DNDarray, dtype=x.dtype, is_split=x.split, device=x.device, comm=x.comm) + DNDarray(sub_DNDarray, dtype=x.dtype, is_split=x.split, device=x.device, comm=x.comm) for sub_DNDarray in sub_arrays_t ] @@ -3091,10 +3091,10 @@ def unique( ) if isinstance(torch_output, tuple): heat_output = tuple( - factories.array(i, dtype=a.dtype, split=None, device=a.device) for i in torch_output + DNDarray(i, dtype=a.dtype, split=None, device=a.device) for i in torch_output ) else: - heat_output = factories.array(torch_output, dtype=a.dtype, split=None, device=a.device) + heat_output = DNDarray(torch_output, dtype=a.dtype, split=None, device=a.device) return heat_output local_data = a.larray @@ -3240,7 +3240,7 @@ def unique( gres = gres.transpose(0, axis) split = split if a.split < len(gres.shape) else None - result = factories.array( + result = DNDarray( gres, dtype=a.dtype, device=a.device, comm=a.comm, split=split, is_split=is_split ) if split is not None: @@ -3372,7 +3372,7 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray: if axis == arr.split: return arr.copy() if not arr.is_distributed(): - return factories.array(arr.larray, split=axis, device=arr.device, copy=True) + return DNDarray(arr.larray, split=axis, device=arr.device, copy=True) if axis is None: # new_arr = arr.copy() @@ -3381,7 +3381,7 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray: ) counts, displs = arr.counts_displs() arr.comm.Allgatherv(arr.larray, (gathered, counts, displs), recv_axis=arr.split) - new_arr = factories.array(gathered, is_split=axis, device=arr.device, dtype=arr.dtype) + new_arr = DNDarray(gathered, is_split=axis, device=arr.device, dtype=arr.dtype) return new_arr arr_tiles = tiling.SplitTiles(arr) @@ -3629,7 +3629,7 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: _ = x.shape raise TypeError("Input can be a DNDarray or a scalar, is {}".format(type(x))) except AttributeError: - x = factories.array(x).reshape(1) + x = DNDarray(x).reshape(1) x_proxy = x.__torch_proxy__() @@ -3956,10 +3956,10 @@ def local_topk(*args, **kwargs): is_split = a.split split = None - final_array = factories.array( + final_array = DNDarray( gres, dtype=a.dtype, device=a.device, split=split, is_split=is_split ) - final_indices = factories.array( + final_indices = DNDarray( gindices, dtype=types.int64, device=a.device, split=split, is_split=is_split ) From 8612040847d3045d1a301b9284401decbe5008ad Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 2 Apr 2023 18:43:08 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- heat/core/manipulations.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 4119a306d1..d1bc2c219b 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -864,9 +864,7 @@ def flip(a: DNDarray, axis: Union[int, Tuple[int, ...]] = None) -> DNDarray: flipped = torch.flip(a.larray, axis) if a.split not in axis: - return DNDarray( - flipped, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm - ) + return DNDarray(flipped, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm) # Need to redistribute tensors on split axis # Get local shapes @@ -2218,9 +2216,7 @@ def rot90(m: DNDarray, k: int = 1, axes: Sequence[int, int] = (0, 1)) -> DNDarra raise ValueError("Axes={} out of range for array of ndim={}.".format(axes, m.ndim)) if m.split is None: - return DNDarray( - torch.rot90(m.larray, k, axes), dtype=m.dtype, device=m.device, comm=m.comm - ) + return DNDarray(torch.rot90(m.larray, k, axes), dtype=m.dtype, device=m.device, comm=m.comm) try: k = int(k) @@ -3956,9 +3952,7 @@ def local_topk(*args, **kwargs): is_split = a.split split = None - final_array = DNDarray( - gres, dtype=a.dtype, device=a.device, split=split, is_split=is_split - ) + final_array = DNDarray(gres, dtype=a.dtype, device=a.device, split=split, is_split=is_split) final_indices = DNDarray( gindices, dtype=types.int64, device=a.device, split=split, is_split=is_split )