Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: updated _parse_query to use ivy's set_item to use get_item as it's optimized now #28582

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion ivy/functional/backends/tensorflow/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def shape(
*,
as_array: bool = False,
) -> Union[tf.Tensor, ivy.Shape, ivy.Array]:
if as_array:
if as_array or not tf.executing_eagerly():
return ivy.array(tf.shape(x), dtype=ivy.default_int_dtype())
else:
return ivy.Shape(x.shape)
Expand Down
274 changes: 21 additions & 253 deletions ivy/functional/ivy/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# global
import gc
import inspect
import itertools
import math
from functools import wraps
from numbers import Number
Expand Down Expand Up @@ -2802,19 +2801,11 @@ def get_item(
query = ivy.nonzero(query, as_tuple=False)
ret = ivy.gather_nd(x, query)
else:
query, target_shape, vector_inds = _parse_query(
query, ivy.shape(x, as_array=True)
)
if vector_inds is not None:
x = ivy.permute_dims(
x,
axes=[
*vector_inds,
*[i for i in range(len(x.shape)) if i not in vector_inds],
],
)
ret = ivy.gather_nd(x, query)
ret = ivy.reshape(ret, target_shape) if target_shape != list(ret.shape) else ret
indices, target_shape = _parse_query(query, ivy.shape(x))
if indices is None:
return ivy.empty(target_shape, dtype=x.dtype)
ret = ivy.gather_nd(x, indices)
ret = ivy.reshape(ret, target_shape)
return ret


Expand Down Expand Up @@ -2889,9 +2880,7 @@ def set_item(
query = ivy.tile(query, (x.shape[0],))
indices = ivy.nonzero(query, as_tuple=False)
else:
indices, target_shape, _ = _parse_query(
query, ivy.shape(x, as_array=True), scatter=True
)
indices, target_shape = _parse_query(query, ivy.shape(x))
if indices is None:
return x
val = val.astype(x.dtype)
Expand All @@ -2909,252 +2898,31 @@ def set_item(
}


def _parse_query(query, x_shape, scatter=False):
query = (query,) if not isinstance(query, tuple) else query
def _parse_query(query, x_shape):
query = query if isinstance(query, tuple) else (query,)

# sequence and integer queries are dealt with as array queries
query = [ivy.array(q) if isinstance(q, (tuple, list, int)) else q for q in query]
# array containing all of x's flat indices
x_ = ivy.arange(0, _numel(x_shape)).reshape(x_shape)

# check if non-slice queries are in consecutive positions
# if so, they have to be moved to the front
# https://numpy.org/neps/nep-0021-advanced-indexing.html#mixed-indexing
non_slice_q_idxs = [i for i, q in enumerate(query) if ivy.is_array(q)]
to_front = (
len(non_slice_q_idxs) > 1
and any(ivy.diff(non_slice_q_idxs) != 1)
and non_slice_q_idxs[-1] < len(x_shape)
)
# use numpy's __getitem__ to get the queried indices
x_idxs = x_[query]
target_shape = x_idxs.shape

# extract newaxis queries
new_axes = [i for i, q in enumerate(query) if q is None]
query = [q for q in query if q is not None]
query = [Ellipsis] if query == [] else query

# parse ellipsis
ellipsis_inds = None
if any(q is Ellipsis for q in query):
query, ellipsis_inds = _parse_ellipsis(query, len(x_shape))

# broadcast array queries
array_inds = [i for i, v in enumerate(query) if ivy.is_array(v)]
if array_inds:
array_queries = ivy.broadcast_arrays(
*[v for i, v in enumerate(query) if i in array_inds]
)
array_queries = [
ivy.nonzero(q, as_tuple=False)[0] if ivy.is_bool_dtype(q) else q
for q in array_queries
]
array_queries = [
(
ivy.where(arr < 0, arr + x_shape[i], arr).astype(ivy.int64)
if arr.size
else arr.astype(ivy.int64)
)
for arr, i in zip(array_queries, array_inds)
]
for idx, arr in zip(array_inds, array_queries):
query[idx] = arr

# convert slices to range arrays
query = [
_parse_slice(q, x_shape[i]).astype(ivy.int64) if isinstance(q, slice) else q
for i, q in enumerate(query)
]
if 0 in x_idxs.shape or int(ivy.prod(x_shape)) == 0:
return None, target_shape

# fill in missing queries
if len(query) < len(x_shape):
query += [ivy.arange(0, s, 1).astype(ivy.int64) for s in x_shape[len(query) :]]
# convert the flat indices to multi-D indices
x_idxs = ivy.unravel_index(x_idxs, x_shape)

# calculate target_shape, i.e. the shape the gathered/scattered values should have
if len(array_inds) and to_front:
target_shape = (
[list(array_queries[0].shape)]
+ [list(query[i].shape) for i in range(len(query)) if i not in array_inds]
+ [[] for _ in range(len(array_inds) - 1)]
)
elif len(array_inds):
target_shape = (
[list(query[i].shape) for i in range(0, array_inds[0])]
+ [list(ivy.shape(array_queries[0], as_array=True))]
+ [[] for _ in range(len(array_inds) - 1)]
+ [list(query[i].shape) for i in range(array_inds[-1] + 1, len(query))]
)
else:
target_shape = [list(q.shape) for q in query]
if ellipsis_inds is not None:
target_shape = (
target_shape[: ellipsis_inds[0]]
+ [target_shape[ellipsis_inds[0] : ellipsis_inds[1]]]
+ target_shape[ellipsis_inds[1] :]
)
for i, ax in enumerate(new_axes):
if len(array_inds) and to_front:
ax -= sum(1 for x in array_inds if x < ax) - 1
ax = ax + i
target_shape = [*target_shape[:ax], 1, *target_shape[ax:]]
target_shape = _deep_flatten(target_shape)

# calculate the indices mesh (indices in gather_nd/scatter_nd format)
query = [ivy.expand_dims(q) if not len(q.shape) else q for q in query]
if len(array_inds):
array_queries = [
(
arr.reshape((-1,))
if len(arr.shape) > 1
else ivy.expand_dims(arr) if not len(arr.shape) else arr
)
for arr in array_queries
]
array_queries = ivy.stack(array_queries, axis=1)
if len(array_inds) == len(query): # advanced indexing
indices = array_queries.reshape((*target_shape, len(x_shape)))
elif len(array_inds) == 0: # basic indexing
indices = ivy.stack(ivy.meshgrid(*query, indexing="ij"), axis=-1).reshape(
(*target_shape, len(x_shape))
)
else: # mixed indexing
if to_front:
post_array_queries = (
ivy.stack(
ivy.meshgrid(
*[v for i, v in enumerate(query) if i not in array_inds],
indexing="ij",
),
axis=-1,
).reshape((-1, len(query) - len(array_inds)))
if len(array_inds) < len(query)
else ivy.empty((1, 0))
)
indices = ivy.array(
[
(*arr, *post)
for arr, post in itertools.product(
array_queries, post_array_queries
)
]
).reshape((*target_shape, len(x_shape)))
else:
pre_array_queries = (
ivy.stack(
ivy.meshgrid(
*[v for i, v in enumerate(query) if i < array_inds[0]],
indexing="ij",
),
axis=-1,
).reshape((-1, array_inds[0]))
if array_inds[0] > 0
else ivy.empty((1, 0))
)
post_array_queries = (
ivy.stack(
ivy.meshgrid(
*[v for i, v in enumerate(query) if i > array_inds[-1]],
indexing="ij",
),
axis=-1,
).reshape((-1, len(query) - 1 - array_inds[-1]))
if array_inds[-1] < len(query) - 1
else ivy.empty((1, 0))
)
indices = ivy.array(
[
(*pre, *arr, *post)
for pre, arr, post in itertools.product(
pre_array_queries, array_queries, post_array_queries
)
]
).reshape((*target_shape, len(x_shape)))

return (
indices.astype(ivy.int64),
target_shape,
array_inds if len(array_inds) and to_front else None,
)


def _parse_ellipsis(so, ndims):
pre = list()
for s in so:
if s is Ellipsis:
break
pre.append(s)
post = list()
for s in reversed(so):
if s is Ellipsis:
break
post.append(s)
ret = list(
pre
+ [slice(None, None, None) for _ in range(ndims - len(pre) - len(post))]
+ list(reversed(post))
)
return ret, (len(pre), ndims - len(post))


def _parse_slice(idx, s):
step = 1 if idx.step is None else idx.step
if step > 0:
start = 0 if idx.start is None else idx.start
if start >= s:
stop = start
else:
if start <= -s:
start = 0
elif start < 0:
start = start + s
stop = s if idx.stop is None else idx.stop
if stop > s:
stop = s
elif start <= -s:
stop = 0
elif stop < 0:
stop = stop + s
else:
start = s - 1 if idx.start is None else idx.start
if start < -s:
stop = start
else:
if start >= s:
start = s - 1
elif start < 0:
start = start + s
if idx.stop is None:
stop = -1
else:
stop = idx.stop
if stop > s:
stop = s
elif stop < -s:
stop = -1
elif stop == -s:
stop = 0
elif stop < 0:
stop = stop + s
q_i = ivy.arange(start, stop, step)
q_i = [q for q in q_i if 0 <= q < s]
q_i = (
ivy.array(q_i)
if len(q_i) or start == stop or idx.stop is not None
else ivy.arange(0, s, 1)
)
return q_i


def _deep_flatten(iterable):
def _flatten_gen(iterable):
for item in iterable:
if isinstance(item, list):
yield from _flatten_gen(item)
else:
yield item
# stack the multi-D indices to bring them to gather_nd/scatter_nd format
x_idxs = ivy.stack(x_idxs, axis=-1).astype(ivy.int64)

return list(_flatten_gen(iterable))
return x_idxs, target_shape


def _numel(shape):
shape = tuple(shape)
return ivy.prod(shape).to_scalar() if shape != () else 1
return int(ivy.prod(shape)) if shape != () else 1


def _broadcast_to(input, target_shape):
Expand Down
Loading