Skip to content

Commit

Permalink
recursion test
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx committed Jan 17, 2025
1 parent d3aab16 commit 0154874
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions enzyme/test/MLIR/Batch/recursion.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: %eopt -enzyme-batch %s | FileCheck %s

module {
func.func private @f(%arg0: tensor<16xf32>, %arg1: tensor<16xf32>) -> tensor<16xf32> {
%0 = func.call @f(%arg0, %arg1) : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
return %0 : tensor<16xf32>
}
func.func @main(%arg0: tensor<4x16xf32>, %arg1: tensor<4x16xf32>) {
%0 = enzyme.batch @f(%arg0, %arg1) {batch_shape = array<i64: 4>} : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
return
}
}

// CHECK: func.func @main(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) {
// CHECK-NEXT: %[[v0:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK: func.func private @batched_f(%[[arg0:.+]]: tensor<4x16xf32>, %[[arg1:.+]]: tensor<4x16xf32>) -> tensor<4x16xf32> {
// CHECK-NEXT: %[[v0:.+]] = call @batched_f(%[[arg0]], %[[arg1]]) : (tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32>
// CHECK-NEXT: return %[[v0]] : tensor<4x16xf32>
// CHECK-NEXT: }

0 comments on commit 0154874

Please sign in to comment.