Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Optimize imports
Browse files Browse the repository at this point in the history
  • Loading branch information
daemyung committed Sep 7, 2023
1 parent 04852ea commit c0eb02e
Show file tree
Hide file tree
Showing 26 changed files with 22 additions and 111 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,29 @@ jobs:
matrix:
python-version: [ "3.8" ]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8
python -m pip install isort
python -m pip install black
python -m pip install black[jupyter]
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Check Python syntax errors
- name: Check Python syntax errors
run: |
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
- name: Validate the coding style
- name: Validate the coding style
run: |
python -m isort . --check
python -m black . --check
- name: Install Trident
- name: Install Trident
run: |
bash install_package.sh
- name: Test with pytest
- name: Test with pytest
run: |
python -m pytest -n 2 .
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
- id: isort
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black
- id: black-jupyter
- id: black
- id: black-jupyter
1 change: 0 additions & 1 deletion benchmarks/benchmark_shift_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import torch
import triton
import triton.language as tl
import util

import trident
Expand Down
7 changes: 0 additions & 7 deletions examples/playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import triton
import triton.language as tl

import trident
1 change: 0 additions & 1 deletion tests/test_geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def geglu(input, weight, bias: torch.Tensor = None):
return state * torch.nn.functional.gelu(gate)


# @pytest.mark.skip
@pytest.mark.parametrize("num_batches, m_size, n_size, k_size", [(2, 4, 4, 4)])
def test_forward(num_batches, m_size, n_size, k_size, device):
input = torch.randn(num_batches, m_size, k_size, device=device)
Expand Down
2 changes: 0 additions & 2 deletions trident/kernel/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import triton
import triton.language as tl

from trident import language


class Argmax:
@staticmethod
Expand Down
2 changes: 0 additions & 2 deletions trident/kernel/group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import triton
import triton.language as tl

from trident import language


class GroupNorm:
@staticmethod
Expand Down
2 changes: 0 additions & 2 deletions trident/kernel/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import triton
import triton.language as tl

from trident import kernel, language


class LayerNorm:
@staticmethod
Expand Down
2 changes: 0 additions & 2 deletions trident/kernel/max.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import triton
import triton.language as tl

from trident import language


class Max:
@staticmethod
Expand Down
2 changes: 0 additions & 2 deletions trident/kernel/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import triton
import triton.language as tl

from trident import language


class RMSNorm:
@staticmethod
Expand Down
2 changes: 0 additions & 2 deletions trident/kernel/silu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import triton
import triton.language as tl

from trident import language


def silu_configs():
configs = []
Expand Down
2 changes: 0 additions & 2 deletions trident/language/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import triton
import triton.language as tl

from trident import language


@triton.jit
def combine_welford(m2_a, mean_a, count_a, m2_b, mean_b, count_b):
Expand Down
1 change: 0 additions & 1 deletion trident/language/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import triton
import triton.language as tl

dim = [tl.constexpr(i) for i in range(3)]
Expand Down
2 changes: 0 additions & 2 deletions trident/language/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import triton
import triton.language as tl

from trident import language


@triton.jit
def batch(index, num_channels, num_rows, num_cols):
Expand Down
2 changes: 0 additions & 2 deletions trident/language/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import triton
import triton.language as tl

from trident import language


class Var:
@staticmethod
Expand Down
15 changes: 0 additions & 15 deletions trident/math/__init__.py

This file was deleted.

39 changes: 0 additions & 39 deletions trident/math/math.py

This file was deleted.

1 change: 0 additions & 1 deletion trident/operation/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Any

import torch
import triton

from trident import kernel, util

Expand Down
2 changes: 1 addition & 1 deletion trident/operation/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import triton

from trident import function, kernel, util
from trident import kernel, util


class LayerNorm(torch.autograd.Function):
Expand Down
8 changes: 2 additions & 6 deletions trident/operation/max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import triton

from trident import kernel, math
from trident import kernel


class MaxPool2d(torch.autograd.Function):
Expand All @@ -39,11 +39,7 @@ def __forward(inp, knl_sz):
assert out.is_contiguous()

grid = lambda meta: (inp_bt * inp_ch * out_h * triton.cdiv(out_w, meta["grp_sz"]),)
grp_sz = math.clamp(
128 // triton.next_power_of_2(knl_sz),
1,
triton.next_power_of_2(out_w),
)
grp_sz = max(1, min(triton.next_power_of_2(out_w), 128 // triton.next_power_of_2(knl_sz)))

kernel.MaxPool2d.forward[grid](
inp,
Expand Down
2 changes: 1 addition & 1 deletion trident/operation/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import triton

from trident import kernel, math, util
from trident import kernel, util


class Mean(torch.autograd.Function):
Expand Down
2 changes: 1 addition & 1 deletion trident/operation/prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import triton

from trident import kernel, util
from trident import kernel


class PReLU(torch.autograd.Function):
Expand Down
3 changes: 1 addition & 2 deletions trident/operation/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
from typing import Any

import torch
import triton

from trident import kernel, math, util
from trident import kernel, util


class Sum(torch.autograd.Function):
Expand Down
2 changes: 1 addition & 1 deletion trident/operation/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import triton

from trident import kernel, math, util
from trident import kernel, util


class Var(torch.autograd.Function):
Expand Down
2 changes: 1 addition & 1 deletion trident/operation/var_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import triton

from trident import kernel, math, util
from trident import kernel, util


class VarMean(torch.autograd.Function):
Expand Down
3 changes: 1 addition & 2 deletions trident/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
# limitations under the License.

import torch
import triton
import triton.language as tl

from trident import math, module, operation
from trident import module


def fill(inp, val):
Expand Down

0 comments on commit c0eb02e

Please sign in to comment.