From f4b2495ca3b83d88744b17c7b66dd299c21dbafb Mon Sep 17 00:00:00 2001 From: yjl0101 Date: Fri, 20 Sep 2024 10:52:11 +0800 Subject: [PATCH] [Operator] Add vstack op (#175) --- benchmark/test_special_perf.py | 18 +++++ src/flag_gems/__init__.py | 1 + src/flag_gems/ops/__init__.py | 2 + src/flag_gems/ops/vstack.py | 141 +++++++++++++++++++++++++++++++++ tests/test_special_ops.py | 34 ++++++++ 5 files changed, 196 insertions(+) create mode 100644 src/flag_gems/ops/vstack.py diff --git a/benchmark/test_special_perf.py b/benchmark/test_special_perf.py index 6d26ff59..256595c1 100644 --- a/benchmark/test_special_perf.py +++ b/benchmark/test_special_perf.py @@ -270,3 +270,21 @@ def cat_kwargs(dtype, batch, size): kwargs_func=cat_kwargs, ) bench.run() + + +def test_perf_vstack(): + def vstack_args(dtype, batch, size): + inp1 = torch.randn(size=(batch, size), dtype=dtype, device="cuda") + inp2 = torch.randn(size=(batch + 1, size), dtype=dtype, device="cuda") + inp3 = torch.randn(size=(batch + 2, size), dtype=dtype, device="cuda") + return [[inp1, inp2, inp3]] + + bench = Benchmark( + op_name="vstack", + torch_op=torch.vstack, + arg_func=vstack_args, + dtypes=FLOAT_DTYPES, + batch=(512), + sizes=SIZES, + ) + bench.run() diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 933bf51a..114cc1ee 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -153,6 +153,7 @@ def enable(lib=aten_lib): lib.impl("hstack", hstack, "CUDA") lib.impl("cat", cat, "CUDA") lib.impl("repeat_interleave.self_int", repeat_interleave_self_int, "CUDA") + lib.impl("vstack", vstack, "CUDA") class use_gems: diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index cc32e3ff..c7e153f5 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -100,6 +100,7 @@ from .unique import _unique2 from .var_mean import var_mean from .vector_norm import vector_norm +from .vstack import vstack from .where import where_scalar_other, where_scalar_self, where_self from .zeros import zeros from .zeros_like import zeros_like @@ -237,4 +238,5 @@ "hstack", "cat", "repeat_interleave_self_int", + "vstack", ] diff --git a/src/flag_gems/ops/vstack.py b/src/flag_gems/ops/vstack.py new file mode 100644 index 00000000..39f9dc9a --- /dev/null +++ b/src/flag_gems/ops/vstack.py @@ -0,0 +1,141 @@ +import logging + +import torch +import triton +import triton.language as tl + +from ..utils import libentry + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": k}, num_warps=w) + for w in [4, 8, 16, 32] + for k in [512, 1024, 2048, 4096] + ], + key=[ + "max_tile_elems", + ], +) +@triton.jit +def vstack_kernel( + itensor_ptr0, + itensor_ptr1, + itensor_ptr2, + itensor_ptr3, + output_ptr, + local_row0, + local_row1, + local_row2, + local_row3, + exc_row_offset0, + exc_row_offset1, + exc_row_offset2, + exc_row_offset3, + total_row_offset, + row_stride, + max_tile_elems, + BLOCK_SIZE: tl.constexpr, +): + pid_x = tl.program_id(axis=0) + tensor_idx = tl.program_id(axis=1) + col_idx = tl.arange(0, BLOCK_SIZE) + + intensor_ptr = tl.where(tensor_idx == 0, itensor_ptr0, itensor_ptr1) + intensor_ptr = tl.where(tensor_idx == 2, itensor_ptr2, intensor_ptr) + intensor_ptr = tl.where(tensor_idx == 3, itensor_ptr3, intensor_ptr) + base_exc_row_idx = tl.where(tensor_idx == 0, exc_row_offset0, exc_row_offset1) + base_exc_row_idx = tl.where(tensor_idx == 2, exc_row_offset2, base_exc_row_idx) + base_exc_row_idx = tl.where(tensor_idx == 3, exc_row_offset3, base_exc_row_idx) + local_row = tl.where(tensor_idx == 0, local_row0, local_row1) + local_row = tl.where(tensor_idx == 2, local_row2, local_row) + local_row = tl.where(tensor_idx == 3, local_row3, local_row) + + end_idx = local_row * row_stride.to(tl.int64) + idx = (pid_x * BLOCK_SIZE + col_idx).to(tl.int64) + offset_mask = idx < end_idx + in_offset = intensor_ptr + idx + row_stride_offset = (total_row_offset + base_exc_row_idx) * row_stride.to(tl.int64) + out_offset = output_ptr + row_stride_offset + idx + out = tl.load(in_offset, mask=offset_mask) + tl.store(out_offset, out, mask=offset_mask) + + +def vstack(tensors: list[torch.Tensor]): + logging.debug("GEMS VSTACK") + + tensors = torch.atleast_2d(tensors) + num_tensors = len(tensors) + assert num_tensors > 0 + + # Ensure all tensors are on the same device and have the same dtype + device = tensors[0].device + dtype = tensors[0].dtype + for tensor in tensors: + assert ( + tensor.device == device + and tensor.dtype == dtype + and tensors[0].shape[1:] == tensor.shape[1:] + ) + + c_tensors = [t.contiguous() for t in tensors] + # Calculate the output shape + total_rows = sum(tensor.shape[0] for tensor in c_tensors) + output_shape = list(c_tensors[0].shape) + output_shape[0] = total_rows + output = torch.empty(output_shape, device=device, dtype=dtype) + row_stride = c_tensors[0].stride(0) + + outer_iters = triton.cdiv(num_tensors, 4) + total_row_offset = 0 + for i in range(outer_iters): + max_rows = 1 + itensors = [] + exclusive_row = [] + local_row = [] + array_row_offset = 0 + scheduled_num_tensors = 0 + for j in range(4): + tensor_idx = i * 4 + j + if tensor_idx < num_tensors: + scheduled_num_tensors += 1 + itensors.append(c_tensors[tensor_idx]) + local_row.append(c_tensors[tensor_idx].shape[0]) + exclusive_row.append(array_row_offset) + array_row_offset += c_tensors[tensor_idx].shape[0] + max_rows = max(max_rows, c_tensors[tensor_idx].shape[0]) + else: + empty_tensor = torch.empty( + 0, dtype=c_tensors[0].dtype, device=c_tensors[0].device + ) + itensors.append(empty_tensor) + local_row.append(local_row[-1]) + exclusive_row.append(exclusive_row[-1]) + max_tile_elems = max_rows * row_stride + grid = lambda META: ( + triton.cdiv(max_tile_elems, META["BLOCK_SIZE"]), + scheduled_num_tensors, + ) + # Launch the kernel + with torch.cuda.device(c_tensors[0].device): + vstack_kernel[grid]( + itensors[0], + itensors[1], + itensors[2], + itensors[3], + output, + local_row[0], + local_row[1], + local_row[2], + local_row[3], + exclusive_row[0], + exclusive_row[1], + exclusive_row[2], + exclusive_row[3], + total_row_offset, + row_stride, + max_tile_elems, + ) + total_row_offset += array_row_offset + return output diff --git a/tests/test_special_ops.py b/tests/test_special_ops.py index ae5b07e3..4ecc467c 100644 --- a/tests/test_special_ops.py +++ b/tests/test_special_ops.py @@ -652,3 +652,37 @@ def test_accuracy_cat(shape, dim, dtype): with flag_gems.use_gems(): res_out = torch.cat(inp, dim) gems_assert_equal(res_out, ref_out) + + +VSTACK_SHAPES = [ + [(3,), (3,)], + [(3, 33), (7, 33)], + [(13, 3, 333), (17, 3, 333), (7, 3, 333)], + [ + (13, 3, 64, 5, 2), + (16, 3, 64, 5, 2), + (7, 3, 64, 5, 2), + (4, 3, 64, 5, 2), + (1, 3, 64, 5, 2), + ], +] + + +@pytest.mark.parametrize("shape", VSTACK_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES) +def test_accuracy_vstack(shape, dtype): + if dtype in FLOAT_DTYPES: + inp = [torch.randn(s, dtype=dtype, device="cuda") for s in shape] + else: + inp = [ + torch.randint(low=0, high=0x7FFF, size=s, dtype=dtype, device="cuda").to( + dtype + ) + for s in shape + ] + ref_inp = [to_reference(_) for _ in inp] + ref_out = torch.vstack(ref_inp) + + with flag_gems.use_gems(): + res_out = torch.vstack(inp) + gems_assert_equal(res_out, ref_out)