diff --git a/torchbenchmark/operators/jagged_sum/kernels.py b/torchbenchmark/operators/jagged_sum/kernels.py index ea71634aa..23f920df2 100644 --- a/torchbenchmark/operators/jagged_sum/kernels.py +++ b/torchbenchmark/operators/jagged_sum/kernels.py @@ -59,7 +59,9 @@ def triton_jagged_sum_kernel_simple_fused_sum_then_buffer( for block_pos in range( 0, MAX_SEQLEN, BLOCK_SIZE_RAGGED ): # loop over ragged dimension, ranging until maximum seqlen - block_start_ragged = ragged_start + block_pos # offset block position by start of current program + block_start_ragged = ( + ragged_start + block_pos + ) # offset block position by start of current program offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED) mask_ragged = offsets_ragged < ragged_end @@ -132,7 +134,9 @@ def triton_jagged_sum_kernel_simple_fused_buffer_then_sum( for block_pos in range( 0, MAX_SEQLEN, BLOCK_SIZE_RAGGED ): # loop over ragged dimension, ranging until maximum seqlen - block_start_ragged = ragged_start + block_pos # offset block position by start of current program + block_start_ragged = ( + ragged_start + block_pos + ) # offset block position by start of current program offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED) mask_ragged = offsets_ragged < ragged_end diff --git a/torchbenchmark/operators/jagged_sum/operator.py b/torchbenchmark/operators/jagged_sum/operator.py index 775982698..1b63a8071 100644 --- a/torchbenchmark/operators/jagged_sum/operator.py +++ b/torchbenchmark/operators/jagged_sum/operator.py @@ -31,17 +31,25 @@ def parse_op_args(args: List[str]): parser = argparse.ArgumentParser() + parser.add_argument( + "--B", + type=int, + help="[Optional] Size of dimension 0 in shape (B, *, M) (integer)", + ) + parser.add_argument( + "--M", + type=int, + help="[Optional] Size of dimension 2 in shape (B, *, M) (integer)", + ) parser.add_argument( "--seqlen", type=int, - default=500, - help="Maximum sequence length on ragged dimension (integer)", + help="[Optional] Maximum sequence length on ragged dimension (integer)", ) parser.add_argument( "--sparsity", type=float, - default=0.5, - help="Average sparsity for nested tensor (float, (0.0-1.0))", + help="[Optional] Average sparsity for nested tensor (float, (0.0-1.0))", ) parser.add_argument( "--sum-then-buffer", @@ -91,12 +99,16 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = Non ) # bias towards larger sizes, which are more representative of real-world shapes args = parse_op_args(self.extra_args) - self.seqlen = args.seqlen - self.sparsity = args.sparsity + self.B = args.B if args.B is not None else None + self.M = args.M if args.M is not None else None + self.seqlen = args.seqlen if args.seqlen is not None else None + self.sparsity = args.sparsity if args.sparsity is not None else None self.sum_then_buffer = args.sum_then_buffer @register_benchmark(baseline=True) - def torch_jagged_sum_no_pad(self, x: torch.Tensor): + def torch_jagged_sum_no_pad( + self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float + ): return lambda: torch.tensor( [ torch.sum(t, dim=0).tolist() for t in x.unbind() @@ -106,66 +118,87 @@ def torch_jagged_sum_no_pad(self, x: torch.Tensor): ) @register_benchmark() - def torch_jagged_sum_pad(self, x: torch.Tensor): + def torch_jagged_sum_pad( + self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float + ): return lambda: torch.sum( torch.ops.aten._jagged_to_padded_dense_forward( x.values(), [x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`. - max_lengths=[self.seqlen], # max length of ragged dimension + max_lengths=[seqlen], # max length of ragged dimension ), dim=1, ) # sum along ragged dimension (dim == 1) @register_benchmark() - def triton_jagged_sum_no_pad(self, x: torch.Tensor): + def triton_jagged_sum_no_pad( + self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float + ): def _inner(): - return execute_kernel_simple_fused(x, self.seqlen, self.sum_then_buffer) + return execute_kernel_simple_fused(x, seqlen, self.sum_then_buffer) return _inner def get_x_val(self, example_inputs): return len(example_inputs[0]) - def get_x_vals(self) -> Tuple[List[int], List[int]]: - B_vals, M_vals = [], [] - - B_vals.extend([2**n for n in self.sizes]) - B_vals.extend( - [ - (n - 1) * (n + 1) - for n in self.sizes - if n - 1 > 0 and (n - 1) * (n + 1) not in B_vals - ] - ) + def get_x_vals(self) -> Tuple[List[int], List[int], List[int], List[float]]: + B_vals, M_vals, seqlen_vals, sparsity_vals = [], [], [], [] + + def get_dim_vals(): + vals = [] + vals.extend([2**n for n in self.sizes]) + vals.extend( + [ + (n - 1) * (n + 1) + for n in self.sizes + if n - 1 > 0 and (n - 1) * (n + 1) not in vals + ] + ) + return vals + + if self.B is None: + B_vals.extend(get_dim_vals()) + else: + B_vals.extend([self.B]) + + if self.M is None: + M_vals.extend(get_dim_vals()) + else: + M_vals.extend([self.M]) + + if self.seqlen is None: + seqlen_vals.extend( + list(range(100, 1000, 100)) + + list(range(1000, 10000, 1000)) + ) + else: + seqlen_vals.extend([self.seqlen]) - M_vals.extend([2**n for n in self.sizes]) - M_vals.extend( - [ - (n - 1) * (n + 1) - for n in self.sizes - if n - 1 > 0 and (n - 1) * (n + 1) not in M_vals - ] - ) + if self.sparsity is None: + sparsity_vals.extend([n / 10 for n in range(1, 10)]) + else: + sparsity_vals.extend([self.sparsity]) - return B_vals, M_vals + return B_vals, M_vals, seqlen_vals, sparsity_vals def get_input_iter(self) -> Generator: """ Generate random nested tensors of shape (B, *, M), where * is the ragged dimension """ - B_vals, M_vals = self.get_x_vals() - B_M_vals = itertools.product(B_vals, M_vals) + B_vals, M_vals, seqlen_vals, sparsity_vals = self.get_x_vals() + vals = itertools.product(B_vals, M_vals, seqlen_vals, sparsity_vals) - for B, M in B_M_vals: + for B, M, seqlen, sparsity in vals: tensors = [] # greater sparsity --> shorter sequence lengths on ragged dimension seqlen_avg = math.floor( - self.seqlen * (1 - self.sparsity) + seqlen * (1 - sparsity) ) # average sequence length across all tensors in nested tensor seqlen_margin = math.floor( - self.seqlen * RANDOM_CHOICE_MARGIN + seqlen * RANDOM_CHOICE_MARGIN ) # use margin to constrain sequence lengths to range [seqlen_avg - seqlen_margin, seqlen_avg + seqlen_margin] to approximate an average sequence length, which correlates with sparsity for _ in range(B): @@ -174,7 +207,7 @@ def get_input_iter(self) -> Generator: seqlen_avg - seqlen_margin, 1 ), # seqlen_randint must be at least 1 min( - seqlen_avg + seqlen_margin, self.seqlen + seqlen_avg + seqlen_margin, seqlen ), # seqlen_randint must not exceed self.seqlen ) tensor_2d = torch.randn( @@ -189,7 +222,7 @@ def get_input_iter(self) -> Generator: dtype=self.dtype, ) - yield (nt,) + yield (nt, B, M, seqlen, sparsity) def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool: output = fn() @@ -205,15 +238,17 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): * GIGABYTES_PER_BYTE ) - @register_metric(x_only=True) + @register_metric(x_only=True) # TODO modify!!!! def input_shape( self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics ): return ( - example_inputs[0].shape[0], + f"B: {example_inputs[1]}", # B "*", - example_inputs[0].shape[2], - ) # return (B, '*', M) for each example input + f"M: {example_inputs[2]}", # M + f"max seqlen: {example_inputs[3]}", # seqlen + f"sparsity: {example_inputs[4]}", # sparsity + ) # return (B, '*', M, max seqlen, sparsity) for each example input @register_metric(skip_baseline=True) def best_config(