Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add variable seqlen and sparsity parameters to jagged_sum benchmark #2324

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions torchbenchmark/operators/jagged_sum/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def triton_jagged_sum_kernel_simple_fused_sum_then_buffer(
for block_pos in range(
0, MAX_SEQLEN, BLOCK_SIZE_RAGGED
): # loop over ragged dimension, ranging until maximum seqlen
block_start_ragged = ragged_start + block_pos # offset block position by start of current program
block_start_ragged = (
ragged_start + block_pos
) # offset block position by start of current program
offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED)
mask_ragged = offsets_ragged < ragged_end

Expand Down Expand Up @@ -132,7 +134,9 @@ def triton_jagged_sum_kernel_simple_fused_buffer_then_sum(
for block_pos in range(
0, MAX_SEQLEN, BLOCK_SIZE_RAGGED
): # loop over ragged dimension, ranging until maximum seqlen
block_start_ragged = ragged_start + block_pos # offset block position by start of current program
block_start_ragged = (
ragged_start + block_pos
) # offset block position by start of current program
offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED)
mask_ragged = offsets_ragged < ragged_end

Expand Down
119 changes: 77 additions & 42 deletions torchbenchmark/operators/jagged_sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,25 @@

def parse_op_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"--B",
type=int,
help="[Optional] Size of dimension 0 in shape (B, *, M) (integer)",
)
parser.add_argument(
"--M",
type=int,
help="[Optional] Size of dimension 2 in shape (B, *, M) (integer)",
)
parser.add_argument(
"--seqlen",
type=int,
default=500,
help="Maximum sequence length on ragged dimension (integer)",
help="[Optional] Maximum sequence length on ragged dimension (integer)",
)
parser.add_argument(
"--sparsity",
type=float,
default=0.5,
help="Average sparsity for nested tensor (float, (0.0-1.0))",
help="[Optional] Average sparsity for nested tensor (float, (0.0-1.0))",
)
parser.add_argument(
"--sum-then-buffer",
Expand Down Expand Up @@ -91,12 +99,16 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = Non
) # bias towards larger sizes, which are more representative of real-world shapes

args = parse_op_args(self.extra_args)
self.seqlen = args.seqlen
self.sparsity = args.sparsity
self.B = args.B if args.B is not None else None
self.M = args.M if args.M is not None else None
self.seqlen = args.seqlen if args.seqlen is not None else None
self.sparsity = args.sparsity if args.sparsity is not None else None
self.sum_then_buffer = args.sum_then_buffer

@register_benchmark(baseline=True)
def torch_jagged_sum_no_pad(self, x: torch.Tensor):
def torch_jagged_sum_no_pad(
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
):
return lambda: torch.tensor(
[
torch.sum(t, dim=0).tolist() for t in x.unbind()
Expand All @@ -106,66 +118,87 @@ def torch_jagged_sum_no_pad(self, x: torch.Tensor):
)

@register_benchmark()
def torch_jagged_sum_pad(self, x: torch.Tensor):
def torch_jagged_sum_pad(
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
):
return lambda: torch.sum(
torch.ops.aten._jagged_to_padded_dense_forward(
x.values(),
[x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
max_lengths=[self.seqlen], # max length of ragged dimension
max_lengths=[seqlen], # max length of ragged dimension
),
dim=1,
) # sum along ragged dimension (dim == 1)

@register_benchmark()
def triton_jagged_sum_no_pad(self, x: torch.Tensor):
def triton_jagged_sum_no_pad(
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
):
def _inner():
return execute_kernel_simple_fused(x, self.seqlen, self.sum_then_buffer)
return execute_kernel_simple_fused(x, seqlen, self.sum_then_buffer)

return _inner

def get_x_val(self, example_inputs):
return len(example_inputs[0])

def get_x_vals(self) -> Tuple[List[int], List[int]]:
B_vals, M_vals = [], []

B_vals.extend([2**n for n in self.sizes])
B_vals.extend(
[
(n - 1) * (n + 1)
for n in self.sizes
if n - 1 > 0 and (n - 1) * (n + 1) not in B_vals
]
)
def get_x_vals(self) -> Tuple[List[int], List[int], List[int], List[float]]:
B_vals, M_vals, seqlen_vals, sparsity_vals = [], [], [], []

def get_dim_vals():
vals = []
vals.extend([2**n for n in self.sizes])
vals.extend(
[
(n - 1) * (n + 1)
for n in self.sizes
if n - 1 > 0 and (n - 1) * (n + 1) not in vals
]
)
return vals

if self.B is None:
B_vals.extend(get_dim_vals())
else:
B_vals.extend([self.B])

if self.M is None:
M_vals.extend(get_dim_vals())
else:
M_vals.extend([self.M])

if self.seqlen is None:
seqlen_vals.extend(
list(range(100, 1000, 100))
+ list(range(1000, 10000, 1000))
)
else:
seqlen_vals.extend([self.seqlen])

M_vals.extend([2**n for n in self.sizes])
M_vals.extend(
[
(n - 1) * (n + 1)
for n in self.sizes
if n - 1 > 0 and (n - 1) * (n + 1) not in M_vals
]
)
if self.sparsity is None:
sparsity_vals.extend([n / 10 for n in range(1, 10)])
else:
sparsity_vals.extend([self.sparsity])

return B_vals, M_vals
return B_vals, M_vals, seqlen_vals, sparsity_vals

def get_input_iter(self) -> Generator:
"""
Generate random nested tensors of shape (B, *, M), where * is the ragged dimension
"""

B_vals, M_vals = self.get_x_vals()
B_M_vals = itertools.product(B_vals, M_vals)
B_vals, M_vals, seqlen_vals, sparsity_vals = self.get_x_vals()
vals = itertools.product(B_vals, M_vals, seqlen_vals, sparsity_vals)

for B, M in B_M_vals:
for B, M, seqlen, sparsity in vals:
tensors = []

# greater sparsity --> shorter sequence lengths on ragged dimension
seqlen_avg = math.floor(
self.seqlen * (1 - self.sparsity)
seqlen * (1 - sparsity)
) # average sequence length across all tensors in nested tensor
seqlen_margin = math.floor(
self.seqlen * RANDOM_CHOICE_MARGIN
seqlen * RANDOM_CHOICE_MARGIN
) # use margin to constrain sequence lengths to range [seqlen_avg - seqlen_margin, seqlen_avg + seqlen_margin] to approximate an average sequence length, which correlates with sparsity

for _ in range(B):
Expand All @@ -174,7 +207,7 @@ def get_input_iter(self) -> Generator:
seqlen_avg - seqlen_margin, 1
), # seqlen_randint must be at least 1
min(
seqlen_avg + seqlen_margin, self.seqlen
seqlen_avg + seqlen_margin, seqlen
), # seqlen_randint must not exceed self.seqlen
)
tensor_2d = torch.randn(
Expand All @@ -189,7 +222,7 @@ def get_input_iter(self) -> Generator:
dtype=self.dtype,
)

yield (nt,)
yield (nt, B, M, seqlen, sparsity)

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
Expand All @@ -205,15 +238,17 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
* GIGABYTES_PER_BYTE
)

@register_metric(x_only=True)
@register_metric(x_only=True) # TODO modify!!!!
def input_shape(
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics
):
return (
example_inputs[0].shape[0],
f"B: {example_inputs[1]}", # B
"*",
example_inputs[0].shape[2],
) # return (B, '*', M) for each example input
f"M: {example_inputs[2]}", # M
f"max seqlen: {example_inputs[3]}", # seqlen
f"sparsity: {example_inputs[4]}", # sparsity
) # return (B, '*', M, max seqlen, sparsity) for each example input

@register_metric(skip_baseline=True)
def best_config(
Expand Down
Loading