Skip to content

Commit

Permalink
fixing tests and adding github action
Browse files Browse the repository at this point in the history
  • Loading branch information
adamomainz committed Aug 23, 2024
1 parent 8821ef3 commit e74a435
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 25 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: build

on:
push:
branches: [ main ]
pull_request:
branches: [ main, "*"]

jobs:

build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install pip dependencies
run: |
python3 -m pip install --upgrade pip
python3 -m pip install -r requirements.txt
- name: Run Python unit tests
run: python3 -u -m pytest tests/tests.py
File renamed without changes.
20 changes: 11 additions & 9 deletions test/test_blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch

import triton
import triton.ops
import kernels


def is_hip_mi200():
Expand Down Expand Up @@ -57,7 +57,7 @@ def mask_tensor(x, mask, block, value=0):
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
@pytest.mark.parametrize("DTYPE", [torch.float16])
def test_matmul(
MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, device, Z=3, H=2, M=512, N=384, K=256
MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, device="cuda", Z=3, H=2, M=512, N=384, K=256
):
seed = 0
torch.manual_seed(seed)
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_matmul(
b_tri = do_sparsify(b_tri) if is_dds else b_tri
a_tri.retain_grad()
b_tri.retain_grad()
op = triton.ops.blocksparse.matmul(
op = kernels.blocksparse.matmul(
layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device=device
)
c_tri = op(a_tri, b_tri)
Expand Down Expand Up @@ -132,7 +132,9 @@ def test_matmul(

@pytest.mark.parametrize("is_dense", [False, True])
@pytest.mark.parametrize("BLOCK, WIDTH", configs)
def test_softmax(BLOCK, WIDTH, is_dense, device, Z=2, H=2, is_causal=True, scale=0.4):
def test_softmax(
BLOCK, WIDTH, is_dense, device="cuda", Z=2, H=2, is_causal=True, scale=0.4
):
# set seed
torch.random.manual_seed(0)
Z, H, M, N = 2, 3, WIDTH, WIDTH
Expand Down Expand Up @@ -164,7 +166,7 @@ def test_softmax(BLOCK, WIDTH, is_dense, device, Z=2, H=2, is_causal=True, scale
a_tri = sparsify_tensor(a_tri, layout, BLOCK)
a_tri.retain_grad()
dout_tri = sparsify_tensor(dout_tri, layout, BLOCK)
op = triton.ops.blocksparse.softmax(layout, BLOCK, device=device, is_dense=is_dense)
op = kernels.blocksparse.softmax(layout, BLOCK, device=device, is_dense=is_dense)
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
out_tri.backward(dout_tri)
da_tri = a_tri.grad
Expand All @@ -178,7 +180,7 @@ def test_softmax(BLOCK, WIDTH, is_dense, device, Z=2, H=2, is_causal=True, scale
def test_attention_fwd_bwd(
block,
dtype,
device,
device="cuda",
input_scale=1.0,
scale=1 / 8.0,
n_ctx=256,
Expand Down Expand Up @@ -251,13 +253,13 @@ def triton_attention(
value: torch.Tensor,
scale: float,
):
sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(
sparse_dot_sdd_nt = kernels.blocksparse.matmul(
layout, block, "sdd", trans_a=False, trans_b=True, device=value.device
)
sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(
sparse_dot_dsd_nn = kernels.blocksparse.matmul(
layout, block, "dsd", trans_a=False, trans_b=False, device=value.device
)
sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device)
sparse_softmax = kernels.blocksparse.softmax(layout, block, device=value.device)

w = sparse_dot_sdd_nt(query, key)
w = sparse_softmax(w, scale=scale, is_causal=True)
Expand Down
7 changes: 3 additions & 4 deletions test/test_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest
import torch

import triton
import triton.ops
import kernels


@pytest.mark.parametrize(
Expand All @@ -15,7 +14,7 @@
for mode in ["forward", "backward"]
],
)
def test_op(M, N, dtype, mode, device):
def test_op(M, N, dtype, mode, device="cuda"):
capability = torch.cuda.get_device_capability()
if capability[0] < 8 and dtype == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
Expand All @@ -28,7 +27,7 @@ def test_op(M, N, dtype, mode, device):
x = torch.randn(M, N, dtype=dtype, device=device, requires_grad=True)
idx = 4 + torch.ones(M, dtype=torch.int64, device=device)
# forward pass
tt_y = triton.ops.cross_entropy(x, idx)
tt_y = kernels.cross_entropy(x, idx)
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
if mode == "forward":
torch.testing.assert_close(th_y, tt_y)
Expand Down
8 changes: 4 additions & 4 deletions test/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import torch
import os

import kernels
import triton
import triton.ops


@pytest.mark.interpreter
Expand All @@ -19,7 +19,7 @@
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("seq_par", [True, False])
def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par, device):
def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par, device="cuda"):
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability >= 80")
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par, device):
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# # triton implementation
tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par)
tri_out = kernels.attention(q, k, v, causal, sm_scale, seq_par)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
Expand Down Expand Up @@ -151,7 +151,7 @@ def bench_flash_attention(
(BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True
)
if provider == "triton":
fn = lambda: triton.ops.attention(q, k, v, casual, sm_scale, seq_par)
fn = lambda: kernels.attention(q, k, v, casual, sm_scale, seq_par)
if mode == "bwd":
o = fn()
do = torch.randn_like(o)
Expand Down
8 changes: 4 additions & 4 deletions test/test_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import triton.language as tl


def test_normalization_with_remat(device):
def test_normalization_with_remat(device="cuda"):

@triton.jit
def triton_(
Expand Down Expand Up @@ -80,7 +80,7 @@ def triton_(
)


def test_avg_pool_bw(device):
def test_avg_pool_bw(device="cuda"):

@triton.jit
def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr):
Expand Down Expand Up @@ -200,7 +200,7 @@ def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr):

@pytest.mark.parametrize("RBLOCK", [1, 16, 32, 64, 128])
@pytest.mark.parametrize("num_warps", [1, 4])
def test_scan2d_broadcast(RBLOCK, num_warps, device):
def test_scan2d_broadcast(RBLOCK, num_warps, device="cuda"):

@triton.jit(debug=True)
def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
Expand All @@ -220,7 +220,7 @@ def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
torch.testing.assert_close(output, ref)


def test_scan2d_for(device):
def test_scan2d_for(device="cuda"):

@triton.jit
def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr):
Expand Down
8 changes: 4 additions & 4 deletions test/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import triton
import triton.language as tl
import triton.ops
import kernels


def is_hip():
Expand Down Expand Up @@ -1006,7 +1006,7 @@ def test_op(
kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook
)
]
kernel = triton.ops._matmul.kernel
kernel = kernels._matmul.kernel
kernel.configs = configs
# kernel.run = kernel.run.run.run

Expand Down Expand Up @@ -1071,7 +1071,7 @@ def init_input(m, n, dtype, acc_dtype):
# run test
th_a = upcast_if_fp8(a, ADTYPE)
th_b = upcast_if_fp8(b, BDTYPE)
ab_dtype = triton.ops.get_higher_dtype(th_a.dtype, th_b.dtype)
ab_dtype = kernels.get_higher_dtype(th_a.dtype, th_b.dtype)
acc_dtype = getattr(torch, ACC_DTYPE) if ACC_DTYPE else ab_dtype
output_dtype = getattr(torch, OUTPUT_DTYPE) if OUTPUT_DTYPE else ab_dtype
th_c = torch.matmul(th_a.to(output_dtype), th_b.to(output_dtype))
Expand All @@ -1080,7 +1080,7 @@ def init_input(m, n, dtype, acc_dtype):
a = triton.reinterpret(a, getattr(tl, ADTYPE))
if is_fp8(BDTYPE):
b = triton.reinterpret(b, getattr(tl, BDTYPE))
tt_c = triton.ops.matmul(
tt_c = kernels.matmul(
a,
b,
acc_dtype if ACC_DTYPE else None,
Expand Down

0 comments on commit e74a435

Please sign in to comment.