Skip to content

Commit

Permalink
[TKW] Kernel Cacher
Browse files Browse the repository at this point in the history
For eager mode to be viable, we'd need to implement Kernel Cacher S.T we
do not need to re-compile kernels every time. Here are the main changes:

1. Refactor wave.py's compile_and_invoke to two separate functions
   `compile_to_vmfb`, and `invoke_vmfb` this is S.T we can intercept the
compiled vmfb cleanly and store it to the caches.
2. Implement kernel cache dataclass which is a struct necessary to
   reconstruct kernels S.T it is invokable as original state
3. Implement fn to invoke kernel cache
4. Implement kernel cache manager that can hash, load/store kernel to RAM,
   load/store kernel to files.
5. Tests and helper fn for cache manager.

Let's discuss a little more about the newly developed cache manager, the
Wave  cache manager has two main components/cache:

1. Session/Online cache - This is the main cache that our compiler and runtime
will load from and store to. It is essentially a dict that uses the kernel hash
as keys and the WaveCache as values. We added LRU functionality with limits for
number of kernel cached here, because this lives on RAM, and we wouldn't want to run OOM.

2. File/Offline cache - This cache is essential for loading saved/compiled cache
between sessions/runs. This is done by storing vital kernel information
(vmfb, kernel_sig,and mlir) to CACHE_BASE_DIR/kernel_hash directory. If said kernel
is queried during a new run and does not exist on session/online cache yet, we'd load
files from the kernel_hash directory and reconstruct the WaveCache from it.

Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu committed Dec 12, 2024
1 parent bc35630 commit d513890
Show file tree
Hide file tree
Showing 7 changed files with 904 additions and 29 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/ci-tk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,28 @@ jobs:
run: |
pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/
# - name: Test Wave Kernel Cacher
# if: "contains(matrix.os, 'mi300') && !cancelled()"
# run: |
# pip install --no-compile -r pytorch-rocm-requirements.txt
# export WAVE_RUN_E2E_TESTS=1
# export WAVE_CACHE_ON=0
# pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/

- name: Run e2e tests on AMD GPU MI300
if: "contains(matrix.os, 'mi300') && !cancelled()"
run: |
pip install --no-compile -r pytorch-rocm-requirements.txt
export WAVE_RUN_E2E_TESTS=1
export WAVE_CACHE_ON=0
pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/
- name: Run e2e tests on AMD GPU MI250
if: "contains(matrix.os, 'mi250') && !cancelled()"
run: |
pip install --no-compile -r pytorch-rocm-requirements.txt
export WAVE_RUN_E2E_TESTS=1
export WAVE_CACHE_ON=0
pytest -n 2 --capture=tee-sys -vv ./tests/kernel/wave/
- name: Run LIT tests
Expand Down
6 changes: 6 additions & 0 deletions iree/turbine/kernel/_support/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class IndexingContext:
"dyn_dims",
"frozen_subs",
"unbacked_symbols",
"finalized",
]

__tk_context_idname__ = "IndexingContext"
Expand All @@ -116,6 +117,7 @@ def __init__(self):
self.dyn_dims: list[IndexSymbol] = []
self.frozen_subs: list[tuple[IndexSymbol, int]] = []
self.unbacked_symbols: list[IndexSymbol] = []
self.finalized = False

def next_dyn_dim(self) -> IndexSymbol:
s = index_symbol(f"D{len(self.dyn_dims)}")
Expand Down Expand Up @@ -157,6 +159,9 @@ def _bind_symbol(self, symbol: IndexSymbol, value: int):
self.subs[symbol] = value

def finalize(self):
# Early exit if we have finalized indexing context before.
if self.finalized:
return
assert len(self.frozen_subs) == 0
# Go over everything we know and bind all free symbols.
for _sb in self.shaped_bindings.values():
Expand Down Expand Up @@ -217,6 +222,7 @@ def finalize(self):
if errors:
joined = "\n".join(errors)
raise ValueError(f"Indexing mismatches were encountered:\n{joined}")
self.finalized = True

def eval_dim(self, instance: Any, shaped_type: ShapedType, pos: int) -> IndexExpr:
# TODO: Could see if shaped_type is in self.shaped_bindings: it has some
Expand Down
277 changes: 277 additions & 0 deletions iree/turbine/kernel/wave/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception


import copy
import hashlib
import inspect
import json
import os
import shutil
import torch

from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable

from .constraints import Constraint, TilingConstraint, WaveConstraint
from ..compiler.kernel_codegen import KernelBufferUsage
from .._support.indexing import IndexExpr
from .utils import invoke_vmfb, _read_file, _write_file

