Skip to content

Commit

Permalink
add thread locks
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu committed Dec 12, 2024
1 parent 458cfa3 commit 09cdfae
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions iree/turbine/kernel/wave/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import shutil
import torch
import threading

from collections import OrderedDict
from dataclasses import dataclass
Expand Down Expand Up @@ -88,6 +89,7 @@ class WaveCacheManager(object):
def __init__(self):
self.file_cache: set[str] = set()
self.session_cache: OrderedDict[str, WaveCache] = OrderedDict()
self.lock = threading.Lock()
self.update_file_cache()

def get_hash(
Expand Down Expand Up @@ -206,10 +208,11 @@ def store_kernel(
"""
if not WAVE_CACHE_ON or not kernel_hash:
return
self.store_kernel_to_file(kernel_hash, vmfb, kernel_sig, module_str)
self.store_kernel_to_session(
kernel_hash, WaveCache(kernel_hash, kernel_sig, vmfb)
)
with self.lock:
self.store_kernel_to_file(kernel_hash, vmfb, kernel_sig, module_str)
self.store_kernel_to_session(
kernel_hash, WaveCache(kernel_hash, kernel_sig, vmfb)
)

def load_kernel(self, kernel_hash: str):
"""
Expand All @@ -220,12 +223,13 @@ def load_kernel(self, kernel_hash: str):
"""
if WAVE_ALWAYS_COMPILE or not kernel_hash or not WAVE_CACHE_ON:
return None
if kernel_hash in self.session_cache:
self.session_cache.move_to_end(kernel_hash)
elif kernel_hash in self.file_cache:
cached_kernel = self.load_kernel_from_file(kernel_hash)
self.store_kernel_to_session(kernel_hash, cached_kernel)
return self.session_cache.get(kernel_hash, None)
with self.lock:
if kernel_hash in self.session_cache:
self.session_cache.move_to_end(kernel_hash)
elif kernel_hash in self.file_cache:
cached_kernel = self.load_kernel_from_file(kernel_hash)
self.store_kernel_to_session(kernel_hash, cached_kernel)
return self.session_cache.get(kernel_hash, None)


def get_cache_manager() -> WaveCacheManager:
Expand Down

0 comments on commit 09cdfae

Please sign in to comment.