Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential speedups in pipeline runtime #35

Closed
eberrigan opened this issue Jun 22, 2023 · 5 comments · Fixed by #51
Closed

Potential speedups in pipeline runtime #35

eberrigan opened this issue Jun 22, 2023 · 5 comments · Fixed by #51

Comments

@eberrigan
Copy link
Collaborator

eberrigan commented Jun 22, 2023

Doing some profiling, we see that we may be able to shave off ~20-40% by reducing repeated calls to convex hull and ellipse fitting routines.

Here's a profiling script:

from sleap_roots.graphpipeline import get_traits_value_plant_summary
import numpy as np
import pandas as pd
from pathlib import Path
import warnings
from rich.progress import track


def main():
    
    data_folders = """
    week1_3do_4-18-22
    week2_3do_4-25-22
    week3_3_do_5-2-22
    week4_3do_5-6-22
    """.strip().split()
    
    monocot = True
    primary_name = "longest_3do_6nodes"
    lateral_name = "main_3do_6nodes"
    overwrite = False
    
    
    warnings.filterwarnings("ignore", message="invalid value encountered in intersection", category=RuntimeWarning, module="shapely")
    warnings.filterwarnings("ignore", message="All-NaN slice encountered", category=RuntimeWarning)
    warnings.filterwarnings("ignore", message="All-NaN axis encountered", category=RuntimeWarning)
    warnings.filterwarnings("ignore", message="Degrees of freedom <= 0 for slice.", category=RuntimeWarning, module="numpy")
    warnings.filterwarnings("ignore", message="Mean of empty slice", category=RuntimeWarning)
    warnings.filterwarnings("ignore", message="invalid value encountered in sqrt", category=RuntimeWarning, module="skimage")
    warnings.filterwarnings("ignore", message="invalid value encountered in double_scalars", category=RuntimeWarning)
    
    all_traits = []
    h5s = []
    for data_folder in data_folders:
        h5s.extend(Path(data_folder).glob("*.h5"))
    
    # for h5 in track(h5s):
    for h5 in h5s[:5]:
        csv_path = h5.with_suffix(".traits.csv")
        plant_traits = get_traits_value_plant_summary(
            h5.as_posix(),
            monocot,
            primary_name=primary_name,
            lateral_name=lateral_name,
            stem_width_tolerance=0.02,
            n_line=50,
            network_fraction=2 / 3,
            # write_csv=(overwrite or not csv_path.exists()),
            write_csv=False,
            csv_name=csv_path.as_posix(),
        )
        plant_traits["path"] = h5.as_posix()
        all_traits.append(plant_traits)
    all_traits = pd.concat(all_traits, ignore_index=True)
    
    # data_plant_frame_summary.to_csv("all_traits.csv", index=False)


if __name__ == "__main__":
    main()

This runs the pipeline without saving the CSVs for 5 plants.

You can profile by saving the above script to get_traits.py and:

pip install pyinstrument

Then:

pyinstrument -r html get_traits.py

Here's the relevant parts:
image

The solution would be to cache the results from functions like get_convhull, fit_ellipse and etc. that are called repeatedly for each derived feature.

I appreciate that it might be annoying to save that in the data dict in the graph pipeline since they're not serializable later in a dataframe though. One solution could be to use functools.lru_cache. This would work like:

from functools import lru_cache

@lru_cache
def get_convhull(...):
    ...

Then, the next time get_convhull is called with the same arguments in the pipeline, it'll just return the cached results without recomputing them. If it works, it's a lot less work than refactoring the graph pipeline to carry the cached results around with the data dict :)

@linwang9926
Copy link
Collaborator

