diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 00000000..dd7801f4 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,17 @@ +name: code-format-check + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + - uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/python-test.yaml b/.github/workflows/python-test.yaml index 9abd24f9..614fb31b 100644 --- a/.github/workflows/python-test.yaml +++ b/.github/workflows/python-test.yaml @@ -13,7 +13,7 @@ on: jobs: container-unit-test: runs-on: [self-hosted, docker] - timeout-minutes: 30 + timeout-minutes: 50 container: image: localhost:5000/flag-gems-ci:v1.0 ports: @@ -30,7 +30,7 @@ jobs: CUDA_VISIBLE_DEVICES=2 pytest -s tests/test_blas_ops.py & CUDA_VISIBLE_DEVICES=3 pytest -s tests/test_reduction_ops.py & CUDA_VISIBLE_DEVICES=4 pytest -s tests/test_special_ops.py && wait - + container-model-test: runs-on: [self-hosted, docker] timeout-minutes: 5 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..f46714d7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,30 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + - id: flake8 + language_version: python3.11 + args: ["--ignore=F405,E731,F403,W503,E722,E203", --max-line-length=120] + # F405 : Name may be undefined, or defined from star imports: module + # E731 : Do not assign a lambda expression, use a def + # F403 : 'from module import *' used; unable to detect undefined names + # W503 : Line break before binary operator + # E722 : Do not use bare 'except' + # E203 : Whitespace before ':' + +- repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + language_version: python3.11 + args: ["--profile", "black"] + +- repo: https://github.com/psf/black.git + rev: 23.7.0 + hooks: + - id: black + language_version: python3.11 + - id: black-jupyter diff --git a/LICENSE b/LICENSE index 4a70bd07..f18fcdbc 100644 --- a/LICENSE +++ b/LICENSE @@ -175,4 +175,4 @@ Copyright © 2024 BAAI. All rights reserved. incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. - END OF TERMS AND CONDITIONS \ No newline at end of file + END OF TERMS AND CONDITIONS diff --git a/README.md b/README.md index a8925a24..0aee0981 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,9 @@ ## Introduction -FlagGems is a high-performance general operator library implemented in [OpenAI Triton](https://github.com/openai/triton). It aims to provide a suite of kernel functions to accelerate LLM training and inference. +FlagGems is a high-performance general operator library implemented in [OpenAI Triton](https://github.com/openai/triton). It aims to provide a suite of kernel functions to accelerate LLM training and inference. -By registering with the ATen backend of PyTorch, FlagGems facilitates a seamless transition, allowing users to switch to the Triton function library without the need to modify their model code. Users can still utilize the ATen backend as usual while experiencing significant performance enhancement. The Triton language offers benefits in readability, user-friendliness and performance comparable to CUDA. This convenience allows developers to engage in the development of FlagGems with minimal learning investment. +By registering with the ATen backend of PyTorch, FlagGems facilitates a seamless transition, allowing users to switch to the Triton function library without the need to modify their model code. Users can still utilize the ATen backend as usual while experiencing significant performance enhancement. The Triton language offers benefits in readability, user-friendliness and performance comparable to CUDA. This convenience allows developers to engage in the development of FlagGems with minimal learning investment. ## Feature @@ -49,25 +49,25 @@ def ge(x, y): ## Changelog ### v1.0 -- support BLAS operators: addmm, bmm, mm -- support pointwise operators: abs, add, div, dropout, exp, gelu, mul, pow, reciprocal, relu, rsqrt, silu, sub, triu -- support reduction operators: cumsum, layernorm, mean, softmax +- support BLAS operators: addmm, bmm, mm +- support pointwise operators: abs, add, div, dropout, exp, gelu, mul, pow, reciprocal, relu, rsqrt, silu, sub, triu +- support reduction operators: cumsum, layernorm, mean, softmax ### v2.0 -- support BLAS operator: mv, outer -- support pointwise operators: bitwise_and, bitwise_not, bitwise_or, cos, clamp, eq, ge, gt, isinf, isnan, le, lt, ne, neg, or, sin, tanh, sigmoid -- support reduction operators: all, any, amax, argmax, max, min, prod, sum, var_mean, vector_norm, cross_entropy_loss, group_norm, log_softmax, rms_norm -- support fused operators: skip_rms_norm, skip_layer_norm, gelu_and_mul, silu_and_mul, apply_rotary_position_embedding +- support BLAS operator: mv, outer +- support pointwise operators: bitwise_and, bitwise_not, bitwise_or, cos, clamp, eq, ge, gt, isinf, isnan, le, lt, ne, neg, or, sin, tanh, sigmoid +- support reduction operators: all, any, amax, argmax, max, min, prod, sum, var_mean, vector_norm, cross_entropy_loss, group_norm, log_softmax, rms_norm +- support fused operators: skip_rms_norm, skip_layer_norm, gelu_and_mul, silu_and_mul, apply_rotary_position_embedding ## Quick Start ### Requirements -1. Triton >= 2.2.0 -2. PyTorch >= 2.1.2 -3. Transformers >= 4.40.2 +1. Triton >= 2.2.0 +2. PyTorch >= 2.1.2 +3. Transformers >= 4.40.2 -### Installation +### Installation ```shell git clone https://github.com/FlagOpen/FlagGems.git @@ -75,24 +75,24 @@ cd FlagGems pip install . ``` -## Usage +## Usage ### Import -1. Enable permanently +1. Enable permanently ```python import flag_gems flag_gems.enable() ``` -2. Enable temporarily +2. Enable temporarily ```python import flag_gems with flag_gems.use_gems(): pass ``` -3. Example +3. Example ```python import torch import flag_gems @@ -106,41 +106,41 @@ pip install . ### Execute -1. Test Operator Accuracy - - Run reference on cuda +1. Test Operator Accuracy + - Run reference on cuda ```shell cd tests pytest test_xx_ops.py ``` - - Run reference on cpu + - Run reference on cpu ```shell cd tests pytest test_xx_ops.py --device cpu ``` -2. Test Model Accuracy +2. Test Model Accuracy ```shell cd examples pytest model_xx_test.py ``` -3. Test Operator Performance - - Test CUDA performance +3. Test Operator Performance + - Test CUDA performance ```shell cd benchmark pytest test_xx_perf.py -s ``` - - Test end-to-end performance + - Test end-to-end performance ```shell cd benchmark pytest test_xx_perf.py -s --mode cpu ``` -4. Run tests with logging infomation +4. Run tests with logging infomation ```shell pytest program.py --log-cli-level debug ``` - Not recommended in performance testing. + Not recommended in performance testing. ## Supported Operators @@ -148,8 +148,8 @@ Operators will be implemented according to [OperatorList.md](https://github.com/ ## Supported Models -- Bert-base-uncased -- Llama-2-7b +- Bert-base-uncased +- Llama-2-7b ## Supported Platforms diff --git a/README_cn.md b/README_cn.md index 7f9402df..74ddbaff 100644 --- a/README_cn.md +++ b/README_cn.md @@ -2,9 +2,9 @@ ## 介绍 -FlagGems是一个使用OpenAI推出的[Triton编程语言](https://github.com/openai/triton)实现的高性能通用算子库,旨在为大语言模型提供一系列可应用于PyTorch框架的算子,加速模型的推理与训练。 +FlagGems是一个使用OpenAI推出的[Triton编程语言](https://github.com/openai/triton)实现的高性能通用算子库,旨在为大语言模型提供一系列可应用于PyTorch框架的算子,加速模型的推理与训练。 -FlagGems通过对PyTorch的后端aten算子进行覆盖重写,实现算子库的无缝替换,使用户能够在不修改模型代码的情况下平稳地切换到triton算子库。FlagGems不会影响aten后端的正常使用,并且会带来良好的性能提升。Triton语言为算子库提供了更好的可读性和易用性,同时保持了不逊于CUDA的算子性能,因此开发者只需付出较低的学习成本,即可参与FlagGems的算子开发与建设。 +FlagGems通过对PyTorch的后端aten算子进行覆盖重写,实现算子库的无缝替换,使用户能够在不修改模型代码的情况下平稳地切换到triton算子库。FlagGems不会影响aten后端的正常使用,并且会带来良好的性能提升。Triton语言为算子库提供了更好的可读性和易用性,同时保持了不逊于CUDA的算子性能,因此开发者只需付出较低的学习成本,即可参与FlagGems的算子开发与建设。 ## 特性 @@ -49,25 +49,25 @@ def ge(x, y): ## 更新日志 ### v1.0 -- 支持BLAS类算子:addmm, bmm, mm -- 支持pointwise类算子:abs, add, div, dropout, exp, gelu, mul, pow, reciprocal, relu, rsqrt, silu, sub, triu -- 支持reduction类算子:cumsum, layernorm, mean, softmax +- 支持BLAS类算子:addmm, bmm, mm +- 支持pointwise类算子:abs, add, div, dropout, exp, gelu, mul, pow, reciprocal, relu, rsqrt, silu, sub, triu +- 支持reduction类算子:cumsum, layernorm, mean, softmax ### v2.0 -- 支持BLAS类算子: mv, outer -- 支持pointwise类算子: bitwise_and, bitwise_not, bitwise_or, cos, clamp, eq, ge, gt, isinf, isnan, le, lt, ne, neg, or, sin, tanh, sigmoid -- 支持reduction类算子: all, any, amax, argmax, max, min, prod, sum, var_mean, vector_norm, cross_entropy_loss, group_norm, log_softmax, rms_norm -- 支持融合算子: skip_rms_norm, skip_layer_norm, gelu_and_mul, silu_and_mul, apply_rotary_position_embedding +- 支持BLAS类算子: mv, outer +- 支持pointwise类算子: bitwise_and, bitwise_not, bitwise_or, cos, clamp, eq, ge, gt, isinf, isnan, le, lt, ne, neg, or, sin, tanh, sigmoid +- 支持reduction类算子: all, any, amax, argmax, max, min, prod, sum, var_mean, vector_norm, cross_entropy_loss, group_norm, log_softmax, rms_norm +- 支持融合算子: skip_rms_norm, skip_layer_norm, gelu_and_mul, silu_and_mul, apply_rotary_position_embedding ## 快速入门 ### 依赖 -1. Triton >= 2.2.0 -2. PyTorch >= 2.1.2 -3. Transformers >= 4.40.2 +1. Triton >= 2.2.0 +2. PyTorch >= 2.1.2 +3. Transformers >= 4.40.2 -### 安装 +### 安装 ```shell git clone https://github.com/FlagOpen/FlagGems.git @@ -75,24 +75,24 @@ cd FlagGems pip install . ``` -## 使用 +## 使用 ### 导入 -1. 在进程中永久启用 +1. 在进程中永久启用 ```python import flag_gems flag_gems.enable() ``` -2. 暂时启用 +2. 暂时启用 ```python import flag_gems with flag_gems.use_gems(): pass ``` -3. 示例 +3. 示例 ```python import torch import flag_gems @@ -106,40 +106,40 @@ pip install . ### 执行 -1. 算子正确性测试 - - 在CUDA上运行参考实现 +1. 算子正确性测试 + - 在CUDA上运行参考实现 ```shell cd tests/flag_gems pytest op_accu_test.py ``` - - 在CPU上运行参考实现 + - 在CPU上运行参考实现 ```shell cd tests pytest test_xx_ops.py --device cpu ``` -2. 模型正确性测试 +2. 模型正确性测试 ```shell cd examples pytest model_xx_test.py ``` -3. 算子性能测试 - - 测试CUDA性能 +3. 算子性能测试 + - 测试CUDA性能 ```shell cd benchmark pytest test_xx_perf.py -s ``` - - 测试端到端性能 + - 测试端到端性能 ```shell cd benchmark pytest test_xx_perf.py -s --mode cpu ``` -2. 运行时打印日志信息 +2. 运行时打印日志信息 ```shell pytest program.py --log-cli-level debug ``` - 测试性能时不建议打开。 + 测试性能时不建议打开。 ## 支持算子 @@ -147,8 +147,8 @@ pip install . ## 支持模型 -- Bert-base-uncased -- Llama-2-7b +- Bert-base-uncased +- Llama-2-7b ## 支持平台 diff --git a/benchmark/performance_utils.py b/benchmark/performance_utils.py index 8618e8aa..6f11a3ba 100644 --- a/benchmark/performance_utils.py +++ b/benchmark/performance_utils.py @@ -1,9 +1,11 @@ +import time + import torch import triton -import time + import flag_gems -from .conftest import CPU_MODE +from .conftest import CPU_MODE WARMUP = 10 REPETITION = 1000 @@ -42,8 +44,8 @@ def profile(self, op, *args): def run(self): print(f"Operator {self.op_name} Performance Test ({self.dtype})") - print(f"Size Torch Latency (ms) Gems Latency (ms)") - print(f"--------------------------------------------------") + print("Size Torch Latency (ms) Gems Latency (ms)") + print("--------------------------------------------------") for size in self.sizes: args = self.arg_func(self.dtype, self.batch, size) torch_perf = self.profile(self.torch_op, *args) diff --git a/benchmark/test_blas_perf.py b/benchmark/test_blas_perf.py index 08088f87..1040e6ec 100644 --- a/benchmark/test_blas_perf.py +++ b/benchmark/test_blas_perf.py @@ -1,12 +1,11 @@ -import torch import pytest -import flag_gems +import torch + from .performance_utils import * @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_perf_addmm(dtype): - def addmm_args(dtype, batch, size): bias = torch.randn( [ @@ -32,7 +31,6 @@ def addmm_args(dtype, batch, size): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_perf_bmm(dtype): - def bmm_args(dtype, batch, size): inp1 = torch.randn([batch, size, size], dtype=dtype, device="cuda") inp2 = torch.randn([batch, size, size], dtype=dtype, device="cuda") @@ -51,7 +49,6 @@ def bmm_args(dtype, batch, size): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_perf_mm(dtype): - def mm_args(dtype, batch, size): inp1 = torch.randn([size, size], dtype=dtype, device="cuda") inp2 = torch.randn([size, size], dtype=dtype, device="cuda") @@ -70,7 +67,6 @@ def mm_args(dtype, batch, size): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_perf_mv(dtype): - def mv_args(dtype, batch, size): inp1 = torch.randn([size, size], dtype=dtype, device="cuda") inp2 = torch.randn([size], dtype=dtype, device="cuda") @@ -89,7 +85,6 @@ def mv_args(dtype, batch, size): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_perf_outer(dtype): - def outer_args(dtype, batch, size): inp1 = torch.randn([size], dtype=dtype, device="cuda") inp2 = torch.randn([size], dtype=dtype, device="cuda") diff --git a/benchmark/test_fused_perf.py b/benchmark/test_fused_perf.py index 4b97ea24..7c1b28b2 100644 --- a/benchmark/test_fused_perf.py +++ b/benchmark/test_fused_perf.py @@ -1,12 +1,13 @@ -import torch import pytest +import torch + import flag_gems + from .performance_utils import * @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_perf_gelu_and_mul(dtype): - def torch_op(x, y): return torch.mul(torch.nn.functional.gelu(x), y) @@ -26,7 +27,6 @@ def torch_op(x, y): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_perf_silu_and_mul(dtype): - def torch_op(x, y): return torch.mul(torch.nn.functional.silu(x), y) @@ -46,7 +46,6 @@ def torch_op(x, y): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_perf_skip_layernorm(dtype): - def skip_layernorm_args(dtype, batch, size): inp = torch.randn([batch, size], dtype=dtype, device="cuda") residual = torch.randn([batch, size], dtype=dtype, device="cuda") @@ -93,7 +92,6 @@ def torch_op(inp, residual, layer_shape, weight, bias): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_perf_skip_rmsnorm(dtype): - def skip_rmsnorm_args(dtype, batch, size): inp = torch.randn([batch, size], dtype=dtype, device="cuda") residual = torch.randn([batch, size], dtype=dtype, device="cuda") diff --git a/benchmark/test_pointwise_perf.py b/benchmark/test_pointwise_perf.py index b91e4f84..f9b8ae86 100644 --- a/benchmark/test_pointwise_perf.py +++ b/benchmark/test_pointwise_perf.py @@ -1,6 +1,6 @@ -import torch import pytest -import flag_gems +import torch + from .performance_utils import * diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index 75a95127..49f15e9b 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -1,6 +1,6 @@ -import torch import pytest -import flag_gems +import torch + from .performance_utils import * @@ -58,7 +58,6 @@ def test_perf_argmax(dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_perf_cross_entropy_loss(dtype): - def cross_entropy_loss_args(dtype, batch, size): inp = torch.randn([batch, size], dtype=dtype, device="cuda") target = torch.randint( @@ -84,7 +83,6 @@ def cross_entropy_loss_args(dtype, batch, size): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_perf_cumsum(dtype): - def cumsum_args(dtype, batch, size): inp = torch.randn([batch, size], dtype=dtype, device="cuda") return inp, 1 @@ -102,7 +100,6 @@ def cumsum_args(dtype, batch, size): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_perf_groupnorm(dtype): - def group_norm_args(dtype, batch, size): C = 16 G = 16 @@ -136,7 +133,6 @@ def group_norm_args(dtype, batch, size): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_perf_layernorm(dtype): - def layer_norm_args(dtype, batch, size): inp = torch.randn([batch, size], dtype=dtype, device="cuda") weight = torch.randn( diff --git a/examples/model_bert_test.py b/examples/model_bert_test.py index 25c7f41f..7b49b625 100644 --- a/examples/model_bert_test.py +++ b/examples/model_bert_test.py @@ -1,9 +1,11 @@ -import torch -import pytest import copy -import flag_gems + +import pytest +import torch from transformers import AutoTokenizer, BertConfig, BertModel +import flag_gems + @pytest.mark.parametrize( "prompt", @@ -49,4 +51,5 @@ def test_accuracy_bert(prompt, dtype): succeed = score >= 0.99 assert ( succeed - ), f"BERT_{dtype} FAIL with maxdiff {maxdiff} and score {score}\nREF: {ref_outputs}\nRES: {res_outputs}" + ), f"BERT_{dtype} FAIL with maxdiff {maxdiff} and score {score}\nREF: \ + {ref_outputs}\nRES: {res_outputs}" diff --git a/examples/model_llama_test.py b/examples/model_llama_test.py index 3c18ae6e..5b264227 100644 --- a/examples/model_llama_test.py +++ b/examples/model_llama_test.py @@ -1,8 +1,8 @@ -import torch import pytest -import copy +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + import flag_gems -from transformers import AutoTokenizer, AutoModelForCausalLM @pytest.mark.parametrize( diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 5e96a67d..0dc34a7c 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -1,7 +1,7 @@ import torch -from .ops import * from .fused import * +from .ops import * __version__ = "2.0" diff --git a/src/flag_gems/fused/__init__.py b/src/flag_gems/fused/__init__.py index fccd6d37..db1d0e6d 100644 --- a/src/flag_gems/fused/__init__.py +++ b/src/flag_gems/fused/__init__.py @@ -1,9 +1,8 @@ +from .gelu_and_mul import gelu_and_mul from .rotary_embedding import apply_rotary_pos_emb +from .silu_and_mul import silu_and_mul from .skip_layernorm import skip_layer_norm from .skip_rms_norm import skip_rms_norm -from .silu_and_mul import silu_and_mul -from .gelu_and_mul import gelu_and_mul - __all__ = [ "apply_rotary_pos_emb", diff --git a/src/flag_gems/fused/gelu_and_mul.py b/src/flag_gems/fused/gelu_and_mul.py index c655ae48..c2cb52f6 100644 --- a/src/flag_gems/fused/gelu_and_mul.py +++ b/src/flag_gems/fused/gelu_and_mul.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -37,12 +39,11 @@ class GeluAndMul(torch.autograd.Function): def forward(ctx, A, B, approximate="none"): logging.debug("GEMS GELU AND MUL FORWARD") if approximate == "none": - O = gelu_none_and_mul_kernel(A, B) + return gelu_none_and_mul_kernel(A, B) elif approximate == "tanh": - O = gelu_tanh_and_mul_kernel(A, B) + return gelu_tanh_and_mul_kernel(A, B) else: raise ValueError(f"Invalid approximate value: {approximate}") - return O def gelu_and_mul(A, B, approximate="none"): diff --git a/src/flag_gems/fused/rotary_embedding.py b/src/flag_gems/fused/rotary_embedding.py index 7f065f6c..d4ec7725 100644 --- a/src/flag_gems/fused/rotary_embedding.py +++ b/src/flag_gems/fused/rotary_embedding.py @@ -1,9 +1,8 @@ import torch import triton import triton.language as tl -import logging + from ..utils import libentry -import math @libentry() diff --git a/src/flag_gems/fused/silu_and_mul.py b/src/flag_gems/fused/silu_and_mul.py index f86f4844..0d1271ec 100644 --- a/src/flag_gems/fused/silu_and_mul.py +++ b/src/flag_gems/fused/silu_and_mul.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -17,8 +19,7 @@ class SiluAndMul(torch.autograd.Function): @staticmethod def forward(ctx, A, B): logging.debug("GEMS SILU AND MUL FORWARD") - O = silu_and_mul_kernel(A, B) - return O + return silu_and_mul_kernel(A, B) def silu_and_mul(A, B): diff --git a/src/flag_gems/fused/skip_layernorm.py b/src/flag_gems/fused/skip_layernorm.py index b2847aae..910b2eef 100644 --- a/src/flag_gems/fused/skip_layernorm.py +++ b/src/flag_gems/fused/skip_layernorm.py @@ -1,9 +1,11 @@ +import logging +import math + import torch import triton import triton.language as tl -import logging + from ..utils import libentry -import math @libentry() diff --git a/src/flag_gems/fused/skip_rms_norm.py b/src/flag_gems/fused/skip_rms_norm.py index 1db47f0a..7c3bd0c9 100644 --- a/src/flag_gems/fused/skip_rms_norm.py +++ b/src/flag_gems/fused/skip_rms_norm.py @@ -1,9 +1,11 @@ +import logging +import math + import torch import triton import triton.language as tl -import logging + from ..utils import libentry -import math @libentry() diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 83fc116e..4157e0b4 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -1,21 +1,24 @@ -from .all import all, all_dim, all_dims -from .any import any, any_dim, any_dims from .abs import abs from .add import add from .addmm import addmm +from .all import all, all_dim, all_dims +from .amax import amax +from .any import any, any_dim, any_dims +from .argmax import argmax from .bitwise_and import ( - bitwise_and_tensor, bitwise_and_scalar, bitwise_and_scalar_tensor, + bitwise_and_tensor, ) from .bitwise_not import bitwise_not -from .bitwise_or import bitwise_or_tensor, bitwise_or_scalar, bitwise_or_scalar_tensor +from .bitwise_or import bitwise_or_scalar, bitwise_or_scalar_tensor, bitwise_or_tensor from .bmm import bmm from .clamp import clamp, clamp_tensor from .cos import cos +from .cross_entropy_loss import cross_entropy_loss from .cumsum import cumsum -from .dropout import native_dropout from .div import div +from .dropout import native_dropout from .eq import eq, eq_scalar from .exp import exp from .ge import ge, ge_scalar @@ -26,39 +29,34 @@ from .isnan import isnan from .layernorm import layer_norm from .le import le, le_scalar +from .log_softmax import log_softmax from .lt import lt, lt_scalar -from .rms_norm import rms_norm +from .max import max, max_dim from .mean import mean, mean_dim +from .min import min, min_dim from .mm import mm from .mul import mul from .mv import mv from .ne import ne, ne_scalar from .neg import neg +from .outer import outer from .pow import pow_scalar, pow_tensor_scalar, pow_tensor_tensor +from .prod import prod, prod_dim from .reciprocal import reciprocal from .relu import relu +from .rms_norm import rms_norm from .rsqrt import rsqrt from .sigmoid import sigmoid from .silu import silu from .sin import sin from .softmax import softmax from .sub import sub +from .sum import sum, sum_dim from .tanh import tanh from .triu import triu - -from .max import max, max_dim -from .min import min, min_dim -from .amax import amax -from .sum import sum, sum_dim -from .argmax import argmax -from .prod import prod, prod_dim -from .log_softmax import log_softmax -from .outer import outer -from .cross_entropy_loss import cross_entropy_loss - from .var_mean import var_mean from .vector_norm import vector_norm -from .where import where_self, where_scalar_self, where_scalar_other +from .where import where_scalar_other, where_scalar_self, where_self __all__ = [ "all", diff --git a/src/flag_gems/ops/abs.py b/src/flag_gems/ops/abs.py index 6f73c912..1f8eabf9 100644 --- a/src/flag_gems/ops/abs.py +++ b/src/flag_gems/ops/abs.py @@ -1,6 +1,8 @@ +import logging + import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -12,5 +14,4 @@ def abs_func(x): def abs(A): logging.debug("GEMS ABS") - O = abs_func(A) - return O + return abs_func(A) diff --git a/src/flag_gems/ops/add.py b/src/flag_gems/ops/add.py index 7f849e64..5c6dfee3 100644 --- a/src/flag_gems/ops/add.py +++ b/src/flag_gems/ops/add.py @@ -1,7 +1,8 @@ +import logging + import torch import triton -import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -26,13 +27,10 @@ def add_func_scalar_tensor(x, y, alpha): def add(A, B, *, alpha=1): logging.debug("GEMS ADD") if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): - O = add_func(A, B, alpha) - return O + return add_func(A, B, alpha) elif isinstance(A, torch.Tensor): - O = add_func_tensor_scalar(A, B, alpha) - return O + return add_func_tensor_scalar(A, B, alpha) elif isinstance(B, torch.Tensor): - O = add_func_scalar_tensor(A, B, alpha) - return O + return add_func_scalar_tensor(A, B, alpha) else: return A + B * alpha diff --git a/src/flag_gems/ops/addmm.py b/src/flag_gems/ops/addmm.py index 852784cf..2b5bb41c 100644 --- a/src/flag_gems/ops/addmm.py +++ b/src/flag_gems/ops/addmm.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import libentry diff --git a/src/flag_gems/ops/all.py b/src/flag_gems/ops/all.py index 19360f39..9d9352ab 100644 --- a/src/flag_gems/ops/all.py +++ b/src/flag_gems/ops/all.py @@ -1,13 +1,15 @@ +import logging +import math + import torch import triton -import math import triton.language as tl + from ..utils import libentry -import logging -# torch.all: Tests if all elements in input evaluate to True. -# If the dtype of input is not BOOL, then test if all elements in input evaluate to non-zero value +# torch.all: Tests if all elements in input evaluate to True. If the dtype of input +# is not BOOL, then test if all elements in input evaluate to non-zero value # In triton function, test if all elements in input evaluate to non-zero value is ok. def cfggen(): block_m = [1, 2, 4, 8] @@ -89,7 +91,6 @@ def all(inp): mid_size = triton.cdiv(n_elements, block_size) block_mid = triton.next_power_of_2(mid_size) - dtype = inp.dtype mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device) out = torch.empty([], dtype=torch.bool, device=inp.device) diff --git a/src/flag_gems/ops/amax.py b/src/flag_gems/ops/amax.py index ebb3eba3..2dd5a0e0 100644 --- a/src/flag_gems/ops/amax.py +++ b/src/flag_gems/ops/amax.py @@ -1,9 +1,11 @@ +import logging +import math + import torch import triton import triton.language as tl -import logging + from ..utils import libentry -import math @libentry() diff --git a/src/flag_gems/ops/any.py b/src/flag_gems/ops/any.py index 24741408..1a8c120c 100644 --- a/src/flag_gems/ops/any.py +++ b/src/flag_gems/ops/any.py @@ -1,13 +1,15 @@ +import logging +import math + import torch import triton -import math import triton.language as tl + from ..utils import libentry -import logging -# torch.any: Tests if any elements in input evaluate to True. -# If the dtype of input is not BOOL, then test if any elements in input evaluate to non-zero value +# torch.any: Tests if any elements in input evaluate to True. If the dtype of input +# is not BOOL, then test if any elements in input evaluate to non-zero value # In triton function, test if any elements in input evaluate to non-zero value is ok. def cfggen(): block_m = [1, 2, 4, 8] @@ -88,7 +90,6 @@ def any(inp): block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements))) mid_size = triton.cdiv(n_elements, block_size) block_mid = triton.next_power_of_2(mid_size) - dtype = inp.dtype mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device) out = torch.empty([], dtype=torch.bool, device=inp.device) diff --git a/src/flag_gems/ops/argmax.py b/src/flag_gems/ops/argmax.py index 54e1b7ec..a30ec931 100644 --- a/src/flag_gems/ops/argmax.py +++ b/src/flag_gems/ops/argmax.py @@ -1,9 +1,11 @@ +import logging +import math + import torch import triton import triton.language as tl -import logging + from ..utils import libentry -import math @libentry() diff --git a/src/flag_gems/ops/bitwise_and.py b/src/flag_gems/ops/bitwise_and.py index b83dcead..72a994e0 100644 --- a/src/flag_gems/ops/bitwise_and.py +++ b/src/flag_gems/ops/bitwise_and.py @@ -1,5 +1,7 @@ -import triton import logging + +import triton + from ..utils import pointwise_dynamic @@ -11,8 +13,7 @@ def bitwise_and_func(x, y): def bitwise_and_tensor(A, B): logging.debug("GEMS BITWISE AND") - O = bitwise_and_func(A, B) - return O + return bitwise_and_func(A, B) @pointwise_dynamic(is_tensor=[True, False]) @@ -23,11 +24,9 @@ def bitwise_and_func_scalar(x, y): def bitwise_and_scalar(A, B): logging.debug("GEMS BITWISE AND SCALAR") - O = bitwise_and_func_scalar(A, B) - return O + return bitwise_and_func_scalar(A, B) def bitwise_and_scalar_tensor(A, B): logging.debug("GEMS BITWISE AND SCALAR TENSOR") - O = bitwise_and_func_scalar(B, A) - return O + return bitwise_and_func_scalar(B, A) diff --git a/src/flag_gems/ops/bitwise_not.py b/src/flag_gems/ops/bitwise_not.py index f4250120..b5977514 100644 --- a/src/flag_gems/ops/bitwise_not.py +++ b/src/flag_gems/ops/bitwise_not.py @@ -1,5 +1,7 @@ -import triton import logging + +import triton + from ..utils import pointwise_dynamic @@ -11,5 +13,4 @@ def bitwise_not_func(x): def bitwise_not(A): logging.debug("GEMS BITWISE NOT") - O = bitwise_not_func(A) - return O + return bitwise_not_func(A) diff --git a/src/flag_gems/ops/bitwise_or.py b/src/flag_gems/ops/bitwise_or.py index 0237ff35..99f654e4 100644 --- a/src/flag_gems/ops/bitwise_or.py +++ b/src/flag_gems/ops/bitwise_or.py @@ -1,5 +1,7 @@ -import triton import logging + +import triton + from ..utils import pointwise_dynamic @@ -11,8 +13,7 @@ def bitwise_or_func(x, y): def bitwise_or_tensor(A, B): logging.debug("GEMS BITWISE OR") - O = bitwise_or_func(A, B) - return O + return bitwise_or_func(A, B) @pointwise_dynamic(is_tensor=[True, False]) @@ -23,11 +24,9 @@ def bitwise_or_func_scalar(x, y): def bitwise_or_scalar(A, B): logging.debug("GEMS BITWISE OR SCALAR") - O = bitwise_or_func_scalar(A, B) - return O + return bitwise_or_func_scalar(A, B) def bitwise_or_scalar_tensor(A, B): logging.debug("GEMS BITWISE OR SCALAR TENSOR") - O = bitwise_or_func_scalar(B, A) - return O + return bitwise_or_func_scalar(B, A) diff --git a/src/flag_gems/ops/bmm.py b/src/flag_gems/ops/bmm.py index 774444d8..0ff940c4 100644 --- a/src/flag_gems/ops/bmm.py +++ b/src/flag_gems/ops/bmm.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import libentry @@ -184,12 +186,12 @@ def bmm(A, B): _, _, N = B.shape A = A.contiguous() B = B.contiguous() - O = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) + out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) grid_fn = lambda meta: ( triton.cdiv(meta["M"], meta["TILE_M"]), triton.cdiv(meta["N"], meta["TILE_N"]), batch, ) - bmm_kernel[grid_fn](A, B, O, M, N, K) - return O + bmm_kernel[grid_fn](A, B, out, M, N, K) + return out diff --git a/src/flag_gems/ops/clamp.py b/src/flag_gems/ops/clamp.py index a5d51edc..c603fe52 100644 --- a/src/flag_gems/ops/clamp.py +++ b/src/flag_gems/ops/clamp.py @@ -1,6 +1,8 @@ +import logging + import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -27,12 +29,11 @@ def clamp_tensor(A, mini=None, maxi=None): if mini is None and maxi is None: raise ValueError("At least one of mini or maxi must not be None") elif mini is None: - O = clamp_func_max_tensor(A, maxi) + return clamp_func_max_tensor(A, maxi) elif maxi is None: - O = clamp_func_min_tensor(A, mini) + return clamp_func_min_tensor(A, mini) else: - O = clamp_func_tensor(A, mini, maxi) - return O + return clamp_func_tensor(A, mini, maxi) @pointwise_dynamic(is_tensor=[True, False, False]) @@ -58,9 +59,8 @@ def clamp(A, mini=None, maxi=None): if mini is None and maxi is None: raise ValueError("At least one of mini or maxi must not be None") elif mini is None: - O = clamp_func_max(A, maxi) + return clamp_func_max(A, maxi) elif maxi is None: - O = clamp_func_min(A, mini) + return clamp_func_min(A, mini) else: - O = clamp_func(A, mini, maxi) - return O + return clamp_func(A, mini, maxi) diff --git a/src/flag_gems/ops/cos.py b/src/flag_gems/ops/cos.py index 07fd64b3..fee00a4b 100644 --- a/src/flag_gems/ops/cos.py +++ b/src/flag_gems/ops/cos.py @@ -1,6 +1,8 @@ +import logging + import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -12,5 +14,4 @@ def cos_func(x): def cos(A): logging.debug("GEMS COS") - O = cos_func(A) - return O + return cos_func(A) diff --git a/src/flag_gems/ops/cross_entropy_loss.py b/src/flag_gems/ops/cross_entropy_loss.py index d224a4a8..1f498baf 100644 --- a/src/flag_gems/ops/cross_entropy_loss.py +++ b/src/flag_gems/ops/cross_entropy_loss.py @@ -1,8 +1,10 @@ +import logging +from enum import IntEnum + import torch import triton import triton.language as tl -import logging -from enum import IntEnum + from ..utils import libentry from .sum import sum, sum_dim @@ -285,7 +287,8 @@ def backward(ctx, out_grad): return out, None, None, None, None, None -# todo: reducetion(dtype: int,default mean->1), support other scenarios as follows: (none->0, sum->2) +# todo: reducetion(dtype: int,default mean->1), support other scenarios as follows: +# (none->0, sum->2) def cross_entropy_loss( input, target, weight=None, reduction=1, ignore_index=-100, label_smoothing=0.0 ): diff --git a/src/flag_gems/ops/cumsum.py b/src/flag_gems/ops/cumsum.py index 123faa4b..5693b9ed 100644 --- a/src/flag_gems/ops/cumsum.py +++ b/src/flag_gems/ops/cumsum.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import libentry diff --git a/src/flag_gems/ops/div.py b/src/flag_gems/ops/div.py index ad711a0b..20e4b60c 100644 --- a/src/flag_gems/ops/div.py +++ b/src/flag_gems/ops/div.py @@ -1,7 +1,8 @@ +import logging + import torch import triton -import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -26,14 +27,11 @@ def div_func_scalar_tensor(x, y): def div(A, B): logging.debug("GEMS DIV") if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): - O = div_func(A, B) - return O + return div_func(A, B) elif isinstance(A, torch.Tensor): - O = div_func_tensor_scalar(A, B) - return O + return div_func_tensor_scalar(A, B) elif isinstance(B, torch.Tensor): - O = div_func_scalar_tensor(A, B) - return O + return div_func_scalar_tensor(A, B) else: # Both scalar return A / B diff --git a/src/flag_gems/ops/dropout.py b/src/flag_gems/ops/dropout.py index dff36e96..3567916e 100644 --- a/src/flag_gems/ops/dropout.py +++ b/src/flag_gems/ops/dropout.py @@ -1,17 +1,18 @@ +import logging + +import torch import triton import triton.language as tl -import torch -import logging -from ..utils.random_utils import philox_cuda_seed_offset -from ..utils import libentry +from ..utils import libentry +from ..utils.random_utils import philox_cuda_seed_offset try: tl_rand_dtype = tl.int64 + @triton.jit def _rand(seed, offset): offset = offset.to(tl_rand_dtype) - z = tl.rand(seed, offset, n_rounds=6) _grid = (1,) _seed, _offset = philox_cuda_seed_offset(0) @@ -114,18 +115,18 @@ def forward(ctx, x, p, train): logging.debug("GEMS NATIVE DROPOUT FORWARD") assert p > 0.0 and p < 1.0, "p must be in (0, 1)" x = x.contiguous() - O = torch.empty_like(x) + out = torch.empty_like(x) N = x.numel() grid_fn = lambda meta: (triton.cdiv(N, meta["N_BLOCK_SIZE"]),) # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = N philox_seed, philox_offset = philox_cuda_seed_offset(increment) - dropout_forward_kernel[grid_fn](x, O, N, p, philox_seed, philox_offset) + dropout_forward_kernel[grid_fn](x, out, N, p, philox_seed, philox_offset) ctx.p = p ctx.philox_seed = philox_seed ctx.philox_offset = philox_offset - return O, None + return out, None @staticmethod def backward(ctx, grad_outputs, kwargs): diff --git a/src/flag_gems/ops/eq.py b/src/flag_gems/ops/eq.py index 868f5737..38ee0447 100644 --- a/src/flag_gems/ops/eq.py +++ b/src/flag_gems/ops/eq.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -13,8 +15,7 @@ def eq_func(x, y): def eq(A, B): logging.debug("GEMS EQ") - O = eq_func(A, B) - return O + return eq_func(A, B) @pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) @@ -25,5 +26,4 @@ def eq_func_scalar(x, y): def eq_scalar(A, B): logging.debug("GEMS EQ SCALAR") - O = eq_func_scalar(A, B) - return O + return eq_func_scalar(A, B) diff --git a/src/flag_gems/ops/exp.py b/src/flag_gems/ops/exp.py index 9986883e..5d51c5b4 100644 --- a/src/flag_gems/ops/exp.py +++ b/src/flag_gems/ops/exp.py @@ -1,6 +1,8 @@ +import logging + import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -12,5 +14,4 @@ def exp_func(x): def exp(A): logging.debug("GEMS EXP") - O = exp_func(A) - return O + return exp_func(A) diff --git a/src/flag_gems/ops/ge.py b/src/flag_gems/ops/ge.py index 7614369b..2edd5dfc 100644 --- a/src/flag_gems/ops/ge.py +++ b/src/flag_gems/ops/ge.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -13,8 +15,7 @@ def ge_func(x, y): def ge(A, B): logging.debug("GEMS GE") - O = ge_func(A, B) - return O + return ge_func(A, B) @pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) @@ -25,5 +26,4 @@ def ge_func_scalar(x, y): def ge_scalar(A, B): logging.debug("GEMS GE SCALAR") - O = ge_func_scalar(A, B) - return O + return ge_func_scalar(A, B) diff --git a/src/flag_gems/ops/gelu.py b/src/flag_gems/ops/gelu.py index d8d21261..2df36194 100644 --- a/src/flag_gems/ops/gelu.py +++ b/src/flag_gems/ops/gelu.py @@ -1,6 +1,8 @@ +import logging + import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic diff --git a/src/flag_gems/ops/groupnorm.py b/src/flag_gems/ops/groupnorm.py index 45639955..4fd5fda7 100644 --- a/src/flag_gems/ops/groupnorm.py +++ b/src/flag_gems/ops/groupnorm.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import libentry diff --git a/src/flag_gems/ops/gt.py b/src/flag_gems/ops/gt.py index 6260fe1a..4fe53628 100644 --- a/src/flag_gems/ops/gt.py +++ b/src/flag_gems/ops/gt.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -13,8 +15,7 @@ def gt_func(x, y): def gt(A, B): logging.debug("GEMS GT") - O = gt_func(A, B) - return O + return gt_func(A, B) @pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) @@ -25,5 +26,4 @@ def gt_func_scalar(x, y): def gt_scalar(A, B): logging.debug("GEMS GT SCALAR") - O = gt_func_scalar(A, B) - return O + return gt_func_scalar(A, B) diff --git a/src/flag_gems/ops/isinf.py b/src/flag_gems/ops/isinf.py index fba408d9..17c8488c 100644 --- a/src/flag_gems/ops/isinf.py +++ b/src/flag_gems/ops/isinf.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -13,5 +15,4 @@ def isinf_func(x): def isinf(A): logging.debug("GEMS ISINF") - O = isinf_func(A) - return O + return isinf_func(A) diff --git a/src/flag_gems/ops/isnan.py b/src/flag_gems/ops/isnan.py index 8c77fd00..f58f2ce1 100644 --- a/src/flag_gems/ops/isnan.py +++ b/src/flag_gems/ops/isnan.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -13,5 +15,4 @@ def isnan_func(x): def isnan(A): logging.debug("GEMS ISNAN") - O = isnan_func(A) - return O + return isnan_func(A) diff --git a/src/flag_gems/ops/layernorm.py b/src/flag_gems/ops/layernorm.py index 7a096a4a..a9cb6291 100644 --- a/src/flag_gems/ops/layernorm.py +++ b/src/flag_gems/ops/layernorm.py @@ -1,9 +1,11 @@ +import logging +import math + import torch import triton import triton.language as tl -import logging + from ..utils import libentry -import math def cfggen(): diff --git a/src/flag_gems/ops/le.py b/src/flag_gems/ops/le.py index e31437a6..70d84be5 100644 --- a/src/flag_gems/ops/le.py +++ b/src/flag_gems/ops/le.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -13,8 +15,7 @@ def le_func(x, y): def le(A, B): logging.debug("GEMS LE") - O = le_func(A, B) - return O + return le_func(A, B) @pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) @@ -25,5 +26,4 @@ def le_func_scalar(x, y): def le_scalar(A, B): logging.debug("GEMS LE SCALAR") - O = le_func_scalar(A, B) - return O + return le_func_scalar(A, B) diff --git a/src/flag_gems/ops/log_softmax.py b/src/flag_gems/ops/log_softmax.py index 3cff8a71..be39e357 100644 --- a/src/flag_gems/ops/log_softmax.py +++ b/src/flag_gems/ops/log_softmax.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import libentry diff --git a/src/flag_gems/ops/lt.py b/src/flag_gems/ops/lt.py index 8d560ea1..26e828d1 100644 --- a/src/flag_gems/ops/lt.py +++ b/src/flag_gems/ops/lt.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -13,8 +15,7 @@ def lt_func(x, y): def lt(A, B): logging.debug("GEMS LT") - O = lt_func(A, B) - return O + return lt_func(A, B) @pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) @@ -25,5 +26,4 @@ def lt_func_scalar(x, y): def lt_scalar(A, B): logging.debug("GEMS LT SCALAR") - O = lt_func_scalar(A, B) - return O + return lt_func_scalar(A, B) diff --git a/src/flag_gems/ops/max.py b/src/flag_gems/ops/max.py index 8e27f7a2..aac983d5 100644 --- a/src/flag_gems/ops/max.py +++ b/src/flag_gems/ops/max.py @@ -1,10 +1,12 @@ +import logging +import math +from collections import namedtuple + import torch import triton import triton.language as tl -import logging + from ..utils import libentry -import math -from collections import namedtuple @libentry() diff --git a/src/flag_gems/ops/mean.py b/src/flag_gems/ops/mean.py index daeee423..4e24d363 100644 --- a/src/flag_gems/ops/mean.py +++ b/src/flag_gems/ops/mean.py @@ -1,9 +1,11 @@ +import logging +import math + import torch import triton import triton.language as tl -import logging + from ..utils import libentry -import math @libentry() @@ -85,7 +87,7 @@ def mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr) def mean_dim(x, dim, keepdim=False, *, dtype=None): logging.debug("GEMS MEAN DIM") - if dtype == None: + if dtype is None: dtype = x.dtype if dim is None: dim = list(range(x.ndim)) diff --git a/src/flag_gems/ops/min.py b/src/flag_gems/ops/min.py index 1a17dde2..2c5e9c0d 100644 --- a/src/flag_gems/ops/min.py +++ b/src/flag_gems/ops/min.py @@ -1,10 +1,12 @@ +import logging +import math +from collections import namedtuple + import torch import triton import triton.language as tl -import logging + from ..utils import libentry -import math -from collections import namedtuple @libentry() diff --git a/src/flag_gems/ops/mm.py b/src/flag_gems/ops/mm.py index 8b81996c..d5906e38 100644 --- a/src/flag_gems/ops/mm.py +++ b/src/flag_gems/ops/mm.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import libentry diff --git a/src/flag_gems/ops/mul.py b/src/flag_gems/ops/mul.py index bf65ecba..e5745f98 100644 --- a/src/flag_gems/ops/mul.py +++ b/src/flag_gems/ops/mul.py @@ -1,7 +1,8 @@ +import logging + import torch import triton -import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -20,14 +21,11 @@ def mul_func_scalar(x, y): def mul(A, B): logging.debug("GEMS MUL") if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): - O = mul_func(A, B) - return O + return mul_func(A, B) elif isinstance(A, torch.Tensor): - O = mul_func_scalar(A, B) - return O + return mul_func_scalar(A, B) elif isinstance(B, torch.Tensor): - O = mul_func_scalar(B, A) - return O + return mul_func_scalar(B, A) else: # Both scalar return A * B diff --git a/src/flag_gems/ops/mv.py b/src/flag_gems/ops/mv.py index ab449862..02d53f5e 100644 --- a/src/flag_gems/ops/mv.py +++ b/src/flag_gems/ops/mv.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import libentry diff --git a/src/flag_gems/ops/ne.py b/src/flag_gems/ops/ne.py index 28696d0f..f322ffff 100644 --- a/src/flag_gems/ops/ne.py +++ b/src/flag_gems/ops/ne.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -13,8 +15,7 @@ def ne_func(x, y): def ne(A, B): logging.debug("GEMS NE") - O = ne_func(A, B) - return O + return ne_func(A, B) @pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) @@ -25,5 +26,4 @@ def ne_func_scalar(x, y): def ne_scalar(A, B): logging.debug("GEMS NE SCALAR") - O = ne_func_scalar(A, B) - return O + return ne_func_scalar(A, B) diff --git a/src/flag_gems/ops/neg.py b/src/flag_gems/ops/neg.py index 8a6f8072..88ea6a26 100644 --- a/src/flag_gems/ops/neg.py +++ b/src/flag_gems/ops/neg.py @@ -1,5 +1,7 @@ -import triton import logging + +import triton + from ..utils import pointwise_dynamic @@ -11,5 +13,4 @@ def neg_func(x): def neg(A): logging.debug("GEMS NEG") - O = neg_func(A) - return O + return neg_func(A) diff --git a/src/flag_gems/ops/outer.py b/src/flag_gems/ops/outer.py index 4bd62576..1c8a8974 100644 --- a/src/flag_gems/ops/outer.py +++ b/src/flag_gems/ops/outer.py @@ -1,9 +1,9 @@ -import torch -import triton.language as tl import logging -from ..utils import libentry -from .mul import mul + +import torch + from .mm import mm +from .mul import mul class Outer(torch.autograd.Function): diff --git a/src/flag_gems/ops/pow.py b/src/flag_gems/ops/pow.py index 9430d4d5..60abf1fa 100644 --- a/src/flag_gems/ops/pow.py +++ b/src/flag_gems/ops/pow.py @@ -1,6 +1,8 @@ +import logging + import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -12,8 +14,7 @@ def pow_func(x, exponent): def pow_tensor_tensor(A, exponent): logging.debug("GEMS POW_TENSOR_TENSOR") - O = pow_func(A, exponent) - return O + return pow_func(A, exponent) @pointwise_dynamic(is_tensor=[True, False]) @@ -24,8 +25,7 @@ def pow_func_tensor_scalar(x, exponent): def pow_tensor_scalar(A, exponent): logging.debug("GEMS POW_TENSOR_SCALAR") - O = pow_func_tensor_scalar(A, exponent) - return O + return pow_func_tensor_scalar(A, exponent) @pointwise_dynamic(is_tensor=[False, True]) @@ -36,5 +36,4 @@ def pow_func_scalar_tensor(x, exponent): def pow_scalar(A, exponent): logging.debug("GEMS POW_SCALAR") - O = pow_func_scalar_tensor(A, exponent) - return O + return pow_func_scalar_tensor(A, exponent) diff --git a/src/flag_gems/ops/prod.py b/src/flag_gems/ops/prod.py index 572147d3..090ff949 100644 --- a/src/flag_gems/ops/prod.py +++ b/src/flag_gems/ops/prod.py @@ -1,8 +1,10 @@ +import logging +import math + import torch import triton import triton.language as tl -import logging -import math + from ..utils import libentry diff --git a/src/flag_gems/ops/reciprocal.py b/src/flag_gems/ops/reciprocal.py index 28904147..432d45ea 100644 --- a/src/flag_gems/ops/reciprocal.py +++ b/src/flag_gems/ops/reciprocal.py @@ -1,6 +1,8 @@ +import logging + import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -12,5 +14,4 @@ def reciprocal_func(x): def reciprocal(A): logging.debug("GEMS RECIPROCAL") - O = reciprocal_func(A) - return O + return reciprocal_func(A) diff --git a/src/flag_gems/ops/relu.py b/src/flag_gems/ops/relu.py index 35c3199a..32d59af1 100644 --- a/src/flag_gems/ops/relu.py +++ b/src/flag_gems/ops/relu.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -21,9 +23,9 @@ class Relu(torch.autograd.Function): @staticmethod def forward(ctx, A): logging.debug("GEMS RELU FORWARD") - O = relu_forward(A) + out = relu_forward(A) ctx.save_for_backward(A) - return O + return out @staticmethod def backward(ctx, out_grad): diff --git a/src/flag_gems/ops/rms_norm.py b/src/flag_gems/ops/rms_norm.py index c1d21658..9f4308ff 100644 --- a/src/flag_gems/ops/rms_norm.py +++ b/src/flag_gems/ops/rms_norm.py @@ -1,9 +1,11 @@ +import logging +import math + import torch import triton import triton.language as tl -import logging + from ..utils import libentry -import math @libentry() diff --git a/src/flag_gems/ops/rsqrt.py b/src/flag_gems/ops/rsqrt.py index f6475468..24090a0e 100644 --- a/src/flag_gems/ops/rsqrt.py +++ b/src/flag_gems/ops/rsqrt.py @@ -1,6 +1,8 @@ +import logging + import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -12,5 +14,4 @@ def rsqrt_func(x): def rsqrt(A): logging.debug("GEMS RSQRT") - O = rsqrt_func(A) - return O + return rsqrt_func(A) diff --git a/src/flag_gems/ops/sigmoid.py b/src/flag_gems/ops/sigmoid.py index e93abc90..833e9559 100644 --- a/src/flag_gems/ops/sigmoid.py +++ b/src/flag_gems/ops/sigmoid.py @@ -1,8 +1,10 @@ +import logging import math + import torch import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -25,9 +27,9 @@ class Sigmoid(torch.autograd.Function): @staticmethod def forward(ctx, A): logging.debug("GEMS SIGMOID FORWARD") - O = sigmoid_forward(A.to(torch.float32)) - ctx.save_for_backward(O) - return O.to(A.dtype) + out = sigmoid_forward(A.to(torch.float32)) + ctx.save_for_backward(out) + return out.to(A.dtype) @staticmethod def backward(ctx, out_grad): diff --git a/src/flag_gems/ops/silu.py b/src/flag_gems/ops/silu.py index 8e1d0170..358ab3aa 100644 --- a/src/flag_gems/ops/silu.py +++ b/src/flag_gems/ops/silu.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -27,9 +29,9 @@ class Silu(torch.autograd.Function): @staticmethod def forward(ctx, A): logging.debug("GEMS SILU FORWARD") - O = silu_forward(A) + out = silu_forward(A) ctx.save_for_backward(A) - return O + return out @staticmethod def backward(ctx, out_grad): diff --git a/src/flag_gems/ops/sin.py b/src/flag_gems/ops/sin.py index cd39e4f8..4431f75e 100644 --- a/src/flag_gems/ops/sin.py +++ b/src/flag_gems/ops/sin.py @@ -1,6 +1,8 @@ +import logging + import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -12,5 +14,4 @@ def sin_func(x): def sin(A): logging.debug("GEMS SIN") - O = sin_func(A) - return O + return sin_func(A) diff --git a/src/flag_gems/ops/softmax.py b/src/flag_gems/ops/softmax.py index 491267f7..07596ed8 100644 --- a/src/flag_gems/ops/softmax.py +++ b/src/flag_gems/ops/softmax.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import libentry diff --git a/src/flag_gems/ops/sub.py b/src/flag_gems/ops/sub.py index a4e4404e..72804744 100644 --- a/src/flag_gems/ops/sub.py +++ b/src/flag_gems/ops/sub.py @@ -1,7 +1,8 @@ +import logging + import torch import triton -import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -26,14 +27,11 @@ def sub_func_scalar_tensor(x, y, alpha): def sub(A, B, *, alpha=1): logging.debug("GEMS SUB") if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): - O = sub_func(A, B, alpha) - return O + return sub_func(A, B, alpha) elif isinstance(A, torch.Tensor): - O = sub_func_tensor_scalar(A, B, alpha) - return O + return sub_func_tensor_scalar(A, B, alpha) elif isinstance(B, torch.Tensor): - O = sub_func_scalar_tensor(A, B, alpha) - return O + return sub_func_scalar_tensor(A, B, alpha) else: # Both scalar return A - B * alpha diff --git a/src/flag_gems/ops/sum.py b/src/flag_gems/ops/sum.py index 8830c46c..eea1736f 100644 --- a/src/flag_gems/ops/sum.py +++ b/src/flag_gems/ops/sum.py @@ -1,9 +1,11 @@ +import logging +import math + import torch import triton import triton.language as tl -import logging + from ..utils import libentry -import math @libentry() diff --git a/src/flag_gems/ops/tanh.py b/src/flag_gems/ops/tanh.py index bed0e5e4..b2cff858 100644 --- a/src/flag_gems/ops/tanh.py +++ b/src/flag_gems/ops/tanh.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -21,9 +23,9 @@ class Tanh(torch.autograd.Function): @staticmethod def forward(ctx, A): logging.debug("GEMS TANH FORWARD") - O = tanh_forward(A.to(torch.float32)) - ctx.save_for_backward(O) - return O.to(A.dtype) + out = tanh_forward(A.to(torch.float32)) + ctx.save_for_backward(out) + return out.to(A.dtype) @staticmethod def backward(ctx, out_grad): diff --git a/src/flag_gems/ops/triu.py b/src/flag_gems/ops/triu.py index bf2c2fe7..2a2aae56 100644 --- a/src/flag_gems/ops/triu.py +++ b/src/flag_gems/ops/triu.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import libentry @@ -84,12 +86,12 @@ def triu_batch_kernel( def triu(A, diagonal=0): logging.debug("GEMS TRIU") A = A.contiguous() - O = torch.empty_like(A) + out = torch.empty_like(A) assert len(A.shape) > 1, "Input tensor must have at least 2 dimensions" M, N = A.shape[-2:] if len(A.shape) == 2: grid = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) - triu_kernel[grid](A, O, M, N, diagonal) + triu_kernel[grid](A, out, M, N, diagonal) else: batch = int(torch.numel(A) / M / N) B = A.view(batch, -1) @@ -97,6 +99,6 @@ def triu(A, diagonal=0): triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]), triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]), ) - triu_batch_kernel[grid](B, O, batch, M * N, N, diagonal) - O = O.view(A.shape) - return O + triu_batch_kernel[grid](B, out, batch, M * N, N, diagonal) + out = out.view(A.shape) + return out diff --git a/src/flag_gems/ops/var_mean.py b/src/flag_gems/ops/var_mean.py index d76b528a..4ebc3954 100644 --- a/src/flag_gems/ops/var_mean.py +++ b/src/flag_gems/ops/var_mean.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import libentry diff --git a/src/flag_gems/ops/vector_norm.py b/src/flag_gems/ops/vector_norm.py index 9b82a12e..9d0e0297 100644 --- a/src/flag_gems/ops/vector_norm.py +++ b/src/flag_gems/ops/vector_norm.py @@ -1,7 +1,9 @@ +import logging + import torch import triton import triton.language as tl -import logging + from ..utils import libentry diff --git a/src/flag_gems/ops/where.py b/src/flag_gems/ops/where.py index 78e3be02..68b09868 100644 --- a/src/flag_gems/ops/where.py +++ b/src/flag_gems/ops/where.py @@ -1,7 +1,8 @@ -import torch +import logging + import triton import triton.language as tl -import logging + from ..utils import pointwise_dynamic @@ -13,8 +14,7 @@ def where_self_func(self, condition, other): def where_self(condition, self, other): logging.debug("GEMS WHERE_SELF") - O = where_self_func(self, condition, other) - return O + return where_self_func(self, condition, other) @pointwise_dynamic(is_tensor=[True, True, False]) @@ -25,8 +25,7 @@ def where_scalar_self_func(other, condition, self): def where_scalar_self(condition, self, other): logging.debug("GEMS WHERE_SCALAR_SELF") - O = where_scalar_self_func(other, condition, self) - return O + return where_scalar_self_func(other, condition, self) @pointwise_dynamic(is_tensor=[True, True, False]) @@ -37,5 +36,4 @@ def where_scalar_other_func(self, condition, other): def where_scalar_other(condition, self, other): logging.debug("GEMS WHERE_SCALAR_OTHER") - O = where_scalar_other_func(self, condition, other) - return O + return where_scalar_other_func(self, condition, other) diff --git a/src/flag_gems/utils/code_cache.py b/src/flag_gems/utils/code_cache.py index 496a1163..43db3bc5 100644 --- a/src/flag_gems/utils/code_cache.py +++ b/src/flag_gems/utils/code_cache.py @@ -1,7 +1,7 @@ import functools import os -from pathlib import Path import shutil +from pathlib import Path @functools.lru_cache(maxsize=None) # this is the same as functools.cache in Python 3.9+ diff --git a/src/flag_gems/utils/code_utils.py b/src/flag_gems/utils/code_utils.py index 93ea75a9..6f71fa14 100644 --- a/src/flag_gems/utils/code_utils.py +++ b/src/flag_gems/utils/code_utils.py @@ -1,5 +1,7 @@ -# The code for IndentedBuffer is adapted from https://github.com/pytorch/pytorch/blob/ed48ea9997c2b04736096e4b6669543ab2e627d5/torch/_inductor/utils.py#L742 -# The code for Namespace is adapted from https://github.com/pytorch/pytorch/blob/ed48ea9997c2b04736096e4b6669543ab2e627d5/torch/fx/graph.py#L115 +# The code for IndentedBuffer is adapted from +# https://github.com/pytorch/pytorch/blob/ed48ea9997c2b04736096e4b6669543ab2e627d5/torch/_inductor/utils.py#L742 +# The code for Namespace is adapted from +# https://github.com/pytorch/pytorch/blob/ed48ea9997c2b04736096e4b6669543ab2e627d5/torch/fx/graph.py#L115 # License from pytorch(https://github.com/pytorch/pytorch) @@ -54,13 +56,13 @@ # All rights reserved. -import keyword import builtins -from collections import defaultdict +import contextlib +import keyword import re -from typing import Set, Dict +from collections import defaultdict from io import StringIO -import contextlib +from typing import Dict, Set class IndentedBuffer: diff --git a/src/flag_gems/utils/inliner.py b/src/flag_gems/utils/inliner.py index f504e909..577df7aa 100644 --- a/src/flag_gems/utils/inliner.py +++ b/src/flag_gems/utils/inliner.py @@ -1,7 +1,8 @@ -from typing import List import ast +from typing import List from triton.runtime import JITFunction + from flag_gems.utils.code_utils import NameSpace diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index 369cda5f..c6f47c32 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -1,16 +1,16 @@ -import os import importlib -from typing import Tuple, List, Callable, Mapping, Optional, Any +import os +from typing import Any, Callable, List, Mapping, Optional, Tuple import torch import triton -from triton.runtime.jit import JITFunction from triton import language as tl +from triton.runtime.jit import JITFunction -from flag_gems.utils.shape_utils import broadcast_shapes from flag_gems.utils.code_cache import cache_dir -from flag_gems.utils.inliner import inline_function from flag_gems.utils.code_utils import IndentedBuffer, NameSpace +from flag_gems.utils.inliner import inline_function +from flag_gems.utils.shape_utils import broadcast_shapes # ------------------ Operation Description --------------------------- @@ -246,7 +246,8 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer: code.writeline("from triton import language as tl") code.newline() code.writeline( - "from flag_gems.utils.shape_utils import broadcast_shapes, broadcasted_stride, c_contiguous_stride, volume, Stride" + "from flag_gems.utils.shape_utils import broadcast_shapes, \ + broadcasted_stride, c_contiguous_stride, volume, Stride" ) code.writeline("from flag_gems.utils.libentry import libentry") code.newline() @@ -279,17 +280,20 @@ def generate_functional_pointwise_wrapper( for i in range(op_desc.num_outputs()): if op_desc.output_dtype(i) is None: code.writeline( - f"out{num_output_tensor_index} = torch.empty(shape, dtype=in0.dtype, device=in0.device)" + f"out{num_output_tensor_index} = \ + torch.empty(shape, dtype=in0.dtype, device=in0.device)" ) else: code.writeline( - f"out{num_output_tensor_index} = torch.empty(shape, dtype={_type_name(op_desc.output_dtype(i))}, device=in0.device)" + f"out{num_output_tensor_index} = \ + torch.empty(shape, dtype={_type_name(op_desc.output_dtype(i))}, device=in0.device)" ) num_output_tensor_index += 1 # call destination_passing_func output_names: str = output_ref_for_wrapper(op_desc) - call_str = f"{output_names} = {destination_passing_func_name}({parameter_ref_for_wrapper(op_desc, include_outputs=True)})" + call_str = f"{output_names} = {destination_passing_func_name} \ + ({parameter_ref_for_wrapper(op_desc, include_outputs=True)})" code.writeline(call_str) return_str = f"return {output_names}" @@ -326,7 +330,7 @@ def generate_destination_passing_pointwise_wrapper( f"in{i}.shape" for i in range(op_desc.num_input_tensors()) ) code.writeline(f"shape = broadcast_shapes([{shapes_str}])") - code.writeline(f"num_tasks = volume(shape)") + code.writeline("num_tasks = volume(shape)") code.newline() # input strides for each input tensor w.r.t. the task index space @@ -368,7 +372,7 @@ def generate_destination_passing_pointwise_wrapper( shape_args: str = ", ".join(f"shape[{i}]" for i in range(rank)) if rank > 0: code.writeline(f"{shape_args}, # task indexing space") - code.writeline(f"num_tasks, # num tasks") + code.writeline("num_tasks, # num tasks") code.writeline(f"tile_size={tile_size},") code.writeline(f"num_warps={num_warps},") @@ -451,7 +455,7 @@ def generate_pointwise_kernel( code.writeline(f"{task_space_args}, # task_space") # number of tasks, used to compute mask - code.writeline(f"num_tasks: int,") + code.writeline("num_tasks: int,") function_ns.create_name("num_tasks") code.writeline("tile_size: tl.constexpr,") diff --git a/src/flag_gems/utils/shape_utils.py b/src/flag_gems/utils/shape_utils.py index 70e64afc..41053d28 100644 --- a/src/flag_gems/utils/shape_utils.py +++ b/src/flag_gems/utils/shape_utils.py @@ -1,6 +1,6 @@ -from typing import Tuple, Iterable import functools import operator +from typing import Iterable, Tuple Shape = Tuple[int] Stride = Tuple[int] diff --git a/tests/accuracy_utils.py b/tests/accuracy_utils.py index 9a4ed90e..83c4888c 100644 --- a/tests/accuracy_utils.py +++ b/tests/accuracy_utils.py @@ -1,6 +1,6 @@ import torch -from .conftest import TO_CPU +from .conftest import TO_CPU major, minor = torch.__version__.split(".")[:2] skip_expr = major < "2" or minor < "2" diff --git a/tests/test_binary_pointwise_ops.py b/tests/test_binary_pointwise_ops.py index 8da780d2..aa06d127 100644 --- a/tests/test_binary_pointwise_ops.py +++ b/tests/test_binary_pointwise_ops.py @@ -1,6 +1,8 @@ -import torch import pytest +import torch + import flag_gems + from .accuracy_utils import * diff --git a/tests/test_blas_ops.py b/tests/test_blas_ops.py index a791adc4..e05afa71 100644 --- a/tests/test_blas_ops.py +++ b/tests/test_blas_ops.py @@ -1,6 +1,8 @@ -import torch import pytest +import torch + import flag_gems + from .accuracy_utils import * diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 8bdacc75..4f4fd920 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -1,6 +1,8 @@ -import torch import pytest +import torch + import flag_gems + from .accuracy_utils import * diff --git a/tests/test_special_ops.py b/tests/test_special_ops.py index 4ce1c082..24305244 100644 --- a/tests/test_special_ops.py +++ b/tests/test_special_ops.py @@ -1,6 +1,8 @@ -import torch import pytest +import torch + import flag_gems + from .accuracy_utils import * diff --git a/tests/test_unary_pointwise_ops.py b/tests/test_unary_pointwise_ops.py index 794ee382..3c5998a2 100644 --- a/tests/test_unary_pointwise_ops.py +++ b/tests/test_unary_pointwise_ops.py @@ -1,6 +1,8 @@ -import torch import pytest +import torch + import flag_gems + from .accuracy_utils import *