Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement integer array indexing #651

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 109 additions & 9 deletions code/ndarray.c
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,94 @@ static mp_obj_t ndarray_from_boolean_index(ndarray_obj_t *ndarray, ndarray_obj_t
return MP_OBJ_FROM_PTR(results);
}


static mp_obj_t ndarray_from_integer_index(ndarray_obj_t *ndarray, ndarray_obj_t *index) {
if(ndarray->ndim > 1) {
mp_raise_ValueError(MP_ERROR_TEXT("only supports 1-dim target arrays"));
}

if(!ndarray_is_dense(ndarray)) {
mp_raise_ValueError(MP_ERROR_TEXT("only supports dense target arrays"));
}

// TODO: range-check index values against ndarray->shape[ULAB_MAX_DIMS-1]
// TODO: normalize or handle negative indices in loop (without modifying index)

int32_t *strides = strides_from_shape(index->shape, ndarray->dtype);
ndarray_obj_t *results = ndarray_new_ndarray(index->ndim, index->shape, strides, ndarray->dtype);

uint8_t *larray = (uint8_t *)ndarray->array;
uint8_t *iarray = (uint8_t *)index->array;

if (ndarray->dtype == NDARRAY_UINT8) {
if (index->dtype == NDARRAY_UINT8) {
INDEX_LOOP(results, uint8_t, uint8_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_INT8) {
INDEX_LOOP(results, uint8_t, int8_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_UINT16) {
INDEX_LOOP(results, uint8_t, uint16_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_INT16) {
INDEX_LOOP(results, uint8_t, int16_t, larray, ndarray->strides, iarray, index->strides);
}
} else if (ndarray->dtype == NDARRAY_INT8) {
if (index->dtype == NDARRAY_UINT8) {
INDEX_LOOP(results, int8_t, uint8_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_INT8) {
INDEX_LOOP(results, int8_t, int8_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_UINT16) {
INDEX_LOOP(results, int8_t, uint16_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_INT16) {
INDEX_LOOP(results, int8_t, int16_t, larray, ndarray->strides, iarray, index->strides);
}
} else if (ndarray->dtype == NDARRAY_UINT16) {
if (index->dtype == NDARRAY_UINT8) {
INDEX_LOOP(results, uint16_t, uint8_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_INT8) {
INDEX_LOOP(results, uint16_t, int8_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_UINT16) {
INDEX_LOOP(results, uint16_t, uint16_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_INT16) {
INDEX_LOOP(results, uint16_t, int16_t, larray, ndarray->strides, iarray, index->strides);
}
} else if (ndarray->dtype == NDARRAY_INT16) {
if (index->dtype == NDARRAY_UINT8) {
INDEX_LOOP(results, int16_t, uint8_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_INT8) {
INDEX_LOOP(results, int16_t, int8_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_UINT16) {
INDEX_LOOP(results, int16_t, uint16_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_INT16) {
INDEX_LOOP(results, int16_t, int16_t, larray, ndarray->strides, iarray, index->strides);
}
} else if (ndarray->dtype == NDARRAY_FLOAT) {
if (index->dtype == NDARRAY_UINT8) {
INDEX_LOOP(results, mp_float_t, uint8_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_INT8) {
INDEX_LOOP(results, mp_float_t, int8_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_UINT16) {
INDEX_LOOP(results, mp_float_t, uint16_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_INT16) {
INDEX_LOOP(results, mp_float_t, int16_t, larray, ndarray->strides, iarray, index->strides);
}
#if ULAB_SUPPORTS_COMPLEX
} else if (ndarray->dtype == NDARRAY_COMPLEX) {
struct complex_t { float a; float b; };

if (index->dtype == NDARRAY_UINT8) {
INDEX_LOOP(results, struct complex_t, uint8_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_INT8) {
INDEX_LOOP(results, struct complex_t, int8_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_UINT16) {
INDEX_LOOP(results, struct complex_t, uint16_t, larray, ndarray->strides, iarray, index->strides);
} else if (index->dtype == NDARRAY_INT16) {
INDEX_LOOP(results, struct complex_t, int16_t, larray, ndarray->strides, iarray, index->strides);
}
#endif
}

return MP_OBJ_FROM_PTR(results);
}

static mp_obj_t ndarray_assign_from_boolean_index(ndarray_obj_t *ndarray, ndarray_obj_t *index, ndarray_obj_t *values) {
// assigns values to a Boolean-indexed array
// first we have to find out how many trues there are
Expand Down Expand Up @@ -1313,16 +1401,28 @@ static mp_obj_t ndarray_assign_from_boolean_index(ndarray_obj_t *ndarray, ndarra
static mp_obj_t ndarray_get_slice(ndarray_obj_t *ndarray, mp_obj_t index, ndarray_obj_t *values) {
if(mp_obj_is_type(index, &ulab_ndarray_type)) {
ndarray_obj_t *nindex = MP_OBJ_TO_PTR(index);
if((nindex->ndim > 1) || (nindex->boolean == false)) {
mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is implemented for 1D Boolean arrays only"));
}
if(values == NULL) { // return value(s)
return ndarray_from_boolean_index(ndarray, nindex);
} else { // assign value(s)
ndarray_assign_from_boolean_index(ndarray, nindex, values);

if(nindex->boolean) {
if(nindex->ndim > 1) {
mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is only implemented for integer arrays or 1D Boolean arrays"));
}

if(values == NULL) { // return value(s)
return ndarray_from_boolean_index(ndarray, index);
} else { // assign value(s)
ndarray_assign_from_boolean_index(ndarray, nindex, values);
}
} else if ((nindex->dtype == NDARRAY_UINT8) || (nindex->dtype == NDARRAY_INT8) ||
(nindex->dtype == NDARRAY_UINT16) || (nindex->dtype == NDARRAY_INT16)) {
if(values == NULL) { // return value(s)
return ndarray_from_integer_index(ndarray, nindex);
} else { // assign value(s)
mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is only implemented for integer arrays or 1D Boolean arrays"));
}
} else {
mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is only implemented for integer arrays or 1D Boolean arrays"));
}
}
if(mp_obj_is_type(index, &mp_type_tuple) || mp_obj_is_int(index) || mp_obj_is_type(index, &mp_type_slice)) {
} else if(mp_obj_is_type(index, &mp_type_tuple) || mp_obj_is_int(index) || mp_obj_is_type(index, &mp_type_slice)) {
mp_obj_tuple_t *tuple;
if(mp_obj_is_type(index, &mp_type_tuple)) {
tuple = MP_OBJ_TO_PTR(index);
Expand Down
92 changes: 92 additions & 0 deletions code/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,17 @@ ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t , uint8_t );
l++;\
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\

#define INDEX_LOOP(results, type_left, type_index, larray, lstrides, iarray, istrides)\
uint8_t *array = (uint8_t *)results->array;\
size_t l = 0;\
do {\
size_t offset = lstrides[ULAB_MAX_DIMS - 1] * *((type_index *)(iarray));\
*((type_left *)array) = *((type_left *)(larray + offset));\
array += results->strides[ULAB_MAX_DIMS - 1];\
iarray += istrides[ULAB_MAX_DIMS - 1];\
l++;\
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\

#endif /* ULAB_MAX_DIMS == 1 */

#if ULAB_MAX_DIMS == 2
Expand Down Expand Up @@ -464,6 +475,25 @@ ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t , uint8_t );
k++;\
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\

#define INDEX_LOOP(results, type_left, type_index, larray, lstrides, iarray, istrides)\
uint8_t *array = (uint8_t *)results->array;\
size_t k = 0;\
do {\
size_t l = 0;\
do {\
size_t offset = lstrides[ULAB_MAX_DIMS - 1] * *((type_index *)(iarray));\
*((type_left *)array) = *((type_left *)(larray + offset));\
array += results->strides[ULAB_MAX_DIMS - 1];\
iarray += istrides[ULAB_MAX_DIMS - 1];\
l++;\
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
(array) -= (results->strides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(array) += (results->strides)[ULAB_MAX_DIMS - 2];\
(iarray) -= (istrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(iarray) += (istrides)[ULAB_MAX_DIMS - 2];\
k++;\
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\

#endif /* ULAB_MAX_DIMS == 2 */

#if ULAB_MAX_DIMS == 3
Expand Down Expand Up @@ -569,6 +599,33 @@ ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t , uint8_t );
j++;\
} while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\

#define INDEX_LOOP(results, type_left, type_index, larray, lstrides, iarray, istrides)\
uint8_t *array = (uint8_t *)results->array;\
size_t j = 0;\
do {\
size_t k = 0;\
do {\
size_t l = 0;\
do {\
size_t offset = lstrides[ULAB_MAX_DIMS - 1] * *((type_index *)(iarray));\
*((type_left *)array) = *((type_left *)(larray + offset));\
array += results->strides[ULAB_MAX_DIMS - 1];\
iarray += istrides[ULAB_MAX_DIMS - 1];\
l++;\
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
(array) -= (results->strides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(array) += (results->strides)[ULAB_MAX_DIMS - 2];\
(iarray) -= (istrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(iarray) += (istrides)[ULAB_MAX_DIMS - 2];\
k++;\
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
(array) -= (results->strides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
(array) += (results->strides)[ULAB_MAX_DIMS - 3];\
(iarray) -= (istrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
(iarray) += (istrides)[ULAB_MAX_DIMS - 3];\
j++;\
} while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\

#endif /* ULAB_MAX_DIMS == 3 */

#if ULAB_MAX_DIMS == 4
Expand Down Expand Up @@ -706,6 +763,41 @@ ndarray_obj_t *ndarray_from_mp_obj(mp_obj_t , uint8_t );
i++;\
} while(i < (results)->shape[ULAB_MAX_DIMS - 4]);\

#define INDEX_LOOP(results, type_left, type_index, larray, lstrides, iarray, istrides)\
uint8_t *array = (uint8_t *)results->array;\
size_t i = 0;\
do {\
size_t j = 0;\
do {\
size_t k = 0;\
do {\
size_t l = 0;\
do {\
size_t offset = lstrides[ULAB_MAX_DIMS - 1] * *((type_index *)(iarray));\
*((type_left *)array) = *((type_left *)(larray + offset));\
array += results->strides[ULAB_MAX_DIMS - 1];\
iarray += istrides[ULAB_MAX_DIMS - 1];\
l++;\
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
(array) -= (results->strides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(array) += (results->strides)[ULAB_MAX_DIMS - 2];\
(iarray) -= (istrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS-1];\
(iarray) += (istrides)[ULAB_MAX_DIMS - 2];\
k++;\
} while(k < (results)->shape[ULAB_MAX_DIMS - 2]);\
(array) -= (results->strides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
(array) += (results->strides)[ULAB_MAX_DIMS - 3];\
(iarray) -= (istrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS-2];\
(iarray) += (istrides)[ULAB_MAX_DIMS - 3];\
j++;\
} while(j < (results)->shape[ULAB_MAX_DIMS - 3]);\
(array) -= (results->strides)[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\
(array) += (results->strides)[ULAB_MAX_DIMS - 4];\
(iarray) -= (lstrides)[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS-3];\
(iarray) += (lstrides)[ULAB_MAX_DIMS - 4];\
i++;\
} while(i < (results)->shape[ULAB_MAX_DIMS - 4]);\

#endif /* ULAB_MAX_DIMS == 4 */
#endif /* ULAB_HAS_FUNCTION_ITERATOR */

Expand Down
21 changes: 21 additions & 0 deletions tests/1d/numpy/advanced_indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from ulab import numpy as np

a = np.array(range(0, 100, 10), dtype=np.uint8)
b = np.array([0.5, 1.5, 0.2, 4.3], dtype=np.float)

# integer array indexing
print(a[np.array([0, 4, 2], dtype=np.uint8)])
print(b[np.array([3, 2, 2, 3], dtype=np.int16)])
# TODO: test negative indices
# TODO: check range checking

# boolean array indexing
print(a[a >= 50])
print(b[b > 1])

# boolean array index assignment
a[a > 1] = 0
print(a)

b[b > 50] += 5
print(b)
6 changes: 6 additions & 0 deletions tests/1d/numpy/advanced_indexing.py.exp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
array([0, 40, 20], dtype=uint8)
array([4.3, 0.2, 0.2, 4.3], dtype=float64)
array([50, 60, 70, 80, 90], dtype=uint8)
array([1.5, 4.3], dtype=float64)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=uint8)
array([0.5, 1.5, 0.2, 4.3], dtype=float64)
10 changes: 10 additions & 0 deletions tests/2d/numpy/advanced_indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from ulab import numpy as np

a = np.array(range(0, 100, 10), dtype=np.uint8)
b = np.array([0.5, 1.5, 0.2, 4.3], dtype=np.float)

# integer array indexing
print(a[np.array([[0, 4], [1, 2]], dtype=np.uint8)])
print(b[np.array([[3, 2], [2, 3]], dtype=np.uint8)])
# TODO: test negative indices
# TODO: check range checking
4 changes: 4 additions & 0 deletions tests/2d/numpy/advanced_indexing.py.exp
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
array([[0, 40],
[10, 20]], dtype=uint8)
array([[4.3, 0.2],
[0.2, 4.3]], dtype=float64)
Loading