Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[operator] fix libentry to support triton 2.3 #89

Merged
merged 3 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/performance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .conftest import CPU_MODE

WARMUP = 10
WARMUP = 100
REPETITION = 1000


Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/fused/skip_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@libentry()
@triton.jit
@triton.jit(do_not_specialize=["eps"])
def skip_layer_norm_kernel(
Y, # pointer to the output
X, # pointer to the input
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/fused/skip_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@libentry()
@triton.jit
@triton.jit(do_not_specialize=["eps"])
def skip_rms_norm_kernel(
Y, # pointer to the output
X, # pointer to the input
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
],
key=["M", "N", "K"],
)
@triton.jit
@triton.jit(do_not_specialize=["alpha", "beta"])
def addmm_kernel(
a_ptr,
b_ptr,
Expand Down
6 changes: 3 additions & 3 deletions src/flag_gems/ops/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Reduction(IntEnum):
),
},
)
@triton.jit
@triton.jit(do_not_specialize=["mean_num"])
def log_softmax_and_mul_kernel(
output_ptr,
input_ptr,
Expand Down Expand Up @@ -95,7 +95,7 @@ def log_softmax_and_mul_kernel(
),
},
)
@triton.jit
@triton.jit(do_not_specialize=["mean_num"])
def softmax_and_sub_kernel(
output_ptr,
input_ptr,
Expand Down Expand Up @@ -158,7 +158,7 @@ def softmax_and_sub_kernel(
),
},
)
@triton.jit
@triton.jit(do_not_specialize=["mean_num"])
def softmax_and_sub_reduce_kernel(
output_ptr,
input_ptr,
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _rand(seed, offset):
"N",
],
)
@triton.jit
@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
def dropout_forward_kernel(
X,
Y,
Expand Down Expand Up @@ -82,7 +82,7 @@ def dropout_forward_kernel(
"N",
],
)
@triton.jit
@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Triton specialize on floats?

Copy link
Collaborator Author

@StrongSpoon StrongSpoon Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

floats not marked as dns are specialized as False

