Skip to content

Commit

Permalink
add FileCheck tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Dixin Zhou committed Dec 30, 2024
1 parent cab423b commit c9c1c97
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,59 @@ func.func @test_einsum_inner_prod(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.

// -----

// CHECK-LABEL: test_aten_trilinear_decompose
func.func @test_aten_trilinear_decompose(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !torch.vtensor<[2,3,6],f32>, %arg2: !torch.vtensor<[6],f32>) -> !torch.vtensor<[3,6],f32> {
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[EXPAND0:.+]] = torch.aten.unsqueeze %arg0, %[[INT1]]
// CHECK: %[[EXPAND1:.+]] = torch.aten.unsqueeze %arg2, %[[INT0]]
// CHECK: %[[EXPAND2:.+]] = torch.aten.unsqueeze %[[EXPAND1]], %[[INT1]]
// CHECK: %[[MUL1:.+]] = torch.aten.mul.Tensor %[[EXPAND0]], %arg1
// CHECK: %[[MUL2:.+]] = torch.aten.mul.Tensor %[[MUL1]], %[[EXPAND2]]
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT0]]
// CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MUL2]], %[[LIST]], %[[FALSE]], %[[NONE]]
// CHECK: return %[[SUM]]
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct : () -> !torch.list<int>
%int0 = torch.constant.int 0
%2 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
%4 = torch.aten._trilinear %arg0, %arg1, %arg2, %0, %1, %2, %3, %int1 : !torch.vtensor<[2,6],f32>, !torch.vtensor<[2,3,6],f32>, !torch.vtensor<[6],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[3,6],f32>
return %4 : !torch.vtensor<[3,6],f32>
}

// -----

// CHECK-LABEL: test_aten_bilinear_decompose
func.func @test_aten_bilinear_decompose(%arg0: !torch.vtensor<[6,2],f32>, %arg1: !torch.vtensor<[6,3],f32>, %arg2: !torch.vtensor<[4,2,3],f32>, %arg3: !torch.vtensor<[4],f32>) -> !torch.vtensor<[6,4],f32> {
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[INT3:.+]] = torch.constant.int 3
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[EXPAND0:.+]] = torch.aten.unsqueeze %arg0, %[[INT1]]
// CHECK: %[[EXPAND1:.+]] = torch.aten.unsqueeze %[[EXPAND0]], %[[INT3]]
// CHECK: %[[EXPAND2:.+]] = torch.aten.unsqueeze %arg2, %[[INT0]]
// CHECK: %[[EXPAND3:.+]] = torch.aten.unsqueeze %arg1, %[[INT1]]
// CHECK: %[[EXPAND4:.+]] = torch.aten.unsqueeze %[[EXPAND3]], %[[INT2]]
// CHECK: %[[MUL1:.+]] = torch.aten.mul.Tensor %[[EXPAND1]], %[[EXPAND2]]
// CHECK: %[[MUL2:.+]] = torch.aten.mul.Tensor %[[MUL1]], %[[EXPAND4]]
// CHECK: %[[LIST1:.+]] = torch.prim.ListConstruct %[[INT3]]
// CHECK: %[[SUM1:.+]] = torch.aten.sum.dim_IntList %[[MUL2]], %[[LIST1]], %[[FALSE]], %[[NONE]]
// CHECK: %[[LIST2:.+]] = torch.prim.ListConstruct %[[INT2]]
// CHECK: %[[SUM2:.+]] = torch.aten.sum.dim_IntList %[[SUM1]], %[[LIST2]], %[[FALSE]], %[[NONE]]
// CHECK: %[[OUT:.+]] = torch.aten.add.Tensor %[[SUM2]], %arg3, %[[INT1]]
// CHECK: return %[[OUT]]
%1 = torch.aten.bilinear %arg0, %arg1, %arg2, %arg3 : !torch.vtensor<[6,2],f32>, !torch.vtensor<[6,3],f32>, !torch.vtensor<[4,2,3],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[6,4],f32>
return %1 : !torch.vtensor<[6,4],f32>
}

// -----

// CHECK: func.func @torch.aten.fmod_int(%[[ARG0:.+]]: !torch.vtensor<[?],si32>, %[[ARG1:.+]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[?],si32> {
// CHECK: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
// CHECK: %[[V0:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32>
Expand Down

0 comments on commit c9c1c97

Please sign in to comment.