From e72dc5c4ad17abf9fb91d134124fb37056282c74 Mon Sep 17 00:00:00 2001 From: apaz-cli Date: Sun, 24 Mar 2024 16:42:50 -0400 Subject: [PATCH] Add negative example for .mT. --- thunder/tests/opinfos.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 57e32e5175..bab8e58b1a 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -3947,6 +3947,7 @@ def matrix_transpose_sample_generator(op, device, dtype, requires_grad, **kwargs # shape cases = ( + (), (2, 3), (2, 3, 4), (2, 3, 4, 2), @@ -3959,14 +3960,8 @@ def matrix_transpose_error_generator(op, device, dtype=torch.float32, **kwargs): make = partial(make_tensor, device=device, dtype=dtype) # shape, error type, error message - RuntimeError(f"tensor.mT is only supported on matrices or batches of matrices. Got 1-D tensor.") cases = ( - ((4, 5, 6), RuntimeError, r"t\(\) expects a tensor with <= 2 dimensions, but self is 3D"), - ( - (4, 5, 6, 7), - RuntimeError, - r"t\(\) expects a tensor with <= 2 dimensions, but self is 4D", - ), + ((3), RuntimeError, "tensor.mT is only supported on matrices or batches of matrices. Got 1-D tensor."), ) for shape, err_type, err_msg in cases: @@ -3976,6 +3971,7 @@ def matrix_transpose_error_generator(op, device, dtype=torch.float32, **kwargs): transpose_opinfo = OpInfo( clang.matrix_transpose, sample_input_generator=matrix_transpose_sample_generator, + error_input_generator=matrix_transpose_error_generator, torch_reference=lambda x: x.mT, ) shape_ops.append(transpose_opinfo)