Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: fix axis=0 concatenation #457

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 116 additions & 1 deletion src/dask_awkward/lib/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,119 @@ def _enforce_concatenated_form(array: AwkwardArray, form: Form) -> AwkwardArray:
return ak.Array(result, behavior=array._behavior, attrs=array._attrs)


from awkward.typetracer import TypeTracerReport


class ParentReport(TypeTracerReport):
def __init__(self):
self._parent_to_child: dict[str, tuple[TypeTracerReport, str]] = {}

def add_child_key(
self, parent_key: str, child_key: str, child_report: TypeTracerReport
):
self._parent_to_child.setdefault(parent_key, []).append(
(child_report, child_key)
)

@property
def shape_touched(self):
raise NotImplementedError

@property
def data_touched(self):
raise NotImplementedError

def touch_shape(self, label: str):
if (child_infos := self._parent_to_child.get(label)) is not None:
for child_report, child_label in child_infos:
child_report.touch_shape(child_label)

def touch_data(self, label: str):
if (child_infos := self._parent_to_child.get(label)) is not None:
for child_report, child_label in child_infos:
child_report.touch_data(child_label)


def maybe_parent_report(parent, children, parent_report):
if parent_report is None:
parent_report = ParentReport()
if parent.report is not None:
parent_report.add_child_key(parent.form_key, parent.form_key, parent.report)
for child in children:
if child.report is not None:
parent_report.add_child_key(parent.form_key, child.form_key, child.report)
parent.report = parent_report
return parent_report


def merge_reports(first, *remainder):
parent_report = None

def impl(first, *remainder):
nonlocal parent_report
assert all(type(rem) is type(first) for rem in remainder)

if first.is_numpy:
parent_report = maybe_parent_report(
first.data, [c.data for c in remainder], parent_report
)

elif first.is_option and first.is_indexed:
parent_report = maybe_parent_report(
first.index.data, [c.index.data for c in remainder], parent_report
)
impl(first.content, *[c.content for c in remainder])

elif first.is_option:
parent_report = maybe_parent_report(
first.mask.data, [c.mask.data for c in remainder], parent_report
)
impl(first.content, *[c.content for c in remainder])

elif first.is_list and isinstance(first, ak.contents.ListOffsetArray):
parent_report = maybe_parent_report(
first.offsets.data, [c.offsets.data for c in remainder], parent_report
)
impl(first.content, *[c.content for c in remainder])

elif first.is_list and isinstance(first, ak.contents.ListArray):
parent_report = maybe_parent_report(
first.starts.data, [c.starts.data for c in remainder], parent_report
)
parent_report = maybe_parent_report(
first.stops.data, [c.stops.data for c in remainder], parent_report
)
impl(first.content, *[c.content for c in remainder])

elif first.is_list and isinstance(first, ak.contents.RegularArray):
impl(first.content, *[c.content for c in remainder])

elif first.is_indexed:
parent_report = maybe_parent_report(
first.index.data, [c.index.data for c in remainder], parent_report
)
impl(first.content, *[c.content for c in remainder])

elif first.is_record:
for this, *that in zip(first.contents, *[c.contents for c in remainder]):
impl(this, *that)

elif first.is_empty:
return

elif first.is_union:
raise NotImplementedError

else:
raise AssertionError

impl(first, *remainder)


def _concatenate_axis_0_meta(*arrays: AwkwardArray) -> AwkwardArray:
# At this stage, the metas have all been enforced to the same type
layouts = [arr.layout for arr in arrays]
merge_reports(layouts[0], *layouts)
return arrays[0]


Expand Down Expand Up @@ -119,7 +230,11 @@ def concatenate(
)
}

aml = AwkwardMaterializedLayer(g, previous_layer_names=[arrays[0].name])
aml = AwkwardMaterializedLayer(
g,
previous_layer_names=[a.name for a in arrays],
fn=_concatenate_axis_0_meta,
)

hlg = HighLevelGraph.from_collections(name, aml, dependencies=arrays)
return new_array_object(
Expand Down
Loading