Description
Feature description.
A way to cache functions so that repeated evaluations of identical inputs are not recalculated during optimization. Or, a way to remove the time-consuming error-checking from functions during optimization and instead validate the inputs beforehand.
Issue addressed.
Inspired by #305, I tried profiling wecopttool to see what is taking up the most time. Below are the results of a sweep of optimizations with 27 different impedances, with a regular wave, nsubsteps = 16
, nfreq = 5
, use_grad=False
.
Surprisingly, a substantial portion of time (2.33 / 6.12 = 38%) of the time is spent repeatedly (6244 times) evaluating the wave_excitation
function. A lot of the time in this function comes from the error checking (subset_close
, allclose
). This seems strange because this function is constant in the optimization, so could be computed a single time and reused for the whole optimization.
Describe alternatives you've considered
In another project, I've used functools.cache
which can cache function evaluations when inputs are identical, with just a decorator. I managed to get it working on WecOptTool, but it was a bit more complicated than I expected because xarrays aren't hashable, so I had to use dask
(which conveniently is already a WecOptTool dependency through wavespectra
). You can see from the profiling results that now wave_excitation
only gets called once, and the other 6243 times, the value is cached.
Unfortunately, this doesn't work because the dask hashing (which takes around 1ms per argument so 2ms per function call) is slower than the <1ms time it takes to actually execute the wave_excitation
function, so the overall time goes up substantially. Now most of the time is in the init
function, which calls dask.base.tokenize
to generate the hash.
Describe the solution you'd like
This would be beneficial to users who plan to perform many many optimizations and therefore speed is important. Perhaps there is a better way to do the hashing that would make my solution viable. ie:
- both arguments could be hashed together so hashing has to happen half as often
- use a faster hashing package (cityhash, xxhash or murmurhash) instead of the default (recommended by dask here)
- hash some other way without dask (I'm not super knowledgeable on hashing so idk how this would work)
Or, ditch hashing altogether and do something more manual:
- Evaluate the force once in the
wec.solve
setup before running the optimization (seemingly requires changing how thewec.forces
work, so I didn't want to mess with it) - Do the error checking (the most costly part) only in the setup of
wec.solve
, or add a conditional error-checking flag to thewave_excitation
function (this is easy to implement but could result in uncaught silent problems whenwave_excitation
is called outside of an optimization).
Interest in leading this feature development?
Sure, but I have reached the limit of my knowledge of hashing, so someone with more CS experience might be more useful than me if that approach is preferred
Additional information
The timings were done using main
598e875 rather than 2.6.0
because 2.6.0
has a transpose
function in the wave_excitation
function which actually seems to slow it down a lot.
I used this xarray issue and this stackoverflow page to figure out the dask hashing.
My code for dask hashing, inserted into core.py
:
class HashWrapper:
def __init__(self, x) -> None:
self.value = x
with dask.config.set({"tokenize.ensure-deterministic":True}):
self.h = dask.base.tokenize(x)
def __hash__(self) -> int:
return hash(self.h)
def __eq__(self, __value: object) -> bool:
return __value.h == self.h
def hashable_cache(function):
@functools.cache
def cached_wrapper(*args, **kwargs):
arg_values = [a.value for a in args]
kwargs_values = {
k: v.value for k,v in kwargs.items()
}
return function(*arg_values, **kwargs_values)
@functools.wraps(function)
def wrapper(*args, **kwargs):
shell_args = [HashWrapper(a) for a in args]
shell_kwargs = {
k: HashWrapper(v) for k,v in kwargs.items()
}
return cached_wrapper(*shell_args, **shell_kwargs)
wrapper.cache_info = cached_wrapper.cache_info
wrapper.cache_clear = cached_wrapper.cache_clear
return wrapper
plus the @hashable_cache
decorator applied above the wave_excitation
function definition. Hypothetically if the caching were actually faster, the decorator could be added to any function with dask-hashable inputs (ie np arrays and xarrays, but not custom classes unless a custom tokenize method were written).