diff --git a/benchmark/test_pointwise_perf.py b/benchmark/test_pointwise_perf.py index cd4b4a6f..66f67584 100644 --- a/benchmark/test_pointwise_perf.py +++ b/benchmark/test_pointwise_perf.py @@ -630,3 +630,20 @@ def repeat_arg(dtype, batch, size): sizes=SIZES, ) bench.run() + + +def test_perf_repeat_interleave_self_int(): + def repeat_interleave_self_int_arg(dtype, batch, size): + inp = torch.randn([batch, size], dtype=dtype, device="cuda") + repeats = 2 + return inp, repeats + + bench = Benchmark( + op_name="repeat_interleave_self_int", + torch_op=torch.repeat_interleave, + arg_func=repeat_interleave_self_int_arg, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 2b59ed98..933bf51a 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -152,6 +152,7 @@ def enable(lib=aten_lib): lib.impl("stack", stack, "CUDA") lib.impl("hstack", hstack, "CUDA") lib.impl("cat", cat, "CUDA") + lib.impl("repeat_interleave.self_int", repeat_interleave_self_int, "CUDA") class use_gems: diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index da4b396f..cc32e3ff 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -79,6 +79,7 @@ from .reciprocal import reciprocal from .relu import relu from .repeat import repeat +from .repeat_interleave import repeat_interleave_self_int from .resolve_conj import resolve_conj from .resolve_neg import resolve_neg from .rms_norm import rms_norm @@ -235,4 +236,5 @@ "stack", "hstack", "cat", + "repeat_interleave_self_int", ] diff --git a/src/flag_gems/ops/repeat_interleave.py b/src/flag_gems/ops/repeat_interleave.py new file mode 100644 index 00000000..84c3bea6 --- /dev/null +++ b/src/flag_gems/ops/repeat_interleave.py @@ -0,0 +1,55 @@ +import torch +import triton + +from ..utils.pointwise_dynamic import pointwise_dynamic +from ..utils.shape_utils import c_contiguous_stride +from ..utils.tensor_wrapper import StridedBuffer + + +@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")]) +@triton.jit +def copy_func(x): + return x + + +def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None): + if dim is None: + inp = inp.flatten() + dim = 0 + else: + if (dim < -inp.ndim) or (dim >= inp.ndim): + raise IndexError( + "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( + -inp.ndim, inp.ndim - 1, dim + ) + ) + inp_shape = list(inp.shape) + inp_stride = list(inp.stride()) + output_shape = list(inp.shape) + + if dim < 0: + dim = dim + len(inp_shape) + + output_shape[dim] *= repeats + + if output_size is not None and output_size != output_shape[dim]: + raise RuntimeError( + "repeat_interleave: Invalid output_size, expected {} but got {}".format( + output_shape[dim], output_size + ) + ) + + output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device) + + if repeats == 0: + return output + + in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :] + out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :] + out_view_stride = c_contiguous_stride(out_view_shape) + + in_view = StridedBuffer(inp, out_view_shape, in_view_stride) + out_view = StridedBuffer(output, out_view_shape, out_view_stride) + ndim = len(out_view_shape) + copy_func.instantiate(ndim)(in_view, out0=out_view) + return output diff --git a/tests/test_binary_pointwise_ops.py b/tests/test_binary_pointwise_ops.py index 52a110e2..d1fc4554 100644 --- a/tests/test_binary_pointwise_ops.py +++ b/tests/test_binary_pointwise_ops.py @@ -1032,3 +1032,28 @@ def test_accuracy_allclose(shape, dtype, equal_nan, gen_nan): ref_out = torch.allclose(ref_inp1, ref_inp2, rtol, atol, equal_nan=equal_nan) assert res_out == ref_out + + +REPEAT_INTERLEAVE_SHAPES = [ + (1,), + (1024, 1024), + (20, 320, 15), + (16, 128, 64, 60), + (16, 7, 57, 32, 29), +] +REPEAT_INTERLEAVE_REPEATS = [2] +REPEAT_INTERLEAVE_DIM = [-1, 0, None] + + +@pytest.mark.parametrize("shape", REPEAT_INTERLEAVE_SHAPES) +@pytest.mark.parametrize("dim", REPEAT_INTERLEAVE_DIM) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_repeat_interleave_self_int(shape, dim, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + repeats = 2 + ref_inp = to_reference(inp) + + ref_out = torch.repeat_interleave(ref_inp, repeats, dim) + with flag_gems.use_gems(): + res_out = torch.repeat_interleave(ref_inp, repeats, dim) + gems_assert_equal(res_out, ref_out)