Skip to content

Commit

Permalink
[operator] fix libentry to support triton 2.3
Browse files Browse the repository at this point in the history
  • Loading branch information
StrongSpoon committed Jun 28, 2024
1 parent 9168f2d commit 736822e
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 46 deletions.
2 changes: 1 addition & 1 deletion benchmark/performance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .conftest import CPU_MODE

WARMUP = 10
WARMUP = 100
REPETITION = 1000


Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/fused/skip_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@libentry()
@triton.jit
@triton.jit(do_not_specialize=["eps"])
def skip_layer_norm_kernel(
Y, # pointer to the output
X, # pointer to the input
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/fused/skip_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@libentry()
@triton.jit
@triton.jit(do_not_specialize=["eps"])
def skip_rms_norm_kernel(
Y, # pointer to the output
X, # pointer to the input
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
],
key=["M", "N", "K"],
)
@triton.jit
@triton.jit(do_not_specialize=["alpha", "beta"])
def addmm_kernel(
a_ptr,
b_ptr,
Expand Down
6 changes: 3 additions & 3 deletions src/flag_gems/ops/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Reduction(IntEnum):
),
},
)
@triton.jit
@triton.jit(do_not_specialize=["mean_num"])
def log_softmax_and_mul_kernel(
output_ptr,
input_ptr,
Expand Down Expand Up @@ -95,7 +95,7 @@ def log_softmax_and_mul_kernel(
),
},
)
@triton.jit
@triton.jit(do_not_specialize=["mean_num"])
def softmax_and_sub_kernel(
output_ptr,
input_ptr,
Expand Down Expand Up @@ -158,7 +158,7 @@ def softmax_and_sub_kernel(
),
},
)
@triton.jit
@triton.jit(do_not_specialize=["mean_num"])
def softmax_and_sub_reduce_kernel(
output_ptr,
input_ptr,
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _rand(seed, offset):
"N",
],
)
@triton.jit
@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
def dropout_forward_kernel(
X,
Y,
Expand Down Expand Up @@ -82,7 +82,7 @@ def dropout_forward_kernel(
"N",
],
)
@triton.jit
@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
def dropout_backward_kernel(
DY,
DX,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/groupnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@libentry()
@triton.jit
@triton.jit(do_not_specialize=["eps"])
def group_norm_kernel(
X,
Y,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def cfggen():

@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
@triton.jit(do_not_specialize=["eps"])
def layer_norm_kernel(
X,
Y,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@libentry()
@triton.jit
@triton.jit(do_not_specialize=["eps"])
def rms_norm_kernel(
Y, # pointer to the output
X, # pointer to the input
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/triu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def cfggen_batch():

@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
@triton.jit(do_not_specialize=["diagonal"])
def triu_kernel(
X,
Y,
Expand Down Expand Up @@ -55,7 +55,7 @@ def triu_kernel(

@libentry()
@triton.autotune(configs=cfggen_batch(), key=["batch", "MN", "N", "diagonal"])
@triton.jit
@triton.jit(do_not_specialize=["diagonal"])
def triu_batch_kernel(
X,
Y,
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/var_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def welford_func(mean_x, count_x, M_x, mean_y, count_y, M_y):

@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
@triton.jit(do_not_specialize=["correction"])
def var_mean_welford_kernel(
X,
Var,
Expand Down Expand Up @@ -112,7 +112,7 @@ def var_mean_kernel_1(
@triton.heuristics(
values={"BLOCK_N": lambda args: triton.next_power_of_2(args["BLOCK_NUM"])},
)
@triton.jit
@triton.jit(do_not_specialize=["correction"])
def var_mean_kernel_2(
Acc,
Average,
Expand Down
6 changes: 3 additions & 3 deletions src/flag_gems/ops/vector_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def l0_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr):

@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
@triton.jit(do_not_specialize=["ord"])
def v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
pid = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
X = X + pid * N
Expand All @@ -231,7 +231,7 @@ def v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexp


@libentry()
@triton.jit
@triton.jit(do_not_specialize=["ord"])
def l1_norm_kernel_1(X, Mid, ord, M, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
Expand All @@ -245,7 +245,7 @@ def l1_norm_kernel_1(X, Mid, ord, M, BLOCK_SIZE: tl.constexpr):


@libentry()
@triton.jit
@triton.jit(do_not_specialize=["ord"])
def l1_norm_kernel_2(Mid, Out, ord, MID_SIZE, BLOCK_MID: tl.constexpr):
offset = tl.arange(0, BLOCK_MID)
Mid = Mid + offset
Expand Down
112 changes: 85 additions & 27 deletions src/flag_gems/utils/libentry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import inspect

import triton


Expand All @@ -14,47 +16,103 @@ def __init__(
while not isinstance(fn, triton.runtime.JITFunction):
fn = fn.fn
self.jit_function: triton.runtime.JITFunction = fn
self.kernel_arg_indices = []
self.spec_indices = []
self.dns_indices = []
for p in self.jit_function.params:
if not p.is_constexpr:
self.kernel_arg_indices.append(p.num)
if p.do_not_specialize:
self.dns_indices.append(p.num)
else:
self.spec_indices.append(p.num)

def run(self, *args, **kwargs):
key = []
for arg in args:
def key(self, spec_args, dns_args, const_args):
entry_key = []
for arg in spec_args:
if hasattr(arg, "data_ptr"):
entry_key.append(str(arg.dtype))
entry_key.append(arg.data_ptr() % self.divisibility == 0)
else:
entry_key.append(type(arg))
entry_key.append(arg)
# args do not specialize
for arg in dns_args:
if hasattr(arg, "data_ptr"):
key.append(arg.dtype)
key.append(arg.data_ptr() % self.divisibility == 0)
elif isinstance(arg, int):
key.append(arg)
entry_key = tuple(key)
entry_key.append(str(arg.dtype))
else:
entry_key.append(type(arg))
# const args passed by position
return tuple(entry_key + const_args)

def run(self, *args, **kwargs):
grid = kwargs["grid"]

# collect all the arguments
spec_args = [] # specialize arguments
dns_args = [] # do not specialize arguments
const_args = [] # constexpr arguments
k_args = [] # kernel arguments
for i, arg in enumerate(args):
if i in self.spec_indices:
k_args.append(arg)
spec_args.append(arg)
elif i in self.dns_indices:
k_args.append(arg)
dns_args.append(arg)
else:
const_args.append(arg)
for p in self.jit_function.params[len(args) :]:
if p.name in kwargs:
val = kwargs[p.name]
elif p.default is inspect._empty:
continue
else:
val = p.default

if p.is_constexpr:
const_args.append(val)
elif p.do_not_specialize:
dns_args.append(val)
k_args.append(val)
else:
spec_args.append(val)
k_args.append(val)

entry_key = self.key(spec_args, dns_args, const_args)

if entry_key not in self.kernel_cache:
kernel = self.fn.run(*args, **kwargs)
self.kernel_cache[entry_key] = kernel
fn = self.fn
# collect constexpr arguments for grid computation
constexprs = {}
while not isinstance(fn, triton.runtime.JITFunction):
if isinstance(fn, triton.runtime.Autotuner):
config = fn.best_config
constexprs["num_warps"] = config.num_warps
constexprs["num_stages"] = config.num_stages
constexprs["num_ctas"] = config.num_ctas
constexprs = {**constexprs, **config.kwargs}
elif isinstance(fn, triton.runtime.Heuristics):
for v, heur in fn.values.items():
constexprs[v] = heur(
{**dict(zip(fn.arg_names, args)), **kwargs, **constexprs}
)
else:
raise RuntimeError("Invalid Runtime Function")
fn = fn.fn
for p in self.jit_function.params:
if p.is_constexpr and p.name not in constexprs:
constexprs[p.name] = p.default
self.kernel_cache[entry_key] = (kernel, constexprs)
else:
kernel = self.kernel_cache[entry_key]

# collect all the arguments to the kernel, all non-constexpr arguments
k_args = [arg for i, arg in enumerate(args) if i in self.kernel_arg_indices]
if len(k_args) < len(self.kernel_arg_indices):
for p in self.jit_function.params[len(args) :]:
if not p.is_constexpr:
if p.name in kwargs:
k_args.append(kwargs[p.name])
else:
k_args.append(p.default)
kernel, constexprs = self.kernel_cache[entry_key]

grid = kwargs["grid"]
if callable(grid):
# collect all arguments to the grid fn,ie:
# 1. args,
# 2. kwargs,
# 3. all all other captured arguments in CompiledKernel from Autotunner & Heuristics
# when kwargs & captured args conflict, captured args have higher priority
for k, v in kernel.constants.items():
arg_name = self.arg_names[int(k)]
kwargs[arg_name] = v
meta = {**dict(zip(self.arg_names, args)), **kwargs}
meta = {**dict(zip(self.arg_names, args)), **kwargs, **constexprs}
grid = grid(meta)
grid = grid + (1, 1)

Expand Down

0 comments on commit 736822e

Please sign in to comment.