diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index dd451b337..6927752d2 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -646,6 +646,18 @@ class IndexStmt : public util::IntrusivePtr { IndexStmt divide(IndexVar i, IndexVar i1, IndexVar i2, size_t divideFactor) const; // TODO: TailStrategy + /// The loopfuse transformation fuses common outer loops in + /// 2 iteration graphs. + /// when performing loopfuse operation on an already branched index statement + /// eg: forall(l, where(forall(ijk, T(j,k) += A*B), forall(mjkn, X(l,m,n) += T*C*D))) + /// and we want to further breakdown T*C*D into T2 = T*C and T2*D + /// we can use the path vector to specify the branch we want to apply the fuse on + /// eg: loopfuse(2, true, {1}) where 2 refers to breaking T*C*D at the 2nd position + /// and true refers to making T*C as the producer (if false, then C*D will be the producer if used with 1) + /// and {1} refers to the branch we want to apply the fuse on + IndexStmt loopfuse(int pos, bool isProducerOnLeft, std::vector& path) const; + + /// The reorder transformation swaps two directly nested index /// variables in an iteration graph. This changes the order of /// iteration through the space and the order of tensor accesses. @@ -663,6 +675,9 @@ class IndexStmt : public util::IntrusivePtr { /// reorder takes a new ordering for a set of index variables that are directly nested in the iteration order IndexStmt reorder(std::vector reorderedvars) const; + /// reorders the index variables in a nested structure with where clauses + IndexStmt reorder(std::vector path, std::vector reorderedvars) const; + /// The mergeby transformation specifies how to merge iterators on /// the given index variable. By default, if an iterator is used for windowing /// it will be merged with the "gallop" strategy. @@ -1308,7 +1323,7 @@ std::vector getAttrQueryResults(IndexStmt stmt); // [Olivia] /// Returns the temporaries in the index statement, in the order they appear. -std::map getTemporaryLocations(IndexStmt stmt); +std::map > getTemporaryLocations(IndexStmt stmt); /// Returns the results in the index statement that should be assembled by /// ungrouped insertion. diff --git a/include/taco/index_notation/transformations.h b/include/taco/index_notation/transformations.h index b750e3961..a38494387 100644 --- a/include/taco/index_notation/transformations.h +++ b/include/taco/index_notation/transformations.h @@ -17,6 +17,7 @@ class IndexStmt; class TransformationInterface; class Reorder; class Precompute; +class LoopFuse; class ForAllReplace; class AddSuchThatPredicates; class Parallelize; @@ -32,6 +33,7 @@ class Transformation { public: Transformation(Reorder); Transformation(Precompute); + Transformation(LoopFuse); Transformation(ForAllReplace); Transformation(Parallelize); Transformation(TopoReorder); @@ -65,10 +67,12 @@ class Reorder : public TransformationInterface { public: Reorder(IndexVar i, IndexVar j); Reorder(std::vector replacePattern); + Reorder(std::vector path, std::vector replacePattern); IndexVar geti() const; IndexVar getj() const; const std::vector& getreplacepattern() const; + const std::vector& getpath() const; /// Apply the reorder optimization to a concrete index statement. Returns /// an undefined statement and a reason if the statement cannot be lowered. @@ -114,6 +118,28 @@ class Precompute : public TransformationInterface { /// Print a precompute command. std::ostream &operator<<(std::ostream &, const Precompute &); +/// The loopfuse optimization rewrite an index expression to precompute +/// part of the `expr` and store it to a workspace. +class LoopFuse : public TransformationInterface { +public: + LoopFuse(); + LoopFuse(int pos, bool isProducerOnLeft, std::vector& path); + + int getPos() const; + bool getIsProducerOnLeft() const; + std::vector& getPath() const; + + /// Apply the loopfuse optimization to a concrete index statement. + IndexStmt apply(IndexStmt, std::string *reason = nullptr) const; + + void print(std::ostream &os) const; + +private: + struct Content; + std::shared_ptr content; +}; + +std::ostream &operator<<(std::ostream &, const LoopFuse &); /// Replaces all occurrences of directly nested forall nodes of pattern with /// directly nested loops of replacement diff --git a/include/taco/lower/lowerer_impl_imperative.h b/include/taco/lower/lowerer_impl_imperative.h index fa97e3cd9..b19c19034 100644 --- a/include/taco/lower/lowerer_impl_imperative.h +++ b/include/taco/lower/lowerer_impl_imperative.h @@ -513,7 +513,7 @@ class LowererImplImperative : public LowererImpl { std::set nonFullyInitializedResults; /// Map used to hoist temporary workspace initialization - std::map temporaryInitialization; + std::map > temporaryInitialization; /// Map used to hoist parallel temporary workspaces. Maps workspace shared by all threads to where statement std::map whereToTemporaryVar; diff --git a/include/taco/parser/lexer.h b/include/taco/parser/lexer.h index 55dc74410..e304c2f71 100644 --- a/include/taco/parser/lexer.h +++ b/include/taco/parser/lexer.h @@ -22,6 +22,7 @@ enum class Token { sub, mul, div, + colon, // numbers before the colon indicate the path to branch in a branched iteration graph eq, eot, // End of tokens error diff --git a/include/taco/tensor.h b/include/taco/tensor.h index c462cbd32..75af68ba2 100644 --- a/include/taco/tensor.h +++ b/include/taco/tensor.h @@ -429,6 +429,7 @@ class TensorBase { /// Compute the given expression and put the values in the tensor storage. void compute(); + void compute(IndexStmt stmt); /// Compile, assemble and compute as needed. void evaluate(); diff --git a/src/codegen/module.cpp b/src/codegen/module.cpp index 89738f22d..08593bcca 100644 --- a/src/codegen/module.cpp +++ b/src/codegen/module.cpp @@ -128,7 +128,7 @@ string Module::compile() { #ifdef TACO_DEBUG // In debug mode, compile the generated code with debug symbols and a // low optimization level. - string defaultFlags = "-g -O0 -std=c99"; + string defaultFlags = "-O3 -ffast-math -std=c99"; #else // Otherwise, use the standard set of optimizing flags. string defaultFlags = "-O3 -ffast-math -std=c99"; @@ -145,6 +145,8 @@ string Module::compile() { prefix + file_ending + " " + shims_file + " " + "-o " + fullpath + " -lm"; + // std::cout << "Compiling generated code with command:\n" << cmd << "\n"; + // open the output file & write out the source compileToSource(tmpdir, libname); diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index e38e3d2d3..7cead8387 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -11,6 +11,7 @@ #include "error/error_checks.h" #include "taco/error/error_messages.h" +#include "taco/index_notation/index_notation_visitor.h" #include "taco/type.h" #include "taco/format.h" @@ -1854,6 +1855,18 @@ IndexStmt IndexStmt::divide(IndexVar i, IndexVar i1, IndexVar i2, size_t splitFa return transformed; } +IndexStmt IndexStmt::loopfuse(int pos, bool isProducerOnLeft, vector& path) const { + string reason; // reason saves the error message if the transformation fails + IndexStmt transformed = *this; + transformed = Transformation(LoopFuse(pos, isProducerOnLeft, path)).apply(transformed, &reason); + if (!transformed.defined()) { + taco_uerror << reason; + } + return transformed; + + return *this; +} + IndexStmt IndexStmt::precompute(IndexExpr expr, std::vector i_vars, std::vector iw_vars, TensorVar workspace) const { @@ -1907,6 +1920,15 @@ IndexStmt IndexStmt::reorder(std::vector reorderedvars) const { return transformed; } +IndexStmt IndexStmt::reorder(std::vector path, std::vector reorderedvars) const { + string reason; + IndexStmt transformed = Reorder(path, reorderedvars).apply(*this, &reason); + if (!transformed.defined()) { + taco_uerror << reason; + } + return transformed; +} + IndexStmt IndexStmt::mergeby(IndexVar i, MergeStrategy strategy) const { string reason; IndexStmt transformed = SetMergeStrategy(i, strategy).apply(*this, &reason); @@ -2048,6 +2070,7 @@ IndexStmt IndexStmt::assemble(TensorVar result, AssembleStrategy strategy, return transformed; } + IndexStmt IndexStmt::wsaccel(TensorVar& ws, bool shouldAccel, const std::vector& accelIndexVars) { if (accelIndexVars.size() == 0) { ws.setAccelIndexVars(accelIndexVars, shouldAccel); @@ -3452,20 +3475,39 @@ bool allForFreeLoopsBeforeAllReductionLoops(IndexStmt stmt) { return true; } -std::map getTemporaryLocations(IndexStmt stmt) { - map temporaryLocs; - Forall f = Forall(); - match(stmt, - function([&](const ForallNode* op, Matcher* ctx) { - f = op; - ctx->match(op->stmt); - }), - function([&](const WhereNode* w, Matcher* ctx) { - if (!(f == IndexStmt())) - temporaryLocs.insert({f, Where(w)}); - }) - ); - return temporaryLocs; +std::map > getTemporaryLocations(IndexStmt stmt) { + struct TemporaryLocsGetter : public IndexNotationVisitor { + map > temporaryLocs; + Forall f; + + using IndexNotationVisitor::visit; + + void visit(const ForallNode *op) { + Forall forall = Forall(op); + + if (f == NULL) { + f = op; + } + IndexNotationVisitor::visit(op); + } + + void visit(const WhereNode *op) { + Where where = Where(op); + if (temporaryLocs.find(f) != temporaryLocs.end()) { + temporaryLocs[f].push_back(where); + } + else { + vector whereVec; + whereVec.push_back(where); + temporaryLocs.insert({f, whereVec}); + } + IndexNotationVisitor::visit(op); + } + }; + TemporaryLocsGetter getter; + getter.visit(stmt); + + return getter.temporaryLocs; } diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index d53ec58c3..4d708e111 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -4,12 +4,15 @@ #include "taco/index_notation/index_notation_rewriter.h" #include "taco/index_notation/index_notation_nodes.h" #include "taco/error/error_messages.h" +#include "taco/index_notation/index_notation_visitor.h" +#include "taco/storage/index.h" #include "taco/util/collections.h" #include "taco/lower/iterator.h" #include "taco/lower/merge_lattice.h" #include "taco/lower/mode.h" #include "taco/lower/mode_format_impl.h" +#include #include #include #include @@ -30,6 +33,10 @@ Transformation::Transformation(Precompute precompute) : transformation(new Precompute(precompute)) { } +Transformation::Transformation(LoopFuse loopfuse) + : transformation(new LoopFuse(loopfuse)) { +} + Transformation::Transformation(ForAllReplace forallreplace) : transformation(new ForAllReplace(forallreplace)) { } @@ -58,6 +65,7 @@ std::ostream& operator<<(std::ostream& os, const Transformation& t) { // class Reorder struct Reorder::Content { + std::vector path; std::vector replacePattern; bool pattern_ordered; // In case of Reorder(i, j) need to change replacePattern ordering to actually reorder }; @@ -72,6 +80,12 @@ Reorder::Reorder(std::vector replacePattern) : content(new Conte content->pattern_ordered = true; } +Reorder::Reorder(std::vector path, std::vector replacePattern) : content(new Content) { + content->path = path; + content->replacePattern = replacePattern; + content->pattern_ordered = true; +} + IndexVar Reorder::geti() const { return content->replacePattern[0]; } @@ -87,13 +101,60 @@ const std::vector& Reorder::getreplacepattern() const { return content->replacePattern; } +const std::vector& Reorder::getpath() const { + return content->path; +} + IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const { INIT_REASON(reason); string r; - if (!isConcreteNotation(stmt, &r)) { - *reason = "The index statement is not valid concrete index notation: " + r; - return IndexStmt(); + + // TODO - Add a different check for concrete index notation with branching + // if (!isConcreteNotation(stmt, &r)) { + // *reason = "The index statement is not valid concrete index notation: " + r; + // return IndexStmt(); + // } + + IndexStmt originalStmt = stmt; + struct ReorderVisitor : public IndexNotationVisitor { + using IndexNotationVisitor::visit; + vector& path; + unsigned int pathIdx = 0; + IndexStmt innerStmt; + + ReorderVisitor(vector& path) : path(path) {} + + void visit(const ForallNode* node) { + if (pathIdx == path.size()) { + innerStmt = IndexStmt(node); + return; + } + IndexNotationVisitor::visit(node); + } + + void visit(const WhereNode* node) { + + Where where(node); + + if (pathIdx == path.size()) { + innerStmt = IndexStmt(node); + return; + } + + if (!path[pathIdx]) { + pathIdx++; + IndexNotationVisitor::visit(node->producer); + } else { + pathIdx++; + IndexNotationVisitor::visit(node->consumer); + } + } + }; + ReorderVisitor reorderVisitor(content->path); + if (content->path.size() > 0) { + reorderVisitor.visit(stmt); + stmt = reorderVisitor.innerStmt; } // collect current ordering of IndexVars @@ -123,7 +184,50 @@ IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const { *reason = "The foralls of reorder pattern: " + util::join(getreplacepattern()) + " were not directly nested."; return IndexStmt(); } - return ForAllReplace(currentOrdering, getreplacepattern()).apply(stmt, reason); + + auto reorderedStmt = ForAllReplace(currentOrdering, getreplacepattern()).apply(stmt, reason); + + struct ReorderedRewriter : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + + IndexStmt reorderedStmt; + vector& path; + vector visited; + + ReorderedRewriter(IndexStmt reorderedStmt, vector& path) : reorderedStmt(reorderedStmt), path(path) {} + + void visit(const ForallNode* node) { + // at the end of the path, rewrite should happen using the producer and consumer + if (visited == path) { + stmt = reorderedStmt; + return; + } + IndexNotationRewriter::visit(node); + } + + void visit(const WhereNode* node) { + Where where(node); + + // add 0 to visited if the producer is visited and 1 if the consumer is visited + visited.push_back(0); + IndexStmt producer = rewrite(node->producer); + visited.pop_back(); + visited.push_back(1); + IndexStmt consumer = rewrite(node->consumer); + visited.pop_back(); + if (producer == node->producer && consumer == node->consumer) { + stmt = node; + } + else { + stmt = new WhereNode(consumer, producer); + } + + } + }; + ReorderedRewriter reorderedRewriter(reorderedStmt, content->path); + stmt = reorderedRewriter.rewrite(originalStmt); + + return stmt; } void Reorder::print(std::ostream& os) const { @@ -233,6 +337,344 @@ std::ostream& operator<<(std::ostream& os, const SetMergeStrategy& setmergestrat return os; } +// class LoopFuse +struct LoopFuse::Content { + int pos; + int isProducerOnLeft; + std::vector path; +}; + +LoopFuse::LoopFuse() : content(nullptr) { +} + +LoopFuse::LoopFuse(int pos, bool isProducerOnLeft, std::vector& path) : content(new Content) { + content->pos = pos; + content->path = path; + content->isProducerOnLeft = isProducerOnLeft; +} + +int LoopFuse::getPos() const { + return content->pos; +} + +bool LoopFuse::getIsProducerOnLeft() const { + return content->isProducerOnLeft; +} + +std::vector& LoopFuse::getPath() const { + return content->path; +} + +IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { + INIT_REASON(reason); + + auto printVector = [](const vector& array) { + for (auto& var : array) { + cout << var << " "; + } + cout << endl; + }; + auto printSet = [](const set& array) { + for (auto& var : array) { + cout << var << " "; + } + cout << endl; + }; + + struct GetAssignment : public IndexNotationVisitor { + using IndexNotationVisitor::visit; + Assignment innerAssignment; + vector indexAccessVars; + vector indexVarsUntilBranch; + vector indexVarsAfterBranch; + unsigned int pathIdx = 0; + vector path; + + // insert constructor with path + GetAssignment(vector& _path) : path(_path) {} + + void visit(const ForallNode* node) { + Forall forall(node); + + indexAccessVars.push_back(forall.getIndexVar()); + if (pathIdx < path.size()) { + indexVarsUntilBranch.push_back(forall.getIndexVar()); + } + if (pathIdx >= path.size()) { + indexVarsAfterBranch.push_back(forall.getIndexVar()); + } + + if (isa(forall.getStmt())) { + innerAssignment = to(forall.getStmt()); + } + else { + IndexNotationVisitor::visit(node); + } + } + + void visit(const WhereNode* node) { + Where where(node); + + if (!path[pathIdx]) { // if path[pathIdx] == 0, go to the producer + pathIdx++; + IndexNotationVisitor::visit(node->producer); + } else { + pathIdx++; + IndexNotationVisitor::visit(node->consumer); + } + + } + }; + GetAssignment getAssignment(getPath()); + stmt.accept(&getAssignment); + + // saves the result, producer and consumer of the assignment + // result = producer * consumer + // eg: Assignment is A(i,j) += B(i,j) * C(j,k) * D(k,l) + // if the pos is 1, the result is A(i,j), the producer is B(i,j) and the consumer is C(j,k) * D(k,l) + // if the post is 2, the result is A(i,j), the producer is B(i,j) * C(j,k) and the consumer is D(k,l) + // resultVars, producerVars and consumerVars are the index variables of the result, producer and consumer + struct GetProducerAndConsumer : public IndexNotationVisitor { + using IndexNotationVisitor::visit; + int pos; + int pathIdx = 0; + bool isProducerOnLeft; + vector path; + IndexExpr result; + IndexExpr producer; + IndexExpr consumer; + vector resultVars; + set producerVars; + set consumerVars; + map> varTypes; + + GetProducerAndConsumer(int _pos, int _isProducerOnLeft, vector& _path) : pos(_pos), isProducerOnLeft(_isProducerOnLeft), path(_path), result(nullptr), producer(nullptr), consumer(nullptr), varTypes({}) {} + + void addIndexVar(Access access) { + // get the dimension and type of each index variable in tensor + for (unsigned long i = 0; i < access.getIndexVars().size(); i++) { + auto tensorVar = access.getTensorVar(); // Tensor variable like A, B + auto indexVar = access.getIndexVars()[i]; // Index variable like i, j + auto tensorVarType = tensorVar.getType(); + varTypes[indexVar] = make_pair(tensorVarType, tensorVarType.getShape().getDimension(i)); + } + } + + void visit(const AssignmentNode* node) { + Assignment assignment(node); + // result is stored in the left hand side of the assignment + result = assignment.getLhs(); + resultVars = assignment.getLhs().getIndexVars(); + + // add the index variables of the result to the map + addIndexVar(to(assignment.getLhs())); + + // visit the tensor contraction expression on the left hand side of += or = + IndexNotationVisitor::visit(assignment.getRhs()); + } + + void visit(const WhereNode* node) { + Where where(node); + + // select the path to visit + if (!path[pathIdx]) { // if path[pathIdx] == 0, go to the producer + pathIdx++; + IndexNotationVisitor::visit(node->producer); + } else { + pathIdx++; + IndexNotationVisitor::visit(node->consumer); + } + } + + // lhs is a multiplication in the tensor contraction + void visit(const MulNode* node) { + Mul mul(node); + IndexNotationVisitor::visit(mul.getA()); + IndexNotationVisitor::visit(mul.getB()); + } + + void visit(const AccessNode* node) { + Access access(node); + IndexExpr* it; + set* vars; + if ((pos > 0 && isProducerOnLeft) || (pos <= 0 && !isProducerOnLeft)) { it = &producer; vars = &producerVars; } + else { it = &consumer; vars = &consumerVars; } + + if (*it == nullptr) *it = access; + else *it = *it * access; + + for (const auto& var : access.getIndexVars()) { + vars->insert(var); + } + + // add the index variables of the access to the map + addIndexVar(access); + pos--; + } + }; + GetProducerAndConsumer getProducerAndConsumer(getPos(), getIsProducerOnLeft(), getPath()); + stmt.accept(&getProducerAndConsumer); + + // indices in the temporary comes from the producer indices (IndexVars) + // that are either in result indices or in consumer indices + // indices in the producer that are neither in producer indices nor in consumer indices + // gets contracted within the producer computation + vector temporaryVars; + for (auto& var : getProducerAndConsumer.producerVars) { + auto itC = getProducerAndConsumer.consumerVars.find(var); + auto itR = find(getProducerAndConsumer.resultVars.begin(), getProducerAndConsumer.resultVars.end(), var); + if (itC != getProducerAndConsumer.consumerVars.end() || + itR != getProducerAndConsumer.resultVars.end()) { + temporaryVars.push_back(var); + } + } + + // get the producer index access pattern + // get the consumer index access pattern + vector producerLoopVars; + vector consumerLoopVars; + for (auto& var : getAssignment.indexAccessVars) { + auto itP = getProducerAndConsumer.producerVars.find(var); + auto itC = getProducerAndConsumer.consumerVars.find(var); + auto itR = find(getProducerAndConsumer.resultVars.begin(), getProducerAndConsumer.resultVars.end(), var); + // check if variable is in the producer + if (itP != getProducerAndConsumer.producerVars.end()) { + producerLoopVars.push_back(var); + } + // check if variable is in the consumer or result + if (itC != getProducerAndConsumer.consumerVars.end() || itR != getProducerAndConsumer.resultVars.end()) { + consumerLoopVars.push_back(var); + } + } + + // remove indices from producerLoops and consumerLoops that are in getAssignment.indexVarsUntilBranch + for (auto& var : getAssignment.indexVarsUntilBranch) { + producerLoopVars.erase(remove(producerLoopVars.begin(), producerLoopVars.end(), var), producerLoopVars.end()); + consumerLoopVars.erase(remove(consumerLoopVars.begin(), consumerLoopVars.end(), var), consumerLoopVars.end()); + } + + // check if there are common outer loops in producerAccessOrder and consumerAccessOrder + vector commonLoopVars; + for (auto& var : getAssignment.indexVarsAfterBranch) { + auto itC = find(consumerLoopVars.begin(), consumerLoopVars.end(), var); + auto itP = find(producerLoopVars.begin(), producerLoopVars.end(), var); + if (itC != consumerLoopVars.end() && itP != producerLoopVars.end()) { + commonLoopVars.push_back(var); + temporaryVars.erase(remove(temporaryVars.begin(), temporaryVars.end(), var), temporaryVars.end()); + } + else { + break; + } + } + + // remove commonLoopVars from producerLoopVars and consumerLoopVars + for (auto& var : commonLoopVars) { + producerLoopVars.erase(remove(producerLoopVars.begin(), producerLoopVars.end(), var), producerLoopVars.end()); + consumerLoopVars.erase(remove(consumerLoopVars.begin(), consumerLoopVars.end(), var), consumerLoopVars.end()); + } + + // create the intermediate tensor + vector temporaryDims; + vector temporaryModes; + // populate shape of the intermediate tensor + auto populateDimension = + [&](map>& varTypes) { + for (auto& var : temporaryVars) { + temporaryDims.push_back(varTypes[var].second); + temporaryModes.push_back(ModeFormat{Dense}); + } + }; + populateDimension(getProducerAndConsumer.varTypes); + Access resultAccess = to(getProducerAndConsumer.result); + string path = ""; + for (auto& p : getPath()) { + path += to_string(p); + } + TensorVar intermediateTensor("t_" + resultAccess.getTensorVar().getName() + path, Type(Float64, temporaryDims)); + Access workspace(intermediateTensor, temporaryVars); + + Assignment producerAssignment(workspace, getProducerAndConsumer.producer, getAssignment.innerAssignment.getOperator()); + + Assignment consumerAssignment; + // if the producer is on left, then consumer is constructed by + // multiplying workspace * consumer and if the producer is on right, + // then the consumer is constructed by multiplying consumer * workspace + if (!getIsProducerOnLeft()) { + consumerAssignment = Assignment(to(getProducerAndConsumer.result), getProducerAndConsumer.consumer == NULL ? workspace : getProducerAndConsumer.consumer * workspace, getAssignment.innerAssignment.getOperator()); + } else { + consumerAssignment = Assignment(to(getProducerAndConsumer.result), getProducerAndConsumer.consumer == NULL ? workspace : workspace * getProducerAndConsumer.consumer, getAssignment.innerAssignment.getOperator()); + } + + // rewrite the index notation to use the temporary + // eg: Assignment is A(i,j) += B(i,j) * C(j,k) * D(k,l) + // T(i,k) += B(i,j) * C(j,k) is the producer and A(i,j) += T(i,k) * D(k,l) is the consumer + struct ProducerConsumerRewriter : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + vector& path; + vector visited; + Assignment& producer; + Assignment& consumer; + vector& commonLoopVars; + vector& producerLoopVars; + vector& consumerLoopVars; + + // constructor + ProducerConsumerRewriter(vector& _path, Assignment& producer, Assignment& consumer, vector& commonLoopVars, vector& producerLoopVars, vector& consumerLoopVars) : + path(_path), producer(producer), consumer(consumer), commonLoopVars(commonLoopVars), producerLoopVars(producerLoopVars), consumerLoopVars(consumerLoopVars) {} + + IndexStmt generateForalls(IndexStmt innerStmt, vector indexVars) { + auto returnStmt = innerStmt; + for (auto it = indexVars.rbegin(); it != indexVars.rend(); ++it) { + returnStmt = forall(*it, returnStmt); + } + + return returnStmt; + }; + + // should find the path to get to this loop to perform the rewrite + void visit(const ForallNode* node) { + // at the end of the path, rewrite should happen using the producer and consumer + if (visited == path) { + IndexStmt consumer = generateForalls(this->consumer, consumerLoopVars); + IndexStmt producer = generateForalls(this->producer, producerLoopVars); + Where where(consumer, producer); + stmt = generateForalls(where, commonLoopVars); + return; + } + IndexNotationRewriter::visit(node); + } + + void visit(const WhereNode* node) { + Where where(node); + + // add 0 to visited if the producer is visited and 1 if the consumer is visited + visited.push_back(0); + IndexStmt producer = rewrite(node->producer); + visited.pop_back(); + visited.push_back(1); + IndexStmt consumer = rewrite(node->consumer); + visited.pop_back(); + if (producer == node->producer && consumer == node->consumer) { + stmt = node; + } + else { + stmt = new WhereNode(consumer, producer); + } + } + }; + + ProducerConsumerRewriter rewriter(getPath(), producerAssignment, consumerAssignment, commonLoopVars, producerLoopVars, consumerLoopVars); + stmt = rewriter.rewrite(stmt); + + return stmt; +} + + + +void LoopFuse::print(std::ostream &os) const { + os << "fuse(" << getPos() << ", " << util::join(getPath()) << ")"; +} + // class Precompute struct Precompute::Content { IndexExpr expr; @@ -722,10 +1164,12 @@ IndexStmt ForAllReplace::apply(IndexStmt stmt, string* reason) const { INIT_REASON(reason); string r; - if (!isConcreteNotation(stmt, &r)) { - *reason = "The index statement is not valid concrete index notation: " + r; - return IndexStmt(); - } + + // TODO - Add a different check for concrete index notation with branching + // if (!isConcreteNotation(stmt, &r)) { + // *reason = "The index statement is not valid concrete index notation: " + r; + // return IndexStmt(); + // } /// Since all IndexVars can only appear once, assume replacement will work and error if it doesn't struct ForAllReplaceRewriter : public IndexNotationRewriter { diff --git a/src/lower/lowerer_impl_imperative.cpp b/src/lower/lowerer_impl_imperative.cpp index c1f614463..614693b3f 100644 --- a/src/lower/lowerer_impl_imperative.cpp +++ b/src/lower/lowerer_impl_imperative.cpp @@ -783,14 +783,22 @@ Stmt LowererImplImperative::lowerForall(Forall forall) // Emit temporary initialization if forall is sequential or parallelized by // cpu threads and leads to a where statement // This is for workspace hoisting by 1-level + vector> temporaryValuesInit; vector temporaryValuesInitFree = {Stmt(), Stmt()}; auto temp = temporaryInitialization.find(forall); - if (temp != temporaryInitialization.end() && forall.getParallelUnit() == - ParallelUnit::NotParallel && !isScalar(temp->second.getTemporary().getType())) - temporaryValuesInitFree = codeToInitializeTemporary(temp->second); - else if (temp != temporaryInitialization.end() && forall.getParallelUnit() == - ParallelUnit::CPUThread && !isScalar(temp->second.getTemporary().getType())) { - temporaryValuesInitFree = codeToInitializeTemporaryParallel(temp->second, forall.getParallelUnit()); + if (temp != temporaryInitialization.end()) { + auto whereClauses = temp->second; + // iterate over whereClauses + for (auto& where : whereClauses) { + if (forall.getParallelUnit() == ParallelUnit::NotParallel && !isScalar(where.getTemporary().getType())) { + temporaryValuesInitFree = codeToInitializeTemporary(where); + temporaryValuesInit.push_back(temporaryValuesInitFree); + } + else if (forall.getParallelUnit() == ParallelUnit::CPUThread && !isScalar(where.getTemporary().getType())) { + temporaryValuesInitFree = codeToInitializeTemporaryParallel(where, forall.getParallelUnit()); + temporaryValuesInit.push_back(temporaryValuesInitFree); + } + } } Stmt loops; @@ -890,10 +898,21 @@ Stmt LowererImplImperative::lowerForall(Forall forall) parallelUnitIndexVars.erase(forall.getParallelUnit()); parallelUnitSizes.erase(forall.getParallelUnit()); } + + vector inits; + vector frees; + // iterate over temporaryValuesInit and add to inits and frees + for (auto& s : temporaryValuesInit) { + inits.push_back(s[0]); + frees.push_back(s[1]); + } + Stmt initsBlock = Block::make(inits); + Stmt freesBlock = Block::make(frees); + return Block::blanks(preInitValues, - temporaryValuesInitFree[0], + initsBlock, loops, - temporaryValuesInitFree[1]); + freesBlock); } Stmt LowererImplImperative::lowerForallCloned(Forall forall) { @@ -2523,20 +2542,23 @@ Stmt LowererImplImperative::lowerWhere(Where where) { vector temporaryValuesInitFree = {Stmt(), Stmt()}; bool temporaryHoisted = false; for (auto it = temporaryInitialization.begin(); it != temporaryInitialization.end(); ++it) { - if (it->second == where && it->first.getParallelUnit() == - ParallelUnit::NotParallel && !isScalar(temporary.getType())) { - temporaryHoisted = true; - } else if (it->second == where && it->first.getParallelUnit() == - ParallelUnit::CPUThread && !isScalar(temporary.getType())) { - temporaryHoisted = true; - auto decls = codeToInitializeLocalTemporaryParallel(where, it->first.getParallelUnit()); - - temporaryValuesInitFree[0] = ir::Block::make(decls); + auto whereClauses = it->second; + for (auto& whereClause : whereClauses) { + if (whereClause == where && it->first.getParallelUnit() == + ParallelUnit::NotParallel && !isScalar(temporary.getType())) { + temporaryHoisted = true; + } else if (whereClause == where && it->first.getParallelUnit() == + ParallelUnit::CPUThread && !isScalar(temporary.getType())) { + temporaryHoisted = true; + auto decls = codeToInitializeLocalTemporaryParallel(where, it->first.getParallelUnit()); + + temporaryValuesInitFree[0] = ir::Block::make(decls); + } } - } - if (!temporaryHoisted) { - temporaryValuesInitFree = codeToInitializeTemporary(where); + if (!temporaryHoisted) { + temporaryValuesInitFree = codeToInitializeTemporary(where); + } } Stmt initializeTemporary = temporaryValuesInitFree[0]; diff --git a/src/parser/lexer.cpp b/src/parser/lexer.cpp index a490840a8..e4cf3830b 100644 --- a/src/parser/lexer.cpp +++ b/src/parser/lexer.cpp @@ -84,6 +84,9 @@ Token Lexer::getToken() { case '/': token = Token::div; break; + case ';': + token = Token::colon; + break; case '=': token = Token::eq; break; diff --git a/src/parser/schedule_parser.cpp b/src/parser/schedule_parser.cpp index 0666601da..2b099f4b3 100644 --- a/src/parser/schedule_parser.cpp +++ b/src/parser/schedule_parser.cpp @@ -52,6 +52,10 @@ vector> ScheduleParser(const string argValue) { current_element += lexer.tokenString(tok); parenthesesCnt--; break; + case parser::Token::colon: + current_schedule.push_back(current_element); + current_element = ""; + break; case parser::Token::comma: if (curlyParenthesesCnt > 0) { // multiple indexes inside of a {} list; pass it through diff --git a/src/tensor.cpp b/src/tensor.cpp index 257c396c3..78e30a3b7 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -775,6 +775,41 @@ static inline map getTensors(const IndexExpr& expr) { return getOperands.arguments; } +static inline map getTensors(const IndexStmt& stmt, vector& operands) { + struct GetOperands : public IndexNotationVisitor { + using IndexNotationVisitor::visit; + vector& operands; + map arguments; + + GetOperands(vector& operands) : operands(operands) {} + + void visit(const AccessNode* node) { + if (!isa(node)) { + return; // temporary ignore + } + Access ac = Access(node); + taco_iassert(isa(node)) << "Unknown subexpression"; + + if (!util::contains(arguments, node->tensorVar)) { + arguments.insert({node->tensorVar, to(node)->tensor}); + operands.push_back(node->tensorVar); + } + + // Also add any tensors backing index sets of tensor accesses. + for (auto& p : node->indexSetModes) { + auto tv = p.second.tensor.getTensorVar(); + if (!util::contains(arguments, tv)) { + arguments.insert({tv, p.second.tensor}); + operands.push_back(tv); + } + } + } + }; + GetOperands getOperands(operands); + stmt.accept(&getOperands); + return getOperands.arguments; +} + static inline vector packArguments(const TensorBase& tensor) { vector arguments; @@ -805,6 +840,35 @@ vector packArguments(const TensorBase& tensor) { return arguments; } +static inline +vector packArguments(const TensorBase& tensor, const IndexStmt stmt) { + vector arguments; + + // Pack the result tensor + arguments.push_back(tensor.getStorage()); + + // Pack any index sets on the result tensor at the front of the arguments list. + auto lhs = getNode(tensor.getAssignment().getLhs()); + // We check isa rather than isa to catch cases + // where the underlying access is represented with the base AccessNode class. + if (isa(lhs)) { + auto indexSetModes = to(lhs)->indexSetModes; + for (auto& it : indexSetModes) { + arguments.push_back(it.second.tensor.getStorage()); + } + } + + // Pack operand tensors + std::vector operands; + auto tensors = getTensors(stmt, operands); + for (auto& operand : operands) { + taco_iassert(util::contains(tensors, operand)); + arguments.push_back(tensors.at(operand).getStorage()); + } + + return arguments; +} + void TensorBase::assemble() { taco_uassert(!needsCompile()) << error::assemble_without_compile; if (!needsAssemble()) { @@ -831,7 +895,7 @@ void TensorBase::compute() { if (!needsCompute()) { return; } - setNeedsCompute(false); + // setNeedsCompute(false); // Sync operand tensors if needed. auto operands = getTensors(getAssignment().getRhs()); for (auto& operand : operands) { @@ -843,7 +907,30 @@ void TensorBase::compute() { this->content->module->callFuncPacked("compute", arguments.data()); if (content->assembleWhileCompute) { - setNeedsAssemble(false); + // setNeedsAssemble(false); + taco_tensor_t* tensorData = ((taco_tensor_t*)arguments[0]); + content->valuesSize = unpackTensorData(*tensorData, *this); + } +} + +void TensorBase::compute(IndexStmt stmt) { + taco_uassert(!needsCompile()) << error::compute_without_compile; + if (!needsCompute()) { + return; + } + // setNeedsCompute(false); + // Sync operand tensors if needed. + auto operands = getTensors(getAssignment().getRhs()); + for (auto& operand : operands) { + operand.second.syncValues(); + operand.second.removeDependentTensor(*this); + } + + auto arguments = packArguments(*this, stmt); + this->content->module->callFuncPacked("compute", arguments.data()); + + if (content->assembleWhileCompute) { + // setNeedsAssemble(false); taco_tensor_t* tensorData = ((taco_tensor_t*)arguments[0]); content->valuesSize = unpackTensorData(*tensorData, *this); } diff --git a/test/tests-indexstmt.cpp b/test/tests-indexstmt.cpp index e2a972430..a62776bad 100644 --- a/test/tests-indexstmt.cpp +++ b/test/tests-indexstmt.cpp @@ -2,6 +2,8 @@ #include "test_tensors.h" #include "taco/tensor.h" #include "taco/index_notation/index_notation.h" +#include "taco/index_notation/kernel.h" +#include "taco/index_notation/transformations.h" using namespace taco; const IndexVar i("i"), j("j"), k("k"); @@ -84,4 +86,43 @@ TEST(indexstmt, spmm) { } - +TEST(indexstmt, sddmmPlusSpmm) { + Type t(type(), {3,3}); + const IndexVar i("i"), j("j"), k("k"), l("l"); + + TensorVar A("A", t, Format{Dense, Dense}); + TensorVar B("B", t, Format{Dense, Sparse}); + TensorVar C("C", t, Format{Dense, Dense}); + TensorVar D("D", t, Format{Dense, Dense}); + TensorVar E("E", t, Format{Dense, Dense}); + + TensorVar tmp("tmp", Type(), Format()); + + // A(i,j) = B(i,j) * C(i,k) * D(j,k) * E(j,l) + IndexStmt fused = + forall(i, + forall(j, + forall(k, + forall(l, A(i,l) += B(i,j) * C(i,k) * D(j,k) * E(j,l)) + ) + ) + ); + + std::cout << "before topological sort: " << fused << std::endl; + fused = reorderLoopsTopologically(fused); + std::cout << "after topological sort: " << fused << std::endl; + + Kernel kernel = compile(fused); + + IndexStmt fusedNested = + forall(i, + forall(j, + where( + forall(l, A(i,l) += tmp * E(j,l)), // consumer + forall(k, tmp += B(i,j) * C(i,k) * D(j,k)) // producer + ) + ) + ); + + std::cout << "nested loop stmt: " << fusedNested << std::endl; +} diff --git a/test/tests-workspaces.cpp b/test/tests-workspaces.cpp index 350fb8538..62c2f28db 100644 --- a/test/tests-workspaces.cpp +++ b/test/tests-workspaces.cpp @@ -1,15 +1,37 @@ +#include +#include #include #include #include +#include #include "test.h" #include "test_tensors.h" #include "taco/tensor.h" #include "taco/index_notation/index_notation.h" #include "codegen/codegen.h" #include "taco/lower/lower.h" +#include "taco/util/env.h" +#include "time.h" using namespace taco; +void printCodeToFile(string filename, IndexStmt stmt) { + stringstream source; + + string file_path = "eval_generated/"; + mkdir(file_path.c_str(), 0777); + + std::shared_ptr codegen = ir::CodeGen::init_default(source, ir::CodeGen::ImplementationGen); + ir::Stmt compute = lower(stmt, "compute", true, true); + codegen->compile(compute, true); + + ofstream source_file; + string file_ending = should_use_CUDA_codegen() ? ".cu" : ".c"; + source_file.open(file_path + filename + file_ending); + source_file << source.str(); + source_file.close(); +} + TEST(workspaces, tile_vecElemMul_NoTail) { Tensor A("A", {16}, Format{Dense}); @@ -36,6 +58,7 @@ TEST(workspaces, tile_vecElemMul_NoTail) { .split(i_bounded, i0, i1, 4) .precompute(precomputedExpr, i1, i1, precomputed); + printCodeToFile("tile_vecElemMul_NoTail", stmt); A.compile(stmt); A.assemble(); A.compute(); @@ -73,6 +96,7 @@ TEST(workspaces, tile_vecElemMul_Tail1) { stmt = stmt.bound(i, i_bounded, 16, BoundType::MaxExact) .split(i_bounded, i0, i1, 5) .precompute(precomputedExpr, i1, i1, precomputed); + printCodeToFile("tile_vecElemMul_Tail1", stmt); A.compile(stmt.concretize()); A.assemble(); @@ -111,6 +135,7 @@ TEST(workspaces, tile_vecElemMul_Tail2) { stmt = stmt.bound(i, i_bounded, 17, BoundType::MaxExact) .split(i_bounded, i0, i1, 4) .precompute(precomputedExpr, i1, i1, precomputed); + printCodeToFile("tile_vecElemMul_Tail2", stmt); A.compile(stmt.concretize()); A.assemble(); @@ -161,6 +186,7 @@ TEST(workspaces, tile_denseMatMul) { .split(i_bounded, i0, i1, 4); stmt = stmt.precompute(precomputedExpr, i1, i1, precomputed); + printCodeToFile("tile_denseMatMul", stmt); A.compile(stmt.concretize()); A.assemble(); @@ -209,6 +235,9 @@ TEST(workspaces, precompute2D_add) { TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); stmt = stmt.precompute(precomputedExpr, {i, j}, {i, j}, ws); + std::cout << stmt << endl; + printCodeToFile("precompute2D_ad", stmt); + A.compile(stmt.concretize()); A.assemble(); A.compute(); @@ -253,6 +282,8 @@ TEST(workspaces, precompute4D_add) { Format{Dense, Dense, Dense, Dense}); stmt = stmt.precompute(precomputedExpr, {i, j, k, l}, {i, j, k, l}, ws1) .precompute(ws1(i, j, k, l) + D(i, j, k, l), {i, j, k, l}, {i, j, k ,l}, ws2); + std::cout << stmt << endl; + printCodeToFile("precompute4D_add", stmt); A.compile(stmt.concretize()); A.assemble(); @@ -295,6 +326,9 @@ TEST(workspaces, precompute4D_multireduce) { TensorVar ws2("ws2", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); stmt = stmt.precompute(precomputedExpr, {i, j, m}, {i, j, m}, ws1) .precompute(ws1(i, j, m) * D(m, n), {i, j}, {i, j}, ws2); + + std::cout << stmt << endl; + printCodeToFile("precompute4D_multireduce", stmt); A.compile(stmt.concretize()); A.assemble(); @@ -335,6 +369,9 @@ TEST(workspaces, precompute3D_TspV) { stmt = stmt.precompute(precomputedExpr, {i, j, k}, {i, j, k}, ws); stmt = stmt.concretize(); + std::cout << stmt << endl; + printCodeToFile("precompute3D_TspV", stmt); + A.compile(stmt); A.assemble(); A.compute(); @@ -379,6 +416,9 @@ TEST(workspaces, precompute3D_multipleWS) { stmt = stmt.precompute(ws(i, j, k) * c(k), {i, j}, {i, j}, t); stmt = stmt.concretize(); + std::cout << stmt << endl; + printCodeToFile("precompute3D_multipleWS", stmt); + A.compile(stmt); A.assemble(); A.compute(); @@ -422,6 +462,9 @@ TEST(workspaces, precompute3D_renamedIVars_TspV) { stmt = stmt.precompute(precomputedExpr, {i, j, k}, {iw, jw, kw}, ws); stmt = stmt.concretize(); + std::cout << stmt << endl; + printCodeToFile("precompute3D_renamedIVars_TspV", stmt); + A.compile(stmt); A.assemble(); A.compute(); @@ -475,6 +518,9 @@ TEST(workspaces, tile_dotProduct_1) { stmt = stmt.concretize(); + std::cout << stmt << endl; + printCodeToFile("tile_dotProduct_1", stmt); + A.compile(stmt); A.assemble(); A.compute(); @@ -538,6 +584,9 @@ TEST(workspaces, tile_dotProduct_2) { stmt = stmt.concretize(); + std::cout << stmt << endl; + printCodeToFile("tile_dotProduct_2", stmt); + stmt = stmt.wsaccel(precomputed, false); A.compile(stmt); A.assemble(); @@ -588,6 +637,9 @@ TEST(workspaces, tile_dotProduct_3) { stmt = stmt.concretize(); + std::cout << stmt << endl; + printCodeToFile("tile_dotProduct_3", stmt); + A.compile(stmt); A.assemble(); A.compute(); @@ -599,3 +651,1280 @@ TEST(workspaces, tile_dotProduct_3) { expected.compute(); ASSERT_TENSOR_EQ(expected, A); } + + +TEST(workspaces, loopfuse) { + int N = 16; + float SPARSITY = 0.3; + Tensor A("A", {N, N}, Format{Dense, Dense}); + Tensor B("B", {N, N}, Format{Dense, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + Tensor E("E", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + float rand_float = (float) rand() / (float) RAND_MAX; + if (rand_float < SPARSITY) + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + E.insert({i, j}, (double) i*j); + D.insert({i, j}, (double) i*j); + } + } + B.pack(); + + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"); + A(i,m) = B(i,j) * C(j,k) * D(k,l) * E(l,m); + + IndexStmt stmt = A.getAssignment().concretize(); + // TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + // TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + std::cout << stmt << endl; + vector path0; + vector path1 = {1}; + vector path2 = {1, 0}; + // + stmt = stmt + .reorder({i, l, j, k, m}) + .loopfuse(1, true, path0); + + std::cout << "inter: " << stmt << std::endl; + + stmt = stmt + .reorder(path1, {l, j}) + .loopfuse(2, false, path1) + // .loopfuse(1, false, path2) + ; + // stmt = stmt + // .parallelize(i, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces) + // ; + // + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("loopfuse", stmt); + + A.compile(stmt); + A.assemble(); + + clock_t begin = clock(); + A.compute(stmt); + clock_t end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC; + + std::cout << "executed\n"; + + Tensor expected("expected", {N, N}, Format{Dense, Dense}); + expected(i,m) = B(i,j) * C(j,k) * D(k,l) * E(l,m); + expected.compile(); + expected.assemble(); + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC; + ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; +} + +TEST(workspaces, sddmm_spmm) { + int N = 16; + float SPARSITY = 0.3; + Tensor A("A", {N, N}, Format{Dense, Dense}); + Tensor B("B", {N, N}, Format{Dense, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + Tensor E("E", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + float rand_float = (float) rand() / (float) RAND_MAX; + if (rand_float < SPARSITY) + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + E.insert({i, j}, (double) i*j); + D.insert({i, j}, (double) i*j); + } + } + B.pack(); + + + + // 3 -> A(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l) - + IndexVar i("i"), j("j"), k("k"), l("l"); + A(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l); + + IndexStmt stmt = A.getAssignment().concretize(); + // TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + // TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + std::cout << stmt << endl; + + /* BEGIN sddmm_spmm TEST */ + vector path0; + stmt = stmt + .reorder({i, j, k, l}) + .loopfuse(3, true, path0) + ; + /* END sddmm_spmm TEST */ + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("sddmm_spmm", stmt); + + A.compile(stmt); + A.assemble(); + + Tensor expected("expected", {N, N}, Format{Dense, Dense}); + expected(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l); + IndexStmt exp = makeReductionNotation(expected.getAssignment()); + exp = insertTemporaries(exp); + exp = exp.concretize(); + expected.compile(exp); + expected.assemble(); + + clock_t begin; + clock_t end; + + for (int i = 0; i< 10; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC; + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + + + +} + +TEST(workspaces, sddmm_spmm_gemm) { + int N = 16; + float SPARSITY = 0.3; + Tensor A("A", {N, N}, Format{Dense, Dense}); + Tensor B("B", {N, N}, Format{Dense, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + Tensor E("E", {N, N}, Format{Dense, Dense}); + Tensor F("F", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + float rand_float = (float) rand() / (float) RAND_MAX; + if (rand_float < SPARSITY) + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + E.insert({i, j}, (double) i*j); + D.insert({i, j}, (double) i*j); + F.insert({i, j}, (double) i*j); + } + } + B.pack(); + + + + // 3 -> A(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l) - + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"); + A(i,m) = B(i,j) * C(i,k) * D(j,k) * E(j,l) * F(l,m); + + IndexStmt stmt = A.getAssignment().concretize(); + // TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + // TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + std::cout << stmt << endl; + + /* BEGIN sddmm_spmm TEST */ + vector path0; + stmt = stmt + .reorder({i, j, k, l, m}) + .loopfuse(3, true, path0) + ; + /* END sddmm_spmm TEST */ + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("sddmm_spmm", stmt); + + A.compile(stmt); + A.assemble(); + + Tensor expected("expected", {N, N}, Format{Dense, Dense}); + expected(i,m) = B(i,j) * C(i,k) * D(j,k) * E(j,l) * F(l,m); + IndexStmt exp = makeReductionNotation(expected.getAssignment()); + exp = insertTemporaries(exp); + exp = exp.concretize(); + expected.compile(exp); + expected.assemble(); + + clock_t begin; + clock_t end; + + for (int i = 0; i< 10; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC; + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + + + +} + +TEST(workspaces, sddmm_spmm_gemm_real) { + + int K = 16; + int L = 16; + int M = 16; + + std::string mat_file = util::getFromEnv("TENSOR_FILE", ""); + + std::cout << mat_file << std::endl; + + Tensor B = read(mat_file, Format({Dense, Sparse}), true); + B.setName("B"); + B.pack(); + + if (mat_file == "") { + std::cout << "No tensor file specified!\n"; + return; + } + + Tensor C("C", {B.getDimension(0), K}, Format{Dense, Dense}); + for (int i=0; i D("D", {B.getDimension(1), K}, Format{Dense, Dense}); + for (int j=0; j E("E", {B.getDimension(1), L}, Format{Dense, Dense}); + for (int j=0; j F("F", {L, M}, Format{Dense, Dense}); + for (int j=0; j A("A", {B.getDimension(0), M}, Format{Dense, Dense}); + + // 3 -> A(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l) * F(l,m) - + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"); + A(i,m) = B(i,j) * C(i,k) * D(j,k) * E(j,l) * F(l,m); + + IndexStmt stmt = A.getAssignment().concretize(); + // TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + // TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + std::cout << stmt << endl; + + /* BEGIN sddmm_spmm_gemm_real TEST */ + vector path0; + vector path1 = {1}; + vector path2 = {1, 0}; + vector path3 = {1, 0, 0}; + vector path4 = {1, 1}; + vector path5 = {1, 0, 1}; + vector path6 = {1, 0, 0, 0}; + stmt = stmt + .reorder({i, k, j, l, m}) + .loopfuse(1, true, path0) + // .loopfuse(4, true, path1) + // .loopfuse(3, true, path2) + // .loopfuse(1, false, path3) + // .reorder(path4, {m, l}) + // .reorder(path5, {l, j}) + // .reorder(path6, {j, k}) + ; + /* END sddmm_spmm_gemm_real TEST */ + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("sddmm_spmm", stmt); + + A.compile(stmt); + A.assemble(); + + Tensor expected("expected", {B.getDimension(0), M}, Format{Dense, Dense}); + expected(i,m) = B(i,j) * C(i,k) * D(j,k) * E(j,l) * F(l,m); + IndexStmt exp = makeReductionNotation(expected.getAssignment()); + exp = insertTemporaries(exp); + exp = exp.concretize(); + expected.compile(exp); + expected.assemble(); + + clock_t begin; + clock_t end; + + for (int i = 0; i< 10; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC * 1000; + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC * 1000; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + + std::cout << "workspaces, sddmm_spmm_gemm -> execution completed for matrix: " << mat_file << std::endl; + +} + +TEST(workspaces, sddmm_spmm_real) { + int K = 16; + int L = 16; + + std::string mat_file = util::getFromEnv("TENSOR_FILE", ""); + + Tensor B = read(mat_file, Format({Dense, Sparse}), true); + B.setName("B"); + B.pack(); + + if (mat_file == "") { + std::cout << "No tensor file specified!\n"; + return; + } + + Tensor C("C", {B.getDimension(0), K}, Format{Dense, Dense}); + for (int i=0; i D("D", {B.getDimension(1), K}, Format{Dense, Dense}); + for (int j=0; j E("E", {B.getDimension(1), L}, Format{Dense, Dense}); + for (int j=0; j A("A", {B.getDimension(0), L}, Format{Dense, Dense}); + + + // 3 -> A(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l) - + IndexVar i("i"), j("j"), k("k"), l("l"); + A(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l); + + IndexStmt stmt = A.getAssignment().concretize(); + // TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + // TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + std::cout << stmt << endl; + + /* BEGIN sddmm_spmm_real TEST */ + vector path0; + stmt = stmt + .reorder({i, j, k, l}) + .loopfuse(3, true, path0) + ; + /* END sddmm_spmm_real TEST */ + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("sddmm_spmm", stmt); + + A.compile(stmt); + A.assemble(); + + Tensor expected("expected", {B.getDimension(0), L}, Format{Dense, Dense}); + expected(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l); + IndexStmt exp = makeReductionNotation(expected.getAssignment()); + exp = insertTemporaries(exp); + exp = exp.concretize(); + expected.compile(exp); + expected.assemble(); + + clock_t begin; + clock_t end; + + for (int i = 0; i< 10; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC * 1000; + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC * 1000; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + + std::cout << "workspaces, sddmm_spmm -> execution completed for matrix: " << mat_file << std::endl; + +} + +TEST(workspaces, loopreversefuse) { + int N = 16; + float SPARSITY = 0.3; + Tensor A("A", {N, N}, Format{Dense, Dense}); + Tensor B("B", {N, N}, Format{Dense, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + Tensor E("E", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + float rand_float = (float) rand() / (float) RAND_MAX; + if (rand_float < SPARSITY) + B.insert({i, j}, (double) rand_float); + C.insert({i, j}, (double) j); + E.insert({i, j}, (double) i*j); + D.insert({i, j}, (double) i*j); + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"); + A(i,m) = B(i,j) * C(j,k) * D(k,l) * E(l,m); + + IndexStmt stmt = A.getAssignment().concretize(); + + std::cout << stmt << endl; + vector path1; + stmt = stmt + .reorder({m,i,l,k,j}) + .loopfuse(3, false, path1) + ; + // stmt = stmt + // .parallelize(m, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces) + // ; + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("loopreversefuse", stmt); + + A.compile(stmt); + B.pack(); + A.assemble(); + A.compute(stmt); + + Tensor expected("expected", {N, N}, Format{Dense, Dense}); + expected(i,m) = B(i,j) * C(j,k) * D(k,l) * E(l,m); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +} + +TEST(workspaces, loopcontractfuse) { + int N = 16; + Tensor A("A", {N, N, N}, Format{Dense, Dense, Dense}); + Tensor B("B", {N, N, N}, Format{Dense, Sparse, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + Tensor E("E", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + for (int k = 0; k < N; k++) { + B.insert({i, j, k}, (double) i); + } + C.insert({i, j}, (double) j); + E.insert({i, j}, (double) i*j); + D.insert({i, j}, (double) i*j); + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n"); + A(l,m,n) = B(i,j,k) * C(i,l) * D(j,m) * E(k,n); + + IndexStmt stmt = A.getAssignment().concretize(); + + std::cout << stmt << endl; + + /* BEGIN loopcontractfuse TEST */ + vector path0; + vector path1 = {1}; + vector path2 = {1, 0}; + vector path3 = {1, 1}; + stmt = stmt + .reorder({l, i, j, k, m, n}) + .loopfuse(2, true, path0) + .loopfuse(2, true, path1) + .reorder(path2, {m, k, j}) + .reorder(path3, {n, m, k}) + ; + /* END loopcontractfuse TEST */ + + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("loopcontractfuse", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + + Tensor expected("expected", {N, N, N}, Format{Dense, Dense, Dense}); + expected(l,m,n) = B(i,j,k) * C(i,l) * D(j,m) * E(k,n); + expected.compile(); + expected.assemble(); + + clock_t begin; + clock_t end; + + for (int i=0; i<10; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC * 1000; + + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC * 1000; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + +} + +TEST(workspaces, loopcontractfuse_real) { + int L = 16; + int M = 16; + int N = 16; + Tensor A("A", {L, M, N}, Format{Dense, Dense, Dense}); + // Tensor B("B", {N, N, N}, Format{Dense, Sparse, Sparse}); + // Tensor C("C", {N, N}, Format{Dense, Dense}); + // Tensor D("D", {N, N}, Format{Dense, Dense}); + // Tensor E("E", {N, N}, Format{Dense, Dense}); + + std::string mat_file = util::getFromEnv("TENSOR_FILE", ""); + + // std::cout << mat_file << std::endl; + + Tensor B = read(mat_file, Format({Dense, Sparse, Sparse}), true); + B.setName("B"); + B.pack(); + + // std::cout << "B tensor successfully read and packed!\n"; + // return; + + Tensor C("C", {B.getDimension(0), L}, Format{Dense, Dense}); + for (int i=0; i D("D", {B.getDimension(1), M}, Format{Dense, Dense}); + for (int j=0; j E("E", {B.getDimension(2), N}, Format{Dense, Dense}); + for (int k=0; k path0; + vector path1 = {1}; + vector path2 = {1, 0}; + vector path3 = {1, 1}; + stmt = stmt + .reorder({l, i, j, k, m, n}) + .loopfuse(2, true, path0) + .loopfuse(2, true, path1) + .reorder(path2, {k, m, j}) + .reorder(path3, {m, n, k}) + ; + /* END loopcontractfuse_real TEST */ + + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("loopcontractfuse", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + + Tensor expected("expected", {N, N, N}, Format{Dense, Dense, Dense}); + expected(l,m,n) = B(i,j,k) * C(i,l) * D(j,m) * E(k,n); + expected.compile(); + expected.assemble(); + + clock_t begin; + clock_t end; + + for (int i=0; i<3; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC * 1000; + + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC * 1000; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + +std::cout << "workspaces, loopcontractfuse -> execution completed for matrix: " << mat_file << std::endl; + +} + +TEST(workspaces, spttm_ttm) { + int N = 16; + Tensor A("A", {N, N, N}, Format{Dense, Dense, Dense}); + Tensor B("B", {N, N, N}, Format{Dense, Sparse, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + for (int k = 0; k < N; k++) { + B.insert({i, j, k}, (double) i); + } + C.insert({i, j}, (double) j); + D.insert({i, j}, (double) i*j); + } + } + + // 5 -> A(i,l,m) = B(i,j,k) * C(j,l) * D(k,m) - + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n"); + A(i,l,m) = B(i,j,k) * C(j,l) * D(k,m); + + IndexStmt stmt = A.getAssignment().concretize(); + + std::cout << stmt << endl; + + /* BEGIN spttm_ttm TEST */ + vector path0; + vector path1 = {1}; + stmt = stmt + .reorder({l, i, j, k, m}) + .loopfuse(2, true, path0) + .reorder(path1, {m, k}) + ; + /* END spttm_ttm TEST */ + + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("spttm_ttm", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + + Tensor expected("expected", {N, N, N}, Format{Dense, Dense, Dense}); + expected(i,l,m) = B(i,j,k) * C(j,l) * D(k,m); + expected.compile(); + expected.assemble(); + + clock_t begin; + clock_t end; + + for (int i=0; i<10; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC * 1000; + + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC * 1000; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + +} + +TEST(workspaces, spttm_ttm_real) { + // int N = 16; + // Tensor A("A", {N, N, N}, Format{Dense, Dense, Dense}); + // Tensor B("B", {N, N, N}, Format{Dense, Sparse, Sparse}); + // Tensor C("C", {N, N}, Format{Dense, Dense}); + // Tensor D("D", {N, N}, Format{Dense, Dense}); + + // for (int i = 0; i < N; i++) { + // for (int j = 0; j < N; j++) { + // for (int k = 0; k < N; k++) { + // B.insert({i, j, k}, (double) i); + // } + // C.insert({i, j}, (double) j); + // D.insert({i, j}, (double) i*j); + // } + // } + + int L = 16; + int M = 16; + + std::string mat_file = util::getFromEnv("TENSOR_FILE", ""); + + // std::cout << mat_file << std::endl; + + Tensor B = read(mat_file, Format({Dense, Sparse, Sparse}), true); + B.setName("B"); + B.pack(); + + // std::cout << "B tensor successfully read and packed!\n"; + // return; + + Tensor C("C", {B.getDimension(1), L}, Format{Dense, Dense}); + for (int i=0; i D("D", {B.getDimension(2), M}, Format{Dense, Dense}); + for (int j=0; j A("A", {B.getDimension(0), L, M}, Format{Dense, Dense, Dense}); + + // 5 -> A(i,l,m) = B(i,j,k) * C(j,l) * D(k,m) - + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n"); + A(i,l,m) = B(i,j,k) * C(j,l) * D(k,m); + + IndexStmt stmt = A.getAssignment().concretize(); + + std::cout << stmt << endl; + + /* BEGIN spttm_ttm TEST */ + vector path0; + vector path1 = {1}; + vector path2 = {1, 0}; + vector path3 = {1, 0, 0}; + vector path4 = {1, 1}; + vector path5 = {1, 0, 1}; + vector path6 = {1, 0, 0, 0}; + stmt = stmt + .reorder({i, k, j, l, m}) + .loopfuse(1, true, path0) + .loopfuse(4, true, path1) + .loopfuse(3, true, path2) + .loopfuse(1, false, path3) + .reorder(path4, {m, l}) + .reorder(path5, {l, j}) + .reorder(path6, {j, k}) + ; + /* END spttm_ttm TEST */ + + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("spttm_ttm", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + + Tensor expected("expected", {B.getDimension(0), L, M}, Format{Dense, Dense, Dense}); + expected(i,l,m) = B(i,j,k) * C(j,l) * D(k,m); + expected.compile(); + expected.assemble(); + + clock_t begin; + clock_t end; + + for (int i=0; i<10; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC * 1000; + + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC * 1000; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + +} + +TEST(workspaces, loopreordercontractfuse) { + int N = 16; + Tensor A("A", {N, N, N}, Format{Dense, Dense, Dense}); + Tensor B("B", {N, N, N}, Format{Dense, Sparse, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + Tensor E("E", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + for (int k = 0; k < N; k++) { + B.insert({i, j, k}, (double) i); + } + C.insert({i, j}, (double) j); + E.insert({i, j}, (double) i*j); + D.insert({i, j}, (double) i*j); + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n"); + A(l,m,n) = B(i,j,k) * C(i,l) * D(j,m) * E(k,n); + + IndexStmt stmt = A.getAssignment().concretize(); + + std::cout << stmt << endl; + vector path1; + vector path2 = {1}; + stmt = stmt + .reorder({l,i,m, j, k, n}) + .loopfuse(2, true, path1) + .reorder(path2, {m,k,j,n}) + .loopfuse(2, true, path2) + ; + stmt = stmt + .parallelize(l, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces) + ; + + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("loopreordercontractfuse", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + A.compute(); + + Tensor expected("expected", {N, N, N}, Format{Dense, Dense, Dense}); + expected(l,m,n) = B(i,j,k) * C(i,l) * D(j,m) * E(k,n); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +} + +TEST(workspaces, sddmm) { + int N = 16; + float SPARSITY = 0.3; + vector dims{N,N}; + const IndexVar i("i"), j("j"), k("k"), l("l"); + + Tensor A("A", dims, Format{Dense, Dense}); + Tensor B("B", dims, Format{Dense, Sparse}); + Tensor C("C", dims, Format{Dense, Dense}); + Tensor D("D", dims, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + float rand_float = (float) rand() / (float) RAND_MAX; + if (rand_float < SPARSITY) + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + D.insert({i, j}, (double) i*j); + } + } + + A(i,j) = B(i,j) * C(i,k) * D(j,k); + + IndexStmt stmt = A.getAssignment().concretize(); + + vector path1; + stmt = stmt + .reorder({i,k,j}); + stmt = stmt + .loopfuse(3, true, path1); + stmt = stmt + .parallelize(i, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces) + ; + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("sddmm", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + // beging timing + A.compute(); + // end timing + + Tensor expected("expected", dims, Format{Dense, Dense}); + expected(i,j) = B(i,j) * C(i,k) * D(j,k); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +} + +TEST(workspaces, precompute2D_mul) { + int N = 16; + Tensor A("A", {N, N}, Format{Dense, Dense}); + Tensor B("B", {N, N}, Format{Dense, Dense}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + D.insert({i, j}, (double) i*j); + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"); + IndexExpr precomputedExpr = B(i,j) * C(j,k); + IndexExpr precomputedExpr2 = precomputedExpr * D(k,l); + A(i,l) = precomputedExpr2; + + IndexStmt stmt = A.getAssignment().concretize(); + TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + vector path; + stmt = stmt.precompute(precomputedExpr, {i,k}, {i,k}, ws); + stmt = stmt.precompute(ws(i,k) * D(k,l), {i,l}, {i,l}, t); + stmt = stmt.concretize(); + + std::cout << "stmt: " << stmt << std::endl; + printCodeToFile("precompute2D_mul", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + A.compute(); + + Tensor expected("expected", {N, N}, Format{Dense, Dense}); + expected(i,l) = B(i,j) * C(j,k) * D(k,l); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +} + +TEST(workspaces, precompute_sparseMul) { + int N = 16; + Tensor A("A", {N, N}, Format{Dense, Dense}); + Tensor B("B", {N, N}, Format{Dense, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + D.insert({i, j}, (double) i*j); + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"); + IndexExpr precomputedExpr = B(i,j) * C(j,k); + IndexExpr precomputedExpr2 = precomputedExpr * D(k,l); + A(i,l) = precomputedExpr2; + + IndexStmt stmt = A.getAssignment().concretize(); + TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + stmt = stmt.precompute(precomputedExpr, {i,k}, {i,k}, ws); + stmt = stmt.precompute(ws(i,k) * D(k,l), {i,l}, {i,l}, t); + stmt = stmt.concretize(); + + std::cout << "stmt: " << stmt << std::endl; + printCodeToFile("precompute2D_sparseMul", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + A.compute(); + + Tensor expected("expected", {N, N}, Format{Dense, Dense}); + expected(i,l) = B(i,j) * C(j,k) * D(k,l); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +} + +TEST(workspaces, precompute_changedSparseMul) { + int N = 16; + Tensor A("A", {N, N}, Format{Dense, Dense}); + Tensor B("B", {N, N}, Format{Dense, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + D.insert({i, j}, (double) i*j); + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"); + IndexExpr precomputedExpr = C(j,k) * D(k,l); + IndexExpr precomputedExpr2 = B(i,j) * precomputedExpr; + A(i,l) = precomputedExpr2; + + IndexStmt stmt = A.getAssignment().concretize(); + TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + stmt = stmt.precompute(precomputedExpr, {j,l}, {j,l}, ws); + stmt = stmt.precompute(B(i,j) * ws(j,l), {i,l}, {i,l}, t); + stmt = stmt.concretize(); + + std::cout << "stmt: " << stmt << std::endl; + printCodeToFile("precompute_changedSparseMul", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + A.compute(); + + Tensor expected("expected", {N, N}, Format{Dense, Dense}); + expected(i,l) = B(i,j) * C(j,k) * D(k,l); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +} + + +TEST(workspaces, precompute_tensorContraction) { + int N = 16; + + Tensor X("X", {N, N, N}, Format{Dense, Dense, Dense}); + Tensor A("A", {N, N, N}, Format{Dense, Sparse, Sparse}); + Tensor B("B", {N, N}, Format{Dense, Dense}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + D.insert({i, j}, (double) i*j); + for (int k = 0; k < N; k++) { + A.insert({i,j,k}, (double) i*j*k); + } + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n"); + TensorVar tmp("tmp", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + IndexStmt stmt = + forall(l, + where( + forall(m, + forall(k, + forall(j, + forall(n, + X(l,m,n) += tmp(j,k) * C(j,m) * D(k,n) + ) + ) + ) + ), + forall(i, + forall(j, + forall(k, + tmp(j,k) += A(i,j,k) * B(i,l) + ) + ) + ) + ) + ); + + std::cout << "stmt: " << stmt << std::endl; + printCodeToFile("precompute_tensorContraction", stmt); + + X(l,m,n) = A(i,j,k) * B(i,l) * C(j,m) * D(k,n); + X.compile(stmt.concretize()); + X.assemble(); + X.compute(); + + Tensor expected("expected", {N, N, N}, Format{Dense, Dense, Dense}); + expected(l, m, n) = A(i,j,k) * B(i,l) * C(j,m) * D(k,n); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, X); +} + + +TEST(workspaces, precompute_tensorContraction2) { + int N = 16; + + Tensor X("X", {N, N, N}, Format{Dense, Dense, Dense}); + Tensor A("A", {N, N, N}, Format{Dense, Sparse, Sparse}); + Tensor B("B", {N, N}, Format{Dense, Dense}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + D.insert({i, j}, (double) i*j); + for (int k = 0; k < N; k++) { + A.insert({i,j,k}, (double) i*j*k); + } + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n"); + TensorVar tmp1("tmp1", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + TensorVar tmp2("tmp2", Type(Float64, {(size_t)N}), Format{Dense}); + IndexStmt stmt = + forall(l, + where( + forall(m, + where( + forall(k, + forall(n, + X(l,m,n) += tmp2(k) * D(k,n) // contracts k + ) + ) + , + forall(j, + forall(k, + tmp2(k) += tmp1(j,k) * C(j,m) // contracts j + ) + ) + ) + ), + forall(i, + forall(j, + forall(k, + tmp1(j,k) += A(i,j,k) * B(i,l) // contracts i + ) + ) + ) + ) + ); + + std::cout << "stmt: " << stmt << std::endl; + printCodeToFile("precompute_tensorContraction2", stmt); + + X(l,m,n) = A(i,j,k) * B(i,l) * C(j,m) * D(k,n); + X.compile(stmt.concretize()); + X.assemble(); + X.compute(); + + Tensor expected("expected", {N, N, N}, Format{Dense, Dense, Dense}); + expected(l, m, n) = A(i,j,k) * B(i,l) * C(j,m) * D(k,n); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, X); +} + + + +TEST(workspaces, sddmmPlusSpmm) { + Type t(type(), {3,3}); + const IndexVar i("i"), j("j"), k("k"), l("l"); + + TensorVar A("A", t, Format{Dense, Dense}); + TensorVar B("B", t, Format{Dense, Sparse}); + TensorVar C("C", t, Format{Dense, Dense}); + TensorVar D("D", t, Format{Dense, Dense}); + TensorVar E("E", t, Format{Dense, Dense}); + + TensorVar tmp("tmp", Type(), Format()); + + // A(i,j) = B(i,j) * C(i,k) * D(j,k) * E(j,l) + IndexStmt fused = + forall(i, + forall(j, + forall(k, + forall(l, A(i,l) += B(i,j) * C(i,k) * D(j,k) * E(j,l)) + ) + ) + ); + + std::cout << "before topological sort: " << fused << std::endl; + fused = reorderLoopsTopologically(fused); + // std::vector order{"i", "j", "k", "l"}; + fused = fused.reorder({i, j, k, l}); + std::cout << "after topological sort: " << fused << std::endl; + + // fused = fused.precompute(B(i,j) * C(i,k) * D(j,k), {}, {}, tmp); + std::cout << "after precompute: " << fused << std::endl; + + // Kernel kernel = compile(fused); + + // IndexStmt fusedNested = + // forall(i, + // forall(j, + // where( + // forall(l, A(i,l) += tmp * E(j,l)), // consumer + // forall(k, tmp += B(i,j) * C(i,k) * D(j,k)) // producer + // ) + // ) + // ); + + // std::cout << "nested loop stmt: " << fusedNested << std::endl; +} \ No newline at end of file diff --git a/tools/taco.cpp b/tools/taco.cpp index 30558c830..45124a2d2 100644 --- a/tools/taco.cpp +++ b/tools/taco.cpp @@ -123,6 +123,12 @@ static void printUsageInfo() { "-help=scheduling for a list of scheduling commands. " "Examples: split(i,i0,i1,16), precompute(A(i,j)*x(j),i,i)."); cout << endl; + printFlag("s=\"loopfuse(;)\"", + "Specify the loopfuse command to apply fusion directive to the code. " + "Parameters take on the form of a comma-delimited list, separated by a colon, " + "And the branching point. See -help=loopfuse for a list of loopfuse commands. " + "Examples: loopfuse(3).loopfuse(1;2).loopfuse(2;2).loopfuse(1,1;1)."); + cout << endl; printFlag("c", "Generate compute kernel that simultaneously does assembly."); cout << endl; @@ -352,6 +358,23 @@ static bool setSchedulingCommands(vector> scheduleCommands, parse IndexVar fused(f); stmt = stmt.fuse(findVar(i), findVar(j), fused); + } else if (command == "loopfuse") { + taco_uassert(scheduleCommand.size() >= 1) + << "'loopfuse' scheduling directive takes more than 1 parameter; loopfuse(1)"; + cout << "loopfuse directive found\n"; + + vector path; + transform(scheduleCommand.begin(), scheduleCommand.end(), back_inserter(path), + [&](std::string &s) { return stoi(s); }); + int split = path.back(); + if (path.size() > 1) path.pop_back(); + cout << "split: " << split << endl; + for (unsigned int i=0; i> parsed = parser::ScheduleParser(argValue);