Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task graphs for record arrays with systematic variations or multi-step corrections are very large #138

Open
lgray opened this issue Dec 29, 2022 · 12 comments

Comments

@lgray
Copy link
Collaborator

lgray commented Dec 29, 2022

Starting a new topic since there is a concrete example:

You'll need to install coffea from this branch: https://github.com/CoffeaTeam/coffea/tree/awkward2_dev (pip install -e '.[dev]')
You'll need to install dask_awkward from this branch: https://github.com/lgray/dask-awkward/tree/feat_dak_with_field

Once you're there, go to the coffea install directory spin up a jupyter lab python3 notebook and paste the following into a cell:

def jetmet_evaluator():
    from coffea.lookup_tools import extractor

    extract = extractor()

    extract.add_weight_sets(
        [
            "* * tests/samples/Summer16_23Sep2016V3_MC_L1FastJet_AK4PFPuppi.jec.txt.gz",
            "* * tests/samples/Summer16_23Sep2016V3_MC_L2L3Residual_AK4PFPuppi.jec.txt.gz",
            "* * tests/samples/Summer16_23Sep2016V3_MC_L2Relative_AK4PFPuppi.jec.txt.gz",
            "* * tests/samples/Summer16_23Sep2016V3_MC_L3Absolute_AK4PFPuppi.jec.txt.gz",
            "* * tests/samples/Summer16_23Sep2016V3_MC_UncertaintySources_AK4PFPuppi.junc.txt.gz",
            "* * tests/samples/Summer16_23Sep2016V3_MC_Uncertainty_AK4PFPuppi.junc.txt.gz",
            "* * tests/samples/Fall17_17Nov2017_V6_MC_UncertaintySources_AK4PFchs.junc.txt.gz",
            "* * tests/samples/RegroupedV2_Fall17_17Nov2017_V32_MC_UncertaintySources_AK4PFchs.junc.txt.gz",
            "* * tests/samples/Regrouped_Fall17_17Nov2017_V32_MC_UncertaintySources_AK4PFchs.junc.txt",
            "* * tests/samples/Spring16_25nsV10_MC_PtResolution_AK4PFPuppi.jr.txt.gz",
            "* * tests/samples/Spring16_25nsV10_MC_SF_AK4PFPuppi.jersf.txt.gz",
            "* * tests/samples/Autumn18_V7_MC_SF_AK4PFchs.jersf.txt.gz",
        ]
    )

    extract.finalize()

    return extract.make_evaluator()
evaluator = jetmet_evaluator()

import os
import time
import awkward as ak
import numpy as np

from coffea.jetmet_tools import CorrectedJetsFactory, CorrectedMETFactory, JECStack

events = None
from coffea.nanoevents import NanoEventsFactory

factory = NanoEventsFactory.from_root(os.path.abspath("tests/samples/nano_dy.root"))
events = factory.events()

jec_stack_names = [
    "Summer16_23Sep2016V3_MC_L1FastJet_AK4PFPuppi",
    "Summer16_23Sep2016V3_MC_L2Relative_AK4PFPuppi",
    "Summer16_23Sep2016V3_MC_L2L3Residual_AK4PFPuppi",
    "Summer16_23Sep2016V3_MC_L3Absolute_AK4PFPuppi",
    "Spring16_25nsV10_MC_PtResolution_AK4PFPuppi",
    "Spring16_25nsV10_MC_SF_AK4PFPuppi",
]
for key in evaluator.keys():
    if "Summer16_23Sep2016V3_MC_UncertaintySources_AK4PFPuppi" in key:
        jec_stack_names.append(key)

jec_inputs = {name: evaluator[name] for name in jec_stack_names}
jec_stack = JECStack(jec_inputs)

name_map = jec_stack.blank_name_map
name_map["JetPt"] = "pt"
name_map["JetMass"] = "mass"
name_map["JetEta"] = "eta"
name_map["JetA"] = "area"

jets = events.Jet

jets["pt_raw"] = (1 - jets["rawFactor"]) * jets["pt"]
jets["mass_raw"] = (1 - jets["rawFactor"]) * jets["mass"]
jets["pt_gen"] = ak.values_astype(ak.fill_none(jets.matched_gen.pt, 0), np.float32)
jets["rho"] = ak.broadcast_arrays(events.fixedGridRhoFastjetAll, jets.pt)[0]
name_map["ptGenJet"] = "pt_gen"
name_map["ptRaw"] = "pt_raw"
name_map["massRaw"] = "mass_raw"
name_map["Rho"] = "rho"

print(name_map)

tic = time.time()
jet_factory = CorrectedJetsFactory(name_map, jec_stack)
toc = time.time()

print("setup corrected jets time =", toc - tic)

tic = time.time()
corrected_jets = jet_factory.build(jets)
toc = time.time()

print("corrected_jets build time =", toc - tic)
corrected_jets.visualize()

and you'll get this rather amazing task graph with 743 keys.
What's even more interesting is that the graph executes faster with optimization turned off!
optimization off:

%time corrected_jets.JES_AbsoluteStat.up.pt.compute(optimize_graph=False)
CPU times: user 621 ms, sys: 108 ms, total: 728 ms
Wall time: 649 ms

optimization on:

%time corrected_jets.JES_AbsoluteStat.up.pt.compute(optimize_graph=True)
CPU times: user 2.76 s, sys: 12.3 ms, total: 2.77 s
Wall time: 2.77 s

Both results are correct! I'm sure we just need to think about when fuse and inline operations are allowed in awkward task graphs, especially for highly structured data since the data's form implies where caching/inlining are helpful. Anyway - this gets us super close to having realistic analysis use cases already.

The other major bit is the dak.from_buffers style interface to relabel data and mutate interfaces.

