diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d1451b7c..beabd017 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,12 +17,12 @@ jobs: matrix: python-version: [ "3.8" ] steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies + - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install flake8 @@ -30,16 +30,16 @@ jobs: python -m pip install black python -m pip install black[jupyter] if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - name: Check Python syntax errors + - name: Check Python syntax errors run: | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - - name: Validate the coding style + - name: Validate the coding style run: | python -m isort . --check python -m black . --check - - name: Install Trident + - name: Install Trident run: | bash install_package.sh - - name: Test with pytest + - name: Test with pytest run: | python -m pytest -n 2 . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d85322bb..a707e85a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,15 +2,15 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.3.0 hooks: - - id: check-yaml - - id: end-of-file-fixer - - id: trailing-whitespace + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - - id: isort + - id: isort - repo: https://github.com/psf/black rev: 22.10.0 hooks: - - id: black - - id: black-jupyter + - id: black + - id: black-jupyter diff --git a/benchmarks/benchmark_shift_gelu.py b/benchmarks/benchmark_shift_gelu.py index 2ced211e..dd12bbc9 100644 --- a/benchmarks/benchmark_shift_gelu.py +++ b/benchmarks/benchmark_shift_gelu.py @@ -14,7 +14,6 @@ import torch import triton -import triton.language as tl import util import trident diff --git a/examples/playground.py b/examples/playground.py index a4a7b928..e39be181 100644 --- a/examples/playground.py +++ b/examples/playground.py @@ -11,10 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import torch -import torch.nn as nn -import triton -import triton.language as tl - -import trident diff --git a/tests/test_geglu.py b/tests/test_geglu.py index 118b44dc..fa125ea3 100644 --- a/tests/test_geglu.py +++ b/tests/test_geglu.py @@ -24,7 +24,6 @@ def geglu(input, weight, bias: torch.Tensor = None): return state * torch.nn.functional.gelu(gate) -# @pytest.mark.skip @pytest.mark.parametrize("num_batches, m_size, n_size, k_size", [(2, 4, 4, 4)]) def test_forward(num_batches, m_size, n_size, k_size, device): input = torch.randn(num_batches, m_size, k_size, device=device) diff --git a/trident/kernel/argmax.py b/trident/kernel/argmax.py index 3fd0ee32..cf1c5734 100644 --- a/trident/kernel/argmax.py +++ b/trident/kernel/argmax.py @@ -15,8 +15,6 @@ import triton import triton.language as tl -from trident import language - class Argmax: @staticmethod diff --git a/trident/kernel/group_norm.py b/trident/kernel/group_norm.py index c5a71c66..5c191588 100644 --- a/trident/kernel/group_norm.py +++ b/trident/kernel/group_norm.py @@ -15,8 +15,6 @@ import triton import triton.language as tl -from trident import language - class GroupNorm: @staticmethod diff --git a/trident/kernel/layer_norm.py b/trident/kernel/layer_norm.py index 509fd711..2e59e7ae 100644 --- a/trident/kernel/layer_norm.py +++ b/trident/kernel/layer_norm.py @@ -15,8 +15,6 @@ import triton import triton.language as tl -from trident import kernel, language - class LayerNorm: @staticmethod diff --git a/trident/kernel/max.py b/trident/kernel/max.py index 5b723d2f..cb9c2240 100644 --- a/trident/kernel/max.py +++ b/trident/kernel/max.py @@ -15,8 +15,6 @@ import triton import triton.language as tl -from trident import language - class Max: @staticmethod diff --git a/trident/kernel/rms_norm.py b/trident/kernel/rms_norm.py index eb65e114..9c40608a 100644 --- a/trident/kernel/rms_norm.py +++ b/trident/kernel/rms_norm.py @@ -15,8 +15,6 @@ import triton import triton.language as tl -from trident import language - class RMSNorm: @staticmethod diff --git a/trident/kernel/silu.py b/trident/kernel/silu.py index 5df8e097..465c2f01 100644 --- a/trident/kernel/silu.py +++ b/trident/kernel/silu.py @@ -15,8 +15,6 @@ import triton import triton.language as tl -from trident import language - def silu_configs(): configs = [] diff --git a/trident/language/combine.py b/trident/language/combine.py index 83f524d3..129e414c 100644 --- a/trident/language/combine.py +++ b/trident/language/combine.py @@ -15,8 +15,6 @@ import triton import triton.language as tl -from trident import language - @triton.jit def combine_welford(m2_a, mean_a, count_a, m2_b, mean_b, count_b): diff --git a/trident/language/constant.py b/trident/language/constant.py index 7619a7e1..6aabe728 100644 --- a/trident/language/constant.py +++ b/trident/language/constant.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import triton import triton.language as tl dim = [tl.constexpr(i) for i in range(3)] diff --git a/trident/language/function.py b/trident/language/function.py index 9ea2fd8f..a9dc0f8a 100644 --- a/trident/language/function.py +++ b/trident/language/function.py @@ -15,8 +15,6 @@ import triton import triton.language as tl -from trident import language - @triton.jit def batch(index, num_channels, num_rows, num_cols): diff --git a/trident/language/var.py b/trident/language/var.py index 4a534b85..3de6b363 100644 --- a/trident/language/var.py +++ b/trident/language/var.py @@ -15,8 +15,6 @@ import triton import triton.language as tl -from trident import language - class Var: @staticmethod diff --git a/trident/math/__init__.py b/trident/math/__init__.py deleted file mode 100644 index aed7fc3a..00000000 --- a/trident/math/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2023 ⓒ Kakao Brain Corp. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .math import * diff --git a/trident/math/math.py b/trident/math/math.py deleted file mode 100644 index 51cbfb88..00000000 --- a/trident/math/math.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2023 ⓒ Kakao Brain Corp. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import triton - - -def clamp(v, lo, hi): - return max(lo, min(hi, v)) - - -def is_pow2(x): - return x > 0 and (x & (x - 1)) == 0 - - -def multiple_of(x, y): - return triton.cdiv(x, y) * y - - -def prev_pow2(x): - x |= x >> 1 - x |= x >> 2 - x |= x >> 4 - x |= x >> 8 - x |= x >> 16 - x |= x >> 32 - x -= x >> 1 - - return x diff --git a/trident/operation/cosine_similarity.py b/trident/operation/cosine_similarity.py index 65079cc9..16931cff 100644 --- a/trident/operation/cosine_similarity.py +++ b/trident/operation/cosine_similarity.py @@ -15,7 +15,6 @@ from typing import Any import torch -import triton from trident import kernel, util diff --git a/trident/operation/layer_norm.py b/trident/operation/layer_norm.py index 06e250df..cf521ade 100644 --- a/trident/operation/layer_norm.py +++ b/trident/operation/layer_norm.py @@ -18,7 +18,7 @@ import torch import triton -from trident import function, kernel, util +from trident import kernel, util class LayerNorm(torch.autograd.Function): diff --git a/trident/operation/max_pool2d.py b/trident/operation/max_pool2d.py index 696f6c04..a26afb5c 100644 --- a/trident/operation/max_pool2d.py +++ b/trident/operation/max_pool2d.py @@ -17,7 +17,7 @@ import torch import triton -from trident import kernel, math +from trident import kernel class MaxPool2d(torch.autograd.Function): @@ -39,11 +39,7 @@ def __forward(inp, knl_sz): assert out.is_contiguous() grid = lambda meta: (inp_bt * inp_ch * out_h * triton.cdiv(out_w, meta["grp_sz"]),) - grp_sz = math.clamp( - 128 // triton.next_power_of_2(knl_sz), - 1, - triton.next_power_of_2(out_w), - ) + grp_sz = max(1, min(triton.next_power_of_2(out_w), 128 // triton.next_power_of_2(knl_sz))) kernel.MaxPool2d.forward[grid]( inp, diff --git a/trident/operation/mean.py b/trident/operation/mean.py index d44f6653..0a23d677 100644 --- a/trident/operation/mean.py +++ b/trident/operation/mean.py @@ -17,7 +17,7 @@ import torch import triton -from trident import kernel, math, util +from trident import kernel, util class Mean(torch.autograd.Function): diff --git a/trident/operation/prelu.py b/trident/operation/prelu.py index a069def4..692801ce 100644 --- a/trident/operation/prelu.py +++ b/trident/operation/prelu.py @@ -17,7 +17,7 @@ import torch import triton -from trident import kernel, util +from trident import kernel class PReLU(torch.autograd.Function): diff --git a/trident/operation/sum.py b/trident/operation/sum.py index 799dadbb..0bc391d4 100644 --- a/trident/operation/sum.py +++ b/trident/operation/sum.py @@ -15,9 +15,8 @@ from typing import Any import torch -import triton -from trident import kernel, math, util +from trident import kernel, util class Sum(torch.autograd.Function): diff --git a/trident/operation/var.py b/trident/operation/var.py index 628f546c..00ea2ad5 100644 --- a/trident/operation/var.py +++ b/trident/operation/var.py @@ -17,7 +17,7 @@ import torch import triton -from trident import kernel, math, util +from trident import kernel, util class Var(torch.autograd.Function): diff --git a/trident/operation/var_mean.py b/trident/operation/var_mean.py index a889e5fa..a581176a 100644 --- a/trident/operation/var_mean.py +++ b/trident/operation/var_mean.py @@ -17,7 +17,7 @@ import torch import triton -from trident import kernel, math, util +from trident import kernel, util class VarMean(torch.autograd.Function): diff --git a/trident/util/util.py b/trident/util/util.py index d87fa49a..b5c810b1 100644 --- a/trident/util/util.py +++ b/trident/util/util.py @@ -13,10 +13,9 @@ # limitations under the License. import torch -import triton import triton.language as tl -from trident import math, module, operation +from trident import module def fill(inp, val):