Skip to content

Commit

Permalink
broken parsing example
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Manainen authored and Max Manainen committed Oct 14, 2024
1 parent 8fc4d95 commit 93a8703
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/filecheck/transforms/broken-jax-parsing.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: xdsl-opt %s -p jax-use-donated-arguments --split-input-file --verify-diagnostics | filecheck %s
// other passes like convert-linalg-to-loops have the same problems

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module @jit_matmul attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = tensor.empty() : tensor<2x4xf32>
%cst = arith.constant 0.000000e+00 : f32
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32>
%2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%3 = arith.mulf %in, %in_0 : f32
%4 = arith.addf %out, %3 : f32
linalg.yield %4 : f32
} -> tensor<2x4xf32>
return %2 : tensor<2x4xf32>
}
}

0 comments on commit 93a8703

Please sign in to comment.