From 6de10e364a2d75a76136ae439dce28ea2b896ce9 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 9 Sep 2024 12:08:51 -0500 Subject: [PATCH] Add kernel execution time measurement using hooks for do_bench (#139) * Add timing measurements using launch hooks for CPU. Signed-off-by: Ilya Enkovich * Avoid OMP for trivial grid in CPU launcher. Signed-off-by: Ilya Enkovich * Add more measurement options for vector-add tutorial. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich --- python/triton/testing.py | 55 ++++++++++++++++++--- python/tutorials/01-vector-add.py | 80 +++++++++++++++++++++++++++---- third_party/cpu/backend/driver.py | 6 ++- 3 files changed, 123 insertions(+), 18 deletions(-) diff --git a/python/triton/testing.py b/python/triton/testing.py index a0813ba1611b..de73c0595185 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -6,27 +6,63 @@ from contextlib import contextmanager from typing import Any, Dict, List from . import language as tl +import triton class CPUDeviceInterface: - class Event: + class HooksTimeAccessor: - def __init__(self, enable_timing=True): - self.time = 0 + def __init__(self, di): + self.di = di + self.record_idx = 0 def elapsed_time(self, end_event) -> float: - return (end_event.time - self.time) * 1000 + total_time = 0 + for i in range(self.record_idx, end_event.record_idx): + total_time += self.di.kernel_times[i] + return total_time * 1000 def record(self): - self.time = time.perf_counter() + self.record_idx = len(self.di.kernel_times) + + class TimerEvent: + + def __init__(self): + self.timer = 0 + + def elapsed_time(self, end_event) -> float: + return (end_event.timer - self.timer) * 1000 + + def record(self): + self.timer = time.perf_counter() def __init__(self): - pass + self.kernel_times = [] + self.last_start = 0 + self.use_hooks = False + triton.compiler.CompiledKernel.launch_enter_hook = None + triton.compiler.CompiledKernel.launch_exit_hook = None + + def enable_hook_timing(self): + self.use_hooks = True + triton.compiler.CompiledKernel.launch_enter_hook = lambda arg: self._enter_hook() + triton.compiler.CompiledKernel.launch_exit_hook = lambda arg: self._exit_hook() def synchronize(self): pass + def _enter_hook(self): + self.last_start = time.perf_counter() + + def _exit_hook(self): + self.kernel_times.append(time.perf_counter() - self.last_start) + + def Event(self, enable_timing=True): + if self.use_hooks: + return CPUDeviceInterface.HooksTimeAccessor(self) + return CPUDeviceInterface.TimerEvent() + def nvsmi(attrs): attrs = ','.join(attrs) @@ -113,7 +149,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean", - device_type="cuda"): + device_type="cuda", measure_time_with_hooks=False): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -168,6 +204,11 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu di.synchronize() estimate_ms = start_event.elapsed_time(end_event) / 5 + # For CPU we can use entry and exit hooks to measure execution time + # more precisely. + if measure_time_with_hooks: + di.enable_hook_timing() + # compute number of warmup and repeat n_warmup = max(1, int(warmup / estimate_ms)) n_repeat = max(1, int(rep / estimate_ms)) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index ea2c816a463a..0acb69fc3b21 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -25,6 +25,7 @@ GPU_BLOCK_SIZE = 1024 CPU_BLOCK_SIZE = 4096 +CPU_ST_THRESHOLD = 65536 USE_GPU = False @@ -56,6 +57,26 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. tl.store(output_ptr + offsets, output, mask=mask) +@triton.jit +def add_kernel_tiled(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + TILE_SIZE: tl.constexpr, # Number of elements each iteration should process. + # NOTE `constexpr` so it can be used as a shape value. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + for i in range(0, tl.cdiv(BLOCK_SIZE, TILE_SIZE)): + offsets = block_start + i * TILE_SIZE + tl.arange(0, TILE_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + # %% # Let's also declare a helper function to (1) allocate the `z` tensor # and (2) enqueue the above kernel with appropriate grid/block sizes: @@ -80,6 +101,28 @@ def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, device): return output +def add_tiled(x: torch.Tensor, y: torch.Tensor, output): + if output is None: + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + add_kernel_tiled[grid](x, y, output, n_elements, BLOCK_SIZE=CPU_BLOCK_SIZE, TILE_SIZE=16) + return output + + +def add_tiled_with_st_threshold(x: torch.Tensor, y: torch.Tensor, output): + if output is None: + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + # TODO: try to choose the best block size using autotuner + BLOCK_SIZE = triton.next_power_of_2(n_elements) + if BLOCK_SIZE > CPU_ST_THRESHOLD: + BLOCK_SIZE = CPU_BLOCK_SIZE + add_kernel_tiled[grid](x, y, output, n_elements, BLOCK_SIZE=BLOCK_SIZE, TILE_SIZE=16) + return output + + # %% # We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: torch.manual_seed(0) @@ -94,10 +137,19 @@ def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, device): print(output_triton_cpu) print(f'The maximum difference between torch-cpu and triton-cpu is ' f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}') +output_triton_cpu = add_tiled(x, y, None) +print(f'The maximum difference between torch-cpu-tiled and triton-cpu is ' + f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}') -LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu'] -LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU'] -LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '-')] +LINE_VALS = [ + 'triton-cpu', 'triton-cpu-hooks', 'triton-cpu-tiled', 'triton-cpu-tiled-hooks', 'triton-cpu-tiled-tuned-hooks', + 'torch-cpu' +] +LINE_NAMES = [ + 'TritonCPU', 'TritonCPU (hooks)', 'TritonCPUTiled', 'TritonCPUTiled (hooks)', 'TritonCPUTiled (tuned, hooks)', + 'TorchCPU' +] +LINE_STYLES = [('blue', '--'), ('blue', '-'), ('blue', '-'), ('blue', '-'), ('blue', '-'), ('green', '-')] if USE_GPU and triton.runtime.driver.get_active_gpus(): triton.runtime.driver.set_active_to_gpu() @@ -156,6 +208,7 @@ def benchmark(size, provider): os.unsetenv('TRITON_CPU_SINGLE_CORE') else: triton.runtime.driver.set_active_to_gpu() + output = torch.empty_like(x) quantiles = [0.5, 0.2, 0.8] if provider == 'torch-gpu': @@ -166,17 +219,24 @@ def benchmark(size, provider): elif provider == 'torch-cpu': # Note that we preallocate the output buffer here to only measure the kernel performance # without a large chunk of memory allocation. - output = torch.empty_like(x) ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles, device_type=device) - elif provider == 'triton-cpu-single': - output = torch.empty_like(x) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, - device_type=device) elif provider == 'triton-cpu': - output = torch.empty_like(x) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, device), quantiles=quantiles, + device_type=device) + elif provider == 'triton-cpu-hooks': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, device), quantiles=quantiles, + device_type=device, measure_time_with_hooks=True) + elif provider == 'triton-cpu-tiled': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled(x, y, output), quantiles=quantiles, device_type=device) + elif provider == 'triton-cpu-tiled-hooks': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled(x, y, output), quantiles=quantiles, + device_type=device, measure_time_with_hooks=True) + elif provider == 'triton-cpu-tiled-tuned-hooks': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled_with_st_threshold(x, y, output), + quantiles=quantiles, device_type=device, + measure_time_with_hooks=True) gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6 return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 44d980e01987..9bc9db4379f8 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -229,9 +229,13 @@ def format_of(ty): static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ // TODO: Consider using omp collapse(3) clause for simplicity? - auto all_grids = get_all_grids(gridX, gridY, gridZ); size_t N = gridX * gridY * gridZ; + if (N == 1) {{ + (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} 0, 0, 0, 1, 1, 1); + return; + }} + auto all_grids = get_all_grids(gridX, gridY, gridZ); if (getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{ if (getBoolEnv("TRITON_CPU_OMP_DEBUG")) printf("Single core launcher\\n");