From 37f8cd9517ec0498ddaa0f7d1a8efb9ae0692eb0 Mon Sep 17 00:00:00 2001 From: pthomadakis <136025103+pthomadakis@users.noreply.github.com> Date: Thu, 12 Oct 2023 15:05:50 -0700 Subject: [PATCH] Fix redundant traspose in TCtoTTGT pass (#36) This commit partially addresses issue #29. When one or more tensors are transposed before a contraction (C=A*B), the shape of the contraction result might be different than the expected output C. Thus this intermediate results is stored in a temporary tensor, which will later be tranposed in the C, taking the expected shape. However, instead of just allocating this intermediate tensor we would also transpose tensor C into it, making a redundant trasposition. This commit removes this extra transposition and maintains it only for the case of C+=,-= A*B. --- lib/Dialect/TensorAlgebra/Transforms/TCtoTTGT.cpp | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TensorAlgebra/Transforms/TCtoTTGT.cpp b/lib/Dialect/TensorAlgebra/Transforms/TCtoTTGT.cpp index c383d1cd..475e7c45 100644 --- a/lib/Dialect/TensorAlgebra/Transforms/TCtoTTGT.cpp +++ b/lib/Dialect/TensorAlgebra/Transforms/TCtoTTGT.cpp @@ -421,8 +421,17 @@ namespace MemRefType::get(lhsDims, lhsMemrefType.getElementType()), loc, rewriter); useLHSTranspose = true; - // TODO(gkestor): we might need this copy if we support update C[] += A[] * B[] - rewriter.create(loc, lhsMemref, lhsAlloc, llvm::ArrayRef(lhsOutPerm_int64)); + double beta_val = betaAttr.cast().getValueAsDouble(); + + if(beta_val == 0) + { + Value constantOp = rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + rewriter.create(loc, constantOp, lhsAlloc); + } + else + { + rewriter.create(loc, lhsMemref, lhsAlloc, llvm::ArrayRef(lhsOutPerm_int64)); + } } RankedTensorType collapsedTensorType;