Skip to content

Commit

Permalink
[bugfix] fix libentry on argument processing (#68)
Browse files Browse the repository at this point in the history
* fix libentry:
1. ensure that decorator cascading is working as expected, i.e. inner decorator can use arguments provided by outer decorator
2. ensure that grid function can use all the arguments provided by decorators(Autotuner & Heuristics)
3. simply LibEntry, extract captured constant arguments from CompiledKernel, instead of traversing layers of decorator.

* add test_libentry into CI

* add test_libentry

* assert not raising certain kind of exception

* clean code
  • Loading branch information
iclementine authored Jun 18, 2024
1 parent 1630cde commit fcc56c5
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 53 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/python-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ jobs:
CUDA_VISIBLE_DEVICES=1 pytest -s tests/test_binary_pointwise_ops.py &
CUDA_VISIBLE_DEVICES=2 pytest -s tests/test_blas_ops.py &
CUDA_VISIBLE_DEVICES=3 pytest -s tests/test_reduction_ops.py &
CUDA_VISIBLE_DEVICES=4 pytest -s tests/test_special_ops.py && wait
CUDA_VISIBLE_DEVICES=4 pytest -s tests/test_special_ops.py &
CUDA_VISIBLE_DEVICES=5 pytest -s tests/test_libentry.py && wait
container-model-test:
runs-on: [self-hosted, docker]
Expand Down
85 changes: 33 additions & 52 deletions src/flag_gems/utils/libentry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ def __init__(
self.fn = fn
self.arg_names = fn.arg_names
self.divisibility = 16
self.config_cache = dict()
self.kernel_cache = dict()
if isinstance(fn, triton.runtime.Autotuner):
self.rt = "Autotuner"
elif isinstance(fn, triton.runtime.Heuristics):
self.rt = "Heuristics"
else:
self.rt = "JitFunction"
fn = self.fn
while not isinstance(fn, triton.runtime.JITFunction):
fn = fn.fn
self.jit_function: triton.runtime.JITFunction = fn
self.kernel_arg_indices = []
for p in self.jit_function.params:
if not p.is_constexpr:
self.kernel_arg_indices.append(p.num)

def run(self, *args, **kwargs):
key = []
Expand All @@ -27,57 +28,37 @@ def run(self, *args, **kwargs):
elif isinstance(arg, int):
key.append(arg)
entry_key = tuple(key)

config = {}
# Autotuner
if self.rt == "Autotuner":
if entry_key not in self.config_cache:
# tune
kernel = self.fn.run(*args, **kwargs)
config = self.fn.best_config.kwargs
self.config_cache[entry_key] = config
self.kernel_cache[entry_key] = kernel
return
else:
# tuned
config = self.config_cache[entry_key]
kernel = self.kernel_cache[entry_key]
# Heuristics
elif self.rt == "Heuristics":
if entry_key not in self.kernel_cache:
# compile
kernel = self.fn.run(*args, **kwargs)
self.kernel_cache[entry_key] = kernel
return
else:
# compiled
for v, heur in self.fn.values.items():
config[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
kernel = self.kernel_cache[entry_key]
# JitFunction
if entry_key not in self.kernel_cache:
kernel = self.fn.run(*args, **kwargs)
self.kernel_cache[entry_key] = kernel
else:
if entry_key not in self.kernel_cache:
# compile
kernel = self.fn.run(*args, **kwargs)
self.kernel_cache[entry_key] = kernel
return
else:
# compiled
args = [
arg
for i, arg in enumerate(args)
if not self.fn.params[i].is_constexpr
]
kernel = self.kernel_cache[entry_key]
kernel = self.kernel_cache[entry_key]

# collect all the arguments to the kernel, all non-constexpr arguments
k_args = [arg for i, arg in enumerate(args) if i in self.kernel_arg_indices]
if len(k_args) < len(self.kernel_arg_indices):
for p in self.jit_function.params[len(args) :]:
if not p.is_constexpr:
if p.name in kwargs:
k_args.append(kwargs[p.name])
else:
k_args.append(p.default)

grid = kwargs["grid"]
if callable(grid):
# grid_fn
current = dict(**kwargs, **config)
meta = {**dict(zip(self.arg_names, args)), **current}
# collect all arguments to the grid fn,ie:
# 1. args,
# 2. kwargs,
# 3. all all other captured arguments in CompiledKernel from Autotunner & Heuristics
# when kwargs & captured args conflict, captured args have higher priority
for k, v in kernel.constants.items():
arg_name = self.arg_names[int(k)]
kwargs[arg_name] = v
meta = {**dict(zip(self.arg_names, args)), **kwargs}
grid = grid(meta)
grid = grid + (1, 1)

kernel[grid[0:3]](*args)
kernel[grid[0:3]](*k_args)
return


Expand Down
182 changes: 182 additions & 0 deletions tests/test_libentry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from contextlib import contextmanager

import torch
import triton
from triton import language as tl

from flag_gems.utils import libentry


# not_raises is copied from https://gist.github.com/oisinmulvihill/45c14271fad7794a4a52516ecb784e69
@contextmanager
def not_raises(ExpectedException):
try:
yield

except ExpectedException as error:
raise AssertionError(f"Raised exception {error} when it should not!")

except Exception as error:
raise AssertionError(f"An unexpected exception {error} raised.")


def softmax_inner_decorator_cascade(x, dim, dtype=None):
assert dim >= -x.ndim and dim < x.ndim, "Invalid dim"
dim = dim % x.ndim
M = 1
N = x.shape[dim]
for i in range(dim):
M *= x.shape[i] # pre_dim
inp = x.contiguous()
if dtype is None:
dtype = x.dtype
out = torch.empty_like(inp, dtype=dtype)

grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1)
softmax_kernel_inner[grid](
out,
inp,
M,
N,
DUMMY=60,
)
return out


def softmax_inner_pass_kernel_arg_via_kw(x, dim, dtype=None):
assert dim >= -x.ndim and dim < x.ndim, "Invalid dim"
dim = dim % x.ndim
M = 1
N = x.shape[dim]
for i in range(dim):
M *= x.shape[i] # pre_dim
inp = x.contiguous()
if dtype is None:
dtype = x.dtype
out = torch.empty_like(inp, dtype=dtype)

grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1)
softmax_kernel_inner[grid](
out,
inp,
M,
N=N,
DUMMY=60,
)
return out


def softmax_inner_kernel_arg_apply_default(x, dim, dtype=None):
assert dim >= -x.ndim and dim < x.ndim, "Invalid dim"
dim = dim % x.ndim
M = 1
N = x.shape[dim]
for i in range(dim):
M *= x.shape[i] # pre_dim
inp = x.contiguous()
if dtype is None:
dtype = x.dtype
out = torch.empty_like(inp, dtype=dtype)

grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1)
softmax_kernel_inner[grid](
out,
inp,
M,
N,
)
return out


