Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]: Jax tensor donation #3224

Closed
wants to merge 22 commits into from
20 changes: 20 additions & 0 deletions tests/filecheck/transforms/broken-jax-parsing.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: xdsl-opt %s -p jax-use-donated-arguments --split-input-file --verify-diagnostics | filecheck %s
// other passes like convert-linalg-to-loops have the same problems

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module @jit_matmul attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x4xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x4xf32> {mhlo.layout_mode = "default", tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = tensor.empty() : tensor<2x4xf32>
%cst = arith.constant 0.000000e+00 : f32
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4xf32>) -> tensor<2x4xf32>
%2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<2x3xf32>, tensor<3x4xf32>) outs(%1 : tensor<2x4xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%3 = arith.mulf %in, %in_0 : f32
%4 = arith.addf %out, %3 : f32
linalg.yield %4 : f32
} -> tensor<2x4xf32>
return %2 : tensor<2x4xf32>
}
}
45 changes: 45 additions & 0 deletions tests/filecheck/transforms/jax-use-donated-arguments.mlir
superlopuh marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// RUN: xdsl-opt %s -p jax-use-donated-arguments --split-input-file --verify-diagnostics | filecheck %s

func.func public @one_donation(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x4xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) {
%res = "test.op"() : () -> tensor<2x4xf32>
return %res : tensor<2x4xf32>
}

// CHECK: builtin.module {
// CHECK-NEXT: func.func public @one_donation(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>, %arg2 : tensor<2x4xf32> {"tf.aliasing_output" = 0 : i32}) -> tensor<2x4xf32> {
// CHECK-NEXT: %res = "test.op"() : () -> tensor<2x4xf32>
// CHECK-NEXT: %0 = bufferization.materialize_in_destination %res in %arg2 : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK-NEXT: func.return %0 : tensor<2x4xf32>
// CHECK-NEXT: }

func.func public @same_type_donation(%arg0: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
%res1 = "test.op"() : () -> tensor<2x3xf32>
%res2 = "test.op"() : () -> tensor<2x3xf32>
return %res1, %res2 : tensor<2x3xf32>, tensor<2x3xf32>
}

// CHECK-NEXT: func.func public @same_type_donation(%arg0 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
// CHECK-NEXT: %res1 = "test.op"() : () -> tensor<2x3xf32>
// CHECK-NEXT: %res2 = "test.op"() : () -> tensor<2x3xf32>
// CHECK-NEXT: %0 = bufferization.materialize_in_destination %res1 in %arg0 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: %1 = bufferization.materialize_in_destination %res2 in %arg1 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: func.return %0, %1 : tensor<2x3xf32>, tensor<2x3xf32>
// CHECK-NEXT: }

func.func public @non_trivial_donation(%arg0: tensor<4x5xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<2x3xf32> {tf.aliasing_output = 0 : i32}, %arg2: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) {
%res1 = "test.op"() : () -> tensor<2x3xf32>
%res2 = "test.op"() : () -> tensor<2x3xf32>
%res3 = "test.op"() : () -> tensor<4x5xf32>
return %res1, %res2, %res3 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>
}

// CHECK-NEXT: func.func public @non_trivial_donation(%arg0 : tensor<4x5xf32> {"tf.aliasing_output" = 0 : i32}, %arg1 : tensor<2x3xf32> {"tf.aliasing_output" = 0 : i32}, %arg2 : tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>) {
// CHECK-NEXT: %res1 = "test.op"() : () -> tensor<2x3xf32>
// CHECK-NEXT: %res2 = "test.op"() : () -> tensor<2x3xf32>
// CHECK-NEXT: %res3 = "test.op"() : () -> tensor<4x5xf32>
// CHECK-NEXT: %0 = bufferization.materialize_in_destination %res1 in %arg1 : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: %1 = bufferization.materialize_in_destination %res3 in %arg0 : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32>
// CHECK-NEXT: func.return %0, %res2, %1 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<4x5xf32>
// CHECK-NEXT: }

// CHECK-NEXT: }
6 changes: 6 additions & 0 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ def get_convert_stencil_to_ll_mlir():

return convert_stencil_to_ll_mlir.ConvertStencilToLLMLIRPass

def get_jax_use_donated_arguments():
from xdsl.transforms import jax_use_donated_arguments

return jax_use_donated_arguments.JaxUseDonatedArguments

def get_convert_riscv_scf_to_riscv_cf():
from xdsl.backend.riscv.lowering import convert_riscv_scf_to_riscv_cf

Expand Down Expand Up @@ -455,6 +460,7 @@ def get_stencil_shape_minimize():
"convert-stencil-to-csl-stencil": get_convert_stencil_to_csl_stencil,
"inline-snrt": get_convert_snrt_to_riscv,
"convert-stencil-to-ll-mlir": get_convert_stencil_to_ll_mlir,
"jax-use-donated-arguments": get_jax_use_donated_arguments,
"cse": get_cse,
"csl-stencil-bufferize": get_csl_stencil_bufferize,
"csl-stencil-materialize-stores": get_csl_stencil_materialize_stores,
Expand Down
64 changes: 64 additions & 0 deletions xdsl/transforms/jax_use_donated_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from dataclasses import dataclass

from xdsl.context import MLContext
from xdsl.dialects import builtin
from xdsl.dialects.bufferization import MaterializeInDestination
from xdsl.dialects.builtin import TensorType
from xdsl.dialects.func import FuncOp, Return
from xdsl.ir import Operation, SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)


@dataclass
class SubstituteDonatedTensors(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: Return, rewriter: PatternRewriter, /):
func_op = op.parent_op()
assert isinstance(func_op, FuncOp)

if func_op.arg_attrs is None:
return

donated_inputs = [
inp
for inp, attr in zip(func_op.args, func_op.arg_attrs, strict=True)
if isinstance(inp.type, TensorType) and "tf.aliasing_output" in attr.data
]

value_mapper: dict[SSAValue, SSAValue] = {}
new_ops: list[Operation] = []
for output in op.arguments:
for i, arg in enumerate(donated_inputs):
if arg.type == output.type:
new_ops.append(
MaterializeInDestination(
operands=[output, donated_inputs.pop(i)],
result_types=[output.type],
)
)
value_mapper[output] = new_ops[-1].results[0]
break

new_ops.append(op.clone(value_mapper))
rewriter.replace_matched_op(new_ops)


@dataclass(frozen=True)
class JaxUseDonatedArguments(ModulePass):
name = "jax-use-donated-arguments"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
the_one_pass = PatternRewriteWalker(
GreedyRewritePatternApplier([SubstituteDonatedTensors()]),
apply_recursively=False,
walk_reverse=True,
walk_regions_first=True,
)
the_one_pass.rewrite_module(op)
Loading