From dfa08fe4319d352cde367a2288aacff9f712fc21 Mon Sep 17 00:00:00 2001 From: manainen Date: Fri, 27 Sep 2024 17:29:30 +0100 Subject: [PATCH 01/22] dumb version for one operation type --- docs/marimo/mlir/onnx_demo.py | 8 +-- xdsl/dialects/builtin.py | 13 +++++ xdsl/tools/command_line_tool.py | 6 +++ .../experimental/convert_jax_to_linalg.py | 54 +++++++++++++++++++ 4 files changed, 77 insertions(+), 4 deletions(-) create mode 100644 xdsl/transforms/experimental/convert_jax_to_linalg.py diff --git a/docs/marimo/mlir/onnx_demo.py b/docs/marimo/mlir/onnx_demo.py index 286999ef20..1da39cdc2d 100644 --- a/docs/marimo/mlir/onnx_demo.py +++ b/docs/marimo/mlir/onnx_demo.py @@ -4,7 +4,7 @@ app = marimo.App() -@app.cell(hide_code=True) +@app.cell def __(mo): mo.md( """ @@ -16,7 +16,7 @@ def __(mo): return -@app.cell(hide_code=True) +@app.cell def __(mo): rank = mo.ui.slider(1, 4, value=2, label="Rank") @@ -60,7 +60,7 @@ def __(mo, shape): return -@app.cell(hide_code=True) +@app.cell def __(): import onnx from onnx import AttributeProto, GraphProto, TensorProto, ValueInfoProto, helper @@ -121,7 +121,7 @@ def __(mo): return -@app.cell(hide_code=True) +@app.cell def __(mo, model_def): mo.accordion( { diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 443337b029..1846867dac 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -800,6 +800,19 @@ def get_shape(self) -> tuple[int, ...]: def get_element_type(self) -> AttributeCovT: return self.element_type + def is_same_type_with(self, other_tensor: TensorType): + current_shape = list(self.shape) + other_shape = list(other_tensor.shape) + if len(current_shape) != len(other_shape): + return False + + return ( + len(list(filter(lambda x: x[0] != x[1], zip(current_shape, other_shape)))) + == 0 + and self.element_type == other_tensor.element_type + and self.encoding == other_tensor.encoding + ) + AnyTensorType: TypeAlias = TensorType[Attribute] AnyTensorTypeConstr = BaseAttr[TensorType[Attribute]](TensorType) diff --git a/xdsl/tools/command_line_tool.py b/xdsl/tools/command_line_tool.py index ac20b02ca8..a18581de6d 100644 --- a/xdsl/tools/command_line_tool.py +++ b/xdsl/tools/command_line_tool.py @@ -86,6 +86,11 @@ def get_convert_stencil_to_ll_mlir(): return convert_stencil_to_ll_mlir.ConvertStencilToLLMLIRPass + def get_convert_jax_to_linalg(): + from xdsl.transforms.experimental import convert_jax_to_linalg + + return convert_jax_to_linalg.ConvertJaxToLinalgPass + def get_convert_riscv_scf_to_riscv_cf(): from xdsl.backend.riscv.lowering import convert_riscv_scf_to_riscv_cf @@ -455,6 +460,7 @@ def get_stencil_shape_minimize(): "convert-stencil-to-csl-stencil": get_convert_stencil_to_csl_stencil, "inline-snrt": get_convert_snrt_to_riscv, "convert-stencil-to-ll-mlir": get_convert_stencil_to_ll_mlir, + "convert-jax-to-linalg": get_convert_jax_to_linalg, "cse": get_cse, "csl-stencil-bufferize": get_csl_stencil_bufferize, "csl-stencil-materialize-stores": get_csl_stencil_materialize_stores, diff --git a/xdsl/transforms/experimental/convert_jax_to_linalg.py b/xdsl/transforms/experimental/convert_jax_to_linalg.py new file mode 100644 index 0000000000..e3123b766d --- /dev/null +++ b/xdsl/transforms/experimental/convert_jax_to_linalg.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass + +from xdsl.context import MLContext +from xdsl.dialects import builtin +from xdsl.dialects.builtin import TensorType +from xdsl.dialects.func import FuncOp +from xdsl.dialects.linalg import FillOp +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + op_type_rewrite_pattern, +) + + +@dataclass +class SubstituteDonatedTensors(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /): + donated_inputs = {} + for inp, attr in zip(op.regions[0].block._args, op.arg_attrs): + if ( + type(inp.type) is not TensorType + or "tf.aliasing_output" not in attr.data + ): + continue + donated_inputs[inp.name_hint] = inp + + for child_op in op.regions[0].ops: + if type(child_op) is FillOp: + value_mapper = {} + for output in child_op.outputs: + for arg_name, arg in list(donated_inputs.items()): + if arg.type.is_same_type_with(output.type): + value_mapper[output] = arg + break + new_op = child_op.clone(value_mapper) + rewriter.replace_op(child_op, [new_op]) + + +@dataclass(frozen=True) +class ConvertJaxToLinalgPass(ModulePass): + name = "convert-jax-to-linalg" + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + the_one_pass = PatternRewriteWalker( + GreedyRewritePatternApplier([SubstituteDonatedTensors()]), + apply_recursively=False, + walk_reverse=True, + walk_regions_first=True, + ) + the_one_pass.rewrite_module(op) From c83999fd624fd608cfbc9394986864c9bbd65101 Mon Sep 17 00:00:00 2001 From: manainen Date: Fri, 27 Sep 2024 17:46:25 +0100 Subject: [PATCH 02/22] test added --- .../transforms/convert-jax-to-linalg.mlir | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 tests/filecheck/transforms/convert-jax-to-linalg.mlir diff --git a/tests/filecheck/transforms/convert-jax-to-linalg.mlir b/tests/filecheck/transforms/convert-jax-to-linalg.mlir new file mode 100644 index 0000000000..ff45f653dc --- /dev/null +++ b/tests/filecheck/transforms/convert-jax-to-linalg.mlir @@ -0,0 +1,34 @@ +// RUN: xdsl-opt %s -p convert-jax-to-linalg --split-input-file --verify-diagnostics | filecheck %s + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +builtin.module attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { +func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) { + %0 = tensor.empty() : tensor<2x4xf32> + %cst = arith.constant 0.000000e+00 : f32 + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %3 = arith.mulf %in, %in_0 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<2x4xf32> + return %2 : tensor<2x4xf32> + } +} + +// CHECK: builtin.module attributes {"mhlo.num_partitions" = 1 : i32, "mhlo.num_replicas" = 1 : i32} { +// CHECK-NEXT: func.func public @main(%arg0 : tensor<2x3xf32> {"mhlo.layout_mode" = "default"}, %arg1 : tensor<3x4xf32> {"mhlo.layout_mode" = "default"}, %arg2 : tensor<2x4xf32> {"mhlo.layout_mode" = "default", "tf.aliasing_output" = 0 : i32}) -> tensor<2x4xf32> { +// CHECK-NEXT: %0 = tensor.empty() : tensor<2x4xf32> +// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %1 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) { +// CHECK-NEXT: ^0(%in : f32, %in_1 : f32, %out : f32): +// CHECK-NEXT: %3 = arith.mulf %in, %in_1 : f32 +// CHECK-NEXT: %4 = arith.addf %out, %3 : f32 +// CHECK-NEXT: linalg.yield %4 : f32 +// CHECK-NEXT: } -> tensor<2x4xf32> +// CHECK-NEXT: func.return %2 : tensor<2x4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } From 67b8a07b222170fb1306b834b7d21aa9193aa11d Mon Sep 17 00:00:00 2001 From: manainen Date: Fri, 27 Sep 2024 18:06:11 +0100 Subject: [PATCH 03/22] mlir step added --- .../filecheck/transforms/convert-jax-to-linalg.mlir | 13 ++++++------- .../experimental/convert_jax_to_linalg.py | 3 +++ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/filecheck/transforms/convert-jax-to-linalg.mlir b/tests/filecheck/transforms/convert-jax-to-linalg.mlir index ff45f653dc..894c4e58e1 100644 --- a/tests/filecheck/transforms/convert-jax-to-linalg.mlir +++ b/tests/filecheck/transforms/convert-jax-to-linalg.mlir @@ -20,15 +20,14 @@ func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %a // CHECK: builtin.module attributes {"mhlo.num_partitions" = 1 : i32, "mhlo.num_replicas" = 1 : i32} { // CHECK-NEXT: func.func public @main(%arg0 : tensor<2x3xf32> {"mhlo.layout_mode" = "default"}, %arg1 : tensor<3x4xf32> {"mhlo.layout_mode" = "default"}, %arg2 : tensor<2x4xf32> {"mhlo.layout_mode" = "default", "tf.aliasing_output" = 0 : i32}) -> tensor<2x4xf32> { -// CHECK-NEXT: %0 = tensor.empty() : tensor<2x4xf32> // CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: %1 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK-NEXT: %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) { +// CHECK-NEXT: %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%0 : tensor<2x4xf32>) { // CHECK-NEXT: ^0(%in : f32, %in_1 : f32, %out : f32): -// CHECK-NEXT: %3 = arith.mulf %in, %in_1 : f32 -// CHECK-NEXT: %4 = arith.addf %out, %3 : f32 -// CHECK-NEXT: linalg.yield %4 : f32 +// CHECK-NEXT: %2 = arith.mulf %in, %in_1 : f32 +// CHECK-NEXT: %3 = arith.addf %out, %2 : f32 +// CHECK-NEXT: linalg.yield %3 : f32 // CHECK-NEXT: } -> tensor<2x4xf32> -// CHECK-NEXT: func.return %2 : tensor<2x4xf32> +// CHECK-NEXT: func.return %1 : tensor<2x4xf32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/xdsl/transforms/experimental/convert_jax_to_linalg.py b/xdsl/transforms/experimental/convert_jax_to_linalg.py index e3123b766d..cfec883251 100644 --- a/xdsl/transforms/experimental/convert_jax_to_linalg.py +++ b/xdsl/transforms/experimental/convert_jax_to_linalg.py @@ -13,6 +13,7 @@ RewritePattern, op_type_rewrite_pattern, ) +from xdsl.transforms.mlir_opt import MLIROptPass @dataclass @@ -35,6 +36,7 @@ def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /): for arg_name, arg in list(donated_inputs.items()): if arg.type.is_same_type_with(output.type): value_mapper[output] = arg + del donated_inputs[arg_name] break new_op = child_op.clone(value_mapper) rewriter.replace_op(child_op, [new_op]) @@ -52,3 +54,4 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: walk_regions_first=True, ) the_one_pass.rewrite_module(op) + MLIROptPass(arguments=["--linalg-fuse-elementwise-ops"]).apply(ctx, op) From 2ffda16253b81b509e52c7072d88cdedc02eb177 Mon Sep 17 00:00:00 2001 From: manainen Date: Fri, 27 Sep 2024 18:13:24 +0100 Subject: [PATCH 04/22] more general version --- xdsl/transforms/experimental/convert_jax_to_linalg.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/xdsl/transforms/experimental/convert_jax_to_linalg.py b/xdsl/transforms/experimental/convert_jax_to_linalg.py index cfec883251..aca4fd573a 100644 --- a/xdsl/transforms/experimental/convert_jax_to_linalg.py +++ b/xdsl/transforms/experimental/convert_jax_to_linalg.py @@ -4,7 +4,7 @@ from xdsl.dialects import builtin from xdsl.dialects.builtin import TensorType from xdsl.dialects.func import FuncOp -from xdsl.dialects.linalg import FillOp +from xdsl.dialects.linalg import NamedOpBase from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -30,11 +30,13 @@ def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /): donated_inputs[inp.name_hint] = inp for child_op in op.regions[0].ops: - if type(child_op) is FillOp: + if issubclass(type(child_op), NamedOpBase): value_mapper = {} for output in child_op.outputs: for arg_name, arg in list(donated_inputs.items()): - if arg.type.is_same_type_with(output.type): + if type( + output.type + ) is TensorType and arg.type.is_same_type_with(output.type): value_mapper[output] = arg del donated_inputs[arg_name] break From 121548f96170c1146394903e99ce7c25a8cddb2f Mon Sep 17 00:00:00 2001 From: manainen Date: Sun, 29 Sep 2024 13:17:51 +0100 Subject: [PATCH 05/22] pyright fixes --- xdsl/dialects/builtin.py | 2 +- .../experimental/convert_jax_to_linalg.py | 40 ++++++++++--------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 1846867dac..bc5339439d 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -800,7 +800,7 @@ def get_shape(self) -> tuple[int, ...]: def get_element_type(self) -> AttributeCovT: return self.element_type - def is_same_type_with(self, other_tensor: TensorType): + def is_same_type_with(self, other_tensor: TensorType[Attribute]) -> bool: current_shape = list(self.shape) other_shape = list(other_tensor.shape) if len(current_shape) != len(other_shape): diff --git a/xdsl/transforms/experimental/convert_jax_to_linalg.py b/xdsl/transforms/experimental/convert_jax_to_linalg.py index aca4fd573a..fed10dd6df 100644 --- a/xdsl/transforms/experimental/convert_jax_to_linalg.py +++ b/xdsl/transforms/experimental/convert_jax_to_linalg.py @@ -4,7 +4,8 @@ from xdsl.dialects import builtin from xdsl.dialects.builtin import TensorType from xdsl.dialects.func import FuncOp -from xdsl.dialects.linalg import NamedOpBase +from xdsl.ir import BlockArgument, SSAValue +from xdsl.irdl import VarOperand from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -20,25 +21,26 @@ class SubstituteDonatedTensors(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /): - donated_inputs = {} - for inp, attr in zip(op.regions[0].block._args, op.arg_attrs): + if op.arg_attrs is None: + return + + donated_inputs: list[BlockArgument] = [] + for inp, attr in zip(op.args, op.arg_attrs): + if type(inp.type) is TensorType and "tf.aliasing_output" in attr.data: + donated_inputs.append(inp) + + for child_op in op.body.ops: if ( - type(inp.type) is not TensorType - or "tf.aliasing_output" not in attr.data + hasattr(child_op, "outputs") + and type(getattr(child_op, "outputs")) is VarOperand ): - continue - donated_inputs[inp.name_hint] = inp - - for child_op in op.regions[0].ops: - if issubclass(type(child_op), NamedOpBase): - value_mapper = {} - for output in child_op.outputs: - for arg_name, arg in list(donated_inputs.items()): - if type( - output.type - ) is TensorType and arg.type.is_same_type_with(output.type): - value_mapper[output] = arg - del donated_inputs[arg_name] + value_mapper: dict[SSAValue, SSAValue] = {} + for output in getattr(child_op, "outputs"): + for i, arg in enumerate(donated_inputs): + if type(getattr(output, "type")) is TensorType and getattr( + arg, "type" + ).is_same_type_with(output.type): + value_mapper[output] = donated_inputs.pop(i) break new_op = child_op.clone(value_mapper) rewriter.replace_op(child_op, [new_op]) @@ -56,4 +58,4 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: walk_regions_first=True, ) the_one_pass.rewrite_module(op) - MLIROptPass(arguments=["--linalg-fuse-elementwise-ops"]).apply(ctx, op) + MLIROptPass(arguments=("--linalg-fuse-elementwise-ops",)).apply(ctx, op) From 1319404a3b63608964bfd8fccfab8baeb83d48b5 Mon Sep 17 00:00:00 2001 From: manainen Date: Sun, 29 Sep 2024 13:37:34 +0100 Subject: [PATCH 06/22] change names --- ...-jax-to-linalg.mlir => jax-use-donated-arguments.mlir} | 2 +- xdsl/tools/command_line_tool.py | 8 ++++---- ...vert_jax_to_linalg.py => jax_use_donated_arguments.py} | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) rename tests/filecheck/transforms/{convert-jax-to-linalg.mlir => jax-use-donated-arguments.mlir} (95%) rename xdsl/transforms/experimental/{convert_jax_to_linalg.py => jax_use_donated_arguments.py} (96%) diff --git a/tests/filecheck/transforms/convert-jax-to-linalg.mlir b/tests/filecheck/transforms/jax-use-donated-arguments.mlir similarity index 95% rename from tests/filecheck/transforms/convert-jax-to-linalg.mlir rename to tests/filecheck/transforms/jax-use-donated-arguments.mlir index 894c4e58e1..e352aae774 100644 --- a/tests/filecheck/transforms/convert-jax-to-linalg.mlir +++ b/tests/filecheck/transforms/jax-use-donated-arguments.mlir @@ -1,4 +1,4 @@ -// RUN: xdsl-opt %s -p convert-jax-to-linalg --split-input-file --verify-diagnostics | filecheck %s +// RUN: xdsl-opt %s -p jax-use-donated-arguments --split-input-file --verify-diagnostics | filecheck %s #map = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> diff --git a/xdsl/tools/command_line_tool.py b/xdsl/tools/command_line_tool.py index a18581de6d..d5692ed7b0 100644 --- a/xdsl/tools/command_line_tool.py +++ b/xdsl/tools/command_line_tool.py @@ -86,10 +86,10 @@ def get_convert_stencil_to_ll_mlir(): return convert_stencil_to_ll_mlir.ConvertStencilToLLMLIRPass - def get_convert_jax_to_linalg(): - from xdsl.transforms.experimental import convert_jax_to_linalg + def get_jax_use_donated_arguments(): + from xdsl.transforms.experimental import jax_use_donated_arguments - return convert_jax_to_linalg.ConvertJaxToLinalgPass + return jax_use_donated_arguments.JaxUseDonatedArguments def get_convert_riscv_scf_to_riscv_cf(): from xdsl.backend.riscv.lowering import convert_riscv_scf_to_riscv_cf @@ -460,7 +460,7 @@ def get_stencil_shape_minimize(): "convert-stencil-to-csl-stencil": get_convert_stencil_to_csl_stencil, "inline-snrt": get_convert_snrt_to_riscv, "convert-stencil-to-ll-mlir": get_convert_stencil_to_ll_mlir, - "convert-jax-to-linalg": get_convert_jax_to_linalg, + "jax-use-donated-arguments": get_jax_use_donated_arguments, "cse": get_cse, "csl-stencil-bufferize": get_csl_stencil_bufferize, "csl-stencil-materialize-stores": get_csl_stencil_materialize_stores, diff --git a/xdsl/transforms/experimental/convert_jax_to_linalg.py b/xdsl/transforms/experimental/jax_use_donated_arguments.py similarity index 96% rename from xdsl/transforms/experimental/convert_jax_to_linalg.py rename to xdsl/transforms/experimental/jax_use_donated_arguments.py index fed10dd6df..f7259b8994 100644 --- a/xdsl/transforms/experimental/convert_jax_to_linalg.py +++ b/xdsl/transforms/experimental/jax_use_donated_arguments.py @@ -47,8 +47,8 @@ def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /): @dataclass(frozen=True) -class ConvertJaxToLinalgPass(ModulePass): - name = "convert-jax-to-linalg" +class JaxUseDonatedArguments(ModulePass): + name = "jax-use-donated-arguments" def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: the_one_pass = PatternRewriteWalker( From 6513a6298601b256a92aeff8d5bd12ba26118027 Mon Sep 17 00:00:00 2001 From: manainen Date: Sun, 29 Sep 2024 13:53:05 +0100 Subject: [PATCH 07/22] move test to mlir ones --- .../with-mlir}/jax-use-donated-arguments.mlir | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/filecheck/{transforms => mlir-conversion/with-mlir}/jax-use-donated-arguments.mlir (100%) diff --git a/tests/filecheck/transforms/jax-use-donated-arguments.mlir b/tests/filecheck/mlir-conversion/with-mlir/jax-use-donated-arguments.mlir similarity index 100% rename from tests/filecheck/transforms/jax-use-donated-arguments.mlir rename to tests/filecheck/mlir-conversion/with-mlir/jax-use-donated-arguments.mlir From 7c3d38454b134fcef988f8b6da27c01121b88536 Mon Sep 17 00:00:00 2001 From: manainen Date: Sun, 29 Sep 2024 17:50:15 +0100 Subject: [PATCH 08/22] trim test --- .../with-mlir/jax-use-donated-arguments.mlir | 27 +++++-------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/tests/filecheck/mlir-conversion/with-mlir/jax-use-donated-arguments.mlir b/tests/filecheck/mlir-conversion/with-mlir/jax-use-donated-arguments.mlir index e352aae774..44eace82b3 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/jax-use-donated-arguments.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/jax-use-donated-arguments.mlir @@ -1,33 +1,18 @@ // RUN: xdsl-opt %s -p jax-use-donated-arguments --split-input-file --verify-diagnostics | filecheck %s -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -builtin.module attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { -func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) { +builtin.module { +func.func public @main(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x4xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) { %0 = tensor.empty() : tensor<2x4xf32> %cst = arith.constant 0.000000e+00 : f32 %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> - %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %3 = arith.mulf %in, %in_0 : f32 - %4 = arith.addf %out, %3 : f32 - linalg.yield %4 : f32 - } -> tensor<2x4xf32> - return %2 : tensor<2x4xf32> + return %1 : tensor<2x4xf32> } } -// CHECK: builtin.module attributes {"mhlo.num_partitions" = 1 : i32, "mhlo.num_replicas" = 1 : i32} { -// CHECK-NEXT: func.func public @main(%arg0 : tensor<2x3xf32> {"mhlo.layout_mode" = "default"}, %arg1 : tensor<3x4xf32> {"mhlo.layout_mode" = "default"}, %arg2 : tensor<2x4xf32> {"mhlo.layout_mode" = "default", "tf.aliasing_output" = 0 : i32}) -> tensor<2x4xf32> { +// CHECK: builtin.module { +// CHECK-NEXT: func.func public @main(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>, %arg2 : tensor<2x4xf32> {"tf.aliasing_output" = 0 : i32}) -> tensor<2x4xf32> { // CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 // CHECK-NEXT: %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%0 : tensor<2x4xf32>) { -// CHECK-NEXT: ^0(%in : f32, %in_1 : f32, %out : f32): -// CHECK-NEXT: %2 = arith.mulf %in, %in_1 : f32 -// CHECK-NEXT: %3 = arith.addf %out, %2 : f32 -// CHECK-NEXT: linalg.yield %3 : f32 -// CHECK-NEXT: } -> tensor<2x4xf32> -// CHECK-NEXT: func.return %1 : tensor<2x4xf32> +// CHECK-NEXT: func.return %0 : tensor<2x4xf32> // CHECK-NEXT: } // CHECK-NEXT: } From 432ebd2e5a34af80a9893b44408d3552b52c2c57 Mon Sep 17 00:00:00 2001 From: manainen Date: Sun, 29 Sep 2024 17:51:34 +0100 Subject: [PATCH 09/22] move file --- xdsl/tools/command_line_tool.py | 2 +- xdsl/transforms/{experimental => }/jax_use_donated_arguments.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename xdsl/transforms/{experimental => }/jax_use_donated_arguments.py (100%) diff --git a/xdsl/tools/command_line_tool.py b/xdsl/tools/command_line_tool.py index d5692ed7b0..0686924dc7 100644 --- a/xdsl/tools/command_line_tool.py +++ b/xdsl/tools/command_line_tool.py @@ -87,7 +87,7 @@ def get_convert_stencil_to_ll_mlir(): return convert_stencil_to_ll_mlir.ConvertStencilToLLMLIRPass def get_jax_use_donated_arguments(): - from xdsl.transforms.experimental import jax_use_donated_arguments + from xdsl.transforms import jax_use_donated_arguments return jax_use_donated_arguments.JaxUseDonatedArguments diff --git a/xdsl/transforms/experimental/jax_use_donated_arguments.py b/xdsl/transforms/jax_use_donated_arguments.py similarity index 100% rename from xdsl/transforms/experimental/jax_use_donated_arguments.py rename to xdsl/transforms/jax_use_donated_arguments.py From 290adb3c5280ac8e89654f76418ca80211bf7c1a Mon Sep 17 00:00:00 2001 From: manainen Date: Mon, 7 Oct 2024 22:06:52 +0100 Subject: [PATCH 10/22] test mlirs --- buffered.mlir | 18 ++ docs/marimo/arith.mlir | 23 +++ docs/marimo/arith_donated_buffered.mlir | 22 +++ docs/marimo/buffered.mlir | 18 ++ docs/marimo/conv.mlir | 41 +++++ docs/marimo/conv_buffered.mlir | 43 +++++ docs/marimo/conv_donated.mlir | 44 +++++ docs/marimo/conv_donated_buffered.mlir | 43 +++++ docs/marimo/donated_buffered.mlir | 17 ++ docs/marimo/fused.mlir | 16 ++ docs/marimo/fused_buffered.mlir | 14 ++ docs/marimo/jax_experiments.py | 172 +++++++++++++++++++ docs/marimo/jax_mult.mlir | 17 ++ docs/marimo/jax_mult_donation.mlir | 20 +++ docs/marimo/jax_rewrite_exp.py | 62 +++++++ docs/marimo/multiple_donated_buffered.mlir | 38 ++++ docs/marimo/multiple_outputs.mlir | 39 +++++ docs/marimo/removed.mlir | 17 ++ xdsl/transforms/jax_use_donated_arguments.py | 24 ++- 19 files changed, 684 insertions(+), 4 deletions(-) create mode 100644 buffered.mlir create mode 100644 docs/marimo/arith.mlir create mode 100644 docs/marimo/arith_donated_buffered.mlir create mode 100644 docs/marimo/buffered.mlir create mode 100644 docs/marimo/conv.mlir create mode 100644 docs/marimo/conv_buffered.mlir create mode 100644 docs/marimo/conv_donated.mlir create mode 100644 docs/marimo/conv_donated_buffered.mlir create mode 100644 docs/marimo/donated_buffered.mlir create mode 100644 docs/marimo/fused.mlir create mode 100644 docs/marimo/fused_buffered.mlir create mode 100644 docs/marimo/jax_experiments.py create mode 100644 docs/marimo/jax_mult.mlir create mode 100644 docs/marimo/jax_mult_donation.mlir create mode 100644 docs/marimo/jax_rewrite_exp.py create mode 100644 docs/marimo/multiple_donated_buffered.mlir create mode 100644 docs/marimo/multiple_outputs.mlir create mode 100644 docs/marimo/removed.mlir diff --git a/buffered.mlir b/buffered.mlir new file mode 100644 index 0000000000..6c7ffb4a14 --- /dev/null +++ b/buffered.mlir @@ -0,0 +1,18 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +module { + func.func public @main(%arg0: memref<2x3xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<3x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<2x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> memref<2x4xf32> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x4xf32> + %cst = arith.constant 0.000000e+00 : f32 + linalg.fill ins(%cst : f32) outs(%alloc : memref<2x4xf32>) + linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<2x3xf32, strided<[?, ?], offset: ?>>, memref<3x4xf32, strided<[?, ?], offset: ?>>) outs(%alloc : memref<2x4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %0 = arith.mulf %in, %in_0 : f32 + %1 = arith.addf %out, %0 : f32 + linalg.yield %1 : f32 + } + %cast = memref.cast %alloc : memref<2x4xf32> to memref<2x4xf32, strided<[?, ?], offset: ?>> + return %alloc : memref<2x4xf32> + } +} diff --git a/docs/marimo/arith.mlir b/docs/marimo/arith.mlir new file mode 100644 index 0000000000..3a7287d99b --- /dev/null +++ b/docs/marimo/arith.mlir @@ -0,0 +1,23 @@ +#map = affine_map<(d0, d1) -> (d0, d1)> +builtin.module { + func.func public @main(%arg0: tensor<5x5xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<5x5xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<5x5xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> tensor<5x5xf32> { + %0 = tensor.empty() : tensor<5x5xf32> + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<5x5xf32>, tensor<5x5xf32>) outs(%0 : tensor<5x5xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %4 = arith.addf %in, %in_1 : f32 + linalg.yield %4 : f32 + } -> tensor<5x5xf32> + %cst = arith.constant dense<1.000000e+00> : tensor + %cst_0 = arith.constant dense<1.000000e+00> : tensor<5x5xf32> + %2 = tensor.empty() : tensor<5x5xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %cst_0 : tensor<5x5xf32>, tensor<5x5xf32>) outs(%2 : tensor<5x5xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %4 = arith.addf %in, %in_1 : f32 + linalg.yield %4 : f32 + } -> tensor<5x5xf32> + + %output = bufferization.materialize_in_destination %3 in %arg2 : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> + + return %output : tensor<5x5xf32> + } +} diff --git a/docs/marimo/arith_donated_buffered.mlir b/docs/marimo/arith_donated_buffered.mlir new file mode 100644 index 0000000000..7498408215 --- /dev/null +++ b/docs/marimo/arith_donated_buffered.mlir @@ -0,0 +1,22 @@ +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + memref.global "private" constant @__constant_5x5xf32 : memref<5x5xf32> = dense<1.000000e+00> {alignment = 64 : i64} + memref.global "private" constant @__constant_xf32 : memref = dense<1.000000e+00> {alignment = 64 : i64} + func.func public @main(%arg0: memref<5x5xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<5x5xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<5x5xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> memref<5x5xf32, strided<[?, ?], offset: ?>> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<5x5xf32> + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : memref<5x5xf32, strided<[?, ?], offset: ?>>, memref<5x5xf32, strided<[?, ?], offset: ?>>) outs(%alloc : memref<5x5xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %2 = arith.addf %in, %in_0 : f32 + linalg.yield %2 : f32 + } + %0 = memref.get_global @__constant_xf32 : memref + %1 = memref.get_global @__constant_5x5xf32 : memref<5x5xf32> + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc, %1 : memref<5x5xf32>, memref<5x5xf32>) outs(%arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %2 = arith.addf %in, %in_0 : f32 + linalg.yield %2 : f32 + } + memref.copy %arg2, %arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>> to memref<5x5xf32, strided<[?, ?], offset: ?>> + return %arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>> + } +} diff --git a/docs/marimo/buffered.mlir b/docs/marimo/buffered.mlir new file mode 100644 index 0000000000..da09e5fb04 --- /dev/null +++ b/docs/marimo/buffered.mlir @@ -0,0 +1,18 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +module @jit_matmul attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: memref<2x3xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<3x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<2x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (memref<2x4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x4xf32> + %cst = arith.constant 0.000000e+00 : f32 + linalg.fill ins(%cst : f32) outs(%alloc : memref<2x4xf32>) + linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<2x3xf32, strided<[?, ?], offset: ?>>, memref<3x4xf32, strided<[?, ?], offset: ?>>) outs(%alloc : memref<2x4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %0 = arith.mulf %in, %in_0 : f32 + %1 = arith.addf %out, %0 : f32 + linalg.yield %1 : f32 + } + %cast = memref.cast %alloc : memref<2x4xf32> to memref<2x4xf32, strided<[?, ?], offset: ?>> + return %alloc : memref<2x4xf32> + } +} diff --git a/docs/marimo/conv.mlir b/docs/marimo/conv.mlir new file mode 100644 index 0000000000..ec58f465cd --- /dev/null +++ b/docs/marimo/conv.mlir @@ -0,0 +1,41 @@ +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d0, d2 + d3, d4 + d5, d6 + d7, d8 + d9)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d1, d0, d3, d5, d7, d9)> +#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d1, d2, d4, d6, d8)> +#map5 = affine_map<(d0, d1, d2, d3) -> (-d0, -d1, -d2 + 2, -d3 + 2)> +#map6 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +builtin.module { + func.func public @main(%arg0: tensor<1x1x10x10xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<1x1x3x3xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<1x1x8x8xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> tensor<1x1x8x8xf32> { + %0 = call @_flip(%arg1) : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x3xf32> + %1 = tensor.empty() : tensor<1x1x1x1x10x10xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x1x10x10xf32>) outs(%1 : tensor<1x1x1x1x10x10xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<1x1x1x1x10x10xf32> + %3 = tensor.empty() : tensor<1x1x1x1x3x3xf32> + %4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%0 : tensor<1x1x3x3xf32>) outs(%3 : tensor<1x1x1x1x3x3xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<1x1x1x1x3x3xf32> + %5 = tensor.empty() : tensor<1x1x1x1x8x8xf32> + %cst = arith.constant 0.000000e+00 : f32 + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x1x1x1x8x8xf32>) -> tensor<1x1x1x1x8x8xf32> + %7 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["reduction", "parallel", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel"]} ins(%2, %4 : tensor<1x1x1x1x10x10xf32>, tensor<1x1x1x1x3x3xf32>) outs(%6 : tensor<1x1x1x1x8x8xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %8 = arith.mulf %in, %in_0 : f32 + %9 = arith.addf %out, %8 : f32 + linalg.yield %9 : f32 + } -> tensor<1x1x1x1x8x8xf32> + %collapsed = tensor.collapse_shape %7 [[0], [1], [2, 3, 4], [5]] : tensor<1x1x1x1x8x8xf32> into tensor<1x1x8x8xf32> + return %collapsed : tensor<1x1x8x8xf32> + } + func.func private @_flip(%arg0: tensor<1x1x3x3xf32> {mhlo.layout_mode = "default"}) -> tensor<1x1x3x3xf32> { + %0 = tensor.empty() : tensor<1x1x3x3xf32> + %1 = linalg.generic {indexing_maps = [#map5, #map6], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x1x3x3xf32>) outs(%0 : tensor<1x1x3x3xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<1x1x3x3xf32> + return %1 : tensor<1x1x3x3xf32> + } +} diff --git a/docs/marimo/conv_buffered.mlir b/docs/marimo/conv_buffered.mlir new file mode 100644 index 0000000000..911a2ca07a --- /dev/null +++ b/docs/marimo/conv_buffered.mlir @@ -0,0 +1,43 @@ +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d0, d2 + d3, d4 + d5, d6 + d7, d8 + d9)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d1, d0, d3, d5, d7, d9)> +#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d1, d2, d4, d6, d8)> +#map5 = affine_map<(d0, d1, d2, d3) -> (-d0, -d1, -d2 + 2, -d3 + 2)> +#map6 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +module { + func.func public @main(%arg0: memref<1x1x10x10xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<1x1x8x8xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> memref<1x1x8x8xf32> { + %0 = call @_flip(%arg1) : (memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>>) -> memref<1x1x3x3xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x1x1x10x10xf32> + linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : memref<1x1x10x10xf32, strided<[?, ?, ?, ?], offset: ?>>) outs(%alloc : memref<1x1x1x1x10x10xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x1x1x1x3x3xf32> + linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%0 : memref<1x1x3x3xf32>) outs(%alloc_0 : memref<1x1x1x1x3x3xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x1x1x1x8x8xf32> + %cst = arith.constant 0.000000e+00 : f32 + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<1x1x1x1x8x8xf32>) + linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["reduction", "parallel", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel"]} ins(%alloc, %alloc_0 : memref<1x1x1x1x10x10xf32>, memref<1x1x1x1x3x3xf32>) outs(%alloc_1 : memref<1x1x1x1x8x8xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %1 = arith.mulf %in, %in_2 : f32 + %2 = arith.addf %out, %1 : f32 + linalg.yield %2 : f32 + } + %collapse_shape = memref.collapse_shape %alloc_1 [[0], [1], [2, 3, 4], [5]] : memref<1x1x1x1x8x8xf32> into memref<1x1x8x8xf32> + %cast = memref.cast %collapse_shape : memref<1x1x8x8xf32> to memref<1x1x8x8xf32, strided<[?, ?, ?, ?], offset: ?>> + return %collapse_shape : memref<1x1x8x8xf32> + } + func.func private @_flip(%arg0: memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default"}) -> memref<1x1x3x3xf32> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x3x3xf32> + linalg.generic {indexing_maps = [#map5, #map6], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>>) outs(%alloc : memref<1x1x3x3xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } + %cast = memref.cast %alloc : memref<1x1x3x3xf32> to memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>> + return %alloc : memref<1x1x3x3xf32> + } +} diff --git a/docs/marimo/conv_donated.mlir b/docs/marimo/conv_donated.mlir new file mode 100644 index 0000000000..900e695885 --- /dev/null +++ b/docs/marimo/conv_donated.mlir @@ -0,0 +1,44 @@ +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d0, d2 + d3, d4 + d5, d6 + d7, d8 + d9)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d1, d0, d3, d5, d7, d9)> +#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d1, d2, d4, d6, d8)> +#map5 = affine_map<(d0, d1, d2, d3) -> (-d0, -d1, -d2 + 2, -d3 + 2)> +#map6 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +builtin.module { + func.func public @main(%arg0: tensor<1x1x10x10xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<1x1x3x3xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<1x1x8x8xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> tensor<1x1x8x8xf32> { + %0 = call @_flip(%arg1) : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x3xf32> + %1 = tensor.empty() : tensor<1x1x1x1x10x10xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x1x10x10xf32>) outs(%1 : tensor<1x1x1x1x10x10xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<1x1x1x1x10x10xf32> + %3 = tensor.empty() : tensor<1x1x1x1x3x3xf32> + %4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%0 : tensor<1x1x3x3xf32>) outs(%3 : tensor<1x1x1x1x3x3xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<1x1x1x1x3x3xf32> + %5 = tensor.empty() : tensor<1x1x1x1x8x8xf32> + %cst = arith.constant 0.000000e+00 : f32 + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x1x1x1x8x8xf32>) -> tensor<1x1x1x1x8x8xf32> + %7 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["reduction", "parallel", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel"]} ins(%2, %4 : tensor<1x1x1x1x10x10xf32>, tensor<1x1x1x1x3x3xf32>) outs(%6 : tensor<1x1x1x1x8x8xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %8 = arith.mulf %in, %in_0 : f32 + %9 = arith.addf %out, %8 : f32 + linalg.yield %9 : f32 + } -> tensor<1x1x1x1x8x8xf32> + %collapsed = tensor.collapse_shape %7 [[0], [1], [2, 3, 4], [5]] : tensor<1x1x1x1x8x8xf32> into tensor<1x1x8x8xf32> + + %output_1 = bufferization.materialize_in_destination %collapsed in %arg2 : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>) -> tensor<1x1x8x8xf32> + + return %output_1 : tensor<1x1x8x8xf32> + } + func.func private @_flip(%arg0: tensor<1x1x3x3xf32> {mhlo.layout_mode = "default"}) -> tensor<1x1x3x3xf32> { + %0 = tensor.empty() : tensor<1x1x3x3xf32> + %1 = linalg.generic {indexing_maps = [#map5, #map6], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x1x3x3xf32>) outs(%0 : tensor<1x1x3x3xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<1x1x3x3xf32> + return %1 : tensor<1x1x3x3xf32> + } +} diff --git a/docs/marimo/conv_donated_buffered.mlir b/docs/marimo/conv_donated_buffered.mlir new file mode 100644 index 0000000000..e4edb55c6e --- /dev/null +++ b/docs/marimo/conv_donated_buffered.mlir @@ -0,0 +1,43 @@ +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d0, d2 + d3, d4 + d5, d6 + d7, d8 + d9)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d1, d0, d3, d5, d7, d9)> +#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d1, d2, d4, d6, d8)> +#map5 = affine_map<(d0, d1, d2, d3) -> (-d0, -d1, -d2 + 2, -d3 + 2)> +#map6 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +module { + func.func public @main(%arg0: memref<1x1x10x10xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<1x1x8x8xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> memref<1x1x8x8xf32, strided<[?, ?, ?, ?], offset: ?>> { + %0 = call @_flip(%arg1) : (memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>>) -> memref<1x1x3x3xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x1x1x10x10xf32> + linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : memref<1x1x10x10xf32, strided<[?, ?, ?, ?], offset: ?>>) outs(%alloc : memref<1x1x1x1x10x10xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x1x1x1x3x3xf32> + linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%0 : memref<1x1x3x3xf32>) outs(%alloc_0 : memref<1x1x1x1x3x3xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x1x1x1x8x8xf32> + %cst = arith.constant 0.000000e+00 : f32 + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<1x1x1x1x8x8xf32>) + linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["reduction", "parallel", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel"]} ins(%alloc, %alloc_0 : memref<1x1x1x1x10x10xf32>, memref<1x1x1x1x3x3xf32>) outs(%alloc_1 : memref<1x1x1x1x8x8xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %1 = arith.mulf %in, %in_2 : f32 + %2 = arith.addf %out, %1 : f32 + linalg.yield %2 : f32 + } + %collapse_shape = memref.collapse_shape %alloc_1 [[0], [1], [2, 3, 4], [5]] : memref<1x1x1x1x8x8xf32> into memref<1x1x8x8xf32> + memref.copy %collapse_shape, %arg2 : memref<1x1x8x8xf32> to memref<1x1x8x8xf32, strided<[?, ?, ?, ?], offset: ?>> + return %arg2 : memref<1x1x8x8xf32, strided<[?, ?, ?, ?], offset: ?>> + } + func.func private @_flip(%arg0: memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default"}) -> memref<1x1x3x3xf32> { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x3x3xf32> + linalg.generic {indexing_maps = [#map5, #map6], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>>) outs(%alloc : memref<1x1x3x3xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } + %cast = memref.cast %alloc : memref<1x1x3x3xf32> to memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>> + return %alloc : memref<1x1x3x3xf32> + } +} diff --git a/docs/marimo/donated_buffered.mlir b/docs/marimo/donated_buffered.mlir new file mode 100644 index 0000000000..0a98d0baed --- /dev/null +++ b/docs/marimo/donated_buffered.mlir @@ -0,0 +1,17 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +module { + func.func public @main(%arg0: memref<2x3xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<3x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<2x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> memref<2x4xf32, strided<[?, ?], offset: ?>> { + %cst = arith.constant 0.000000e+00 : f32 + linalg.fill ins(%cst : f32) outs(%arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>>) + linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<2x3xf32, strided<[?, ?], offset: ?>>, memref<3x4xf32, strided<[?, ?], offset: ?>>) outs(%arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %0 = arith.mulf %in, %in_0 : f32 + %1 = arith.addf %out, %0 : f32 + linalg.yield %1 : f32 + } + memref.copy %arg2, %arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>> to memref<2x4xf32, strided<[?, ?], offset: ?>> + return %arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>> + } +} diff --git a/docs/marimo/fused.mlir b/docs/marimo/fused.mlir new file mode 100644 index 0000000000..17d6d16043 --- /dev/null +++ b/docs/marimo/fused.mlir @@ -0,0 +1,16 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +module @jit_matmul attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2x4xf32>) -> tensor<2x4xf32> + %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%0 : tensor<2x4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %2 = arith.mulf %in, %in_0 : f32 + %3 = arith.addf %out, %2 : f32 + linalg.yield %3 : f32 + } -> tensor<2x4xf32> + return %1 : tensor<2x4xf32> + } +} diff --git a/docs/marimo/fused_buffered.mlir b/docs/marimo/fused_buffered.mlir new file mode 100644 index 0000000000..93d7fbed92 --- /dev/null +++ b/docs/marimo/fused_buffered.mlir @@ -0,0 +1,14 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +module @jit_matmul attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: memref<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: memref<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: memref<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (memref<2x4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<2x3xf32>, memref<3x4xf32>) outs(%arg2 : memref<2x4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %0 = arith.mulf %in, %in_0 : f32 + %1 = arith.addf %out, %0 : f32 + linalg.yield %1 : f32 + } + return %arg2 : memref<2x4xf32> + } +} diff --git a/docs/marimo/jax_experiments.py b/docs/marimo/jax_experiments.py new file mode 100644 index 0000000000..6dace3bfc1 --- /dev/null +++ b/docs/marimo/jax_experiments.py @@ -0,0 +1,172 @@ +import marimo + +__generated_with = "0.8.5" +app = marimo.App(width="medium") + + +@app.cell +def __(): + import marimo as mo + return mo, + + +@app.cell +def __(): + import jax + import jax.numpy as jnp + import jax.scipy as jsp + from jax import lax + return jax, jnp, jsp, lax + + +@app.cell +def __(): + from jax import random + return random, + + +@app.cell +def __(get_linalg_module_str, jax, jnp, random): + def matmul(A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray): + return A @ B + + matmul_params = (2, 4, 3) + matmul_shapes = ((matmul_params[0], matmul_params[2]), (matmul_params[2], matmul_params[1]), (matmul_params[0], matmul_params[1])) + + key = jax.random.key(42) + + matmul_data = tuple(random.uniform(key, shape) for shape in matmul_shapes) + + matmul_jit = jax.jit(matmul, donate_argnames=['C'], keep_unused=True) + + get_linalg_module_str(matmul_jit, matmul_data) + return ( + key, + matmul, + matmul_data, + matmul_jit, + matmul_params, + matmul_shapes, + ) + + +@app.cell +def __(jax, jnp, jsp): + def conv(X: jnp.ndarray, K: jnp.ndarray, Z: jnp.ndarray): + return jsp.signal.convolve(X, K, mode="valid", method="direct") + + conv_jit = jax.jit(conv, donate_argnames=['Z'], keep_unused=True) + return conv, conv_jit + + +@app.cell +def __(key, random): + X_shape = (1, 1, 10, 10) + K_shape = (1, 1, 3, 3) + Z_shape = (1, 1, 1, 1, 8, 8) + + conv_data = tuple(random.uniform(key, shape) for shape in [X_shape, K_shape, Z_shape]) + return K_shape, X_shape, Z_shape, conv_data + + +@app.cell +def __(conv_data, conv_jit, get_linalg_module_str): + get_linalg_module_str(conv_jit, conv_data) + return + + +@app.cell +def __(get_linalg_module_str, jax, jnp, key, random): + def simple_arith(A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray): + return A + B + 1 + + arith_jit = jax.jit(simple_arith, donate_argnames=['C'], keep_unused=True) + arith_data = tuple(random.uniform(key, shape) for shape in [(5, 5), (5, 5), (5, 5)]) + + get_linalg_module_str(arith_jit, arith_data) + return arith_data, arith_jit, simple_arith + + +@app.cell +def __(get_linalg_module_str, jax, jnp, key, random): + def multiple_outputs(A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray, D: jnp.ndarray): + return (B @ A, A + 1, A @ B) + + multiple_jit = jax.jit(multiple_outputs, donate_argnames=['C', 'D'], keep_unused=True) + multiple_data = tuple(random.uniform(key, shape) for shape in [(5, 2), (2, 5), (5, 5), (2, 2)]) + + get_linalg_module_str(multiple_jit, multiple_data) + return multiple_data, multiple_jit, multiple_outputs + + +@app.cell +def __(matmul_data, matmul_jit): + lowered_matmul = matmul_jit.lower(*matmul_data) + lowered_matmul + return lowered_matmul, + + +@app.cell +def __(lowered_matmul): + type(lowered_matmul.compile()).__doc__ + return + + +@app.cell +def __(lowered_matmul, matmul_data): + lowered_matmul.compile()(*matmul_data) + return + + +@app.cell +def __(matmul, matmul_data): + matmul(*matmul_data), matmul(*matmul_data) + return + + +@app.cell +def __(): + from jax._src.interpreters import mlir + from jaxlib.mlir.dialects import mhlo + from jaxlib.mlir.passmanager import PassManager + + def get_linalg_module_str(func, args): + lowered = func.lower(*args) + + mhlo_module = lowered.compiler_ir(dialect="mhlo") + + # print(mhlo_module) + + with mhlo_module.context as ctx: + ctx.append_dialect_registry(mlir.upstream_dialects) + # ctx.load_all_available_dialects() + # mhlo.register_mhlo_dialect(ctx) + mhlo.register_mhlo_passes() + pipeline = PassManager.parse("builtin.module(hlo-legalize-to-arithmetic,func.func(hlo-legalize-to-linalg))") + pipeline.run(mhlo_module.operation) + + mhlo_module_str = f"{mhlo_module}" + + return mhlo_module_str + return PassManager, get_linalg_module_str, mhlo, mlir + + +@app.cell +def __(): + from jax import make_jaxpr + return make_jaxpr, + + +@app.cell +def __(make_jaxpr, matmul, matmul_data): + type(make_jaxpr(matmul)(*matmul_data)) + return + + +@app.cell +def __(): + return + + +if __name__ == "__main__": + app.run() diff --git a/docs/marimo/jax_mult.mlir b/docs/marimo/jax_mult.mlir new file mode 100644 index 0000000000..d13498fa8a --- /dev/null +++ b/docs/marimo/jax_mult.mlir @@ -0,0 +1,17 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +builtin.module { + func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> tensor<2x4xf32> { + %0 = tensor.empty() : tensor<2x4xf32> + %cst = arith.constant 0.000000e+00 : f32 + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %3 = arith.mulf %in, %in_0 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<2x4xf32> + return %2 : tensor<2x4xf32> + } +} diff --git a/docs/marimo/jax_mult_donation.mlir b/docs/marimo/jax_mult_donation.mlir new file mode 100644 index 0000000000..ef90bad128 --- /dev/null +++ b/docs/marimo/jax_mult_donation.mlir @@ -0,0 +1,20 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +builtin.module { + func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> tensor<2x4xf32> { + %0 = tensor.empty() : tensor<2x4xf32> + %cst = arith.constant 0.000000e+00 : f32 + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %3 = arith.mulf %in, %in_0 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<2x4xf32> + + %output = bufferization.materialize_in_destination %2 in %arg2 : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + + return %output : tensor<2x4xf32> + } +} diff --git a/docs/marimo/jax_rewrite_exp.py b/docs/marimo/jax_rewrite_exp.py new file mode 100644 index 0000000000..8a06eb53af --- /dev/null +++ b/docs/marimo/jax_rewrite_exp.py @@ -0,0 +1,62 @@ +import marimo + +__generated_with = "0.8.5" +app = marimo.App(width="medium") + + +@app.cell +def __(): + import marimo as mo + return mo, + + +@app.cell +def __(): + original_func = """ + #map = affine_map<(d0, d1, d2) -> (d0, d2)> + #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> + #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> + + func.func main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = tensor.empty() : tensor<2x4xf32> + %cst = arith.constant 0.000000e+00 : f32 + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %3 = arith.mulf %in, %in_0 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<2x4xf32> + return %2 : tensor<2x4xf32> + } + """ + return original_func, + + +@app.cell +def __(): + from xdsl.context import MLContext + from xdsl.parser import Parser + return MLContext, Parser + + +@app.cell +def __(MLContext, Parser, original_func): + ctx = MLContext() + parser = Parser(ctx, original_func) + return ctx, parser + + +@app.cell +def __(parser): + parser.parse_module(True) + return + + +@app.cell +def __(): + return + + +if __name__ == "__main__": + app.run() diff --git a/docs/marimo/multiple_donated_buffered.mlir b/docs/marimo/multiple_donated_buffered.mlir new file mode 100644 index 0000000000..b51b3be0da --- /dev/null +++ b/docs/marimo/multiple_donated_buffered.mlir @@ -0,0 +1,38 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map3 = affine_map<(d0, d1) -> (d0, d1)> +module @jit_multiple_outputs attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + memref.global "private" constant @__constant_5x2xf32 : memref<5x2xf32> = dense<1.000000e+00> {alignment = 64 : i64} + memref.global "private" constant @__constant_xf32 : memref = dense<1.000000e+00> {alignment = 64 : i64} + func.func public @main(%arg0: memref<5x2xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<2x5xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<5x5xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 2 : i32}, %arg3: memref<2x2xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (memref<2x2xf32, strided<[?, ?], offset: ?>> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, memref<5x2xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, memref<5x5xf32, strided<[?, ?], offset: ?>> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %cst = arith.constant 0.000000e+00 : f32 + linalg.fill ins(%cst : f32) outs(%arg3 : memref<2x2xf32, strided<[?, ?], offset: ?>>) + linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg1, %arg0 : memref<2x5xf32, strided<[?, ?], offset: ?>>, memref<5x2xf32, strided<[?, ?], offset: ?>>) outs(%arg3 : memref<2x2xf32, strided<[?, ?], offset: ?>>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %2 = arith.mulf %in, %in_1 : f32 + %3 = arith.addf %out, %2 : f32 + linalg.yield %3 : f32 + } + %0 = memref.get_global @__constant_xf32 : memref + %1 = memref.get_global @__constant_5x2xf32 : memref<5x2xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<5x2xf32> + linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%arg0, %1 : memref<5x2xf32, strided<[?, ?], offset: ?>>, memref<5x2xf32>) outs(%alloc : memref<5x2xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %2 = arith.addf %in, %in_1 : f32 + linalg.yield %2 : f32 + } + %cst_0 = arith.constant 0.000000e+00 : f32 + linalg.fill ins(%cst_0 : f32) outs(%arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>>) + linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<5x2xf32, strided<[?, ?], offset: ?>>, memref<2x5xf32, strided<[?, ?], offset: ?>>) outs(%arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %2 = arith.mulf %in, %in_1 : f32 + %3 = arith.addf %out, %2 : f32 + linalg.yield %3 : f32 + } + memref.copy %arg3, %arg3 : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2x2xf32, strided<[?, ?], offset: ?>> + memref.copy %arg2, %arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>> to memref<5x5xf32, strided<[?, ?], offset: ?>> + %cast = memref.cast %alloc : memref<5x2xf32> to memref<5x2xf32, strided<[?, ?], offset: ?>> + return %arg3, %alloc, %arg2 : memref<2x2xf32, strided<[?, ?], offset: ?>>, memref<5x2xf32>, memref<5x5xf32, strided<[?, ?], offset: ?>> + } +} diff --git a/docs/marimo/multiple_outputs.mlir b/docs/marimo/multiple_outputs.mlir new file mode 100644 index 0000000000..62fc6dde24 --- /dev/null +++ b/docs/marimo/multiple_outputs.mlir @@ -0,0 +1,39 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map3 = affine_map<(d0, d1) -> (d0, d1)> +module @jit_multiple_outputs attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<5x2xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<2x5xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<5x5xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 2 : i32}, %arg3: tensor<2x2xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x2xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<5x2xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<5x5xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = tensor.empty() : tensor<2x2xf32> + %cst = arith.constant 0.000000e+00 : f32 + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2xf32>) -> tensor<2x2xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg1, %arg0 : tensor<2x5xf32>, tensor<5x2xf32>) outs(%1 : tensor<2x2xf32>) { + ^bb0(%in: f32, %in_3: f32, %out: f32): + %8 = arith.mulf %in, %in_3 : f32 + %9 = arith.addf %out, %8 : f32 + linalg.yield %9 : f32 + } -> tensor<2x2xf32> + %cst_0 = arith.constant dense<1.000000e+00> : tensor + %cst_1 = arith.constant dense<1.000000e+00> : tensor<5x2xf32> + %3 = tensor.empty() : tensor<5x2xf32> + %4 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%arg0, %cst_1 : tensor<5x2xf32>, tensor<5x2xf32>) outs(%3 : tensor<5x2xf32>) { + ^bb0(%in: f32, %in_3: f32, %out: f32): + %8 = arith.addf %in, %in_3 : f32 + linalg.yield %8 : f32 + } -> tensor<5x2xf32> + %5 = tensor.empty() : tensor<5x5xf32> + %cst_2 = arith.constant 0.000000e+00 : f32 + %6 = linalg.fill ins(%cst_2 : f32) outs(%5 : tensor<5x5xf32>) -> tensor<5x5xf32> + %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<5x2xf32>, tensor<2x5xf32>) outs(%6 : tensor<5x5xf32>) { + ^bb0(%in: f32, %in_3: f32, %out: f32): + %8 = arith.mulf %in, %in_3 : f32 + %9 = arith.addf %out, %8 : f32 + linalg.yield %9 : f32 + } -> tensor<5x5xf32> + + %output_1 = bufferization.materialize_in_destination %2 in %arg3 : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %output_2 = bufferization.materialize_in_destination %7 in %arg2: (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> + + return %output_1, %4, %output_2 : tensor<2x2xf32>, tensor<5x2xf32>, tensor<5x5xf32> + } +} diff --git a/docs/marimo/removed.mlir b/docs/marimo/removed.mlir new file mode 100644 index 0000000000..165e6680e7 --- /dev/null +++ b/docs/marimo/removed.mlir @@ -0,0 +1,17 @@ +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +module @jit_matmul attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = tensor.empty() : tensor<2x4xf32> + %cst = arith.constant 0.000000e+00 : f32 + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%arg2 : tensor<2x4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %3 = arith.mulf %in, %in_0 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<2x4xf32> + return %2 : tensor<2x4xf32> + } +} diff --git a/xdsl/transforms/jax_use_donated_arguments.py b/xdsl/transforms/jax_use_donated_arguments.py index f7259b8994..89d4a1a696 100644 --- a/xdsl/transforms/jax_use_donated_arguments.py +++ b/xdsl/transforms/jax_use_donated_arguments.py @@ -1,11 +1,11 @@ from dataclasses import dataclass +from pprint import pprint from xdsl.context import MLContext from xdsl.dialects import builtin from xdsl.dialects.builtin import TensorType from xdsl.dialects.func import FuncOp -from xdsl.ir import BlockArgument, SSAValue -from xdsl.irdl import VarOperand +from xdsl.ir import BlockArgument from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -16,7 +16,7 @@ ) from xdsl.transforms.mlir_opt import MLIROptPass - +""" @dataclass class SubstituteDonatedTensors(RewritePattern): @op_type_rewrite_pattern @@ -44,6 +44,22 @@ def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /): break new_op = child_op.clone(value_mapper) rewriter.replace_op(child_op, [new_op]) +""" + + +@dataclass +class SubstituteDonatedTensors(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /): + if op.arg_attrs is None: + return + + donated_inputs: list[BlockArgument] = [] + for inp, attr in zip(op.args, op.arg_attrs): + if type(inp.type) is TensorType and "tf.aliasing_output" in attr.data: + donated_inputs.append(inp) + + pprint(vars(op.body)) @dataclass(frozen=True) @@ -51,6 +67,7 @@ class JaxUseDonatedArguments(ModulePass): name = "jax-use-donated-arguments" def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + MLIROptPass(arguments=("--linalg-fuse-elementwise-ops",)).apply(ctx, op) the_one_pass = PatternRewriteWalker( GreedyRewritePatternApplier([SubstituteDonatedTensors()]), apply_recursively=False, @@ -58,4 +75,3 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: walk_regions_first=True, ) the_one_pass.rewrite_module(op) - MLIROptPass(arguments=("--linalg-fuse-elementwise-ops",)).apply(ctx, op) From ad4450583bc6de2e6c2c836b8c1af61f366ee449 Mon Sep 17 00:00:00 2001 From: manainen Date: Tue, 8 Oct 2024 14:44:19 +0100 Subject: [PATCH 11/22] rewrite pass to use materialize_in_destination --- test.mlir | 25 +++++++ test_buffered.mlir | 23 ++++++ .../with-mlir/jax-use-donated-arguments.mlir | 43 ++++++++++- xdsl/transforms/jax_use_donated_arguments.py | 73 +++++++++---------- 4 files changed, 121 insertions(+), 43 deletions(-) create mode 100644 test.mlir create mode 100644 test_buffered.mlir diff --git a/test.mlir b/test.mlir new file mode 100644 index 0000000000..7da7465ead --- /dev/null +++ b/test.mlir @@ -0,0 +1,25 @@ +builtin.module { + builtin.module { + func.func public @main(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>, %arg2 : tensor<2x4xf32> {"tf.aliasing_output" = 0 : i32}) -> tensor<2x4xf32> { + %0 = tensor.empty() : tensor<2x4xf32> + %cst = arith.constant 0.000000e+00 : f32 + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = bufferization.materialize_in_destination %1 in %arg2 : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> + func.return %2 : tensor<2x4xf32> + } + } + builtin.module { + func.func public @main(%arg0 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32>, %arg2 : tensor<4x5xf32> {"tf.aliasing_output" = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<2x3xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x3xf32>) -> tensor<2x3xf32> + %2 = tensor.empty() : tensor<2x3xf32> + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<2x3xf32>) -> tensor<2x3xf32> + %4 = tensor.empty() : tensor<4x5xf32> + %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<4x5xf32>) -> tensor<4x5xf32> + %6 = bufferization.materialize_in_destination %1 in %arg0 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + %7 = bufferization.materialize_in_destination %5 in %arg2 : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> + func.return %6, %3, %7 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32> + } + } +} diff --git a/test_buffered.mlir b/test_buffered.mlir new file mode 100644 index 0000000000..4f89819726 --- /dev/null +++ b/test_buffered.mlir @@ -0,0 +1,23 @@ +module { + module { + func.func public @main(%arg0: memref<2x3xf32, strided<[?, ?], offset: ?>>, %arg1: memref<3x4xf32, strided<[?, ?], offset: ?>>, %arg2: memref<2x4xf32, strided<[?, ?], offset: ?>> {tf.aliasing_output = 0 : i32}) -> memref<2x4xf32, strided<[?, ?], offset: ?>> { + %cst = arith.constant 0.000000e+00 : f32 + linalg.fill ins(%cst : f32) outs(%arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>>) + memref.copy %arg2, %arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>> to memref<2x4xf32, strided<[?, ?], offset: ?>> + return %arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>> + } + } + module { + func.func public @main(%arg0: memref<2x3xf32, strided<[?, ?], offset: ?>> {tf.aliasing_output = 0 : i32}, %arg1: memref<2x3xf32, strided<[?, ?], offset: ?>>, %arg2: memref<4x5xf32, strided<[?, ?], offset: ?>> {tf.aliasing_output = 0 : i32}) -> (memref<2x3xf32, strided<[?, ?], offset: ?>>, memref<2x3xf32>, memref<4x5xf32, strided<[?, ?], offset: ?>>) { + %cst = arith.constant 0.000000e+00 : f32 + linalg.fill ins(%cst : f32) outs(%arg0 : memref<2x3xf32, strided<[?, ?], offset: ?>>) + %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x3xf32> + linalg.fill ins(%cst : f32) outs(%alloc : memref<2x3xf32>) + linalg.fill ins(%cst : f32) outs(%arg2 : memref<4x5xf32, strided<[?, ?], offset: ?>>) + memref.copy %arg0, %arg0 : memref<2x3xf32, strided<[?, ?], offset: ?>> to memref<2x3xf32, strided<[?, ?], offset: ?>> + memref.copy %arg2, %arg2 : memref<4x5xf32, strided<[?, ?], offset: ?>> to memref<4x5xf32, strided<[?, ?], offset: ?>> + %cast = memref.cast %alloc : memref<2x3xf32> to memref<2x3xf32, strided<[?, ?], offset: ?>> + return %arg0, %alloc, %arg2 : memref<2x3xf32, strided<[?, ?], offset: ?>>, memref<2x3xf32>, memref<4x5xf32, strided<[?, ?], offset: ?>> + } + } +} diff --git a/tests/filecheck/mlir-conversion/with-mlir/jax-use-donated-arguments.mlir b/tests/filecheck/mlir-conversion/with-mlir/jax-use-donated-arguments.mlir index 44eace82b3..34b4a69ebc 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/jax-use-donated-arguments.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/jax-use-donated-arguments.mlir @@ -9,10 +9,47 @@ func.func public @main(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: te } } -// CHECK: builtin.module { +// CHECK: builtin.module { +// CHECK-NEXT: builtin.module { // CHECK-NEXT: func.func public @main(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>, %arg2 : tensor<2x4xf32> {"tf.aliasing_output" = 0 : i32}) -> tensor<2x4xf32> { +// CHECK-NEXT: %0 = tensor.empty() : tensor<2x4xf32> +// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: %2 = bufferization.materialize_in_destination %1 in %arg2 : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: func.return %2 : tensor<2x4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } + +builtin.module { +func.func public @main(%arg0: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32>, %arg2: tensor<4x5xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + + %0 = tensor.empty() : tensor<2x3xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x3xf32>) -> tensor<2x3xf32> + + %2 = tensor.empty() : tensor<2x3xf32> + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<2x3xf32>) -> tensor<2x3xf32> + + %4 = tensor.empty() : tensor<4x5xf32> + %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<4x5xf32>) -> tensor<4x5xf32> + + return %1, %3, %5 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32> + } +} + +// CHECK-NEXT: builtin.module { +// CHECK-NEXT: func.func public @main(%arg0 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32>, %arg2 : tensor<4x5xf32> {"tf.aliasing_output" = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) { // CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK-NEXT: func.return %0 : tensor<2x4xf32> +// CHECK-NEXT: %0 = tensor.empty() : tensor<2x3xf32> +// CHECK-NEXT: %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: %2 = tensor.empty() : tensor<2x3xf32> +// CHECK-NEXT: %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: %4 = tensor.empty() : tensor<4x5xf32> +// CHECK-NEXT: %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<4x5xf32>) -> tensor<4x5xf32> +// CHECK-NEXT: %6 = bufferization.materialize_in_destination %1 in %arg0 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: %7 = bufferization.materialize_in_destination %5 in %arg2 : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> +// CHECK-NEXT: func.return %6, %3, %7 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32> // CHECK-NEXT: } // CHECK-NEXT: } + +// CHECK-NEXT: } diff --git a/xdsl/transforms/jax_use_donated_arguments.py b/xdsl/transforms/jax_use_donated_arguments.py index 89d4a1a696..3be54eb059 100644 --- a/xdsl/transforms/jax_use_donated_arguments.py +++ b/xdsl/transforms/jax_use_donated_arguments.py @@ -1,11 +1,11 @@ from dataclasses import dataclass -from pprint import pprint from xdsl.context import MLContext from xdsl.dialects import builtin +from xdsl.dialects.bufferization import MaterializeInDestination from xdsl.dialects.builtin import TensorType -from xdsl.dialects.func import FuncOp -from xdsl.ir import BlockArgument +from xdsl.dialects.func import Return +from xdsl.ir import Operation, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -14,52 +14,46 @@ RewritePattern, op_type_rewrite_pattern, ) -from xdsl.transforms.mlir_opt import MLIROptPass +from xdsl.utils.exceptions import VerifyException -""" -@dataclass -class SubstituteDonatedTensors(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /): - if op.arg_attrs is None: - return - donated_inputs: list[BlockArgument] = [] - for inp, attr in zip(op.args, op.arg_attrs): - if type(inp.type) is TensorType and "tf.aliasing_output" in attr.data: - donated_inputs.append(inp) - - for child_op in op.body.ops: - if ( - hasattr(child_op, "outputs") - and type(getattr(child_op, "outputs")) is VarOperand - ): - value_mapper: dict[SSAValue, SSAValue] = {} - for output in getattr(child_op, "outputs"): - for i, arg in enumerate(donated_inputs): - if type(getattr(output, "type")) is TensorType and getattr( - arg, "type" - ).is_same_type_with(output.type): - value_mapper[output] = donated_inputs.pop(i) - break - new_op = child_op.clone(value_mapper) - rewriter.replace_op(child_op, [new_op]) -""" +def make_materialize_op(source: SSAValue, dest: SSAValue) -> MaterializeInDestination: + return MaterializeInDestination(operands=[source, dest], result_types=[source.type]) @dataclass class SubstituteDonatedTensors(RewritePattern): @op_type_rewrite_pattern - def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /): - if op.arg_attrs is None: + def match_and_rewrite(self, op: Return, rewriter: PatternRewriter, /): + func_op = op.parent_op() + if func_op is None: + raise VerifyException("Return operation should be tied to a FuncOp") + + arg_attrs = getattr(func_op, "arg_attrs") + args = getattr(func_op, "args") + + if arg_attrs is None: return - donated_inputs: list[BlockArgument] = [] - for inp, attr in zip(op.args, op.arg_attrs): - if type(inp.type) is TensorType and "tf.aliasing_output" in attr.data: - donated_inputs.append(inp) + donated_inputs = [ + inp + for inp, attr in zip(args, arg_attrs) + if isinstance(inp.type, TensorType) and "tf.aliasing_output" in attr.data + ] + + value_mapper: dict[SSAValue, SSAValue] = {} + new_ops: list[Operation] = [] + for output in op.arguments: + for i, arg in enumerate(donated_inputs): + if type(getattr(output, "type")) is TensorType and getattr( + arg, "type" + ).is_same_type_with(output.type): + new_ops.append(make_materialize_op(output, donated_inputs.pop(i))) + value_mapper[output] = new_ops[-1].results[0] + break - pprint(vars(op.body)) + new_ops.append(op.clone(value_mapper)) + rewriter.replace_matched_op(new_ops) @dataclass(frozen=True) @@ -67,7 +61,6 @@ class JaxUseDonatedArguments(ModulePass): name = "jax-use-donated-arguments" def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: - MLIROptPass(arguments=("--linalg-fuse-elementwise-ops",)).apply(ctx, op) the_one_pass = PatternRewriteWalker( GreedyRewritePatternApplier([SubstituteDonatedTensors()]), apply_recursively=False, From 7fb15af7dee3b217c901359195b5a792d9e8a2c2 Mon Sep 17 00:00:00 2001 From: manainen Date: Tue, 8 Oct 2024 14:45:40 +0100 Subject: [PATCH 12/22] move test --- test.mlir | 25 ------------------- test_buffered.mlir | 23 ----------------- .../jax-use-donated-arguments.mlir | 0 3 files changed, 48 deletions(-) delete mode 100644 test.mlir delete mode 100644 test_buffered.mlir rename tests/filecheck/{mlir-conversion/with-mlir => transforms}/jax-use-donated-arguments.mlir (100%) diff --git a/test.mlir b/test.mlir deleted file mode 100644 index 7da7465ead..0000000000 --- a/test.mlir +++ /dev/null @@ -1,25 +0,0 @@ -builtin.module { - builtin.module { - func.func public @main(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>, %arg2 : tensor<2x4xf32> {"tf.aliasing_output" = 0 : i32}) -> tensor<2x4xf32> { - %0 = tensor.empty() : tensor<2x4xf32> - %cst = arith.constant 0.000000e+00 : f32 - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> - %2 = bufferization.materialize_in_destination %1 in %arg2 : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> - func.return %2 : tensor<2x4xf32> - } - } - builtin.module { - func.func public @main(%arg0 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32>, %arg2 : tensor<4x5xf32> {"tf.aliasing_output" = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<2x3xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x3xf32>) -> tensor<2x3xf32> - %2 = tensor.empty() : tensor<2x3xf32> - %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<2x3xf32>) -> tensor<2x3xf32> - %4 = tensor.empty() : tensor<4x5xf32> - %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<4x5xf32>) -> tensor<4x5xf32> - %6 = bufferization.materialize_in_destination %1 in %arg0 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> - %7 = bufferization.materialize_in_destination %5 in %arg2 : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> - func.return %6, %3, %7 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32> - } - } -} diff --git a/test_buffered.mlir b/test_buffered.mlir deleted file mode 100644 index 4f89819726..0000000000 --- a/test_buffered.mlir +++ /dev/null @@ -1,23 +0,0 @@ -module { - module { - func.func public @main(%arg0: memref<2x3xf32, strided<[?, ?], offset: ?>>, %arg1: memref<3x4xf32, strided<[?, ?], offset: ?>>, %arg2: memref<2x4xf32, strided<[?, ?], offset: ?>> {tf.aliasing_output = 0 : i32}) -> memref<2x4xf32, strided<[?, ?], offset: ?>> { - %cst = arith.constant 0.000000e+00 : f32 - linalg.fill ins(%cst : f32) outs(%arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>>) - memref.copy %arg2, %arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>> to memref<2x4xf32, strided<[?, ?], offset: ?>> - return %arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>> - } - } - module { - func.func public @main(%arg0: memref<2x3xf32, strided<[?, ?], offset: ?>> {tf.aliasing_output = 0 : i32}, %arg1: memref<2x3xf32, strided<[?, ?], offset: ?>>, %arg2: memref<4x5xf32, strided<[?, ?], offset: ?>> {tf.aliasing_output = 0 : i32}) -> (memref<2x3xf32, strided<[?, ?], offset: ?>>, memref<2x3xf32>, memref<4x5xf32, strided<[?, ?], offset: ?>>) { - %cst = arith.constant 0.000000e+00 : f32 - linalg.fill ins(%cst : f32) outs(%arg0 : memref<2x3xf32, strided<[?, ?], offset: ?>>) - %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x3xf32> - linalg.fill ins(%cst : f32) outs(%alloc : memref<2x3xf32>) - linalg.fill ins(%cst : f32) outs(%arg2 : memref<4x5xf32, strided<[?, ?], offset: ?>>) - memref.copy %arg0, %arg0 : memref<2x3xf32, strided<[?, ?], offset: ?>> to memref<2x3xf32, strided<[?, ?], offset: ?>> - memref.copy %arg2, %arg2 : memref<4x5xf32, strided<[?, ?], offset: ?>> to memref<4x5xf32, strided<[?, ?], offset: ?>> - %cast = memref.cast %alloc : memref<2x3xf32> to memref<2x3xf32, strided<[?, ?], offset: ?>> - return %arg0, %alloc, %arg2 : memref<2x3xf32, strided<[?, ?], offset: ?>>, memref<2x3xf32>, memref<4x5xf32, strided<[?, ?], offset: ?>> - } - } -} diff --git a/tests/filecheck/mlir-conversion/with-mlir/jax-use-donated-arguments.mlir b/tests/filecheck/transforms/jax-use-donated-arguments.mlir similarity index 100% rename from tests/filecheck/mlir-conversion/with-mlir/jax-use-donated-arguments.mlir rename to tests/filecheck/transforms/jax-use-donated-arguments.mlir From c83f6fa72a03bb584f68cbcf180abb61d6085f5d Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Mon, 14 Oct 2024 10:36:11 +0100 Subject: [PATCH 13/22] little flow fix --- xdsl/transforms/jax_use_donated_arguments.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xdsl/transforms/jax_use_donated_arguments.py b/xdsl/transforms/jax_use_donated_arguments.py index 3be54eb059..fcfe204254 100644 --- a/xdsl/transforms/jax_use_donated_arguments.py +++ b/xdsl/transforms/jax_use_donated_arguments.py @@ -44,10 +44,11 @@ def match_and_rewrite(self, op: Return, rewriter: PatternRewriter, /): value_mapper: dict[SSAValue, SSAValue] = {} new_ops: list[Operation] = [] for output in op.arguments: + if type(getattr(output, "type")) is not TensorType: + break + for i, arg in enumerate(donated_inputs): - if type(getattr(output, "type")) is TensorType and getattr( - arg, "type" - ).is_same_type_with(output.type): + if getattr(arg, "type").is_same_type_with(output.type): new_ops.append(make_materialize_op(output, donated_inputs.pop(i))) value_mapper[output] = new_ops[-1].results[0] break From c77ffdae821c7d8f912acd89fdc6f712ef0fdf7e Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Mon, 14 Oct 2024 10:37:13 +0100 Subject: [PATCH 14/22] little flow fix --- xdsl/transforms/jax_use_donated_arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/transforms/jax_use_donated_arguments.py b/xdsl/transforms/jax_use_donated_arguments.py index fcfe204254..ec93d56053 100644 --- a/xdsl/transforms/jax_use_donated_arguments.py +++ b/xdsl/transforms/jax_use_donated_arguments.py @@ -45,7 +45,7 @@ def match_and_rewrite(self, op: Return, rewriter: PatternRewriter, /): new_ops: list[Operation] = [] for output in op.arguments: if type(getattr(output, "type")) is not TensorType: - break + continue for i, arg in enumerate(donated_inputs): if getattr(arg, "type").is_same_type_with(output.type): From 57a98eea6d5a5b7f166bcbd8c4235b6ee3c43f54 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Mon, 14 Oct 2024 14:47:20 +0100 Subject: [PATCH 15/22] full integration test --- .../with-mlir/jax_argument_donation_full.mlir | 109 ++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 tests/filecheck/mlir-conversion/with-mlir/jax_argument_donation_full.mlir diff --git a/tests/filecheck/mlir-conversion/with-mlir/jax_argument_donation_full.mlir b/tests/filecheck/mlir-conversion/with-mlir/jax_argument_donation_full.mlir new file mode 100644 index 0000000000..309468a7ad --- /dev/null +++ b/tests/filecheck/mlir-conversion/with-mlir/jax_argument_donation_full.mlir @@ -0,0 +1,109 @@ +// RUN: xdsl-opt %s -p jax-use-donated-arguments --split-input-file | mlir-opt --eliminate-empty-tensors --one-shot-bufferize=bufferize-function-boundaries | filecheck %s + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +builtin.module { + func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> tensor<2x4xf32> { + %0 = tensor.empty() : tensor<2x4xf32> + %cst = arith.constant 0.000000e+00 : f32 + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %3 = arith.mulf %in, %in_0 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<2x4xf32> + return %2 : tensor<2x4xf32> + } +} + +// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-NEXT: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-NEXT: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-NEXT: #map3 = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-NEXT: module { +// CHECK-NEXT: module { +// CHECK-NEXT: func.func public @main(%arg0: memref<2x3xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<3x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<2x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> memref<2x4xf32, strided<[?, ?], offset: ?>> { +// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: linalg.fill ins(%cst : f32) outs(%arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>>) +// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<2x3xf32, strided<[?, ?], offset: ?>>, memref<3x4xf32, strided<[?, ?], offset: ?>>) outs(%arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>>) { +// CHECK-NEXT: ^bb0(%in: f32, %in_0: f32, %out: f32): +// CHECK-NEXT: %0 = arith.mulf %in, %in_0 : f32 +// CHECK-NEXT: %1 = arith.addf %out, %0 : f32 +// CHECK-NEXT: linalg.yield %1 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: memref.copy %arg2, %arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>> to memref<2x4xf32, strided<[?, ?], offset: ?>> +// CHECK-NEXT: return %arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>> +// CHECK-NEXT: } +// CHECK-NEXT: } + +#map3 = affine_map<(d0, d1) -> (d0, d1)> +builtin.module { + func.func public @main(%arg0: tensor<5x2xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<2x5xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<5x5xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 2 : i32}, %arg3: tensor<2x2xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x2xf32>, tensor<5x2xf32>, tensor<5x5xf32>) { + %0 = tensor.empty() : tensor<2x2xf32> + %cst = arith.constant 0.000000e+00 : f32 + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2xf32>) -> tensor<2x2xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg1, %arg0 : tensor<2x5xf32>, tensor<5x2xf32>) outs(%1 : tensor<2x2xf32>) { + ^bb0(%in: f32, %in_3: f32, %out: f32): + %8 = arith.mulf %in, %in_3 : f32 + %9 = arith.addf %out, %8 : f32 + linalg.yield %9 : f32 + } -> tensor<2x2xf32> + %cst_0 = arith.constant dense<1.000000e+00> : tensor + %cst_1 = arith.constant dense<1.000000e+00> : tensor<5x2xf32> + %3 = tensor.empty() : tensor<5x2xf32> + %4 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%arg0, %cst_1 : tensor<5x2xf32>, tensor<5x2xf32>) outs(%3 : tensor<5x2xf32>) { + ^bb0(%in: f32, %in_3: f32, %out: f32): + %8 = arith.addf %in, %in_3 : f32 + linalg.yield %8 : f32 + } -> tensor<5x2xf32> + %5 = tensor.empty() : tensor<5x5xf32> + %cst_2 = arith.constant 0.000000e+00 : f32 + %6 = linalg.fill ins(%cst_2 : f32) outs(%5 : tensor<5x5xf32>) -> tensor<5x5xf32> + %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<5x2xf32>, tensor<2x5xf32>) outs(%6 : tensor<5x5xf32>) { + ^bb0(%in: f32, %in_3: f32, %out: f32): + %8 = arith.mulf %in, %in_3 : f32 + %9 = arith.addf %out, %8 : f32 + linalg.yield %9 : f32 + } -> tensor<5x5xf32> + return %2, %4, %7 : tensor<2x2xf32>, tensor<5x2xf32>, tensor<5x5xf32> + } +} + +// CHECK-NEXT: module { +// CHECK-NEXT: memref.global "private" constant @__constant_5x2xf32 : memref<5x2xf32> = dense<1.000000e+00> {alignment = 64 : i64} +// CHECK-NEXT: memref.global "private" constant @__constant_xf32 : memref = dense<1.000000e+00> {alignment = 64 : i64} +// CHECK-NEXT: func.func public @main(%arg0: memref<5x2xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<2x5xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<5x5xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 2 : i32}, %arg3: memref<2x2xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (memref<2x2xf32, strided<[?, ?], offset: ?>>, memref<5x2xf32>, memref<5x5xf32, strided<[?, ?], offset: ?>>) { +// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: linalg.fill ins(%cst : f32) outs(%arg3 : memref<2x2xf32, strided<[?, ?], offset: ?>>) +// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg1, %arg0 : memref<2x5xf32, strided<[?, ?], offset: ?>>, memref<5x2xf32, strided<[?, ?], offset: ?>>) outs(%arg3 : memref<2x2xf32, strided<[?, ?], offset: ?>>) { +// CHECK-NEXT: ^bb0(%in: f32, %in_1: f32, %out: f32): +// CHECK-NEXT: %2 = arith.mulf %in, %in_1 : f32 +// CHECK-NEXT: %3 = arith.addf %out, %2 : f32 +// CHECK-NEXT: linalg.yield %3 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: %0 = memref.get_global @__constant_xf32 : memref +// CHECK-NEXT: %1 = memref.get_global @__constant_5x2xf32 : memref<5x2xf32> +// CHECK-NEXT: %alloc = memref.alloc() {alignment = 64 : i64} : memref<5x2xf32> +// CHECK-NEXT: linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%arg0, %1 : memref<5x2xf32, strided<[?, ?], offset: ?>>, memref<5x2xf32>) outs(%alloc : memref<5x2xf32>) { +// CHECK-NEXT: ^bb0(%in: f32, %in_1: f32, %out: f32): +// CHECK-NEXT: %2 = arith.addf %in, %in_1 : f32 +// CHECK-NEXT: linalg.yield %2 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: linalg.fill ins(%cst_0 : f32) outs(%arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>>) +// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<5x2xf32, strided<[?, ?], offset: ?>>, memref<2x5xf32, strided<[?, ?], offset: ?>>) outs(%arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>>) { +// CHECK-NEXT: ^bb0(%in: f32, %in_1: f32, %out: f32): +// CHECK-NEXT: %2 = arith.mulf %in, %in_1 : f32 +// CHECK-NEXT: %3 = arith.addf %out, %2 : f32 +// CHECK-NEXT: linalg.yield %3 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: memref.copy %arg3, %arg3 : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2x2xf32, strided<[?, ?], offset: ?>> +// CHECK-NEXT: memref.copy %arg2, %arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>> to memref<5x5xf32, strided<[?, ?], offset: ?>> +// CHECK-NEXT: %cast = memref.cast %alloc : memref<5x2xf32> to memref<5x2xf32, strided<[?, ?], offset: ?>> +// CHECK-NEXT: return %arg3, %alloc, %arg2 : memref<2x2xf32, strided<[?, ?], offset: ?>>, memref<5x2xf32>, memref<5x5xf32, strided<[?, ?], offset: ?>> +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK-NEXT: } \ No newline at end of file From c7279fe87c4e98ab287e7b562ad4a649c888781b Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Mon, 14 Oct 2024 15:19:26 +0100 Subject: [PATCH 16/22] cleanup --- buffered.mlir | 18 -- docs/marimo/arith.mlir | 23 --- docs/marimo/arith_donated_buffered.mlir | 22 --- docs/marimo/buffered.mlir | 18 -- docs/marimo/conv.mlir | 41 ----- docs/marimo/conv_buffered.mlir | 43 ----- docs/marimo/conv_donated.mlir | 44 ----- docs/marimo/conv_donated_buffered.mlir | 43 ----- docs/marimo/donated_buffered.mlir | 17 -- docs/marimo/fused.mlir | 16 -- docs/marimo/fused_buffered.mlir | 14 -- docs/marimo/jax_experiments.py | 172 ------------------ docs/marimo/jax_mult.mlir | 17 -- docs/marimo/jax_mult_donation.mlir | 20 -- docs/marimo/jax_rewrite_exp.py | 62 ------- docs/marimo/multiple_donated_buffered.mlir | 38 ---- docs/marimo/multiple_outputs.mlir | 39 ---- docs/marimo/removed.mlir | 17 -- .../with-mlir/jax_argument_donation_full.mlir | 109 ----------- 19 files changed, 773 deletions(-) delete mode 100644 buffered.mlir delete mode 100644 docs/marimo/arith.mlir delete mode 100644 docs/marimo/arith_donated_buffered.mlir delete mode 100644 docs/marimo/buffered.mlir delete mode 100644 docs/marimo/conv.mlir delete mode 100644 docs/marimo/conv_buffered.mlir delete mode 100644 docs/marimo/conv_donated.mlir delete mode 100644 docs/marimo/conv_donated_buffered.mlir delete mode 100644 docs/marimo/donated_buffered.mlir delete mode 100644 docs/marimo/fused.mlir delete mode 100644 docs/marimo/fused_buffered.mlir delete mode 100644 docs/marimo/jax_experiments.py delete mode 100644 docs/marimo/jax_mult.mlir delete mode 100644 docs/marimo/jax_mult_donation.mlir delete mode 100644 docs/marimo/jax_rewrite_exp.py delete mode 100644 docs/marimo/multiple_donated_buffered.mlir delete mode 100644 docs/marimo/multiple_outputs.mlir delete mode 100644 docs/marimo/removed.mlir delete mode 100644 tests/filecheck/mlir-conversion/with-mlir/jax_argument_donation_full.mlir diff --git a/buffered.mlir b/buffered.mlir deleted file mode 100644 index 6c7ffb4a14..0000000000 --- a/buffered.mlir +++ /dev/null @@ -1,18 +0,0 @@ -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -module { - func.func public @main(%arg0: memref<2x3xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<3x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<2x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> memref<2x4xf32> { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x4xf32> - %cst = arith.constant 0.000000e+00 : f32 - linalg.fill ins(%cst : f32) outs(%alloc : memref<2x4xf32>) - linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<2x3xf32, strided<[?, ?], offset: ?>>, memref<3x4xf32, strided<[?, ?], offset: ?>>) outs(%alloc : memref<2x4xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %0 = arith.mulf %in, %in_0 : f32 - %1 = arith.addf %out, %0 : f32 - linalg.yield %1 : f32 - } - %cast = memref.cast %alloc : memref<2x4xf32> to memref<2x4xf32, strided<[?, ?], offset: ?>> - return %alloc : memref<2x4xf32> - } -} diff --git a/docs/marimo/arith.mlir b/docs/marimo/arith.mlir deleted file mode 100644 index 3a7287d99b..0000000000 --- a/docs/marimo/arith.mlir +++ /dev/null @@ -1,23 +0,0 @@ -#map = affine_map<(d0, d1) -> (d0, d1)> -builtin.module { - func.func public @main(%arg0: tensor<5x5xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<5x5xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<5x5xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> tensor<5x5xf32> { - %0 = tensor.empty() : tensor<5x5xf32> - %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<5x5xf32>, tensor<5x5xf32>) outs(%0 : tensor<5x5xf32>) { - ^bb0(%in: f32, %in_1: f32, %out: f32): - %4 = arith.addf %in, %in_1 : f32 - linalg.yield %4 : f32 - } -> tensor<5x5xf32> - %cst = arith.constant dense<1.000000e+00> : tensor - %cst_0 = arith.constant dense<1.000000e+00> : tensor<5x5xf32> - %2 = tensor.empty() : tensor<5x5xf32> - %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %cst_0 : tensor<5x5xf32>, tensor<5x5xf32>) outs(%2 : tensor<5x5xf32>) { - ^bb0(%in: f32, %in_1: f32, %out: f32): - %4 = arith.addf %in, %in_1 : f32 - linalg.yield %4 : f32 - } -> tensor<5x5xf32> - - %output = bufferization.materialize_in_destination %3 in %arg2 : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> - - return %output : tensor<5x5xf32> - } -} diff --git a/docs/marimo/arith_donated_buffered.mlir b/docs/marimo/arith_donated_buffered.mlir deleted file mode 100644 index 7498408215..0000000000 --- a/docs/marimo/arith_donated_buffered.mlir +++ /dev/null @@ -1,22 +0,0 @@ -#map = affine_map<(d0, d1) -> (d0, d1)> -module { - memref.global "private" constant @__constant_5x5xf32 : memref<5x5xf32> = dense<1.000000e+00> {alignment = 64 : i64} - memref.global "private" constant @__constant_xf32 : memref = dense<1.000000e+00> {alignment = 64 : i64} - func.func public @main(%arg0: memref<5x5xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<5x5xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<5x5xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> memref<5x5xf32, strided<[?, ?], offset: ?>> { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<5x5xf32> - linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : memref<5x5xf32, strided<[?, ?], offset: ?>>, memref<5x5xf32, strided<[?, ?], offset: ?>>) outs(%alloc : memref<5x5xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %2 = arith.addf %in, %in_0 : f32 - linalg.yield %2 : f32 - } - %0 = memref.get_global @__constant_xf32 : memref - %1 = memref.get_global @__constant_5x5xf32 : memref<5x5xf32> - linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc, %1 : memref<5x5xf32>, memref<5x5xf32>) outs(%arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %2 = arith.addf %in, %in_0 : f32 - linalg.yield %2 : f32 - } - memref.copy %arg2, %arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>> to memref<5x5xf32, strided<[?, ?], offset: ?>> - return %arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>> - } -} diff --git a/docs/marimo/buffered.mlir b/docs/marimo/buffered.mlir deleted file mode 100644 index da09e5fb04..0000000000 --- a/docs/marimo/buffered.mlir +++ /dev/null @@ -1,18 +0,0 @@ -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -module @jit_matmul attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: memref<2x3xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<3x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<2x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (memref<2x4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x4xf32> - %cst = arith.constant 0.000000e+00 : f32 - linalg.fill ins(%cst : f32) outs(%alloc : memref<2x4xf32>) - linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<2x3xf32, strided<[?, ?], offset: ?>>, memref<3x4xf32, strided<[?, ?], offset: ?>>) outs(%alloc : memref<2x4xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %0 = arith.mulf %in, %in_0 : f32 - %1 = arith.addf %out, %0 : f32 - linalg.yield %1 : f32 - } - %cast = memref.cast %alloc : memref<2x4xf32> to memref<2x4xf32, strided<[?, ?], offset: ?>> - return %alloc : memref<2x4xf32> - } -} diff --git a/docs/marimo/conv.mlir b/docs/marimo/conv.mlir deleted file mode 100644 index ec58f465cd..0000000000 --- a/docs/marimo/conv.mlir +++ /dev/null @@ -1,41 +0,0 @@ -#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d0, d2 + d3, d4 + d5, d6 + d7, d8 + d9)> -#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d1, d0, d3, d5, d7, d9)> -#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d1, d2, d4, d6, d8)> -#map5 = affine_map<(d0, d1, d2, d3) -> (-d0, -d1, -d2 + 2, -d3 + 2)> -#map6 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -builtin.module { - func.func public @main(%arg0: tensor<1x1x10x10xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<1x1x3x3xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<1x1x8x8xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> tensor<1x1x8x8xf32> { - %0 = call @_flip(%arg1) : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x3xf32> - %1 = tensor.empty() : tensor<1x1x1x1x10x10xf32> - %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x1x10x10xf32>) outs(%1 : tensor<1x1x1x1x10x10xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<1x1x1x1x10x10xf32> - %3 = tensor.empty() : tensor<1x1x1x1x3x3xf32> - %4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%0 : tensor<1x1x3x3xf32>) outs(%3 : tensor<1x1x1x1x3x3xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<1x1x1x1x3x3xf32> - %5 = tensor.empty() : tensor<1x1x1x1x8x8xf32> - %cst = arith.constant 0.000000e+00 : f32 - %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x1x1x1x8x8xf32>) -> tensor<1x1x1x1x8x8xf32> - %7 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["reduction", "parallel", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel"]} ins(%2, %4 : tensor<1x1x1x1x10x10xf32>, tensor<1x1x1x1x3x3xf32>) outs(%6 : tensor<1x1x1x1x8x8xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %8 = arith.mulf %in, %in_0 : f32 - %9 = arith.addf %out, %8 : f32 - linalg.yield %9 : f32 - } -> tensor<1x1x1x1x8x8xf32> - %collapsed = tensor.collapse_shape %7 [[0], [1], [2, 3, 4], [5]] : tensor<1x1x1x1x8x8xf32> into tensor<1x1x8x8xf32> - return %collapsed : tensor<1x1x8x8xf32> - } - func.func private @_flip(%arg0: tensor<1x1x3x3xf32> {mhlo.layout_mode = "default"}) -> tensor<1x1x3x3xf32> { - %0 = tensor.empty() : tensor<1x1x3x3xf32> - %1 = linalg.generic {indexing_maps = [#map5, #map6], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x1x3x3xf32>) outs(%0 : tensor<1x1x3x3xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<1x1x3x3xf32> - return %1 : tensor<1x1x3x3xf32> - } -} diff --git a/docs/marimo/conv_buffered.mlir b/docs/marimo/conv_buffered.mlir deleted file mode 100644 index 911a2ca07a..0000000000 --- a/docs/marimo/conv_buffered.mlir +++ /dev/null @@ -1,43 +0,0 @@ -#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d0, d2 + d3, d4 + d5, d6 + d7, d8 + d9)> -#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d1, d0, d3, d5, d7, d9)> -#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d1, d2, d4, d6, d8)> -#map5 = affine_map<(d0, d1, d2, d3) -> (-d0, -d1, -d2 + 2, -d3 + 2)> -#map6 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -module { - func.func public @main(%arg0: memref<1x1x10x10xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<1x1x8x8xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> memref<1x1x8x8xf32> { - %0 = call @_flip(%arg1) : (memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>>) -> memref<1x1x3x3xf32> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x1x1x10x10xf32> - linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : memref<1x1x10x10xf32, strided<[?, ?, ?, ?], offset: ?>>) outs(%alloc : memref<1x1x1x1x10x10xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x1x1x1x3x3xf32> - linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%0 : memref<1x1x3x3xf32>) outs(%alloc_0 : memref<1x1x1x1x3x3xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x1x1x1x8x8xf32> - %cst = arith.constant 0.000000e+00 : f32 - linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<1x1x1x1x8x8xf32>) - linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["reduction", "parallel", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel"]} ins(%alloc, %alloc_0 : memref<1x1x1x1x10x10xf32>, memref<1x1x1x1x3x3xf32>) outs(%alloc_1 : memref<1x1x1x1x8x8xf32>) { - ^bb0(%in: f32, %in_2: f32, %out: f32): - %1 = arith.mulf %in, %in_2 : f32 - %2 = arith.addf %out, %1 : f32 - linalg.yield %2 : f32 - } - %collapse_shape = memref.collapse_shape %alloc_1 [[0], [1], [2, 3, 4], [5]] : memref<1x1x1x1x8x8xf32> into memref<1x1x8x8xf32> - %cast = memref.cast %collapse_shape : memref<1x1x8x8xf32> to memref<1x1x8x8xf32, strided<[?, ?, ?, ?], offset: ?>> - return %collapse_shape : memref<1x1x8x8xf32> - } - func.func private @_flip(%arg0: memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default"}) -> memref<1x1x3x3xf32> { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x3x3xf32> - linalg.generic {indexing_maps = [#map5, #map6], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>>) outs(%alloc : memref<1x1x3x3xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } - %cast = memref.cast %alloc : memref<1x1x3x3xf32> to memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>> - return %alloc : memref<1x1x3x3xf32> - } -} diff --git a/docs/marimo/conv_donated.mlir b/docs/marimo/conv_donated.mlir deleted file mode 100644 index 900e695885..0000000000 --- a/docs/marimo/conv_donated.mlir +++ /dev/null @@ -1,44 +0,0 @@ -#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d0, d2 + d3, d4 + d5, d6 + d7, d8 + d9)> -#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d1, d0, d3, d5, d7, d9)> -#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d1, d2, d4, d6, d8)> -#map5 = affine_map<(d0, d1, d2, d3) -> (-d0, -d1, -d2 + 2, -d3 + 2)> -#map6 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -builtin.module { - func.func public @main(%arg0: tensor<1x1x10x10xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<1x1x3x3xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<1x1x8x8xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> tensor<1x1x8x8xf32> { - %0 = call @_flip(%arg1) : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x3xf32> - %1 = tensor.empty() : tensor<1x1x1x1x10x10xf32> - %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x1x10x10xf32>) outs(%1 : tensor<1x1x1x1x10x10xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<1x1x1x1x10x10xf32> - %3 = tensor.empty() : tensor<1x1x1x1x3x3xf32> - %4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%0 : tensor<1x1x3x3xf32>) outs(%3 : tensor<1x1x1x1x3x3xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<1x1x1x1x3x3xf32> - %5 = tensor.empty() : tensor<1x1x1x1x8x8xf32> - %cst = arith.constant 0.000000e+00 : f32 - %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x1x1x1x8x8xf32>) -> tensor<1x1x1x1x8x8xf32> - %7 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["reduction", "parallel", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel"]} ins(%2, %4 : tensor<1x1x1x1x10x10xf32>, tensor<1x1x1x1x3x3xf32>) outs(%6 : tensor<1x1x1x1x8x8xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %8 = arith.mulf %in, %in_0 : f32 - %9 = arith.addf %out, %8 : f32 - linalg.yield %9 : f32 - } -> tensor<1x1x1x1x8x8xf32> - %collapsed = tensor.collapse_shape %7 [[0], [1], [2, 3, 4], [5]] : tensor<1x1x1x1x8x8xf32> into tensor<1x1x8x8xf32> - - %output_1 = bufferization.materialize_in_destination %collapsed in %arg2 : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>) -> tensor<1x1x8x8xf32> - - return %output_1 : tensor<1x1x8x8xf32> - } - func.func private @_flip(%arg0: tensor<1x1x3x3xf32> {mhlo.layout_mode = "default"}) -> tensor<1x1x3x3xf32> { - %0 = tensor.empty() : tensor<1x1x3x3xf32> - %1 = linalg.generic {indexing_maps = [#map5, #map6], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x1x3x3xf32>) outs(%0 : tensor<1x1x3x3xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<1x1x3x3xf32> - return %1 : tensor<1x1x3x3xf32> - } -} diff --git a/docs/marimo/conv_donated_buffered.mlir b/docs/marimo/conv_donated_buffered.mlir deleted file mode 100644 index e4edb55c6e..0000000000 --- a/docs/marimo/conv_donated_buffered.mlir +++ /dev/null @@ -1,43 +0,0 @@ -#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d0, d2 + d3, d4 + d5, d6 + d7, d8 + d9)> -#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d1, d0, d3, d5, d7, d9)> -#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d10, d1, d2, d4, d6, d8)> -#map5 = affine_map<(d0, d1, d2, d3) -> (-d0, -d1, -d2 + 2, -d3 + 2)> -#map6 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -module { - func.func public @main(%arg0: memref<1x1x10x10xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<1x1x8x8xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> memref<1x1x8x8xf32, strided<[?, ?, ?, ?], offset: ?>> { - %0 = call @_flip(%arg1) : (memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>>) -> memref<1x1x3x3xf32> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x1x1x10x10xf32> - linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : memref<1x1x10x10xf32, strided<[?, ?, ?, ?], offset: ?>>) outs(%alloc : memref<1x1x1x1x10x10xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x1x1x1x3x3xf32> - linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%0 : memref<1x1x3x3xf32>) outs(%alloc_0 : memref<1x1x1x1x3x3xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x1x1x1x8x8xf32> - %cst = arith.constant 0.000000e+00 : f32 - linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<1x1x1x1x8x8xf32>) - linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["reduction", "parallel", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel", "reduction", "parallel"]} ins(%alloc, %alloc_0 : memref<1x1x1x1x10x10xf32>, memref<1x1x1x1x3x3xf32>) outs(%alloc_1 : memref<1x1x1x1x8x8xf32>) { - ^bb0(%in: f32, %in_2: f32, %out: f32): - %1 = arith.mulf %in, %in_2 : f32 - %2 = arith.addf %out, %1 : f32 - linalg.yield %2 : f32 - } - %collapse_shape = memref.collapse_shape %alloc_1 [[0], [1], [2, 3, 4], [5]] : memref<1x1x1x1x8x8xf32> into memref<1x1x8x8xf32> - memref.copy %collapse_shape, %arg2 : memref<1x1x8x8xf32> to memref<1x1x8x8xf32, strided<[?, ?, ?, ?], offset: ?>> - return %arg2 : memref<1x1x8x8xf32, strided<[?, ?, ?, ?], offset: ?>> - } - func.func private @_flip(%arg0: memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>> {mhlo.layout_mode = "default"}) -> memref<1x1x3x3xf32> { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x3x3xf32> - linalg.generic {indexing_maps = [#map5, #map6], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>>) outs(%alloc : memref<1x1x3x3xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } - %cast = memref.cast %alloc : memref<1x1x3x3xf32> to memref<1x1x3x3xf32, strided<[?, ?, ?, ?], offset: ?>> - return %alloc : memref<1x1x3x3xf32> - } -} diff --git a/docs/marimo/donated_buffered.mlir b/docs/marimo/donated_buffered.mlir deleted file mode 100644 index 0a98d0baed..0000000000 --- a/docs/marimo/donated_buffered.mlir +++ /dev/null @@ -1,17 +0,0 @@ -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -module { - func.func public @main(%arg0: memref<2x3xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<3x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<2x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> memref<2x4xf32, strided<[?, ?], offset: ?>> { - %cst = arith.constant 0.000000e+00 : f32 - linalg.fill ins(%cst : f32) outs(%arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>>) - linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<2x3xf32, strided<[?, ?], offset: ?>>, memref<3x4xf32, strided<[?, ?], offset: ?>>) outs(%arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %0 = arith.mulf %in, %in_0 : f32 - %1 = arith.addf %out, %0 : f32 - linalg.yield %1 : f32 - } - memref.copy %arg2, %arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>> to memref<2x4xf32, strided<[?, ?], offset: ?>> - return %arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>> - } -} diff --git a/docs/marimo/fused.mlir b/docs/marimo/fused.mlir deleted file mode 100644 index 17d6d16043..0000000000 --- a/docs/marimo/fused.mlir +++ /dev/null @@ -1,16 +0,0 @@ -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -module @jit_matmul attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { - %cst = arith.constant 0.000000e+00 : f32 - %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2x4xf32>) -> tensor<2x4xf32> - %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%0 : tensor<2x4xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %2 = arith.mulf %in, %in_0 : f32 - %3 = arith.addf %out, %2 : f32 - linalg.yield %3 : f32 - } -> tensor<2x4xf32> - return %1 : tensor<2x4xf32> - } -} diff --git a/docs/marimo/fused_buffered.mlir b/docs/marimo/fused_buffered.mlir deleted file mode 100644 index 93d7fbed92..0000000000 --- a/docs/marimo/fused_buffered.mlir +++ /dev/null @@ -1,14 +0,0 @@ -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -module @jit_matmul attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: memref<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: memref<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: memref<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (memref<2x4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { - linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<2x3xf32>, memref<3x4xf32>) outs(%arg2 : memref<2x4xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %0 = arith.mulf %in, %in_0 : f32 - %1 = arith.addf %out, %0 : f32 - linalg.yield %1 : f32 - } - return %arg2 : memref<2x4xf32> - } -} diff --git a/docs/marimo/jax_experiments.py b/docs/marimo/jax_experiments.py deleted file mode 100644 index 6dace3bfc1..0000000000 --- a/docs/marimo/jax_experiments.py +++ /dev/null @@ -1,172 +0,0 @@ -import marimo - -__generated_with = "0.8.5" -app = marimo.App(width="medium") - - -@app.cell -def __(): - import marimo as mo - return mo, - - -@app.cell -def __(): - import jax - import jax.numpy as jnp - import jax.scipy as jsp - from jax import lax - return jax, jnp, jsp, lax - - -@app.cell -def __(): - from jax import random - return random, - - -@app.cell -def __(get_linalg_module_str, jax, jnp, random): - def matmul(A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray): - return A @ B - - matmul_params = (2, 4, 3) - matmul_shapes = ((matmul_params[0], matmul_params[2]), (matmul_params[2], matmul_params[1]), (matmul_params[0], matmul_params[1])) - - key = jax.random.key(42) - - matmul_data = tuple(random.uniform(key, shape) for shape in matmul_shapes) - - matmul_jit = jax.jit(matmul, donate_argnames=['C'], keep_unused=True) - - get_linalg_module_str(matmul_jit, matmul_data) - return ( - key, - matmul, - matmul_data, - matmul_jit, - matmul_params, - matmul_shapes, - ) - - -@app.cell -def __(jax, jnp, jsp): - def conv(X: jnp.ndarray, K: jnp.ndarray, Z: jnp.ndarray): - return jsp.signal.convolve(X, K, mode="valid", method="direct") - - conv_jit = jax.jit(conv, donate_argnames=['Z'], keep_unused=True) - return conv, conv_jit - - -@app.cell -def __(key, random): - X_shape = (1, 1, 10, 10) - K_shape = (1, 1, 3, 3) - Z_shape = (1, 1, 1, 1, 8, 8) - - conv_data = tuple(random.uniform(key, shape) for shape in [X_shape, K_shape, Z_shape]) - return K_shape, X_shape, Z_shape, conv_data - - -@app.cell -def __(conv_data, conv_jit, get_linalg_module_str): - get_linalg_module_str(conv_jit, conv_data) - return - - -@app.cell -def __(get_linalg_module_str, jax, jnp, key, random): - def simple_arith(A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray): - return A + B + 1 - - arith_jit = jax.jit(simple_arith, donate_argnames=['C'], keep_unused=True) - arith_data = tuple(random.uniform(key, shape) for shape in [(5, 5), (5, 5), (5, 5)]) - - get_linalg_module_str(arith_jit, arith_data) - return arith_data, arith_jit, simple_arith - - -@app.cell -def __(get_linalg_module_str, jax, jnp, key, random): - def multiple_outputs(A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray, D: jnp.ndarray): - return (B @ A, A + 1, A @ B) - - multiple_jit = jax.jit(multiple_outputs, donate_argnames=['C', 'D'], keep_unused=True) - multiple_data = tuple(random.uniform(key, shape) for shape in [(5, 2), (2, 5), (5, 5), (2, 2)]) - - get_linalg_module_str(multiple_jit, multiple_data) - return multiple_data, multiple_jit, multiple_outputs - - -@app.cell -def __(matmul_data, matmul_jit): - lowered_matmul = matmul_jit.lower(*matmul_data) - lowered_matmul - return lowered_matmul, - - -@app.cell -def __(lowered_matmul): - type(lowered_matmul.compile()).__doc__ - return - - -@app.cell -def __(lowered_matmul, matmul_data): - lowered_matmul.compile()(*matmul_data) - return - - -@app.cell -def __(matmul, matmul_data): - matmul(*matmul_data), matmul(*matmul_data) - return - - -@app.cell -def __(): - from jax._src.interpreters import mlir - from jaxlib.mlir.dialects import mhlo - from jaxlib.mlir.passmanager import PassManager - - def get_linalg_module_str(func, args): - lowered = func.lower(*args) - - mhlo_module = lowered.compiler_ir(dialect="mhlo") - - # print(mhlo_module) - - with mhlo_module.context as ctx: - ctx.append_dialect_registry(mlir.upstream_dialects) - # ctx.load_all_available_dialects() - # mhlo.register_mhlo_dialect(ctx) - mhlo.register_mhlo_passes() - pipeline = PassManager.parse("builtin.module(hlo-legalize-to-arithmetic,func.func(hlo-legalize-to-linalg))") - pipeline.run(mhlo_module.operation) - - mhlo_module_str = f"{mhlo_module}" - - return mhlo_module_str - return PassManager, get_linalg_module_str, mhlo, mlir - - -@app.cell -def __(): - from jax import make_jaxpr - return make_jaxpr, - - -@app.cell -def __(make_jaxpr, matmul, matmul_data): - type(make_jaxpr(matmul)(*matmul_data)) - return - - -@app.cell -def __(): - return - - -if __name__ == "__main__": - app.run() diff --git a/docs/marimo/jax_mult.mlir b/docs/marimo/jax_mult.mlir deleted file mode 100644 index d13498fa8a..0000000000 --- a/docs/marimo/jax_mult.mlir +++ /dev/null @@ -1,17 +0,0 @@ -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -builtin.module { - func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> tensor<2x4xf32> { - %0 = tensor.empty() : tensor<2x4xf32> - %cst = arith.constant 0.000000e+00 : f32 - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> - %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %3 = arith.mulf %in, %in_0 : f32 - %4 = arith.addf %out, %3 : f32 - linalg.yield %4 : f32 - } -> tensor<2x4xf32> - return %2 : tensor<2x4xf32> - } -} diff --git a/docs/marimo/jax_mult_donation.mlir b/docs/marimo/jax_mult_donation.mlir deleted file mode 100644 index ef90bad128..0000000000 --- a/docs/marimo/jax_mult_donation.mlir +++ /dev/null @@ -1,20 +0,0 @@ -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -builtin.module { - func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> tensor<2x4xf32> { - %0 = tensor.empty() : tensor<2x4xf32> - %cst = arith.constant 0.000000e+00 : f32 - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> - %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %3 = arith.mulf %in, %in_0 : f32 - %4 = arith.addf %out, %3 : f32 - linalg.yield %4 : f32 - } -> tensor<2x4xf32> - - %output = bufferization.materialize_in_destination %2 in %arg2 : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> - - return %output : tensor<2x4xf32> - } -} diff --git a/docs/marimo/jax_rewrite_exp.py b/docs/marimo/jax_rewrite_exp.py deleted file mode 100644 index 8a06eb53af..0000000000 --- a/docs/marimo/jax_rewrite_exp.py +++ /dev/null @@ -1,62 +0,0 @@ -import marimo - -__generated_with = "0.8.5" -app = marimo.App(width="medium") - - -@app.cell -def __(): - import marimo as mo - return mo, - - -@app.cell -def __(): - original_func = """ - #map = affine_map<(d0, d1, d2) -> (d0, d2)> - #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> - #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> - - func.func main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { - %0 = tensor.empty() : tensor<2x4xf32> - %cst = arith.constant 0.000000e+00 : f32 - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> - %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %3 = arith.mulf %in, %in_0 : f32 - %4 = arith.addf %out, %3 : f32 - linalg.yield %4 : f32 - } -> tensor<2x4xf32> - return %2 : tensor<2x4xf32> - } - """ - return original_func, - - -@app.cell -def __(): - from xdsl.context import MLContext - from xdsl.parser import Parser - return MLContext, Parser - - -@app.cell -def __(MLContext, Parser, original_func): - ctx = MLContext() - parser = Parser(ctx, original_func) - return ctx, parser - - -@app.cell -def __(parser): - parser.parse_module(True) - return - - -@app.cell -def __(): - return - - -if __name__ == "__main__": - app.run() diff --git a/docs/marimo/multiple_donated_buffered.mlir b/docs/marimo/multiple_donated_buffered.mlir deleted file mode 100644 index b51b3be0da..0000000000 --- a/docs/marimo/multiple_donated_buffered.mlir +++ /dev/null @@ -1,38 +0,0 @@ -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -#map3 = affine_map<(d0, d1) -> (d0, d1)> -module @jit_multiple_outputs attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - memref.global "private" constant @__constant_5x2xf32 : memref<5x2xf32> = dense<1.000000e+00> {alignment = 64 : i64} - memref.global "private" constant @__constant_xf32 : memref = dense<1.000000e+00> {alignment = 64 : i64} - func.func public @main(%arg0: memref<5x2xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<2x5xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<5x5xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 2 : i32}, %arg3: memref<2x2xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (memref<2x2xf32, strided<[?, ?], offset: ?>> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, memref<5x2xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, memref<5x5xf32, strided<[?, ?], offset: ?>> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { - %cst = arith.constant 0.000000e+00 : f32 - linalg.fill ins(%cst : f32) outs(%arg3 : memref<2x2xf32, strided<[?, ?], offset: ?>>) - linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg1, %arg0 : memref<2x5xf32, strided<[?, ?], offset: ?>>, memref<5x2xf32, strided<[?, ?], offset: ?>>) outs(%arg3 : memref<2x2xf32, strided<[?, ?], offset: ?>>) { - ^bb0(%in: f32, %in_1: f32, %out: f32): - %2 = arith.mulf %in, %in_1 : f32 - %3 = arith.addf %out, %2 : f32 - linalg.yield %3 : f32 - } - %0 = memref.get_global @__constant_xf32 : memref - %1 = memref.get_global @__constant_5x2xf32 : memref<5x2xf32> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<5x2xf32> - linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%arg0, %1 : memref<5x2xf32, strided<[?, ?], offset: ?>>, memref<5x2xf32>) outs(%alloc : memref<5x2xf32>) { - ^bb0(%in: f32, %in_1: f32, %out: f32): - %2 = arith.addf %in, %in_1 : f32 - linalg.yield %2 : f32 - } - %cst_0 = arith.constant 0.000000e+00 : f32 - linalg.fill ins(%cst_0 : f32) outs(%arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>>) - linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<5x2xf32, strided<[?, ?], offset: ?>>, memref<2x5xf32, strided<[?, ?], offset: ?>>) outs(%arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>>) { - ^bb0(%in: f32, %in_1: f32, %out: f32): - %2 = arith.mulf %in, %in_1 : f32 - %3 = arith.addf %out, %2 : f32 - linalg.yield %3 : f32 - } - memref.copy %arg3, %arg3 : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2x2xf32, strided<[?, ?], offset: ?>> - memref.copy %arg2, %arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>> to memref<5x5xf32, strided<[?, ?], offset: ?>> - %cast = memref.cast %alloc : memref<5x2xf32> to memref<5x2xf32, strided<[?, ?], offset: ?>> - return %arg3, %alloc, %arg2 : memref<2x2xf32, strided<[?, ?], offset: ?>>, memref<5x2xf32>, memref<5x5xf32, strided<[?, ?], offset: ?>> - } -} diff --git a/docs/marimo/multiple_outputs.mlir b/docs/marimo/multiple_outputs.mlir deleted file mode 100644 index 62fc6dde24..0000000000 --- a/docs/marimo/multiple_outputs.mlir +++ /dev/null @@ -1,39 +0,0 @@ -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -#map3 = affine_map<(d0, d1) -> (d0, d1)> -module @jit_multiple_outputs attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<5x2xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<2x5xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<5x5xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 2 : i32}, %arg3: tensor<2x2xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x2xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<5x2xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<5x5xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { - %0 = tensor.empty() : tensor<2x2xf32> - %cst = arith.constant 0.000000e+00 : f32 - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2xf32>) -> tensor<2x2xf32> - %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg1, %arg0 : tensor<2x5xf32>, tensor<5x2xf32>) outs(%1 : tensor<2x2xf32>) { - ^bb0(%in: f32, %in_3: f32, %out: f32): - %8 = arith.mulf %in, %in_3 : f32 - %9 = arith.addf %out, %8 : f32 - linalg.yield %9 : f32 - } -> tensor<2x2xf32> - %cst_0 = arith.constant dense<1.000000e+00> : tensor - %cst_1 = arith.constant dense<1.000000e+00> : tensor<5x2xf32> - %3 = tensor.empty() : tensor<5x2xf32> - %4 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%arg0, %cst_1 : tensor<5x2xf32>, tensor<5x2xf32>) outs(%3 : tensor<5x2xf32>) { - ^bb0(%in: f32, %in_3: f32, %out: f32): - %8 = arith.addf %in, %in_3 : f32 - linalg.yield %8 : f32 - } -> tensor<5x2xf32> - %5 = tensor.empty() : tensor<5x5xf32> - %cst_2 = arith.constant 0.000000e+00 : f32 - %6 = linalg.fill ins(%cst_2 : f32) outs(%5 : tensor<5x5xf32>) -> tensor<5x5xf32> - %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<5x2xf32>, tensor<2x5xf32>) outs(%6 : tensor<5x5xf32>) { - ^bb0(%in: f32, %in_3: f32, %out: f32): - %8 = arith.mulf %in, %in_3 : f32 - %9 = arith.addf %out, %8 : f32 - linalg.yield %9 : f32 - } -> tensor<5x5xf32> - - %output_1 = bufferization.materialize_in_destination %2 in %arg3 : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %output_2 = bufferization.materialize_in_destination %7 in %arg2: (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> - - return %output_1, %4, %output_2 : tensor<2x2xf32>, tensor<5x2xf32>, tensor<5x5xf32> - } -} diff --git a/docs/marimo/removed.mlir b/docs/marimo/removed.mlir deleted file mode 100644 index 165e6680e7..0000000000 --- a/docs/marimo/removed.mlir +++ /dev/null @@ -1,17 +0,0 @@ -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -module @jit_matmul attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { - %0 = tensor.empty() : tensor<2x4xf32> - %cst = arith.constant 0.000000e+00 : f32 - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> - %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%arg2 : tensor<2x4xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %3 = arith.mulf %in, %in_0 : f32 - %4 = arith.addf %out, %3 : f32 - linalg.yield %4 : f32 - } -> tensor<2x4xf32> - return %2 : tensor<2x4xf32> - } -} diff --git a/tests/filecheck/mlir-conversion/with-mlir/jax_argument_donation_full.mlir b/tests/filecheck/mlir-conversion/with-mlir/jax_argument_donation_full.mlir deleted file mode 100644 index 309468a7ad..0000000000 --- a/tests/filecheck/mlir-conversion/with-mlir/jax_argument_donation_full.mlir +++ /dev/null @@ -1,109 +0,0 @@ -// RUN: xdsl-opt %s -p jax-use-donated-arguments --split-input-file | mlir-opt --eliminate-empty-tensors --one-shot-bufferize=bufferize-function-boundaries | filecheck %s - -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -builtin.module { - func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> tensor<2x4xf32> { - %0 = tensor.empty() : tensor<2x4xf32> - %cst = arith.constant 0.000000e+00 : f32 - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> - %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %3 = arith.mulf %in, %in_0 : f32 - %4 = arith.addf %out, %3 : f32 - linalg.yield %4 : f32 - } -> tensor<2x4xf32> - return %2 : tensor<2x4xf32> - } -} - -// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK-NEXT: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -// CHECK-NEXT: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-NEXT: #map3 = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-NEXT: module { -// CHECK-NEXT: module { -// CHECK-NEXT: func.func public @main(%arg0: memref<2x3xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<3x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<2x4xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> memref<2x4xf32, strided<[?, ?], offset: ?>> { -// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: linalg.fill ins(%cst : f32) outs(%arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>>) -// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<2x3xf32, strided<[?, ?], offset: ?>>, memref<3x4xf32, strided<[?, ?], offset: ?>>) outs(%arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>>) { -// CHECK-NEXT: ^bb0(%in: f32, %in_0: f32, %out: f32): -// CHECK-NEXT: %0 = arith.mulf %in, %in_0 : f32 -// CHECK-NEXT: %1 = arith.addf %out, %0 : f32 -// CHECK-NEXT: linalg.yield %1 : f32 -// CHECK-NEXT: } -// CHECK-NEXT: memref.copy %arg2, %arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>> to memref<2x4xf32, strided<[?, ?], offset: ?>> -// CHECK-NEXT: return %arg2 : memref<2x4xf32, strided<[?, ?], offset: ?>> -// CHECK-NEXT: } -// CHECK-NEXT: } - -#map3 = affine_map<(d0, d1) -> (d0, d1)> -builtin.module { - func.func public @main(%arg0: tensor<5x2xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<2x5xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<5x5xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 2 : i32}, %arg3: tensor<2x2xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x2xf32>, tensor<5x2xf32>, tensor<5x5xf32>) { - %0 = tensor.empty() : tensor<2x2xf32> - %cst = arith.constant 0.000000e+00 : f32 - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2xf32>) -> tensor<2x2xf32> - %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg1, %arg0 : tensor<2x5xf32>, tensor<5x2xf32>) outs(%1 : tensor<2x2xf32>) { - ^bb0(%in: f32, %in_3: f32, %out: f32): - %8 = arith.mulf %in, %in_3 : f32 - %9 = arith.addf %out, %8 : f32 - linalg.yield %9 : f32 - } -> tensor<2x2xf32> - %cst_0 = arith.constant dense<1.000000e+00> : tensor - %cst_1 = arith.constant dense<1.000000e+00> : tensor<5x2xf32> - %3 = tensor.empty() : tensor<5x2xf32> - %4 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%arg0, %cst_1 : tensor<5x2xf32>, tensor<5x2xf32>) outs(%3 : tensor<5x2xf32>) { - ^bb0(%in: f32, %in_3: f32, %out: f32): - %8 = arith.addf %in, %in_3 : f32 - linalg.yield %8 : f32 - } -> tensor<5x2xf32> - %5 = tensor.empty() : tensor<5x5xf32> - %cst_2 = arith.constant 0.000000e+00 : f32 - %6 = linalg.fill ins(%cst_2 : f32) outs(%5 : tensor<5x5xf32>) -> tensor<5x5xf32> - %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<5x2xf32>, tensor<2x5xf32>) outs(%6 : tensor<5x5xf32>) { - ^bb0(%in: f32, %in_3: f32, %out: f32): - %8 = arith.mulf %in, %in_3 : f32 - %9 = arith.addf %out, %8 : f32 - linalg.yield %9 : f32 - } -> tensor<5x5xf32> - return %2, %4, %7 : tensor<2x2xf32>, tensor<5x2xf32>, tensor<5x5xf32> - } -} - -// CHECK-NEXT: module { -// CHECK-NEXT: memref.global "private" constant @__constant_5x2xf32 : memref<5x2xf32> = dense<1.000000e+00> {alignment = 64 : i64} -// CHECK-NEXT: memref.global "private" constant @__constant_xf32 : memref = dense<1.000000e+00> {alignment = 64 : i64} -// CHECK-NEXT: func.func public @main(%arg0: memref<5x2xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg1: memref<2x5xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default"}, %arg2: memref<5x5xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 2 : i32}, %arg3: memref<2x2xf32, strided<[?, ?], offset: ?>> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (memref<2x2xf32, strided<[?, ?], offset: ?>>, memref<5x2xf32>, memref<5x5xf32, strided<[?, ?], offset: ?>>) { -// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: linalg.fill ins(%cst : f32) outs(%arg3 : memref<2x2xf32, strided<[?, ?], offset: ?>>) -// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg1, %arg0 : memref<2x5xf32, strided<[?, ?], offset: ?>>, memref<5x2xf32, strided<[?, ?], offset: ?>>) outs(%arg3 : memref<2x2xf32, strided<[?, ?], offset: ?>>) { -// CHECK-NEXT: ^bb0(%in: f32, %in_1: f32, %out: f32): -// CHECK-NEXT: %2 = arith.mulf %in, %in_1 : f32 -// CHECK-NEXT: %3 = arith.addf %out, %2 : f32 -// CHECK-NEXT: linalg.yield %3 : f32 -// CHECK-NEXT: } -// CHECK-NEXT: %0 = memref.get_global @__constant_xf32 : memref -// CHECK-NEXT: %1 = memref.get_global @__constant_5x2xf32 : memref<5x2xf32> -// CHECK-NEXT: %alloc = memref.alloc() {alignment = 64 : i64} : memref<5x2xf32> -// CHECK-NEXT: linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%arg0, %1 : memref<5x2xf32, strided<[?, ?], offset: ?>>, memref<5x2xf32>) outs(%alloc : memref<5x2xf32>) { -// CHECK-NEXT: ^bb0(%in: f32, %in_1: f32, %out: f32): -// CHECK-NEXT: %2 = arith.addf %in, %in_1 : f32 -// CHECK-NEXT: linalg.yield %2 : f32 -// CHECK-NEXT: } -// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: linalg.fill ins(%cst_0 : f32) outs(%arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>>) -// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : memref<5x2xf32, strided<[?, ?], offset: ?>>, memref<2x5xf32, strided<[?, ?], offset: ?>>) outs(%arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>>) { -// CHECK-NEXT: ^bb0(%in: f32, %in_1: f32, %out: f32): -// CHECK-NEXT: %2 = arith.mulf %in, %in_1 : f32 -// CHECK-NEXT: %3 = arith.addf %out, %2 : f32 -// CHECK-NEXT: linalg.yield %3 : f32 -// CHECK-NEXT: } -// CHECK-NEXT: memref.copy %arg3, %arg3 : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2x2xf32, strided<[?, ?], offset: ?>> -// CHECK-NEXT: memref.copy %arg2, %arg2 : memref<5x5xf32, strided<[?, ?], offset: ?>> to memref<5x5xf32, strided<[?, ?], offset: ?>> -// CHECK-NEXT: %cast = memref.cast %alloc : memref<5x2xf32> to memref<5x2xf32, strided<[?, ?], offset: ?>> -// CHECK-NEXT: return %arg3, %alloc, %arg2 : memref<2x2xf32, strided<[?, ?], offset: ?>>, memref<5x2xf32>, memref<5x5xf32, strided<[?, ?], offset: ?>> -// CHECK-NEXT: } -// CHECK-NEXT: } - -// CHECK-NEXT: } \ No newline at end of file From a720e3cf4cd408607d545d01bb00bb22d37111cf Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Mon, 14 Oct 2024 15:21:00 +0100 Subject: [PATCH 17/22] little fix --- xdsl/transforms/jax_use_donated_arguments.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/xdsl/transforms/jax_use_donated_arguments.py b/xdsl/transforms/jax_use_donated_arguments.py index ec93d56053..66ac72d266 100644 --- a/xdsl/transforms/jax_use_donated_arguments.py +++ b/xdsl/transforms/jax_use_donated_arguments.py @@ -4,7 +4,7 @@ from xdsl.dialects import builtin from xdsl.dialects.bufferization import MaterializeInDestination from xdsl.dialects.builtin import TensorType -from xdsl.dialects.func import Return +from xdsl.dialects.func import FuncOp, Return from xdsl.ir import Operation, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( @@ -26,18 +26,15 @@ class SubstituteDonatedTensors(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: Return, rewriter: PatternRewriter, /): func_op = op.parent_op() - if func_op is None: + if func_op is None or type(func_op) is not FuncOp: raise VerifyException("Return operation should be tied to a FuncOp") - arg_attrs = getattr(func_op, "arg_attrs") - args = getattr(func_op, "args") - - if arg_attrs is None: + if func_op.arg_attrs is None: return donated_inputs = [ inp - for inp, attr in zip(args, arg_attrs) + for inp, attr in zip(func_op.args, func_op.arg_attrs) if isinstance(inp.type, TensorType) and "tf.aliasing_output" in attr.data ] From 7459e678ef46c3af69b0f99e670b12c7fcb144ed Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Mon, 14 Oct 2024 15:25:58 +0100 Subject: [PATCH 18/22] fixed unnecessary onnx marimo changes --- docs/marimo/mlir/onnx_demo.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/marimo/mlir/onnx_demo.py b/docs/marimo/mlir/onnx_demo.py index 1da39cdc2d..286999ef20 100644 --- a/docs/marimo/mlir/onnx_demo.py +++ b/docs/marimo/mlir/onnx_demo.py @@ -4,7 +4,7 @@ app = marimo.App() -@app.cell +@app.cell(hide_code=True) def __(mo): mo.md( """ @@ -16,7 +16,7 @@ def __(mo): return -@app.cell +@app.cell(hide_code=True) def __(mo): rank = mo.ui.slider(1, 4, value=2, label="Rank") @@ -60,7 +60,7 @@ def __(mo, shape): return -@app.cell +@app.cell(hide_code=True) def __(): import onnx from onnx import AttributeProto, GraphProto, TensorProto, ValueInfoProto, helper @@ -121,7 +121,7 @@ def __(mo): return -@app.cell +@app.cell(hide_code=True) def __(mo, model_def): mo.accordion( { From d95168d8bf07934b112c138e3d45377a6c30f2e0 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Mon, 14 Oct 2024 15:45:42 +0100 Subject: [PATCH 19/22] style --- xdsl/transforms/jax_use_donated_arguments.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/xdsl/transforms/jax_use_donated_arguments.py b/xdsl/transforms/jax_use_donated_arguments.py index 66ac72d266..f00f592586 100644 --- a/xdsl/transforms/jax_use_donated_arguments.py +++ b/xdsl/transforms/jax_use_donated_arguments.py @@ -41,11 +41,8 @@ def match_and_rewrite(self, op: Return, rewriter: PatternRewriter, /): value_mapper: dict[SSAValue, SSAValue] = {} new_ops: list[Operation] = [] for output in op.arguments: - if type(getattr(output, "type")) is not TensorType: - continue - for i, arg in enumerate(donated_inputs): - if getattr(arg, "type").is_same_type_with(output.type): + if getattr(arg, "type").is_same_type_with(getattr(output, "type")): new_ops.append(make_materialize_op(output, donated_inputs.pop(i))) value_mapper[output] = new_ops[-1].results[0] break From 9f6ae3891febdb6136be31a1775250a75d00fb52 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Mon, 14 Oct 2024 16:21:52 +0100 Subject: [PATCH 20/22] fixes and simplifications --- .../transforms/jax-use-donated-arguments.mlir | 64 ++++++++++--------- xdsl/dialects/builtin.py | 13 ---- xdsl/transforms/jax_use_donated_arguments.py | 13 ++-- 3 files changed, 40 insertions(+), 50 deletions(-) diff --git a/tests/filecheck/transforms/jax-use-donated-arguments.mlir b/tests/filecheck/transforms/jax-use-donated-arguments.mlir index 34b4a69ebc..7c4832839e 100644 --- a/tests/filecheck/transforms/jax-use-donated-arguments.mlir +++ b/tests/filecheck/transforms/jax-use-donated-arguments.mlir @@ -2,53 +2,55 @@ builtin.module { func.func public @main(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x4xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) { - %0 = tensor.empty() : tensor<2x4xf32> - %cst = arith.constant 0.000000e+00 : f32 - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> - return %1 : tensor<2x4xf32> + %res = "test.op"() : () -> tensor<2x4xf32> + return %res : tensor<2x4xf32> } } // CHECK: builtin.module { // CHECK-NEXT: builtin.module { // CHECK-NEXT: func.func public @main(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>, %arg2 : tensor<2x4xf32> {"tf.aliasing_output" = 0 : i32}) -> tensor<2x4xf32> { -// CHECK-NEXT: %0 = tensor.empty() : tensor<2x4xf32> -// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK-NEXT: %2 = bufferization.materialize_in_destination %1 in %arg2 : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK-NEXT: func.return %2 : tensor<2x4xf32> +// CHECK-NEXT: %res = "test.op"() : () -> tensor<2x4xf32> +// CHECK-NEXT: %0 = bufferization.materialize_in_destination %res in %arg2 : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: func.return %0 : tensor<2x4xf32> // CHECK-NEXT: } // CHECK-NEXT: } builtin.module { -func.func public @main(%arg0: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32>, %arg2: tensor<4x5xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - - %0 = tensor.empty() : tensor<2x3xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x3xf32>) -> tensor<2x3xf32> - - %2 = tensor.empty() : tensor<2x3xf32> - %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<2x3xf32>) -> tensor<2x3xf32> +func.func public @main(%arg0: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>) { + %res1 = "test.op"() : () -> tensor<2x3xf32> + %res2 = "test.op"() : () -> tensor<2x3xf32> + return %res1, %res2 : tensor<2x3xf32>, tensor<2x3xf32> + } +} - %4 = tensor.empty() : tensor<4x5xf32> - %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<4x5xf32>) -> tensor<4x5xf32> +// CHECK-NEXT: builtin.module { +// CHECK-NEXT: func.func public @main(%arg0 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>) { +// CHECK-NEXT: %res1 = "test.op"() : () -> tensor<2x3xf32> +// CHECK-NEXT: %res2 = "test.op"() : () -> tensor<2x3xf32> +// CHECK-NEXT: %0 = bufferization.materialize_in_destination %res1 in %arg0 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: %1 = bufferization.materialize_in_destination %res2 in %arg1 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: func.return %0, %1 : tensor<2x3xf32>, tensor<2x3xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } - return %1, %3, %5 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32> +builtin.module { +func.func public @main(%arg0: tensor<4x5xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg2: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) { + %res1 = "test.op"() : () -> tensor<2x3xf32> + %res2 = "test.op"() : () -> tensor<2x3xf32> + %res3 = "test.op"() : () -> tensor<4x5xf32> + return %res1, %res2, %res3 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32> } } // CHECK-NEXT: builtin.module { -// CHECK-NEXT: func.func public @main(%arg0 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32>, %arg2 : tensor<4x5xf32> {"tf.aliasing_output" = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) { -// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: %0 = tensor.empty() : tensor<2x3xf32> -// CHECK-NEXT: %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x3xf32>) -> tensor<2x3xf32> -// CHECK-NEXT: %2 = tensor.empty() : tensor<2x3xf32> -// CHECK-NEXT: %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<2x3xf32>) -> tensor<2x3xf32> -// CHECK-NEXT: %4 = tensor.empty() : tensor<4x5xf32> -// CHECK-NEXT: %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<4x5xf32>) -> tensor<4x5xf32> -// CHECK-NEXT: %6 = bufferization.materialize_in_destination %1 in %arg0 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> -// CHECK-NEXT: %7 = bufferization.materialize_in_destination %5 in %arg2 : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> -// CHECK-NEXT: func.return %6, %3, %7 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32> +// CHECK-NEXT: func.func public @main(%arg0 : tensor<4x5xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg2 : tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) { +// CHECK-NEXT: %res1 = "test.op"() : () -> tensor<2x3xf32> +// CHECK-NEXT: %res2 = "test.op"() : () -> tensor<2x3xf32> +// CHECK-NEXT: %res3 = "test.op"() : () -> tensor<4x5xf32> +// CHECK-NEXT: %0 = bufferization.materialize_in_destination %res1 in %arg1 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-NEXT: %1 = bufferization.materialize_in_destination %res3 in %arg0 : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> +// CHECK-NEXT: func.return %0, %res2, %1 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index bc5339439d..443337b029 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -800,19 +800,6 @@ def get_shape(self) -> tuple[int, ...]: def get_element_type(self) -> AttributeCovT: return self.element_type - def is_same_type_with(self, other_tensor: TensorType[Attribute]) -> bool: - current_shape = list(self.shape) - other_shape = list(other_tensor.shape) - if len(current_shape) != len(other_shape): - return False - - return ( - len(list(filter(lambda x: x[0] != x[1], zip(current_shape, other_shape)))) - == 0 - and self.element_type == other_tensor.element_type - and self.encoding == other_tensor.encoding - ) - AnyTensorType: TypeAlias = TensorType[Attribute] AnyTensorTypeConstr = BaseAttr[TensorType[Attribute]](TensorType) diff --git a/xdsl/transforms/jax_use_donated_arguments.py b/xdsl/transforms/jax_use_donated_arguments.py index f00f592586..7df36a95b4 100644 --- a/xdsl/transforms/jax_use_donated_arguments.py +++ b/xdsl/transforms/jax_use_donated_arguments.py @@ -17,10 +17,6 @@ from xdsl.utils.exceptions import VerifyException -def make_materialize_op(source: SSAValue, dest: SSAValue) -> MaterializeInDestination: - return MaterializeInDestination(operands=[source, dest], result_types=[source.type]) - - @dataclass class SubstituteDonatedTensors(RewritePattern): @op_type_rewrite_pattern @@ -42,8 +38,13 @@ def match_and_rewrite(self, op: Return, rewriter: PatternRewriter, /): new_ops: list[Operation] = [] for output in op.arguments: for i, arg in enumerate(donated_inputs): - if getattr(arg, "type").is_same_type_with(getattr(output, "type")): - new_ops.append(make_materialize_op(output, donated_inputs.pop(i))) + if arg.type == output.type: + new_ops.append( + MaterializeInDestination( + operands=[output, donated_inputs.pop(i)], + result_types=[output.type], + ) + ) value_mapper[output] = new_ops[-1].results[0] break From 9b1bf3e1eb650ccfa550632e64a244b0faef9597 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Mon, 14 Oct 2024 17:27:58 +0100 Subject: [PATCH 21/22] tests and some style --- .../transforms/jax-use-donated-arguments.mlir | 24 +++++-------------- xdsl/transforms/jax_use_donated_arguments.py | 6 ++--- 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/tests/filecheck/transforms/jax-use-donated-arguments.mlir b/tests/filecheck/transforms/jax-use-donated-arguments.mlir index 7c4832839e..50beb4610b 100644 --- a/tests/filecheck/transforms/jax-use-donated-arguments.mlir +++ b/tests/filecheck/transforms/jax-use-donated-arguments.mlir @@ -1,50 +1,39 @@ // RUN: xdsl-opt %s -p jax-use-donated-arguments --split-input-file --verify-diagnostics | filecheck %s -builtin.module { -func.func public @main(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x4xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) { +func.func public @one_donation(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x4xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) { %res = "test.op"() : () -> tensor<2x4xf32> return %res : tensor<2x4xf32> } -} // CHECK: builtin.module { -// CHECK-NEXT: builtin.module { -// CHECK-NEXT: func.func public @main(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>, %arg2 : tensor<2x4xf32> {"tf.aliasing_output" = 0 : i32}) -> tensor<2x4xf32> { +// CHECK-NEXT: func.func public @one_donation(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>, %arg2 : tensor<2x4xf32> {"tf.aliasing_output" = 0 : i32}) -> tensor<2x4xf32> { // CHECK-NEXT: %res = "test.op"() : () -> tensor<2x4xf32> // CHECK-NEXT: %0 = bufferization.materialize_in_destination %res in %arg2 : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK-NEXT: func.return %0 : tensor<2x4xf32> // CHECK-NEXT: } -// CHECK-NEXT: } -builtin.module { -func.func public @main(%arg0: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>) { +func.func public @same_type_donation(%arg0: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>) { %res1 = "test.op"() : () -> tensor<2x3xf32> %res2 = "test.op"() : () -> tensor<2x3xf32> return %res1, %res2 : tensor<2x3xf32>, tensor<2x3xf32> } -} -// CHECK-NEXT: builtin.module { -// CHECK-NEXT: func.func public @main(%arg0 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>) { +// CHECK-NEXT: func.func public @same_type_donation(%arg0 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>) { // CHECK-NEXT: %res1 = "test.op"() : () -> tensor<2x3xf32> // CHECK-NEXT: %res2 = "test.op"() : () -> tensor<2x3xf32> // CHECK-NEXT: %0 = bufferization.materialize_in_destination %res1 in %arg0 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: %1 = bufferization.materialize_in_destination %res2 in %arg1 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: func.return %0, %1 : tensor<2x3xf32>, tensor<2x3xf32> // CHECK-NEXT: } -// CHECK-NEXT: } -builtin.module { -func.func public @main(%arg0: tensor<4x5xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg2: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) { +func.func public @non_trivial_donation(%arg0: tensor<4x5xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg2: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) { %res1 = "test.op"() : () -> tensor<2x3xf32> %res2 = "test.op"() : () -> tensor<2x3xf32> %res3 = "test.op"() : () -> tensor<4x5xf32> return %res1, %res2, %res3 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32> } -} -// CHECK-NEXT: builtin.module { -// CHECK-NEXT: func.func public @main(%arg0 : tensor<4x5xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg2 : tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) { +// CHECK-NEXT: func.func public @non_trivial_donation(%arg0 : tensor<4x5xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg2 : tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) { // CHECK-NEXT: %res1 = "test.op"() : () -> tensor<2x3xf32> // CHECK-NEXT: %res2 = "test.op"() : () -> tensor<2x3xf32> // CHECK-NEXT: %res3 = "test.op"() : () -> tensor<4x5xf32> @@ -52,6 +41,5 @@ func.func public @main(%arg0: tensor<4x5xf32> {tf.aliasing_output = 0 : i32}, %a // CHECK-NEXT: %1 = bufferization.materialize_in_destination %res3 in %arg0 : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> // CHECK-NEXT: func.return %0, %res2, %1 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32> // CHECK-NEXT: } -// CHECK-NEXT: } // CHECK-NEXT: } diff --git a/xdsl/transforms/jax_use_donated_arguments.py b/xdsl/transforms/jax_use_donated_arguments.py index 7df36a95b4..743c08d0ab 100644 --- a/xdsl/transforms/jax_use_donated_arguments.py +++ b/xdsl/transforms/jax_use_donated_arguments.py @@ -14,7 +14,6 @@ RewritePattern, op_type_rewrite_pattern, ) -from xdsl.utils.exceptions import VerifyException @dataclass @@ -22,15 +21,14 @@ class SubstituteDonatedTensors(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: Return, rewriter: PatternRewriter, /): func_op = op.parent_op() - if func_op is None or type(func_op) is not FuncOp: - raise VerifyException("Return operation should be tied to a FuncOp") + assert isinstance(func_op, FuncOp) if func_op.arg_attrs is None: return donated_inputs = [ inp - for inp, attr in zip(func_op.args, func_op.arg_attrs) + for inp, attr in zip(func_op.args, func_op.arg_attrs, strict=True) if isinstance(inp.type, TensorType) and "tf.aliasing_output" in attr.data ] From 7869d09eeb0fd2bdfae692e92eac3e3cdc3e0336 Mon Sep 17 00:00:00 2001 From: Max Manainen Date: Mon, 14 Oct 2024 17:44:16 +0100 Subject: [PATCH 22/22] broken parsing example --- .../transforms/broken-jax-parsing.mlir | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/filecheck/transforms/broken-jax-parsing.mlir diff --git a/tests/filecheck/transforms/broken-jax-parsing.mlir b/tests/filecheck/transforms/broken-jax-parsing.mlir new file mode 100644 index 0000000000..52b1cf5f88 --- /dev/null +++ b/tests/filecheck/transforms/broken-jax-parsing.mlir @@ -0,0 +1,20 @@ +// RUN: xdsl-opt %s -p jax-use-donated-arguments --split-input-file --verify-diagnostics | filecheck %s +// other passes like convert-linalg-to-loops have the same problems + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +module @jit_matmul attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = tensor.empty() : tensor<2x4xf32> + %cst = arith.constant 0.000000e+00 : f32 + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %3 = arith.mulf %in, %in_0 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<2x4xf32> + return %2 : tensor<2x4xf32> + } +}