Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Type promotion for pointwise Ops #79

Merged
merged 7 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/flag_gems/fused/gelu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def gelu_none_and_mul_kernel(x, y):
x_fp32 = x.to(tl.float32)
x_gelu = 0.5 * x_fp32 * (1 + tl.math.erf(x_fp32 * 0.7071067811))
return x_gelu * y


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def gelu_tanh_and_mul_kernel(x, y):
x_fp32 = x.to(tl.float32)
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/fused/silu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def silu_and_mul_kernel(x, y):
x_fp32 = x.to(tl.float32)
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, "COMPLEX_TO_FLOAT"]])
@triton.jit
def abs_func(x):
return tl.abs(x)
Expand Down
10 changes: 7 additions & 3 deletions src/flag_gems/ops/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,23 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic(is_tensor=[True, True, False])
@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def add_func(x, y, alpha):
return x + y * alpha


@pointwise_dynamic(is_tensor=[True, False, False])
@pointwise_dynamic(
is_tensor=[True, False, False], promotion_methods=[[0, 1, "DEFAULT"]]
)
@triton.jit
def add_func_tensor_scalar(x, y, alpha):
return x + y * alpha


@pointwise_dynamic(is_tensor=[False, True, False])
@pointwise_dynamic(
is_tensor=[False, True, False], promotion_methods=[[0, 1, "DEFAULT"]]
)
@triton.jit
def add_func_scalar_tensor(x, y, alpha):
return x + y * alpha
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/bitwise_and.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def bitwise_and_func(x, y):
return x & y
Expand All @@ -16,7 +16,7 @@ def bitwise_and_tensor(A, B):
return bitwise_and_func(A, B)


@pointwise_dynamic(is_tensor=[True, False])
@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def bitwise_and_func_scalar(x, y):
return x & y
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/bitwise_not.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, "DEFAULT"]])
@triton.jit
def bitwise_not_func(x):
return ~x
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/bitwise_or.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def bitwise_or_func(x, y):
return x | y
Expand All @@ -16,7 +16,7 @@ def bitwise_or_tensor(A, B):
return bitwise_or_func(A, B)


@pointwise_dynamic(is_tensor=[True, False])
@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def bitwise_or_func_scalar(x, y):
return x | y
Expand Down
14 changes: 8 additions & 6 deletions src/flag_gems/ops/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, 1, 2, "DEFAULT"]])
@triton.jit
def clamp_func_tensor(x, mini, maxi):
return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def clamp_func_min_tensor(x, mini):
return tl.maximum(mini, x.to(tl.float32))


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def clamp_func_max_tensor(x, maxi):
return tl.minimum(maxi, x.to(tl.float32))
Expand All @@ -36,19 +36,21 @@ def clamp_tensor(A, mini=None, maxi=None):
return clamp_func_tensor(A, mini, maxi)


@pointwise_dynamic(is_tensor=[True, False, False])
@pointwise_dynamic(
is_tensor=[True, False, False], promotion_methods=[[0, 1, 2, "DEFAULT"]]
)
@triton.jit
def clamp_func(x, mini, maxi):
return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))


@pointwise_dynamic(is_tensor=[True, False])
@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def clamp_func_min(x, mini):
return tl.maximum(mini, x.to(tl.float32))


@pointwise_dynamic(is_tensor=[True, False])
@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def clamp_func_max(x, maxi):
return tl.minimum(maxi, x.to(tl.float32))
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/cos.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]])
@triton.jit
def cos_func(x):
return tl.cos(x.to(tl.float32))
Expand Down
96 changes: 85 additions & 11 deletions src/flag_gems/ops/div.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,106 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, 1, "INT_TO_FLOAT"]])
@triton.jit
def div_func(x, y):
def true_div_func(x, y):
return x / y


@pointwise_dynamic(is_tensor=[True, False])
@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "INT_TO_FLOAT"]])
@triton.jit
def div_func_tensor_scalar(x, y):
def true_div_func_tensor_scalar(x, y):
return x / y


@pointwise_dynamic(is_tensor=[False, True])
@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, 1, "INT_TO_FLOAT"]])
@triton.jit
def div_func_scalar_tensor(x, y):
def true_div_func_scalar_tensor(x, y):
return x / y


