diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..f5330bb --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,24 @@ +name: build + +on: + push: + branches: [ main ] + pull_request: + branches: [ main, "*"] + +jobs: + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + - name: Install pip dependencies + run: | + python3 -m pip install --upgrade pip + python3 -m pip install -r requirements.txt + - name: Run Python unit tests + run: python3 -u -m pytest tests/tests.py diff --git a/models/llama/requirements.txt b/requirements.txt similarity index 100% rename from models/llama/requirements.txt rename to requirements.txt diff --git a/test/test_blocksparse.py b/test/test_blocksparse.py index cf456af..abc71b6 100644 --- a/test/test_blocksparse.py +++ b/test/test_blocksparse.py @@ -2,7 +2,7 @@ import torch import triton -import triton.ops +import kernels def is_hip_mi200(): @@ -57,7 +57,7 @@ def mask_tensor(x, mask, block, value=0): @pytest.mark.parametrize("BLOCK", [16, 32, 64]) @pytest.mark.parametrize("DTYPE", [torch.float16]) def test_matmul( - MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, device, Z=3, H=2, M=512, N=384, K=256 + MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, device="cuda", Z=3, H=2, M=512, N=384, K=256 ): seed = 0 torch.manual_seed(seed) @@ -103,7 +103,7 @@ def test_matmul( b_tri = do_sparsify(b_tri) if is_dds else b_tri a_tri.retain_grad() b_tri.retain_grad() - op = triton.ops.blocksparse.matmul( + op = kernels.blocksparse.matmul( layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device=device ) c_tri = op(a_tri, b_tri) @@ -132,7 +132,9 @@ def test_matmul( @pytest.mark.parametrize("is_dense", [False, True]) @pytest.mark.parametrize("BLOCK, WIDTH", configs) -def test_softmax(BLOCK, WIDTH, is_dense, device, Z=2, H=2, is_causal=True, scale=0.4): +def test_softmax( + BLOCK, WIDTH, is_dense, device="cuda", Z=2, H=2, is_causal=True, scale=0.4 +): # set seed torch.random.manual_seed(0) Z, H, M, N = 2, 3, WIDTH, WIDTH @@ -164,7 +166,7 @@ def test_softmax(BLOCK, WIDTH, is_dense, device, Z=2, H=2, is_causal=True, scale a_tri = sparsify_tensor(a_tri, layout, BLOCK) a_tri.retain_grad() dout_tri = sparsify_tensor(dout_tri, layout, BLOCK) - op = triton.ops.blocksparse.softmax(layout, BLOCK, device=device, is_dense=is_dense) + op = kernels.blocksparse.softmax(layout, BLOCK, device=device, is_dense=is_dense) out_tri = op(a_tri, scale=scale, is_causal=is_causal) out_tri.backward(dout_tri) da_tri = a_tri.grad @@ -178,7 +180,7 @@ def test_softmax(BLOCK, WIDTH, is_dense, device, Z=2, H=2, is_causal=True, scale def test_attention_fwd_bwd( block, dtype, - device, + device="cuda", input_scale=1.0, scale=1 / 8.0, n_ctx=256, @@ -251,13 +253,13 @@ def triton_attention( value: torch.Tensor, scale: float, ): - sparse_dot_sdd_nt = triton.ops.blocksparse.matmul( + sparse_dot_sdd_nt = kernels.blocksparse.matmul( layout, block, "sdd", trans_a=False, trans_b=True, device=value.device ) - sparse_dot_dsd_nn = triton.ops.blocksparse.matmul( + sparse_dot_dsd_nn = kernels.blocksparse.matmul( layout, block, "dsd", trans_a=False, trans_b=False, device=value.device ) - sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device) + sparse_softmax = kernels.blocksparse.softmax(layout, block, device=value.device) w = sparse_dot_sdd_nt(query, key) w = sparse_softmax(w, scale=scale, is_causal=True) diff --git a/test/test_cross_entropy.py b/test/test_cross_entropy.py index 701f54a..05ac67c 100644 --- a/test/test_cross_entropy.py +++ b/test/test_cross_entropy.py @@ -1,8 +1,7 @@ import pytest import torch -import triton -import triton.ops +import kernels @pytest.mark.parametrize( @@ -15,7 +14,7 @@ for mode in ["forward", "backward"] ], ) -def test_op(M, N, dtype, mode, device): +def test_op(M, N, dtype, mode, device="cuda"): capability = torch.cuda.get_device_capability() if capability[0] < 8 and dtype == "bfloat16": pytest.skip("Only test bfloat16 on devices with sm >= 80") @@ -28,7 +27,7 @@ def test_op(M, N, dtype, mode, device): x = torch.randn(M, N, dtype=dtype, device=device, requires_grad=True) idx = 4 + torch.ones(M, dtype=torch.int64, device=device) # forward pass - tt_y = triton.ops.cross_entropy(x, idx) + tt_y = kernels.cross_entropy(x, idx) th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx) if mode == "forward": torch.testing.assert_close(th_y, tt_y) diff --git a/test/test_flash_attention.py b/test/test_flash_attention.py index 724028c..af8a645 100644 --- a/test/test_flash_attention.py +++ b/test/test_flash_attention.py @@ -2,8 +2,8 @@ import torch import os +import kernels import triton -import triton.ops @pytest.mark.interpreter @@ -19,7 +19,7 @@ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("seq_par", [True, False]) -def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par, device): +def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par, device="cuda"): capability = torch.cuda.get_device_capability() if capability[0] < 8: pytest.skip("Flash attention only supported for compute capability >= 80") @@ -56,7 +56,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par, device): ref_dk, k.grad = k.grad.clone(), None ref_dq, q.grad = q.grad.clone(), None # # triton implementation - tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par) + tri_out = kernels.attention(q, k, v, causal, sm_scale, seq_par) tri_out.backward(dout) tri_dv, v.grad = v.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None @@ -151,7 +151,7 @@ def bench_flash_attention( (BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True ) if provider == "triton": - fn = lambda: triton.ops.attention(q, k, v, casual, sm_scale, seq_par) + fn = lambda: kernels.attention(q, k, v, casual, sm_scale, seq_par) if mode == "bwd": o = fn() do = torch.randn_like(o) diff --git a/test/test_inductor.py b/test/test_inductor.py index 73c3b9b..2820876 100644 --- a/test/test_inductor.py +++ b/test/test_inductor.py @@ -5,7 +5,7 @@ import triton.language as tl -def test_normalization_with_remat(device): +def test_normalization_with_remat(device="cuda"): @triton.jit def triton_( @@ -80,7 +80,7 @@ def triton_( ) -def test_avg_pool_bw(device): +def test_avg_pool_bw(device="cuda"): @triton.jit def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): @@ -200,7 +200,7 @@ def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): @pytest.mark.parametrize("RBLOCK", [1, 16, 32, 64, 128]) @pytest.mark.parametrize("num_warps", [1, 4]) -def test_scan2d_broadcast(RBLOCK, num_warps, device): +def test_scan2d_broadcast(RBLOCK, num_warps, device="cuda"): @triton.jit(debug=True) def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): @@ -220,7 +220,7 @@ def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): torch.testing.assert_close(output, ref) -def test_scan2d_for(device): +def test_scan2d_for(device="cuda"): @triton.jit def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr): diff --git a/test/test_matmul.py b/test/test_matmul.py index dab9f82..632165e 100644 --- a/test/test_matmul.py +++ b/test/test_matmul.py @@ -5,7 +5,7 @@ import triton import triton.language as tl -import triton.ops +import kernels def is_hip(): @@ -1006,7 +1006,7 @@ def test_op( kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook ) ] - kernel = triton.ops._matmul.kernel + kernel = kernels._matmul.kernel kernel.configs = configs # kernel.run = kernel.run.run.run @@ -1071,7 +1071,7 @@ def init_input(m, n, dtype, acc_dtype): # run test th_a = upcast_if_fp8(a, ADTYPE) th_b = upcast_if_fp8(b, BDTYPE) - ab_dtype = triton.ops.get_higher_dtype(th_a.dtype, th_b.dtype) + ab_dtype = kernels.get_higher_dtype(th_a.dtype, th_b.dtype) acc_dtype = getattr(torch, ACC_DTYPE) if ACC_DTYPE else ab_dtype output_dtype = getattr(torch, OUTPUT_DTYPE) if OUTPUT_DTYPE else ab_dtype th_c = torch.matmul(th_a.to(output_dtype), th_b.to(output_dtype)) @@ -1080,7 +1080,7 @@ def init_input(m, n, dtype, acc_dtype): a = triton.reinterpret(a, getattr(tl, ADTYPE)) if is_fp8(BDTYPE): b = triton.reinterpret(b, getattr(tl, BDTYPE)) - tt_c = triton.ops.matmul( + tt_c = kernels.matmul( a, b, acc_dtype if ACC_DTYPE else None,