default_cache_base_dir = f"{str(Path.home())}/.wave"
CACHE_BASE_DIR = str(os.environ.get("WAVE_CACHE_DIR", default_cache_base_dir))
WAVE_ALWAYS_COMPILE = int(os.environ.get("WAVE_ALWAYS_COMPILE", 0))
WAVE_CACHE_ON = int(os.environ.get("WAVE_CACHE_ON", 1))
WAVE_CACHE_LIMIT = int(os.environ.get("WAVE_CACHE_LIMIT", 16))


@dataclass
class WaveCache:
"""
Dataclass/Struct that stores necessary information S.T we can
reconstruct and call the "cached" kernel.
"""

cache_id: str
kernel_sig: tuple[KernelBufferUsage]
vmfb: bytes

@property
def module_op(self):
filepath = f"{CACHE_BASE_DIR}/{self.cache_id}/{self.cache_id}.mlir"
if not os.path.exists(filepath):
raise ValueError("Failed to find module op MLIR for cached kernel.")
with open(filepath, "r") as f:
module_str = f.read()
return module_str


def annonyimize_constraints(input_constraints: list[Constraint]):
"""
Helper function to annonymize constraint S.T we can have the same generate
hash before and after initializing constraints and induction variables.
This is crucial to enable kernels being called under same LaunchableWave have
the same kernel cache despite having constraints and iv initialized.
Note that this annonymization would not affect the correctness of the hash,
because the factors that can impact initialization of these constraints exist
in different parts of the hash.
"""
processed_constraints = copy.deepcopy(input_constraints)
for constraint in processed_constraints:
if isinstance(constraint, TilingConstraint):
constraint.induction_var = None
elif isinstance(constraint, WaveConstraint):
constraint.wave_id = None
else:
continue


class WaveCacheManager(object):
"""
Wave cache manager has two main components/cache:
1. Session/Online cache - This is the main cache that our compiler and runtime will load from and store to. It is
essentially a dict that uses the kernel hash as keys and the WaveCache as values. We added LRU functionality with limits
for number of kernel cached here, because this lives on RAM, and we wouldn't want to run OOM.
2. File/Offline cache - This cache is essential for loading saved/compiled cache between sessions/runs. This is done
by storing vital kernel information(vmfb, kernel_sig, and mlir) to CACHE_BASE_DIR/kernel_hash directory. If said kernel
is queried during a new run and does not exist on session/online cache yet, we'd load files from the kernel_hash directory
and reconstruct the WaveCache from it.
"""

def __init__(self):
self.file_cache: set[str] = set()
self.session_cache: OrderedDict[str, WaveCache] = OrderedDict()
self.update_file_cache()

def get_hash(
self,
constraints: list[Constraint],
kernel_body: Callable,
hyperparams: dict[IndexExpr, Any],
dynamic_symbols: list[IndexExpr, Any],
config: dict[str, str],
use_scheduling: bool,
use_scheduling_barriers: bool,
run_bench: bool,
):
"""
Get a unique identifier for a given kernel.
"""
processed_constraints = annonyimize_constraints(constraints)
key = [
inspect.getsource(kernel_body),
processed_constraints,
hyperparams,
dynamic_symbols,
use_scheduling,
use_scheduling_barriers,
]

# Benchmark related hash
if run_bench and config != None:
key += config.get("benchmark_batch_size", "")
return hashlib.sha256(str(key).encode("utf-8")).hexdigest()

###############################################################################
# File Cache related helpers
###############################################################################

def update_file_cache(self):
"""
Search for saved/cached kernels in cache_base_directory and inform
the cache manager for what are available.
"""
# Early exit if no cache directory found.
if not os.path.exists(CACHE_BASE_DIR):
return
for entry in os.scandir(CACHE_BASE_DIR):
if entry.name not in self.file_cache:
self.file_cache.add(entry.name)

def store_kernel_to_file(
self,
kernel_hash,
vmfb: bytes,
kernel_sig: tuple[KernelBufferUsage],
module_str: str,
):
"""
Stores/save compiled kernels into CACHE_BASE_DIR/kernel_hash
including it's MLIR, VMFB, and kernel signature.
"""
cur_cache_dir = f"{CACHE_BASE_DIR}/{kernel_hash}"
if not os.path.exists(cur_cache_dir):
os.makedirs(cur_cache_dir)
cur_vmfb_path = f"{cur_cache_dir}/{kernel_hash}.vmfb"
cur_module_path = f"{cur_cache_dir}/{kernel_hash}.mlir"
cur_kernelsig_path = f"{cur_cache_dir}/{kernel_hash}.json"
_write_file(cur_vmfb_path, "wb", vmfb)
_write_file(cur_module_path, "w", module_str)
kernel_sig_str = json.dumps([usage.name for usage in kernel_sig])
_write_file(cur_kernelsig_path, "w", kernel_sig_str)

