diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 65f7140ce..793a066cb 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -85,7 +85,7 @@ Steps without :bash:`sudo` access (e.g. on a cluster): - /path/to/directory * [Optional] both Mamba and Miniconda can be automatically activated via :bash:`~/.bashrc`. Do not forget to add these (usually provided at the end of the installation). * Exit the shell and re-enter to make sure Conda is available. :bash:`cd` to the kernel tuner directory. - * [Optional] if you have limited user folder space, the Pip cache can be pointed elsewhere with the environment variable :bash:`PIP_CACHE_DIR`. The cache location can be checked with :bash:`pip cache dir`. + * [Optional] if you have limited user folder space, the Pip cache can be pointed elsewhere with the environment variable :bash:`PIP_CACHE_DIR`. The cache location can be checked with :bash:`pip cache dir`. On Linu, to point the entire :bash:`~/.cache` default elsewhere, use the :bash:`XDG_CACHE_HOME` environment variable. * [Optional] update Conda if available before continuing: :bash:`conda update -n base -c conda-forge conda`. #. Setup a virtual environment: :bash:`conda create --name kerneltuner python=3.11` (or whatever Python version and environment name you prefer). #. Activate the virtual environment: :bash:`conda activate kerneltuner`. diff --git a/kernel_tuner/backends/backend.py b/kernel_tuner/backends/backend.py index a37c9d6e7..b8a90bbc0 100644 --- a/kernel_tuner/backends/backend.py +++ b/kernel_tuner/backends/backend.py @@ -2,6 +2,7 @@ from __future__ import print_function from abc import ABC, abstractmethod +from numpy import ndarray class Backend(ABC): @@ -65,6 +66,16 @@ class GPUBackend(Backend): def __init__(self, device, iterations, compiler_options, observers): pass + @abstractmethod + def allocate_ndarray(self, array: ndarray) -> any: + """This method must allocate on the GPU a buffer for a given np.ndarray and return the pointer.""" + pass + + @abstractmethod + def free_mem(self, pointer): + """This method must free on the GPU a buffer for a given pointer.""" + pass + @abstractmethod def copy_constant_memory_args(self, cmem_args): """This method must implement the allocation and copy of constant memory to the GPU.""" diff --git a/kernel_tuner/backends/cupy.py b/kernel_tuner/backends/cupy.py index a1e13ff03..ca514f279 100644 --- a/kernel_tuner/backends/cupy.py +++ b/kernel_tuner/backends/cupy.py @@ -1,5 +1,6 @@ """This module contains all Cupy specific kernel_tuner functions.""" from __future__ import print_function +from warnings import warn import numpy as np @@ -46,6 +47,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None self.devprops = dev.attributes self.cc = dev.compute_capability self.max_threads = self.devprops["MaxThreadsPerBlock"] + self.cache_size_L2 = int(self.devprops["L2CacheSize"]) self.iterations = iterations self.current_module = None @@ -82,6 +84,18 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None self.env = env self.name = env["device_name"] + def allocate_ndarray(self, array): + alloc = cp.array(array) + self.allocations.append(alloc) + return alloc + + def free_mem(self, pointer): + # iteratively comparing is required as comparing with list.remove is not properly implemented + to_remove = [i for i, alloc in enumerate(self.allocations) if cp.array_equal(alloc, pointer)] + assert len(to_remove) == 1 + self.allocations.pop(to_remove[0]) + del pointer # CuPy uses Python reference counter to free upon disuse + def ready_argument_list(self, arguments): """Ready argument list to be passed to the kernel, allocates gpu mem. @@ -97,8 +111,7 @@ def ready_argument_list(self, arguments): for arg in arguments: # if arg i is a numpy array copy to device if isinstance(arg, np.ndarray): - alloc = cp.array(arg) - self.allocations.append(alloc) + alloc = self.allocate_ndarray(arg) gpu_args.append(alloc) # if not a numpy array, just pass argument along else: @@ -124,6 +137,7 @@ def compile(self, kernel_instance): compiler_options = self.compiler_options if not any(["-std=" in opt for opt in self.compiler_options]): compiler_options = ["--std=c++11"] + self.compiler_options + # CuPy already sets the --gpu-architecture by itself, as per https://github.com/cupy/cupy/blob/main/cupy/cuda/compiler.py#L145 options = tuple(compiler_options) @@ -132,6 +146,7 @@ def compile(self, kernel_instance): ) self.func = self.current_module.get_function(kernel_name) + self.num_regs = self.func.num_regs return self.func def start_event(self): @@ -197,6 +212,8 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None): of the grid :type grid: tuple(int, int) """ + if stream is None: + stream = self.stream func(grid, threads, gpu_args, stream=stream, shared_mem=self.smem_size) def memset(self, allocation, value, size): diff --git a/kernel_tuner/backends/hip.py b/kernel_tuner/backends/hip.py index 1db4cb302..1973cfc91 100644 --- a/kernel_tuner/backends/hip.py +++ b/kernel_tuner/backends/hip.py @@ -59,6 +59,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None self.name = self.hipProps._name.decode('utf-8') self.max_threads = self.hipProps.maxThreadsPerBlock + self.cache_size_L2 = int(self.hipProps.l2CacheSize) self.device = device self.compiler_options = compiler_options or [] self.iterations = iterations @@ -85,6 +86,11 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None for obs in self.observers: obs.register_device(self) + def allocate_ndarray(self, array): + return hip.hipMalloc(array.nbytes) + + def free_mem(self, pointer): + raise NotImplementedError("PyHIP currently does not have a free function") def ready_argument_list(self, arguments): """Ready argument list to be passed to the HIP function. @@ -106,7 +112,7 @@ def ready_argument_list(self, arguments): # Allocate space on device for array and convert to ctypes if isinstance(arg, np.ndarray): if dtype_str in dtype_map.keys(): - device_ptr = hip.hipMalloc(arg.nbytes) + device_ptr = self.allocate_ndarray(arg) data_ctypes = arg.ctypes.data_as(ctypes.POINTER(dtype_map[dtype_str])) hip.hipMemcpy_htod(device_ptr, data_ctypes, arg.nbytes) # may be part of run_kernel, return allocations here instead diff --git a/kernel_tuner/backends/nvcuda.py b/kernel_tuner/backends/nvcuda.py index c6fb73d5e..dd2653f04 100644 --- a/kernel_tuner/backends/nvcuda.py +++ b/kernel_tuner/backends/nvcuda.py @@ -1,9 +1,11 @@ """This module contains all NVIDIA cuda-python specific kernel_tuner functions.""" +from warnings import warn + import numpy as np from kernel_tuner.backends.backend import GPUBackend from kernel_tuner.observers.nvcuda import CudaRuntimeObserver -from kernel_tuner.util import SkippableFailure, cuda_error_check +from kernel_tuner.util import SkippableFailure, cuda_error_check, to_valid_nvrtc_gpu_arch_cc # embedded in try block to be able to generate documentation # and run tests without cuda-python installed @@ -66,6 +68,11 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None cudart.cudaDeviceAttr.cudaDevAttrMaxThreadsPerBlock, device ) cuda_error_check(err) + err, self.cache_size_L2 = cudart.cudaDeviceGetAttribute( + cudart.cudaDeviceAttr.cudaDevAttrL2CacheSize, device + ) + cuda_error_check(err) + self.cache_size_L2 = int(self.cache_size_L2) self.cc = f"{major}{minor}" self.iterations = iterations self.current_module = None @@ -107,9 +114,19 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None def __del__(self): for device_memory in self.allocations: - if isinstance(device_memory, cuda.CUdeviceptr): - err = cuda.cuMemFree(device_memory) - cuda_error_check(err) + self.free_mem(device_memory) + + def allocate_ndarray(self, array): + err, device_memory = cuda.cuMemAlloc(array.nbytes) + cuda_error_check(err) + self.allocations.append(device_memory) + return device_memory + + def free_mem(self, pointer): + assert isinstance(pointer, cuda.CUdeviceptr) + self.allocations.remove(pointer) + err = cuda.cuMemFree(pointer) + cuda_error_check(err) def ready_argument_list(self, arguments): """Ready argument list to be passed to the kernel, allocates gpu mem. @@ -126,9 +143,7 @@ def ready_argument_list(self, arguments): for arg in arguments: # if arg is a numpy array copy it to device if isinstance(arg, np.ndarray): - err, device_memory = cuda.cuMemAlloc(arg.nbytes) - cuda_error_check(err) - self.allocations.append(device_memory) + device_memory = self.allocate_ndarray(arg) gpu_args.append(device_memory) self.memcpy_htod(device_memory, arg) # if not array, just pass along @@ -161,12 +176,12 @@ def compile(self, kernel_instance): compiler_options.append(b"--std=c++11") if not any(["--std=" in opt for opt in self.compiler_options]): self.compiler_options.append("--std=c++11") - if not any([b"--gpu-architecture=" in opt for opt in compiler_options]): + if not any([b"--gpu-architecture=" in opt or b"-arch" in opt for opt in compiler_options]): compiler_options.append( - f"--gpu-architecture=compute_{self.cc}".encode("UTF-8") + f"--gpu-architecture=compute_{to_valid_nvrtc_gpu_arch_cc(self.cc)}".encode("UTF-8") ) - if not any(["--gpu-architecture=" in opt for opt in self.compiler_options]): - self.compiler_options.append(f"--gpu-architecture=compute_{self.cc}") + if not any(["--gpu-architecture=" in opt or "-arch" in opt for opt in self.compiler_options]): + self.compiler_options.append(f"--gpu-architecture=compute_{to_valid_nvrtc_gpu_arch_cc(self.cc)}") err, program = nvrtc.nvrtcCreateProgram( str.encode(kernel_string), b"CUDAProgram", 0, [], [] @@ -192,6 +207,11 @@ def compile(self, kernel_instance): ) cuda_error_check(err) + # get the number of registers per thread used in this kernel + num_regs = cuda.cuFuncGetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NUM_REGS, self.func) + assert num_regs[0] == 0, f"Retrieving number of registers per thread unsuccesful: code {num_regs[0]}" + self.num_regs = num_regs[1] + except RuntimeError as re: _, n = nvrtc.nvrtcGetProgramLogSize(program) log = b" " * n @@ -273,6 +293,8 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None): of the grid :type grid: tuple(int, int) """ + if stream is None: + stream = self.stream arg_types = list() for arg in gpu_args: if isinstance(arg, cuda.CUdeviceptr): @@ -309,7 +331,7 @@ def memset(allocation, value, size): :type size: int """ - err = cudart.cudaMemset(allocation, value, size) + err = cudart.cudaMemset(allocation.__init__(), value, size) cuda_error_check(err) @staticmethod diff --git a/kernel_tuner/backends/opencl.py b/kernel_tuner/backends/opencl.py index af3be1c00..0377cbb93 100644 --- a/kernel_tuner/backends/opencl.py +++ b/kernel_tuner/backends/opencl.py @@ -45,6 +45,10 @@ def __init__( self.max_threads = self.ctx.devices[0].get_info( cl.device_info.MAX_WORK_GROUP_SIZE ) + # TODO the L2 cache size request fails + # self.cache_size_L2 = self.ctx.devices[0].get_info( + # cl.device_affinity_domain.L2_CACHE + # ) self.compiler_options = compiler_options or [] # observer stuff @@ -68,6 +72,13 @@ def __init__( self.env = env self.name = dev.name + def allocate_ndarray(self, array): + return cl.Buffer(self.ctx, self.mf.READ_WRITE | self.mf.COPY_HOST_PTR, hostbuf=array) + + def free_mem(self, pointer): + assert isinstance(pointer, cl.Buffer) + pointer.release() + def ready_argument_list(self, arguments): """Ready argument list to be passed to the kernel, allocates gpu mem. @@ -83,13 +94,7 @@ def ready_argument_list(self, arguments): for arg in arguments: # if arg i is a numpy array copy to device if isinstance(arg, np.ndarray): - gpu_args.append( - cl.Buffer( - self.ctx, - self.mf.READ_WRITE | self.mf.COPY_HOST_PTR, - hostbuf=arg, - ) - ) + gpu_args.append(self.allocate_ndarray(arg)) # if not an array, just pass argument along else: gpu_args.append(arg) diff --git a/kernel_tuner/backends/pycuda.py b/kernel_tuner/backends/pycuda.py index 3c168f824..a42ca8a70 100644 --- a/kernel_tuner/backends/pycuda.py +++ b/kernel_tuner/backends/pycuda.py @@ -101,6 +101,7 @@ def _finish_up(): str(k): v for (k, v) in self.context.get_device().get_attributes().items() } self.max_threads = devprops["MAX_THREADS_PER_BLOCK"] + self.cache_size_L2 = int(devprops["L2_CACHE_SIZE"]) cc = str(devprops.get("COMPUTE_CAPABILITY_MAJOR", "0")) + str( devprops.get("COMPUTE_CAPABILITY_MINOR", "0") ) @@ -151,7 +152,17 @@ def __del__(self): for gpu_mem in self.allocations: # if needed for when using mocks during testing if hasattr(gpu_mem, "free"): - gpu_mem.free() + self.free_mem(gpu_mem) + + def allocate_ndarray(self, array): + alloc = drv.mem_alloc(array.nbytes) + self.allocations.append(alloc) + return alloc + + def free_mem(self, pointer): + assert hasattr(pointer, "free") + self.allocations.remove(pointer) + pointer.free() def ready_argument_list(self, arguments): """Ready argument list to be passed to the kernel, allocates gpu mem. @@ -168,8 +179,7 @@ def ready_argument_list(self, arguments): for arg in arguments: # if arg i is a numpy array copy to device if isinstance(arg, np.ndarray): - alloc = drv.mem_alloc(arg.nbytes) - self.allocations.append(alloc) + alloc = self.allocate_ndarray(arg) gpu_args.append(alloc) drv.memcpy_htod(gpu_args[-1], arg) elif isinstance(arg, torch.Tensor): @@ -218,6 +228,8 @@ def compile(self, kernel_instance): ) self.func = self.current_module.get_function(kernel_name) + if not isinstance(self.func, str): + self.num_regs = self.func.num_regs return self.func except drv.CompileError as e: if "uses too much shared data" in e.stderr: diff --git a/kernel_tuner/core.py b/kernel_tuner/core.py index 174cd3af5..44b6e82e9 100644 --- a/kernel_tuner/core.py +++ b/kernel_tuner/core.py @@ -337,17 +337,46 @@ def __init__( self.units = dev.units self.name = dev.name self.max_threads = dev.max_threads + self.flush_possible = lang.upper() not in ['OPENCL', 'HIP', 'C', 'FORTRAN'] and isinstance(self.dev.cache_size_L2, int) and self.dev.cache_size_L2 > 0 + if self.flush_possible: + self.flush_type = np.uint8 + size = (self.dev.cache_size_L2 // self.flush_type(0).itemsize) + # self.flush_array = np.zeros((size), order='F', dtype=self.flush_type) + self.flush_array = np.empty((size), order='F', dtype=self.flush_type) + self.flush_alloc = None if not quiet: print("Using: " + self.dev.name) - def benchmark_default(self, func, gpu_args, threads, grid, result): - """Benchmark one kernel execution at a time""" + def flush_cache(self): + """This special function can be called to flush the L2 cache.""" + if self.flush_possible: + # explicitely free the previous memory + if self.flush_alloc is not None: + self.dev.free_mem(self.flush_alloc) + # inspired by https://github.com/NVIDIA/nvbench/blob/main/nvbench/detail/l2flush.cuh#L51 + self.flush_alloc = self.dev.allocate_ndarray(self.flush_array) + self.dev.memset(self.flush_alloc, value=1, size=self.flush_array.nbytes) + + def benchmark_default(self, func, gpu_args, threads, grid, result, flush_cache=True, recopy_arrays=None): + """ + Benchmark one kernel execution at a time. + + Run with `flush_cache=True` to avoid caching effects between iterations. + Run with `recopy_arrays` to always write the input arrays to the GPU before each kernel launch. Must have the same order as `gpu_args`. + """ observers = [ obs for obs in self.dev.observers if not isinstance(obs, ContinuousObserver) ] self.dev.synchronize() - for _ in range(self.iterations): + for i in range(self.iterations): + if flush_cache: + self.flush_cache() + if recopy_arrays is not None: + assert len(recopy_arrays) == len(gpu_args), "The `recopy_arrays` must be the same length and order as `gpu_args`." + for i, arg in enumerate(recopy_arrays): + if isinstance(arg, (np.ndarray, cp.ndarray, torch.Tensor)): + self.dev.memcpy_htod(gpu_args[i], arg) for obs in observers: obs.before_start() self.dev.synchronize() @@ -391,12 +420,8 @@ def benchmark_continuous(self, func, gpu_args, threads, grid, result, duration): for obs in self.continuous_observers: result.update(obs.get_results()) - def benchmark(self, func, gpu_args, instance, verbose, objective): - """benchmark the kernel instance""" - logging.debug("benchmark " + instance.name) - logging.debug("thread block dimensions x,y,z=%d,%d,%d", *instance.threads) - logging.debug("grid dimensions x,y,z=%d,%d,%d", *instance.grid) - + def set_nvml_parameters(self, instance): + """Set the NVML parameters. Avoids setting time leaking into benchmark time.""" if self.use_nvml: if "nvml_pwr_limit" in instance.params: new_limit = int( @@ -409,6 +434,15 @@ def benchmark(self, func, gpu_args, instance, verbose, objective): if "nvml_mem_clock" in instance.params: self.nvml.mem_clock = instance.params["nvml_mem_clock"] + def benchmark(self, func, gpu_args, instance, verbose, objective, skip_nvml_setting=False, flush_L2=True, recopy_arrays=False): + """Benchmark the kernel instance.""" + logging.debug("benchmark " + instance.name) + logging.debug("thread block dimensions x,y,z=%d,%d,%d", *instance.threads) + logging.debug("grid dimensions x,y,z=%d,%d,%d", *instance.grid) + + if self.use_nvml and not skip_nvml_setting: + self.set_nvml_parameters(instance) + # Call the observers to register the configuration to be benchmarked for obs in self.dev.observers: obs.register_configuration(instance.params) @@ -416,7 +450,7 @@ def benchmark(self, func, gpu_args, instance, verbose, objective): result = {} try: self.benchmark_default( - func, gpu_args, instance.threads, instance.grid, result + func, gpu_args, instance.threads, instance.grid, result, flush_cache=flush_L2, recopy_arrays=instance.arguments if recopy_arrays else None ) if self.continuous_observers: @@ -577,9 +611,12 @@ def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options, # benchmark if func: + # setting the NVML parameters here avoids this time from leaking into the benchmark time, ends up in framework time instead + if self.use_nvml: + self.set_nvml_parameters(instance) start_benchmark = time.perf_counter() result.update( - self.benchmark(func, gpu_args, instance, verbose, to.objective) + self.benchmark(func, gpu_args, instance, verbose, to.objective, skip_nvml_setting=False, flush_L2=to.flush_L2_cache, recopy_arrays=to.recopy_arrays) ) last_benchmark_time = 1000 * (time.perf_counter() - start_benchmark) diff --git a/kernel_tuner/interface.py b/kernel_tuner/interface.py index 96efe3ce8..bdfd26f43 100644 --- a/kernel_tuner/interface.py +++ b/kernel_tuner/interface.py @@ -464,6 +464,8 @@ def __deepcopy__(self, _): ("metrics", ("specifies user-defined metrics, please see :ref:`metrics`.", "dict")), ("simulation_mode", ("Simulate an auto-tuning search from an existing cachefile", "bool")), ("observers", ("""A list of Observers to use during tuning, please see :ref:`observers`.""", "list")), + ("flush_L2_cache", ("""Whether to flush the GPU L2 cache between kernel launches. Defaults to True.""", "bool")), + ("recopy_arrays", ("""Whether to rewrite the input arrays to the GPU between kernel launches. Defaults to False.""", "bool")), ] ) @@ -577,6 +579,8 @@ def tune_kernel( observers=None, objective=None, objective_higher_is_better=None, + flush_L2_cache=True, + recopy_arrays=False, ): start_overhead_time = perf_counter() if log: diff --git a/kernel_tuner/observers/nvml.py b/kernel_tuner/observers/nvml.py index d33327a3c..0bd9adc84 100644 --- a/kernel_tuner/observers/nvml.py +++ b/kernel_tuner/observers/nvml.py @@ -135,6 +135,8 @@ def persistence_mode(self, new_mode): raise ValueError( "Illegal value for persistence mode, should be either 0 or 1" ) + if self.persistence_mode == new_mode: + return try: pynvml.nvmlDeviceSetPersistenceMode(self.dev, new_mode) self._persistence_mode = pynvml.nvmlDeviceGetPersistenceMode(self.dev) @@ -168,21 +170,15 @@ def set_clocks(self, mem_clock, gr_clock): self.nvidia_smi, "-i", str(self.id), - "--lock-gpu-clocks=" + str(gr_clock) + "," + str(gr_clock), ] - subprocess.run(args, check=True) - args = [ - "sudo", - self.nvidia_smi, - "-i", - str(self.id), - "--lock-memory-clocks=" + str(mem_clock) + "," + str(mem_clock), - ] - subprocess.run(args, check=True) + command_set_mem_clocks = f"--lock-memory-clocks={str(mem_clock)},{str(mem_clock)}" + command_set_gpu_clocks = f"--lock-gpu-clocks={str(gr_clock)},{str(gr_clock)}" + subprocess.run(args + [command_set_gpu_clocks], check=True) + subprocess.run(args + [command_set_mem_clocks], check=True) else: try: - if self.persistence_mode != 0: - self.persistence_mode = 0 + if self.persistence_mode != 1: + self.persistence_mode = 1 except Exception: pass try: @@ -233,24 +229,20 @@ def reset_clocks(self): if ( gr_app_clock != self.gr_clock_default or mem_app_clock != self.mem_clock_default - ): + ): self.set_clocks(self.mem_clock_default, self.gr_clock_default) @property def gr_clock(self): """Control the graphics clock (may require permission), only values compatible with the memory clock can be set directly.""" - return pynvml.nvmlDeviceGetClockInfo(self.dev, pynvml.NVML_CLOCK_GRAPHICS) + if self.use_locked_clocks: + return pynvml.nvmlDeviceGetClockInfo(self.dev, pynvml.NVML_CLOCK_GRAPHICS) + else: + return pynvml.nvmlDeviceGetApplicationsClock(self.dev, pynvml.NVML_CLOCK_GRAPHICS) @gr_clock.setter def gr_clock(self, new_clock): - cur_clock = ( - pynvml.nvmlDeviceGetClockInfo(self.dev, pynvml.NVML_CLOCK_GRAPHICS) - if self.use_locked_clocks - else pynvml.nvmlDeviceGetApplicationsClock( - self.dev, pynvml.NVML_CLOCK_GRAPHICS - ) - ) - if new_clock != cur_clock: + if new_clock != self.gr_clock: self.set_clocks(self.mem_clock, new_clock) @property @@ -268,12 +260,7 @@ def mem_clock(self): @mem_clock.setter def mem_clock(self, new_clock): - cur_clock = ( - pynvml.nvmlDeviceGetClockInfo(self.dev, pynvml.NVML_CLOCK_MEM) - if self.use_locked_clocks - else pynvml.nvmlDeviceGetApplicationsClock(self.dev, pynvml.NVML_CLOCK_MEM) - ) - if new_clock != cur_clock: + if new_clock != self.mem_clock: self.set_clocks(new_clock, self.gr_clock) @property diff --git a/kernel_tuner/observers/register.py b/kernel_tuner/observers/register.py new file mode 100644 index 000000000..92f22ffd8 --- /dev/null +++ b/kernel_tuner/observers/register.py @@ -0,0 +1,16 @@ +from kernel_tuner.observers.observer import BenchmarkObserver + +class RegisterObserver(BenchmarkObserver): + """Observer for counting the number of registers.""" + + def __init__(self) -> None: + super().__init__() + + def get_results(self): + try: + registers_per_thread = self.dev.num_regs + except AttributeError: + raise NotImplementedError(f"Backend '{type(self.dev).__name__}' does not support count of registers per thread") + return { + "num_regs": registers_per_thread + } \ No newline at end of file diff --git a/kernel_tuner/runners/sequential.py b/kernel_tuner/runners/sequential.py index c493a0089..aeebd5116 100644 --- a/kernel_tuner/runners/sequential.py +++ b/kernel_tuner/runners/sequential.py @@ -100,7 +100,7 @@ def run(self, parameter_space, tuning_options): params = process_metrics(params, tuning_options.metrics) # get the framework time by estimating based on other times - total_time = 1000 * (perf_counter() - self.start_time) - warmup_time + total_time = 1000 * ((perf_counter() - self.start_time) - warmup_time) params['strategy_time'] = self.last_strategy_time params['framework_time'] = max(total_time - (params['compile_time'] + params['verification_time'] + params['benchmark_time'] + params['strategy_time']), 0) params['timestamp'] = str(datetime.now(timezone.utc)) diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index 6e9cdf5b0..df73f2127 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -570,6 +570,24 @@ def get_total_timings(results, env, overhead_time): return env +def to_valid_nvrtc_gpu_arch_cc(compute_capability: str) -> str: + """Returns a valid Compute Capability for NVRTC `--gpu-architecture=`, as per https://docs.nvidia.com/cuda/nvrtc/index.html#group__options.""" + valid_cc = ['50', '52', '53', '60', '61', '62', '70', '72', '75', '80', '87', '89', '90', '90a'] # must be in ascending order, when updating also update test_to_valid_nvrtc_gpu_arch_cc + compute_capability = str(compute_capability) + if len(compute_capability) < 2: + raise ValueError(f"Compute capability '{compute_capability}' must be at least of length 2, is {len(compute_capability)}") + if compute_capability in valid_cc: + return compute_capability + # if the compute capability does not match, scale down to the nearest matching + subset_cc = [cc for cc in valid_cc if compute_capability[0] == cc[0]] + if len(subset_cc) > 0: + # get the next-highest valid CC + highest_cc_index = max([i for i, cc in enumerate(subset_cc) if int(cc[1]) <= int(compute_capability[1])]) + return subset_cc[highest_cc_index] + # if all else fails, return the default 52 + return '52' + + def print_config(config, tuning_options, runner): """Print the configuration string with tunable parameters and benchmark results.""" print_config_output(tuning_options.tune_params, config, runner.quiet, tuning_options.metrics, runner.units) diff --git a/test/test_observers.py b/test/test_observers.py index d881fed74..c1cc460a9 100644 --- a/test/test_observers.py +++ b/test/test_observers.py @@ -1,11 +1,14 @@ - - import kernel_tuner from kernel_tuner.observers.nvml import NVMLObserver +from kernel_tuner.observers.register import RegisterObserver from kernel_tuner.observers.observer import BenchmarkObserver -from .context import skip_if_no_pycuda, skip_if_no_pynvml +from .context import skip_if_no_pycuda, skip_if_no_pynvml, skip_if_no_cupy, skip_if_no_cuda, skip_if_no_opencl, skip_if_no_pyhip from .test_runners import env # noqa: F401 +from .test_opencl_functions import env as env_opencl # noqa: F401 +from .test_hip_functions import env as env_hip # noqa: F401 + +from pytest import raises @skip_if_no_pycuda @@ -20,7 +23,6 @@ def test_nvml_observer(env): assert "temperature" in result[0] assert result[0]["temperature"] > 0 - @skip_if_no_pycuda def test_custom_observer(env): env[-1]["block_size_x"] = [128] @@ -34,3 +36,34 @@ def get_results(self): assert "name" in result[0] assert len(result[0]["name"]) > 0 +@skip_if_no_pycuda +def test_register_observer_pycuda(env): + result, _ = kernel_tuner.tune_kernel(*env, observers=[RegisterObserver()], lang='CUDA') + assert "num_regs" in result[0] + assert result[0]["num_regs"] > 0 + +@skip_if_no_cupy +def test_register_observer_cupy(env): + result, _ = kernel_tuner.tune_kernel(*env, observers=[RegisterObserver()], lang='CuPy') + assert "num_regs" in result[0] + assert result[0]["num_regs"] > 0 + +@skip_if_no_cuda +def test_register_observer_nvcuda(env): + result, _ = kernel_tuner.tune_kernel(*env, observers=[RegisterObserver()], lang='NVCUDA') + assert "num_regs" in result[0] + assert result[0]["num_regs"] > 0 + +@skip_if_no_opencl +def test_register_observer_opencl(env_opencl): + with raises(NotImplementedError) as err: + kernel_tuner.tune_kernel(*env_opencl, observers=[RegisterObserver()], lang='OpenCL') + assert err.errisinstance(NotImplementedError) + assert "OpenCL" in str(err.value) + +@skip_if_no_pyhip +def test_register_observer_hip(env_opencl): + with raises(NotImplementedError) as err: + kernel_tuner.tune_kernel(*env_opencl, observers=[RegisterObserver()], lang='HIP') + assert err.errisinstance(NotImplementedError) + assert "Hip" in str(err.value) diff --git a/test/test_pycuda_mocked.py b/test/test_pycuda_mocked.py index 21f352a3f..e47fc8e8e 100644 --- a/test/test_pycuda_mocked.py +++ b/test/test_pycuda_mocked.py @@ -13,7 +13,8 @@ def setup_mock(drv): context = Mock() devprops = {'MAX_THREADS_PER_BLOCK': 1024, 'COMPUTE_CAPABILITY_MAJOR': 5, - 'COMPUTE_CAPABILITY_MINOR': 5} + 'COMPUTE_CAPABILITY_MINOR': 5, + 'L2_CACHE_SIZE': 4096} context.return_value.get_device.return_value.get_attributes.return_value = devprops context.return_value.get_device.return_value.compute_capability.return_value = "55" drv.Device.return_value.retain_primary_context.return_value = context() diff --git a/test/test_util_functions.py b/test/test_util_functions.py index c66964b0e..970603d83 100644 --- a/test/test_util_functions.py +++ b/test/test_util_functions.py @@ -146,6 +146,19 @@ def test_get_thread_block_dimensions(): assert threads[2] == 1 +def test_to_valid_nvrtc_gpu_arch_cc(): + assert to_valid_nvrtc_gpu_arch_cc("89") == "89" + assert to_valid_nvrtc_gpu_arch_cc("88") == "87" + assert to_valid_nvrtc_gpu_arch_cc("86") == "80" + assert to_valid_nvrtc_gpu_arch_cc("40") == "52" + assert to_valid_nvrtc_gpu_arch_cc("90b") == "90a" + assert to_valid_nvrtc_gpu_arch_cc("91c") == "90a" + assert to_valid_nvrtc_gpu_arch_cc("1234") == "52" + with pytest.raises(ValueError): + assert to_valid_nvrtc_gpu_arch_cc("") + assert to_valid_nvrtc_gpu_arch_cc("1") + + def test_prepare_kernel_string(): kernel = "this is a weird kernel" grid = (3, 7)