@lgray
Copy link
Collaborator Author

lgray commented Dec 29, 2022

If you then add on met corrections:

name_map["METpt"] = "pt"
name_map["METphi"] = "phi"
name_map["JetPhi"] = "phi"
name_map["UnClusteredEnergyDeltaX"] = "MetUnclustEnUpDeltaX"
name_map["UnClusteredEnergyDeltaY"] = "MetUnclustEnUpDeltaY"

met = events.MET
met_factory = CorrectedMETFactory(name_map)
corrected_met = met_factory.build(met, corrected_jets)

You'll find a 4655 node graph:

%time corrected_met.JES_AbsoluteStat.up.pt.compute(optimize_graph=False)
CPU times: user 2.24 s, sys: 86.7 ms, total: 2.33 s
Wall time: 2.28 s

vs.

%time corrected_met.JES_AbsoluteStat.up.pt.compute(optimize_graph=True)
CPU times: user 13.9 s, sys: 41.4 ms, total: 14 s
Wall time: 14 s

In both cases when optimization is invoked the process collapses the graph to a single node.

@agoose77
Copy link
Collaborator

agoose77 commented Dec 29, 2022

Just chiming in whilst people are on holiday: the optimisation strategy currently includes an approach to reduce unneeded disk access by pruning un-needed inputs. There are limitations to the current approach, which I believe involves a brute-force optimisation strategy, that mean the optimisation is both slow and redundant if you're using most of the inputs. My guess is that this brute-forcing is what you're seeing here. You could test this by manually invoking the optimsation step on the graph, and see if it hangs.

It might also be that execution is actually slower, i.e. the optimised graph has longer wall-time. That's a separate kettle of fish.

@lgray
Copy link
Collaborator Author

lgray commented Dec 29, 2022

There's no disk access in this case since dak is reading from awkward arrays in memory. But... perhaps in January.

@martindurant
Copy link
Collaborator

I see that for the optimized graph, execution time is 500ms, so it is faster to run.

Profiling the optimization, all the time is taken by optimize_blockwise, a dask builtin, and not by the brute force method mentioned by @agoose77 (which isn't on by default atm). Particularly, daks.blockwise.subs is called 7.6M times. I would care to bet that it has always been "fast enough" so far, but that the rewriting strategy has N**2 (or worse) characteristics - and this is the first time working with any high-level graph with 745 layers.

I will see what can be done, short of turning off optimization altogether.

Side node: HLG optimization is particularly necessary when working with a distributed scheduler. With the threaded scheduler, there is no serialization/transmit cost, and the overhead per task is very small. Converting the high-level to low-level graph, you get about the same execution time with or without further optimization. Of course, IO optimization (column pruning) would still make a big difference in many workloads.

@martindurant
Copy link
Collaborator

A further comment that, although I have no idea what the example code is doing, I suspect it could be done in a single call to map_partitions, with a function that does all that stuff to one piece of the input ak array. I don't immediately see any aggregation steps.

@lgray
Copy link
Collaborator Author

lgray commented Dec 30, 2022

Ok that's interesting. I'm surprised it's so much slower in my case (I'm running on osx/arm64).

@lgray
Copy link
Collaborator Author

lgray commented Dec 30, 2022

I also agree this is likely very easy to optimize with fusing! The 700 nodes was the outcome of doing a straight port from awkward.virtual awkward1 implementation.

@martindurant
Copy link
Collaborator

I'm surprised it's so much slower in my case (I'm running on osx/arm64).

You misunderstand me: I had the same wall time as you, I was simply saying that the compute part (after optimizing) was faster as a result of optimizing, as one would hope. If only the optimization hadn't outweighed the benefit.

@lgray
Copy link
Collaborator Author

lgray commented Dec 30, 2022

Ah, sorry, I understand now. That is good. We'll get it ship shape in no time I am sure. :-)

@lgray
Copy link
Collaborator Author

lgray commented Dec 30, 2022

So after thinking about this a bit what I hope to achieve optimization wise is that even after zipping the products together, when filling (systematic variations of) histograms we avoid any recalculations of other corrections and avoid fusing too many histogram fill operations together on a single node.

I should also point out that this is strictly a subgraph of an analysis and not a complete one and it's fundamentally a one-to-many mapping from uncorrected jets to O(10-50) variations of corrected jets. All the corrected jets come from the same set of initial corrections which can certainly be fused calculation-wise (hopefully that result can be reused automatically). The variations are trivially parallelized taking this base corrected jet. For the user interface we zip them all back together into a central estimate with a list of cloned and partially substituted (i.e. with varied kinematics) jet variants.

The final zip takes this rather splayed graph and pulls it all to one node (for a decent user interface that we would like to preserve). If these varied 'leaves' each go into a histogram is dask smart enough parallelize the calculation of variants and filling of histograms starting from a common corrected jet node?

I suppose an outcome is that just replicating the whole correction and variation stack is faster at scale, but we know for sure that locally caching the 'central' corrected jet is incredibly efficient.

@lgray
Copy link
Collaborator Author

lgray commented Dec 30, 2022

Good news is dask_histogram appears to have enough smarts to do this out of the box. A 111 category histogram fill seems to also spend most of its time in optimization and is otherwise mega-speedy. Getting it to yield a task graph visualization seems to take a while though.

I'll condense the graph these corrections generate so they're less sprawling and that should help many things.

Though there are some other things that take us to the dask-histogram repository that will need to be worked out (UI quirks, growth axes not working as expected).

@lgray
Copy link
Collaborator Author

lgray commented Dec 30, 2022

For a limited set of systematics we get a quite nice condensed graph (though it isn't very indicative of what's being filled, but that's ok - I checked that the resulting histogram is correct). So we are indeed optimizing as I would hope!:
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants