Skip to content

Commit

Permalink
[Operator] Fix cal acc and inputs check of isclose and allclose op, a…
Browse files Browse the repository at this point in the history
…dd all types tests
  • Loading branch information
zhengyang committed Jun 28, 2024
1 parent cc99b5f commit 2aabb57
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 58 deletions.
87 changes: 66 additions & 21 deletions src/flag_gems/ops/isclose.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
@pointwise_dynamic(is_tensor=[True, True, False, False], output_dtypes=[torch.bool])
@triton.jit
def isclose_func(x, y, rtol, atol):
x_fp = x.to(tl.float64)
y_fp = y.to(tl.float64)
x_fp = x.to(tl.float32)
y_fp = y.to(tl.float32)
return tl.where(
tl.math.isinf(x.to(tl.float32)) | tl.math.isinf(y.to(tl.float32)),
tl.math.isinf(x_fp) | tl.math.isinf(y_fp),
x_fp == y_fp,
tl.abs(x_fp - y_fp) <= atol + rtol * tl.abs(y_fp),
)
Expand All @@ -23,35 +23,70 @@ def isclose_func(x, y, rtol, atol):
@pointwise_dynamic(is_tensor=[True, True, False, False], output_dtypes=[torch.bool])
@triton.jit
def isclose_func_equal_nan(x, y, rtol, atol):
x_fp = x.to(tl.float64)
y_fp = y.to(tl.float64)
x_fp = x.to(tl.float32)
y_fp = y.to(tl.float32)
x_nan = x_fp != x_fp
y_nan = y_fp != y_fp
return tl.where(
x_nan | y_nan,
x_nan == y_nan,
tl.where(
tl.math.isinf(x.to(tl.float32)) | tl.math.isinf(y.to(tl.float32)),
tl.math.isinf(x_fp) | tl.math.isinf(y_fp),
x_fp == y_fp,
tl.abs(x_fp - y_fp) <= atol + rtol * tl.abs(y_fp),
),
)


@pointwise_dynamic(is_tensor=[True, True, False, False], output_dtypes=[torch.bool])
@triton.jit
def isclose_func_fp(x, y, rtol, atol):
return tl.where(
tl.math.isinf(x) | tl.math.isinf(y),
x == y,
tl.abs(x - y) <= atol + rtol * tl.abs(y),
)


@pointwise_dynamic(is_tensor=[True, True, False, False], output_dtypes=[torch.bool])
@triton.jit
def isclose_func_equal_nan_fp(x, y, rtol, atol):
x_nan = x != x
y_nan = y != y
return tl.where(
x_nan | y_nan,
x_nan == y_nan,
tl.where(
tl.math.isinf(x) | tl.math.isinf(y),
x == y,
tl.abs(x - y) <= atol + rtol * tl.abs(y),
),
)


@pointwise_dynamic(is_tensor=[True, True, False, False], output_dtypes=[torch.bool])
@triton.jit
def isclose_func_int(x, y, rtol, atol):
return tl.abs(x - y) <= atol + rtol * tl.abs(y)
x_long = x.to(tl.int64)
y_long = y.to(tl.int64)
return tl.abs(x_long - y_long) <= atol + rtol * tl.abs(y_long)