@libentry()
@triton.autotune(
configs=[
triton.Config({"TILE_N": 32}),
triton.Config({"TILE_N": 64}),
triton.Config({"TILE_N": 128}),
triton.Config({"TILE_N": 256}),
triton.Config({"TILE_N": 512}),
triton.Config({"TILE_N": 1024}),
],
key=["N"],
)
@triton.heuristics(
values={
"TILE_M": lambda args: 1024 // args["TILE_N"],
"ONE_TILE_PER_CTA": lambda args: args["TILE_N"] >= args["N"],
},
)
@triton.jit
def softmax_kernel_inner(
output_ptr,
input_ptr,
M,
N,
TILE_M: tl.constexpr,
TILE_N: tl.constexpr,
ONE_TILE_PER_CTA: tl.constexpr,
DUMMY=42,
):
_ = DUMMY
pid_m = tl.program_id(0)
m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M)
if ONE_TILE_PER_CTA:
n_offsets = tl.arange(0, TILE_N)
offset = m_offsets[:, None] * N + n_offsets
input_ptrs = input_ptr + offset
mask = (m_offsets[:, None] < M) & (n_offsets < N)
inp = tl.load(input_ptrs, mask=mask, other=-float("inf"))
m = tl.max(inp, 1)
e = tl.exp(inp - m[:, None])
z = tl.sum(e, 1)
out = e / z[:, None]
output_ptrs = output_ptr + offset
tl.store(output_ptrs, out, mask=mask)
else:
m = tl.full([TILE_M], value=float("-inf"), dtype=tl.float32)
z = tl.full([TILE_M], value=0.0, dtype=tl.float32)

n_offsets = tl.arange(0, TILE_N)
offset = m_offsets[:, None] * N + n_offsets
for _ in range(0, N, TILE_N):
mask = (m_offsets[:, None] < M) & (n_offsets < N)
input_ptrs = input_ptr + offset
inp = tl.load(input_ptrs, mask=mask, other=-float("inf"))
m_new = tl.maximum(m, tl.max(inp, 1))
alpha = m - m_new
z = z * tl.exp(alpha) + tl.sum(tl.exp(inp - m_new[:, None]), axis=1)
m = m_new
n_offsets += TILE_N
offset += TILE_N

n_offsets = tl.arange(0, TILE_N)
offset = m_offsets[:, None] * N + n_offsets
for _ in range(0, N, TILE_N):
mask = (m_offsets[:, None] < M) & (n_offsets < N)
input_ptrs = input_ptr + offset
inp = tl.load(input_ptrs, mask=mask, other=-float("inf"))
o = tl.exp(inp - m[:, None]) / z[:, None]
output_ptrs = output_ptr + offset
tl.store(output_ptrs, o, mask=mask)
n_offsets += TILE_N
offset += TILE_N


def test_decorator_cascade():
# to test inner decorator can use arguments supplied by outer decorator
# and grid function can use arguments supplied by all the decorator
x = torch.randn((128, 128, 128), device="cuda")
with not_raises(KeyError):
_ = softmax_inner_decorator_cascade(x, dim=2)


def test_pass_kernel_arg_via_kw():
x = torch.randn((128, 128, 128), device="cuda")
with not_raises(KeyError):
_ = softmax_inner_pass_kernel_arg_via_kw(x, dim=2)


def test_kernel_arg_apply_default():
x = torch.randn((128, 128, 128), device="cuda")
with not_raises(KeyError):
_ = softmax_inner_kernel_arg_apply_default(x, dim=2)

0 comments on commit fcc56c5

Please sign in to comment.