def load_kernel_from_file(self, kernel_hash):
"""
Loads the queried kernel(including VMFB, and kernel signature)
from local cache file/directory.
"""
cur_cache_dir = f"{CACHE_BASE_DIR}/{kernel_hash}"
vmfb = None
kernel_sig_str = None
if not os.path.exists(cur_cache_dir):
raise ValueError("Failed to find queried cached kernel.")
cur_vmfb_path = f"{cur_cache_dir}/{kernel_hash}.vmfb"
cur_kernelsig_path = f"{cur_cache_dir}/{kernel_hash}.json"
vmfb = _read_file(cur_vmfb_path, "rb")
kernel_sig_str = json.loads(_read_file(cur_kernelsig_path, "r"))
kernel_sig = [KernelBufferUsage[usage] for usage in kernel_sig_str]
return WaveCache(kernel_hash, kernel_sig, vmfb)

###############################################################################
# Session cache related helpers
###############################################################################
def store_kernel_to_session(self, kernel_hash: str, cached_kernel: WaveCache):
"""
LRU style storing of kernel into session cache. Set most recently generated kernel to top of session cache,
and if len of cache exceed limit, we'd pop least recently used
"""
self.session_cache[kernel_hash] = cached_kernel
self.session_cache.move_to_end(kernel_hash)
if len(self.session_cache) > WAVE_CACHE_LIMIT:
self.session_cache.popitem(last=False)

def store_kernel(
self,
vmfb: bytes,
kernel_sig: tuple[KernelBufferUsage],
module_str: str,
kernel_hash: str,
):
"""
Save given kernel(vmfb, kernel_sig, and MLIR) into session_cache and file/offline cache.
"""
if not WAVE_CACHE_ON:
return
self.store_kernel_to_file(kernel_hash, vmfb, kernel_sig, module_str)
self.store_kernel_to_session(
kernel_hash, WaveCache(kernel_hash, kernel_sig, vmfb)
)

def load_kernel(self, kernel_hash: str):
"""
LRU style loading of kernel from session cache and move queried kernel to top of LRU if it exist.
If it only exist in file/offline cache, we'll load from local files, reconstruct WaveCache and then store
into session_cache.If it does not exist in session cache nor offline/file cache, then we return "None"
and ask compiler to compile from scratch.
"""
if WAVE_ALWAYS_COMPILE or not WAVE_CACHE_ON:
return None
if kernel_hash in self.session_cache:
self.session_cache.move_to_end(kernel_hash)
elif kernel_hash in self.file_cache:
cached_kernel = self.load_kernel_from_file(kernel_hash)
self.store_kernel_to_session(kernel_hash, cached_kernel)
return self.session_cache.get(kernel_hash, None)


def get_cache_manager() -> WaveCacheManager:
global _global_cache_manager
if not "_global_cache_manager" in globals():
_global_cache_manager = WaveCacheManager()
return _global_cache_manager


def reset_cache_manager() -> WaveCacheManager:
if not "_global_cache_manager" in globals():
return
if os.path.exists(CACHE_BASE_DIR):
shutil.rmtree(CACHE_BASE_DIR)
global _global_cache_manager
del _global_cache_manager


def invoke_cached_kernel(
cached_kernel: WaveCache,
args: list[torch.Tensor],
config: dict[str, str],
dynamic_symbols: list[IndexExpr],
dynamic_symbols_map: dict[IndexExpr, int],
run: bool,
run_bench: bool,
):
kernel_inputs = []
kernel_outputs = []
for arg, usage in zip(args, cached_kernel.kernel_sig):
if usage == KernelBufferUsage.INPUT:
kernel_inputs.append(arg)

if usage == KernelBufferUsage.OUTPUT:
kernel_outputs.append(arg)

kernel_dynamic_dims = []
if dynamic_symbols:
kernel_dynamic_dims = dynamic_symbols_map.values()

if not config:
raise ValueError("no config provided")

invoke_vmfb(
cached_kernel.vmfb,
"isolated_benchmark",
config,
kernel_inputs,
kernel_outputs,
kernel_dynamic_dims,
run,
run_bench,
inplace=True,
)
5 changes: 4 additions & 1 deletion iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,10 @@ def _expand_mma_reduction(
for dim in mma.indexing_dims:
if dim not in dim_scaling and mma.vector_shapes[dim] > 0:
tile_size = idxc.get_static_value(dim)
dim_scaling[dim] = max(tile_size // mma.vector_shapes[dim], 1)
try:
dim_scaling[dim] = max(tile_size // mma.vector_shapes[dim], 1)
except:
breakpoint()

# Store the original mma node and accumulator value for expansion.
# When we begin expansion, we have a single mma node with the correct accumulator.
Expand Down
Loading

0 comments on commit d513890

Please sign in to comment.