Skip to content

Feature request: Avoid repeating identical calculations (speedup) #306

Open
@rebeccamccabe

Description

@rebeccamccabe

Feature description.
A way to cache functions so that repeated evaluations of identical inputs are not recalculated during optimization. Or, a way to remove the time-consuming error-checking from functions during optimization and instead validate the inputs beforehand.

Issue addressed.
Inspired by #305, I tried profiling wecopttool to see what is taking up the most time. Below are the results of a sweep of optimizations with 27 different impedances, with a regular wave, nsubsteps = 16, nfreq = 5, use_grad=False.

image

Surprisingly, a substantial portion of time (2.33 / 6.12 = 38%) of the time is spent repeatedly (6244 times) evaluating the wave_excitation function. A lot of the time in this function comes from the error checking (subset_close, allclose). This seems strange because this function is constant in the optimization, so could be computed a single time and reused for the whole optimization.

Describe alternatives you've considered
In another project, I've used functools.cache which can cache function evaluations when inputs are identical, with just a decorator. I managed to get it working on WecOptTool, but it was a bit more complicated than I expected because xarrays aren't hashable, so I had to use dask (which conveniently is already a WecOptTool dependency through wavespectra). You can see from the profiling results that now wave_excitation only gets called once, and the other 6243 times, the value is cached.

image

Unfortunately, this doesn't work because the dask hashing (which takes around 1ms per argument so 2ms per function call) is slower than the <1ms time it takes to actually execute the wave_excitation function, so the overall time goes up substantially. Now most of the time is in the init function, which calls dask.base.tokenize to generate the hash.

Describe the solution you'd like
This would be beneficial to users who plan to perform many many optimizations and therefore speed is important. Perhaps there is a better way to do the hashing that would make my solution viable. ie:

  • both arguments could be hashed together so hashing has to happen half as often
  • use a faster hashing package (cityhash, xxhash or murmurhash) instead of the default (recommended by dask here)
  • hash some other way without dask (I'm not super knowledgeable on hashing so idk how this would work)

Or, ditch hashing altogether and do something more manual:

  • Evaluate the force once in the wec.solve setup before running the optimization (seemingly requires changing how the wec.forces work, so I didn't want to mess with it)
  • Do the error checking (the most costly part) only in the setup of wec.solve, or add a conditional error-checking flag to the wave_excitation function (this is easy to implement but could result in uncaught silent problems when wave_excitation is called outside of an optimization).

Interest in leading this feature development?
Sure, but I have reached the limit of my knowledge of hashing, so someone with more CS experience might be more useful than me if that approach is preferred

Additional information
The timings were done using main 598e875 rather than 2.6.0 because 2.6.0 has a transpose function in the wave_excitation function which actually seems to slow it down a lot.
I used this xarray issue and this stackoverflow page to figure out the dask hashing.

My code for dask hashing, inserted into core.py:

class HashWrapper:
    def __init__(self, x) -> None:
        self.value = x 
        with dask.config.set({"tokenize.ensure-deterministic":True}):
            self.h = dask.base.tokenize(x)
    def __hash__(self) -> int:
        return hash(self.h)
    def __eq__(self, __value: object) -> bool:
        return __value.h == self.h

def hashable_cache(function):
    @functools.cache
    def cached_wrapper(*args, **kwargs):
        arg_values = [a.value for a in args]
        kwargs_values = {
            k: v.value for k,v in kwargs.items()
        }
        return function(*arg_values, **kwargs_values)
    @functools.wraps(function)
    def wrapper(*args, **kwargs):
        shell_args = [HashWrapper(a) for a in args]
        shell_kwargs = {
            k: HashWrapper(v) for k,v in kwargs.items()
        }
        return cached_wrapper(*shell_args, **shell_kwargs)
    
    wrapper.cache_info = cached_wrapper.cache_info
    wrapper.cache_clear = cached_wrapper.cache_clear
    
    return wrapper

plus the @hashable_cache decorator applied above the wave_excitation function definition. Hypothetically if the caching were actually faster, the decorator could be added to any function with dask-hashable inputs (ie np arrays and xarrays, but not custom classes unless a custom tokenize method were written).

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions