diff --git a/tests/accuracy_utils.py b/tests/accuracy_utils.py index 56819fab..9a4ed90e 100644 --- a/tests/accuracy_utils.py +++ b/tests/accuracy_utils.py @@ -14,7 +14,7 @@ } POINTWISE_SHAPES = [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)] -REDUCTION_SHAPES = [(1024, 64 * i) for i in range(1, 10, 2)] +REDUCTION_SHAPES = [(4096, 256 * i) for i in range(1, 10, 2)] MNK_SHAPES = [15, 160, 1024] FLOAT_DTYPES = [torch.float16, torch.float32, torch.bfloat16]