The issue was solved by adding a dictionary which stores the scanline, ellipse, and convex hull for each plant. (PR #44 )
But there is a problem for testing the modified functions, the dictionary was not cleared for new function test.

@linwang9926
Copy link
Collaborator

The testing issue was fixed by adding the function which appending results to the cache dictionary for the test functions of different dataset
For example, if the last test case was canola, the current test case is rice, add get_convhull function when test the convex hull area of rice. This will make sure the last cache dictionary won't carry out to this test.

@linwang9926
Copy link
Collaborator

This way will save more than 1/2 of calculation time (2.781 vs. 6.589 seconds for 5 rice plants using my desktop)!

@talmo
Copy link
Contributor

talmo commented Jul 19, 2023

We should refactor the traits pipeline to actually use the trait map as a computation graph. Currently it just serves as metadata, but we don't use the graph structure at all -- we're in fact recomputing intermediate traits everywhere since all of the inputs to all of the traits are just the original primary and/or lateral points.

In get_traits_value_frame() where we actually do the function execution, we should be re-using values from the data variable:

    trait_computation_order = ... # Find breadth-first ordering of trait graph.

    # Initialize traits container with initial points.
    traits = {"pts": pts}

    # Compute each trait.
    for trait_name in trait_computation_order:
        fn, input_traits, kwargs = trait_map[trait_name]
        fn_outputs = fn(*[traits[input_trait] for input_trait in input_traits], **kwargs)
        traits[trait_name] = fn_outputs

Re-define the traits map so it's formatted as:

{
    "trait_name": (function, ["input_trait1", "input_trait2", ...], {"additional_kwarg1": True, "additional_kwarg2": 0.5})
}

For example, for a derived convhull feature, we would specify it as:

{
    "chull_area": (get_chull_area, ["convex_hull"], {}),
}

This indicates that the "convex_hull" trait should be reused as input to this function rather than recomputed from the original points.

The **kwargs allows you to pass in arbitrary, trait-specific arguments that aren't in traits, for example:

{
    "scanline_intersection_counts": (
            count_scanline_intersections,
            ["primary_pts", "lateral_pts"],  # these come from traits dict
            {"height": 1080, "width": 2048, "n_line": 50, "monocots": monocots}  # these are fixed inputs that don't depend on the trait graph
        ),
}

This format would also make traitsgraph unnecessary since it can be inferred from the trait_map directly like:

edges = []
for output_trait, (_, input_traits, _) in trait_map.items():
    for input_trait in input_traits:
        edges.append((input_trait, output_trait))

Putting it all together:

    # Define trait map.
    trait_map = {
        # Ignore these if precomputed already:
        # "primary_pts": (get_primary_pts, ["pts"], {}),
        # "lateral_pts": (get_lateral_pts, ["pts"], {}),
# ...
        "scanline_intersection_counts": (
                count_scanline_intersections,
                ["primary_pts", "lateral_pts"],  # these come from traits dict
                {"height": 1080, "width": 2048, "n_line": 50, "monocots": monocots}  # these are fixed inputs that don't depend on the trait graph
            ),
# ...
        "chull_area": (get_chull_area, ["convex_hull"], {}),
# ...
    }

    # Initialize edges with precomputed top-level traits.
    edges = [("pts", "primary_pts"), ("pts", "lateral_pts")]

    # Infer edges from trait map.
    for output_trait, (_, input_traits, _) in trait_map.items():
        for input_trait in input_traits:
            edges.append((input_trait, output_trait))

    # Compute breadth-first ordering.
    G = nx.DiGraph()
    G.add_edges_from(edges)
    trait_computation_order = [dst for (src, dst) in list(nx.bfs_tree(G, "pts").edges())[2:]]

    # Initialize traits container with initial points.
    traits = {"primary_pts": primary_pts, "lateral_pts": lateral_pts}

    # Compute each trait.
    for trait_name in trait_computation_order:
        if trait_name in traits:
            # Ignore traits that are already computed.
            continue
        fn, input_traits, kwargs = trait_map[trait_name]
        fn_outputs = fn(*[traits[input_trait] for input_trait in input_traits], **kwargs)
        traits[trait_name] = fn_outputs

@talmo
Copy link
Contributor

talmo commented Jul 19, 2023

btw, pls rename graphpipeline to just pipeline :)

@talmo talmo linked a pull request Aug 17, 2023 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants