Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Bowen12992 committed Jun 26, 2024
1 parent 26147cb commit 6c80334
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 7 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
- name: unit_test-flag-gems
run: |
CUDA_VISIBLE_DEVICES=0 pytest -s tests/test_unary_pointwise_ops.py &
CUDA_VISIBLE_DEVICES=0 pytest -s tests/test_pointwise_type_promotion.py &
CUDA_VISIBLE_DEVICES=1 pytest -s tests/test_binary_pointwise_ops.py &
CUDA_VISIBLE_DEVICES=2 pytest -s tests/test_blas_ops.py &
CUDA_VISIBLE_DEVICES=3 pytest -s tests/test_reduction_ops.py &
Expand Down
6 changes: 3 additions & 3 deletions src/flag_gems/ops/pow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic(promotion_methods=[[0, "BOOL_TO_LONG"]])
@pointwise_dynamic(promotion_methods=[[0, 1, "BOOL_TO_LONG"]])
@triton.jit
def pow_func(x, exponent):
return tl.math.pow(x.to(tl.float32), exponent)
Expand All @@ -17,7 +17,7 @@ def pow_tensor_tensor(A, exponent):
return pow_func(A, exponent)


@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, "BOOL_TO_LONG"]])
@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[[0, 1, "BOOL_TO_LONG"]])
@triton.jit
def pow_func_tensor_scalar(x, exponent):
return tl.math.pow(x.to(tl.float32), exponent)
Expand All @@ -28,7 +28,7 @@ def pow_tensor_scalar(A, exponent):
return pow_func_tensor_scalar(A, exponent)


@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, "BOOL_TO_LONG"]])
@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[[0, 1, "BOOL_TO_LONG"]])
@triton.jit
def pow_func_scalar_tensor(x, exponent):
return tl.math.pow(x.to(tl.float32), exponent)
Expand Down
6 changes: 3 additions & 3 deletions src/flag_gems/ops/where.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@pointwise_dynamic(
is_tensor=[True, True, True], promotion_methods=[[0, 1, "NO_OPMATH"]]
is_tensor=[True, True, True], promotion_methods=[[1, 2, "NO_OPMATH"]]
)
@triton.jit
def where_self_func(condition, self, other):
Expand All @@ -20,7 +20,7 @@ def where_self(condition, self, other):


@pointwise_dynamic(
is_tensor=[True, True, False], promotion_methods=[[0, 1, "NO_OPMATH"]]
is_tensor=[True, True, False], promotion_methods=[[1, 2, "NO_OPMATH"]]
)
@triton.jit
def where_scalar_self_func(condition, self, other):
Expand All @@ -33,7 +33,7 @@ def where_scalar_self(condition, self, other):


@pointwise_dynamic(
is_tensor=[True, True, False], promotion_methods=[[0, 1, "NO_OPMATH"]]
is_tensor=[True, True, False], promotion_methods=[[1, 2, "NO_OPMATH"]]
)
@triton.jit
def where_scalar_other_func(condition, self, other):
Expand Down
6 changes: 5 additions & 1 deletion src/flag_gems/utils/pointwise_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,11 @@ def ith_parameter_for_type_promotion(op_desc: OPDesc, ith: int) -> str:
non_tensor_index = 0
for i in range(op_desc.num_inputs()):
if i not in op_desc.ith_type_promotion_args(ith):
break
if op_desc._is_tensor[i]:
input_tensor_index += 1
else:
non_tensor_index += 1
continue
if op_desc._is_tensor[i]:
parameters.append(f"in{input_tensor_index}")
input_tensor_index += 1
Expand Down
126 changes: 126 additions & 0 deletions tests/test_pointwise_type_promotion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import logging

import pytest
import torch

import flag_gems

