Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace factories.array() with DNDarray construct #1138

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 37 additions & 47 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray:
s0, s1 = arr0.split, arr1.split
# no splits, local concat
if s0 is None and s1 is None:
return factories.array(
return DNDarray(
torch.cat((arr0.larray, arr1.larray), dim=axis), device=arr0.device, comm=arr0.comm
)

Expand All @@ -464,7 +464,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray:
elif (s0 is None and s1 != axis) or (s1 is None and s0 != axis):
_, _, arr0_slice = arr1.comm.chunk(arr0.shape, arr1.split)
_, _, arr1_slice = arr0.comm.chunk(arr1.shape, arr0.split)
out = factories.array(
out = DNDarray(
torch.cat((arr0.larray[arr0_slice], arr1.larray[arr1_slice]), dim=axis),
dtype=out_dtype,
is_split=s1 if s1 is not None else s0,
Expand All @@ -479,7 +479,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray:
# the axis is different than the split axis, this case can be easily implemented
# torch cat arrays together and return a new array that is_split

out = factories.array(
out = DNDarray(
torch.cat((arr0.larray, arr1.larray), dim=axis),
dtype=out_dtype,
is_split=s0,
Expand Down Expand Up @@ -649,7 +649,7 @@ def concatenate(arrays: Sequence[DNDarray, ...], axis: int = 0) -> DNDarray:
t_arr1.unsqueeze_(axis)

res = torch.cat((t_arr0, t_arr1), dim=axis)
out = factories.array(
out = DNDarray(
res,
is_split=s0 if s0 is not None else s1,
dtype=out_dtype,
Expand Down Expand Up @@ -732,7 +732,7 @@ def diag(a: DNDarray, offset: int = 0) -> DNDarray:
local = torch.zeros(lshape, dtype=a.dtype.torch_type(), device=a.device.torch_device)
local[indices_x, indices_y] = a.larray[indices_x]

return factories.array(local, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm)
return DNDarray(local, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm)


def diagonal(a: DNDarray, offset: int = 0, dim1: int = 0, dim2: int = 1) -> DNDarray:
Expand Down Expand Up @@ -805,10 +805,11 @@ def diagonal(a: DNDarray, offset: int = 0, dim1: int = 0, dim2: int = 1) -> DNDa
else:
vz = 1 if a.split == dim1 else -1
off, _, _ = a.comm.chunk(a.shape, a.split)
return DNDarray(result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm)
result = torch.diagonal(
a.larray, offset=offset + vz * off, dim1=dim1, dim2=dim2
).contiguous()
return factories.array(result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm)
return DNDarray(result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm)


def dsplit(x: Sequence[DNDarray, ...], indices_or_sections: Iterable) -> List[DNDarray, ...]:
Expand Down Expand Up @@ -959,14 +960,14 @@ def flatten(a: DNDarray) -> DNDarray:
sanitation.sanitize_in(a)

if a.split is None:
return factories.array(
return DNDarray(
torch.flatten(a.larray), dtype=a.dtype, is_split=None, device=a.device, comm=a.comm
)

if a.split > 0:
a = resplit(a, 0)

a = factories.array(
a = DNDarray(
torch.flatten(a.larray), dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm
)
a.balance_()
Expand Down Expand Up @@ -1017,9 +1018,7 @@ def flip(a: DNDarray, axis: Union[int, Tuple[int, ...]] = None) -> DNDarray:
flipped = torch.flip(a.larray, axis)

if a.split not in axis:
return factories.array(
flipped, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm
)
return DNDarray(flipped, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm)

# Need to redistribute tensors on split axis
# Get local shapes
Expand All @@ -1032,7 +1031,7 @@ def flip(a: DNDarray, axis: Union[int, Tuple[int, ...]] = None) -> DNDarray:
received = torch.empty(new_lshape, dtype=a.larray.dtype, device=a.device.torch_device)
a.comm.Recv(received, source=dest_proc)

res = factories.array(received, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm)
res = DNDarray(received, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm)
res.balance_() # after swapping, first processes may be empty
req.Wait()
return res
Expand Down Expand Up @@ -1589,7 +1588,7 @@ def pad(
padded_torch_tensor, pad_tuple, mode, value_tuple[i]
)

padded_tensor = factories.array(
padded_tensor = DNDarray(
padded_torch_tensor,
dtype=array.dtype,
is_split=array.split,
Expand Down Expand Up @@ -1631,7 +1630,7 @@ def ravel(a: DNDarray) -> DNDarray:
sanitation.sanitize_in(a)

if a.split is None:
return factories.array(
return DNDarray(
torch.flatten(a._DNDarray__array),
dtype=a.dtype,
copy=False,
Expand All @@ -1644,7 +1643,7 @@ def ravel(a: DNDarray) -> DNDarray:
if a.split != 0:
return flatten(a)

result = factories.array(
result = DNDarray(
torch.flatten(a._DNDarray__array),
dtype=a.dtype,
copy=False,
Expand Down Expand Up @@ -1754,9 +1753,9 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr
# sanitation `a`
if not isinstance(a, DNDarray):
if isinstance(a, (int, float)):
a = factories.array([a])
a = DNDarray([a])
elif isinstance(a, (tuple, list, np.ndarray)):
a = factories.array(a)
a = DNDarray(a)
else:
raise TypeError(
f"`a` must be a ht.DNDarray, np.ndarray, list, tuple, integer, or float, currently: {type(a)}"
Expand Down Expand Up @@ -1784,7 +1783,7 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr
if repeats.dtype == types.int64:
pass
elif types.can_cast(repeats.dtype, types.int64):
repeats = factories.array(
repeats = DNDarray(
repeats,
dtype=types.int64,
is_split=repeats.split,
Expand All @@ -1800,15 +1799,15 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr
raise TypeError(
f"Invalid dtype for np.ndarray `repeats`. Has to be integer, but was {repeats.dtype.type}"
)
repeats = factories.array(
repeats = DNDarray(
repeats, dtype=types.int64, is_split=None, device=a.device, comm=a.comm
)
elif not all(isinstance(r, int) for r in repeats):
raise TypeError(
"Invalid type within `repeats`. All components of `repeats` must be integers."
)
else:
repeats = factories.array(
repeats = DNDarray(
repeats, dtype=types.int64, is_split=None, device=a.device, comm=a.comm
)

Expand Down Expand Up @@ -1900,7 +1899,7 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr
)
repeats = resplit(repeats, 0)
flatten_repeats_t = torch.flatten(repeats._DNDarray__array)
repeats = factories.array(
repeats = DNDarray(
flatten_repeats_t,
is_split=repeats.split,
device=repeats.device,
Expand Down Expand Up @@ -1936,7 +1935,7 @@ def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarr
a._DNDarray__array, repeats._DNDarray__array, axis
)

repeated_array = factories.array(
repeated_array = DNDarray(
repeated_array_torch, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm
)
repeated_array.balance_()
Expand Down Expand Up @@ -2332,9 +2331,7 @@ def rot90(m: DNDarray, k: int = 1, axes: Sequence[int, int] = (0, 1)) -> DNDarra
raise ValueError(f"Axes={axes} out of range for array of ndim={m.ndim}.")

if m.split is None:
return factories.array(
torch.rot90(m.larray, k, axes), dtype=m.dtype, device=m.device, comm=m.comm
)
return DNDarray(torch.rot90(m.larray, k, axes), dtype=m.dtype, device=m.device, comm=m.comm)

try:
k = int(k)
Expand Down Expand Up @@ -2619,14 +2616,14 @@ def sort(a: DNDarray, axis: int = -1, descending: bool = False, out: Optional[DN
val = tmp_indices[idx]
final_indices[idx] = second_indices[val.item()][idx[1:]]
final_indices = final_indices.transpose(0, axis)
return_indices = factories.array(
return_indices = DNDarray(
final_indices, dtype=types.int32, is_split=a.split, device=a.device, comm=a.comm
)
if out is not None:
out.larray = final_result
return return_indices
else:
tensor = factories.array(
tensor = DNDarray(
final_result, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm
)
return tensor, return_indices
Expand Down Expand Up @@ -2723,7 +2720,7 @@ def split(x: DNDarray, indices_or_sections: Iterable, axis: int = 0) -> List[DND

elif isinstance(indices_or_sections, (list, tuple, DNDarray)):
if isinstance(indices_or_sections, (list, tuple)):
indices_or_sections = factories.array(indices_or_sections)
indices_or_sections = DNDarray(indices_or_sections)
if len(indices_or_sections.gshape) != 1:
raise ValueError(
f"Expected indices_or_sections to be 1-dimensional, but was {len(indices_or_sections.gshape) - 1}-dimensional instead."
Expand Down Expand Up @@ -2806,7 +2803,7 @@ def split(x: DNDarray, indices_or_sections: Iterable, axis: int = 0) -> List[DND
indices_or_sections_t = arithmetics.diff(
indices_or_sections_t, prepend=slice_axis.start, append=slice_axis.stop
)
indices_or_sections_t = factories.array(
indices_or_sections_t = DNDarray(
indices_or_sections_t,
dtype=types.int64,
is_split=indices_or_sections_t.split,
Expand Down Expand Up @@ -2843,7 +2840,7 @@ def split(x: DNDarray, indices_or_sections: Iterable, axis: int = 0) -> List[DND
indices_or_sections_t = arithmetics.diff(
indices_or_sections_t, prepend=0, append=x.gshape[axis]
)
indices_or_sections_t = factories.array(
indices_or_sections_t = DNDarray(
indices_or_sections_t,
dtype=types.int64,
is_split=indices_or_sections_t.split,
Expand All @@ -2856,7 +2853,7 @@ def split(x: DNDarray, indices_or_sections: Iterable, axis: int = 0) -> List[DND

sub_arrays_t = torch.split(x._DNDarray__array, indices_or_sections_t, axis)
sub_arrays_ht = [
factories.array(sub_DNDarray, dtype=x.dtype, is_split=x.split, device=x.device, comm=x.comm)
DNDarray(sub_DNDarray, dtype=x.dtype, is_split=x.split, device=x.device, comm=x.comm)
for sub_DNDarray in sub_arrays_t
]

Expand Down Expand Up @@ -3192,13 +3189,10 @@ def unique(
)
if isinstance(torch_output, tuple):
heat_output = tuple(
factories.array(i, dtype=a.dtype, split=None, device=a.device, comm=a.comm)
for i in torch_output
DNDarray(i, dtype=a.dtype, split=None, device=a.device) for i in torch_output
)
else:
heat_output = factories.array(
torch_output, dtype=a.dtype, split=None, device=a.device, comm=a.comm
)
heat_output = DNDarray(torch_output, dtype=a.dtype, split=None, device=a.device)
return heat_output

local_data = a.larray
Expand Down Expand Up @@ -3344,7 +3338,7 @@ def unique(
gres = gres.transpose(0, axis)

split = split if a.split < len(gres.shape) else None
result = factories.array(
result = DNDarray(
gres, dtype=a.dtype, device=a.device, comm=a.comm, split=split, is_split=is_split
)
if split is not None:
Expand Down Expand Up @@ -3476,7 +3470,7 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray:
if axis == arr.split:
return arr.copy()
if not arr.is_distributed():
return factories.array(arr.larray, split=axis, device=arr.device, comm=arr.comm, copy=True)
return DNDarray(arr.larray, split=axis, device=arr.device, copy=True)

if axis is None:
# new_arr = arr.copy()
Expand All @@ -3485,9 +3479,7 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray:
)
counts, displs = arr.counts_displs()
arr.comm.Allgatherv(arr.larray, (gathered, counts, displs), recv_axis=arr.split)
new_arr = factories.array(
gathered, is_split=axis, device=arr.device, comm=arr.comm, dtype=arr.dtype
)
new_arr = DNDarray(gathered, is_split=axis, device=arr.device, dtype=arr.dtype)
return new_arr
arr_tiles = tiling.SplitTiles(arr)
new_arr = factories.empty(arr.gshape, split=axis, dtype=arr.dtype, device=arr.device)
Expand Down Expand Up @@ -3734,7 +3726,7 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray:
_ = x.shape
raise TypeError(f"Input can be a DNDarray or a scalar, is {type(x)}")
except AttributeError:
x = factories.array(x).reshape(1)
x = DNDarray(x).reshape(1)

x_proxy = x.__torch_proxy__()

Expand Down Expand Up @@ -4059,11 +4051,9 @@ def local_topk(*args, **kwargs):
is_split = a.split
split = None

final_array = factories.array(
gres, dtype=a.dtype, device=a.device, split=split, is_split=is_split
)
final_indices = factories.array(
gindices, dtype=types.int64, device=a.device, comm=a.comm, split=split, is_split=is_split
final_array = DNDarray(gres, dtype=a.dtype, device=a.device, split=split, is_split=is_split)
final_indices = DNDarray(
gindices, dtype=types.int64, device=a.device, split=split, is_split=is_split
)

if out is not None:
Expand Down