Skip to content

Commit

Permalink
Add kernel execution time measurement using hooks for do_bench (#139)
Browse files Browse the repository at this point in the history
* Add timing measurements using launch hooks for CPU.

Signed-off-by: Ilya Enkovich <[email protected]>

* Avoid OMP for trivial grid in CPU launcher.

Signed-off-by: Ilya Enkovich <[email protected]>

* Add more measurement options for vector-add tutorial.

Signed-off-by: Ilya Enkovich <[email protected]>

---------

Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich authored Sep 9, 2024
1 parent 0b65730 commit 6de10e3
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 18 deletions.
55 changes: 48 additions & 7 deletions python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,63 @@
from contextlib import contextmanager
from typing import Any, Dict, List
from . import language as tl
import triton


class CPUDeviceInterface:

class Event:
class HooksTimeAccessor:

def __init__(self, enable_timing=True):
self.time = 0
def __init__(self, di):
self.di = di
self.record_idx = 0

def elapsed_time(self, end_event) -> float:
return (end_event.time - self.time) * 1000
total_time = 0
for i in range(self.record_idx, end_event.record_idx):
total_time += self.di.kernel_times[i]
return total_time * 1000

def record(self):
self.time = time.perf_counter()
self.record_idx = len(self.di.kernel_times)

class TimerEvent:

def __init__(self):
self.timer = 0

def elapsed_time(self, end_event) -> float:
return (end_event.timer - self.timer) * 1000

def record(self):
self.timer = time.perf_counter()

def __init__(self):
pass
self.kernel_times = []
self.last_start = 0
self.use_hooks = False
triton.compiler.CompiledKernel.launch_enter_hook = None
triton.compiler.CompiledKernel.launch_exit_hook = None

def enable_hook_timing(self):
self.use_hooks = True
triton.compiler.CompiledKernel.launch_enter_hook = lambda arg: self._enter_hook()
triton.compiler.CompiledKernel.launch_exit_hook = lambda arg: self._exit_hook()

def synchronize(self):
pass

def _enter_hook(self):
self.last_start = time.perf_counter()

def _exit_hook(self):
self.kernel_times.append(time.perf_counter() - self.last_start)

def Event(self, enable_timing=True):
if self.use_hooks:
return CPUDeviceInterface.HooksTimeAccessor(self)
return CPUDeviceInterface.TimerEvent()


def nvsmi(attrs):
attrs = ','.join(attrs)
Expand Down Expand Up @@ -113,7 +149,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod


def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean",
device_type="cuda"):
device_type="cuda", measure_time_with_hooks=False):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand Down Expand Up @@ -168,6 +204,11 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
di.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5

# For CPU we can use entry and exit hooks to measure execution time
# more precisely.
if measure_time_with_hooks:
di.enable_hook_timing()

# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
Expand Down
80 changes: 70 additions & 10 deletions python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

GPU_BLOCK_SIZE = 1024
CPU_BLOCK_SIZE = 4096
CPU_ST_THRESHOLD = 65536
USE_GPU = False


Expand Down Expand Up @@ -56,6 +57,26 @@ def add_kernel(x_ptr, # *Pointer* to first input vector.
tl.store(output_ptr + offsets, output, mask=mask)


@triton.jit
def add_kernel_tiled(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
TILE_SIZE: tl.constexpr, # Number of elements each iteration should process.
# NOTE `constexpr` so it can be used as a shape value.
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
for i in range(0, tl.cdiv(BLOCK_SIZE, TILE_SIZE)):
offsets = block_start + i * TILE_SIZE + tl.arange(0, TILE_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)


# %%
# Let's also declare a helper function to (1) allocate the `z` tensor
# and (2) enqueue the above kernel with appropriate grid/block sizes:
Expand All @@ -80,6 +101,28 @@ def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, device):
return output


def add_tiled(x: torch.Tensor, y: torch.Tensor, output):
if output is None:
output = torch.empty_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
add_kernel_tiled[grid](x, y, output, n_elements, BLOCK_SIZE=CPU_BLOCK_SIZE, TILE_SIZE=16)
return output


def add_tiled_with_st_threshold(x: torch.Tensor, y: torch.Tensor, output):
if output is None:
output = torch.empty_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
# TODO: try to choose the best block size using autotuner
BLOCK_SIZE = triton.next_power_of_2(n_elements)
if BLOCK_SIZE > CPU_ST_THRESHOLD:
BLOCK_SIZE = CPU_BLOCK_SIZE
add_kernel_tiled[grid](x, y, output, n_elements, BLOCK_SIZE=BLOCK_SIZE, TILE_SIZE=16)
return output


# %%
# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
torch.manual_seed(0)
Expand All @@ -94,10 +137,19 @@ def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, device):
print(output_triton_cpu)
print(f'The maximum difference between torch-cpu and triton-cpu is '
f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}')
output_triton_cpu = add_tiled(x, y, None)
print(f'The maximum difference between torch-cpu-tiled and triton-cpu is '
f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}')

LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu']
LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU']
LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '-')]
LINE_VALS = [
'triton-cpu', 'triton-cpu-hooks', 'triton-cpu-tiled', 'triton-cpu-tiled-hooks', 'triton-cpu-tiled-tuned-hooks',
'torch-cpu'
]
LINE_NAMES = [
'TritonCPU', 'TritonCPU (hooks)', 'TritonCPUTiled', 'TritonCPUTiled (hooks)', 'TritonCPUTiled (tuned, hooks)',
'TorchCPU'
]
LINE_STYLES = [('blue', '--'), ('blue', '-'), ('blue', '-'), ('blue', '-'), ('blue', '-'), ('green', '-')]

if USE_GPU and triton.runtime.driver.get_active_gpus():
triton.runtime.driver.set_active_to_gpu()
Expand Down Expand Up @@ -156,6 +208,7 @@ def benchmark(size, provider):
os.unsetenv('TRITON_CPU_SINGLE_CORE')
else:
triton.runtime.driver.set_active_to_gpu()
output = torch.empty_like(x)

quantiles = [0.5, 0.2, 0.8]
if provider == 'torch-gpu':
Expand All @@ -166,17 +219,24 @@ def benchmark(size, provider):
elif provider == 'torch-cpu':
# Note that we preallocate the output buffer here to only measure the kernel performance
# without a large chunk of memory allocation.
output = torch.empty_like(x)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles,
device_type=device)
elif provider == 'triton-cpu-single':
output = torch.empty_like(x)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles,
device_type=device)
elif provider == 'triton-cpu':
output = torch.empty_like(x)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles,
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, device), quantiles=quantiles,
device_type=device)
elif provider == 'triton-cpu-hooks':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, device), quantiles=quantiles,
device_type=device, measure_time_with_hooks=True)
elif provider == 'triton-cpu-tiled':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled(x, y, output), quantiles=quantiles,
device_type=device)
elif provider == 'triton-cpu-tiled-hooks':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled(x, y, output), quantiles=quantiles,
device_type=device, measure_time_with_hooks=True)
elif provider == 'triton-cpu-tiled-tuned-hooks':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled_with_st_threshold(x, y, output),
quantiles=quantiles, device_type=device,
measure_time_with_hooks=True)
gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)

Expand Down
6 changes: 5 additions & 1 deletion third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,13 @@ def format_of(ty):
static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
// TODO: Consider using omp collapse(3) clause for simplicity?
auto all_grids = get_all_grids(gridX, gridY, gridZ);
size_t N = gridX * gridY * gridZ;
if (N == 1) {{
(*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} 0, 0, 0, 1, 1, 1);
return;
}}
auto all_grids = get_all_grids(gridX, gridY, gridZ);
if (getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{
if (getBoolEnv("TRITON_CPU_OMP_DEBUG"))
printf("Single core launcher\\n");
Expand Down

0 comments on commit 6de10e3

Please sign in to comment.