def isclose(
def _isclose(
A: torch.Tensor,
B: torch.Tensor,
rtol=1e-05,
atol=1e-08,
equal_nan: bool = False,
) -> torch.Tensor:
logging.debug("GEMS ISCLOSE")
if A.dtype != B.dtype:
raise RuntimeError(
"{} did not match {}".format(A.dtype, B.dtype)
)
if A.is_quantized or B.is_quantized:
raise RuntimeError(
"isclose is not supported for quantized inputs."
)
if rtol < 0:
raise RuntimeError(
"rtol must be greater than or equal to zero, but got {}".format(rtol)
Expand All @@ -61,20 +96,30 @@ def isclose(
"atol must be greater than or equal to zero, but got {}".format(atol)
)

def is_int(X):
return (
X.dtype == torch.int8
or X.dtype == torch.int16
or X.dtype == torch.int32
or X.dtype == torch.int64
)

if False and is_int(A) and is_int(B):
if (A.dtype == torch.int64 or A.dtype == torch.int32 or A.dtype == torch.int16
or A.dtype == torch.int8 or A.dtype == torch.bool):
return isclose_func_int(A, B, rtol, atol)
elif equal_nan:
return isclose_func_equal_nan(A, B, rtol, atol)
if A.dtype == torch.float32 or A.dtype == torch.float64:
return isclose_func_equal_nan_fp(A, B, rtol, atol)
else:
return isclose_func_equal_nan(A, B, rtol, atol)
else:
return isclose_func(A, B, rtol, atol)
if A.dtype == torch.float32 or A.dtype == torch.float64:
return isclose_func_fp(A, B, rtol, atol)
else:
return isclose_func(A, B, rtol, atol)


def isclose(
A: torch.Tensor,
B: torch.Tensor,
rtol=1e-05,
atol=1e-08,
equal_nan: bool = False,
) -> torch.Tensor:
logging.debug("GEMS ISCLOSE")
return _isclose(A, B, rtol, atol, equal_nan)


def allclose(
Expand All @@ -85,4 +130,4 @@ def allclose(
equal_nan: bool = False,
) -> bool:
logging.debug("GEMS ALLCLOSE")
return all(isclose(A, B, rtol, atol, equal_nan)).item()
return all(_isclose(A, B, rtol, atol, equal_nan)).item()
80 changes: 43 additions & 37 deletions tests/test_binary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,97 +659,103 @@ def test_accuracy_where_scalar_other(shape, scalar, dtype):
gems_assert_equal(res_out, ref_out)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES)
@pytest.mark.parametrize("shape", POINTWISE_SHAPES + [(128, 1024, 1024)])
@pytest.mark.parametrize(
"dtype", [torch.float64, torch.int64, torch.int8, torch.bool] + FLOAT_DTYPES + INT_DTYPES
)
@pytest.mark.parametrize("equal_nan", [False, True])
@pytest.mark.parametrize(
"gen_nan", [0, 1, 2, 3, 4]
) # 1: nan, 2: inf, 3: -inf, 4: inf vs -inf
def test_accuracy_isclose(shape, dtype, equal_nan, gen_nan):
rtol = torch.rand(1, dtype=torch.float16, device="cuda").item()
atol = torch.rand(1, dtype=torch.bfloat16, device="cuda").item()
rtol = torch.rand(1, dtype=torch.bfloat16, device="cuda").item()
atol = torch.rand(1, dtype=torch.float32, device="cuda").item()
if dtype in FLOAT_DTYPES:
inp1 = torch.randn(shape, dtype=dtype, device="cuda")
inp2 = torch.randn(shape, dtype=dtype, device="cuda")
if gen_nan:
nan_num = torch.full(
(1,),
float("nan" if gen_nan == 1 else "inf"),
dtype=dtype,
device="cuda",
(1,), float("nan" if gen_nan == 1 else "inf"), dtype=dtype, device="cuda"
)
inp1.view(-1)[0] = -nan_num if gen_nan == 3 else nan_num
inp2.view(-1)[0] = -nan_num if gen_nan >= 3 else nan_num
else:
atol *= 10
inp1 = torch.randint(-1000, 1000, shape, device="cuda").to(dtype)
inp2 = torch.randint(-1000, 1000, shape, device="cuda").to(dtype)
ref_inp1 = to_reference(inp1, True)
ref_inp2 = to_reference(inp2, True)

ref_inp1 = to_reference(inp1, False)
ref_inp2 = to_reference(inp2, False)
logging.debug(
"shape={}, dtype={}, rtol={}, atol={}".format(shape, dtype, rtol, atol)
)

with flag_gems.use_gems():
res_out = torch.isclose(inp1, inp2, rtol, atol, equal_nan=equal_nan)
ref_out = torch.isclose(ref_inp1, ref_inp2, rtol, atol, equal_nan=equal_nan)

inp1_flat = inp1.view(-1)
inp2_flat = inp2.view(-1)
ref_flat = ref_out.view(-1)
res_flat = res_out.view(-1)
if dtype in FLOAT_DTYPES and gen_nan:
logging.debug(
"equal_nan={}, gen_nan={}: inp1={}, inp2={}, res={}, ref={}".format(
equal_nan,
gen_nan,
inp1.view(-1)[0],
inp2.view(-1)[0],
res_out.view(-1)[0],
ref_out.view(-1)[0],
inp1_flat[0],
inp2_flat[0],
res_flat[0],
ref_flat[0],
)
)
gems_assert_equal(res_out, ref_out)
if dtype in [torch.float32, torch.float16, torch.bfloat16]:
if dtype == torch.float32:
rate = 0.000001
elif dtype == torch.float16:
rate = 0.001
else:
rate = 0.01
err_num = torch.sum(torch.ne(res_out, ref_out)).item()
logging.debug("err_num = {}, err_rate = {}".format(err_num, err_num / ref_out.numel()))
assert err_num <= max(1, ref_out.numel() * rate)
else:
gems_assert_equal(res_out, ref_out)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES)
@pytest.mark.parametrize(
"dtype", [torch.float64, torch.int64, torch.int8, torch.bool] + FLOAT_DTYPES + INT_DTYPES
)
@pytest.mark.parametrize("equal_nan", [False, True])
@pytest.mark.parametrize(
"gen_nan", [0, 1, 2, 3, 4]
) # 1: nan, 2: inf, 3: -inf, 4: inf vs -inf
def test_accuracy_allclose(shape, dtype, equal_nan, gen_nan):
rtol = torch.rand(1, dtype=torch.float16, device="cuda").item()
atol = torch.rand(1, dtype=torch.bfloat16, device="cuda").item()
rtol = torch.rand(1, dtype=torch.bfloat16, device="cuda").item()
atol = torch.rand(1, dtype=torch.float32, device="cuda").item()
if dtype in FLOAT_DTYPES:
inp1 = torch.randn(shape, dtype=dtype, device="cuda")
inp2 = torch.randn(shape, dtype=dtype, device="cuda")
inp1 = torch.full(shape, 1.234, dtype=dtype, device="cuda")
inp2 = torch.full(shape, 1.234, dtype=dtype, device="cuda")
if gen_nan:
nan_num = torch.full(
(1,),
float("nan" if gen_nan == 1 else "inf"),
dtype=dtype,
device="cuda",
(1,), float("nan" if gen_nan == 1 else "inf"), dtype=dtype, device="cuda"
)
inp1.view(-1)[0] = -nan_num if gen_nan == 3 else nan_num
inp2.view(-1)[0] = -nan_num if gen_nan >= 3 else nan_num
else:
atol *= 10
inp1 = torch.randint(-1000, 1000, shape, device="cuda").to(dtype)
inp2 = torch.randint(-1000, 1000, shape, device="cuda").to(dtype)
ref_inp1 = to_reference(inp1, True)
ref_inp2 = to_reference(inp2, True)

ref_inp1 = to_reference(inp1, False)
ref_inp2 = to_reference(inp2, False)
logging.debug(
"shape={}, dtype={}, rtol={}, atol={}".format(shape, dtype, rtol, atol)
)

with flag_gems.use_gems():
res_out = torch.allclose(inp1, inp2, rtol, atol, equal_nan=equal_nan)
ref_out = torch.allclose(ref_inp1, ref_inp2, rtol, atol, equal_nan=equal_nan)
if dtype in FLOAT_DTYPES and gen_nan:
logging.debug(
"equal_nan={}, gen_nan={}: inp1={}, inp2={}, res={}, ref={}".format(
equal_nan,
gen_nan,
inp1.view(-1)[0],
inp2.view(-1)[0],
res_out,
ref_out,
)
)

assert res_out == ref_out

0 comments on commit 2aabb57

Please sign in to comment.