Skip to content

Commit

Permalink
Merge pull request #52 from Bowen12992/where_op
Browse files Browse the repository at this point in the history
Add where op
  • Loading branch information
Bowen12992 authored Jun 13, 2024
2 parents 19fb335 + ff000eb commit 57d4612
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 2 deletions.
20 changes: 20 additions & 0 deletions benchmark/test_pointwise_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,23 @@ def test_perf_triu(dtype):
sizes=SIZES,
)
bench.run()


def where_args(dtype, batch, size):
inp1 = torch.randn([size], dtype=dtype, device="cuda")
inp2 = torch.randn([size], dtype=dtype, device="cuda")
condition = inp1 > 0
return condition, inp1, inp2


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_perf_where(dtype):
bench = Benchmark(
op_name="where",
torch_op=torch.where,
arg_func=where_args,
dtype=dtype,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()
4 changes: 3 additions & 1 deletion src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def enable(lib=aten_lib):
lib.impl("triu", triu, "CUDA")
lib.impl("var_mean.correction", var_mean, "CUDA")
lib.impl("linalg_vector_norm", vector_norm, "CUDA")

lib.impl("where.self", where_self, "CUDA")
lib.impl("where.ScalarSelf", where_scalar_self, "CUDA")
lib.impl("where.ScalarOther", where_scalar_other, "CUDA")
lib.impl("max", max, "CUDA")
lib.impl("max.dim", max_dim, "CUDA")
lib.impl("min", min, "CUDA")
Expand Down
4 changes: 4 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

from .var_mean import var_mean
from .vector_norm import vector_norm
from .where import where_self, where_scalar_self, where_scalar_other

__all__ = [
"all",
Expand Down Expand Up @@ -136,4 +137,7 @@
"log_softmax",
"outer",
"cross_entropy_loss",
"where_self",
"where_scalar_self",
"where_scalar_other",
]
41 changes: 41 additions & 0 deletions src/flag_gems/ops/where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import triton
import triton.language as tl
import logging
from ..utils import pointwise_dynamic


@pointwise_dynamic(is_tensor=[True, True, True])
@triton.jit
def where_self_func(self, condition, other):
return tl.where(condition, self, other)


def where_self(condition, self, other):
logging.debug("GEMS WHERE_SELF")
O = where_self_func(self, condition, other)
return O


@pointwise_dynamic(is_tensor=[True, True, False])
@triton.jit
def where_scalar_self_func(other, condition, self):
return tl.where(condition, self, other)


def where_scalar_self(condition, self, other):
logging.debug("GEMS WHERE_SCALAR_SELF")
O = where_scalar_self_func(other, condition, self)
return O


@pointwise_dynamic(is_tensor=[True, True, False])
@triton.jit
def where_scalar_other_func(self, condition, other):
return tl.where(condition, self, other)


def where_scalar_other(condition, self, other):
logging.debug("GEMS WHERE_SCALAR_OTHER")
O = where_scalar_other_func(self, condition, other)
return O
2 changes: 1 addition & 1 deletion src/flag_gems/utils/pointwise_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
_check_sized_list(output_dtypes, num_outputs)
self._output_dtypes = output_dtypes
else:
self._output_dtypes = [None] * num_inputs # infer from the 1st input
self._output_dtypes = [None] * num_outputs # infer from the 1st input
elif output_dtypes is not None:
self._num_outputs = len(output_dtypes)
self._output_dtypes = output_dtypes
Expand Down
45 changes: 45 additions & 0 deletions tests/test_binary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,3 +600,48 @@ def test_accuracy_sub_scalar_tensor(shape, scalar, alpha, dtype):
res_out = torch.sub(inp1, inp2, alpha=alpha)

gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_where_self(shape, dtype):
inp1 = torch.randn(shape, dtype=dtype, device="cuda")
inp2 = torch.randn(shape, dtype=dtype, device="cuda")
ref_inp1 = to_reference(inp1)
ref_inp2 = to_reference(inp2)

ref_out = torch.where(ref_inp1 > 0, ref_inp1, ref_inp2)
with flag_gems.use_gems():
res_out = torch.where(inp1 > 0, inp1, inp2)

gems_assert_equal(res_out, ref_out)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("scalar", SCALARS)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_where_scalar_self(shape, scalar, dtype):
inp1 = scalar
inp2 = torch.randn(shape, dtype=dtype, device="cuda")
ref_inp2 = to_reference(inp2)

ref_out = torch.where(inp2 > 0, inp1, ref_inp2)
with flag_gems.use_gems():
res_out = torch.where(inp2 > 0, inp1, inp2)

gems_assert_equal(res_out, ref_out)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("scalar", SCALARS)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_where_scalar_other(shape, scalar, dtype):
inp1 = scalar
inp2 = torch.randn(shape, dtype=dtype, device="cuda")
ref_inp2 = to_reference(inp2)

ref_out = torch.where(inp2 > 0, ref_inp2, inp1)
with flag_gems.use_gems():
res_out = torch.where(inp2 > 0, inp2, inp1)

gems_assert_equal(res_out, ref_out)

0 comments on commit 57d4612

Please sign in to comment.