diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index cc3c738e7f..a15db92c22 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -453,7 +453,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: s0, s1 = arr0.split, arr1.split # no splits, local concat if s0 is None and s1 is None: - return factories.array( + return DNDarray( torch.cat((arr0.larray, arr1.larray), dim=axis), device=arr0.device, comm=arr0.comm ) @@ -464,7 +464,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: elif (s0 is None and s1 != axis) or (s1 is None and s0 != axis): _, _, arr0_slice = arr1.comm.chunk(arr0.shape, arr1.split) _, _, arr1_slice = arr0.comm.chunk(arr1.shape, arr0.split) - out = factories.array( + out = DNDarray( torch.cat((arr0.larray[arr0_slice], arr1.larray[arr1_slice]), dim=axis), dtype=out_dtype, is_split=s1 if s1 is not None else s0, @@ -479,7 +479,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: # the axis is different than the split axis, this case can be easily implemented # torch cat arrays together and return a new array that is_split - out = factories.array( + out = DNDarray( torch.cat((arr0.larray, arr1.larray), dim=axis), dtype=out_dtype, is_split=s0, @@ -649,7 +649,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray: t_arr1.unsqueeze_(axis) res = torch.cat((t_arr0, t_arr1), dim=axis) - out = factories.array( + out = DNDarray( res, is_split=s0 if s0 is not None else s1, dtype=out_dtype, @@ -732,7 +732,7 @@ def diag(a: DNDarray, offset: int = 0) -> DNDarray: local = torch.zeros(lshape, dtype=a.dtype.torch_type(), device=a.device.torch_device) local[indices_x, indices_y] = a.larray[indices_x] - return factories.array(local, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm) + return DNDarray(local, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm) def diagonal(a: DNDarray, offset: int = 0, dim1: int = 0, dim2: int = 1) -> DNDarray: @@ -805,10 +805,11 @@ def diagonal(a: DNDarray, offset: int = 0, dim1: int = 0, dim2: int = 1) -> DNDa else: vz = 1 if a.split == dim1 else -1 off, _, _ = a.comm.chunk(a.shape, a.split) + return DNDarray(result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm) result = torch.diagonal( a.larray, offset=offset + vz * off, dim1=dim1, dim2=dim2 ).contiguous() - return factories.array(result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm) + return DNDarray(result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm) def dsplit(x: Sequence[DNDarray, ...], indices_or_sections: Iterable) -> List[DNDarray, ...]: @@ -959,14 +960,14 @@ def flatten(a: DNDarray) -> DNDarray: sanitation.sanitize_in(a) if a.split is None: - return factories.array( + return DNDarray( torch.flatten(a.larray), dtype=a.dtype, is_split=None, device=a.device, comm=a.comm ) if a.split > 0: a = resplit(a, 0) - a = factories.array( + a = DNDarray( torch.flatten(a.larray), dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm ) a.balance_() @@ -1017,9 +1018,7 @@ def flip(a: DNDarray, axis: Union[int, Tuple[int, ...]] = None) -> DNDarray: flipped = torch.flip(a.larray, axis) if a.split not in axis: - return factories.array( - flipped, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm - ) + return DNDarray(flipped, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm) # Need to redistribute tensors on split axis # Get local shapes @@ -1032,7 +1031,7 @@ def flip(a: DNDarray, axis: Union[int, Tuple[int, ...]] = None) -> DNDarray: received = torch.empty(new_lshape, dtype=a.larray.dtype, device=a.device.torch_device) a.comm.Recv(received, source=dest_proc) - res = factories.array(received, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm) + res = DNDarray(received, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm) res.balance_() # after swapping, first processes may be empty req.Wait() return res @@ -1589,7 +1588,7 @@ def pad( padded_torch_tensor, pad_tuple, mode, value_tuple[i] ) - padded_tensor = factories.array( + padded_tensor = DNDarray( padded_torch_tensor, dtype=array.dtype, is_split=array.split, @@ -1631,7 +1630,7 @@ def ravel(a: DNDarray) -> DNDarray: sanitation.sanitize_in(a) if a.split is None: - return factories.array( + return DNDarray( torch.flatten(a._DNDarray__array), dtype=a.dtype, copy=False, @@ -1644,7 +1643,7 @@ def ravel(a: DNDarray) -> DNDarray: if a.split != 0: return flatten(a) - result = factories.array( + result = DNDarray( torch.flatten(a._DNDarray__array), dtype=a.dtype, copy=False, @@ -1754,9 +1753,9 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr # sanitation `a` if not isinstance(a, DNDarray): if isinstance(a, (int, float)): - a = factories.array([a]) + a = DNDarray([a]) elif isinstance(a, (tuple, list, np.ndarray)): - a = factories.array(a) + a = DNDarray(a) else: raise TypeError( f"`a` must be a ht.DNDarray, np.ndarray, list, tuple, integer, or float, currently: {type(a)}" @@ -1784,7 +1783,7 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr if repeats.dtype == types.int64: pass elif types.can_cast(repeats.dtype, types.int64): - repeats = factories.array( + repeats = DNDarray( repeats, dtype=types.int64, is_split=repeats.split, @@ -1800,7 +1799,7 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr raise TypeError( f"Invalid dtype for np.ndarray `repeats`. Has to be integer, but was {repeats.dtype.type}" ) - repeats = factories.array( + repeats = DNDarray( repeats, dtype=types.int64, is_split=None, device=a.device, comm=a.comm ) elif not all(isinstance(r, int) for r in repeats): @@ -1808,7 +1807,7 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr "Invalid type within `repeats`. All components of `repeats` must be integers." ) else: - repeats = factories.array( + repeats = DNDarray( repeats, dtype=types.int64, is_split=None, device=a.device, comm=a.comm ) @@ -1900,7 +1899,7 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr ) repeats = resplit(repeats, 0) flatten_repeats_t = torch.flatten(repeats._DNDarray__array) - repeats = factories.array( + repeats = DNDarray( flatten_repeats_t, is_split=repeats.split, device=repeats.device, @@ -1936,7 +1935,7 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr a._DNDarray__array, repeats._DNDarray__array, axis ) - repeated_array = factories.array( + repeated_array = DNDarray( repeated_array_torch, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm ) repeated_array.balance_() @@ -2332,9 +2331,7 @@ def rot90(m: DNDarray, k: int = 1, axes: Sequence[int, int] = (0, 1)) -> DNDarra raise ValueError(f"Axes={axes} out of range for array of ndim={m.ndim}.") if m.split is None: - return factories.array( - torch.rot90(m.larray, k, axes), dtype=m.dtype, device=m.device, comm=m.comm - ) + return DNDarray(torch.rot90(m.larray, k, axes), dtype=m.dtype, device=m.device, comm=m.comm) try: k = int(k) @@ -2619,14 +2616,14 @@ def sort(a: DNDarray, axis: int = -1, descending: bool = False, out: Optional[DN val = tmp_indices[idx] final_indices[idx] = second_indices[val.item()][idx[1:]] final_indices = final_indices.transpose(0, axis) - return_indices = factories.array( + return_indices = DNDarray( final_indices, dtype=types.int32, is_split=a.split, device=a.device, comm=a.comm ) if out is not None: out.larray = final_result return return_indices else: - tensor = factories.array( + tensor = DNDarray( final_result, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm ) return tensor, return_indices @@ -2723,7 +2720,7 @@ def split(x: DNDarray, indices_or_sections: Iterable, axis: int = 0) -> List[DND elif isinstance(indices_or_sections, (list, tuple, DNDarray)): if isinstance(indices_or_sections, (list, tuple)): - indices_or_sections = factories.array(indices_or_sections) + indices_or_sections = DNDarray(indices_or_sections) if len(indices_or_sections.gshape) != 1: raise ValueError( f"Expected indices_or_sections to be 1-dimensional, but was {len(indices_or_sections.gshape) - 1}-dimensional instead." @@ -2806,7 +2803,7 @@ def split(x: DNDarray, indices_or_sections: Iterable, axis: int = 0) -> List[DND indices_or_sections_t = arithmetics.diff( indices_or_sections_t, prepend=slice_axis.start, append=slice_axis.stop ) - indices_or_sections_t = factories.array( + indices_or_sections_t = DNDarray( indices_or_sections_t, dtype=types.int64, is_split=indices_or_sections_t.split, @@ -2843,7 +2840,7 @@ def split(x: DNDarray, indices_or_sections: Iterable, axis: int = 0) -> List[DND indices_or_sections_t = arithmetics.diff( indices_or_sections_t, prepend=0, append=x.gshape[axis] ) - indices_or_sections_t = factories.array( + indices_or_sections_t = DNDarray( indices_or_sections_t, dtype=types.int64, is_split=indices_or_sections_t.split, @@ -2856,7 +2853,7 @@ def split(x: DNDarray, indices_or_sections: Iterable, axis: int = 0) -> List[DND sub_arrays_t = torch.split(x._DNDarray__array, indices_or_sections_t, axis) sub_arrays_ht = [ - factories.array(sub_DNDarray, dtype=x.dtype, is_split=x.split, device=x.device, comm=x.comm) + DNDarray(sub_DNDarray, dtype=x.dtype, is_split=x.split, device=x.device, comm=x.comm) for sub_DNDarray in sub_arrays_t ] @@ -3192,13 +3189,10 @@ def unique( ) if isinstance(torch_output, tuple): heat_output = tuple( - factories.array(i, dtype=a.dtype, split=None, device=a.device, comm=a.comm) - for i in torch_output + DNDarray(i, dtype=a.dtype, split=None, device=a.device) for i in torch_output ) else: - heat_output = factories.array( - torch_output, dtype=a.dtype, split=None, device=a.device, comm=a.comm - ) + heat_output = DNDarray(torch_output, dtype=a.dtype, split=None, device=a.device) return heat_output local_data = a.larray @@ -3344,7 +3338,7 @@ def unique( gres = gres.transpose(0, axis) split = split if a.split < len(gres.shape) else None - result = factories.array( + result = DNDarray( gres, dtype=a.dtype, device=a.device, comm=a.comm, split=split, is_split=is_split ) if split is not None: @@ -3476,7 +3470,7 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray: if axis == arr.split: return arr.copy() if not arr.is_distributed(): - return factories.array(arr.larray, split=axis, device=arr.device, comm=arr.comm, copy=True) + return DNDarray(arr.larray, split=axis, device=arr.device, copy=True) if axis is None: # new_arr = arr.copy() @@ -3485,9 +3479,7 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray: ) counts, displs = arr.counts_displs() arr.comm.Allgatherv(arr.larray, (gathered, counts, displs), recv_axis=arr.split) - new_arr = factories.array( - gathered, is_split=axis, device=arr.device, comm=arr.comm, dtype=arr.dtype - ) + new_arr = DNDarray(gathered, is_split=axis, device=arr.device, dtype=arr.dtype) return new_arr arr_tiles = tiling.SplitTiles(arr) new_arr = factories.empty(arr.gshape, split=axis, dtype=arr.dtype, device=arr.device) @@ -3734,7 +3726,7 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: _ = x.shape raise TypeError(f"Input can be a DNDarray or a scalar, is {type(x)}") except AttributeError: - x = factories.array(x).reshape(1) + x = DNDarray(x).reshape(1) x_proxy = x.__torch_proxy__() @@ -4059,11 +4051,9 @@ def local_topk(*args, **kwargs): is_split = a.split split = None - final_array = factories.array( - gres, dtype=a.dtype, device=a.device, split=split, is_split=is_split - ) - final_indices = factories.array( - gindices, dtype=types.int64, device=a.device, comm=a.comm, split=split, is_split=is_split + final_array = DNDarray(gres, dtype=a.dtype, device=a.device, split=split, is_split=is_split) + final_indices = DNDarray( + gindices, dtype=types.int64, device=a.device, split=split, is_split=is_split ) if out is not None: