Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Kernel Cacher #329

Merged
merged 8 commits into from
Dec 12, 2024
Merged

[TKW] Kernel Cacher #329

merged 8 commits into from
Dec 12, 2024

Conversation

raikonenfnu
Copy link
Contributor

For eager mode to be viable, we'd need to implement Kernel Cacher S.T we do not need to re-compile kernels every time. Here are the main changes:

  1. Refactor wave.py's compile_and_invoke to two separate functions compile_to_vmfb, and invoke_vmfb this is S.T we can intercept the compiled vmfb cleanly and store it to the caches.
  2. Implement kernel cache dataclass which is a struct necessary to reconstruct kernels S.T it is invokable as original state
  3. Implement fn to invoke kernel cache
  4. Implement kernel cache manager that can hash, load/store kernel to RAM, load/store kernel to files.
  5. Tests and helper fn for cache manager.

Let's discuss a little more about the newly developed cache manager, the Wave cache manager has two main components/cache:

  1. Session/Online cache - This is the main cache that our compiler and runtime will load from and store to. It is essentially a dict that uses the kernel hash as keys and the WaveCache as values. We added LRU functionality with limits for number of kernel cached here, because this lives on RAM, and we wouldn't want to run OOM.

  2. File/Offline cache - This cache is essential for loading saved/compiled cache between sessions/runs. This is done by storing vital kernel information (vmfb, kernel_sig,and mlir) to CACHE_BASE_DIR/kernel_hash directory. If said kernel is queried during a new run and does not exist on session/online cache yet, we'd load files from the kernel_hash directory and reconstruct the WaveCache from it.

Get a unique identifier for a given kernel.
"""
key = [
inspect.getsource(kernel_body),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function source code is not always possible to retrieve and inspect.getsource will raise an error. We can live with this limitation for now, but in this case we need a way to bypass caching.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know how/which cases it will fail to do inspect.getsource? I can add some checker there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I have it on try and except not the nicest aesthetically but should do the job!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main case is when function constructed dynamically via eval/exec. It is also possible to distribute python programs as bytecode (.pyc) without source code.

For eager mode to be viable, we'd need to implement Kernel Cacher S.T we
do not need to re-compile kernels every time. Here are the main changes:

1. Refactor wave.py's compile_and_invoke to two separate functions
   `compile_to_vmfb`, and `invoke_vmfb` this is S.T we can intercept the
compiled vmfb cleanly and store it to the caches.
2. Implement kernel cache dataclass which is a struct necessary to
   reconstruct kernels S.T it is invokable as original state
3. Implement fn to invoke kernel cache
4. Implement kernel cache manager that can hash, load/store kernel to RAM,
   load/store kernel to files.
5. Tests and helper fn for cache manager.

Let's discuss a little more about the newly developed cache manager, the
Wave  cache manager has two main components/cache:

1. Session/Online cache - This is the main cache that our compiler and runtime
will load from and store to. It is essentially a dict that uses the kernel hash
as keys and the WaveCache as values. We added LRU functionality with limits for
number of kernel cached here, because this lives on RAM, and we wouldn't want to run OOM.

2. File/Offline cache - This cache is essential for loading saved/compiled cache
between sessions/runs. This is done by storing vital kernel information
(vmfb, kernel_sig,and mlir) to CACHE_BASE_DIR/kernel_hash directory. If said kernel
is queried during a new run and does not exist on session/online cache yet, we'd load
files from the kernel_hash directory and reconstruct the WaveCache from it.

Signed-off-by: Stanley Winata <[email protected]>
Signed-off-by: Stanley Winata <[email protected]>
Copy link
Contributor

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! thanks for putting this together and will be a big boost to users. Just had some stylistic comments but looks good overall!

iree/turbine/kernel/wave/cache.py Outdated Show resolved Hide resolved
cur_cache_dir = f"{CACHE_BASE_DIR}/{kernel_hash}"
if not os.path.exists(cur_cache_dir):
os.makedirs(cur_cache_dir)
cur_vmfb_path = f"{cur_cache_dir}/{kernel_hash}.vmfb"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: can you switch to using Pathlib instead of manually concatenating the strings which gives you lots of additional benefits.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

"""
Save given kernel(vmfb, kernel_sig, and MLIR) into session_cache and file/offline cache.
"""
if not WAVE_CACHE_ON or not kernel_hash:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If WAVE_ALWAYS_COMPILE=1, should we store the kernel to the cache?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can store it to file cache, S.T we can use it to debug stuff (Triton has the exact same flag), which I had found useful. But I can make sure it doesn't store to the online/session cache.

iree/turbine/kernel/wave/wave.py Show resolved Hide resolved
Signed-off-by: Stanley Winata <[email protected]>
Signed-off-by: Stanley Winata <[email protected]>
Signed-off-by: Stanley Winata <[email protected]>
@raikonenfnu raikonenfnu merged commit d9be797 into iree-org:main Dec 12, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants