Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AUTOTUNER] Make autotuner take do_bench as a parameter #4496

Merged
merged 6 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions python/test/unit/hopper/test_flashattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
Expand All @@ -447,7 +445,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
ms = triton.testing.do_bench(fn)
return ms
if provider == "flash":
lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device)
Expand All @@ -459,7 +457,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
ms = triton.testing.do_bench(fn)
return ms


Expand Down
4 changes: 3 additions & 1 deletion python/test/unit/language/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def test_triton_heuristic(device):
src = torch.empty(N, device=device)
dst = torch.zeros(N, device=device)

@triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], warmup=1, rep=1)
do_bench = lambda kernel, quantiles: triton.testing.do_bench(kernel, quantiles=quantiles, warmup=1, rep=1)

@triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], do_bench=do_bench)
@triton.heuristics({'EVEN_N': lambda nargs: nargs['N'] % 2 == 0}) # test kwargs
@triton.heuristics({'EVEN_src': lambda nargs: nargs['src'].data_ptr() % 2 == 0}) # test args
@triton.jit
Expand Down
12 changes: 8 additions & 4 deletions python/test/unit/runtime/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import pytest


def do_bench(kernel_call, quantiles):
return triton.testing.do_bench(kernel_call, quantiles=quantiles, warmup=1, rep=1)


@pytest.mark.parametrize('use_cuda_graph', [False, True])
def test_kwargs(use_cuda_graph: bool, device: str):
M, N = 1024, 16
Expand All @@ -13,7 +17,7 @@ def test_kwargs(use_cuda_graph: bool, device: str):

configs = [triton.Config(kwargs={'BLOCK_SIZE_M': 32}), triton.Config(kwargs={'BLOCK_SIZE_M': 128})]

@triton.autotune(configs=configs, key=['M'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph)
@triton.autotune(configs=configs, key=['M'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph, do_bench=do_bench)
@triton.jit
def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr):
offsets_m = tl.program_id(0) * stride_m + tl.arange(0, BLOCK_SIZE_M)
Expand All @@ -34,7 +38,7 @@ def test_restore(device):

configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})]

@triton.autotune(configs=configs, key=['N'], restore_value=['src'], warmup=1, rep=1)
@triton.autotune(configs=configs, key=['N'], restore_value=['src'], do_bench=do_bench)
@triton.jit
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
Expand Down Expand Up @@ -64,7 +68,7 @@ def _post_hook(*args, exception):
values["has_exception"] = True
assert values["counter"] == 0

@triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, pre_hook=_pre_hook, post_hook=_post_hook)
@triton.autotune(configs=configs, key=['N'], do_bench=do_bench, pre_hook=_pre_hook, post_hook=_post_hook)
@triton.heuristics({"N_STAGES": lambda nargs: 100 if nargs['N'] == 4096 else 4})
@triton.jit
def _kernel(src, N, N_STAGES: tl.constexpr, BLOCK_SIZE: tl.constexpr):
Expand Down Expand Up @@ -115,7 +119,7 @@ def perf_model(*args, **kwargs):
else:
prune_configs_by = {'early_config_prune': early_config_prune}

@triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, warmup=1, rep=1)
@triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, do_bench=do_bench)
@triton.jit
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
Expand Down
14 changes: 14 additions & 0 deletions python/triton/backends/driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from abc import ABCMeta, abstractmethod, abstractclassmethod
from typing import Callable, List, Protocol, Sequence


class Benchmarker(Protocol):

def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]:
pass


class DriverBase(metaclass=ABCMeta):
Expand All @@ -11,6 +18,13 @@ def is_active(self):
def get_current_target(self):
pass

@abstractmethod
def get_benchmarker(self) -> Benchmarker:
"""
Return the benchmarking function that this backend should use by default.
"""
raise NotImplementedError

def __init__(self) -> None:
pass

Expand Down
53 changes: 40 additions & 13 deletions python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import inspect
from typing import Dict

from ..testing import do_bench, do_bench_cudagraph
from .jit import KernelInterface
from .errors import OutOfResources
from .driver import driver


class Autotuner(KernelInterface):
Expand All @@ -24,9 +24,10 @@ def __init__(
pre_hook=None,
post_hook=None,
prune_configs_by: Dict = None,
warmup=25,
rep=100,
warmup=None,
rep=None,
use_cuda_graph=False,
do_bench=None,
):
"""
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
Expand Down Expand Up @@ -88,10 +89,36 @@ def _post_hook(args, exception):
self.base_fn = fn
while not inspect.isfunction(self.base_fn):
self.base_fn = self.base_fn.fn
self.num_warmups = warmup
self.num_reps = rep
import torch
self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
Comment on lines -91 to -94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Jokeren, @int3,

Fields self.num_warmups , self.num_reps and self.use_cuda_graph are used by PyTorch to find out what parameters the autotuner was called with:

https://github.com/pytorch/pytorch/blame/5141ade8e30c64e873e14dcc8de233da45d15025/torch/_higher_order_ops/triton_kernel_wrap.py#L829

Can they be left until the corresponding parameters are removed from __init__ signature?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@int3 is driving the effort. It's up to him. I'm OK either way.


# If we got explicitly called via the old interface, raise a warning
# and proceed with the old behavior.
if warmup is not None or rep is not None or use_cuda_graph:
import warnings
warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "
"https://github.com/triton-lang/triton/pull/4496 for details."), DeprecationWarning,
stacklevel=1)
if use_cuda_graph:
from ..testing import do_bench_cudagraph
self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
kernel_call,
rep=rep if rep is not None else 100,
quantiles=quantiles,
)
return

import triton.testing
self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
kernel_call,
warmup=warmup if warmup is not None else 25,
rep=rep if rep is not None else 100,
quantiles=quantiles,
)
int3 marked this conversation as resolved.
Show resolved Hide resolved
return

if do_bench is None:
self.do_bench = driver.active.get_benchmarker()
else:
self.do_bench = do_bench

def _bench(self, *args, config, **meta):
from ..compiler.errors import CompileTimeAssertionFailure
Expand Down Expand Up @@ -125,9 +152,7 @@ def kernel_call():
self.post_hook(args, exception=None)

try:
if self.use_cuda_graph:
return do_bench_cudagraph(kernel_call, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
except (OutOfResources, CompileTimeAssertionFailure):
return [float("inf"), float("inf"), float("inf")]

Expand Down Expand Up @@ -257,7 +282,7 @@ def __str__(self):


def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None,
warmup=25, rep=100, use_cuda_graph=False):
warmup=None, rep=None, use_cuda_graph=False, do_bench=None):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.

Expand Down Expand Up @@ -305,10 +330,12 @@ def kernel(x_ptr, x_size, **META):
'args': a list of arguments passed to the kernel.
'exception': the exception raised by the kernel in case of a compilation or runtime error.
:type post_hook: lambda args, exception
:param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25.
:param warmup: warmup time (in ms) to pass to benchmarking (deprecated).
:type warmup: int
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
:param rep: repetition time (in ms) to pass to benchmarking (deprecated).
:type rep: int
:param do_bench: a benchmark function to measure the time of each run.
:type do_bench: lambda fn, quantiles
"""

def decorator(fn):
Expand Down
6 changes: 2 additions & 4 deletions python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,8 +601,6 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"):
assert mode in ["fwd", "bwd"]
warmup = 25
rep = 100
dtype = torch.float16
if "triton" in provider:
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
Expand All @@ -620,15 +618,15 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, dev
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
ms = triton.testing.do_bench(fn)
if provider == "flash":
qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, causal=causal)
if mode == "bwd":
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
ms = triton.testing.do_bench(fn)
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM
total_flops = 2 * flops_per_matmul
if causal:
Expand Down
4 changes: 4 additions & 0 deletions third_party/amd/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,3 +499,7 @@ def get_current_target(self):
arch = device_properties['arch']
warp_size = device_properties['warpSize']
return GPUTarget("hip", arch.split(':')[0], warp_size)

def get_benchmarker(self):
from triton.testing import do_bench
return do_bench
4 changes: 4 additions & 0 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,7 @@ def get_device_interface(self):
def is_active():
import torch
return torch.cuda.is_available() and (torch.version.hip is None)

def get_benchmarker(self):
from triton.testing import do_bench
return do_bench
Loading