From 728540db2ea7394a0c3f46828a8ab5e18fff92bc Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Fri, 21 Jun 2024 00:14:25 -0700 Subject: [PATCH] Add variable seqlen and sparsity parameters to jagged_sum benchmark Summary: Modify existing `jagged_sum` operator benchmark to optionally accept any of the following parameters: `B` (dimension 0 of nested tensor), `M` (dimension 2 of nested tensor), `seqlen` (maximum sequence length on ragged dimension), or `sparsity` (average sparsity on ragged dimension). This diff fixes the provided command line parameters and varies all other parameters above, enabling testing of all combinations of multiple parameters in parallel. The following errors persist with sufficiently large inputs: - `RuntimeError: numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64` (when running command `buck2 run mode/{opt,inplace} //pytorch/benchmark:triton -- --op jagged_sum --B 1024 --M 1024 --sparsity 0.3`) - `torch.OutOfMemoryError: CUDA out of memory.` Reviewed By: davidberard98 Differential Revision: D58772201 --- .../operators/jagged_sum/kernels.py | 8 +- .../operators/jagged_sum/operator.py | 119 +++++++++++------- 2 files changed, 83 insertions(+), 44 deletions(-) diff --git a/torchbenchmark/operators/jagged_sum/kernels.py b/torchbenchmark/operators/jagged_sum/kernels.py index ea71634aa..23f920df2 100644 --- a/torchbenchmark/operators/jagged_sum/kernels.py +++ b/torchbenchmark/operators/jagged_sum/kernels.py @@ -59,7 +59,9 @@ def triton_jagged_sum_kernel_simple_fused_sum_then_buffer( for block_pos in range( 0, MAX_SEQLEN, BLOCK_SIZE_RAGGED ): # loop over ragged dimension, ranging until maximum seqlen - block_start_ragged = ragged_start + block_pos # offset block position by start of current program + block_start_ragged = ( + ragged_start + block_pos + ) # offset block position by start of current program offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED) mask_ragged = offsets_ragged < ragged_end @@ -132,7 +134,9 @@ def triton_jagged_sum_kernel_simple_fused_buffer_then_sum( for block_pos in range( 0, MAX_SEQLEN, BLOCK_SIZE_RAGGED ): # loop over ragged dimension, ranging until maximum seqlen - block_start_ragged = ragged_start + block_pos # offset block position by start of current program + block_start_ragged = ( + ragged_start + block_pos + ) # offset block position by start of current program offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED) mask_ragged = offsets_ragged < ragged_end diff --git a/torchbenchmark/operators/jagged_sum/operator.py b/torchbenchmark/operators/jagged_sum/operator.py index 775982698..1b63a8071 100644 --- a/torchbenchmark/operators/jagged_sum/operator.py +++ b/torchbenchmark/operators/jagged_sum/operator.py @@ -31,17 +31,25 @@ def parse_op_args(args: List[str]): parser = argparse.ArgumentParser() + parser.add_argument( + "--B", + type=int, + help="[Optional] Size of dimension 0 in shape (B, *, M) (integer)", + ) + parser.add_argument( + "--M", + type=int, + help="[Optional] Size of dimension 2 in shape (B, *, M) (integer)", + ) parser.add_argument( "--seqlen", type=int, - default=500, - help="Maximum sequence length on ragged dimension (integer)", + help="[Optional] Maximum sequence length on ragged dimension (integer)", ) parser.add_argument( "--sparsity", type=float, - default=0.5, - help="Average sparsity for nested tensor (float, (0.0-1.0))", + help="[Optional] Average sparsity for nested tensor (float, (0.0-1.0))", ) parser.add_argument( "--sum-then-buffer", @@ -91,12 +99,16 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = Non ) # bias towards larger sizes, which are more representative of real-world shapes args = parse_op_args(self.extra_args) - self.seqlen = args.seqlen - self.sparsity = args.sparsity + self.B = args.B if args.B is not None else None + self.M = args.M if args.M is not None else None + self.seqlen = args.seqlen if args.seqlen is not None else None + self.sparsity = args.sparsity if args.sparsity is not None else None self.sum_then_buffer = args.sum_then_buffer @register_benchmark(baseline=True) - def torch_jagged_sum_no_pad(self, x: torch.Tensor): + def torch_jagged_sum_no_pad( + self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float + ): return lambda: torch.tensor( [ torch.sum(t, dim=0).tolist() for t in x.unbind() @@ -106,66 +118,87 @@ def torch_jagged_sum_no_pad(self, x: torch.Tensor): ) @register_benchmark() - def torch_jagged_sum_pad(self, x: torch.Tensor): + def torch_jagged_sum_pad( + self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float + ): return lambda: torch.sum( torch.ops.aten._jagged_to_padded_dense_forward( x.values(), [x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`. - max_lengths=[self.seqlen], # max length of ragged dimension + max_lengths=[seqlen], # max length of ragged dimension ), dim=1, ) # sum along ragged dimension (dim == 1) @register_benchmark() - def triton_jagged_sum_no_pad(self, x: torch.Tensor): + def triton_jagged_sum_no_pad( + self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float + ): def _inner(): - return execute_kernel_simple_fused(x, self.seqlen, self.sum_then_buffer) + return execute_kernel_simple_fused(x, seqlen, self.sum_then_buffer) return _inner def get_x_val(self, example_inputs): return len(example_inputs[0]) - def get_x_vals(self) -> Tuple[List[int], List[int]]: - B_vals, M_vals = [], [] - - B_vals.extend([2**n for n in self.sizes]) - B_vals.extend( - [ - (n - 1) * (n + 1) - for n in self.sizes - if n - 1 > 0 and (n - 1) * (n + 1) not in B_vals - ] - ) + def get_x_vals(self) -> Tuple[List[int], List[int], List[int], List[float]]: + B_vals, M_vals, seqlen_vals, sparsity_vals = [], [], [], [] + + def get_dim_vals(): + vals = [] + vals.extend([2**n for n in self.sizes]) + vals.extend( + [ + (n - 1) * (n + 1) + for n in self.sizes + if n - 1 > 0 and (n - 1) * (n + 1) not in vals + ] + ) + return vals + + if self.B is None: + B_vals.extend(get_dim_vals()) + else: + B_vals.extend([self.B]) + + if self.M is None: + M_vals.extend(get_dim_vals()) + else: + M_vals.extend([self.M]) + + if self.seqlen is None: + seqlen_vals.extend( + list(range(100, 1000, 100)) + + list(range(1000, 10000, 1000)) + ) + else: + seqlen_vals.extend([self.seqlen]) - M_vals.extend([2**n for n in self.sizes]) - M_vals.extend( - [ - (n - 1) * (n + 1) - for n in self.sizes - if n - 1 > 0 and (n - 1) * (n + 1) not in M_vals - ] - ) + if self.sparsity is None: + sparsity_vals.extend([n / 10 for n in range(1, 10)]) + else: + sparsity_vals.extend([self.sparsity]) - return B_vals, M_vals + return B_vals, M_vals, seqlen_vals, sparsity_vals def get_input_iter(self) -> Generator: """ Generate random nested tensors of shape (B, *, M), where * is the ragged dimension """ - B_vals, M_vals = self.get_x_vals() - B_M_vals = itertools.product(B_vals, M_vals) + B_vals, M_vals, seqlen_vals, sparsity_vals = self.get_x_vals() + vals = itertools.product(B_vals, M_vals, seqlen_vals, sparsity_vals) - for B, M in B_M_vals: + for B, M, seqlen, sparsity in vals: tensors = [] # greater sparsity --> shorter sequence lengths on ragged dimension seqlen_avg = math.floor( - self.seqlen * (1 - self.sparsity) + seqlen * (1 - sparsity) ) # average sequence length across all tensors in nested tensor seqlen_margin = math.floor( - self.seqlen * RANDOM_CHOICE_MARGIN + seqlen * RANDOM_CHOICE_MARGIN ) # use margin to constrain sequence lengths to range [seqlen_avg - seqlen_margin, seqlen_avg + seqlen_margin] to approximate an average sequence length, which correlates with sparsity for _ in range(B): @@ -174,7 +207,7 @@ def get_input_iter(self) -> Generator: seqlen_avg - seqlen_margin, 1 ), # seqlen_randint must be at least 1 min( - seqlen_avg + seqlen_margin, self.seqlen + seqlen_avg + seqlen_margin, seqlen ), # seqlen_randint must not exceed self.seqlen ) tensor_2d = torch.randn( @@ -189,7 +222,7 @@ def get_input_iter(self) -> Generator: dtype=self.dtype, ) - yield (nt,) + yield (nt, B, M, seqlen, sparsity) def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool: output = fn() @@ -205,15 +238,17 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): * GIGABYTES_PER_BYTE ) - @register_metric(x_only=True) + @register_metric(x_only=True) # TODO modify!!!! def input_shape( self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics ): return ( - example_inputs[0].shape[0], + f"B: {example_inputs[1]}", # B "*", - example_inputs[0].shape[2], - ) # return (B, '*', M) for each example input + f"M: {example_inputs[2]}", # M + f"max seqlen: {example_inputs[3]}", # seqlen + f"sparsity: {example_inputs[4]}", # sparsity + ) # return (B, '*', M, max seqlen, sparsity) for each example input @register_metric(skip_baseline=True) def best_config(