Skip to content

Commit

Permalink
[Operator] Add test cases of isclose and allclose op
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengyang committed Jun 26, 2024
1 parent ec18e22 commit e909d0f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
26 changes: 26 additions & 0 deletions benchmark/test_pointwise_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,19 @@ def test_perf_isclose(dtype):
bench.run()


@pytest.mark.parametrize("dtype", INT_DTYPES)
def test_perf_isclose_int(dtype):
bench = Benchmark(
op_name="isclose",
torch_op=torch.isclose,
arg_func=binary_int_args,
dtype=dtype,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_perf_allclose(dtype):
bench = Benchmark(
Expand All @@ -462,3 +475,16 @@ def test_perf_allclose(dtype):
sizes=SIZES,
)
bench.run()


@pytest.mark.parametrize("dtype", INT_DTYPES)
def test_perf_allclose_int(dtype):
bench = Benchmark(
op_name="allclose",
torch_op=torch.allclose,
arg_func=binary_int_args,
dtype=dtype,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()
20 changes: 16 additions & 4 deletions src/flag_gems/ops/isclose.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def isclose_func(x, y, rtol, atol):
x_fp == y_fp,
tl.abs(x_fp - y_fp) <= atol + rtol * tl.abs(y_fp),
)
# return tl.abs(x - y) <= atol + rtol * tl.abs(y)


@pointwise_dynamic(is_tensor=[True, True, False, False], output_dtypes=[torch.bool])
Expand All @@ -39,6 +38,12 @@ def isclose_func_equal_nan(x, y, rtol, atol):
)


@pointwise_dynamic(is_tensor=[True, True, False, False], output_dtypes=[torch.bool])
@triton.jit
def isclose_func_int(x, y, rtol, atol):
return tl.abs(x - y) <= atol + rtol * tl.abs(y)


def isclose(
A: torch.Tensor,
B: torch.Tensor,
Expand All @@ -49,11 +54,18 @@ def isclose(
logging.debug("GEMS ISCLOSE")
if rtol < 0:
raise RuntimeError(
"rtol must be greater than or equal to zero, but got {}".format(rtol))
"rtol must be greater than or equal to zero, but got {}".format(rtol)
)
if atol < 0:
raise RuntimeError(
"atol must be greater than or equal to zero, but got {}".format(atol))
if equal_nan:
"atol must be greater than or equal to zero, but got {}".format(atol)
)
def is_int(X):
return X.dtype == torch.int8 or X.dtype == torch.int16 or \
X.dtype == torch.int32 or X.dtype == torch.int64
if False and is_int(A) and is_int(B):
return isclose_func_int(A, B, rtol, atol)
elif equal_nan:
return isclose_func_equal_nan(A, B, rtol, atol)
else:
return isclose_func(A, B, rtol, atol)
Expand Down
11 changes: 9 additions & 2 deletions tests/test_binary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,10 @@ def test_accuracy_isclose(shape, dtype, equal_nan, gen_nan):
inp2 = torch.randn(shape, dtype=dtype, device="cuda")
if gen_nan:
nan_num = torch.full(
(1,), float("nan" if gen_nan == 1 else "inf"), dtype=dtype, device="cuda"
(1,),
float("nan" if gen_nan == 1 else "inf"),
dtype=dtype,
device="cuda",
)
inp1.view(-1)[0] = -nan_num if gen_nan == 3 else nan_num
inp2.view(-1)[0] = -nan_num if gen_nan >= 3 else nan_num
Expand Down Expand Up @@ -718,7 +721,11 @@ def test_accuracy_allclose(shape, dtype, equal_nan, gen_nan):
inp2 = torch.randn(shape, dtype=dtype, device="cuda")
if gen_nan:
nan_num = torch.full(
(1,), float("nan" if gen_nan == 1 else "inf"), dtype=dtype, device="cuda"
(1,),
float("nan" if gen_nan == 1 else "inf"),
dtype=dtype,
device="cuda",
)
)
inp1.view(-1)[0] = -nan_num if gen_nan == 3 else nan_num
inp2.view(-1)[0] = -nan_num if gen_nan >= 3 else nan_num
Expand Down

0 comments on commit e909d0f

Please sign in to comment.