def div(A, B):
logging.debug("GEMS DIV")
def true_divide(A, B):
logging.debug("GEMS TRUE_DIVIDE")
if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
return div_func(A, B)
return true_div_func(A, B)
elif isinstance(A, torch.Tensor):
return div_func_tensor_scalar(A, B)
return true_div_func_tensor_scalar(A, B)
elif isinstance(B, torch.Tensor):
return div_func_scalar_tensor(A, B)
return true_div_func_scalar_tensor(A, B)
else:
# Both scalar
return A / B


@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def trunc_div_func(x, y):
return triton.div_rz(x, y)


@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def trunc_div_func_tensor_scalar(x, y):
return triton.div_rz(x, y)


@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def trunc_div_func_scalar_tensor(x, y):
return triton.div_rz(x, y)


def trunc_divide(A, B):
logging.debug("GEMS TRUNC_DIVIDE")
if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
return trunc_div_func(A, B)
elif isinstance(A, torch.Tensor):
return trunc_div_func_tensor_scalar(A, B)
elif isinstance(B, torch.Tensor):
return trunc_div_func_scalar_tensor(A, B)
else:
# Both scalar
return A / B


@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def floor_div_func(x, y):
return x // y


@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def floor_div_func_tensor_scalar(x, y):
return x // y


@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, 1, "DEFAULT"]])
@triton.jit
def floor_div_func_scalar_tensor(x, y):
return x // y


def floor_divide(A, B):
logging.debug("GEMS FLOOR_DIVIDE")
if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
return floor_div_func(A, B)
elif isinstance(A, torch.Tensor):
return floor_div_func_tensor_scalar(A, B)
elif isinstance(B, torch.Tensor):
return floor_div_func_scalar_tensor(A, B)
else:
# Both scalar
return A // B


def div(A, B, rounding_mode=None):
iclementine marked this conversation as resolved.
Show resolved Hide resolved
if rounding_mode is None:
return true_divide(A, B)
elif rounding_mode == "trunc":
return trunc_divide(A, B)
elif rounding_mode == "floor":
return floor_divide(A, B)
else:
msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
raise ValueError(msg)
5 changes: 2 additions & 3 deletions src/flag_gems/ops/eq.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import pointwise_dynamic


@pointwise_dynamic(output_dtypes=[torch.bool])
@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]])
@triton.jit
def eq_func(x, y):
return x.to(tl.float32) == y.to(tl.float32)
Expand All @@ -18,7 +17,7 @@ def eq(A, B):
return eq_func(A, B)


@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool])
@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]])
@triton.jit
def eq_func_scalar(x, y):
return x.to(tl.float32) == y.to(tl.float32)
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]])
@triton.jit
def exp_func(x):
return tl.exp(x.to(tl.float32))
Expand Down
5 changes: 2 additions & 3 deletions src/flag_gems/ops/ge.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import pointwise_dynamic


@pointwise_dynamic(output_dtypes=[torch.bool])
@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]])
@triton.jit
def ge_func(x, y):
return x.to(tl.float32) >= y
Expand All @@ -18,7 +17,7 @@ def ge(A, B):
return ge_func(A, B)


@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool])
@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]])
@triton.jit
def ge_func_scalar(x, y):
return x.to(tl.float32) >= y
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]])
@triton.jit
def gelu_none(x):
scale = 0.7071067811
output = 0.5 * x * (1 + tl.math.erf(x * scale))
return output


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]])
@triton.jit
def gelu_tanh(x):
output = (
Expand Down
5 changes: 2 additions & 3 deletions src/flag_gems/ops/gt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import pointwise_dynamic


@pointwise_dynamic(output_dtypes=[torch.bool])
@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]])
@triton.jit
def gt_func(x, y):
return x.to(tl.float32) > y
Expand All @@ -18,7 +17,7 @@ def gt(A, B):
return gt_func(A, B)


@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool])
@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]])
@triton.jit
def gt_func_scalar(x, y):
return x.to(tl.float32) > y
Expand Down
3 changes: 1 addition & 2 deletions src/flag_gems/ops/isinf.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import pointwise_dynamic


@pointwise_dynamic(output_dtypes=[torch.bool])
@pointwise_dynamic(promotion_methods=[[0, "ALWAYS_BOOL"]])
@triton.jit
def isinf_func(x):
return tl.math.isinf(x.to(tl.float32))
Expand Down
Loading