diff --git a/.github/workflows/python-test.yaml b/.github/workflows/python-test.yaml index 614fb31b..e4f42a1a 100644 --- a/.github/workflows/python-test.yaml +++ b/.github/workflows/python-test.yaml @@ -26,6 +26,7 @@ jobs: - name: unit_test-flag-gems run: | CUDA_VISIBLE_DEVICES=0 pytest -s tests/test_unary_pointwise_ops.py & + CUDA_VISIBLE_DEVICES=0 pytest -s tests/test_pointwise_type_promotion.py & CUDA_VISIBLE_DEVICES=1 pytest -s tests/test_binary_pointwise_ops.py & CUDA_VISIBLE_DEVICES=2 pytest -s tests/test_blas_ops.py & CUDA_VISIBLE_DEVICES=3 pytest -s tests/test_reduction_ops.py & diff --git a/src/flag_gems/ops/pow.py b/src/flag_gems/ops/pow.py index 664531ba..e3c17f48 100644 --- a/src/flag_gems/ops/pow.py +++ b/src/flag_gems/ops/pow.py @@ -6,7 +6,7 @@ from ..utils import pointwise_dynamic -@pointwise_dynamic(promotion_methods=[[0, "BOOL_TO_LONG"]]) +@pointwise_dynamic(promotion_methods=[[0, 1, "BOOL_TO_LONG"]]) @triton.jit def pow_func(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) @@ -17,7 +17,7 @@ def pow_tensor_tensor(A, exponent): return pow_func(A, exponent) -@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, "BOOL_TO_LONG"]]) +@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "BOOL_TO_LONG"]]) @triton.jit def pow_func_tensor_scalar(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) @@ -28,7 +28,7 @@ def pow_tensor_scalar(A, exponent): return pow_func_tensor_scalar(A, exponent) -@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, "BOOL_TO_LONG"]]) +@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, 1, "BOOL_TO_LONG"]]) @triton.jit def pow_func_scalar_tensor(x, exponent): return tl.math.pow(x.to(tl.float32), exponent) diff --git a/src/flag_gems/ops/where.py b/src/flag_gems/ops/where.py index 859868ec..ec48a136 100644 --- a/src/flag_gems/ops/where.py +++ b/src/flag_gems/ops/where.py @@ -7,7 +7,7 @@ @pointwise_dynamic( - is_tensor=[True, True, True], promotion_methods=[[0, 1, "NO_OPMATH"]] + is_tensor=[True, True, True], promotion_methods=[[1, 2, "NO_OPMATH"]] ) @triton.jit def where_self_func(condition, self, other): @@ -20,7 +20,7 @@ def where_self(condition, self, other): @pointwise_dynamic( - is_tensor=[True, True, False], promotion_methods=[[0, 1, "NO_OPMATH"]] + is_tensor=[True, True, False], promotion_methods=[[1, 2, "NO_OPMATH"]] ) @triton.jit def where_scalar_self_func(condition, self, other): @@ -33,7 +33,7 @@ def where_scalar_self(condition, self, other): @pointwise_dynamic( - is_tensor=[True, True, False], promotion_methods=[[0, 1, "NO_OPMATH"]] + is_tensor=[True, True, False], promotion_methods=[[1, 2, "NO_OPMATH"]] ) @triton.jit def where_scalar_other_func(condition, self, other): diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index 177061d8..e6493943 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -213,7 +213,11 @@ def ith_parameter_for_type_promotion(op_desc: OPDesc, ith: int) -> str: non_tensor_index = 0 for i in range(op_desc.num_inputs()): if i not in op_desc.ith_type_promotion_args(ith): - break + if op_desc._is_tensor[i]: + input_tensor_index += 1 + else: + non_tensor_index += 1 + continue if op_desc._is_tensor[i]: parameters.append(f"in{input_tensor_index}") input_tensor_index += 1 diff --git a/tests/test_pointwise_type_promotion.py b/tests/test_pointwise_type_promotion.py new file mode 100644 index 00000000..4444a8eb --- /dev/null +++ b/tests/test_pointwise_type_promotion.py @@ -0,0 +1,126 @@ +import logging + +import pytest +import torch + +import flag_gems + +from .accuracy_utils import ( + FLOAT_DTYPES, + POINTWISE_SHAPES, + SCALARS, + gems_assert_close, + gems_assert_equal, + to_reference, +) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("alpha", SCALARS) +@pytest.mark.parametrize("float_type", FLOAT_DTYPES) +def test_type_promotion_default(shape, alpha, float_type): + inp1 = torch.randint(10, shape, device="cuda") + inp2 = torch.randn(shape, dtype=float_type, device="cuda") + ref_inp2 = to_reference(inp2, True) + # arg0:int arg1:float + ref_out = torch.add(inp1, ref_inp2, alpha=alpha) + with flag_gems.use_gems(): + res_out = torch.add(inp1, inp2, alpha=alpha) + gems_assert_close(res_out, ref_out, float_type) + # arg0:float arg1:int + ref_out = torch.add(ref_inp2, inp1, alpha=alpha) + with flag_gems.use_gems(): + res_out = torch.add(inp2, inp1, alpha=alpha) + gems_assert_close(res_out, ref_out, float_type) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("float_type", FLOAT_DTYPES) +def test_type_promotion_no_opmath(shape, float_type): + inp1 = torch.randint(10, shape, device="cuda") + inp2 = torch.randn(shape, dtype=float_type, device="cuda") + ref_inp2 = to_reference(inp2) + # arg0:bool arg1:int arg2:float + ref_out = torch.where(inp1 > 0, inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.where(inp1 > 0, inp1, inp2) + gems_assert_equal(res_out, ref_out) + + # arg0:bool arg1:float arg2:int + ref_out = torch.where(inp1 > 0, ref_inp2, inp1) + with flag_gems.use_gems(): + res_out = torch.where(inp1 > 0, inp2, inp1) + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("float_type", FLOAT_DTYPES) +def test_type_promotion_int_to_float(shape, float_type): + # arg0:float + inp_float = torch.randn(shape, dtype=float_type, device="cuda") + ref_inp = to_reference(inp_float, True) + ref_out = torch.sin(ref_inp) + with flag_gems.use_gems(): + res_out = torch.sin(inp_float) + gems_assert_close(res_out, ref_out, float_type) + + # arg0:int + inp_int = torch.randint(10, shape, device="cuda") + ref_out = torch.sin(inp_int) + with flag_gems.use_gems(): + res_out = torch.sin(inp_int) + gems_assert_close(res_out, ref_out, torch.float32) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +def test_type_promotion_always_bool(shape): + # arg0:int arg0:int + inp1 = torch.randint(0, 10, shape, device="cuda") + inp2 = torch.randint(0, 10, shape, device="cuda") + ref_inp1 = to_reference(inp1) + ref_inp2 = to_reference(inp2) + ref_out = torch.eq(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.eq(inp1, inp2) + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("float_type", FLOAT_DTYPES) +def test_type_promotion_complex_to_long(shape, float_type): + # arg0:float + inp = torch.randn(shape, dtype=float_type, device="cuda") + ref_inp = to_reference(inp) + ref_out = torch.abs(ref_inp) + with flag_gems.use_gems(): + res_out = torch.abs(inp) + gems_assert_equal(res_out, ref_out) + + # arg0:int + inp1 = torch.randint(0, 10, shape, device="cuda") + ref_out1 = torch.abs(inp1) + with flag_gems.use_gems(): + res_out1 = torch.abs(inp1) + gems_assert_equal(res_out1, ref_out1) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("float_dtype", FLOAT_DTYPES) +def test_type_promotion_bool_to_long(shape, float_dtype): + inp1 = torch.randn(shape, dtype=float_dtype, device="cuda") + inp2 = torch.randint(0, 10, shape, device="cuda") + # arg0: float arg1: int + ref_out = torch.pow(inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.pow(inp1, inp2) + logging.debug(ref_out.dtype) + logging.debug(res_out.dtype) + gems_assert_close(res_out, ref_out, float_dtype, equal_nan=True) + + # arg0: int arg1: float + ref_out = torch.pow(inp2, inp1) + with flag_gems.use_gems(): + res_out = torch.pow(inp2, inp1) + logging.debug(ref_out.dtype) + logging.debug(res_out.dtype) + gems_assert_close(res_out, ref_out, float_dtype, equal_nan=True)