Skip to content

Commit

Permalink
rm python version and ignore of F403 & E722
Browse files Browse the repository at this point in the history
  • Loading branch information
Bowen12992 committed Jun 14, 2024
1 parent d3b5121 commit 8d4ac50
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 19 deletions.
7 changes: 1 addition & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,20 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: flake8
language_version: python3.11
args: ["--ignore=F405,E731,F403,W503,E722,E203", --max-line-length=120]
args: ["--ignore=F405,E731,W503,E203", --max-line-length=120]
# F405 : Name may be undefined, or defined from star imports: module
# E731 : Do not assign a lambda expression, use a def
# F403 : 'from module import *' used; unable to detect undefined names
# W503 : Line break before binary operator
# E722 : Do not use bare 'except'
# E203 : Whitespace before ':'

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
language_version: python3.11
args: ["--profile", "black"]

- repo: https://github.com/psf/black.git
rev: 23.7.0
hooks:
- id: black
language_version: python3.11
- id: black-jupyter
2 changes: 1 addition & 1 deletion benchmark/test_blas_perf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from .performance_utils import *
from .performance_utils import BLAS_BATCH, DEFAULT_BATCH, FLOAT_DTYPES, SIZES, Benchmark


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down
9 changes: 8 additions & 1 deletion benchmark/test_fused_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

import flag_gems

from .performance_utils import *
from .performance_utils import (
FLOAT_DTYPES,
POINTWISE_BATCH,
REDUCTION_BATCH,
SIZES,
Benchmark,
binary_args,
)


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down
13 changes: 12 additions & 1 deletion benchmark/test_pointwise_perf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import pytest
import torch

from .performance_utils import *
from .performance_utils import (
FLOAT_DTYPES,
INT_DTYPES,
POINTWISE_BATCH,
SIZES,
Benchmark,
binary_args,
binary_int_args,
ternary_args,
unary_arg,
unary_int_arg,
)


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down
9 changes: 8 additions & 1 deletion benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import pytest
import torch

from .performance_utils import *
from .performance_utils import (
BLAS_BATCH,
FLOAT_DTYPES,
REDUCTION_BATCH,
SIZES,
Benchmark,
unary_arg,
)


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from .fused import *
from .ops import *
from .fused import * # noqa: F403
from .ops import * # noqa: F403

__version__ = "2.0"

Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _rand(seed, offset):
_grid = (1,)
_seed, _offset = philox_cuda_seed_offset(0)
_rand[_grid](_seed, _offset)
except:
except Exception:
tl_rand_dtype = tl.int32

del _grid
Expand Down
10 changes: 9 additions & 1 deletion tests/test_binary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@

import flag_gems

from .accuracy_utils import *
from .accuracy_utils import (
FLOAT_DTYPES,
INT_DTYPES,
POINTWISE_SHAPES,
SCALARS,
gems_assert_close,
gems_assert_equal,
to_reference,
)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
Expand Down
8 changes: 7 additions & 1 deletion tests/test_blas_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@

import flag_gems

from .accuracy_utils import *
from .accuracy_utils import (
FLOAT_DTYPES,
MNK_SHAPES,
SCALARS,
gems_assert_close,
to_reference,
)


@pytest.mark.parametrize("M", MNK_SHAPES)
Expand Down
12 changes: 11 additions & 1 deletion tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@

import flag_gems

from .accuracy_utils import *
from .accuracy_utils import (
DIM_LIST,
DIMS_LIST,
FLOAT_DTYPES,
REDUCTION_SHAPES,
gems_assert_close,
gems_assert_equal,
skip_expr,
skip_reason,
to_reference,
)


@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
Expand Down
9 changes: 7 additions & 2 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

import flag_gems

from .accuracy_utils import *
from .accuracy_utils import (
FLOAT_DTYPES,
POINTWISE_SHAPES,
gems_assert_close,
to_reference,
)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
Expand Down Expand Up @@ -52,7 +57,7 @@ def get_rope_cos_sin(max_seq_len, dim, dtype, base=10000, device="cuda"):
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
x2 = x[..., x.shape[-1] // 2 :] # noqa: E203
return torch.cat((-x2, x1), dim=-1)


Expand Down
9 changes: 8 additions & 1 deletion tests/test_unary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

import flag_gems

from .accuracy_utils import *
from .accuracy_utils import (
FLOAT_DTYPES,
INT_DTYPES,
POINTWISE_SHAPES,
gems_assert_close,
gems_assert_equal,
to_reference,
)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
Expand Down

0 comments on commit 8d4ac50

Please sign in to comment.