def dropout_backward_kernel(
DY,
DX,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/groupnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@libentry()
@triton.jit
@triton.jit(do_not_specialize=["eps"])
def group_norm_kernel(
X,
Y,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def cfggen():

@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
@triton.jit(do_not_specialize=["eps"])
def layer_norm_kernel(
X,
Y,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@libentry()
@triton.jit
@triton.jit(do_not_specialize=["eps"])
def rms_norm_kernel(
Y, # pointer to the output
X, # pointer to the input
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/triu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def cfggen_batch():

@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
@triton.jit(do_not_specialize=["diagonal"])
def triu_kernel(
X,
Y,
Expand Down Expand Up @@ -55,7 +55,7 @@ def triu_kernel(

@libentry()
@triton.autotune(configs=cfggen_batch(), key=["batch", "MN", "N", "diagonal"])
@triton.jit
@triton.jit(do_not_specialize=["diagonal"])
def triu_batch_kernel(
X,
Y,
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/var_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def welford_func(mean_x, count_x, M_x, mean_y, count_y, M_y):

@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
@triton.jit(do_not_specialize=["correction"])
def var_mean_welford_kernel(
X,
Var,
Expand Down Expand Up @@ -112,7 +112,7 @@ def var_mean_kernel_1(
@triton.heuristics(
values={"BLOCK_N": lambda args: triton.next_power_of_2(args["BLOCK_NUM"])},
)
@triton.jit
@triton.jit(do_not_specialize=["correction"])
def var_mean_kernel_2(
Acc,
Average,
Expand Down
6 changes: 3 additions & 3 deletions src/flag_gems/ops/vector_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def l0_norm_kernel_2(Mid, Out, MID_SIZE, BLOCK_MID: tl.constexpr):

@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
@triton.jit(do_not_specialize=["ord"])
def v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
pid = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
X = X + pid * N
Expand All @@ -231,7 +231,7 @@ def v_norm_kernel(X, Out, M, N, ord, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexp


@libentry()
@triton.jit
@triton.jit(do_not_specialize=["ord"])
def l1_norm_kernel_1(X, Mid, ord, M, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
Expand All @@ -245,7 +245,7 @@ def l1_norm_kernel_1(X, Mid, ord, M, BLOCK_SIZE: tl.constexpr):


@libentry()
@triton.jit
@triton.jit(do_not_specialize=["ord"])
def l1_norm_kernel_2(Mid, Out, ord, MID_SIZE, BLOCK_MID: tl.constexpr):
offset = tl.arange(0, BLOCK_MID)
Mid = Mid + offset
Expand Down
122 changes: 93 additions & 29 deletions src/flag_gems/utils/libentry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import inspect

import triton


Expand All @@ -14,47 +16,109 @@ def __init__(
while not isinstance(fn, triton.runtime.JITFunction):
fn = fn.fn
self.jit_function: triton.runtime.JITFunction = fn
self.kernel_arg_indices = []
for p in self.jit_function.params:
if not p.is_constexpr:
self.kernel_arg_indices.append(p.num)
self.specialize_indices = [
p.num
for p in self.jit_function.params
if not p.is_constexpr and not p.do_not_specialize
]
self.do_not_specialize_indices = [
p.num
for p in self.jit_function.params
if not p.is_constexpr and p.do_not_specialize
]

def key(self, spec_args, dns_args, const_args):
spec_key = [
(arg.dtype, arg.data_ptr() % self.divisibility == 0)
if hasattr(arg, "data_ptr")
else (type(arg), arg)
for arg in spec_args
]
dns_key = [
arg.dtype
if hasattr(arg, "data_ptr")
else type(arg)
if not isinstance(arg, int)
else "i32"
if -(2**31) <= arg and arg <= 2**31 - 1
else "u64"
if 2**63 <= arg and arg <= 2**64 - 1
else "i64"
for arg in dns_args
]
# const args passed by position
return tuple(spec_key + dns_key + const_args)

def run(self, *args, **kwargs):
key = []
for arg in args:
if hasattr(arg, "data_ptr"):
key.append(arg.dtype)
key.append(arg.data_ptr() % self.divisibility == 0)
elif isinstance(arg, int):
key.append(arg)
entry_key = tuple(key)
grid = kwargs["grid"]

# collect all the arguments
spec_args = [] # specialize arguments
dns_args = [] # do not specialize arguments
const_args = [] # constexpr arguments
k_args = [] # kernel arguments
for i, arg in enumerate(args):
if i in self.specialize_indices:
k_args.append(arg)
spec_args.append(arg)
elif i in self.do_not_specialize_indices:
k_args.append(arg)
dns_args.append(arg)
else:
const_args.append(arg)
for p in self.jit_function.params[len(args) :]:
if p.name in kwargs:
val = kwargs[p.name]
elif p.default is inspect._empty:
continue
else:
val = p.default

if p.is_constexpr:
const_args.append(val)
elif p.do_not_specialize:
dns_args.append(val)
k_args.append(val)
else:
spec_args.append(val)
k_args.append(val)

entry_key = self.key(spec_args, dns_args, const_args)

if entry_key not in self.kernel_cache:
kernel = self.fn.run(*args, **kwargs)
self.kernel_cache[entry_key] = kernel
fn = self.fn
# collect constexpr arguments for grid computation
constexprs = {}
while not isinstance(fn, triton.runtime.JITFunction):
if isinstance(fn, triton.runtime.Autotuner):
config = fn.best_config
constexprs["num_warps"] = config.num_warps
constexprs["num_stages"] = config.num_stages
constexprs["num_ctas"] = config.num_ctas
constexprs = {**constexprs, **config.kwargs}
elif isinstance(fn, triton.runtime.Heuristics):
for v, heur in fn.values.items():
constexprs[v] = heur(
{**dict(zip(fn.arg_names, args)), **kwargs, **constexprs}
)
else:
raise RuntimeError("Invalid Runtime Function")
fn = fn.fn
for p in self.jit_function.params:
if p.is_constexpr and p.name not in constexprs:
constexprs[p.name] = p.default
self.kernel_cache[entry_key] = (kernel, constexprs)
else:
kernel = self.kernel_cache[entry_key]
kernel, constexprs = self.kernel_cache[entry_key]

# collect all the arguments to the kernel, all non-constexpr arguments
k_args = [arg for i, arg in enumerate(args) if i in self.kernel_arg_indices]
if len(k_args) < len(self.kernel_arg_indices):
for p in self.jit_function.params[len(args) :]:
if not p.is_constexpr:
if p.name in kwargs:
k_args.append(kwargs[p.name])
else:
k_args.append(p.default)

grid = kwargs["grid"]
if callable(grid):
# collect all arguments to the grid fn,ie:
# 1. args,
# 2. kwargs,
# 3. all all other captured arguments in CompiledKernel from Autotunner & Heuristics
# when kwargs & captured args conflict, captured args have higher priority
for k, v in kernel.constants.items():
arg_name = self.arg_names[int(k)]
kwargs[arg_name] = v
meta = {**dict(zip(self.arg_names, args)), **kwargs}
meta = {**dict(zip(self.arg_names, args)), **kwargs, **constexprs}
grid = grid(meta)
grid = grid + (1, 1)

Expand Down