From 8cee10f9dc18848c977dfdde065d933f2f2b5847 Mon Sep 17 00:00:00 2001 From: Bowen12992 Date: Fri, 21 Jun 2024 12:54:05 +0800 Subject: [PATCH 1/7] add type promotion for pointwise op --- src/flag_gems/ops/abs.py | 6 +++++ src/flag_gems/ops/add.py | 6 +++++ src/flag_gems/ops/bitwise_and.py | 14 ++++++++++ src/flag_gems/ops/bitwise_not.py | 6 +++++ src/flag_gems/ops/bitwise_or.py | 14 ++++++++++ src/flag_gems/ops/clamp.py | 10 +++++++ src/flag_gems/ops/cos.py | 6 +++++ src/flag_gems/ops/eq.py | 10 +++++++ src/flag_gems/ops/exp.py | 6 +++++ src/flag_gems/ops/ge.py | 10 +++++++ src/flag_gems/ops/gelu.py | 6 +++++ src/flag_gems/ops/gt.py | 10 +++++++ src/flag_gems/ops/isinf.py | 6 +++++ src/flag_gems/ops/isnan.py | 6 +++++ src/flag_gems/ops/le.py | 10 +++++++ src/flag_gems/ops/lt.py | 10 +++++++ src/flag_gems/ops/mul.py | 6 +++++ src/flag_gems/ops/ne.py | 10 +++++++ src/flag_gems/ops/neg.py | 6 +++++ src/flag_gems/ops/pow.py | 14 ++++++++++ src/flag_gems/ops/reciprocal.py | 6 +++++ src/flag_gems/ops/relu.py | 6 +++++ src/flag_gems/ops/rsqrt.py | 6 +++++ src/flag_gems/ops/sigmoid.py | 6 +++++ src/flag_gems/ops/silu.py | 6 +++++ src/flag_gems/ops/sin.py | 6 +++++ src/flag_gems/ops/sub.py | 6 +++++ src/flag_gems/ops/tanh.py | 6 +++++ src/flag_gems/ops/where.py | 14 ++++++++++ src/flag_gems/utils/pointwise_dynamic.py | 33 +++++++++++++++++------- 30 files changed, 257 insertions(+), 10 deletions(-) diff --git a/src/flag_gems/ops/abs.py b/src/flag_gems/ops/abs.py index 1f8eabf9..5f2e1dd3 100644 --- a/src/flag_gems/ops/abs.py +++ b/src/flag_gems/ops/abs.py @@ -2,6 +2,8 @@ import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -12,6 +14,10 @@ def abs_func(x): return tl.abs(x) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) def abs(A): logging.debug("GEMS ABS") return abs_func(A) diff --git a/src/flag_gems/ops/add.py b/src/flag_gems/ops/add.py index 5c6dfee3..efe29d73 100644 --- a/src/flag_gems/ops/add.py +++ b/src/flag_gems/ops/add.py @@ -2,6 +2,8 @@ import torch import triton +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -24,6 +26,10 @@ def add_func_scalar_tensor(x, y, alpha): return x + y * alpha +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def add(A, B, *, alpha=1): logging.debug("GEMS ADD") if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): diff --git a/src/flag_gems/ops/bitwise_and.py b/src/flag_gems/ops/bitwise_and.py index 72a994e0..a8c9874b 100644 --- a/src/flag_gems/ops/bitwise_and.py +++ b/src/flag_gems/ops/bitwise_and.py @@ -1,6 +1,8 @@ import logging import triton +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -11,6 +13,10 @@ def bitwise_and_func(x, y): return x & y +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def bitwise_and_tensor(A, B): logging.debug("GEMS BITWISE AND") return bitwise_and_func(A, B) @@ -22,11 +28,19 @@ def bitwise_and_func_scalar(x, y): return x & y +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def bitwise_and_scalar(A, B): logging.debug("GEMS BITWISE AND SCALAR") return bitwise_and_func_scalar(A, B) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def bitwise_and_scalar_tensor(A, B): logging.debug("GEMS BITWISE AND SCALAR TENSOR") return bitwise_and_func_scalar(B, A) diff --git a/src/flag_gems/ops/bitwise_not.py b/src/flag_gems/ops/bitwise_not.py index b5977514..11126b98 100644 --- a/src/flag_gems/ops/bitwise_not.py +++ b/src/flag_gems/ops/bitwise_not.py @@ -1,6 +1,8 @@ import logging import triton +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -11,6 +13,10 @@ def bitwise_not_func(x): return ~x +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def bitwise_not(A): logging.debug("GEMS BITWISE NOT") return bitwise_not_func(A) diff --git a/src/flag_gems/ops/bitwise_or.py b/src/flag_gems/ops/bitwise_or.py index 99f654e4..f2490e13 100644 --- a/src/flag_gems/ops/bitwise_or.py +++ b/src/flag_gems/ops/bitwise_or.py @@ -1,6 +1,8 @@ import logging import triton +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -11,6 +13,10 @@ def bitwise_or_func(x, y): return x | y +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def bitwise_or_tensor(A, B): logging.debug("GEMS BITWISE OR") return bitwise_or_func(A, B) @@ -22,11 +28,19 @@ def bitwise_or_func_scalar(x, y): return x | y +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def bitwise_or_scalar(A, B): logging.debug("GEMS BITWISE OR SCALAR") return bitwise_or_func_scalar(A, B) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def bitwise_or_scalar_tensor(A, B): logging.debug("GEMS BITWISE OR SCALAR TENSOR") return bitwise_or_func_scalar(B, A) diff --git a/src/flag_gems/ops/clamp.py b/src/flag_gems/ops/clamp.py index c603fe52..de2019b2 100644 --- a/src/flag_gems/ops/clamp.py +++ b/src/flag_gems/ops/clamp.py @@ -2,6 +2,8 @@ import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -24,6 +26,10 @@ def clamp_func_max_tensor(x, maxi): return tl.minimum(maxi, x.to(tl.float32)) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "mini", "maxi"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def clamp_tensor(A, mini=None, maxi=None): logging.debug("GEMS CLAMP TENSOR") if mini is None and maxi is None: @@ -54,6 +60,10 @@ def clamp_func_max(x, maxi): return tl.minimum(maxi, x.to(tl.float32)) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "mini", "maxi"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def clamp(A, mini=None, maxi=None): logging.debug("GEMS CLAMP") if mini is None and maxi is None: diff --git a/src/flag_gems/ops/cos.py b/src/flag_gems/ops/cos.py index fee00a4b..3ecad3ab 100644 --- a/src/flag_gems/ops/cos.py +++ b/src/flag_gems/ops/cos.py @@ -2,6 +2,8 @@ import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -12,6 +14,10 @@ def cos_func(x): return tl.cos(x.to(tl.float32)) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) def cos(A): logging.debug("GEMS COS") return cos_func(A) diff --git a/src/flag_gems/ops/eq.py b/src/flag_gems/ops/eq.py index 38ee0447..b5707c46 100644 --- a/src/flag_gems/ops/eq.py +++ b/src/flag_gems/ops/eq.py @@ -3,6 +3,8 @@ import torch import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -13,6 +15,10 @@ def eq_func(x, y): return x.to(tl.float32) == y.to(tl.float32) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) def eq(A, B): logging.debug("GEMS EQ") return eq_func(A, B) @@ -24,6 +30,10 @@ def eq_func_scalar(x, y): return x.to(tl.float32) == y.to(tl.float32) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) def eq_scalar(A, B): logging.debug("GEMS EQ SCALAR") return eq_func_scalar(A, B) diff --git a/src/flag_gems/ops/exp.py b/src/flag_gems/ops/exp.py index 5d51c5b4..b350e65b 100644 --- a/src/flag_gems/ops/exp.py +++ b/src/flag_gems/ops/exp.py @@ -2,6 +2,8 @@ import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -12,6 +14,10 @@ def exp_func(x): return tl.exp(x.to(tl.float32)) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) def exp(A): logging.debug("GEMS EXP") return exp_func(A) diff --git a/src/flag_gems/ops/ge.py b/src/flag_gems/ops/ge.py index 2edd5dfc..8b13b8e9 100644 --- a/src/flag_gems/ops/ge.py +++ b/src/flag_gems/ops/ge.py @@ -3,6 +3,8 @@ import torch import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -13,6 +15,10 @@ def ge_func(x, y): return x.to(tl.float32) >= y +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) def ge(A, B): logging.debug("GEMS GE") return ge_func(A, B) @@ -24,6 +30,10 @@ def ge_func_scalar(x, y): return x.to(tl.float32) >= y +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) def ge_scalar(A, B): logging.debug("GEMS GE SCALAR") return ge_func_scalar(A, B) diff --git a/src/flag_gems/ops/gelu.py b/src/flag_gems/ops/gelu.py index 2df36194..028f0018 100644 --- a/src/flag_gems/ops/gelu.py +++ b/src/flag_gems/ops/gelu.py @@ -2,6 +2,8 @@ import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -30,6 +32,10 @@ def gelu_tanh(x): return output +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def gelu(A, *, approximate="none"): logging.debug("GEMS GELU") if approximate == "tanh": diff --git a/src/flag_gems/ops/gt.py b/src/flag_gems/ops/gt.py index 4fe53628..599a5c5c 100644 --- a/src/flag_gems/ops/gt.py +++ b/src/flag_gems/ops/gt.py @@ -3,6 +3,8 @@ import torch import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -13,6 +15,10 @@ def gt_func(x, y): return x.to(tl.float32) > y +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) def gt(A, B): logging.debug("GEMS GT") return gt_func(A, B) @@ -24,6 +30,10 @@ def gt_func_scalar(x, y): return x.to(tl.float32) > y +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) def gt_scalar(A, B): logging.debug("GEMS GT SCALAR") return gt_func_scalar(A, B) diff --git a/src/flag_gems/ops/isinf.py b/src/flag_gems/ops/isinf.py index 17c8488c..5505cffb 100644 --- a/src/flag_gems/ops/isinf.py +++ b/src/flag_gems/ops/isinf.py @@ -3,6 +3,8 @@ import torch import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -13,6 +15,10 @@ def isinf_func(x): return tl.math.isinf(x.to(tl.float32)) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) def isinf(A): logging.debug("GEMS ISINF") return isinf_func(A) diff --git a/src/flag_gems/ops/isnan.py b/src/flag_gems/ops/isnan.py index f58f2ce1..efeedea2 100644 --- a/src/flag_gems/ops/isnan.py +++ b/src/flag_gems/ops/isnan.py @@ -3,6 +3,8 @@ import torch import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -13,6 +15,10 @@ def isnan_func(x): return tl.math.isnan(x.to(tl.float32)) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) def isnan(A): logging.debug("GEMS ISNAN") return isnan_func(A) diff --git a/src/flag_gems/ops/le.py b/src/flag_gems/ops/le.py index 70d84be5..c837c48e 100644 --- a/src/flag_gems/ops/le.py +++ b/src/flag_gems/ops/le.py @@ -3,6 +3,8 @@ import torch import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -13,6 +15,10 @@ def le_func(x, y): return x.to(tl.float32) <= y +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) def le(A, B): logging.debug("GEMS LE") return le_func(A, B) @@ -24,6 +30,10 @@ def le_func_scalar(x, y): return x.to(tl.float32) <= y +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) def le_scalar(A, B): logging.debug("GEMS LE SCALAR") return le_func_scalar(A, B) diff --git a/src/flag_gems/ops/lt.py b/src/flag_gems/ops/lt.py index 26e828d1..be811c07 100644 --- a/src/flag_gems/ops/lt.py +++ b/src/flag_gems/ops/lt.py @@ -3,6 +3,8 @@ import torch import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -13,6 +15,10 @@ def lt_func(x, y): return x.to(tl.float32) < y +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) def lt(A, B): logging.debug("GEMS LT") return lt_func(A, B) @@ -24,6 +30,10 @@ def lt_func_scalar(x, y): return x.to(tl.float32) < y +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) def lt_scalar(A, B): logging.debug("GEMS LT SCALAR") return lt_func_scalar(A, B) diff --git a/src/flag_gems/ops/mul.py b/src/flag_gems/ops/mul.py index e5745f98..6a6f6ccd 100644 --- a/src/flag_gems/ops/mul.py +++ b/src/flag_gems/ops/mul.py @@ -2,6 +2,8 @@ import torch import triton +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -18,6 +20,10 @@ def mul_func_scalar(x, y): return x * y +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def mul(A, B): logging.debug("GEMS MUL") if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): diff --git a/src/flag_gems/ops/ne.py b/src/flag_gems/ops/ne.py index f322ffff..e410d1b3 100644 --- a/src/flag_gems/ops/ne.py +++ b/src/flag_gems/ops/ne.py @@ -3,6 +3,8 @@ import torch import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -13,6 +15,10 @@ def ne_func(x, y): return x.to(tl.float32) != y.to(tl.float32) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) def ne(A, B): logging.debug("GEMS NE") return ne_func(A, B) @@ -24,6 +30,10 @@ def ne_func_scalar(x, y): return x.to(tl.float32) != y.to(tl.float32) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) def ne_scalar(A, B): logging.debug("GEMS NE SCALAR") return ne_func_scalar(A, B) diff --git a/src/flag_gems/ops/neg.py b/src/flag_gems/ops/neg.py index 88ea6a26..6e3db4c8 100644 --- a/src/flag_gems/ops/neg.py +++ b/src/flag_gems/ops/neg.py @@ -1,6 +1,8 @@ import logging import triton +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -11,6 +13,10 @@ def neg_func(x): return -x +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def neg(A): logging.debug("GEMS NEG") return neg_func(A) diff --git a/src/flag_gems/ops/pow.py b/src/flag_gems/ops/pow.py index 60abf1fa..98e20a65 100644 --- a/src/flag_gems/ops/pow.py +++ b/src/flag_gems/ops/pow.py @@ -2,6 +2,8 @@ import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -12,6 +14,10 @@ def pow_func(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, +) def pow_tensor_tensor(A, exponent): logging.debug("GEMS POW_TENSOR_TENSOR") return pow_func(A, exponent) @@ -23,6 +29,10 @@ def pow_func_tensor_scalar(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, +) def pow_tensor_scalar(A, exponent): logging.debug("GEMS POW_TENSOR_SCALAR") return pow_func_tensor_scalar(A, exponent) @@ -34,6 +44,10 @@ def pow_func_scalar_tensor(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, +) def pow_scalar(A, exponent): logging.debug("GEMS POW_SCALAR") return pow_func_scalar_tensor(A, exponent) diff --git a/src/flag_gems/ops/reciprocal.py b/src/flag_gems/ops/reciprocal.py index 432d45ea..33f7ebf3 100644 --- a/src/flag_gems/ops/reciprocal.py +++ b/src/flag_gems/ops/reciprocal.py @@ -2,6 +2,8 @@ import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -12,6 +14,10 @@ def reciprocal_func(x): return 1.0 / x.to(tl.float32) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) def reciprocal(A): logging.debug("GEMS RECIPROCAL") return reciprocal_func(A) diff --git a/src/flag_gems/ops/relu.py b/src/flag_gems/ops/relu.py index 32d59af1..aec698c4 100644 --- a/src/flag_gems/ops/relu.py +++ b/src/flag_gems/ops/relu.py @@ -3,6 +3,8 @@ import torch import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -35,5 +37,9 @@ def backward(ctx, out_grad): return in_grad +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def relu(A): return Relu.apply(A) diff --git a/src/flag_gems/ops/rsqrt.py b/src/flag_gems/ops/rsqrt.py index 24090a0e..e15010eb 100644 --- a/src/flag_gems/ops/rsqrt.py +++ b/src/flag_gems/ops/rsqrt.py @@ -2,6 +2,8 @@ import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -12,6 +14,10 @@ def rsqrt_func(x): return 1.0 / tl.sqrt(x.to(tl.float32)) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) def rsqrt(A): logging.debug("GEMS RSQRT") return rsqrt_func(A) diff --git a/src/flag_gems/ops/sigmoid.py b/src/flag_gems/ops/sigmoid.py index 833e9559..e989da69 100644 --- a/src/flag_gems/ops/sigmoid.py +++ b/src/flag_gems/ops/sigmoid.py @@ -4,6 +4,8 @@ import torch import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -39,5 +41,9 @@ def backward(ctx, out_grad): return in_grad +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) def sigmoid(A): return Sigmoid.apply(A) diff --git a/src/flag_gems/ops/silu.py b/src/flag_gems/ops/silu.py index 358ab3aa..f0f4a7fc 100644 --- a/src/flag_gems/ops/silu.py +++ b/src/flag_gems/ops/silu.py @@ -3,6 +3,8 @@ import torch import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -41,5 +43,9 @@ def backward(ctx, out_grad): return in_grad +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def silu(A): return Silu.apply(A) diff --git a/src/flag_gems/ops/sin.py b/src/flag_gems/ops/sin.py index 4431f75e..7394b93a 100644 --- a/src/flag_gems/ops/sin.py +++ b/src/flag_gems/ops/sin.py @@ -2,6 +2,8 @@ import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -12,6 +14,10 @@ def sin_func(x): return tl.sin(x.to(tl.float32)) +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) def sin(A): logging.debug("GEMS SIN") return sin_func(A) diff --git a/src/flag_gems/ops/sub.py b/src/flag_gems/ops/sub.py index 72804744..00dcc4fd 100644 --- a/src/flag_gems/ops/sub.py +++ b/src/flag_gems/ops/sub.py @@ -2,6 +2,8 @@ import torch import triton +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -24,6 +26,10 @@ def sub_func_scalar_tensor(x, y, alpha): return x - y * alpha +@elementwise_type_promotion_wrapper( + type_promoting_args=("A", "B"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) def sub(A, B, *, alpha=1): logging.debug("GEMS SUB") if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): diff --git a/src/flag_gems/ops/tanh.py b/src/flag_gems/ops/tanh.py index b2cff858..b6353fc8 100644 --- a/src/flag_gems/ops/tanh.py +++ b/src/flag_gems/ops/tanh.py @@ -3,6 +3,8 @@ import torch import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -35,5 +37,9 @@ def backward(ctx, out_grad): return in_grad +@elementwise_type_promotion_wrapper( + type_promoting_args=("A"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) def tanh(A): return Tanh.apply(A) diff --git a/src/flag_gems/ops/where.py b/src/flag_gems/ops/where.py index 68b09868..21f9fc06 100644 --- a/src/flag_gems/ops/where.py +++ b/src/flag_gems/ops/where.py @@ -2,6 +2,8 @@ import triton import triton.language as tl +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic @@ -12,6 +14,10 @@ def where_self_func(self, condition, other): return tl.where(condition, self, other) +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "other"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, +) def where_self(condition, self, other): logging.debug("GEMS WHERE_SELF") return where_self_func(self, condition, other) @@ -23,6 +29,10 @@ def where_scalar_self_func(other, condition, self): return tl.where(condition, self, other) +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "other"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, +) def where_scalar_self(condition, self, other): logging.debug("GEMS WHERE_SCALAR_SELF") return where_scalar_self_func(other, condition, self) @@ -34,6 +44,10 @@ def where_scalar_other_func(self, condition, other): return tl.where(condition, self, other) +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "other"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, +) def where_scalar_other(condition, self, other): logging.debug("GEMS WHERE_SCALAR_OTHER") return where_scalar_other_func(self, condition, other) diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index c6f47c32..7f658423 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -245,12 +245,16 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer: code.writeline("import triton") code.writeline("from triton import language as tl") code.newline() - code.writeline( - "from flag_gems.utils.shape_utils import broadcast_shapes, \ - broadcasted_stride, c_contiguous_stride, volume, Stride" - ) + code.writeline("from flag_gems.utils.shape_utils import (") + code.writeline(" broadcast_shapes,") + code.writeline(" broadcasted_stride,") + code.writeline(" c_contiguous_stride,") + code.writeline(" volume,") + code.writeline(" Stride,") + code.writeline(")") code.writeline("from flag_gems.utils.libentry import libentry") code.newline() + code.newline() return code @@ -280,25 +284,33 @@ def generate_functional_pointwise_wrapper( for i in range(op_desc.num_outputs()): if op_desc.output_dtype(i) is None: code.writeline( - f"out{num_output_tensor_index} = \ - torch.empty(shape, dtype=in0.dtype, device=in0.device)" + ( + f"out{num_output_tensor_index} = " + f"torch.empty(shape, dtype=in0.dtype, device=in0.device)" + ) ) else: code.writeline( - f"out{num_output_tensor_index} = \ - torch.empty(shape, dtype={_type_name(op_desc.output_dtype(i))}, device=in0.device)" + ( + f"out{num_output_tensor_index} = " + f"torch.empty(shape, dtype={_type_name(op_desc.output_dtype(i))}, " + f"device=in0.device)" + ) ) num_output_tensor_index += 1 # call destination_passing_func output_names: str = output_ref_for_wrapper(op_desc) - call_str = f"{output_names} = {destination_passing_func_name} \ - ({parameter_ref_for_wrapper(op_desc, include_outputs=True)})" + call_str = ( + f"{output_names} = {destination_passing_func_name}" + f"({parameter_ref_for_wrapper(op_desc, include_outputs=True)})" + ) code.writeline(call_str) return_str = f"return {output_names}" code.writeline(return_str) code.newline() + code.newline() return code @@ -381,6 +393,7 @@ def generate_destination_passing_pointwise_wrapper( # return code.writeline(f"return {output_ref_for_wrapper(op_desc)}") code.newline() + code.newline() return code From 41991e3933f5b4f22730262ced8aa17d6f4ca62c Mon Sep 17 00:00:00 2001 From: Bowen12992 Date: Fri, 21 Jun 2024 15:17:09 +0800 Subject: [PATCH 2/7] add trunc & floor mode for divOp --- src/flag_gems/ops/div.py | 91 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 83 insertions(+), 8 deletions(-) diff --git a/src/flag_gems/ops/div.py b/src/flag_gems/ops/div.py index 20e4b60c..7fb92be3 100644 --- a/src/flag_gems/ops/div.py +++ b/src/flag_gems/ops/div.py @@ -8,30 +8,105 @@ @pointwise_dynamic @triton.jit -def div_func(x, y): +def true_div_func(x, y): return x / y @pointwise_dynamic(is_tensor=[True, False]) @triton.jit -def div_func_tensor_scalar(x, y): +def true_div_func_tensor_scalar(x, y): return x / y @pointwise_dynamic(is_tensor=[False, True]) @triton.jit -def div_func_scalar_tensor(x, y): +def true_div_func_scalar_tensor(x, y): return x / y -def div(A, B): - logging.debug("GEMS DIV") +def true_divide(A, B): + logging.debug("GEMS TRUE_DIVIDE") if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): - return div_func(A, B) + return true_div_func(A, B) elif isinstance(A, torch.Tensor): - return div_func_tensor_scalar(A, B) + return true_div_func_tensor_scalar(A, B) elif isinstance(B, torch.Tensor): - return div_func_scalar_tensor(A, B) + return true_div_func_scalar_tensor(A, B) else: # Both scalar return A / B + + +@pointwise_dynamic +@triton.jit +def trunc_div_func(x, y): + return triton.div_rz(x, y) + + +@pointwise_dynamic(is_tensor=[True, False]) +@triton.jit +def trunc_div_func_tensor_scalar(x, y): + return triton.div_rz(x, y) + + +@pointwise_dynamic(is_tensor=[False, True]) +@triton.jit +def trunc_div_func_scalar_tensor(x, y): + return triton.div_rz(x, y) + + +def trunc_divide(A, B): + logging.debug("GEMS TRUNC_DIVIDE") + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + return trunc_div_func(A, B) + elif isinstance(A, torch.Tensor): + return trunc_div_func_tensor_scalar(A, B) + elif isinstance(B, torch.Tensor): + return trunc_div_func_scalar_tensor(A, B) + else: + # Both scalar + return A / B + + +@pointwise_dynamic +@triton.jit +def floor_div_func(x, y): + return x // y + + +@pointwise_dynamic(is_tensor=[True, False]) +@triton.jit +def floor_div_func_tensor_scalar(x, y): + return x // y + + +@pointwise_dynamic(is_tensor=[False, True]) +@triton.jit +def floor_div_func_scalar_tensor(x, y): + return x // y + + +def floor_divide(A, B): + logging.debug("GEMS FLOOR_DIVIDE") + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + return floor_div_func(A, B) + elif isinstance(A, torch.Tensor): + return floor_div_func_tensor_scalar(A, B) + elif isinstance(B, torch.Tensor): + return floor_div_func_scalar_tensor(A, B) + else: + # Both scalar + return A // B + + +def div(A, B, rounding_mode=None): + print("xxx") + if rounding_mode is None: + return true_divide(A, B) + elif rounding_mode == "trunc": + return trunc_divide(A, B) + elif rounding_mode == "floor": + return floor_divide(A, B) + else: + msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." + raise ValueError(msg) From d4c36d688734b70370be16a0c3f5ff3b949b5755 Mon Sep 17 00:00:00 2001 From: Bowen12992 Date: Tue, 25 Jun 2024 18:22:10 +0800 Subject: [PATCH 3/7] add type promotion while code gen --- src/flag_gems/fused/gelu_and_mul.py | 4 +- src/flag_gems/fused/silu_and_mul.py | 2 +- src/flag_gems/ops/abs.py | 8 +- src/flag_gems/ops/add.py | 16 ++-- src/flag_gems/ops/bitwise_and.py | 18 +--- src/flag_gems/ops/bitwise_not.py | 8 +- src/flag_gems/ops/bitwise_or.py | 18 +--- src/flag_gems/ops/clamp.py | 24 ++---- src/flag_gems/ops/cos.py | 8 +- src/flag_gems/ops/div.py | 19 ++--- src/flag_gems/ops/eq.py | 15 +--- src/flag_gems/ops/exp.py | 8 +- src/flag_gems/ops/ge.py | 15 +--- src/flag_gems/ops/gelu.py | 10 +-- src/flag_gems/ops/gt.py | 15 +--- src/flag_gems/ops/isinf.py | 9 +- src/flag_gems/ops/isnan.py | 9 +- src/flag_gems/ops/le.py | 15 +--- src/flag_gems/ops/lt.py | 15 +--- src/flag_gems/ops/mul.py | 10 +-- src/flag_gems/ops/ne.py | 15 +--- src/flag_gems/ops/neg.py | 8 +- src/flag_gems/ops/pow.py | 20 +---- src/flag_gems/ops/reciprocal.py | 8 +- src/flag_gems/ops/relu.py | 10 +-- src/flag_gems/ops/rsqrt.py | 8 +- src/flag_gems/ops/sigmoid.py | 10 +-- src/flag_gems/ops/silu.py | 10 +-- src/flag_gems/ops/sin.py | 8 +- src/flag_gems/ops/sub.py | 16 ++-- src/flag_gems/ops/tanh.py | 10 +-- src/flag_gems/ops/where.py | 38 ++++----- src/flag_gems/utils/pointwise_dynamic.py | 104 ++++++++++++++--------- src/flag_gems/utils/type_utils.py | 9 ++ 34 files changed, 164 insertions(+), 356 deletions(-) create mode 100644 src/flag_gems/utils/type_utils.py diff --git a/src/flag_gems/fused/gelu_and_mul.py b/src/flag_gems/fused/gelu_and_mul.py index c2cb52f6..d9dd984b 100644 --- a/src/flag_gems/fused/gelu_and_mul.py +++ b/src/flag_gems/fused/gelu_and_mul.py @@ -7,7 +7,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def gelu_none_and_mul_kernel(x, y): x_fp32 = x.to(tl.float32) @@ -15,7 +15,7 @@ def gelu_none_and_mul_kernel(x, y): return x_gelu * y -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def gelu_tanh_and_mul_kernel(x, y): x_fp32 = x.to(tl.float32) diff --git a/src/flag_gems/fused/silu_and_mul.py b/src/flag_gems/fused/silu_and_mul.py index 0d1271ec..e5cbbfd1 100644 --- a/src/flag_gems/fused/silu_and_mul.py +++ b/src/flag_gems/fused/silu_and_mul.py @@ -7,7 +7,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def silu_and_mul_kernel(x, y): x_fp32 = x.to(tl.float32) diff --git a/src/flag_gems/ops/abs.py b/src/flag_gems/ops/abs.py index 5f2e1dd3..d2267278 100644 --- a/src/flag_gems/ops/abs.py +++ b/src/flag_gems/ops/abs.py @@ -2,22 +2,16 @@ import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "COMPLEX_TO_FLOAT"]]) @triton.jit def abs_func(x): return tl.abs(x) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, -) def abs(A): logging.debug("GEMS ABS") return abs_func(A) diff --git a/src/flag_gems/ops/add.py b/src/flag_gems/ops/add.py index efe29d73..a0181087 100644 --- a/src/flag_gems/ops/add.py +++ b/src/flag_gems/ops/add.py @@ -2,34 +2,32 @@ import torch import triton -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic(is_tensor=[True, True, False]) +@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def add_func(x, y, alpha): return x + y * alpha -@pointwise_dynamic(is_tensor=[True, False, False]) +@pointwise_dynamic( + is_tensor=[True, False, False], promotion_methods=[[0, 1, "DEFAULT"]] +) @triton.jit def add_func_tensor_scalar(x, y, alpha): return x + y * alpha -@pointwise_dynamic(is_tensor=[False, True, False]) +@pointwise_dynamic( + is_tensor=[False, True, False], promotion_methods=[[0, 1, "DEFAULT"]] +) @triton.jit def add_func_scalar_tensor(x, y, alpha): return x + y * alpha -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def add(A, B, *, alpha=1): logging.debug("GEMS ADD") if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): diff --git a/src/flag_gems/ops/bitwise_and.py b/src/flag_gems/ops/bitwise_and.py index a8c9874b..7820ccac 100644 --- a/src/flag_gems/ops/bitwise_and.py +++ b/src/flag_gems/ops/bitwise_and.py @@ -1,46 +1,32 @@ import logging import triton -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def bitwise_and_func(x, y): return x & y -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def bitwise_and_tensor(A, B): logging.debug("GEMS BITWISE AND") return bitwise_and_func(A, B) -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def bitwise_and_func_scalar(x, y): return x & y -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def bitwise_and_scalar(A, B): logging.debug("GEMS BITWISE AND SCALAR") return bitwise_and_func_scalar(A, B) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def bitwise_and_scalar_tensor(A, B): logging.debug("GEMS BITWISE AND SCALAR TENSOR") return bitwise_and_func_scalar(B, A) diff --git a/src/flag_gems/ops/bitwise_not.py b/src/flag_gems/ops/bitwise_not.py index 11126b98..e3d720c7 100644 --- a/src/flag_gems/ops/bitwise_not.py +++ b/src/flag_gems/ops/bitwise_not.py @@ -1,22 +1,16 @@ import logging import triton -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, "DEFAULT"]]) @triton.jit def bitwise_not_func(x): return ~x -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def bitwise_not(A): logging.debug("GEMS BITWISE NOT") return bitwise_not_func(A) diff --git a/src/flag_gems/ops/bitwise_or.py b/src/flag_gems/ops/bitwise_or.py index f2490e13..bd0d97ec 100644 --- a/src/flag_gems/ops/bitwise_or.py +++ b/src/flag_gems/ops/bitwise_or.py @@ -1,46 +1,32 @@ import logging import triton -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def bitwise_or_func(x, y): return x | y -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def bitwise_or_tensor(A, B): logging.debug("GEMS BITWISE OR") return bitwise_or_func(A, B) -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def bitwise_or_func_scalar(x, y): return x | y -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def bitwise_or_scalar(A, B): logging.debug("GEMS BITWISE OR SCALAR") return bitwise_or_func_scalar(A, B) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def bitwise_or_scalar_tensor(A, B): logging.debug("GEMS BITWISE OR SCALAR TENSOR") return bitwise_or_func_scalar(B, A) diff --git a/src/flag_gems/ops/clamp.py b/src/flag_gems/ops/clamp.py index de2019b2..c0ff1c5f 100644 --- a/src/flag_gems/ops/clamp.py +++ b/src/flag_gems/ops/clamp.py @@ -2,34 +2,28 @@ import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, 1, 2, "DEFAULT"]]) @triton.jit def clamp_func_tensor(x, mini, maxi): return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def clamp_func_min_tensor(x, mini): return tl.maximum(mini, x.to(tl.float32)) -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def clamp_func_max_tensor(x, maxi): return tl.minimum(maxi, x.to(tl.float32)) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "mini", "maxi"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def clamp_tensor(A, mini=None, maxi=None): logging.debug("GEMS CLAMP TENSOR") if mini is None and maxi is None: @@ -42,28 +36,26 @@ def clamp_tensor(A, mini=None, maxi=None): return clamp_func_tensor(A, mini, maxi) -@pointwise_dynamic(is_tensor=[True, False, False]) +@pointwise_dynamic( + is_tensor=[True, False, False], promotion_methods=[[0, 1, 2, "DEFAULT"]] +) @triton.jit def clamp_func(x, mini, maxi): return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def clamp_func_min(x, mini): return tl.maximum(mini, x.to(tl.float32)) -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def clamp_func_max(x, maxi): return tl.minimum(maxi, x.to(tl.float32)) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "mini", "maxi"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def clamp(A, mini=None, maxi=None): logging.debug("GEMS CLAMP") if mini is None and maxi is None: diff --git a/src/flag_gems/ops/cos.py b/src/flag_gems/ops/cos.py index 3ecad3ab..8bb1bff0 100644 --- a/src/flag_gems/ops/cos.py +++ b/src/flag_gems/ops/cos.py @@ -2,22 +2,16 @@ import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) @triton.jit def cos_func(x): return tl.cos(x.to(tl.float32)) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, -) def cos(A): logging.debug("GEMS COS") return cos_func(A) diff --git a/src/flag_gems/ops/div.py b/src/flag_gems/ops/div.py index 7fb92be3..8acbc848 100644 --- a/src/flag_gems/ops/div.py +++ b/src/flag_gems/ops/div.py @@ -6,19 +6,19 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, 1, "INT_TO_FLOAT"]]) @triton.jit def true_div_func(x, y): return x / y -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "INT_TO_FLOAT"]]) @triton.jit def true_div_func_tensor_scalar(x, y): return x / y -@pointwise_dynamic(is_tensor=[False, True]) +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, 1, "INT_TO_FLOAT"]]) @triton.jit def true_div_func_scalar_tensor(x, y): return x / y @@ -37,19 +37,19 @@ def true_divide(A, B): return A / B -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def trunc_div_func(x, y): return triton.div_rz(x, y) -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def trunc_div_func_tensor_scalar(x, y): return triton.div_rz(x, y) -@pointwise_dynamic(is_tensor=[False, True]) +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def trunc_div_func_scalar_tensor(x, y): return triton.div_rz(x, y) @@ -68,19 +68,19 @@ def trunc_divide(A, B): return A / B -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def floor_div_func(x, y): return x // y -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def floor_div_func_tensor_scalar(x, y): return x // y -@pointwise_dynamic(is_tensor=[False, True]) +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def floor_div_func_scalar_tensor(x, y): return x // y @@ -100,7 +100,6 @@ def floor_divide(A, B): def div(A, B, rounding_mode=None): - print("xxx") if rounding_mode is None: return true_divide(A, B) elif rounding_mode == "trunc": diff --git a/src/flag_gems/ops/eq.py b/src/flag_gems/ops/eq.py index b5707c46..ec4a1f40 100644 --- a/src/flag_gems/ops/eq.py +++ b/src/flag_gems/ops/eq.py @@ -1,39 +1,28 @@ import logging -import torch import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) @triton.jit def eq_func(x, y): return x.to(tl.float32) == y.to(tl.float32) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, -) def eq(A, B): logging.debug("GEMS EQ") return eq_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) @triton.jit def eq_func_scalar(x, y): return x.to(tl.float32) == y.to(tl.float32) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, -) def eq_scalar(A, B): logging.debug("GEMS EQ SCALAR") return eq_func_scalar(A, B) diff --git a/src/flag_gems/ops/exp.py b/src/flag_gems/ops/exp.py index b350e65b..3b37048e 100644 --- a/src/flag_gems/ops/exp.py +++ b/src/flag_gems/ops/exp.py @@ -2,22 +2,16 @@ import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) @triton.jit def exp_func(x): return tl.exp(x.to(tl.float32)) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, -) def exp(A): logging.debug("GEMS EXP") return exp_func(A) diff --git a/src/flag_gems/ops/ge.py b/src/flag_gems/ops/ge.py index 8b13b8e9..5d58410a 100644 --- a/src/flag_gems/ops/ge.py +++ b/src/flag_gems/ops/ge.py @@ -1,39 +1,28 @@ import logging -import torch import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) @triton.jit def ge_func(x, y): return x.to(tl.float32) >= y -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, -) def ge(A, B): logging.debug("GEMS GE") return ge_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) @triton.jit def ge_func_scalar(x, y): return x.to(tl.float32) >= y -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, -) def ge_scalar(A, B): logging.debug("GEMS GE SCALAR") return ge_func_scalar(A, B) diff --git a/src/flag_gems/ops/gelu.py b/src/flag_gems/ops/gelu.py index 028f0018..6f62689e 100644 --- a/src/flag_gems/ops/gelu.py +++ b/src/flag_gems/ops/gelu.py @@ -2,13 +2,11 @@ import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]]) @triton.jit def gelu_none(x): scale = 0.7071067811 @@ -16,7 +14,7 @@ def gelu_none(x): return output -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]]) @triton.jit def gelu_tanh(x): output = ( @@ -32,10 +30,6 @@ def gelu_tanh(x): return output -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def gelu(A, *, approximate="none"): logging.debug("GEMS GELU") if approximate == "tanh": diff --git a/src/flag_gems/ops/gt.py b/src/flag_gems/ops/gt.py index 599a5c5c..ce23dc32 100644 --- a/src/flag_gems/ops/gt.py +++ b/src/flag_gems/ops/gt.py @@ -1,39 +1,28 @@ import logging -import torch import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) @triton.jit def gt_func(x, y): return x.to(tl.float32) > y -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, -) def gt(A, B): logging.debug("GEMS GT") return gt_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) @triton.jit def gt_func_scalar(x, y): return x.to(tl.float32) > y -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, -) def gt_scalar(A, B): logging.debug("GEMS GT SCALAR") return gt_func_scalar(A, B) diff --git a/src/flag_gems/ops/isinf.py b/src/flag_gems/ops/isinf.py index 5505cffb..63189975 100644 --- a/src/flag_gems/ops/isinf.py +++ b/src/flag_gems/ops/isinf.py @@ -1,24 +1,17 @@ import logging -import torch import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[[0, "ALWAYS_BOOL"]]) @triton.jit def isinf_func(x): return tl.math.isinf(x.to(tl.float32)) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, -) def isinf(A): logging.debug("GEMS ISINF") return isinf_func(A) diff --git a/src/flag_gems/ops/isnan.py b/src/flag_gems/ops/isnan.py index efeedea2..46ff0f6d 100644 --- a/src/flag_gems/ops/isnan.py +++ b/src/flag_gems/ops/isnan.py @@ -1,24 +1,17 @@ import logging -import torch import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[[0, "ALWAYS_BOOL"]]) @triton.jit def isnan_func(x): return tl.math.isnan(x.to(tl.float32)) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, -) def isnan(A): logging.debug("GEMS ISNAN") return isnan_func(A) diff --git a/src/flag_gems/ops/le.py b/src/flag_gems/ops/le.py index c837c48e..514ccd49 100644 --- a/src/flag_gems/ops/le.py +++ b/src/flag_gems/ops/le.py @@ -1,39 +1,28 @@ import logging -import torch import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) @triton.jit def le_func(x, y): return x.to(tl.float32) <= y -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, -) def le(A, B): logging.debug("GEMS LE") return le_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) @triton.jit def le_func_scalar(x, y): return x.to(tl.float32) <= y -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, -) def le_scalar(A, B): logging.debug("GEMS LE SCALAR") return le_func_scalar(A, B) diff --git a/src/flag_gems/ops/lt.py b/src/flag_gems/ops/lt.py index be811c07..8d81bf11 100644 --- a/src/flag_gems/ops/lt.py +++ b/src/flag_gems/ops/lt.py @@ -1,39 +1,28 @@ import logging -import torch import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) @triton.jit def lt_func(x, y): return x.to(tl.float32) < y -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, -) def lt(A, B): logging.debug("GEMS LT") return lt_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) @triton.jit def lt_func_scalar(x, y): return x.to(tl.float32) < y -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, -) def lt_scalar(A, B): logging.debug("GEMS LT SCALAR") return lt_func_scalar(A, B) diff --git a/src/flag_gems/ops/mul.py b/src/flag_gems/ops/mul.py index 6a6f6ccd..ac2ff9d4 100644 --- a/src/flag_gems/ops/mul.py +++ b/src/flag_gems/ops/mul.py @@ -2,28 +2,22 @@ import torch import triton -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def mul_func(x, y): return x * y -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def mul_func_scalar(x, y): return x * y -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def mul(A, B): logging.debug("GEMS MUL") if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): diff --git a/src/flag_gems/ops/ne.py b/src/flag_gems/ops/ne.py index e410d1b3..35370400 100644 --- a/src/flag_gems/ops/ne.py +++ b/src/flag_gems/ops/ne.py @@ -1,39 +1,28 @@ import logging -import torch import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) @triton.jit def ne_func(x, y): return x.to(tl.float32) != y.to(tl.float32) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, -) def ne(A, B): logging.debug("GEMS NE") return ne_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) @triton.jit def ne_func_scalar(x, y): return x.to(tl.float32) != y.to(tl.float32) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, -) def ne_scalar(A, B): logging.debug("GEMS NE SCALAR") return ne_func_scalar(A, B) diff --git a/src/flag_gems/ops/neg.py b/src/flag_gems/ops/neg.py index 6e3db4c8..fb462ac5 100644 --- a/src/flag_gems/ops/neg.py +++ b/src/flag_gems/ops/neg.py @@ -1,22 +1,16 @@ import logging import triton -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]]) @triton.jit def neg_func(x): return -x -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def neg(A): logging.debug("GEMS NEG") return neg_func(A) diff --git a/src/flag_gems/ops/pow.py b/src/flag_gems/ops/pow.py index 98e20a65..664531ba 100644 --- a/src/flag_gems/ops/pow.py +++ b/src/flag_gems/ops/pow.py @@ -2,52 +2,38 @@ import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "BOOL_TO_LONG"]]) @triton.jit def pow_func(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, -) def pow_tensor_tensor(A, exponent): logging.debug("GEMS POW_TENSOR_TENSOR") return pow_func(A, exponent) -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, "BOOL_TO_LONG"]]) @triton.jit def pow_func_tensor_scalar(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, -) def pow_tensor_scalar(A, exponent): logging.debug("GEMS POW_TENSOR_SCALAR") return pow_func_tensor_scalar(A, exponent) -@pointwise_dynamic(is_tensor=[False, True]) +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, "BOOL_TO_LONG"]]) @triton.jit def pow_func_scalar_tensor(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, -) def pow_scalar(A, exponent): logging.debug("GEMS POW_SCALAR") return pow_func_scalar_tensor(A, exponent) diff --git a/src/flag_gems/ops/reciprocal.py b/src/flag_gems/ops/reciprocal.py index 33f7ebf3..d1bdaa45 100644 --- a/src/flag_gems/ops/reciprocal.py +++ b/src/flag_gems/ops/reciprocal.py @@ -2,22 +2,16 @@ import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) @triton.jit def reciprocal_func(x): return 1.0 / x.to(tl.float32) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, -) def reciprocal(A): logging.debug("GEMS RECIPROCAL") return reciprocal_func(A) diff --git a/src/flag_gems/ops/relu.py b/src/flag_gems/ops/relu.py index aec698c4..a56f2ca3 100644 --- a/src/flag_gems/ops/relu.py +++ b/src/flag_gems/ops/relu.py @@ -3,19 +3,17 @@ import torch import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]]) @triton.jit def relu_forward(x): return tl.where(x > 0, x, 0) -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]]) @triton.jit def relu_backward(x, dy): return tl.where(x > 0, dy, 0) @@ -37,9 +35,5 @@ def backward(ctx, out_grad): return in_grad -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def relu(A): return Relu.apply(A) diff --git a/src/flag_gems/ops/rsqrt.py b/src/flag_gems/ops/rsqrt.py index e15010eb..da235975 100644 --- a/src/flag_gems/ops/rsqrt.py +++ b/src/flag_gems/ops/rsqrt.py @@ -2,22 +2,16 @@ import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) @triton.jit def rsqrt_func(x): return 1.0 / tl.sqrt(x.to(tl.float32)) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, -) def rsqrt(A): logging.debug("GEMS RSQRT") return rsqrt_func(A) diff --git a/src/flag_gems/ops/sigmoid.py b/src/flag_gems/ops/sigmoid.py index e989da69..ef382f19 100644 --- a/src/flag_gems/ops/sigmoid.py +++ b/src/flag_gems/ops/sigmoid.py @@ -4,20 +4,18 @@ import torch import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) @triton.jit def sigmoid_forward(x): log2e: tl.constexpr = math.log2(math.e) return 1 / (1 + tl.math.exp2(-x.to(tl.float32) * log2e)) -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) @triton.jit def sigmoid_backward(y, dy): y_f32 = y.to(tl.float32) @@ -41,9 +39,5 @@ def backward(ctx, out_grad): return in_grad -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, -) def sigmoid(A): return Sigmoid.apply(A) diff --git a/src/flag_gems/ops/silu.py b/src/flag_gems/ops/silu.py index f0f4a7fc..84ee682b 100644 --- a/src/flag_gems/ops/silu.py +++ b/src/flag_gems/ops/silu.py @@ -3,13 +3,11 @@ import torch import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]]) @triton.jit def silu_forward(x): x_fp32 = x.to(tl.float32) @@ -17,7 +15,7 @@ def silu_forward(x): return y -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]]) @triton.jit def silu_backward(x, dy): dy_fp32 = dy.to(tl.float32) @@ -43,9 +41,5 @@ def backward(ctx, out_grad): return in_grad -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def silu(A): return Silu.apply(A) diff --git a/src/flag_gems/ops/sin.py b/src/flag_gems/ops/sin.py index 7394b93a..d0046e21 100644 --- a/src/flag_gems/ops/sin.py +++ b/src/flag_gems/ops/sin.py @@ -2,22 +2,16 @@ import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) @triton.jit def sin_func(x): return tl.sin(x.to(tl.float32)) -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, -) def sin(A): logging.debug("GEMS SIN") return sin_func(A) diff --git a/src/flag_gems/ops/sub.py b/src/flag_gems/ops/sub.py index 00dcc4fd..0bc09fb0 100644 --- a/src/flag_gems/ops/sub.py +++ b/src/flag_gems/ops/sub.py @@ -2,34 +2,32 @@ import torch import triton -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic(is_tensor=[True, True, False]) +@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[[0, 1, "DEFAULT"]]) @triton.jit def sub_func(x, y, alpha): return x - y * alpha -@pointwise_dynamic(is_tensor=[True, False, False]) +@pointwise_dynamic( + is_tensor=[True, False, False], promotion_methods=[[0, 1, "DEFAULT"]] +) @triton.jit def sub_func_tensor_scalar(x, y, alpha): return x - y * alpha -@pointwise_dynamic(is_tensor=[False, True, False]) +@pointwise_dynamic( + is_tensor=[False, True, False], promotion_methods=[[0, 1, "DEFAULT"]] +) @triton.jit def sub_func_scalar_tensor(x, y, alpha): return x - y * alpha -@elementwise_type_promotion_wrapper( - type_promoting_args=("A", "B"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, -) def sub(A, B, *, alpha=1): logging.debug("GEMS SUB") if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): diff --git a/src/flag_gems/ops/tanh.py b/src/flag_gems/ops/tanh.py index b6353fc8..f87dec29 100644 --- a/src/flag_gems/ops/tanh.py +++ b/src/flag_gems/ops/tanh.py @@ -3,19 +3,17 @@ import torch import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) @triton.jit def tanh_forward(x): return tl.math.tanh(x.to(tl.float32)) -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) @triton.jit def tanh_backward(y, dy): return dy * (1.0 - tl.math.pow(y.to(tl.float32), 2)) @@ -37,9 +35,5 @@ def backward(ctx, out_grad): return in_grad -@elementwise_type_promotion_wrapper( - type_promoting_args=("A"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, -) def tanh(A): return Tanh.apply(A) diff --git a/src/flag_gems/ops/where.py b/src/flag_gems/ops/where.py index 21f9fc06..859868ec 100644 --- a/src/flag_gems/ops/where.py +++ b/src/flag_gems/ops/where.py @@ -2,52 +2,44 @@ import triton import triton.language as tl -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND -from torch._prims_common.wrappers import elementwise_type_promotion_wrapper from ..utils import pointwise_dynamic -@pointwise_dynamic(is_tensor=[True, True, True]) +@pointwise_dynamic( + is_tensor=[True, True, True], promotion_methods=[[0, 1, "NO_OPMATH"]] +) @triton.jit -def where_self_func(self, condition, other): +def where_self_func(condition, self, other): return tl.where(condition, self, other) -@elementwise_type_promotion_wrapper( - type_promoting_args=("self", "other"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, -) def where_self(condition, self, other): logging.debug("GEMS WHERE_SELF") - return where_self_func(self, condition, other) + return where_self_func(condition, self, other) -@pointwise_dynamic(is_tensor=[True, True, False]) +@pointwise_dynamic( + is_tensor=[True, True, False], promotion_methods=[[0, 1, "NO_OPMATH"]] +) @triton.jit -def where_scalar_self_func(other, condition, self): +def where_scalar_self_func(condition, self, other): return tl.where(condition, self, other) -@elementwise_type_promotion_wrapper( - type_promoting_args=("self", "other"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, -) def where_scalar_self(condition, self, other): logging.debug("GEMS WHERE_SCALAR_SELF") - return where_scalar_self_func(other, condition, self) + return where_scalar_self_func(condition, self, other) -@pointwise_dynamic(is_tensor=[True, True, False]) +@pointwise_dynamic( + is_tensor=[True, True, False], promotion_methods=[[0, 1, "NO_OPMATH"]] +) @triton.jit -def where_scalar_other_func(self, condition, other): +def where_scalar_other_func(condition, self, other): return tl.where(condition, self, other) -@elementwise_type_promotion_wrapper( - type_promoting_args=("self", "other"), - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, -) def where_scalar_other(condition, self, other): logging.debug("GEMS WHERE_SCALAR_OTHER") - return where_scalar_other_func(self, condition, other) + return where_scalar_other_func(condition, self, other) diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index 7f658423..177061d8 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -3,6 +3,7 @@ from typing import Any, Callable, List, Mapping, Optional, Tuple import torch +import torch._prims_common as utils import triton from triton import language as tl from triton.runtime.jit import JITFunction @@ -41,7 +42,7 @@ class OPDesc: _num_non_tensor_inputs: int _num_outputs: int - _output_dtypes: List[torch.dtype] + _promotion_methods: List[Tuple[int, ...]] def __init__( self, @@ -50,12 +51,18 @@ def __init__( is_tensor: Optional[List[bool]] = None, dtypes: Optional[List[Optional[type]]] = None, num_outputs: Optional[int] = None, - output_dtypes: Optional[List[torch.dtype]] = None, + promotion_methods: Optional[List[Tuple[int, ...]]] = None, ): if is_tensor is not None: _check_typed_list(is_tensor, bool) if dtypes is not None: _check_typed_list(dtypes, (type, type(None))) + if promotion_methods is None: + raise ValueError( + "No type promotion method provided! You must provide type promotion method for each output!" + ) + else: + self._promotion_methods = promotion_methods if num_inputs is not None: self._num_inputs = num_inputs @@ -91,22 +98,11 @@ def __init__( "Cannot make OPDesc when none of (num_inputs, is_tensor, dtypes) is specified." ) - if output_dtypes is not None: - _check_typed_list(output_dtypes, torch.dtype) - if num_outputs is not None: self._num_outputs = num_outputs - if output_dtypes is not None: - _check_sized_list(output_dtypes, num_outputs) - self._output_dtypes = output_dtypes - else: - self._output_dtypes = [None] * num_outputs # infer from the 1st input - elif output_dtypes is not None: - self._num_outputs = len(output_dtypes) - self._output_dtypes = output_dtypes + _check_sized_list(promotion_methods, num_outputs) else: - self._num_outputs = 1 - self._output_dtypes = [None] + self._num_outputs = len(promotion_methods) assert self._num_inputs >= 1 assert self._num_outputs >= 1 @@ -127,9 +123,6 @@ def is_tensor(self, arg_id: int) -> bool: def input_type(self, arg_id) -> Optional[type]: return self._dtypes[arg_id] - def output_dtype(self, output_id) -> torch.dtype: - return self._output_dtypes[output_id] - def num_input_tensors(self) -> int: return self._num_input_tensors @@ -139,6 +132,23 @@ def num_output_tensors(self) -> int: def num_non_tensor_args(self) -> int: return self._num_non_tensor_inputs + def type_promotion_methods(self) -> List[Tuple[int, ...]]: + return self._promotion_methods + + def _match_enum_by_string( + self, input_str: str + ) -> utils.ELEMENTWISE_TYPE_PROMOTION_KIND: + for kind in utils.ELEMENTWISE_TYPE_PROMOTION_KIND: + if input_str.lower() == kind.name.lower(): + return kind + raise ValueError(f"No matching enum member found for input: {input_str}") + + def ith_type_promotion_args(self, i) -> List[int]: + return self._promotion_methods[i][:-1] + + def ith_type_promotion_kind(self, i) -> utils.ELEMENTWISE_TYPE_PROMOTION_KIND: + return self._match_enum_by_string(self._promotion_methods[i][-1]) + def signature(self, outputs_in_arg: bool = False): input_types = [] for is_tensor, dtype in zip(self._is_tensor, self._dtypes): @@ -151,11 +161,8 @@ def signature(self, outputs_in_arg: bool = False): input_types.append(_type_name(dtype)) output_types = [] - for dtype in self._output_dtypes: - if dtype is None: - output_types.append("Tensor") - else: - output_types.append(f"Tensor[{_type_name(dtype)}]") + for _ in range(self.num_outputs()): + output_types.append("Tensor") if outputs_in_arg: input_types.extend(output_types) sig = f'Pointwise: ({", ".join(input_types)}) -> ({", ".join(output_types)})' @@ -196,6 +203,27 @@ def parameter_for_wrapper(op_desc: OPDesc, include_outputs: bool = False) -> str return ", ".join(parameters) +def ith_parameter_for_type_promotion(op_desc: OPDesc, ith: int) -> str: + """Generate parameter reference for i-th type promotion rule + Example: in0, val0, out0 + """ + parameters: List[str] = [] + + input_tensor_index = 0 + non_tensor_index = 0 + for i in range(op_desc.num_inputs()): + if i not in op_desc.ith_type_promotion_args(ith): + break + if op_desc._is_tensor[i]: + parameters.append(f"in{input_tensor_index}") + input_tensor_index += 1 + else: + parameters.append(f"val{non_tensor_index}") + non_tensor_index += 1 + + return ", ".join(parameters) + + def parameter_ref_for_wrapper(op_desc: OPDesc, include_outputs: bool = False) -> str: """Generate parameter reference for wrapper function. Example: in0, val0, out0 @@ -253,6 +281,8 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer: code.writeline(" Stride,") code.writeline(")") code.writeline("from flag_gems.utils.libentry import libentry") + code.writeline("from flag_gems.utils.type_utils import type_promotion") + code.writeline("import torch._prims_common as utils") code.newline() code.newline() return code @@ -282,21 +312,17 @@ def generate_functional_pointwise_wrapper( # output allocation num_output_tensor_index = 0 for i in range(op_desc.num_outputs()): - if op_desc.output_dtype(i) is None: - code.writeline( - ( - f"out{num_output_tensor_index} = " - f"torch.empty(shape, dtype=in0.dtype, device=in0.device)" - ) - ) - else: - code.writeline( - ( - f"out{num_output_tensor_index} = " - f"torch.empty(shape, dtype={_type_name(op_desc.output_dtype(i))}, " - f"device=in0.device)" - ) + type_promotion_args = ith_parameter_for_type_promotion(op_desc, i) + print(type_promotion_args) + k_type_promotion = op_desc.ith_type_promotion_kind(i) + code.writeline( + ( + f"out{num_output_tensor_index} = " + f"torch.empty(shape, dtype=type_promotion" + f"({type_promotion_args}, type_promotion=utils.{k_type_promotion})[1], " + f"device=in0.device)" ) + ) num_output_tensor_index += 1 # call destination_passing_func @@ -659,7 +685,7 @@ def pointwise_dynamic( is_tensor: Optional[List[bool]] = None, dtypes: Optional[List[Optional[type]]] = None, num_outputs: Optional[int] = None, - output_dtypes: Optional[List[type]] = None, + promotion_methods: Optional[Tuple[int, ...]] = None, ): def decorator(fn): nonlocal num_inputs @@ -670,7 +696,7 @@ def decorator(fn): is_tensor=is_tensor, dtypes=dtypes, num_outputs=num_outputs, - output_dtypes=output_dtypes, + promotion_methods=promotion_methods, ) return PointwiseDynamicFunction(op_desc, fn) diff --git a/src/flag_gems/utils/type_utils.py b/src/flag_gems/utils/type_utils.py new file mode 100644 index 00000000..996cfb21 --- /dev/null +++ b/src/flag_gems/utils/type_utils.py @@ -0,0 +1,9 @@ +import torch._prims_common as utils + + +def type_promotion(*args, type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND): + computation_dtype, result_dtype = utils.elementwise_dtypes( + *args, + type_promotion_kind=type_promotion, + ) + return computation_dtype, result_dtype From 26147cb6605c8c9d0dd3a83e534f90d11a508c0a Mon Sep 17 00:00:00 2001 From: Bowen12992 Date: Wed, 26 Jun 2024 10:09:13 +0800 Subject: [PATCH 4/7] change docs for pointwise_dy --- README.md | 31 ++++++++++++++++++++++++++++--- README_cn.md | 33 +++++++++++++++++++++++++++++---- 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 0aee0981..011eb546 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ In FlagGems, we provide automatic code generation that developers can use to con Decorating the pointwise operator function with `pointwise_dynamic` can save the manual handling of tensor addressing, tensor read/write, parallel tiling, tensor broadcasting, dynamic dimensions, non-contiguous storage, etc. For example, in the following code, developers only need to describe the computational logic to generate flexible and efficient Triton code. ```python -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "COMPLEX_TO_FLOAT"]]) @triton.jit def abs_func(x): return tl.abs(x) @@ -29,7 +29,11 @@ def abs_func(x): By default, `pointwise_dynamic` treats all parameters as tensors, and by passing a list of boolean values to the parameter `is_tensor`, developers can specify which parameters are tensors and which are not. Additionally, developers can pass in `dtypes` to indicate the data types of non-tensor parameters, but this is not required. For example, in the following code, the `alpha` parameter is defined as a non-tensor floating point number, while the `x` and `y` parameters are defined as tensors. ```python -@pointwise_dynamic(is_tensor=[True, True, False], dtypes=[None, None, float]) +@pointwise_dynamic( + is_tensor=[True, True, False], + dtypes=[None, None, float], + promotion_methods=[[0,"DEFAULT"]] +) @triton.jit def add_func(x, y, alpha): return x + y * alpha @@ -37,7 +41,7 @@ def add_func(x, y, alpha): #### Output Data Type -By default, all output tensors have the same data type as the first input tensor, but it can also be customized by providing a list of data types to the parameter `output_dtypes`. For example, in the following code, the output tensor type is specified as `torch.bool`. +Furthermore, developers MUST provide promotion_methods to specify how type promotion should be handled for the operation to achieve the correct output type during computation. ```python @pointwise_dynamic(output_dtypes=[torch.bool]) @@ -46,6 +50,27 @@ def ge(x, y): return x > y ``` +In `promotion_methods`, an `int` is used to indicate the position of the parameter requiring type promotion, while a `str` denotes the method of type promotion. The `str` corresponds to the following enumerated types: + +```python +class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum): + DEFAULT = (0,) + NO_OPMATH = (1,) + INT_TO_FLOAT = (2,) + ALWAYS_BOOL = (3,) + COMPLEX_TO_FLOAT = (4,) + BOOL_TO_LONG = (5,) +``` + +Examples: + +- `DEFAULT` :add +- `NO_OPMATH` : where, nextafter, cat +- `INT_TO_FLOAT` :sin +- `ALWAYS_BOOL` :eq +- `COMPLEX_TO_FLOAT` :abs +- `BOOL_TO_LONG` :pow + ## Changelog ### v1.0 diff --git a/README_cn.md b/README_cn.md index 74ddbaff..9ae3c2a9 100644 --- a/README_cn.md +++ b/README_cn.md @@ -18,7 +18,7 @@ FlagGems通过对PyTorch的后端aten算子进行覆盖重写,实现算子库 在对位算子函数前装饰`pointwise_dynamic`,可以节省张量寻址、张量读写、并行分块、张量广播、动态维度、非连续存储等的手动处理。例如以下代码,开发者只需简单描述计算逻辑,即可生成灵活高效的Triton核函数与包装代码。 ```python -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[[0, "COMPLEX_TO_FLOAT"]]) @triton.jit def abs_func(x): return tl.abs(x) @@ -29,7 +29,11 @@ def abs_func(x): 在默认情况下,`pointwise_dynamic`将所有参数均处理为张量,而通过向参数`is_tensor`传递布尔值列表,开发者可以指定哪些参数是张量,哪些参数非张量。此外,开发者还可以传入`dtypes`说明非张量参数的数据类型,但这不是必要的。例如以下代码,将`alpha`参数定义为非张量的浮点数,而`x`和`y`参数定义为张量。 ```python -@pointwise_dynamic(is_tensor=[True, True, False], dtypes=[None, None, float]) +@pointwise_dynamic( + is_tensor=[True, True, False], + dtypes=[None, None, float], + promotion_methods=[[0,"DEFAULT"]] +) @triton.jit def add_func(x, y, alpha): return x + y * alpha @@ -37,15 +41,36 @@ def add_func(x, y, alpha): #### 输出数据类型 -在默认情况下,输出张量使用与首个输入张量相同的数据类型,但也可向参数`output_dtypes`传入数据类型组成的列表来指定。例如以下代码,指定输出张量类型为`torch.bool`。 +此外,开发者必须传入 `promotion_methods` 来说明该 Op 在进行计算时应该如何进行`类型提升`以获得正确的输出类型 ```python -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[[0, "ALWAYS_BOOL"]]) @triton.jit def ge(x, y): return x > y ``` +`promotion_methods` 通过传入 `int` 来表示需要进行类型提升的参数位置, 通过传入 `str` 来表示类型提升的方式, `str` 对于以下枚举类型 + +```python +class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum): + DEFAULT = (0,) + NO_OPMATH = (1,) + INT_TO_FLOAT = (2,) + ALWAYS_BOOL = (3,) + COMPLEX_TO_FLOAT = (4,) + BOOL_TO_LONG = (5,) +``` + +举例: + +- `DEFAULT` :add +- `NO_OPMATH` : where, nextafter, cat +- `INT_TO_FLOAT` :sin +- `ALWAYS_BOOL` :eq +- `COMPLEX_TO_FLOAT` :abs +- `BOOL_TO_LONG` :pow + ## 更新日志 ### v1.0 From 6c80334aa81c0706abfc142f7f1a611cece1e9ad Mon Sep 17 00:00:00 2001 From: Bowen12992 Date: Wed, 26 Jun 2024 16:12:00 +0800 Subject: [PATCH 5/7] add unit test --- .github/workflows/python-test.yaml | 1 + src/flag_gems/ops/pow.py | 6 +- src/flag_gems/ops/where.py | 6 +- src/flag_gems/utils/pointwise_dynamic.py | 6 +- tests/test_pointwise_type_promotion.py | 126 +++++++++++++++++++++++ 5 files changed, 138 insertions(+), 7 deletions(-) create mode 100644 tests/test_pointwise_type_promotion.py diff --git a/.github/workflows/python-test.yaml b/.github/workflows/python-test.yaml index 614fb31b..e4f42a1a 100644 --- a/.github/workflows/python-test.yaml +++ b/.github/workflows/python-test.yaml @@ -26,6 +26,7 @@ jobs: - name: unit_test-flag-gems run: | CUDA_VISIBLE_DEVICES=0 pytest -s tests/test_unary_pointwise_ops.py & + CUDA_VISIBLE_DEVICES=0 pytest -s tests/test_pointwise_type_promotion.py & CUDA_VISIBLE_DEVICES=1 pytest -s tests/test_binary_pointwise_ops.py & CUDA_VISIBLE_DEVICES=2 pytest -s tests/test_blas_ops.py & CUDA_VISIBLE_DEVICES=3 pytest -s tests/test_reduction_ops.py & diff --git a/src/flag_gems/ops/pow.py b/src/flag_gems/ops/pow.py index 664531ba..e3c17f48 100644 --- a/src/flag_gems/ops/pow.py +++ b/src/flag_gems/ops/pow.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "BOOL_TO_LONG"]]) +@pointwise_dynamic(promotion_methods=[[0, 1, "BOOL_TO_LONG"]]) @triton.jit def pow_func(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) @@ -17,7 +17,7 @@ def pow_tensor_tensor(A, exponent): return pow_func(A, exponent) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, "BOOL_TO_LONG"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "BOOL_TO_LONG"]]) @triton.jit def pow_func_tensor_scalar(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) @@ -28,7 +28,7 @@ def pow_tensor_scalar(A, exponent): return pow_func_tensor_scalar(A, exponent) -@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, "BOOL_TO_LONG"]]) +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, 1, "BOOL_TO_LONG"]]) @triton.jit def pow_func_scalar_tensor(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) diff --git a/src/flag_gems/ops/where.py b/src/flag_gems/ops/where.py index 859868ec..ec48a136 100644 --- a/src/flag_gems/ops/where.py +++ b/src/flag_gems/ops/where.py @@ -7,7 +7,7 @@ @pointwise_dynamic( - is_tensor=[True, True, True], promotion_methods=[[0, 1, "NO_OPMATH"]] + is_tensor=[True, True, True], promotion_methods=[[1, 2, "NO_OPMATH"]] ) @triton.jit def where_self_func(condition, self, other): @@ -20,7 +20,7 @@ def where_self(condition, self, other): @pointwise_dynamic( - is_tensor=[True, True, False], promotion_methods=[[0, 1, "NO_OPMATH"]] + is_tensor=[True, True, False], promotion_methods=[[1, 2, "NO_OPMATH"]] ) @triton.jit def where_scalar_self_func(condition, self, other): @@ -33,7 +33,7 @@ def where_scalar_self(condition, self, other): @pointwise_dynamic( - is_tensor=[True, True, False], promotion_methods=[[0, 1, "NO_OPMATH"]] + is_tensor=[True, True, False], promotion_methods=[[1, 2, "NO_OPMATH"]] ) @triton.jit def where_scalar_other_func(condition, self, other): diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index 177061d8..e6493943 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -213,7 +213,11 @@ def ith_parameter_for_type_promotion(op_desc: OPDesc, ith: int) -> str: non_tensor_index = 0 for i in range(op_desc.num_inputs()): if i not in op_desc.ith_type_promotion_args(ith): - break + if op_desc._is_tensor[i]: + input_tensor_index += 1 + else: + non_tensor_index += 1 + continue if op_desc._is_tensor[i]: parameters.append(f"in{input_tensor_index}") input_tensor_index += 1 diff --git a/tests/test_pointwise_type_promotion.py b/tests/test_pointwise_type_promotion.py new file mode 100644 index 00000000..4444a8eb --- /dev/null +++ b/tests/test_pointwise_type_promotion.py @@ -0,0 +1,126 @@ +import logging + +import pytest +import torch + +import flag_gems + +from .accuracy_utils import ( + FLOAT_DTYPES, + POINTWISE_SHAPES, + SCALARS, + gems_assert_close, + gems_assert_equal, + to_reference, +) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("alpha", SCALARS) +@pytest.mark.parametrize("float_type", FLOAT_DTYPES) +def test_type_promotion_default(shape, alpha, float_type): + inp1 = torch.randint(10, shape, device="cuda") + inp2 = torch.randn(shape, dtype=float_type, device="cuda") + ref_inp2 = to_reference(inp2, True) + # arg0:int arg1:float + ref_out = torch.add(inp1, ref_inp2, alpha=alpha) + with flag_gems.use_gems(): + res_out = torch.add(inp1, inp2, alpha=alpha) + gems_assert_close(res_out, ref_out, float_type) + # arg0:float arg1:int + ref_out = torch.add(ref_inp2, inp1, alpha=alpha) + with flag_gems.use_gems(): + res_out = torch.add(inp2, inp1, alpha=alpha) + gems_assert_close(res_out, ref_out, float_type) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("float_type", FLOAT_DTYPES) +def test_type_promotion_no_opmath(shape, float_type): + inp1 = torch.randint(10, shape, device="cuda") + inp2 = torch.randn(shape, dtype=float_type, device="cuda") + ref_inp2 = to_reference(inp2) + # arg0:bool arg1:int arg2:float + ref_out = torch.where(inp1 > 0, inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.where(inp1 > 0, inp1, inp2) + gems_assert_equal(res_out, ref_out) + + # arg0:bool arg1:float arg2:int + ref_out = torch.where(inp1 > 0, ref_inp2, inp1) + with flag_gems.use_gems(): + res_out = torch.where(inp1 > 0, inp2, inp1) + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("float_type", FLOAT_DTYPES) +def test_type_promotion_int_to_float(shape, float_type): + # arg0:float + inp_float = torch.randn(shape, dtype=float_type, device="cuda") + ref_inp = to_reference(inp_float, True) + ref_out = torch.sin(ref_inp) + with flag_gems.use_gems(): + res_out = torch.sin(inp_float) + gems_assert_close(res_out, ref_out, float_type) + + # arg0:int + inp_int = torch.randint(10, shape, device="cuda") + ref_out = torch.sin(inp_int) + with flag_gems.use_gems(): + res_out = torch.sin(inp_int) + gems_assert_close(res_out, ref_out, torch.float32) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +def test_type_promotion_always_bool(shape): + # arg0:int arg0:int + inp1 = torch.randint(0, 10, shape, device="cuda") + inp2 = torch.randint(0, 10, shape, device="cuda") + ref_inp1 = to_reference(inp1) + ref_inp2 = to_reference(inp2) + ref_out = torch.eq(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.eq(inp1, inp2) + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("float_type", FLOAT_DTYPES) +def test_type_promotion_complex_to_long(shape, float_type): + # arg0:float + inp = torch.randn(shape, dtype=float_type, device="cuda") + ref_inp = to_reference(inp) + ref_out = torch.abs(ref_inp) + with flag_gems.use_gems(): + res_out = torch.abs(inp) + gems_assert_equal(res_out, ref_out) + + # arg0:int + inp1 = torch.randint(0, 10, shape, device="cuda") + ref_out1 = torch.abs(inp1) + with flag_gems.use_gems(): + res_out1 = torch.abs(inp1) + gems_assert_equal(res_out1, ref_out1) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("float_dtype", FLOAT_DTYPES) +def test_type_promotion_bool_to_long(shape, float_dtype): + inp1 = torch.randn(shape, dtype=float_dtype, device="cuda") + inp2 = torch.randint(0, 10, shape, device="cuda") + # arg0: float arg1: int + ref_out = torch.pow(inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.pow(inp1, inp2) + logging.debug(ref_out.dtype) + logging.debug(res_out.dtype) + gems_assert_close(res_out, ref_out, float_dtype, equal_nan=True) + + # arg0: int arg1: float + ref_out = torch.pow(inp2, inp1) + with flag_gems.use_gems(): + res_out = torch.pow(inp2, inp1) + logging.debug(ref_out.dtype) + logging.debug(res_out.dtype) + gems_assert_close(res_out, ref_out, float_dtype, equal_nan=True) From 204e8443d2d7fe24440e9b6bd447ca538e730da3 Mon Sep 17 00:00:00 2001 From: Bowen12992 Date: Thu, 27 Jun 2024 14:48:33 +0800 Subject: [PATCH 6/7] add promotion_methods --- src/flag_gems/utils/pointwise_dynamic.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index e6493943..2ef9885c 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -317,7 +317,6 @@ def generate_functional_pointwise_wrapper( num_output_tensor_index = 0 for i in range(op_desc.num_outputs()): type_promotion_args = ith_parameter_for_type_promotion(op_desc, i) - print(type_promotion_args) k_type_promotion = op_desc.ith_type_promotion_kind(i) code.writeline( ( @@ -711,7 +710,11 @@ def decorator(fn): if __name__ == "__main__": - @pointwise_dynamic(is_tensor=[True, False, True], dtypes=[None, float, None]) + @pointwise_dynamic( + is_tensor=[True, False, True], + dtypes=[None, float, None], + promotion_methods=[[0, 1, 2, "DEFAULT"]], + ) @triton.jit def saxpy(x, alpha, y): return x * alpha + y @@ -725,7 +728,9 @@ def saxpy(x, alpha, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic(is_tensor=[True, False, True]) + @pointwise_dynamic( + is_tensor=[True, False, True], promotion_methods=[[0, 1, 2, "DEFAULT"]] + ) @triton.jit def saxpy(x, alpha, y): return x * alpha + y @@ -737,7 +742,7 @@ def saxpy(x, alpha, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic(output_dtypes=[torch.bool]) + @pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) @triton.jit def ge(x, y): return x > y @@ -749,7 +754,7 @@ def ge(x, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic() + @pointwise_dynamic(promotion_methods=[[0, 1, "INT_TO_FLOAT"]]) @triton.jit def ordinary(x, y): return tl.sin(x) + tl.cos(y) @@ -761,7 +766,7 @@ def ordinary(x, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic + @pointwise_dynamic(promotion_methods=[[0, 1, "INT_TO_FLOAT"]]) @triton.jit def ordinary2(x, y): return tl.sin(x) + tl.cos(y) @@ -773,7 +778,7 @@ def ordinary2(x, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic + @pointwise_dynamic(promotion_methods=[[0, 1, "INT_TO_FLOAT"]]) @triton.jit def ordinary2(x, y): return tl.sin(x) + tl.cos(y) @@ -787,7 +792,9 @@ def ordinary2(x, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) + @pointwise_dynamic( + is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]] + ) @triton.jit def eq(x, y): return x.to(tl.float32) == y.to( From 70099832d927be34b1905dac764640bf048603a3 Mon Sep 17 00:00:00 2001 From: Bowen12992 Date: Mon, 1 Jul 2024 15:08:49 +0800 Subject: [PATCH 7/7] use tuple rather than list --- README.md | 4 ++-- README_cn.md | 6 +++--- src/flag_gems/fused/gelu_and_mul.py | 4 ++-- src/flag_gems/fused/silu_and_mul.py | 2 +- src/flag_gems/ops/abs.py | 2 +- src/flag_gems/ops/add.py | 6 +++--- src/flag_gems/ops/bitwise_and.py | 4 ++-- src/flag_gems/ops/bitwise_not.py | 2 +- src/flag_gems/ops/bitwise_or.py | 4 ++-- src/flag_gems/ops/clamp.py | 12 ++++++------ src/flag_gems/ops/cos.py | 2 +- src/flag_gems/ops/div.py | 18 +++++++++--------- src/flag_gems/ops/eq.py | 4 ++-- src/flag_gems/ops/exp.py | 2 +- src/flag_gems/ops/ge.py | 4 ++-- src/flag_gems/ops/gelu.py | 4 ++-- src/flag_gems/ops/gt.py | 4 ++-- src/flag_gems/ops/isinf.py | 2 +- src/flag_gems/ops/isnan.py | 2 +- src/flag_gems/ops/le.py | 4 ++-- src/flag_gems/ops/lt.py | 4 ++-- src/flag_gems/ops/mul.py | 4 ++-- src/flag_gems/ops/ne.py | 4 ++-- src/flag_gems/ops/neg.py | 2 +- src/flag_gems/ops/pow.py | 6 +++--- src/flag_gems/ops/reciprocal.py | 2 +- src/flag_gems/ops/relu.py | 4 ++-- src/flag_gems/ops/rsqrt.py | 2 +- src/flag_gems/ops/sigmoid.py | 4 ++-- src/flag_gems/ops/silu.py | 4 ++-- src/flag_gems/ops/sin.py | 2 +- src/flag_gems/ops/sub.py | 6 +++--- src/flag_gems/ops/tanh.py | 4 ++-- src/flag_gems/ops/where.py | 6 +++--- src/flag_gems/utils/pointwise_dynamic.py | 14 +++++++------- 35 files changed, 80 insertions(+), 80 deletions(-) diff --git a/README.md b/README.md index 011eb546..948363c0 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ In FlagGems, we provide automatic code generation that developers can use to con Decorating the pointwise operator function with `pointwise_dynamic` can save the manual handling of tensor addressing, tensor read/write, parallel tiling, tensor broadcasting, dynamic dimensions, non-contiguous storage, etc. For example, in the following code, developers only need to describe the computational logic to generate flexible and efficient Triton code. ```python -@pointwise_dynamic(promotion_methods=[[0, "COMPLEX_TO_FLOAT"]]) +@pointwise_dynamic(promotion_methods=[(0, "COMPLEX_TO_FLOAT")]) @triton.jit def abs_func(x): return tl.abs(x) @@ -32,7 +32,7 @@ By default, `pointwise_dynamic` treats all parameters as tensors, and by passing @pointwise_dynamic( is_tensor=[True, True, False], dtypes=[None, None, float], - promotion_methods=[[0,"DEFAULT"]] + promotion_methods=[(0,"DEFAULT")] ) @triton.jit def add_func(x, y, alpha): diff --git a/README_cn.md b/README_cn.md index 9ae3c2a9..289b9e67 100644 --- a/README_cn.md +++ b/README_cn.md @@ -18,7 +18,7 @@ FlagGems通过对PyTorch的后端aten算子进行覆盖重写,实现算子库 在对位算子函数前装饰`pointwise_dynamic`,可以节省张量寻址、张量读写、并行分块、张量广播、动态维度、非连续存储等的手动处理。例如以下代码,开发者只需简单描述计算逻辑,即可生成灵活高效的Triton核函数与包装代码。 ```python -@pointwise_dynamic(promotion_methods=[[0, "COMPLEX_TO_FLOAT"]]) +@pointwise_dynamic(promotion_methods=[(0, "COMPLEX_TO_FLOAT")]) @triton.jit def abs_func(x): return tl.abs(x) @@ -32,7 +32,7 @@ def abs_func(x): @pointwise_dynamic( is_tensor=[True, True, False], dtypes=[None, None, float], - promotion_methods=[[0,"DEFAULT"]] + promotion_methods=[(0,"DEFAULT")] ) @triton.jit def add_func(x, y, alpha): @@ -44,7 +44,7 @@ def add_func(x, y, alpha): 此外,开发者必须传入 `promotion_methods` 来说明该 Op 在进行计算时应该如何进行`类型提升`以获得正确的输出类型 ```python -@pointwise_dynamic(promotion_methods=[[0, "ALWAYS_BOOL"]]) +@pointwise_dynamic(promotion_methods=[(0, "ALWAYS_BOOL")]) @triton.jit def ge(x, y): return x > y diff --git a/src/flag_gems/fused/gelu_and_mul.py b/src/flag_gems/fused/gelu_and_mul.py index d9dd984b..52fdc7a7 100644 --- a/src/flag_gems/fused/gelu_and_mul.py +++ b/src/flag_gems/fused/gelu_and_mul.py @@ -7,7 +7,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def gelu_none_and_mul_kernel(x, y): x_fp32 = x.to(tl.float32) @@ -15,7 +15,7 @@ def gelu_none_and_mul_kernel(x, y): return x_gelu * y -@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def gelu_tanh_and_mul_kernel(x, y): x_fp32 = x.to(tl.float32) diff --git a/src/flag_gems/fused/silu_and_mul.py b/src/flag_gems/fused/silu_and_mul.py index e5cbbfd1..70b1b326 100644 --- a/src/flag_gems/fused/silu_and_mul.py +++ b/src/flag_gems/fused/silu_and_mul.py @@ -7,7 +7,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def silu_and_mul_kernel(x, y): x_fp32 = x.to(tl.float32) diff --git a/src/flag_gems/ops/abs.py b/src/flag_gems/ops/abs.py index d2267278..54838034 100644 --- a/src/flag_gems/ops/abs.py +++ b/src/flag_gems/ops/abs.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "COMPLEX_TO_FLOAT"]]) +@pointwise_dynamic(promotion_methods=[(0, "COMPLEX_TO_FLOAT")]) @triton.jit def abs_func(x): return tl.abs(x) diff --git a/src/flag_gems/ops/add.py b/src/flag_gems/ops/add.py index a0181087..d7e52385 100644 --- a/src/flag_gems/ops/add.py +++ b/src/flag_gems/ops/add.py @@ -6,14 +6,14 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def add_func(x, y, alpha): return x + y * alpha @pointwise_dynamic( - is_tensor=[True, False, False], promotion_methods=[[0, 1, "DEFAULT"]] + is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] ) @triton.jit def add_func_tensor_scalar(x, y, alpha): @@ -21,7 +21,7 @@ def add_func_tensor_scalar(x, y, alpha): @pointwise_dynamic( - is_tensor=[False, True, False], promotion_methods=[[0, 1, "DEFAULT"]] + is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] ) @triton.jit def add_func_scalar_tensor(x, y, alpha): diff --git a/src/flag_gems/ops/bitwise_and.py b/src/flag_gems/ops/bitwise_and.py index 7820ccac..f05655e7 100644 --- a/src/flag_gems/ops/bitwise_and.py +++ b/src/flag_gems/ops/bitwise_and.py @@ -5,7 +5,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def bitwise_and_func(x, y): return x & y @@ -16,7 +16,7 @@ def bitwise_and_tensor(A, B): return bitwise_and_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def bitwise_and_func_scalar(x, y): return x & y diff --git a/src/flag_gems/ops/bitwise_not.py b/src/flag_gems/ops/bitwise_not.py index e3d720c7..49402046 100644 --- a/src/flag_gems/ops/bitwise_not.py +++ b/src/flag_gems/ops/bitwise_not.py @@ -5,7 +5,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, "DEFAULT"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")]) @triton.jit def bitwise_not_func(x): return ~x diff --git a/src/flag_gems/ops/bitwise_or.py b/src/flag_gems/ops/bitwise_or.py index bd0d97ec..20f0b895 100644 --- a/src/flag_gems/ops/bitwise_or.py +++ b/src/flag_gems/ops/bitwise_or.py @@ -5,7 +5,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def bitwise_or_func(x, y): return x | y @@ -16,7 +16,7 @@ def bitwise_or_tensor(A, B): return bitwise_or_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def bitwise_or_func_scalar(x, y): return x | y diff --git a/src/flag_gems/ops/clamp.py b/src/flag_gems/ops/clamp.py index c0ff1c5f..74cb293a 100644 --- a/src/flag_gems/ops/clamp.py +++ b/src/flag_gems/ops/clamp.py @@ -6,19 +6,19 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, 1, 2, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, 2, "DEFAULT")]) @triton.jit def clamp_func_tensor(x, mini, maxi): return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) -@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def clamp_func_min_tensor(x, mini): return tl.maximum(mini, x.to(tl.float32)) -@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def clamp_func_max_tensor(x, maxi): return tl.minimum(maxi, x.to(tl.float32)) @@ -37,20 +37,20 @@ def clamp_tensor(A, mini=None, maxi=None): @pointwise_dynamic( - is_tensor=[True, False, False], promotion_methods=[[0, 1, 2, "DEFAULT"]] + is_tensor=[True, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")] ) @triton.jit def clamp_func(x, mini, maxi): return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def clamp_func_min(x, mini): return tl.maximum(mini, x.to(tl.float32)) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def clamp_func_max(x, maxi): return tl.minimum(maxi, x.to(tl.float32)) diff --git a/src/flag_gems/ops/cos.py b/src/flag_gems/ops/cos.py index 8bb1bff0..5762763b 100644 --- a/src/flag_gems/ops/cos.py +++ b/src/flag_gems/ops/cos.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def cos_func(x): return tl.cos(x.to(tl.float32)) diff --git a/src/flag_gems/ops/div.py b/src/flag_gems/ops/div.py index 8acbc848..46ba3565 100644 --- a/src/flag_gems/ops/div.py +++ b/src/flag_gems/ops/div.py @@ -6,19 +6,19 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, 1, "INT_TO_FLOAT"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) @triton.jit def true_div_func(x, y): return x / y -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "INT_TO_FLOAT"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) @triton.jit def true_div_func_tensor_scalar(x, y): return x / y -@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, 1, "INT_TO_FLOAT"]]) +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) @triton.jit def true_div_func_scalar_tensor(x, y): return x / y @@ -37,19 +37,19 @@ def true_divide(A, B): return A / B -@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def trunc_div_func(x, y): return triton.div_rz(x, y) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def trunc_div_func_tensor_scalar(x, y): return triton.div_rz(x, y) -@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def trunc_div_func_scalar_tensor(x, y): return triton.div_rz(x, y) @@ -68,19 +68,19 @@ def trunc_divide(A, B): return A / B -@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def floor_div_func(x, y): return x // y -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def floor_div_func_tensor_scalar(x, y): return x // y -@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def floor_div_func_scalar_tensor(x, y): return x // y diff --git a/src/flag_gems/ops/eq.py b/src/flag_gems/ops/eq.py index ec4a1f40..21be2422 100644 --- a/src/flag_gems/ops/eq.py +++ b/src/flag_gems/ops/eq.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def eq_func(x, y): return x.to(tl.float32) == y.to(tl.float32) @@ -17,7 +17,7 @@ def eq(A, B): return eq_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def eq_func_scalar(x, y): return x.to(tl.float32) == y.to(tl.float32) diff --git a/src/flag_gems/ops/exp.py b/src/flag_gems/ops/exp.py index 3b37048e..bb774457 100644 --- a/src/flag_gems/ops/exp.py +++ b/src/flag_gems/ops/exp.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def exp_func(x): return tl.exp(x.to(tl.float32)) diff --git a/src/flag_gems/ops/ge.py b/src/flag_gems/ops/ge.py index 5d58410a..064466e1 100644 --- a/src/flag_gems/ops/ge.py +++ b/src/flag_gems/ops/ge.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def ge_func(x, y): return x.to(tl.float32) >= y @@ -17,7 +17,7 @@ def ge(A, B): return ge_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def ge_func_scalar(x, y): return x.to(tl.float32) >= y diff --git a/src/flag_gems/ops/gelu.py b/src/flag_gems/ops/gelu.py index 6f62689e..0b612793 100644 --- a/src/flag_gems/ops/gelu.py +++ b/src/flag_gems/ops/gelu.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def gelu_none(x): scale = 0.7071067811 @@ -14,7 +14,7 @@ def gelu_none(x): return output -@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def gelu_tanh(x): output = ( diff --git a/src/flag_gems/ops/gt.py b/src/flag_gems/ops/gt.py index ce23dc32..27625421 100644 --- a/src/flag_gems/ops/gt.py +++ b/src/flag_gems/ops/gt.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def gt_func(x, y): return x.to(tl.float32) > y @@ -17,7 +17,7 @@ def gt(A, B): return gt_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def gt_func_scalar(x, y): return x.to(tl.float32) > y diff --git a/src/flag_gems/ops/isinf.py b/src/flag_gems/ops/isinf.py index 63189975..15bb5fe7 100644 --- a/src/flag_gems/ops/isinf.py +++ b/src/flag_gems/ops/isinf.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "ALWAYS_BOOL"]]) +@pointwise_dynamic(promotion_methods=[(0, "ALWAYS_BOOL")]) @triton.jit def isinf_func(x): return tl.math.isinf(x.to(tl.float32)) diff --git a/src/flag_gems/ops/isnan.py b/src/flag_gems/ops/isnan.py index 46ff0f6d..163a7ddf 100644 --- a/src/flag_gems/ops/isnan.py +++ b/src/flag_gems/ops/isnan.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "ALWAYS_BOOL"]]) +@pointwise_dynamic(promotion_methods=[(0, "ALWAYS_BOOL")]) @triton.jit def isnan_func(x): return tl.math.isnan(x.to(tl.float32)) diff --git a/src/flag_gems/ops/le.py b/src/flag_gems/ops/le.py index 514ccd49..d0dea92c 100644 --- a/src/flag_gems/ops/le.py +++ b/src/flag_gems/ops/le.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def le_func(x, y): return x.to(tl.float32) <= y @@ -17,7 +17,7 @@ def le(A, B): return le_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def le_func_scalar(x, y): return x.to(tl.float32) <= y diff --git a/src/flag_gems/ops/lt.py b/src/flag_gems/ops/lt.py index 8d81bf11..c7bdf71a 100644 --- a/src/flag_gems/ops/lt.py +++ b/src/flag_gems/ops/lt.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def lt_func(x, y): return x.to(tl.float32) < y @@ -17,7 +17,7 @@ def lt(A, B): return lt_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def lt_func_scalar(x, y): return x.to(tl.float32) < y diff --git a/src/flag_gems/ops/mul.py b/src/flag_gems/ops/mul.py index ac2ff9d4..d33d1f83 100644 --- a/src/flag_gems/ops/mul.py +++ b/src/flag_gems/ops/mul.py @@ -6,13 +6,13 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def mul_func(x, y): return x * y -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def mul_func_scalar(x, y): return x * y diff --git a/src/flag_gems/ops/ne.py b/src/flag_gems/ops/ne.py index 35370400..4cf58980 100644 --- a/src/flag_gems/ops/ne.py +++ b/src/flag_gems/ops/ne.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def ne_func(x, y): return x.to(tl.float32) != y.to(tl.float32) @@ -17,7 +17,7 @@ def ne(A, B): return ne_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def ne_func_scalar(x, y): return x.to(tl.float32) != y.to(tl.float32) diff --git a/src/flag_gems/ops/neg.py b/src/flag_gems/ops/neg.py index fb462ac5..363eb9d2 100644 --- a/src/flag_gems/ops/neg.py +++ b/src/flag_gems/ops/neg.py @@ -5,7 +5,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def neg_func(x): return -x diff --git a/src/flag_gems/ops/pow.py b/src/flag_gems/ops/pow.py index e3c17f48..5dc224d8 100644 --- a/src/flag_gems/ops/pow.py +++ b/src/flag_gems/ops/pow.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, 1, "BOOL_TO_LONG"]]) +@pointwise_dynamic(promotion_methods=[(0, 1, "BOOL_TO_LONG")]) @triton.jit def pow_func(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) @@ -17,7 +17,7 @@ def pow_tensor_tensor(A, exponent): return pow_func(A, exponent) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "BOOL_TO_LONG"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "BOOL_TO_LONG")]) @triton.jit def pow_func_tensor_scalar(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) @@ -28,7 +28,7 @@ def pow_tensor_scalar(A, exponent): return pow_func_tensor_scalar(A, exponent) -@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, 1, "BOOL_TO_LONG"]]) +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "BOOL_TO_LONG")]) @triton.jit def pow_func_scalar_tensor(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) diff --git a/src/flag_gems/ops/reciprocal.py b/src/flag_gems/ops/reciprocal.py index d1bdaa45..948c2727 100644 --- a/src/flag_gems/ops/reciprocal.py +++ b/src/flag_gems/ops/reciprocal.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def reciprocal_func(x): return 1.0 / x.to(tl.float32) diff --git a/src/flag_gems/ops/relu.py b/src/flag_gems/ops/relu.py index a56f2ca3..2eac7b3c 100644 --- a/src/flag_gems/ops/relu.py +++ b/src/flag_gems/ops/relu.py @@ -7,13 +7,13 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def relu_forward(x): return tl.where(x > 0, x, 0) -@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def relu_backward(x, dy): return tl.where(x > 0, dy, 0) diff --git a/src/flag_gems/ops/rsqrt.py b/src/flag_gems/ops/rsqrt.py index da235975..636dbb5f 100644 --- a/src/flag_gems/ops/rsqrt.py +++ b/src/flag_gems/ops/rsqrt.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def rsqrt_func(x): return 1.0 / tl.sqrt(x.to(tl.float32)) diff --git a/src/flag_gems/ops/sigmoid.py b/src/flag_gems/ops/sigmoid.py index ef382f19..e5f736d7 100644 --- a/src/flag_gems/ops/sigmoid.py +++ b/src/flag_gems/ops/sigmoid.py @@ -8,14 +8,14 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def sigmoid_forward(x): log2e: tl.constexpr = math.log2(math.e) return 1 / (1 + tl.math.exp2(-x.to(tl.float32) * log2e)) -@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def sigmoid_backward(y, dy): y_f32 = y.to(tl.float32) diff --git a/src/flag_gems/ops/silu.py b/src/flag_gems/ops/silu.py index 84ee682b..a79248c5 100644 --- a/src/flag_gems/ops/silu.py +++ b/src/flag_gems/ops/silu.py @@ -7,7 +7,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def silu_forward(x): x_fp32 = x.to(tl.float32) @@ -15,7 +15,7 @@ def silu_forward(x): return y -@pointwise_dynamic(promotion_methods=[[0, "DEFAULT"]]) +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def silu_backward(x, dy): dy_fp32 = dy.to(tl.float32) diff --git a/src/flag_gems/ops/sin.py b/src/flag_gems/ops/sin.py index d0046e21..3cb7ab0b 100644 --- a/src/flag_gems/ops/sin.py +++ b/src/flag_gems/ops/sin.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def sin_func(x): return tl.sin(x.to(tl.float32)) diff --git a/src/flag_gems/ops/sub.py b/src/flag_gems/ops/sub.py index 0bc09fb0..c62faf05 100644 --- a/src/flag_gems/ops/sub.py +++ b/src/flag_gems/ops/sub.py @@ -6,14 +6,14 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[[0, 1, "DEFAULT"]]) +@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def sub_func(x, y, alpha): return x - y * alpha @pointwise_dynamic( - is_tensor=[True, False, False], promotion_methods=[[0, 1, "DEFAULT"]] + is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] ) @triton.jit def sub_func_tensor_scalar(x, y, alpha): @@ -21,7 +21,7 @@ def sub_func_tensor_scalar(x, y, alpha): @pointwise_dynamic( - is_tensor=[False, True, False], promotion_methods=[[0, 1, "DEFAULT"]] + is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] ) @triton.jit def sub_func_scalar_tensor(x, y, alpha): diff --git a/src/flag_gems/ops/tanh.py b/src/flag_gems/ops/tanh.py index f87dec29..e45c4ae1 100644 --- a/src/flag_gems/ops/tanh.py +++ b/src/flag_gems/ops/tanh.py @@ -7,13 +7,13 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def tanh_forward(x): return tl.math.tanh(x.to(tl.float32)) -@pointwise_dynamic(promotion_methods=[[0, "INT_TO_FLOAT"]]) +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def tanh_backward(y, dy): return dy * (1.0 - tl.math.pow(y.to(tl.float32), 2)) diff --git a/src/flag_gems/ops/where.py b/src/flag_gems/ops/where.py index ec48a136..64d40172 100644 --- a/src/flag_gems/ops/where.py +++ b/src/flag_gems/ops/where.py @@ -7,7 +7,7 @@ @pointwise_dynamic( - is_tensor=[True, True, True], promotion_methods=[[1, 2, "NO_OPMATH"]] + is_tensor=[True, True, True], promotion_methods=[(1, 2, "NO_OPMATH")] ) @triton.jit def where_self_func(condition, self, other): @@ -20,7 +20,7 @@ def where_self(condition, self, other): @pointwise_dynamic( - is_tensor=[True, True, False], promotion_methods=[[1, 2, "NO_OPMATH"]] + is_tensor=[True, True, False], promotion_methods=[(1, 2, "NO_OPMATH")] ) @triton.jit def where_scalar_self_func(condition, self, other): @@ -33,7 +33,7 @@ def where_scalar_self(condition, self, other): @pointwise_dynamic( - is_tensor=[True, True, False], promotion_methods=[[1, 2, "NO_OPMATH"]] + is_tensor=[True, True, False], promotion_methods=[(1, 2, "NO_OPMATH")] ) @triton.jit def where_scalar_other_func(condition, self, other): diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index 2ef9885c..a5b2887f 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -713,7 +713,7 @@ def decorator(fn): @pointwise_dynamic( is_tensor=[True, False, True], dtypes=[None, float, None], - promotion_methods=[[0, 1, 2, "DEFAULT"]], + promotion_methods=[(0, 1, 2, "DEFAULT")], ) @triton.jit def saxpy(x, alpha, y): @@ -729,7 +729,7 @@ def saxpy(x, alpha, y): print() @pointwise_dynamic( - is_tensor=[True, False, True], promotion_methods=[[0, 1, 2, "DEFAULT"]] + is_tensor=[True, False, True], promotion_methods=[(0, 1, 2, "DEFAULT")] ) @triton.jit def saxpy(x, alpha, y): @@ -742,7 +742,7 @@ def saxpy(x, alpha, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic(promotion_methods=[[0, 1, "ALWAYS_BOOL"]]) + @pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def ge(x, y): return x > y @@ -754,7 +754,7 @@ def ge(x, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic(promotion_methods=[[0, 1, "INT_TO_FLOAT"]]) + @pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) @triton.jit def ordinary(x, y): return tl.sin(x) + tl.cos(y) @@ -766,7 +766,7 @@ def ordinary(x, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic(promotion_methods=[[0, 1, "INT_TO_FLOAT"]]) + @pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) @triton.jit def ordinary2(x, y): return tl.sin(x) + tl.cos(y) @@ -778,7 +778,7 @@ def ordinary2(x, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic(promotion_methods=[[0, 1, "INT_TO_FLOAT"]]) + @pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) @triton.jit def ordinary2(x, y): return tl.sin(x) + tl.cos(y) @@ -793,7 +793,7 @@ def ordinary2(x, y): print() @pointwise_dynamic( - is_tensor=[True, False], promotion_methods=[[0, 1, "ALWAYS_BOOL"]] + is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")] ) @triton.jit def eq(x, y):