Skip to content

Commit

Permalink
more black
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Apr 30, 2024
1 parent 52d6f3e commit 378661f
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def DiagonalModule_nonsquare(module, tu: TestUtils):

# ==============================================================================


class DiagonalWithStaticShapeModule(torch.nn.Module):
"""
Diagonal with static shape. The other diagonal modules are failing in onnx
Expand All @@ -46,22 +47,26 @@ class DiagonalWithStaticShapeModule(torch.nn.Module):
Please remove this module and associated test once the issue is fixed.
"""

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([5, 9], torch.float32, True),
])

@annotate_args(
[
None,
([5, 9], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.diagonal(a)



@register_test_case(module_factory=lambda: DiagonalWithStaticShapeModule())
def DiagonalWithStaticShapeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 9))


# ==============================================================================


Expand Down

0 comments on commit 378661f

Please sign in to comment.