diff --git a/iree/turbine/kernel/wave/cache.py b/iree/turbine/kernel/wave/cache.py index aaa50dde..4dcfe8e6 100644 --- a/iree/turbine/kernel/wave/cache.py +++ b/iree/turbine/kernel/wave/cache.py @@ -12,6 +12,7 @@ import os import shutil import torch +import threading from collections import OrderedDict from dataclasses import dataclass @@ -88,6 +89,7 @@ class WaveCacheManager(object): def __init__(self): self.file_cache: set[str] = set() self.session_cache: OrderedDict[str, WaveCache] = OrderedDict() + self.lock = threading.Lock() self.update_file_cache() def get_hash( @@ -206,10 +208,11 @@ def store_kernel( """ if not WAVE_CACHE_ON or not kernel_hash: return - self.store_kernel_to_file(kernel_hash, vmfb, kernel_sig, module_str) - self.store_kernel_to_session( - kernel_hash, WaveCache(kernel_hash, kernel_sig, vmfb) - ) + with self.lock: + self.store_kernel_to_file(kernel_hash, vmfb, kernel_sig, module_str) + self.store_kernel_to_session( + kernel_hash, WaveCache(kernel_hash, kernel_sig, vmfb) + ) def load_kernel(self, kernel_hash: str): """ @@ -220,12 +223,13 @@ def load_kernel(self, kernel_hash: str): """ if WAVE_ALWAYS_COMPILE or not kernel_hash or not WAVE_CACHE_ON: return None - if kernel_hash in self.session_cache: - self.session_cache.move_to_end(kernel_hash) - elif kernel_hash in self.file_cache: - cached_kernel = self.load_kernel_from_file(kernel_hash) - self.store_kernel_to_session(kernel_hash, cached_kernel) - return self.session_cache.get(kernel_hash, None) + with self.lock: + if kernel_hash in self.session_cache: + self.session_cache.move_to_end(kernel_hash) + elif kernel_hash in self.file_cache: + cached_kernel = self.load_kernel_from_file(kernel_hash) + self.store_kernel_to_session(kernel_hash, cached_kernel) + return self.session_cache.get(kernel_hash, None) def get_cache_manager() -> WaveCacheManager: