From a6c415c65c3b9909bffc3d0a43e8fea55c571c08 Mon Sep 17 00:00:00 2001 From: Polykarpos Thomadakis Date: Wed, 27 Sep 2023 16:24:05 -0700 Subject: [PATCH] This commit partially addresses issue #29. When one or more tensors are transposed before a contraction (C=A*B), the shape of the contraction result might be different than the expected output C. Thus this intermediate results is stored in a temporary tensor, which will later be tranposed in the C, taking the expected shape. However, instead of just allocating this intermediate tensor we would also transpose tensor C into it, making a redundant trasposition. This commit removes this extra transposition. --- lib/Dialect/TensorAlgebra/Transforms/TCtoTTGT.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/TensorAlgebra/Transforms/TCtoTTGT.cpp b/lib/Dialect/TensorAlgebra/Transforms/TCtoTTGT.cpp index c383d1cd..1b4072c1 100644 --- a/lib/Dialect/TensorAlgebra/Transforms/TCtoTTGT.cpp +++ b/lib/Dialect/TensorAlgebra/Transforms/TCtoTTGT.cpp @@ -421,8 +421,11 @@ namespace MemRefType::get(lhsDims, lhsMemrefType.getElementType()), loc, rewriter); useLHSTranspose = true; + Value constantOp = rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + + rewriter.create(loc, constantOp, lhsAlloc); // TODO(gkestor): we might need this copy if we support update C[] += A[] * B[] - rewriter.create(loc, lhsMemref, lhsAlloc, llvm::ArrayRef(lhsOutPerm_int64)); + // rewriter.create(loc, lhsMemref, lhsAlloc, llvm::ArrayRef(lhsOutPerm_int64)); } RankedTensorType collapsedTensorType;