Skip to content

Commit

Permalink
Applied suggestions from comments by @csbnw
Browse files Browse the repository at this point in the history
  • Loading branch information
fjwillemsen committed Mar 1, 2024
1 parent 0cb5e3a commit b682506
Show file tree
Hide file tree
Showing 8 changed files with 5 additions and 65 deletions.
3 changes: 1 addition & 2 deletions kernel_tuner/backends/cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
self.devprops = dev.attributes
self.cc = dev.compute_capability
self.max_threads = self.devprops["MaxThreadsPerBlock"]
self.cache_size_L2 = self.devprops["L2CacheSize"]

self.iterations = iterations
self.current_module = None
Expand Down Expand Up @@ -126,7 +125,7 @@ def compile(self, kernel_instance):
compiler_options = self.compiler_options
if not any(["-std=" in opt for opt in self.compiler_options]):
compiler_options = ["--std=c++11"] + self.compiler_options
# CuPy already sets the --gpu-architecture by itself, as per https://github.com/cupy/cupy/blob/20ccd63c0acc40969c851b1917dedeb032209e8b/cupy/cuda/compiler.py#L145
# CuPy already sets the --gpu-architecture by itself, as per https://github.com/cupy/cupy/blob/main/cupy/cuda/compiler.py#L145

options = tuple(compiler_options)

Expand Down
1 change: 0 additions & 1 deletion kernel_tuner/backends/hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None

self.name = self.hipProps._name.decode('utf-8')
self.max_threads = self.hipProps.maxThreadsPerBlock
self.cache_size_L2 = self.hipProps.l2CacheSize
self.device = device
self.compiler_options = compiler_options or []
self.iterations = iterations
Expand Down
4 changes: 0 additions & 4 deletions kernel_tuner/backends/nvcuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,6 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
cudart.cudaDeviceAttr.cudaDevAttrMaxThreadsPerBlock, device
)
cuda_error_check(err)
err, self.cache_size_L2 = cudart.cudaDeviceGetAttribute(
cudart.cudaDeviceAttr.cudaDevAttrL2CacheSize, device
)
cuda_error_check(err)
self.cc = f"{major}{minor}"
self.iterations = iterations
self.current_module = None
Expand Down
4 changes: 0 additions & 4 deletions kernel_tuner/backends/opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ def __init__(
self.max_threads = self.ctx.devices[0].get_info(
cl.device_info.MAX_WORK_GROUP_SIZE
)
# TODO the L2 cache size request fails
# self.cache_size_L2 = self.ctx.devices[0].get_info(
# cl.device_affinity_domain.L2_CACHE
# )
self.compiler_options = compiler_options or []

# observer stuff
Expand Down
1 change: 0 additions & 1 deletion kernel_tuner/backends/pycuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def _finish_up():
str(k): v for (k, v) in self.context.get_device().get_attributes().items()
}
self.max_threads = devprops["MAX_THREADS_PER_BLOCK"]
self.cache_size_L2 = devprops["L2_CACHE_SIZE"]
cc = str(devprops.get("COMPUTE_CAPABILITY_MAJOR", "0")) + str(
devprops.get("COMPUTE_CAPABILITY_MINOR", "0")
)
Expand Down
52 changes: 2 additions & 50 deletions kernel_tuner/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,62 +340,14 @@ def __init__(
if not quiet:
print("Using: " + self.dev.name)

if lang.upper() not in ['OPENCL', 'C', 'FORTRAN']:
# flush the L2 cache, inspired by https://github.com/pytorch/FBGEMM/blob/eb3c304e6c213b81f2b2077813d3c6d16597aa97/fbgemm_gpu/bench/verify_fp16_stochastic_benchmark.cu#L130
flush_gpu_string = """
__global__ void flush_gpu(char* d_flush, char* d_flush2, bool do_write) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const char val = d_flush[idx];
if (do_write * val) {
d_flush2[idx] = val;
}
}
"""
cache_size = self.dev.cache_size_L2
d_flush = np.ones((cache_size), order='F').astype(np.float32)
d_flush2 = np.ones((cache_size), order='F').astype(np.float32)
self.flush_kernel_gpu_args = [d_flush, d_flush2, np.int32(True)]

from kernel_tuner.interface import Options
options = {
'kernel_name': 'flush_gpu',
'lang': 'CUDA',
'arguments': self.flush_kernel_gpu_args,
'problem_size': cache_size,
'grid_div_x': None,
'grid_div_y': None,
'grid_div_z': None,
'block_size_names': None,
}
options = Options(options)
flush_kernel_lang = lang.upper() if lang.upper() in ['CUDA', 'CUPY', 'NVCUDA'] else 'CUPY'
flush_kernel_source = KernelSource('flush_gpu', flush_gpu_string, flush_kernel_lang)
self.flush_kernel_instance = self.create_kernel_instance(flush_kernel_source, kernel_options=options, params=dict(), verbose=not quiet)
self.flush_kernel = self.compile_kernel(self.flush_kernel_instance, verbose=not quiet)
self.flush_kernel_gpu_args = self.ready_argument_list(self.flush_kernel_gpu_args)

# from kernel_tuner.kernelbuilder import PythonKernel
# self.flush_kernel = PythonKernel('flush_gpu', flush_gpu_string, cache_size, self.flush_kernel_gpu_args)

def flush_cache(self):
"""This special function can be called to flush the L2 cache."""
if hasattr(self, 'flush_kernel'):
return
self.dev.synchronize()
assert self.run_kernel(self.flush_kernel, self.flush_kernel_gpu_args, self.flush_kernel_instance)
# self.flush_kernel.run_kernel(self.flush_kernel.gpu_args)
self.dev.synchronize()

def benchmark_default(self, func, gpu_args, threads, grid, result, flush_cache=True):
"""Benchmark one kernel execution at a time. Run with `flush_cache=True` to avoid caching effects between iterations."""
def benchmark_default(self, func, gpu_args, threads, grid, result):
"""Benchmark one kernel execution at a time."""
observers = [
obs for obs in self.dev.observers if not isinstance(obs, ContinuousObserver)
]

self.dev.synchronize()
for _ in range(self.iterations):
if flush_cache:
self.flush_cache()
for obs in observers:
obs.before_start()
self.dev.synchronize()
Expand Down
3 changes: 1 addition & 2 deletions test/test_pycuda_mocked.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ def setup_mock(drv):
context = Mock()
devprops = {'MAX_THREADS_PER_BLOCK': 1024,
'COMPUTE_CAPABILITY_MAJOR': 5,
'COMPUTE_CAPABILITY_MINOR': 5,
'L2_CACHE_SIZE': 4096}
'COMPUTE_CAPABILITY_MINOR': 5,}
context.return_value.get_device.return_value.get_attributes.return_value = devprops
context.return_value.get_device.return_value.compute_capability.return_value = "55"
drv.Device.return_value.retain_primary_context.return_value = context()
Expand Down
2 changes: 1 addition & 1 deletion test/test_util_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_to_valid_nvrtc_gpu_arch_cc():
assert to_valid_nvrtc_gpu_arch_cc("40") == "52"
assert to_valid_nvrtc_gpu_arch_cc("90b") == "90a"
assert to_valid_nvrtc_gpu_arch_cc("91c") == "90a"
assert to_valid_nvrtc_gpu_arch_cc("10123001") == "52"
assert to_valid_nvrtc_gpu_arch_cc("1234") == "52"

This comment has been minimized.

Copy link
@csbnw

csbnw Mar 4, 2024

Collaborator

Both 10123001 and 1234 seem odd number to test. Maybe add a comment that an invalid number should return the lowest support architecture? You could also make "52" a constant.

This comment has been minimized.

Copy link
@csbnw

csbnw Mar 4, 2024

Collaborator

I just noticed your comment in the previous thread. It makes sense to test edge cases here, but I still think that some comment may help.

with pytest.raises(ValueError):
assert to_valid_nvrtc_gpu_arch_cc("")
assert to_valid_nvrtc_gpu_arch_cc("1")
Expand Down

0 comments on commit b682506

Please sign in to comment.