forked from pytorch/benchmark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add unit tests on CPU for TritonBench features (pytorch#2323)
Summary: Pull Request resolved: pytorch#2323 Add unit tests that run on the CPU to verify the behavior of the following: - `x_only = True` for metric registration in [`register_metric()`](https://www.internalfb.com/code/fbsource/[731f07681fbbb38750aee3b165137e39fa6cee50]/fbcode/pytorch/benchmark/torchbenchmark/util/triton_op.py?lines=337) - custom `label` argument for benchmark registration in [`register_benchmark()`](https://www.internalfb.com/code/fbsource/[731f07681fbbb38750aee3b165137e39fa6cee50]/fbcode/pytorch/benchmark/torchbenchmark/util/triton_op.py?lines=316) Reviewed By: xuzhao9 Differential Revision: D58558868
- Loading branch information
1 parent
caa76d8
commit f985e9d
Showing
3 changed files
with
46 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .operator import Operator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from typing import Generator, List, Optional | ||
|
||
import torch | ||
|
||
from torchbenchmark.util.triton_op import ( | ||
BenchmarkOperator, | ||
BenchmarkOperatorMetrics, | ||
register_benchmark, | ||
register_metric, | ||
) | ||
|
||
|
||
class Operator(BenchmarkOperator): | ||
|
||
DEFAULT_METRICS = ["test_metric"] | ||
|
||
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): | ||
super().__init__(mode=mode, device=device, extra_args=extra_args) | ||
|
||
@register_benchmark(label="new_op_label") | ||
def test_op(self, x: torch.Tensor): | ||
return lambda: x | ||
|
||
def get_x_val(self, example_inputs): | ||
return example_inputs[0].shape | ||
|
||
def get_x_vals(self) -> List[int]: | ||
return [2**n for n in [1, 2, 3]] | ||
|
||
def get_input_iter(self) -> Generator: | ||
for x in self.get_x_vals(): | ||
yield (torch.Tensor(torch.randn(x, device=self.device, dtype=self.dtype)),) | ||
|
||
@register_metric(x_only=True) | ||
def test_metric( | ||
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics | ||
): | ||
return [ex.shape[0] + 2 for ex in example_inputs] | ||
|
||
@register_metric() | ||
def test_metric_per_benchmark( | ||
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics | ||
): | ||
return [ex.shape[0] + 3 for ex in example_inputs] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters