Skip to content

Commit

Permalink
[TESTS] smaller problem sizes in matmul tests (triton-lang#1908)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet authored Jul 6, 2023
1 parent 9638f5c commit 61e17db
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions python/test/unit/operators/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,24 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr):
(64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
# variable input
(128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE, DTYPE),
(128, 256, 64, 1, 8, 3, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
],
# n-stage
*[
[
(16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
(64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
(128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
(256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE, DTYPE),
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE),
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE),
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE),
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE),
# split-k
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE),
(64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE, DTYPE),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4]
(64, 64, 16, 8, 4, STAGES, 128, 128, 768, AT, BT, DTYPE, DTYPE),
(64, 64, 16, 8, 4, STAGES, 128, 128, 32, AT, BT, DTYPE, DTYPE),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4]
],
# mixed-precision
*[
Expand All @@ -90,7 +89,7 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr):
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
(32, 64, 16, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE),
(128, 128, 32, 8, 4, 2, 1024, 1024, 1024, AT, BT, ADTYPE, BDTYPE),
(128, 128, 32, 8, 4, 2, 256, 256, 128, AT, BT, ADTYPE, BDTYPE),
] for ADTYPE, BDTYPE in [("float8", "float16"), ("float16", "float32"), ("float32", "float16"),
("bfloat16", "float32"), ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
]
Expand Down

0 comments on commit 61e17db

Please sign in to comment.