Skip to content

Commit

Permalink
Add negative example for .mT.
Browse files Browse the repository at this point in the history
  • Loading branch information
apaz-cli committed Mar 24, 2024
1 parent 98ab3e8 commit e72dc5c
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3947,6 +3947,7 @@ def matrix_transpose_sample_generator(op, device, dtype, requires_grad, **kwargs

# shape
cases = (
(),
(2, 3),
(2, 3, 4),
(2, 3, 4, 2),
Expand All @@ -3959,14 +3960,8 @@ def matrix_transpose_error_generator(op, device, dtype=torch.float32, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype)

# shape, error type, error message
RuntimeError(f"tensor.mT is only supported on matrices or batches of matrices. Got 1-D tensor.")
cases = (
((4, 5, 6), RuntimeError, r"t\(\) expects a tensor with <= 2 dimensions, but self is 3D"),
(
(4, 5, 6, 7),
RuntimeError,
r"t\(\) expects a tensor with <= 2 dimensions, but self is 4D",
),
((3), RuntimeError, "tensor.mT is only supported on matrices or batches of matrices. Got 1-D tensor."),
)

for shape, err_type, err_msg in cases:
Expand All @@ -3976,6 +3971,7 @@ def matrix_transpose_error_generator(op, device, dtype=torch.float32, **kwargs):
transpose_opinfo = OpInfo(
clang.matrix_transpose,
sample_input_generator=matrix_transpose_sample_generator,
error_input_generator=matrix_transpose_error_generator,
torch_reference=lambda x: x.mT,
)
shape_ops.append(transpose_opinfo)
Expand Down

0 comments on commit e72dc5c

Please sign in to comment.