Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New NanoEvents form mapping implementation plan #1163

Open
nsmith- opened this issue Aug 14, 2024 · 1 comment
Open

New NanoEvents form mapping implementation plan #1163

nsmith- opened this issue Aug 14, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@nsmith-
Copy link
Member

nsmith- commented Aug 14, 2024

This is a sketch of how to add new columns to a uproot.dask call in a lazy way. The example here is resolving cross-references, which we would rather do at typetracer time using the mixin class since otherwise we have to resolve all of them ahead of time, but it is just to demonstrate it can be done.

from dataclasses import dataclass
from typing import Any, Callable, Mapping

import dask_awkward
import uproot
from uproot._dask import TrivialFormMapping, TrivialFormMappingInfo
from uproot.behaviors.TBranch import HasBranches

import awkward as ak
from awkward.forms import Form

ArrayTransform = Callable[[dask_awkward.Array], dask_awkward.Array]


@dataclass
class ColumnData:
    form: Form
    parent_columns: frozenset[str]
    constructor: ArrayTransform


class AddColumnsMappingInfo(TrivialFormMappingInfo):
    def __init__(
        self,
        form: Form,
        new_columns: dict[str, ColumnData],
    ):
        super().__init__(form)
        self._new_columns = new_columns

    def keys_for_buffer_keys(self, buffer_keys: frozenset[str]) -> frozenset[str]:
        root_keys = super().keys_for_buffer_keys(
            frozenset(k for k in buffer_keys if k.startswith("<root>"))
        )
        derived_keys: set[str] = set()
        for key in buffer_keys:
            if not key.startswith("<derived>."):
                continue
            form_key, _ = self.parse_buffer_key(key)
            column_key = form_key.split(".")[1]
            if column_key not in self._new_columns:
                raise ValueError(f"Unknown derived column {column_key}")
            derived_keys.update(self._new_columns[column_key].parent_columns)
        out = frozenset(root_keys | derived_keys)
        return out

    def load_buffers(
        self,
        tree: HasBranches,
        keys: frozenset[str],
        start: int,
        stop: int,
        decompression_executor,
        interpretation_executor,
        options: Any,
    ) -> Mapping[str, Any]:
        buffers = super().load_buffers(
            tree,
            keys,
            start,
            stop,
            decompression_executor,
            interpretation_executor,
            options,
        )
        for key, column in self._new_columns.items():
            reduced_form = ak.forms.RecordForm(
                contents=[self._form.content(k) for k in column.parent_columns],
                fields=list(column.parent_columns),
            )
            derived_form, _, derived_container = ak.to_buffers(
                column.constructor(
                    ak.from_buffers(
                        reduced_form,
                        stop - start,
                        buffers,
                        behavior=self.behavior,
                        buffer_key=self.buffer_key,
                    )
                )
            )
            assert derived_form == column.form
            for k, v in derived_container.items():
                buffers[k] = v
            for (src, src_dtype), (dst, dst_dtype) in zip(
                derived_form.expected_from_buffers().items(),
                self._form.content(key).expected_from_buffers(self.buffer_key).items(),
            ):
                assert src_dtype == dst_dtype  # Sanity check!
                buffers[dst] = derived_container[src]
        return buffers


class AddColumnsMapping(TrivialFormMapping):
    def __init__(self, new_columns: dict[str, ColumnData]):
        super().__init__()
        self._new_columns = new_columns

    def __call__(self, form: Form) -> tuple[Form, AddColumnsMappingInfo]:
        new_form = dask_awkward.lib.utils.form_with_unique_keys(form, "<root>")
        mappinginfo = AddColumnsMappingInfo(new_form, self._new_columns)
        # add stuff
        for key, column in self._new_columns.items():
            new_form.fields.append(key)
            new_form.contents.append(
                dask_awkward.lib.utils.form_with_unique_keys(
                    column.form, f"<derived>.{key}"
                )
            )
        return new_form, mappinginfo


def build_new_column(base_array: dask_awkward.Array, constructor: ArrayTransform):
    new_column = constructor(base_array)
    all_parent_columns = dask_awkward.necessary_columns(new_column)
    if len(all_parent_columns) != 1:
        raise NotImplementedError(
            "Understand how to map multiple input layers' columns (probably not possible until dak.bundle is implemented)"
        )
    (parent_columns,) = all_parent_columns.values()
    return ColumnData(
        form=new_column.form,
        parent_columns=parent_columns,
        constructor=constructor,
    )


if __name__ == "__main__":
    # filename = "tests/samples/nano_dy.root"
    filename = "form_mapping/nanozstd.root"

    events = uproot.dask(
        {filename: "Events"},
        full_paths=True,
        open_files=False,
        ak_add_doc=True,
        form_mapping=TrivialFormMapping(),
    )

    def negative_to_none(array):
        return array.mask[array < 0]

    # TODO: allow new columns to reference other new columns
    # would require a recursive approach in AddColumnsMappingInfo.load_buffers()
    new_columns = {
        "Electron_genId": build_new_column(
            events,
            lambda events: events.GenPart_pdgId[
                negative_to_none(events.Electron_genPartIdx)
            ],
        ),
        "Electron_genPt": build_new_column(
            events,
            lambda events: events.GenPart_pt[
                negative_to_none(events.Electron_genPartIdx)
            ],
        ),
    }
    print("New columns:", new_columns)

    events = uproot.dask(
        {filename: "Events"},
        full_paths=True,
        open_files=False,
        ak_add_doc=True,
        form_mapping=AddColumnsMapping(new_columns),
    )

    true_electrons = events.Electron_genId == -11 * events.Electron_charge
    dptrel = (1 - events.Electron_genPt / events.Electron_pt)[true_electrons]
    flat_dptrel = ak.flatten(dptrel, axis=None)
    print(
        "Necessary columns for flat_dptrel:",
        dask_awkward.necessary_columns(flat_dptrel),
    )
    print(flat_dptrel.compute())

In the long term, this is not the best solution, rather we would want to introduce a "non-touching zip" to dask-awkward. I had proposed dak.bundle as the verb, where users can zip dask-awkward arrays together without forcing all fields to be materialized by specifying exactly what broadcasting assumptions are expected to hold between the inputs.

@nsmith- nsmith- added the enhancement New feature or request label Aug 14, 2024
@nsmith-
Copy link
Member Author

nsmith- commented Aug 14, 2024

The output of the above example (with rich.print) is

New columns:
{
    'Electron_genId': ColumnData(
        form=ListOffsetForm('i64', IndexedOptionForm('i64', NumpyForm('int32'), parameters={'__doc__': 'PDG id'}), parameters={'__doc__': 'PDG id'}),
        parent_columns=frozenset({'GenPart_pdgId', 'Electron_genPartIdx'}),
        constructor=<function <lambda> at 0x1044937e0>
    ),
    'Electron_genPt': ColumnData(
        form=ListOffsetForm('i64', IndexedOptionForm('i64', NumpyForm('float32'), parameters={'__doc__': 'pt'}), parameters={'__doc__': 'pt'}),
        parent_columns=frozenset({'GenPart_pt', 'Electron_genPartIdx'}),
        constructor=<function <lambda> at 0x104493f60>
    )
}
Necessary columns for flat_dptrel:
{'from-uproot-a24b37c9c0d42135bcbf2dd760ac48e3': frozenset({'GenPart_pt', 'GenPart_pdgId', 'Electron_pt', 'Electron_charge', 'Electron_genPartIdx'})}
[0.983, 0.998, 0.307, 0.98, 0.999, ..., 0.993, 0.363, 0.972, 0.997, 0.943]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant