From 1e8f1aa6072e94dab1ef1c50b557a3a1a9af1fdc Mon Sep 17 00:00:00 2001 From: zfu82 Date: Sat, 14 Sep 2024 07:37:07 +0000 Subject: [PATCH 1/3] [Operator] Add repeat_interleave_self_int op --- benchmark/test_pointwise_perf.py | 17 +++++++ src/flag_gems/__init__.py | 1 + src/flag_gems/ops/__init__.py | 2 + src/flag_gems/ops/repeat_interleave.py | 64 ++++++++++++++++++++++++++ tests/test_binary_pointwise_ops.py | 18 ++++++++ 5 files changed, 102 insertions(+) create mode 100644 src/flag_gems/ops/repeat_interleave.py diff --git a/benchmark/test_pointwise_perf.py b/benchmark/test_pointwise_perf.py index cd4b4a6f..66f67584 100644 --- a/benchmark/test_pointwise_perf.py +++ b/benchmark/test_pointwise_perf.py @@ -630,3 +630,20 @@ def repeat_arg(dtype, batch, size): sizes=SIZES, ) bench.run() + + +def test_perf_repeat_interleave_self_int(): + def repeat_interleave_self_int_arg(dtype, batch, size): + inp = torch.randn([batch, size], dtype=dtype, device="cuda") + repeats = 2 + return inp, repeats + + bench = Benchmark( + op_name="repeat_interleave_self_int", + torch_op=torch.repeat_interleave, + arg_func=repeat_interleave_self_int_arg, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index df0fc095..ef07cf79 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -150,6 +150,7 @@ def enable(lib=aten_lib): lib.impl("masked_select", masked_select, "CUDA") lib.impl("stack", stack, "CUDA") lib.impl("hstack", hstack, "CUDA") + lib.impl("repeat_interleave.self_int", repeat_interleave_self_int, "CUDA") class use_gems: diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index fd556f2e..59fb5353 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -78,6 +78,7 @@ from .reciprocal import reciprocal from .relu import relu from .repeat import repeat +from .repeat_interleave import repeat_interleave_self_int from .resolve_conj import resolve_conj from .resolve_neg import resolve_neg from .rms_norm import rms_norm @@ -233,4 +234,5 @@ "masked_select", "stack", "hstack", + "repeat_interleave_self_int", ] diff --git a/src/flag_gems/ops/repeat_interleave.py b/src/flag_gems/ops/repeat_interleave.py new file mode 100644 index 00000000..6b49cc6e --- /dev/null +++ b/src/flag_gems/ops/repeat_interleave.py @@ -0,0 +1,64 @@ +import torch +import triton + +from ..utils.pointwise_dynamic import pointwise_dynamic +from ..utils.shape_utils import c_contiguous_stride, volume +from ..utils.tensor_wrapper import StridedBuffer + + +@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")]) +@triton.jit +def copy_func(x): + return x + + +def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None): + if dim is None: + nelems = volume(inp.shape) + inp_shape = [ + nelems, + ] + inp_stride = [ + 1, + ] + output_shape = [ + nelems, + ] + dim = 0 + else: + if (dim < -inp.ndim) or (dim >= inp.ndim): + raise IndexError( + "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( + -inp.ndim, inp.ndim - 1, dim + ) + ) + inp_shape = list(inp.shape) + inp_stride = list(inp.stride()) + output_shape = list(inp.shape) + + if dim < 0: + dim = dim + len(inp_shape) + + output_shape[dim] *= repeats + + if output_size is not None and output_size != output_shape[dim]: + raise RuntimeError( + "repeat_interleave: Invalid output_size, expected {} but got {}".format( + output_shape[dim], output_size + ) + ) + + output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device) + + if repeats == 0: + return output + + in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :] + out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :] + out_view_stride = c_contiguous_stride(out_view_shape) + + in_view = StridedBuffer(inp, out_view_shape, in_view_stride) + out_view = StridedBuffer(output, out_view_shape, out_view_stride) + ndim = len(out_view_shape) + copy_func.instantiate(ndim)(in_view, out0=out_view) + return output diff --git a/tests/test_binary_pointwise_ops.py b/tests/test_binary_pointwise_ops.py index fa3d53bd..73d7f7fb 100644 --- a/tests/test_binary_pointwise_ops.py +++ b/tests/test_binary_pointwise_ops.py @@ -976,3 +976,21 @@ def test_accuracy_allclose(shape, dtype, equal_nan, gen_nan): ref_out = torch.allclose(ref_inp1, ref_inp2, rtol, atol, equal_nan=equal_nan) assert res_out == ref_out + + +REPEAT_INTERLEAVE_REPEATS = [2] +REPEAT_INTERLEAVE_DIM = [-1, 0] + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dim", REPEAT_INTERLEAVE_DIM) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_repeat_interleave_self_int(shape, dim, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + repeats = 2 + ref_inp = to_reference(inp) + + ref_out = torch.repeat_interleave(ref_inp, repeats, dim) + with flag_gems.use_gems(): + res_out = torch.repeat_interleave(ref_inp, repeats, dim) + gems_assert_equal(res_out, ref_out) From e1aee9c85a1ab3dd88448b3fa33a5071227a6862 Mon Sep 17 00:00:00 2001 From: zfu82 Date: Wed, 18 Sep 2024 03:07:20 +0000 Subject: [PATCH 2/3] Fix case when dim=None --- src/flag_gems/ops/repeat_interleave.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/src/flag_gems/ops/repeat_interleave.py b/src/flag_gems/ops/repeat_interleave.py index 6b49cc6e..84c3bea6 100644 --- a/src/flag_gems/ops/repeat_interleave.py +++ b/src/flag_gems/ops/repeat_interleave.py @@ -2,7 +2,7 @@ import triton from ..utils.pointwise_dynamic import pointwise_dynamic -from ..utils.shape_utils import c_contiguous_stride, volume +from ..utils.shape_utils import c_contiguous_stride from ..utils.tensor_wrapper import StridedBuffer @@ -14,16 +14,7 @@ def copy_func(x): def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None): if dim is None: - nelems = volume(inp.shape) - inp_shape = [ - nelems, - ] - inp_stride = [ - 1, - ] - output_shape = [ - nelems, - ] + inp = inp.flatten() dim = 0 else: if (dim < -inp.ndim) or (dim >= inp.ndim): @@ -32,9 +23,9 @@ def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None): -inp.ndim, inp.ndim - 1, dim ) ) - inp_shape = list(inp.shape) - inp_stride = list(inp.stride()) - output_shape = list(inp.shape) + inp_shape = list(inp.shape) + inp_stride = list(inp.stride()) + output_shape = list(inp.shape) if dim < 0: dim = dim + len(inp_shape) From af93afb2fc5473d9b68bf4b475b77ca15c7e8a87 Mon Sep 17 00:00:00 2001 From: zfu82 Date: Wed, 18 Sep 2024 03:27:47 +0000 Subject: [PATCH 3/3] Add more test cases --- tests/test_binary_pointwise_ops.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/test_binary_pointwise_ops.py b/tests/test_binary_pointwise_ops.py index 38fc87f2..d1fc4554 100644 --- a/tests/test_binary_pointwise_ops.py +++ b/tests/test_binary_pointwise_ops.py @@ -1034,11 +1034,18 @@ def test_accuracy_allclose(shape, dtype, equal_nan, gen_nan): assert res_out == ref_out +REPEAT_INTERLEAVE_SHAPES = [ + (1,), + (1024, 1024), + (20, 320, 15), + (16, 128, 64, 60), + (16, 7, 57, 32, 29), +] REPEAT_INTERLEAVE_REPEATS = [2] -REPEAT_INTERLEAVE_DIM = [-1, 0] +REPEAT_INTERLEAVE_DIM = [-1, 0, None] -@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("shape", REPEAT_INTERLEAVE_SHAPES) @pytest.mark.parametrize("dim", REPEAT_INTERLEAVE_DIM) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_repeat_interleave_self_int(shape, dim, dtype):