From 16fa50b8c7e30bf6884eed2ffb00d6e2a7fe551b Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Thu, 14 Nov 2024 14:16:17 -0800 Subject: [PATCH] Srinath's fixes for tiling interface --- lib/Dialect/IR/TcpTilingInterfaceImpl.cpp | 3 +- temp.mlir | 35 +++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 temp.mlir diff --git a/lib/Dialect/IR/TcpTilingInterfaceImpl.cpp b/lib/Dialect/IR/TcpTilingInterfaceImpl.cpp index 88162ee0..948bd260 100644 --- a/lib/Dialect/IR/TcpTilingInterfaceImpl.cpp +++ b/lib/Dialect/IR/TcpTilingInterfaceImpl.cpp @@ -143,7 +143,8 @@ struct SliceOpTiling getValueOrCreateConstantIndexOp(b, loc, sizes), sliceOp.getStrides()); return TilingResult{{returnSliceOp}, - SmallVector(returnSliceOp->getResults())}; + SmallVector(returnSliceOp->getResults()), + {extractOp}}; } LogicalResult diff --git a/temp.mlir b/temp.mlir new file mode 100644 index 00000000..eb07b18a --- /dev/null +++ b/temp.mlir @@ -0,0 +1,35 @@ +func.func @fuse_tcp_slice(%arg0: tensor<40x40xf32>) -> tensor<32x32xf32> { + %shape40 = tensor.empty() : tensor<40x40xf32> + + %0 = linalg.elemwise_binary ins(%arg0, %arg0 : tensor<40x40xf32>, tensor<40x40xf32>) + outs(%shape40: tensor<40x40xf32>) -> tensor<40x40xf32> + + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c3 = arith.constant 3 : index + %c5 = arith.constant 5 : index + %c1 = arith.constant 1 : index + %slice = tcp.slice %0 starts ( %c3, %c5 ) sizes ( %c32, %c32 ) strides ( %c1, %c1 ) : tensor<40x40xf32> -> tensor<32x32xf32> + + %shape = tensor.empty() : tensor<32x32xf32> + %ret = linalg.elemwise_unary ins(%slice: tensor<32x32xf32>) outs(%shape: tensor<32x32xf32>) -> tensor<32x32xf32> + + return %ret : tensor<32x32xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %unary = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg0 : (!transform.any_op) -> !transform.any_op + + %1, %loops:2 = transform.structured.fuse %unary {tile_sizes = [1, 1], tile_interchange = [0, 1]} + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + %func_op = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.op<"func.func"> + transform.apply_patterns to %func_op { + transform.apply_patterns.tensor.fold_tensor_empty + transform.apply_patterns.tensor.fold_tensor_subset_ops + } : !transform.op<"func.func"> + + transform.yield + } +} \ No newline at end of file