Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add repeat_interleave_self_int op #214

Merged
merged 4 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions benchmark/test_pointwise_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,20 @@ def repeat_arg(dtype, batch, size):
sizes=SIZES,
)
bench.run()


def test_perf_repeat_interleave_self_int():
def repeat_interleave_self_int_arg(dtype, batch, size):
inp = torch.randn([batch, size], dtype=dtype, device="cuda")
repeats = 2
return inp, repeats

bench = Benchmark(
op_name="repeat_interleave_self_int",
torch_op=torch.repeat_interleave,
arg_func=repeat_interleave_self_int_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def enable(lib=aten_lib):
lib.impl("stack", stack, "CUDA")
lib.impl("hstack", hstack, "CUDA")
lib.impl("cat", cat, "CUDA")
lib.impl("repeat_interleave.self_int", repeat_interleave_self_int, "CUDA")


class use_gems:
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from .reciprocal import reciprocal
from .relu import relu
from .repeat import repeat
from .repeat_interleave import repeat_interleave_self_int
from .resolve_conj import resolve_conj
from .resolve_neg import resolve_neg
from .rms_norm import rms_norm
Expand Down Expand Up @@ -235,4 +236,5 @@
"stack",
"hstack",
"cat",
"repeat_interleave_self_int",
]
55 changes: 55 additions & 0 deletions src/flag_gems/ops/repeat_interleave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
import triton

from ..utils.pointwise_dynamic import pointwise_dynamic
from ..utils.shape_utils import c_contiguous_stride
from ..utils.tensor_wrapper import StridedBuffer


@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")])
@triton.jit
def copy_func(x):
return x


def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None):
if dim is None:
inp = inp.flatten()
dim = 0
else:
if (dim < -inp.ndim) or (dim >= inp.ndim):
raise IndexError(
"Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
-inp.ndim, inp.ndim - 1, dim
)
)
inp_shape = list(inp.shape)
inp_stride = list(inp.stride())
output_shape = list(inp.shape)

if dim < 0:
dim = dim + len(inp_shape)

output_shape[dim] *= repeats

if output_size is not None and output_size != output_shape[dim]:
raise RuntimeError(
"repeat_interleave: Invalid output_size, expected {} but got {}".format(
output_shape[dim], output_size
)
)

output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device)

if repeats == 0:
return output

in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :]
out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :]
out_view_stride = c_contiguous_stride(out_view_shape)

in_view = StridedBuffer(inp, out_view_shape, in_view_stride)
out_view = StridedBuffer(output, out_view_shape, out_view_stride)
ndim = len(out_view_shape)
copy_func.instantiate(ndim)(in_view, out0=out_view)
return output
25 changes: 25 additions & 0 deletions tests/test_binary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,3 +1032,28 @@ def test_accuracy_allclose(shape, dtype, equal_nan, gen_nan):
ref_out = torch.allclose(ref_inp1, ref_inp2, rtol, atol, equal_nan=equal_nan)

assert res_out == ref_out


REPEAT_INTERLEAVE_SHAPES = [
(1,),
(1024, 1024),
(20, 320, 15),
(16, 128, 64, 60),
(16, 7, 57, 32, 29),
]
REPEAT_INTERLEAVE_REPEATS = [2]
REPEAT_INTERLEAVE_DIM = [-1, 0, None]


@pytest.mark.parametrize("shape", REPEAT_INTERLEAVE_SHAPES)
@pytest.mark.parametrize("dim", REPEAT_INTERLEAVE_DIM)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_repeat_interleave_self_int(shape, dim, dtype):
inp = torch.randn(shape, dtype=dtype, device="cuda")
repeats = 2
ref_inp = to_reference(inp)

ref_out = torch.repeat_interleave(ref_inp, repeats, dim)
with flag_gems.use_gems():
res_out = torch.repeat_interleave(ref_inp, repeats, dim)
gems_assert_equal(res_out, ref_out)