from .accuracy_utils import (
FLOAT_DTYPES,
POINTWISE_SHAPES,
SCALARS,
gems_assert_close,
gems_assert_equal,
to_reference,
)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("alpha", SCALARS)
@pytest.mark.parametrize("float_type", FLOAT_DTYPES)
def test_type_promotion_default(shape, alpha, float_type):
inp1 = torch.randint(10, shape, device="cuda")
inp2 = torch.randn(shape, dtype=float_type, device="cuda")
ref_inp2 = to_reference(inp2, True)
# arg0:int arg1:float
ref_out = torch.add(inp1, ref_inp2, alpha=alpha)
with flag_gems.use_gems():
res_out = torch.add(inp1, inp2, alpha=alpha)
gems_assert_close(res_out, ref_out, float_type)
# arg0:float arg1:int
ref_out = torch.add(ref_inp2, inp1, alpha=alpha)
with flag_gems.use_gems():
res_out = torch.add(inp2, inp1, alpha=alpha)
gems_assert_close(res_out, ref_out, float_type)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("float_type", FLOAT_DTYPES)
def test_type_promotion_no_opmath(shape, float_type):
inp1 = torch.randint(10, shape, device="cuda")
inp2 = torch.randn(shape, dtype=float_type, device="cuda")
ref_inp2 = to_reference(inp2)
# arg0:bool arg1:int arg2:float
ref_out = torch.where(inp1 > 0, inp1, ref_inp2)
with flag_gems.use_gems():
res_out = torch.where(inp1 > 0, inp1, inp2)
gems_assert_equal(res_out, ref_out)

# arg0:bool arg1:float arg2:int
ref_out = torch.where(inp1 > 0, ref_inp2, inp1)
with flag_gems.use_gems():
res_out = torch.where(inp1 > 0, inp2, inp1)
gems_assert_equal(res_out, ref_out)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("float_type", FLOAT_DTYPES)
def test_type_promotion_int_to_float(shape, float_type):
# arg0:float
inp_float = torch.randn(shape, dtype=float_type, device="cuda")
ref_inp = to_reference(inp_float, True)
ref_out = torch.sin(ref_inp)
with flag_gems.use_gems():
res_out = torch.sin(inp_float)
gems_assert_close(res_out, ref_out, float_type)

# arg0:int
inp_int = torch.randint(10, shape, device="cuda")
ref_out = torch.sin(inp_int)
with flag_gems.use_gems():
res_out = torch.sin(inp_int)
gems_assert_close(res_out, ref_out, torch.float32)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
def test_type_promotion_always_bool(shape):
# arg0:int arg0:int
inp1 = torch.randint(0, 10, shape, device="cuda")
inp2 = torch.randint(0, 10, shape, device="cuda")
ref_inp1 = to_reference(inp1)
ref_inp2 = to_reference(inp2)
ref_out = torch.eq(ref_inp1, ref_inp2)
with flag_gems.use_gems():
res_out = torch.eq(inp1, inp2)
gems_assert_equal(res_out, ref_out)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("float_type", FLOAT_DTYPES)
def test_type_promotion_complex_to_long(shape, float_type):
# arg0:float
inp = torch.randn(shape, dtype=float_type, device="cuda")
ref_inp = to_reference(inp)
ref_out = torch.abs(ref_inp)
with flag_gems.use_gems():
res_out = torch.abs(inp)
gems_assert_equal(res_out, ref_out)

# arg0:int
inp1 = torch.randint(0, 10, shape, device="cuda")
ref_out1 = torch.abs(inp1)
with flag_gems.use_gems():
res_out1 = torch.abs(inp1)
gems_assert_equal(res_out1, ref_out1)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("float_dtype", FLOAT_DTYPES)
def test_type_promotion_bool_to_long(shape, float_dtype):
inp1 = torch.randn(shape, dtype=float_dtype, device="cuda")
inp2 = torch.randint(0, 10, shape, device="cuda")
# arg0: float arg1: int
ref_out = torch.pow(inp1, inp2)
with flag_gems.use_gems():
res_out = torch.pow(inp1, inp2)
logging.debug(ref_out.dtype)
logging.debug(res_out.dtype)
gems_assert_close(res_out, ref_out, float_dtype, equal_nan=True)

# arg0: int arg1: float
ref_out = torch.pow(inp2, inp1)
with flag_gems.use_gems():
res_out = torch.pow(inp2, inp1)
logging.debug(ref_out.dtype)
logging.debug(res_out.dtype)
gems_assert_close(res_out, ref_out, float_dtype, equal_nan=True)

0 comments on commit 6c80334

Please sign in to comment.