Skip to content

Commit

Permalink
[Tests] Fix tests 3rd ed. (#215)
Browse files Browse the repository at this point in the history
1. Change --device=cpu to --ref=cpu
2. Add --mode=quick to test only one shape for CI cpu tests.
3. Add scalar shape for pointwise_dynamic op tests
4. Move assert_close and assert_equal to flag_gems.testing
5. Add mark for each op test, and register it to pytest.ini
6. Add speedup in benchmark print result.

---------

Co-authored-by: zhengyang <[email protected]>
  • Loading branch information
zhzhcookie and zhengyang authored Sep 18, 2024
1 parent f721fd7 commit 8536ebb
Show file tree
Hide file tree
Showing 21 changed files with 308 additions and 104 deletions.
16 changes: 10 additions & 6 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pre-commit

### 2.2 Op Unit Test
Operator Unit Tests check the correctness of operators. If new operators are added, you need to add test cases in the corresponding file under the `tests` directory. If new test files are added, you should also add the test commands to the `cmd` variable in the `tools/coverage.sh` file.
For operator testing, decorate @pytest.mark.{OP_NAME} before the test function so that we can run the unit test function of the specified OP through `pytest -m`. A unit test function can be decorated with multiple custom marks.

### 2.3 Model Test
Model Tests check the correctness of models. Adding a new model follows a process similar to adding a new operator.
Expand Down Expand Up @@ -59,11 +60,12 @@ Currently, the pipeline does not check the performance of operators. You can wri
```
FlagGems
├── src: source code for library
│ ├──flag_gems
│ │ ├──utils: utilities for automatic code generation
│ │ ├──ops: single operators
│ │ ├──fused: fused operators
│ │ ├──__init__.py
│ └──flag_gems
│ ├──utils: utilities for automatic code generation
│ ├──ops: single operators
│ ├──fused: fused operators
│ ├──testing: testing utility
│ └──__init__.py
├── tests: accuracy test files
├── benchmark: performance test files
├── examples: model test files
Expand All @@ -72,7 +74,9 @@ FlagGems
├── README_cn.md
├── OperatorList.md
├── CONTRIBUTING.md
└── pyproject.toml
├── CONTRIBUTING_cn.md
├── pyproject.toml
└── pytest.ini
```

## 4. License
Expand Down
16 changes: 10 additions & 6 deletions CONTRIBUTING_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pre-commit

### 2.2 算子单元测试
Op Unit Test 检查算子的正确性,如果新增算子,需要在 `tests` 目录的相应文件下增加测试用例;如果新增了测试文件,则需要在 `tools/coverage.sh` 文件中的 `cmd` 变量新增测试命令。
对于算子类的单元测试,请在测试函数前装饰 @pytest.mark.{OP_NAME},这样可以通过 `pytest -m` 选择运行指定 OP 的单元测试函数。一个单元测试函数可装饰多个自定义 mark。

### 2.3 模型测试
Model Test 检查模型的正确性,新增模型的流程与新增算子的流程类似。
Expand Down Expand Up @@ -58,11 +59,12 @@ tools/code_coverage/coverage.sh PR_ID
```
FlagGems
├── src: 源码
│ ├──flag_gems
│ │ ├──utils: 自动代码生成的工具
│ │ ├──ops: 单个算子
│ │ ├──fused: 融合算子
│ │ ├──__init__.py
│ └──flag_gems
│ ├──utils: 自动代码生成的工具
│ ├──ops: 单个算子
│ ├──fused: 融合算子
│ ├──testing: 测试工具
│ └──__init__.py
├── tests: 精度测试文件
├── benchmark: 性能测试文件
├── examples: 模型测试文件
Expand All @@ -71,7 +73,9 @@ FlagGems
├── README_cn.md
├── OperatorList.md
├── CONTRIBUTING.md
└── pyproject.toml
├── CONTRIBUTING_cn.md
├── pyproject.toml
└── pytest.ini
```

## 4. 许可证
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ pip install .
- Run reference on cpu
```shell
cd tests
pytest test_xx_ops.py --device cpu
pytest test_xx_ops.py --ref cpu
```

2. Test Model Accuracy
Expand Down
2 changes: 1 addition & 1 deletion README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ pip install .
-CPU上运行参考实现
```shell
cd tests
pytest test_xx_ops.py --device cpu
pytest test_xx_ops.py --ref cpu
```
2. 模型正确性测试
```shell
Expand Down
9 changes: 6 additions & 3 deletions benchmark/performance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def profile(self, op, *args, **kwargs):
def run(self):
for dtype in self.dtypes:
print(f"Operator {self.op_name} Performance Test ({dtype})")
print("Size Torch Latency (ms) Gems Latency (ms)")
print("--------------------------------------------------")
print("Size Torch Latency (ms) Gems Latency (ms) Gems Speedup")
print("---------------------------------------------------------------")
for size in self.sizes:
args = ()
if self.arg_func is not None:
Expand All @@ -92,7 +92,10 @@ def run(self):
else:
with flag_gems.use_gems():
gems_perf = self.profile(self.torch_op, *args, **kwargs)
print(f"{size: <10}{torch_perf: >20.6}{gems_perf: >20.6}")
speedup = torch_perf / gems_perf
print(
f"{size: <8}{torch_perf: >18.6}{gems_perf: >21.6}{speedup: >16.3}"
)


FLOAT_DTYPES = [torch.float16, torch.float32, torch.bfloat16]
Expand Down
8 changes: 0 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,6 @@ test = [
[project.urls]
homepage = "https://github.com/FlagOpen/FlagGems"

[tool.pytest.ini_options]
testpaths = [
"tests",
]
pythonpath = [
"src",
]

[tool.coverage.run]
omit = [
"*/.flaggems/*",
Expand Down
4 changes: 4 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[pytest]
testpaths = tests
pythonpath = src
filterwarnings = ignore::pytest.PytestUnknownMarkWarning
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch

from . import testing # noqa: F401
from .fused import * # noqa: F403
from .ops import * # noqa: F403

Expand Down
22 changes: 22 additions & 0 deletions src/flag_gems/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch

RESOLUTION = {
torch.bool: 0,
torch.int16: 0,
torch.int32: 0,
torch.float16: 1e-3,
torch.float32: 1.3e-6,
torch.bfloat16: 0.016,
}


def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
assert res.dtype == dtype
ref = ref.to(dtype)
atol = 1e-4 * reduce_dim
rtol = RESOLUTION[dtype]
torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan)


def assert_equal(res, ref):
assert torch.equal(res, ref)
57 changes: 28 additions & 29 deletions tests/accuracy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import torch

from .conftest import TO_CPU
import flag_gems

from .conftest import QUICK_MODE, TO_CPU


def SkipTorchVersion(skip_pattern):
Expand All @@ -29,37 +31,30 @@ def SkipTorchVersion(skip_pattern):
INT32_MIN = torch.iinfo(torch.int32).min
INT32_MAX = torch.iinfo(torch.int32).max

RESOLUTION = {
torch.bool: 0,
torch.int16: 0,
torch.int32: 0,
torch.float16: 1e-3,
torch.float32: 1.3e-6,
torch.bfloat16: 0.016,
}

sizes_one = [1]
sizes_pow_2 = [2**d for d in range(4, 11, 2)]
sizes_noalign = [d + 17 for d in sizes_pow_2]
sizes_1d = sizes_one + sizes_pow_2 + sizes_noalign
sizes_2d_nc = [1] if TO_CPU else [1, 16, 64, 1000]
sizes_2d_nr = [1] if TO_CPU else [1, 5, 1024]
sizes_2d_nc = [1] if QUICK_MODE else [1, 16, 64, 1000]
sizes_2d_nr = [1] if QUICK_MODE else [1, 5, 1024]

UT_SHAPES_1D = list((n,) for n in sizes_1d)
UT_SHAPES_2D = list(itertools.product(sizes_2d_nr, sizes_2d_nc))
POINTWISE_SHAPES = (
[(2, 19, 7)]
if TO_CPU
else [(1,), (1024, 1024), (20, 320, 15), (16, 128, 64, 60), (16, 7, 57, 32, 29)]
if QUICK_MODE
else [(), (1,), (1024, 1024), (20, 320, 15), (16, 128, 64, 60), (16, 7, 57, 32, 29)]
)
SPECIAL_SHAPES = (
[(2, 19, 7)]
if TO_CPU
if QUICK_MODE
else [(1,), (1024, 1024), (20, 320, 15), (16, 128, 64, 1280), (16, 7, 57, 32, 29)]
)
DISTRIBUTION_SHAPES = [(20, 320, 15)]
REDUCTION_SHAPES = [(2, 32)] if TO_CPU else [(1, 2), (4096, 256), (200, 40999, 3)]
REDUCTION_SMALL_SHAPES = [(1, 32)] if TO_CPU else [(1, 2), (4096, 256), (200, 2560, 3)]
REDUCTION_SHAPES = [(2, 32)] if QUICK_MODE else [(1, 2), (4096, 256), (200, 40999, 3)]
REDUCTION_SMALL_SHAPES = (
[(1, 32)] if QUICK_MODE else [(1, 2), (4096, 256), (200, 2560, 3)]
)
STACK_SHAPES = [
[(16,), (16,)],
[(16, 256), (16, 256)],
Expand All @@ -80,26 +75,30 @@ def to_reference(inp, upcast=False):
if inp is None:
return None
ref_inp = inp
if TO_CPU:
ref_inp = ref_inp.to("cpu")
if upcast:
ref_inp = ref_inp.to(torch.float64)
if TO_CPU:
ref_inp = ref_inp.to("cpu")
return ref_inp


def gems_assert_close(a, b, dtype, equal_nan=False, reduce_dim=1):
def to_cpu(res, ref):
if TO_CPU:
a = a.to("cpu")
b = b.to(dtype)
atol = 1e-4 * reduce_dim
rtol = RESOLUTION[dtype]
torch.testing.assert_close(a, b, atol=atol, rtol=rtol, equal_nan=equal_nan)
res = res.to("cpu")
assert ref.device == torch.device("cpu")
return res


def gems_assert_equal(a, b):
if TO_CPU:
a = a.to("cpu")
assert torch.equal(a, b)
def gems_assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
res = to_cpu(res, ref)
flag_gems.testing.assert_close(
res, ref, dtype, equal_nan=equal_nan, reduce_dim=reduce_dim
)


def gems_assert_equal(res, ref):
res = to_cpu(res, ref)
flag_gems.testing.assert_equal(res, ref)


def unsqueeze_tuple(t, max_len):
Expand Down
16 changes: 13 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
def pytest_addoption(parser):
parser.addoption(
"--device",
"--ref",
action="store",
default="cuda",
required=False,
choices=["cuda", "cpu"],
help="device to run reference tests on",
)
parser.addoption(
"--mode",
action="store",
default="normal",
required=False,
choices=["normal", "quick"],
help="run tests on normal or quick mode",
)


def pytest_configure(config):
value = config.getoption("--device")
global TO_CPU
TO_CPU = value == "cpu"
TO_CPU = config.getoption("--ref") == "cpu"

global QUICK_MODE
QUICK_MODE = config.getoption("--mode") == "quick"
Loading

0 comments on commit 8536ebb

Please sign in to comment.