From 61e17db4f738072a3373772605e5f0d6536c0b54 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 6 Jul 2023 14:36:02 -0700 Subject: [PATCH] [TESTS] smaller problem sizes in matmul tests (#1908) --- python/test/unit/operators/test_matmul.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 78d2ab3f6d5e..01dd813537ee 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -63,25 +63,24 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): (64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), (64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE), # variable input - (128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE), - (128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE, DTYPE), + (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE), (128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE, DTYPE), (128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE, DTYPE), - (128, 256, 64, 1, 8, 3, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE), + (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE), ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] ], # n-stage *[ [ - (16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE), - (64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE), - (128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE), - (256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE), - (128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE, DTYPE), + (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE), + (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE), + (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE), + (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE), + (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE), # split-k - (64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE, DTYPE), - (64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE, DTYPE), - ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4] + (64, 64, 16, 8, 4, STAGES, 128, 128, 768, AT, BT, DTYPE, DTYPE), + (64, 64, 16, 8, 4, STAGES, 128, 128, 32, AT, BT, DTYPE, DTYPE), + ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4] ], # mixed-precision *[ @@ -90,7 +89,7 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, ADTYPE, BDTYPE), (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE), (32, 64, 16, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE), - (128, 128, 32, 8, 4, 2, 1024, 1024, 1024, AT, BT, ADTYPE, BDTYPE), + (128, 128, 32, 8, 4, 2, 256, 256, 128, AT, BT, ADTYPE, BDTYPE), ] for ADTYPE, BDTYPE in [("float8", "float16"), ("float16", "float32"), ("float32", "float16"), ("bfloat16", "float32"), ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] ]