Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Avoid repeating identical calculations (speedup) #306

Open
rebeccamccabe opened this issue Dec 28, 2023 · 3 comments
Open
Assignees
Labels
enhancement New feature or request

Comments

@rebeccamccabe
Copy link

Feature description.
A way to cache functions so that repeated evaluations of identical inputs are not recalculated during optimization. Or, a way to remove the time-consuming error-checking from functions during optimization and instead validate the inputs beforehand.

Issue addressed.
Inspired by #305, I tried profiling wecopttool to see what is taking up the most time. Below are the results of a sweep of optimizations with 27 different impedances, with a regular wave, nsubsteps = 16, nfreq = 5, use_grad=False.

image

Surprisingly, a substantial portion of time (2.33 / 6.12 = 38%) of the time is spent repeatedly (6244 times) evaluating the wave_excitation function. A lot of the time in this function comes from the error checking (subset_close, allclose). This seems strange because this function is constant in the optimization, so could be computed a single time and reused for the whole optimization.

Describe alternatives you've considered
In another project, I've used functools.cache which can cache function evaluations when inputs are identical, with just a decorator. I managed to get it working on WecOptTool, but it was a bit more complicated than I expected because xarrays aren't hashable, so I had to use dask (which conveniently is already a WecOptTool dependency through wavespectra). You can see from the profiling results that now wave_excitation only gets called once, and the other 6243 times, the value is cached.

image

Unfortunately, this doesn't work because the dask hashing (which takes around 1ms per argument so 2ms per function call) is slower than the <1ms time it takes to actually execute the wave_excitation function, so the overall time goes up substantially. Now most of the time is in the init function, which calls dask.base.tokenize to generate the hash.

Describe the solution you'd like
This would be beneficial to users who plan to perform many many optimizations and therefore speed is important. Perhaps there is a better way to do the hashing that would make my solution viable. ie:

  • both arguments could be hashed together so hashing has to happen half as often
  • use a faster hashing package (cityhash, xxhash or murmurhash) instead of the default (recommended by dask here)
  • hash some other way without dask (I'm not super knowledgeable on hashing so idk how this would work)

Or, ditch hashing altogether and do something more manual:

  • Evaluate the force once in the wec.solve setup before running the optimization (seemingly requires changing how the wec.forces work, so I didn't want to mess with it)
  • Do the error checking (the most costly part) only in the setup of wec.solve, or add a conditional error-checking flag to the wave_excitation function (this is easy to implement but could result in uncaught silent problems when wave_excitation is called outside of an optimization).

Interest in leading this feature development?
Sure, but I have reached the limit of my knowledge of hashing, so someone with more CS experience might be more useful than me if that approach is preferred

Additional information
The timings were done using main 598e875 rather than 2.6.0 because 2.6.0 has a transpose function in the wave_excitation function which actually seems to slow it down a lot.
I used this xarray issue and this stackoverflow page to figure out the dask hashing.

My code for dask hashing, inserted into core.py:

class HashWrapper:
    def __init__(self, x) -> None:
        self.value = x 
        with dask.config.set({"tokenize.ensure-deterministic":True}):
            self.h = dask.base.tokenize(x)
    def __hash__(self) -> int:
        return hash(self.h)
    def __eq__(self, __value: object) -> bool:
        return __value.h == self.h

def hashable_cache(function):
    @functools.cache
    def cached_wrapper(*args, **kwargs):
        arg_values = [a.value for a in args]
        kwargs_values = {
            k: v.value for k,v in kwargs.items()
        }
        return function(*arg_values, **kwargs_values)
    @functools.wraps(function)
    def wrapper(*args, **kwargs):
        shell_args = [HashWrapper(a) for a in args]
        shell_kwargs = {
            k: HashWrapper(v) for k,v in kwargs.items()
        }
        return cached_wrapper(*shell_args, **shell_kwargs)
    
    wrapper.cache_info = cached_wrapper.cache_info
    wrapper.cache_clear = cached_wrapper.cache_clear
    
    return wrapper

plus the @hashable_cache decorator applied above the wave_excitation function definition. Hypothetically if the caching were actually faster, the decorator could be added to any function with dask-hashable inputs (ie np arrays and xarrays, but not custom classes unless a custom tokenize method were written).

@michaelcdevin michaelcdevin self-assigned this Jan 4, 2024
@cmichelenstrofer
Copy link
Member

This is great! I think we can make the wave excitation function always return some pre-calculated value. This would need to be setup when we call WEC.from_impedance or WEC.from_bem.

@cmichelenstrofer cmichelenstrofer added the enhancement New feature or request label Jan 8, 2024
@ryancoe
Copy link
Collaborator

ryancoe commented Jan 11, 2024

Note that we may have done it this way because we were trying to allow for nonlinear excitation. However, this is only really possible using the default WEC.__init__ method. Solution: for the other static init methods, make it so excitation is only calculated once.

@rebeccamccabe - Thanks for doing this. Can you share with us what tools you used to do the profiling so we can explore this a bit more?

@rebeccamccabe
Copy link
Author

I used the default profiler that comes with spyder ide. I had to turn autograd off to get profile results that were interpretable, because otherwise most computations show up as coming from the autograd box function wrapper.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants