Skip to content

Commit

Permalink
Rebase onto upstream triton 4a54311 and fix regressions
Browse files Browse the repository at this point in the history
  • Loading branch information
minjang committed Oct 24, 2024
1 parent fc3d76b commit 0f6dcd7
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 123 deletions.
3 changes: 1 addition & 2 deletions python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ def kernel_call():
self.post_hook(args, exception=None)

try:
device = driver.active.get_current_target().backend
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8), device_type=device)
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
except (OutOfResources, CompileTimeAssertionFailure):
return [float("inf"), float("inf"), float("inf")]

Expand Down
60 changes: 2 additions & 58 deletions python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,67 +2,10 @@
import os
import subprocess
import sys
import time
from contextlib import contextmanager
from typing import Any, Dict, List
from . import language as tl
from . import runtime
import triton


class CPUDeviceInterface:

class HooksTimeAccessor:

def __init__(self, di):
self.di = di
self.record_idx = 0

def elapsed_time(self, end_event) -> float:
total_time = 0
for i in range(self.record_idx, end_event.record_idx):
total_time += self.di.kernel_times[i]
return total_time * 1000

def record(self):
self.record_idx = len(self.di.kernel_times)

class TimerEvent:

def __init__(self):
self.timer = 0

def elapsed_time(self, end_event) -> float:
return (end_event.timer - self.timer) * 1000

def record(self):
self.timer = time.perf_counter()

def __init__(self):
self.kernel_times = []
self.last_start = 0
self.use_hooks = False
triton.compiler.CompiledKernel.launch_enter_hook = None
triton.compiler.CompiledKernel.launch_exit_hook = None

def enable_hook_timing(self):
self.use_hooks = True
triton.compiler.CompiledKernel.launch_enter_hook = lambda arg: self._enter_hook()
triton.compiler.CompiledKernel.launch_exit_hook = lambda arg: self._exit_hook()

def synchronize(self):
pass

def _enter_hook(self):
self.last_start = time.perf_counter()

def _exit_hook(self):
self.kernel_times.append(time.perf_counter() - self.last_start)

def Event(self, enable_timing=True):
if self.use_hooks:
return CPUDeviceInterface.HooksTimeAccessor(self)
return CPUDeviceInterface.TimerEvent()


def nvsmi(attrs):
Expand Down Expand Up @@ -149,7 +92,8 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
return _summarize_statistics(torch.tensor(ret), quantiles, return_mode)


def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", measure_time_with_hooks=False):
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean",
measure_time_with_hooks=False):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand Down
21 changes: 8 additions & 13 deletions python/tutorials/01-vector-add.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,31 +213,26 @@ def benchmark(size, provider):

quantiles = [0.5, 0.2, 0.8]
if provider == 'torch-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
elif provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles)
elif provider == 'torch-cpu':
# Note that we preallocate the output buffer here to only measure the kernel performance
# without a large chunk of memory allocation.
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles)
elif provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, device), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, device), quantiles=quantiles)
elif provider == 'triton-cpu-hooks':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, device), quantiles=quantiles,
device_type=device, measure_time_with_hooks=True)
measure_time_with_hooks=True)
elif provider == 'triton-cpu-tiled':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled(x, y, output), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled(x, y, output), quantiles=quantiles)
elif provider == 'triton-cpu-tiled-hooks':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled(x, y, output), quantiles=quantiles,
device_type=device, measure_time_with_hooks=True)
measure_time_with_hooks=True)
elif provider == 'triton-cpu-tiled-tuned-hooks':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled_with_st_threshold(x, y, output),
quantiles=quantiles, device_type=device,
measure_time_with_hooks=True)
quantiles=quantiles, measure_time_with_hooks=True)
gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)

Expand Down
18 changes: 8 additions & 10 deletions python/tutorials/02-fused-softmax-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,24 +211,22 @@ def benchmark(M, N, provider):

quantiles = [0.5, 0.2, 0.8]
if provider == 'torch-cpu-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles)
if provider == 'torch-cpu-jit':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles)
if provider == 'torch-cpu-compile':
compiled = torch.compile(naive_softmax)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x), quantiles=quantiles, device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x), quantiles=quantiles)
if provider == 'triton-cpu-single':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles, device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles)
if provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles, device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles)
if provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles, device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles)
if provider == 'torch-gpu-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles)
if provider == 'torch-gpu-jit':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles)
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)

Expand Down
16 changes: 6 additions & 10 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,22 +378,18 @@ def benchmark(M, N, K, provider):

quantiles = [0.5, 0.2, 0.8]
if provider == 'torch-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
elif provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles)
elif provider == 'torch-cpu-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles)
elif provider == 'torch-cpu-compile':
compiled = torch.compile(torch.matmul)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles)
elif provider == 'triton-cpu-single':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles)
elif provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles)
perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)

Expand Down
4 changes: 2 additions & 2 deletions python/tutorials/05-layer-norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,13 +357,13 @@ def y_fwd():
# forward pass
if mode == 'forward':
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500, device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500)
# backward pass
if mode == 'backward':
y = y_fwd()
gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) # noqa: F811, E704
ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), quantiles=quantiles,
grad_to_none=[x], rep=500, device_type=device)
grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)


Expand Down
14 changes: 5 additions & 9 deletions python/tutorials/matrix-vector-multiplication-bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,20 +169,16 @@ def benchmark(M, N, provider):

quantiles = [0.5, 0.2, 0.8]
if provider == 'torch-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles)
elif provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles)
elif 'torch-cpu-native' in provider:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x, out=output), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x, out=output), quantiles=quantiles)
elif 'torch-cpu-compile' in provider:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled_matmul(weight, x, out=output),
quantiles=quantiles, device_type=device)
quantiles=quantiles)
elif 'triton-cpu' in provider:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles)

perf = lambda ms: 2 * M * N * 1e-9 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
Expand Down
29 changes: 10 additions & 19 deletions python/tutorials/matrix-vector-multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,38 +173,29 @@ def benchmark(M, N, provider):

quantiles = [0.5, 0.2, 0.8]
if provider == 'torch-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles)
elif provider == 'triton-gpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles)
elif provider == 'torch-cpu-native' or provider == 'torch-cpu-2d-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x, out=output), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x, out=output), quantiles=quantiles)
elif provider == 'torch-cpu-compile' or provider == 'torch-cpu-2d-compile':
compiled = torch.compile(torch.matmul)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(weight, x, out=output), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(weight, x, out=output), quantiles=quantiles)
elif provider == 'torch-cpu-transpose-native':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(x, weight, out=output), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(x, weight, out=output), quantiles=quantiles)
elif provider == 'torch-cpu-transpose-compile':
compiled = torch.compile(torch.matmul)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x, weight, out=output), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x, weight, out=output), quantiles=quantiles)
elif provider == 'torch-cpu-linear':
weight = torch.nn.Linear(N, M, bias=False, device=weight.device, dtype=weight.dtype)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: weight.forward(x), quantiles=quantiles, device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: weight.forward(x), quantiles=quantiles)
elif provider == 'triton-cpu-single':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles)
elif provider == 'triton-cpu':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles)
elif provider == 'triton-cpu-linear':
# torch.nn.Linear.forward does not take preallocated output buffer, so we also do no provide output buffer for fair comparison
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, None), quantiles=quantiles,
device_type=device)
ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, None), quantiles=quantiles)
perf = lambda ms: 2 * M * N * 1e-9 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)

Expand Down
67 changes: 67 additions & 0 deletions third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import importlib
import importlib.resources
import tempfile
import time

import triton
import triton._C
from triton.runtime.build import _build
from triton.runtime.cache import get_cache_manager
Expand Down Expand Up @@ -353,6 +355,61 @@ def __call__(self, *args, **kwargs):
self.launch(*args, **kwargs)


class CPUDeviceInterface:

class HooksTimeAccessor:

def __init__(self, di):
self.di = di
self.record_idx = 0

def elapsed_time(self, end_event) -> float:
total_time = 0
for i in range(self.record_idx, end_event.record_idx):
total_time += self.di.kernel_times[i]
return total_time * 1000

def record(self):
self.record_idx = len(self.di.kernel_times)

class TimerEvent:

def __init__(self):
self.timer = 0

def elapsed_time(self, end_event) -> float:
return (end_event.timer - self.timer) * 1000

def record(self):
self.timer = time.perf_counter()

def __init__(self):
self.kernel_times = []
self.last_start = 0
self.use_hooks = False
triton.compiler.CompiledKernel.launch_enter_hook = None
triton.compiler.CompiledKernel.launch_exit_hook = None

def enable_hook_timing(self):
self.use_hooks = True
triton.compiler.CompiledKernel.launch_enter_hook = lambda arg: self._enter_hook()
triton.compiler.CompiledKernel.launch_exit_hook = lambda arg: self._exit_hook()

def synchronize(self):
pass

def _enter_hook(self):
self.last_start = time.perf_counter()

def _exit_hook(self):
self.kernel_times.append(time.perf_counter() - self.last_start)

def Event(self, enable_timing=True):
if self.use_hooks:
return CPUDeviceInterface.HooksTimeAccessor(self)
return CPUDeviceInterface.TimerEvent()


class CPUDriver(DriverBase):

def __init__(self):
Expand All @@ -372,10 +429,20 @@ def get_current_target(self):
cpu_arch = llvm.get_cpu_tripple().split("-")[0]
return GPUTarget("cpu", cpu_arch, 0)

def get_device_interface(self):
return CPUDeviceInterface()

@staticmethod
def is_active():
return True

def get_benchmarker(self):
from triton.testing import do_bench
return do_bench

def get_empty_cache_for_benchmark(self):
import torch

# A typical LLC size for high-end server CPUs are ~400MB.
cache_size = 512 * 1024 * 1024
return torch.empty(int(cache_size // 4), dtype=torch.int, device='cpu')

0 comments on commit 0f6dcd7

Please sign in to comment.