diff --git a/.github/workflows/python-test.yaml b/.github/workflows/python-test.yaml index 614fb31b..e4f42a1a 100644 --- a/.github/workflows/python-test.yaml +++ b/.github/workflows/python-test.yaml @@ -26,6 +26,7 @@ jobs: - name: unit_test-flag-gems run: | CUDA_VISIBLE_DEVICES=0 pytest -s tests/test_unary_pointwise_ops.py & + CUDA_VISIBLE_DEVICES=0 pytest -s tests/test_pointwise_type_promotion.py & CUDA_VISIBLE_DEVICES=1 pytest -s tests/test_binary_pointwise_ops.py & CUDA_VISIBLE_DEVICES=2 pytest -s tests/test_blas_ops.py & CUDA_VISIBLE_DEVICES=3 pytest -s tests/test_reduction_ops.py & diff --git a/README.md b/README.md index 0aee0981..948363c0 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ In FlagGems, we provide automatic code generation that developers can use to con Decorating the pointwise operator function with `pointwise_dynamic` can save the manual handling of tensor addressing, tensor read/write, parallel tiling, tensor broadcasting, dynamic dimensions, non-contiguous storage, etc. For example, in the following code, developers only need to describe the computational logic to generate flexible and efficient Triton code. ```python -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "COMPLEX_TO_FLOAT")]) @triton.jit def abs_func(x): return tl.abs(x) @@ -29,7 +29,11 @@ def abs_func(x): By default, `pointwise_dynamic` treats all parameters as tensors, and by passing a list of boolean values to the parameter `is_tensor`, developers can specify which parameters are tensors and which are not. Additionally, developers can pass in `dtypes` to indicate the data types of non-tensor parameters, but this is not required. For example, in the following code, the `alpha` parameter is defined as a non-tensor floating point number, while the `x` and `y` parameters are defined as tensors. ```python -@pointwise_dynamic(is_tensor=[True, True, False], dtypes=[None, None, float]) +@pointwise_dynamic( + is_tensor=[True, True, False], + dtypes=[None, None, float], + promotion_methods=[(0,"DEFAULT")] +) @triton.jit def add_func(x, y, alpha): return x + y * alpha @@ -37,7 +41,7 @@ def add_func(x, y, alpha): #### Output Data Type -By default, all output tensors have the same data type as the first input tensor, but it can also be customized by providing a list of data types to the parameter `output_dtypes`. For example, in the following code, the output tensor type is specified as `torch.bool`. +Furthermore, developers MUST provide promotion_methods to specify how type promotion should be handled for the operation to achieve the correct output type during computation. ```python @pointwise_dynamic(output_dtypes=[torch.bool]) @@ -46,6 +50,27 @@ def ge(x, y): return x > y ``` +In `promotion_methods`, an `int` is used to indicate the position of the parameter requiring type promotion, while a `str` denotes the method of type promotion. The `str` corresponds to the following enumerated types: + +```python +class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum): + DEFAULT = (0,) + NO_OPMATH = (1,) + INT_TO_FLOAT = (2,) + ALWAYS_BOOL = (3,) + COMPLEX_TO_FLOAT = (4,) + BOOL_TO_LONG = (5,) +``` + +Examples: + +- `DEFAULT` :add +- `NO_OPMATH` : where, nextafter, cat +- `INT_TO_FLOAT` :sin +- `ALWAYS_BOOL` :eq +- `COMPLEX_TO_FLOAT` :abs +- `BOOL_TO_LONG` :pow + ## Changelog ### v1.0 diff --git a/README_cn.md b/README_cn.md index 74ddbaff..289b9e67 100644 --- a/README_cn.md +++ b/README_cn.md @@ -18,7 +18,7 @@ FlagGems通过对PyTorch的后端aten算子进行覆盖重写,实现算子库 在对位算子函数前装饰`pointwise_dynamic`,可以节省张量寻址、张量读写、并行分块、张量广播、动态维度、非连续存储等的手动处理。例如以下代码,开发者只需简单描述计算逻辑,即可生成灵活高效的Triton核函数与包装代码。 ```python -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "COMPLEX_TO_FLOAT")]) @triton.jit def abs_func(x): return tl.abs(x) @@ -29,7 +29,11 @@ def abs_func(x): 在默认情况下,`pointwise_dynamic`将所有参数均处理为张量,而通过向参数`is_tensor`传递布尔值列表,开发者可以指定哪些参数是张量,哪些参数非张量。此外,开发者还可以传入`dtypes`说明非张量参数的数据类型,但这不是必要的。例如以下代码,将`alpha`参数定义为非张量的浮点数,而`x`和`y`参数定义为张量。 ```python -@pointwise_dynamic(is_tensor=[True, True, False], dtypes=[None, None, float]) +@pointwise_dynamic( + is_tensor=[True, True, False], + dtypes=[None, None, float], + promotion_methods=[(0,"DEFAULT")] +) @triton.jit def add_func(x, y, alpha): return x + y * alpha @@ -37,15 +41,36 @@ def add_func(x, y, alpha): #### 输出数据类型 -在默认情况下,输出张量使用与首个输入张量相同的数据类型,但也可向参数`output_dtypes`传入数据类型组成的列表来指定。例如以下代码,指定输出张量类型为`torch.bool`。 +此外,开发者必须传入 `promotion_methods` 来说明该 Op 在进行计算时应该如何进行`类型提升`以获得正确的输出类型 ```python -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[(0, "ALWAYS_BOOL")]) @triton.jit def ge(x, y): return x > y ``` +`promotion_methods` 通过传入 `int` 来表示需要进行类型提升的参数位置, 通过传入 `str` 来表示类型提升的方式, `str` 对于以下枚举类型 + +```python +class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum): + DEFAULT = (0,) + NO_OPMATH = (1,) + INT_TO_FLOAT = (2,) + ALWAYS_BOOL = (3,) + COMPLEX_TO_FLOAT = (4,) + BOOL_TO_LONG = (5,) +``` + +举例: + +- `DEFAULT` :add +- `NO_OPMATH` : where, nextafter, cat +- `INT_TO_FLOAT` :sin +- `ALWAYS_BOOL` :eq +- `COMPLEX_TO_FLOAT` :abs +- `BOOL_TO_LONG` :pow + ## 更新日志 ### v1.0 diff --git a/src/flag_gems/fused/gelu_and_mul.py b/src/flag_gems/fused/gelu_and_mul.py index c2cb52f6..52fdc7a7 100644 --- a/src/flag_gems/fused/gelu_and_mul.py +++ b/src/flag_gems/fused/gelu_and_mul.py @@ -7,7 +7,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def gelu_none_and_mul_kernel(x, y): x_fp32 = x.to(tl.float32) @@ -15,7 +15,7 @@ def gelu_none_and_mul_kernel(x, y): return x_gelu * y -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def gelu_tanh_and_mul_kernel(x, y): x_fp32 = x.to(tl.float32) diff --git a/src/flag_gems/fused/silu_and_mul.py b/src/flag_gems/fused/silu_and_mul.py index 0d1271ec..70b1b326 100644 --- a/src/flag_gems/fused/silu_and_mul.py +++ b/src/flag_gems/fused/silu_and_mul.py @@ -7,7 +7,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def silu_and_mul_kernel(x, y): x_fp32 = x.to(tl.float32) diff --git a/src/flag_gems/ops/abs.py b/src/flag_gems/ops/abs.py index 1f8eabf9..54838034 100644 --- a/src/flag_gems/ops/abs.py +++ b/src/flag_gems/ops/abs.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "COMPLEX_TO_FLOAT")]) @triton.jit def abs_func(x): return tl.abs(x) diff --git a/src/flag_gems/ops/add.py b/src/flag_gems/ops/add.py index 5c6dfee3..d7e52385 100644 --- a/src/flag_gems/ops/add.py +++ b/src/flag_gems/ops/add.py @@ -6,19 +6,23 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(is_tensor=[True, True, False]) +@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def add_func(x, y, alpha): return x + y * alpha -@pointwise_dynamic(is_tensor=[True, False, False]) +@pointwise_dynamic( + is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] +) @triton.jit def add_func_tensor_scalar(x, y, alpha): return x + y * alpha -@pointwise_dynamic(is_tensor=[False, True, False]) +@pointwise_dynamic( + is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] +) @triton.jit def add_func_scalar_tensor(x, y, alpha): return x + y * alpha diff --git a/src/flag_gems/ops/bitwise_and.py b/src/flag_gems/ops/bitwise_and.py index 72a994e0..f05655e7 100644 --- a/src/flag_gems/ops/bitwise_and.py +++ b/src/flag_gems/ops/bitwise_and.py @@ -5,7 +5,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def bitwise_and_func(x, y): return x & y @@ -16,7 +16,7 @@ def bitwise_and_tensor(A, B): return bitwise_and_func(A, B) -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def bitwise_and_func_scalar(x, y): return x & y diff --git a/src/flag_gems/ops/bitwise_not.py b/src/flag_gems/ops/bitwise_not.py index b5977514..49402046 100644 --- a/src/flag_gems/ops/bitwise_not.py +++ b/src/flag_gems/ops/bitwise_not.py @@ -5,7 +5,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")]) @triton.jit def bitwise_not_func(x): return ~x diff --git a/src/flag_gems/ops/bitwise_or.py b/src/flag_gems/ops/bitwise_or.py index 99f654e4..20f0b895 100644 --- a/src/flag_gems/ops/bitwise_or.py +++ b/src/flag_gems/ops/bitwise_or.py @@ -5,7 +5,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def bitwise_or_func(x, y): return x | y @@ -16,7 +16,7 @@ def bitwise_or_tensor(A, B): return bitwise_or_func(A, B) -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def bitwise_or_func_scalar(x, y): return x | y diff --git a/src/flag_gems/ops/clamp.py b/src/flag_gems/ops/clamp.py index c603fe52..74cb293a 100644 --- a/src/flag_gems/ops/clamp.py +++ b/src/flag_gems/ops/clamp.py @@ -6,19 +6,19 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, 1, 2, "DEFAULT")]) @triton.jit def clamp_func_tensor(x, mini, maxi): return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def clamp_func_min_tensor(x, mini): return tl.maximum(mini, x.to(tl.float32)) -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def clamp_func_max_tensor(x, maxi): return tl.minimum(maxi, x.to(tl.float32)) @@ -36,19 +36,21 @@ def clamp_tensor(A, mini=None, maxi=None): return clamp_func_tensor(A, mini, maxi) -@pointwise_dynamic(is_tensor=[True, False, False]) +@pointwise_dynamic( + is_tensor=[True, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")] +) @triton.jit def clamp_func(x, mini, maxi): return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def clamp_func_min(x, mini): return tl.maximum(mini, x.to(tl.float32)) -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def clamp_func_max(x, maxi): return tl.minimum(maxi, x.to(tl.float32)) diff --git a/src/flag_gems/ops/cos.py b/src/flag_gems/ops/cos.py index fee00a4b..5762763b 100644 --- a/src/flag_gems/ops/cos.py +++ b/src/flag_gems/ops/cos.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def cos_func(x): return tl.cos(x.to(tl.float32)) diff --git a/src/flag_gems/ops/div.py b/src/flag_gems/ops/div.py index 20e4b60c..46ba3565 100644 --- a/src/flag_gems/ops/div.py +++ b/src/flag_gems/ops/div.py @@ -6,32 +6,106 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) @triton.jit -def div_func(x, y): +def true_div_func(x, y): return x / y -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) @triton.jit -def div_func_tensor_scalar(x, y): +def true_div_func_tensor_scalar(x, y): return x / y -@pointwise_dynamic(is_tensor=[False, True]) +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) @triton.jit -def div_func_scalar_tensor(x, y): +def true_div_func_scalar_tensor(x, y): return x / y -def div(A, B): - logging.debug("GEMS DIV") +def true_divide(A, B): + logging.debug("GEMS TRUE_DIVIDE") if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): - return div_func(A, B) + return true_div_func(A, B) elif isinstance(A, torch.Tensor): - return div_func_tensor_scalar(A, B) + return true_div_func_tensor_scalar(A, B) elif isinstance(B, torch.Tensor): - return div_func_scalar_tensor(A, B) + return true_div_func_scalar_tensor(A, B) else: # Both scalar return A / B + + +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def trunc_div_func(x, y): + return triton.div_rz(x, y) + + +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def trunc_div_func_tensor_scalar(x, y): + return triton.div_rz(x, y) + + +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def trunc_div_func_scalar_tensor(x, y): + return triton.div_rz(x, y) + + +def trunc_divide(A, B): + logging.debug("GEMS TRUNC_DIVIDE") + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + return trunc_div_func(A, B) + elif isinstance(A, torch.Tensor): + return trunc_div_func_tensor_scalar(A, B) + elif isinstance(B, torch.Tensor): + return trunc_div_func_scalar_tensor(A, B) + else: + # Both scalar + return A / B + + +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def floor_div_func(x, y): + return x // y + + +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def floor_div_func_tensor_scalar(x, y): + return x // y + + +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def floor_div_func_scalar_tensor(x, y): + return x // y + + +def floor_divide(A, B): + logging.debug("GEMS FLOOR_DIVIDE") + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + return floor_div_func(A, B) + elif isinstance(A, torch.Tensor): + return floor_div_func_tensor_scalar(A, B) + elif isinstance(B, torch.Tensor): + return floor_div_func_scalar_tensor(A, B) + else: + # Both scalar + return A // B + + +def div(A, B, rounding_mode=None): + if rounding_mode is None: + return true_divide(A, B) + elif rounding_mode == "trunc": + return trunc_divide(A, B) + elif rounding_mode == "floor": + return floor_divide(A, B) + else: + msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." + raise ValueError(msg) diff --git a/src/flag_gems/ops/eq.py b/src/flag_gems/ops/eq.py index 38ee0447..21be2422 100644 --- a/src/flag_gems/ops/eq.py +++ b/src/flag_gems/ops/eq.py @@ -1,13 +1,12 @@ import logging -import torch import triton import triton.language as tl from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def eq_func(x, y): return x.to(tl.float32) == y.to(tl.float32) @@ -18,7 +17,7 @@ def eq(A, B): return eq_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def eq_func_scalar(x, y): return x.to(tl.float32) == y.to(tl.float32) diff --git a/src/flag_gems/ops/exp.py b/src/flag_gems/ops/exp.py index 5d51c5b4..bb774457 100644 --- a/src/flag_gems/ops/exp.py +++ b/src/flag_gems/ops/exp.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def exp_func(x): return tl.exp(x.to(tl.float32)) diff --git a/src/flag_gems/ops/ge.py b/src/flag_gems/ops/ge.py index 2edd5dfc..064466e1 100644 --- a/src/flag_gems/ops/ge.py +++ b/src/flag_gems/ops/ge.py @@ -1,13 +1,12 @@ import logging -import torch import triton import triton.language as tl from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def ge_func(x, y): return x.to(tl.float32) >= y @@ -18,7 +17,7 @@ def ge(A, B): return ge_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def ge_func_scalar(x, y): return x.to(tl.float32) >= y diff --git a/src/flag_gems/ops/gelu.py b/src/flag_gems/ops/gelu.py index 2df36194..0b612793 100644 --- a/src/flag_gems/ops/gelu.py +++ b/src/flag_gems/ops/gelu.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def gelu_none(x): scale = 0.7071067811 @@ -14,7 +14,7 @@ def gelu_none(x): return output -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def gelu_tanh(x): output = ( diff --git a/src/flag_gems/ops/gt.py b/src/flag_gems/ops/gt.py index 4fe53628..27625421 100644 --- a/src/flag_gems/ops/gt.py +++ b/src/flag_gems/ops/gt.py @@ -1,13 +1,12 @@ import logging -import torch import triton import triton.language as tl from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def gt_func(x, y): return x.to(tl.float32) > y @@ -18,7 +17,7 @@ def gt(A, B): return gt_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def gt_func_scalar(x, y): return x.to(tl.float32) > y diff --git a/src/flag_gems/ops/isinf.py b/src/flag_gems/ops/isinf.py index 17c8488c..15bb5fe7 100644 --- a/src/flag_gems/ops/isinf.py +++ b/src/flag_gems/ops/isinf.py @@ -1,13 +1,12 @@ import logging -import torch import triton import triton.language as tl from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[(0, "ALWAYS_BOOL")]) @triton.jit def isinf_func(x): return tl.math.isinf(x.to(tl.float32)) diff --git a/src/flag_gems/ops/isnan.py b/src/flag_gems/ops/isnan.py index f58f2ce1..163a7ddf 100644 --- a/src/flag_gems/ops/isnan.py +++ b/src/flag_gems/ops/isnan.py @@ -1,13 +1,12 @@ import logging -import torch import triton import triton.language as tl from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[(0, "ALWAYS_BOOL")]) @triton.jit def isnan_func(x): return tl.math.isnan(x.to(tl.float32)) diff --git a/src/flag_gems/ops/le.py b/src/flag_gems/ops/le.py index 70d84be5..d0dea92c 100644 --- a/src/flag_gems/ops/le.py +++ b/src/flag_gems/ops/le.py @@ -1,13 +1,12 @@ import logging -import torch import triton import triton.language as tl from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def le_func(x, y): return x.to(tl.float32) <= y @@ -18,7 +17,7 @@ def le(A, B): return le_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def le_func_scalar(x, y): return x.to(tl.float32) <= y diff --git a/src/flag_gems/ops/lt.py b/src/flag_gems/ops/lt.py index 26e828d1..c7bdf71a 100644 --- a/src/flag_gems/ops/lt.py +++ b/src/flag_gems/ops/lt.py @@ -1,13 +1,12 @@ import logging -import torch import triton import triton.language as tl from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def lt_func(x, y): return x.to(tl.float32) < y @@ -18,7 +17,7 @@ def lt(A, B): return lt_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def lt_func_scalar(x, y): return x.to(tl.float32) < y diff --git a/src/flag_gems/ops/mul.py b/src/flag_gems/ops/mul.py index e5745f98..d33d1f83 100644 --- a/src/flag_gems/ops/mul.py +++ b/src/flag_gems/ops/mul.py @@ -6,13 +6,13 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def mul_func(x, y): return x * y -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def mul_func_scalar(x, y): return x * y diff --git a/src/flag_gems/ops/ne.py b/src/flag_gems/ops/ne.py index f322ffff..4cf58980 100644 --- a/src/flag_gems/ops/ne.py +++ b/src/flag_gems/ops/ne.py @@ -1,13 +1,12 @@ import logging -import torch import triton import triton.language as tl from ..utils import pointwise_dynamic -@pointwise_dynamic(output_dtypes=[torch.bool]) +@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def ne_func(x, y): return x.to(tl.float32) != y.to(tl.float32) @@ -18,7 +17,7 @@ def ne(A, B): return ne_func(A, B) -@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def ne_func_scalar(x, y): return x.to(tl.float32) != y.to(tl.float32) diff --git a/src/flag_gems/ops/neg.py b/src/flag_gems/ops/neg.py index 88ea6a26..363eb9d2 100644 --- a/src/flag_gems/ops/neg.py +++ b/src/flag_gems/ops/neg.py @@ -5,7 +5,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def neg_func(x): return -x diff --git a/src/flag_gems/ops/pow.py b/src/flag_gems/ops/pow.py index 60abf1fa..5dc224d8 100644 --- a/src/flag_gems/ops/pow.py +++ b/src/flag_gems/ops/pow.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, 1, "BOOL_TO_LONG")]) @triton.jit def pow_func(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) @@ -17,7 +17,7 @@ def pow_tensor_tensor(A, exponent): return pow_func(A, exponent) -@pointwise_dynamic(is_tensor=[True, False]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "BOOL_TO_LONG")]) @triton.jit def pow_func_tensor_scalar(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) @@ -28,7 +28,7 @@ def pow_tensor_scalar(A, exponent): return pow_func_tensor_scalar(A, exponent) -@pointwise_dynamic(is_tensor=[False, True]) +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "BOOL_TO_LONG")]) @triton.jit def pow_func_scalar_tensor(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) diff --git a/src/flag_gems/ops/reciprocal.py b/src/flag_gems/ops/reciprocal.py index 432d45ea..948c2727 100644 --- a/src/flag_gems/ops/reciprocal.py +++ b/src/flag_gems/ops/reciprocal.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def reciprocal_func(x): return 1.0 / x.to(tl.float32) diff --git a/src/flag_gems/ops/relu.py b/src/flag_gems/ops/relu.py index 32d59af1..2eac7b3c 100644 --- a/src/flag_gems/ops/relu.py +++ b/src/flag_gems/ops/relu.py @@ -7,13 +7,13 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def relu_forward(x): return tl.where(x > 0, x, 0) -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def relu_backward(x, dy): return tl.where(x > 0, dy, 0) diff --git a/src/flag_gems/ops/rsqrt.py b/src/flag_gems/ops/rsqrt.py index 24090a0e..636dbb5f 100644 --- a/src/flag_gems/ops/rsqrt.py +++ b/src/flag_gems/ops/rsqrt.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def rsqrt_func(x): return 1.0 / tl.sqrt(x.to(tl.float32)) diff --git a/src/flag_gems/ops/sigmoid.py b/src/flag_gems/ops/sigmoid.py index 833e9559..e5f736d7 100644 --- a/src/flag_gems/ops/sigmoid.py +++ b/src/flag_gems/ops/sigmoid.py @@ -8,14 +8,14 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def sigmoid_forward(x): log2e: tl.constexpr = math.log2(math.e) return 1 / (1 + tl.math.exp2(-x.to(tl.float32) * log2e)) -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def sigmoid_backward(y, dy): y_f32 = y.to(tl.float32) diff --git a/src/flag_gems/ops/silu.py b/src/flag_gems/ops/silu.py index 358ab3aa..a79248c5 100644 --- a/src/flag_gems/ops/silu.py +++ b/src/flag_gems/ops/silu.py @@ -7,7 +7,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def silu_forward(x): x_fp32 = x.to(tl.float32) @@ -15,7 +15,7 @@ def silu_forward(x): return y -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) @triton.jit def silu_backward(x, dy): dy_fp32 = dy.to(tl.float32) diff --git a/src/flag_gems/ops/sin.py b/src/flag_gems/ops/sin.py index 4431f75e..3cb7ab0b 100644 --- a/src/flag_gems/ops/sin.py +++ b/src/flag_gems/ops/sin.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def sin_func(x): return tl.sin(x.to(tl.float32)) diff --git a/src/flag_gems/ops/sub.py b/src/flag_gems/ops/sub.py index 72804744..c62faf05 100644 --- a/src/flag_gems/ops/sub.py +++ b/src/flag_gems/ops/sub.py @@ -6,19 +6,23 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(is_tensor=[True, True, False]) +@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def sub_func(x, y, alpha): return x - y * alpha -@pointwise_dynamic(is_tensor=[True, False, False]) +@pointwise_dynamic( + is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] +) @triton.jit def sub_func_tensor_scalar(x, y, alpha): return x - y * alpha -@pointwise_dynamic(is_tensor=[False, True, False]) +@pointwise_dynamic( + is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] +) @triton.jit def sub_func_scalar_tensor(x, y, alpha): return x - y * alpha diff --git a/src/flag_gems/ops/tanh.py b/src/flag_gems/ops/tanh.py index b2cff858..e45c4ae1 100644 --- a/src/flag_gems/ops/tanh.py +++ b/src/flag_gems/ops/tanh.py @@ -7,13 +7,13 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def tanh_forward(x): return tl.math.tanh(x.to(tl.float32)) -@pointwise_dynamic +@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) @triton.jit def tanh_backward(y, dy): return dy * (1.0 - tl.math.pow(y.to(tl.float32), 2)) diff --git a/src/flag_gems/ops/where.py b/src/flag_gems/ops/where.py index 68b09868..64d40172 100644 --- a/src/flag_gems/ops/where.py +++ b/src/flag_gems/ops/where.py @@ -6,34 +6,40 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(is_tensor=[True, True, True]) +@pointwise_dynamic( + is_tensor=[True, True, True], promotion_methods=[(1, 2, "NO_OPMATH")] +) @triton.jit -def where_self_func(self, condition, other): +def where_self_func(condition, self, other): return tl.where(condition, self, other) def where_self(condition, self, other): logging.debug("GEMS WHERE_SELF") - return where_self_func(self, condition, other) + return where_self_func(condition, self, other) -@pointwise_dynamic(is_tensor=[True, True, False]) +@pointwise_dynamic( + is_tensor=[True, True, False], promotion_methods=[(1, 2, "NO_OPMATH")] +) @triton.jit -def where_scalar_self_func(other, condition, self): +def where_scalar_self_func(condition, self, other): return tl.where(condition, self, other) def where_scalar_self(condition, self, other): logging.debug("GEMS WHERE_SCALAR_SELF") - return where_scalar_self_func(other, condition, self) + return where_scalar_self_func(condition, self, other) -@pointwise_dynamic(is_tensor=[True, True, False]) +@pointwise_dynamic( + is_tensor=[True, True, False], promotion_methods=[(1, 2, "NO_OPMATH")] +) @triton.jit -def where_scalar_other_func(self, condition, other): +def where_scalar_other_func(condition, self, other): return tl.where(condition, self, other) def where_scalar_other(condition, self, other): logging.debug("GEMS WHERE_SCALAR_OTHER") - return where_scalar_other_func(self, condition, other) + return where_scalar_other_func(condition, self, other) diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index c6f47c32..a5b2887f 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -3,6 +3,7 @@ from typing import Any, Callable, List, Mapping, Optional, Tuple import torch +import torch._prims_common as utils import triton from triton import language as tl from triton.runtime.jit import JITFunction @@ -41,7 +42,7 @@ class OPDesc: _num_non_tensor_inputs: int _num_outputs: int - _output_dtypes: List[torch.dtype] + _promotion_methods: List[Tuple[int, ...]] def __init__( self, @@ -50,12 +51,18 @@ def __init__( is_tensor: Optional[List[bool]] = None, dtypes: Optional[List[Optional[type]]] = None, num_outputs: Optional[int] = None, - output_dtypes: Optional[List[torch.dtype]] = None, + promotion_methods: Optional[List[Tuple[int, ...]]] = None, ): if is_tensor is not None: _check_typed_list(is_tensor, bool) if dtypes is not None: _check_typed_list(dtypes, (type, type(None))) + if promotion_methods is None: + raise ValueError( + "No type promotion method provided! You must provide type promotion method for each output!" + ) + else: + self._promotion_methods = promotion_methods if num_inputs is not None: self._num_inputs = num_inputs @@ -91,22 +98,11 @@ def __init__( "Cannot make OPDesc when none of (num_inputs, is_tensor, dtypes) is specified." ) - if output_dtypes is not None: - _check_typed_list(output_dtypes, torch.dtype) - if num_outputs is not None: self._num_outputs = num_outputs - if output_dtypes is not None: - _check_sized_list(output_dtypes, num_outputs) - self._output_dtypes = output_dtypes - else: - self._output_dtypes = [None] * num_outputs # infer from the 1st input - elif output_dtypes is not None: - self._num_outputs = len(output_dtypes) - self._output_dtypes = output_dtypes + _check_sized_list(promotion_methods, num_outputs) else: - self._num_outputs = 1 - self._output_dtypes = [None] + self._num_outputs = len(promotion_methods) assert self._num_inputs >= 1 assert self._num_outputs >= 1 @@ -127,9 +123,6 @@ def is_tensor(self, arg_id: int) -> bool: def input_type(self, arg_id) -> Optional[type]: return self._dtypes[arg_id] - def output_dtype(self, output_id) -> torch.dtype: - return self._output_dtypes[output_id] - def num_input_tensors(self) -> int: return self._num_input_tensors @@ -139,6 +132,23 @@ def num_output_tensors(self) -> int: def num_non_tensor_args(self) -> int: return self._num_non_tensor_inputs + def type_promotion_methods(self) -> List[Tuple[int, ...]]: + return self._promotion_methods + + def _match_enum_by_string( + self, input_str: str + ) -> utils.ELEMENTWISE_TYPE_PROMOTION_KIND: + for kind in utils.ELEMENTWISE_TYPE_PROMOTION_KIND: + if input_str.lower() == kind.name.lower(): + return kind + raise ValueError(f"No matching enum member found for input: {input_str}") + + def ith_type_promotion_args(self, i) -> List[int]: + return self._promotion_methods[i][:-1] + + def ith_type_promotion_kind(self, i) -> utils.ELEMENTWISE_TYPE_PROMOTION_KIND: + return self._match_enum_by_string(self._promotion_methods[i][-1]) + def signature(self, outputs_in_arg: bool = False): input_types = [] for is_tensor, dtype in zip(self._is_tensor, self._dtypes): @@ -151,11 +161,8 @@ def signature(self, outputs_in_arg: bool = False): input_types.append(_type_name(dtype)) output_types = [] - for dtype in self._output_dtypes: - if dtype is None: - output_types.append("Tensor") - else: - output_types.append(f"Tensor[{_type_name(dtype)}]") + for _ in range(self.num_outputs()): + output_types.append("Tensor") if outputs_in_arg: input_types.extend(output_types) sig = f'Pointwise: ({", ".join(input_types)}) -> ({", ".join(output_types)})' @@ -196,6 +203,31 @@ def parameter_for_wrapper(op_desc: OPDesc, include_outputs: bool = False) -> str return ", ".join(parameters) +def ith_parameter_for_type_promotion(op_desc: OPDesc, ith: int) -> str: + """Generate parameter reference for i-th type promotion rule + Example: in0, val0, out0 + """ + parameters: List[str] = [] + + input_tensor_index = 0 + non_tensor_index = 0 + for i in range(op_desc.num_inputs()): + if i not in op_desc.ith_type_promotion_args(ith): + if op_desc._is_tensor[i]: + input_tensor_index += 1 + else: + non_tensor_index += 1 + continue + if op_desc._is_tensor[i]: + parameters.append(f"in{input_tensor_index}") + input_tensor_index += 1 + else: + parameters.append(f"val{non_tensor_index}") + non_tensor_index += 1 + + return ", ".join(parameters) + + def parameter_ref_for_wrapper(op_desc: OPDesc, include_outputs: bool = False) -> str: """Generate parameter reference for wrapper function. Example: in0, val0, out0 @@ -245,11 +277,17 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer: code.writeline("import triton") code.writeline("from triton import language as tl") code.newline() - code.writeline( - "from flag_gems.utils.shape_utils import broadcast_shapes, \ - broadcasted_stride, c_contiguous_stride, volume, Stride" - ) + code.writeline("from flag_gems.utils.shape_utils import (") + code.writeline(" broadcast_shapes,") + code.writeline(" broadcasted_stride,") + code.writeline(" c_contiguous_stride,") + code.writeline(" volume,") + code.writeline(" Stride,") + code.writeline(")") code.writeline("from flag_gems.utils.libentry import libentry") + code.writeline("from flag_gems.utils.type_utils import type_promotion") + code.writeline("import torch._prims_common as utils") + code.newline() code.newline() return code @@ -278,27 +316,30 @@ def generate_functional_pointwise_wrapper( # output allocation num_output_tensor_index = 0 for i in range(op_desc.num_outputs()): - if op_desc.output_dtype(i) is None: - code.writeline( - f"out{num_output_tensor_index} = \ - torch.empty(shape, dtype=in0.dtype, device=in0.device)" - ) - else: - code.writeline( - f"out{num_output_tensor_index} = \ - torch.empty(shape, dtype={_type_name(op_desc.output_dtype(i))}, device=in0.device)" + type_promotion_args = ith_parameter_for_type_promotion(op_desc, i) + k_type_promotion = op_desc.ith_type_promotion_kind(i) + code.writeline( + ( + f"out{num_output_tensor_index} = " + f"torch.empty(shape, dtype=type_promotion" + f"({type_promotion_args}, type_promotion=utils.{k_type_promotion})[1], " + f"device=in0.device)" ) + ) num_output_tensor_index += 1 # call destination_passing_func output_names: str = output_ref_for_wrapper(op_desc) - call_str = f"{output_names} = {destination_passing_func_name} \ - ({parameter_ref_for_wrapper(op_desc, include_outputs=True)})" + call_str = ( + f"{output_names} = {destination_passing_func_name}" + f"({parameter_ref_for_wrapper(op_desc, include_outputs=True)})" + ) code.writeline(call_str) return_str = f"return {output_names}" code.writeline(return_str) code.newline() + code.newline() return code @@ -381,6 +422,7 @@ def generate_destination_passing_pointwise_wrapper( # return code.writeline(f"return {output_ref_for_wrapper(op_desc)}") code.newline() + code.newline() return code @@ -646,7 +688,7 @@ def pointwise_dynamic( is_tensor: Optional[List[bool]] = None, dtypes: Optional[List[Optional[type]]] = None, num_outputs: Optional[int] = None, - output_dtypes: Optional[List[type]] = None, + promotion_methods: Optional[Tuple[int, ...]] = None, ): def decorator(fn): nonlocal num_inputs @@ -657,7 +699,7 @@ def decorator(fn): is_tensor=is_tensor, dtypes=dtypes, num_outputs=num_outputs, - output_dtypes=output_dtypes, + promotion_methods=promotion_methods, ) return PointwiseDynamicFunction(op_desc, fn) @@ -668,7 +710,11 @@ def decorator(fn): if __name__ == "__main__": - @pointwise_dynamic(is_tensor=[True, False, True], dtypes=[None, float, None]) + @pointwise_dynamic( + is_tensor=[True, False, True], + dtypes=[None, float, None], + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) @triton.jit def saxpy(x, alpha, y): return x * alpha + y @@ -682,7 +728,9 @@ def saxpy(x, alpha, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic(is_tensor=[True, False, True]) + @pointwise_dynamic( + is_tensor=[True, False, True], promotion_methods=[(0, 1, 2, "DEFAULT")] + ) @triton.jit def saxpy(x, alpha, y): return x * alpha + y @@ -694,7 +742,7 @@ def saxpy(x, alpha, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic(output_dtypes=[torch.bool]) + @pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit def ge(x, y): return x > y @@ -706,7 +754,7 @@ def ge(x, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic() + @pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) @triton.jit def ordinary(x, y): return tl.sin(x) + tl.cos(y) @@ -718,7 +766,7 @@ def ordinary(x, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic + @pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) @triton.jit def ordinary2(x, y): return tl.sin(x) + tl.cos(y) @@ -730,7 +778,7 @@ def ordinary2(x, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic + @pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) @triton.jit def ordinary2(x, y): return tl.sin(x) + tl.cos(y) @@ -744,7 +792,9 @@ def ordinary2(x, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) + @pointwise_dynamic( + is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")] + ) @triton.jit def eq(x, y): return x.to(tl.float32) == y.to( diff --git a/src/flag_gems/utils/type_utils.py b/src/flag_gems/utils/type_utils.py new file mode 100644 index 00000000..996cfb21 --- /dev/null +++ b/src/flag_gems/utils/type_utils.py @@ -0,0 +1,9 @@ +import torch._prims_common as utils + + +def type_promotion(*args, type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND): + computation_dtype, result_dtype = utils.elementwise_dtypes( + *args, + type_promotion_kind=type_promotion, + ) + return computation_dtype, result_dtype diff --git a/tests/test_pointwise_type_promotion.py b/tests/test_pointwise_type_promotion.py new file mode 100644 index 00000000..4444a8eb --- /dev/null +++ b/tests/test_pointwise_type_promotion.py @@ -0,0 +1,126 @@ +import logging + +import pytest +import torch + +import flag_gems + +from .accuracy_utils import ( + FLOAT_DTYPES, + POINTWISE_SHAPES, + SCALARS, + gems_assert_close, + gems_assert_equal, + to_reference, +) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("alpha", SCALARS) +@pytest.mark.parametrize("float_type", FLOAT_DTYPES) +def test_type_promotion_default(shape, alpha, float_type): + inp1 = torch.randint(10, shape, device="cuda") + inp2 = torch.randn(shape, dtype=float_type, device="cuda") + ref_inp2 = to_reference(inp2, True) + # arg0:int arg1:float + ref_out = torch.add(inp1, ref_inp2, alpha=alpha) + with flag_gems.use_gems(): + res_out = torch.add(inp1, inp2, alpha=alpha) + gems_assert_close(res_out, ref_out, float_type) + # arg0:float arg1:int + ref_out = torch.add(ref_inp2, inp1, alpha=alpha) + with flag_gems.use_gems(): + res_out = torch.add(inp2, inp1, alpha=alpha) + gems_assert_close(res_out, ref_out, float_type) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("float_type", FLOAT_DTYPES) +def test_type_promotion_no_opmath(shape, float_type): + inp1 = torch.randint(10, shape, device="cuda") + inp2 = torch.randn(shape, dtype=float_type, device="cuda") + ref_inp2 = to_reference(inp2) + # arg0:bool arg1:int arg2:float + ref_out = torch.where(inp1 > 0, inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.where(inp1 > 0, inp1, inp2) + gems_assert_equal(res_out, ref_out) + + # arg0:bool arg1:float arg2:int + ref_out = torch.where(inp1 > 0, ref_inp2, inp1) + with flag_gems.use_gems(): + res_out = torch.where(inp1 > 0, inp2, inp1) + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("float_type", FLOAT_DTYPES) +def test_type_promotion_int_to_float(shape, float_type): + # arg0:float + inp_float = torch.randn(shape, dtype=float_type, device="cuda") + ref_inp = to_reference(inp_float, True) + ref_out = torch.sin(ref_inp) + with flag_gems.use_gems(): + res_out = torch.sin(inp_float) + gems_assert_close(res_out, ref_out, float_type) + + # arg0:int + inp_int = torch.randint(10, shape, device="cuda") + ref_out = torch.sin(inp_int) + with flag_gems.use_gems(): + res_out = torch.sin(inp_int) + gems_assert_close(res_out, ref_out, torch.float32) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +def test_type_promotion_always_bool(shape): + # arg0:int arg0:int + inp1 = torch.randint(0, 10, shape, device="cuda") + inp2 = torch.randint(0, 10, shape, device="cuda") + ref_inp1 = to_reference(inp1) + ref_inp2 = to_reference(inp2) + ref_out = torch.eq(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.eq(inp1, inp2) + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("float_type", FLOAT_DTYPES) +def test_type_promotion_complex_to_long(shape, float_type): + # arg0:float + inp = torch.randn(shape, dtype=float_type, device="cuda") + ref_inp = to_reference(inp) + ref_out = torch.abs(ref_inp) + with flag_gems.use_gems(): + res_out = torch.abs(inp) + gems_assert_equal(res_out, ref_out) + + # arg0:int + inp1 = torch.randint(0, 10, shape, device="cuda") + ref_out1 = torch.abs(inp1) + with flag_gems.use_gems(): + res_out1 = torch.abs(inp1) + gems_assert_equal(res_out1, ref_out1) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("float_dtype", FLOAT_DTYPES) +def test_type_promotion_bool_to_long(shape, float_dtype): + inp1 = torch.randn(shape, dtype=float_dtype, device="cuda") + inp2 = torch.randint(0, 10, shape, device="cuda") + # arg0: float arg1: int + ref_out = torch.pow(inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.pow(inp1, inp2) + logging.debug(ref_out.dtype) + logging.debug(res_out.dtype) + gems_assert_close(res_out, ref_out, float_dtype, equal_nan=True) + + # arg0: int arg1: float + ref_out = torch.pow(inp2, inp1) + with flag_gems.use_gems(): + res_out = torch.pow(inp2, inp1) + logging.debug(ref_out.dtype) + logging.debug(res_out.dtype) + gems_assert_close(res_out, ref_out, float_dtype, equal_nan=True)