Skip to content

Commit

Permalink
Merge branch 'main' into jpivarski/generalize-Index.ptr
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Aug 15, 2024
2 parents 5168aaa + e15518f commit 4f282ca
Show file tree
Hide file tree
Showing 15 changed files with 410 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,13 @@ ERROR awkward_UnionArray8_64_nestedfill_tags_index_64(
return awkward_UnionArray_nestedfill_tags_index<int8_t, int64_t, int64_t>(
totags, toindex, tmpstarts, tag, fromcounts, length);
}
ERROR awkward_UnionArray64_64_nestedfill_tags_index_64(
int64_t* totags,
int64_t* toindex,
int64_t* tmpstarts,
int64_t tag,
const int64_t* fromcounts,
int64_t length) {
return awkward_UnionArray_nestedfill_tags_index<int64_t, int64_t, int64_t>(
totags, toindex, tmpstarts, tag, fromcounts, length);
}
25 changes: 25 additions & 0 deletions awkward-cpp/src/cpu-kernels/awkward_UnionArray_simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,28 @@ ERROR awkward_UnionArray8_64_simplify8_64_to8_64(
length,
base);
}
ERROR awkward_UnionArray64_64_simplify8_64_to8_64(
int8_t* totags,
int64_t* toindex,
const int64_t* outertags,
const int64_t* outerindex,
const int8_t* innertags,
const int64_t* innerindex,
int64_t towhich,
int64_t innerwhich,
int64_t outerwhich,
int64_t length,
int64_t base) {
return awkward_UnionArray_simplify<int64_t, int64_t, int8_t, int64_t, int8_t, int64_t>(
totags,
toindex,
outertags,
outerindex,
innertags,
innerindex,
towhich,
innerwhich,
outerwhich,
length,
base);
}
19 changes: 19 additions & 0 deletions awkward-cpp/src/cpu-kernels/awkward_UnionArray_simplify_one.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,22 @@ ERROR awkward_UnionArray8_64_simplify_one_to8_64(
length,
base);
}
ERROR awkward_UnionArray64_64_simplify_one_to8_64(
int8_t* totags,
int64_t* toindex,
const int64_t* fromtags,
const int64_t* fromindex,
int64_t towhich,
int64_t fromwhich,
int64_t length,
int64_t base) {
return awkward_UnionArray_simplify_one<int64_t, int64_t, int8_t, int64_t>(
totags,
toindex,
fromtags,
fromindex,
towhich,
fromwhich,
length,
base);
}
6 changes: 4 additions & 2 deletions awkward-cpp/src/libawkward/forth/ForthMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3075,7 +3075,7 @@ namespace awkward {
uint64_t tmp;
uint8_t tmpbyte;

if (items_remaining != 0) {
if (items_remaining > 0) {
tmpbyte = input->read_byte(current_error_);
if (current_error_ != util::ForthError::none) {
return;
Expand All @@ -3087,7 +3087,7 @@ namespace awkward {
}
data = tmp;
}
while (items_remaining != 0) {
while (items_remaining > 0) {
if (bits_wnd_r >= 8) {
bits_wnd_r -= 8;
bits_wnd_l -= 8;
Expand Down Expand Up @@ -3230,6 +3230,7 @@ namespace awkward {
break; \
}

if (num_items < 0) num_items = 0;
switch (format) {
case READ_BOOL: WRITE_DIRECTLY(bool, bool)
case READ_INT8: WRITE_DIRECTLY(int8_t, int8)
Expand Down Expand Up @@ -3311,6 +3312,7 @@ namespace awkward {
break; \
}

if (num_items < 0) num_items = 0;
switch (format) {
case READ_BOOL: WRITE_TO_STACK(bool)
case READ_INT8: WRITE_TO_STACK(int8_t)
Expand Down
14 changes: 14 additions & 0 deletions docs/reference/awkwardforth.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,20 @@ and
but the second is faster because it involves two Forth instructions and one ``memcpy``.

If the number of items to read is negative then it is interpreted as zero.

.. code-block:: python
>>> vm = ForthMachine32("""
... input x
... output y float32
...
... -1000000 x #d-> y
... """)
>>> vm.run({"x": np.arange(1000000) * 1.1})
>>> np.asarray(vm["y"])
array([], dtype=float32)
Type codes
""""""""""

Expand Down
31 changes: 31 additions & 0 deletions kernel-specification.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3372,6 +3372,14 @@ kernels:

- name: awkward_UnionArray_nestedfill_tags_index
specializations:
- name: awkward_UnionArray64_64_nestedfill_tags_index_64
args:
- {name: totags, type: "List[int64_t]", dir: out}
- {name: toindex, type: "List[int64_t]", dir: out}
- {name: tmpstarts, type: "List[int64_t]", dir: out}
- {name: tag, type: "int64_t", dir: in, role: default}
- {name: fromcounts, type: "Const[List[int64_t]]", dir: in, role: default}
- {name: length, type: "int64_t", dir: in, role: default}
- name: awkward_UnionArray8_32_nestedfill_tags_index_64
args:
- {name: totags, type: "List[int8_t]", dir: out}
Expand Down Expand Up @@ -3503,6 +3511,19 @@ kernels:

- name: awkward_UnionArray_simplify
specializations:
- name: awkward_UnionArray64_64_simplify8_64_to8_64
args:
- {name: totags, type: "List[int8_t]", dir: out}
- {name: toindex, type: "List[int64_t]", dir: out}
- {name: outertags, type: "Const[List[int64_t]]", dir: in, role: UnionArray-tags}
- {name: outerindex, type: "Const[List[int64_t]]", dir: in, role: IndexedArray-index}
- {name: innertags, type: "Const[List[int8_t]]", dir: in, role: UnionArray2-tags}
- {name: innerindex, type: "Const[List[int64_t]]", dir: in, role: IndexedArray2-index}
- {name: towhich, type: "int64_t", dir: in, role: default}
- {name: innerwhich, type: "int64_t", dir: in, role: UnionArray1-which}
- {name: outerwhich, type: "int64_t", dir: in, role: UnionArray2-which}
- {name: length, type: "int64_t", dir: in, role: default}
- {name: base, type: "int64_t", dir: in, role: default}
- name: awkward_UnionArray8_32_simplify8_32_to8_64
args:
- {name: totags, type: "List[int8_t]", dir: out}
Expand Down Expand Up @@ -3645,6 +3666,16 @@ kernels:

- name: awkward_UnionArray_simplify_one
specializations:
- name: awkward_UnionArray64_64_simplify_one_to8_64
args:
- {name: totags, type: "List[int8_t]", dir: out}
- {name: toindex, type: "List[int64_t]", dir: out}
- {name: fromtags, type: "Const[List[int64_t]]", dir: in, role: UnionArray-tags}
- {name: fromindex, type: "Const[List[int64_t]]", dir: in, role: IndexedArray-index}
- {name: towhich, type: "int64_t", dir: in, role: default}
- {name: fromwhich, type: "int64_t", dir: in, role: UnionArray-which}
- {name: length, type: "int64_t", dir: in, role: default}
- {name: base, type: "int64_t", dir: in, role: default}
- name: awkward_UnionArray8_32_simplify_one_to8_64
args:
- {name: totags, type: "List[int8_t]", dir: out}
Expand Down
5 changes: 4 additions & 1 deletion src/awkward/_nplikes/array_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ def array_equal(
assert not isinstance(x1, PlaceholderArray)
assert not isinstance(x2, PlaceholderArray)
if equal_nan:
both_nan = self._module.logical_and(x1 == np.nan, x2 == np.nan)
# Only newer numpy.array_equal supports the equal_nan parameter.
both_nan = self._module.logical_and(
self._module.isnan(x1), self._module.isnan(x2)
)
both_equal = x1 == x2
return self._module.all(self._module.logical_or(both_equal, both_nan))
else:
Expand Down
26 changes: 17 additions & 9 deletions src/awkward/contents/unionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

np = NumpyMetadata.instance()
numpy = Numpy.instance()
MAX_UNION_CONTENTS = 2**7 # We use int8 tags, 0-127


@final
Expand Down Expand Up @@ -230,6 +231,10 @@ def simplified(
parameters=None,
mergebool=False,
):
# Note: to help merge more than 128 arrays, tags *can* have type ak.index.Index64.
# This is only supported when index is also Index64,
# and all indexed contents are also Index64.
# We still require that this reduces to no more than 128 variants.
self_index = index
self_tags = tags
self_contents = contents
Expand Down Expand Up @@ -299,6 +304,10 @@ def simplified(

# Did we fail to merge any of the final outer contents with this inner union content?
if unmerged:
if len(contents) >= MAX_UNION_CONTENTS:
raise ValueError(
"UnionArray does not support more than 128 content types"
)
backend.maybe_kernel_error(
backend[
"awkward_UnionArray_simplify",
Expand Down Expand Up @@ -373,6 +382,10 @@ def simplified(
break

if unmerged:
if len(contents) >= MAX_UNION_CONTENTS:
raise ValueError(
"UnionArray does not support more than 128 content types"
)
backend.maybe_kernel_error(
backend[
"awkward_UnionArray_simplify_one",
Expand All @@ -393,11 +406,6 @@ def simplified(
)
contents.append(self_cont)

if len(contents) > 2**7:
raise NotImplementedError(
"FIXME: handle UnionArray with more than 127 contents"
)

# If any contents are non-categorical index types, we can merge them into the union
# This is safe, because any remaining index types at this point in the routine are not considered
# mergeable with the other contents. This means none of the other contents are option or index types,
Expand Down Expand Up @@ -1107,8 +1115,8 @@ def _reverse_merge(self, other):
)
)

if len(contents) > 2**7:
raise AssertionError("FIXME: handle UnionArray with more than 127 contents")
if len(contents) > MAX_UNION_CONTENTS:
raise ValueError("UnionArray cannot have more than 128 content types")

return ak.contents.UnionArray.simplified(
tags, index, contents, parameters=self._parameters
Expand Down Expand Up @@ -1236,8 +1244,8 @@ def _mergemany(self, others: Sequence[Content]) -> Content:

nextcontents.append(array)

if len(nextcontents) > 127:
raise ValueError("FIXME: handle UnionArray with more than 127 contents")
if len(nextcontents) > MAX_UNION_CONTENTS:
raise ValueError("UnionArray cannot have more than 128 content types")

next = ak.contents.UnionArray.simplified(
nexttags,
Expand Down
1 change: 1 addition & 0 deletions src/awkward/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from awkward.operations.ak_argmax import *
from awkward.operations.ak_argmin import *
from awkward.operations.ak_argsort import *
from awkward.operations.ak_array_equal import *
from awkward.operations.ak_backend import *
from awkward.operations.ak_broadcast_arrays import *
from awkward.operations.ak_broadcast_fields import *
Expand Down
46 changes: 45 additions & 1 deletion src/awkward/operations/ak_almost_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,32 @@ def almost_equal(
# Dispatch
yield left, right

return _impl(
left,
right,
rtol=rtol,
atol=atol,
dtype_exact=dtype_exact,
check_parameters=check_parameters,
check_regular=check_regular,
exact_eq=False,
same_content_types=False,
equal_nan=False,
)


def _impl(
left,
right,
rtol: float,
atol: float,
dtype_exact: bool,
check_parameters: bool,
check_regular: bool,
exact_eq: bool,
same_content_types: bool,
equal_nan: bool,
):
# Implementation
left_behavior = behavior_of(left)
right_behavior = behavior_of(right)
Expand Down Expand Up @@ -82,6 +108,10 @@ def packed_list_content(layout):
return layout.content[layout.offsets[0] : layout.offsets[-1]]

def visitor(left, right) -> bool:
# Most firstly, check same_content_types before any transformations
if same_content_types and left.__class__ is not right.__class__:
return False

# First, erase indexed types!
if left.is_indexed and not left.is_option:
left = left.project()
Expand Down Expand Up @@ -152,12 +182,26 @@ def visitor(left, right) -> bool:
and backend.nplike.all(left.data == right.data)
and left.shape == right.shape
)
elif exact_eq:
return (
is_approx_dtype(left.dtype, right.dtype)
and backend.nplike.array_equal(
left.data,
right.data,
equal_nan=equal_nan,
)
and left.shape == right.shape
)
else:
return (
is_approx_dtype(left.dtype, right.dtype)
and backend.nplike.all(
backend.nplike.isclose(
left.data, right.data, rtol=rtol, atol=atol, equal_nan=False
left.data,
right.data,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
)
)
and left.shape == right.shape
Expand Down
54 changes: 54 additions & 0 deletions src/awkward/operations/ak_array_equal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import awkward as ak
from awkward._dispatch import high_level_function

__all__ = ("array_equal",)


@high_level_function()
def array_equal(
a1,
a2,
equal_nan: bool = False,
dtype_exact: bool = True,
same_content_types: bool = True,
check_parameters: bool = True,
check_regular: bool = True,
):
"""
True if two arrays have the same shape and elements, False otherwise.
Args:
a1: Array-like data (anything #ak.to_layout recognizes).
a2: Array-like data (anything #ak.to_layout recognizes).
equal_nan: bool (default=False)
Whether to count NaN values as equal to each other.
dtype_exact: bool (default=True) whether the dtypes must be exactly the same, or just the
same family.
same_content_types: bool (default=True)
Whether to require all content classes to match
check_parameters: bool (default=True) whether to compare parameters.
check_regular: bool (default=True) whether to consider ragged and regular dimensions as
unequal.
TypeTracer arrays are not supported, as there is very little information to
be compared.
"""
# Dispatch
yield a1, a2

return ak.operations.ak_almost_equal._impl(
a1,
a2,
rtol=0.0,
atol=0.0,
dtype_exact=dtype_exact,
check_parameters=check_parameters,
check_regular=check_regular,
exact_eq=True,
same_content_types=same_content_types and check_regular,
equal_nan=equal_nan,
)
Loading

0 comments on commit 4f282ca

Please sign in to comment.