From 75293599e9e5c166318d1c6ced979289eacb6d60 Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Fri, 24 Feb 2023 15:09:22 -0500 Subject: [PATCH 01/14] initial implementation of kernel fuse directive --- include/taco/index_notation/index_notation.h | 5 + include/taco/index_notation/transformations.h | 24 ++ include/taco/parser/lexer.h | 1 + src/index_notation/index_notation.cpp | 21 ++ src/index_notation/transformations.cpp | 313 ++++++++++++++++++ src/parser/lexer.cpp | 3 + src/parser/schedule_parser.cpp | 4 + tools/taco.cpp | 24 ++ 8 files changed, 395 insertions(+) diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index dd451b337..d93bbda9b 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -646,6 +646,11 @@ class IndexStmt : public util::IntrusivePtr { IndexStmt divide(IndexVar i, IndexVar i1, IndexVar i2, size_t divideFactor) const; // TODO: TailStrategy + /// The loopfuse transformation fuses common outer loops in + /// 2 iteration graphs. + IndexStmt loopfuse(int pos, bool isProducerOnLeft, std::vector& path) const; + + /// The reorder transformation swaps two directly nested index /// variables in an iteration graph. This changes the order of /// iteration through the space and the order of tensor accesses. diff --git a/include/taco/index_notation/transformations.h b/include/taco/index_notation/transformations.h index b750e3961..5d59261fa 100644 --- a/include/taco/index_notation/transformations.h +++ b/include/taco/index_notation/transformations.h @@ -17,6 +17,7 @@ class IndexStmt; class TransformationInterface; class Reorder; class Precompute; +class LoopFuse; class ForAllReplace; class AddSuchThatPredicates; class Parallelize; @@ -32,6 +33,7 @@ class Transformation { public: Transformation(Reorder); Transformation(Precompute); + Transformation(LoopFuse); Transformation(ForAllReplace); Transformation(Parallelize); Transformation(TopoReorder); @@ -114,6 +116,28 @@ class Precompute : public TransformationInterface { /// Print a precompute command. std::ostream &operator<<(std::ostream &, const Precompute &); +/// The loopfuse optimization rewrite an index expression to precompute +/// part of the `expr` and store it to a workspace. +class LoopFuse : public TransformationInterface { +public: + LoopFuse(); + LoopFuse(int pos, bool isProducerOnLeft, std::vector& path); + + int getPos() const; + bool getIsProducerOnLeft() const; + std::vector& getPath() const; + + /// Apply the loopfuse optimization to a concrete index statement. + IndexStmt apply(IndexStmt, std::string *reason = nullptr) const; + + void print(std::ostream &os) const; + +private: + struct Content; + std::shared_ptr content; +}; + +std::ostream &operator<<(std::ostream &, const LoopFuse &); /// Replaces all occurrences of directly nested forall nodes of pattern with /// directly nested loops of replacement diff --git a/include/taco/parser/lexer.h b/include/taco/parser/lexer.h index 55dc74410..c9e185dcd 100644 --- a/include/taco/parser/lexer.h +++ b/include/taco/parser/lexer.h @@ -22,6 +22,7 @@ enum class Token { sub, mul, div, + colon, eq, eot, // End of tokens error diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index e38e3d2d3..bf8973b39 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -1854,6 +1854,26 @@ IndexStmt IndexStmt::divide(IndexVar i, IndexVar i1, IndexVar i2, size_t splitFa return transformed; } +IndexStmt IndexStmt::loopfuse(int pos, bool isProducerOnLeft, vector& path) const { + + std::cout << "Loop fuse pos: " << pos; + std::cout << ", Loop fuse isProducerOnLeft: " << isProducerOnLeft; + for (const auto& p : path) { + std::cout << " " << p; + } + std::cout << std::endl; + + string reason; + IndexStmt transformed = *this; + transformed = Transformation(LoopFuse(pos, isProducerOnLeft, path)).apply(transformed, &reason); + if (!transformed.defined()) { + taco_uerror << reason; + } + return transformed; + + return *this; +} + IndexStmt IndexStmt::precompute(IndexExpr expr, std::vector i_vars, std::vector iw_vars, TensorVar workspace) const { @@ -2048,6 +2068,7 @@ IndexStmt IndexStmt::assemble(TensorVar result, AssembleStrategy strategy, return transformed; } + IndexStmt IndexStmt::wsaccel(TensorVar& ws, bool shouldAccel, const std::vector& accelIndexVars) { if (accelIndexVars.size() == 0) { ws.setAccelIndexVars(accelIndexVars, shouldAccel); diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index d53ec58c3..0373cdc2f 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -4,6 +4,7 @@ #include "taco/index_notation/index_notation_rewriter.h" #include "taco/index_notation/index_notation_nodes.h" #include "taco/error/error_messages.h" +#include "taco/storage/index.h" #include "taco/util/collections.h" #include "taco/lower/iterator.h" #include "taco/lower/merge_lattice.h" @@ -30,6 +31,10 @@ Transformation::Transformation(Precompute precompute) : transformation(new Precompute(precompute)) { } +Transformation::Transformation(LoopFuse loopfuse) + : transformation(new LoopFuse(loopfuse)) { +} + Transformation::Transformation(ForAllReplace forallreplace) : transformation(new ForAllReplace(forallreplace)) { } @@ -233,6 +238,314 @@ std::ostream& operator<<(std::ostream& os, const SetMergeStrategy& setmergestrat return os; } +// class LoopFuse +struct LoopFuse::Content { + int pos; + int isProducerOnLeft; + std::vector path; +}; + +LoopFuse::LoopFuse() : content(nullptr) { +} + +LoopFuse::LoopFuse(int pos, bool isProducerOnLeft, std::vector& path) : content(new Content) { + content->pos = pos; + content->path = path; + content->isProducerOnLeft = isProducerOnLeft; +} + +int LoopFuse::getPos() const { + return content->pos; +} + +bool LoopFuse::getIsProducerOnLeft() const { + return content->isProducerOnLeft; +} + +std::vector& LoopFuse::getPath() const { + return content->path; +} + +IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { + INIT_REASON(reason); + + auto printVector = [](const vector& array) { + for (auto& var : array) { + cout << var << " "; + } + cout << endl; + }; + auto printSet = [](const set& array) { + for (auto& var : array) { + cout << var << " "; + } + cout << endl; + }; + + cout << "pos: " << getPos() << std::endl; + cout << "isProducerOnLeft: " << getIsProducerOnLeft() << endl; + cout << "path: "; + for (const auto& p : getPath()) { + cout << p << " " << std::endl; + } + cout << endl; + + struct GetAssignment : public IndexNotationVisitor { + using IndexNotationVisitor::visit; + Assignment innerAssignment; + vector indexAccessVars; + + void visit(const ForallNode* node) { + Forall forall(node); + indexAccessVars.push_back(forall.getIndexVar()); + + if (isa(forall.getStmt())) { + innerAssignment = to(forall.getStmt()); + } + else { + IndexNotationVisitor::visit(node); + } + } + }; + GetAssignment getAssignment; + stmt.accept(&getAssignment); + + std::cout << getAssignment.innerAssignment << std::endl; + cout << "Index access order: "; printVector(getAssignment.indexAccessVars); + + // saves the result, producer and consumer of the assignment + // result = producer * consumer + // eg: Assignment is A(i,j) += B(i,j) * C(j,k) * D(k,l) + // if the pos is 1, the result is A(i,j), the producer is B(i,j) and the consumer is C(j,k) * D(k,l) + // if the post is 2, the result is A(i,j), the producer is B(i,j) * C(j,k) and the consumer is D(k,l) + // resultVars, producerVars and consumerVars are the index variables of the result, producer and consumer + struct GetProducerAndConsumer : public IndexNotationVisitor { + using IndexNotationVisitor::visit; + int pos; + bool isProducerOnLeft; + IndexExpr result; + IndexExpr producer; + IndexExpr consumer; + vector resultVars; + set producerVars; + set consumerVars; + map> varTypes; + IndexExpr op; + + GetProducerAndConsumer(int _pos, int _isProducerOnLeft) : pos(_pos), isProducerOnLeft(_isProducerOnLeft), result(nullptr), producer(nullptr), consumer(nullptr), varTypes({}) {} + + void addIndexVar(Access access) { + // get the dimension and type of each index variable in tensor + for (unsigned long i = 0; i < access.getIndexVars().size(); i++) { + auto tensorVar = access.getTensorVar(); // Tensor variable like A, B + auto indexVar = access.getIndexVars()[i]; // Index variable like i, j + auto tensorVarType = tensorVar.getType(); + varTypes[indexVar] = make_pair(tensorVarType, tensorVarType.getShape().getDimension(i)); + } + } + + void visit(const AssignmentNode* node) { + Assignment assignment(node); + // result is stored in the left hand side of the assignment + result = assignment.getLhs(); + resultVars = assignment.getLhs().getIndexVars(); + std::cout << "result: " << result + << ", rhs: " << assignment.getRhs() + << ", freeVars: " << assignment.getFreeVars() + << ", indexVars: " << assignment.getIndexVars() + << ", indexSetRelation: " << assignment.getIndexSetRel() + << std::endl; + + // add the index variables of the result to the map + addIndexVar(to(assignment.getLhs())); + + // visit the tensor contraction expression on the left hand side of += or = + IndexNotationVisitor::visit(assignment.getRhs()); + } + + // lhs is a multiplication in the tensor contraction + void visit(const MulNode* node) { + Mul mul(node); + IndexNotationVisitor::visit(mul.getA()); + IndexNotationVisitor::visit(mul.getB()); + } + + void visit(const AccessNode* node) { + Access access(node); + cout << "pos: " << pos << ", access: " << access << endl; + IndexExpr* it; + set* vars; + if ((pos > 0 && isProducerOnLeft) || (pos <= 0 && !isProducerOnLeft)) { it = &producer; vars = &producerVars; } + else { it = &consumer; vars = &consumerVars; } + + if (*it == nullptr) *it = access; + else *it = *it * access; + + for (const auto& var : access.getIndexVars()) { + vars->insert(var); + } + + // add the index variables of the access to the map + addIndexVar(access); + pos--; + } + }; + GetProducerAndConsumer getProducerAndConsumer(getPos(), getIsProducerOnLeft()); + stmt.accept(&getProducerAndConsumer); + + std::cout << "result: " << getProducerAndConsumer.result << std::endl; + std::cout << "producer: " << getProducerAndConsumer.producer << std::endl; + std::cout << "consumer: " << getProducerAndConsumer.consumer << std::endl; + std::cout << "resultVars: " << getProducerAndConsumer.resultVars << std::endl; + cout << "producerVars: "; printSet(getProducerAndConsumer.producerVars); + cout << "consumerVars: "; printSet(getProducerAndConsumer.consumerVars); + + // indices in the temporary comes from the producer indices (IndexVars) + // that are either in result indices or in consumer indices + // indices in the producer that are neither in producer indices nor in consumer indices + // gets contracted within the producer computation + vector temporaryVars; + for (auto& var : getProducerAndConsumer.producerVars) { + auto itC = getProducerAndConsumer.consumerVars.find(var); + auto itR = find(getProducerAndConsumer.resultVars.begin(), getProducerAndConsumer.resultVars.end(), var); + if (itC != getProducerAndConsumer.consumerVars.end() || + itR != getProducerAndConsumer.resultVars.end()) { + temporaryVars.push_back(var); + } + } + cout << "temporaryVars: "; printVector(temporaryVars); + + // get the producer index access pattern + // get the consumer index access pattern + vector producerLoopVars; + vector consumerLoopVars; + for (auto& var : getAssignment.indexAccessVars) { + auto itP = getProducerAndConsumer.producerVars.find(var); + auto itC = getProducerAndConsumer.consumerVars.find(var); + auto itR = find(getProducerAndConsumer.resultVars.begin(), getProducerAndConsumer.resultVars.end(), var); + // check if variable is in the producer + if (itP != getProducerAndConsumer.producerVars.end()) { + producerLoopVars.push_back(var); + } + // check if variable is in the consumer or result + if (itC != getProducerAndConsumer.consumerVars.end() || itR != getProducerAndConsumer.resultVars.end()) { + consumerLoopVars.push_back(var); + } + } + + // check if there are common outer loops in producerAccessOrder and consumerAccessOrder + vector commonLoopVars; + for (auto& var : getAssignment.indexAccessVars) { + auto itC = find(consumerLoopVars.begin(), consumerLoopVars.end(), var); + auto itP = find(producerLoopVars.begin(), producerLoopVars.end(), var); + if (itC != consumerLoopVars.end() && itP != producerLoopVars.end()) { + commonLoopVars.push_back(var); + temporaryVars.erase(remove(temporaryVars.begin(), temporaryVars.end(), var), temporaryVars.end()); + } + else { + break; + } + } + // for (auto& var : producerLoopVars) { + // auto it = find(consumerLoopVars.begin(), consumerLoopVars.end(), var); + // if (it != consumerLoopVars.end()) { + // commonLoopVars.push_back(var); + // temporaryVars.erase(remove(temporaryVars.begin(), temporaryVars.end(), var), temporaryVars.end()); + // } + // else { + // break; + // } + // } + cout << "commonOuterLoops: "; printVector(commonLoopVars); + cout << "temporaryVars: "; printVector(temporaryVars); + + // remove commonLoopVars from producerLoopVars and consumerLoopVars + for (auto& var : commonLoopVars) { + producerLoopVars.erase(remove(producerLoopVars.begin(), producerLoopVars.end(), var), producerLoopVars.end()); + consumerLoopVars.erase(remove(consumerLoopVars.begin(), consumerLoopVars.end(), var), consumerLoopVars.end()); + } + cout << "producerLoopVars: "; printVector(producerLoopVars); + cout << "consumerLoopVars: "; printVector(consumerLoopVars); + + // create the intermediate tensor + vector temporaryDims; + vector temporaryModes; + // populate shape of the intermediate tensor + auto populateDimension = + [&](map>& varTypes) { + for (auto& var : temporaryVars) { + temporaryDims.push_back(varTypes[var].second); + temporaryModes.push_back(ModeFormat{Dense}); + } + }; + populateDimension(getProducerAndConsumer.varTypes); + TensorVar intermediateTensor("ws", Type(Float64, temporaryDims)); + Access workspace(intermediateTensor, temporaryVars); + cout << "intermediateTensor: " << intermediateTensor << endl; + cout << "workspace: " << workspace << endl; + + Assignment producerAssignment(workspace, getProducerAndConsumer.producer, getAssignment.innerAssignment.getOperator()); + cout << "producerAssignment: " << producerAssignment << endl; + + Assignment consumerAssignment; + if (!getIsProducerOnLeft()) { + consumerAssignment = Assignment(to(getProducerAndConsumer.result), getProducerAndConsumer.consumer * workspace, getAssignment.innerAssignment.getOperator()); + } else { + consumerAssignment = Assignment(to(getProducerAndConsumer.result), workspace * getProducerAndConsumer.consumer, getAssignment.innerAssignment.getOperator()); + } + cout << "consumerAssignment: " << consumerAssignment << endl; + + // check if there are common outer loops + // if there are common outer loops, then remove those common outer loops from the temporaryVars + + // rewrite the index notation to use the temporary + // eg: Assignment is A(i,j) += B(i,j) * C(j,k) * D(k,l) + // T(i,k) += B(i,j) * C(j,k) is the producer and A(i,j) += T(i,k) * D(k,l) is the consumer + struct ProducerConsumerRewriter : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + Assignment& producer; + Assignment& consumer; + vector& commonLoopVars; + vector& producerLoopVars; + vector& consumerLoopVars; + + // constructor + ProducerConsumerRewriter(Assignment& producer, Assignment& consumer, vector& commonLoopVars, vector& producerLoopVars, vector& consumerLoopVars) : + producer(producer), consumer(consumer), commonLoopVars(commonLoopVars), producerLoopVars(producerLoopVars), consumerLoopVars(consumerLoopVars) {} + + IndexStmt generateForalls(IndexStmt innerStmt, vector indexVars) { + auto returnStmt = innerStmt; + for (auto it = indexVars.rbegin(); it != indexVars.rend(); ++it) { + returnStmt = forall(*it, returnStmt); + } + + return returnStmt; + }; + + // should find the path to get to this loop to perform the rewrite + void visit(const ForallNode* node) { + IndexStmt consumer = generateForalls(this->consumer, consumerLoopVars); + IndexStmt producer = generateForalls(this->producer, producerLoopVars); + Where where(consumer, producer); + stmt = generateForalls(where, commonLoopVars); + return; + } + + }; + + ProducerConsumerRewriter rewriter(producerAssignment, consumerAssignment, commonLoopVars, producerLoopVars, consumerLoopVars); + stmt = rewriter.rewrite(stmt); + cout << "stmt: " << stmt << endl; + + return stmt; +} + + + +void LoopFuse::print(std::ostream &os) const { + os << "fuse(" << getPos() << ", " << util::join(getPath()) << ")"; +} + // class Precompute struct Precompute::Content { IndexExpr expr; diff --git a/src/parser/lexer.cpp b/src/parser/lexer.cpp index a490840a8..e4cf3830b 100644 --- a/src/parser/lexer.cpp +++ b/src/parser/lexer.cpp @@ -84,6 +84,9 @@ Token Lexer::getToken() { case '/': token = Token::div; break; + case ';': + token = Token::colon; + break; case '=': token = Token::eq; break; diff --git a/src/parser/schedule_parser.cpp b/src/parser/schedule_parser.cpp index 0666601da..2b099f4b3 100644 --- a/src/parser/schedule_parser.cpp +++ b/src/parser/schedule_parser.cpp @@ -52,6 +52,10 @@ vector> ScheduleParser(const string argValue) { current_element += lexer.tokenString(tok); parenthesesCnt--; break; + case parser::Token::colon: + current_schedule.push_back(current_element); + current_element = ""; + break; case parser::Token::comma: if (curlyParenthesesCnt > 0) { // multiple indexes inside of a {} list; pass it through diff --git a/tools/taco.cpp b/tools/taco.cpp index 30558c830..45124a2d2 100644 --- a/tools/taco.cpp +++ b/tools/taco.cpp @@ -123,6 +123,12 @@ static void printUsageInfo() { "-help=scheduling for a list of scheduling commands. " "Examples: split(i,i0,i1,16), precompute(A(i,j)*x(j),i,i)."); cout << endl; + printFlag("s=\"loopfuse(;)\"", + "Specify the loopfuse command to apply fusion directive to the code. " + "Parameters take on the form of a comma-delimited list, separated by a colon, " + "And the branching point. See -help=loopfuse for a list of loopfuse commands. " + "Examples: loopfuse(3).loopfuse(1;2).loopfuse(2;2).loopfuse(1,1;1)."); + cout << endl; printFlag("c", "Generate compute kernel that simultaneously does assembly."); cout << endl; @@ -352,6 +358,23 @@ static bool setSchedulingCommands(vector> scheduleCommands, parse IndexVar fused(f); stmt = stmt.fuse(findVar(i), findVar(j), fused); + } else if (command == "loopfuse") { + taco_uassert(scheduleCommand.size() >= 1) + << "'loopfuse' scheduling directive takes more than 1 parameter; loopfuse(1)"; + cout << "loopfuse directive found\n"; + + vector path; + transform(scheduleCommand.begin(), scheduleCommand.end(), back_inserter(path), + [&](std::string &s) { return stoi(s); }); + int split = path.back(); + if (path.size() > 1) path.pop_back(); + cout << "split: " << split << endl; + for (unsigned int i=0; i> parsed = parser::ScheduleParser(argValue); From 99adfca736e9392045bc272d501d86a8f67529a3 Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Fri, 24 Feb 2023 15:25:47 -0500 Subject: [PATCH 02/14] add a few initial test cases --- test/tests-indexstmt.cpp | 43 ++++- test/tests-workspaces.cpp | 363 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 405 insertions(+), 1 deletion(-) diff --git a/test/tests-indexstmt.cpp b/test/tests-indexstmt.cpp index e2a972430..a62776bad 100644 --- a/test/tests-indexstmt.cpp +++ b/test/tests-indexstmt.cpp @@ -2,6 +2,8 @@ #include "test_tensors.h" #include "taco/tensor.h" #include "taco/index_notation/index_notation.h" +#include "taco/index_notation/kernel.h" +#include "taco/index_notation/transformations.h" using namespace taco; const IndexVar i("i"), j("j"), k("k"); @@ -84,4 +86,43 @@ TEST(indexstmt, spmm) { } - +TEST(indexstmt, sddmmPlusSpmm) { + Type t(type(), {3,3}); + const IndexVar i("i"), j("j"), k("k"), l("l"); + + TensorVar A("A", t, Format{Dense, Dense}); + TensorVar B("B", t, Format{Dense, Sparse}); + TensorVar C("C", t, Format{Dense, Dense}); + TensorVar D("D", t, Format{Dense, Dense}); + TensorVar E("E", t, Format{Dense, Dense}); + + TensorVar tmp("tmp", Type(), Format()); + + // A(i,j) = B(i,j) * C(i,k) * D(j,k) * E(j,l) + IndexStmt fused = + forall(i, + forall(j, + forall(k, + forall(l, A(i,l) += B(i,j) * C(i,k) * D(j,k) * E(j,l)) + ) + ) + ); + + std::cout << "before topological sort: " << fused << std::endl; + fused = reorderLoopsTopologically(fused); + std::cout << "after topological sort: " << fused << std::endl; + + Kernel kernel = compile(fused); + + IndexStmt fusedNested = + forall(i, + forall(j, + where( + forall(l, A(i,l) += tmp * E(j,l)), // consumer + forall(k, tmp += B(i,j) * C(i,k) * D(j,k)) // producer + ) + ) + ); + + std::cout << "nested loop stmt: " << fusedNested << std::endl; +} diff --git a/test/tests-workspaces.cpp b/test/tests-workspaces.cpp index 350fb8538..f3f6071b1 100644 --- a/test/tests-workspaces.cpp +++ b/test/tests-workspaces.cpp @@ -1,6 +1,8 @@ +#include #include #include #include +#include #include "test.h" #include "test_tensors.h" #include "taco/tensor.h" @@ -10,6 +12,23 @@ using namespace taco; +void printCodeToFile(string filename, IndexStmt stmt) { + stringstream source; + + string file_path = "eval_generated/"; + mkdir(file_path.c_str(), 0777); + + std::shared_ptr codegen = ir::CodeGen::init_default(source, ir::CodeGen::ImplementationGen); + ir::Stmt compute = lower(stmt, "compute", true, true); + codegen->compile(compute, true); + + ofstream source_file; + string file_ending = should_use_CUDA_codegen() ? ".cu" : ".c"; + source_file.open(file_path + filename + file_ending); + source_file << source.str(); + source_file.close(); +} + TEST(workspaces, tile_vecElemMul_NoTail) { Tensor A("A", {16}, Format{Dense}); @@ -36,6 +55,7 @@ TEST(workspaces, tile_vecElemMul_NoTail) { .split(i_bounded, i0, i1, 4) .precompute(precomputedExpr, i1, i1, precomputed); + printCodeToFile("tile_vecElemMul_NoTail", stmt); A.compile(stmt); A.assemble(); A.compute(); @@ -73,6 +93,7 @@ TEST(workspaces, tile_vecElemMul_Tail1) { stmt = stmt.bound(i, i_bounded, 16, BoundType::MaxExact) .split(i_bounded, i0, i1, 5) .precompute(precomputedExpr, i1, i1, precomputed); + printCodeToFile("tile_vecElemMul_Tail1", stmt); A.compile(stmt.concretize()); A.assemble(); @@ -111,6 +132,7 @@ TEST(workspaces, tile_vecElemMul_Tail2) { stmt = stmt.bound(i, i_bounded, 17, BoundType::MaxExact) .split(i_bounded, i0, i1, 4) .precompute(precomputedExpr, i1, i1, precomputed); + printCodeToFile("tile_vecElemMul_Tail2", stmt); A.compile(stmt.concretize()); A.assemble(); @@ -161,6 +183,7 @@ TEST(workspaces, tile_denseMatMul) { .split(i_bounded, i0, i1, 4); stmt = stmt.precompute(precomputedExpr, i1, i1, precomputed); + printCodeToFile("tile_denseMatMul", stmt); A.compile(stmt.concretize()); A.assemble(); @@ -209,6 +232,9 @@ TEST(workspaces, precompute2D_add) { TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); stmt = stmt.precompute(precomputedExpr, {i, j}, {i, j}, ws); + std::cout << stmt << endl; + printCodeToFile("precompute2D_ad", stmt); + A.compile(stmt.concretize()); A.assemble(); A.compute(); @@ -253,6 +279,8 @@ TEST(workspaces, precompute4D_add) { Format{Dense, Dense, Dense, Dense}); stmt = stmt.precompute(precomputedExpr, {i, j, k, l}, {i, j, k, l}, ws1) .precompute(ws1(i, j, k, l) + D(i, j, k, l), {i, j, k, l}, {i, j, k ,l}, ws2); + std::cout << stmt << endl; + printCodeToFile("precompute4D_add", stmt); A.compile(stmt.concretize()); A.assemble(); @@ -295,6 +323,9 @@ TEST(workspaces, precompute4D_multireduce) { TensorVar ws2("ws2", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); stmt = stmt.precompute(precomputedExpr, {i, j, m}, {i, j, m}, ws1) .precompute(ws1(i, j, m) * D(m, n), {i, j}, {i, j}, ws2); + + std::cout << stmt << endl; + printCodeToFile("precompute4D_multireduce", stmt); A.compile(stmt.concretize()); A.assemble(); @@ -335,6 +366,9 @@ TEST(workspaces, precompute3D_TspV) { stmt = stmt.precompute(precomputedExpr, {i, j, k}, {i, j, k}, ws); stmt = stmt.concretize(); + std::cout << stmt << endl; + printCodeToFile("precompute3D_TspV", stmt); + A.compile(stmt); A.assemble(); A.compute(); @@ -379,6 +413,9 @@ TEST(workspaces, precompute3D_multipleWS) { stmt = stmt.precompute(ws(i, j, k) * c(k), {i, j}, {i, j}, t); stmt = stmt.concretize(); + std::cout << stmt << endl; + printCodeToFile("precompute3D_multipleWS", stmt); + A.compile(stmt); A.assemble(); A.compute(); @@ -422,6 +459,9 @@ TEST(workspaces, precompute3D_renamedIVars_TspV) { stmt = stmt.precompute(precomputedExpr, {i, j, k}, {iw, jw, kw}, ws); stmt = stmt.concretize(); + std::cout << stmt << endl; + printCodeToFile("precompute3D_renamedIVars_TspV", stmt); + A.compile(stmt); A.assemble(); A.compute(); @@ -475,6 +515,9 @@ TEST(workspaces, tile_dotProduct_1) { stmt = stmt.concretize(); + std::cout << stmt << endl; + printCodeToFile("tile_dotProduct_1", stmt); + A.compile(stmt); A.assemble(); A.compute(); @@ -538,6 +581,9 @@ TEST(workspaces, tile_dotProduct_2) { stmt = stmt.concretize(); + std::cout << stmt << endl; + printCodeToFile("tile_dotProduct_2", stmt); + stmt = stmt.wsaccel(precomputed, false); A.compile(stmt); A.assemble(); @@ -588,6 +634,9 @@ TEST(workspaces, tile_dotProduct_3) { stmt = stmt.concretize(); + std::cout << stmt << endl; + printCodeToFile("tile_dotProduct_3", stmt); + A.compile(stmt); A.assemble(); A.compute(); @@ -599,3 +648,317 @@ TEST(workspaces, tile_dotProduct_3) { expected.compute(); ASSERT_TENSOR_EQ(expected, A); } + + +TEST(workspaces, loopfuse) { + int N = 16; + Tensor A("A", {N, N}, Format{Dense, Dense}); + Tensor B("B", {N, N}, Format{Dense, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + Tensor E("E", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + E.insert({i, j}, (double) i*j); + D.insert({i, j}, (double) i*j); + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"); + IndexExpr precomputedExpr = B(i,j) * C(j,k); + IndexExpr precomputedExpr2 = precomputedExpr * D(k,l); + // A(i,l) = precomputedExpr2; + A(i,m) = B(i,j) * C(j,k) * D(k,l) * E(l,m); + + IndexStmt stmt = A.getAssignment().concretize(); + TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + std::cout << stmt << endl; + vector path; + stmt = stmt + // .reorder({i,j,k,l}) + .reorder({i,j,k, l, m}) + .loopfuse(2, true, path) + .parallelize(i, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces) + ; + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + + std::cout << "stmt: " << stmt << std::endl; + printCodeToFile("loopfuse", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + A.compute(); + + return; + +} + +TEST(workspaces, precompute2D_mul) { + int N = 16; + Tensor A("A", {N, N}, Format{Dense, Dense}); + Tensor B("B", {N, N}, Format{Dense, Dense}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + D.insert({i, j}, (double) i*j); + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"); + IndexExpr precomputedExpr = B(i,j) * C(j,k); + IndexExpr precomputedExpr2 = precomputedExpr * D(k,l); + A(i,l) = precomputedExpr2; + + IndexStmt stmt = A.getAssignment().concretize(); + TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + vector path; + stmt = stmt.loopfuse(2, true, path); + + stmt = stmt.precompute(precomputedExpr, {i,k}, {i,k}, ws); + stmt = stmt.precompute(ws(i,k) * D(k,l), {i,l}, {i,l}, t); + stmt = stmt.concretize(); + + std::cout << "stmt: " << stmt << std::endl; + printCodeToFile("precompute2D_mul", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + A.compute(); + + Tensor expected("expected", {N, N}, Format{Dense, Dense}); + expected(i,l) = B(i,j) * C(j,k) * D(k,l); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +} + +TEST(workspaces, precompute_sparseMul) { + int N = 16; + Tensor A("A", {N, N}, Format{Dense, Dense}); + Tensor B("B", {N, N}, Format{Dense, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + D.insert({i, j}, (double) i*j); + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"); + IndexExpr precomputedExpr = B(i,j) * C(j,k); + IndexExpr precomputedExpr2 = precomputedExpr * D(k,l); + A(i,l) = precomputedExpr2; + + IndexStmt stmt = A.getAssignment().concretize(); + TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + stmt = stmt.precompute(precomputedExpr, {i,k}, {i,k}, ws); + stmt = stmt.precompute(ws(i,k) * D(k,l), {i,l}, {i,l}, t); + stmt = stmt.concretize(); + + std::cout << "stmt: " << stmt << std::endl; + printCodeToFile("precompute2D_sparseMul", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + A.compute(); + + Tensor expected("expected", {N, N}, Format{Dense, Dense}); + expected(i,l) = B(i,j) * C(j,k) * D(k,l); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +} + +TEST(workspaces, precompute_changedSparseMul) { + int N = 16; + Tensor A("A", {N, N}, Format{Dense, Dense}); + Tensor B("B", {N, N}, Format{Dense, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + D.insert({i, j}, (double) i*j); + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"); + IndexExpr precomputedExpr = C(j,k) * D(k,l); + IndexExpr precomputedExpr2 = B(i,j) * precomputedExpr; + A(i,l) = precomputedExpr2; + + IndexStmt stmt = A.getAssignment().concretize(); + TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + stmt = stmt.precompute(precomputedExpr, {j,l}, {j,l}, ws); + stmt = stmt.precompute(B(i,j) * ws(j,l), {i,l}, {i,l}, t); + stmt = stmt.concretize(); + + std::cout << "stmt: " << stmt << std::endl; + printCodeToFile("precompute_changedSparseMul", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + A.compute(); + + Tensor expected("expected", {N, N}, Format{Dense, Dense}); + expected(i,l) = B(i,j) * C(j,k) * D(k,l); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +} + + +TEST(workspaces, precompute_tensorContraction) { + int N = 16; + + Tensor X("X", {N, N, N}, Format{Dense, Dense, Dense}); + Tensor A("A", {N, N, N}, Format{Dense, Sparse, Sparse}); + Tensor B("B", {N, N}, Format{Dense, Dense}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + D.insert({i, j}, (double) i*j); + for (int k = 0; k < N; k++) { + A.insert({i,j,k}, (double) i*j*k); + } + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n"); + TensorVar tmp("tmp", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + IndexStmt stmt = + forall(l, + where( + forall(m, + forall(k, + forall(j, + forall(n, + X(l,m,n) += tmp(j,k) * C(j,m) * D(k,n) + ) + ) + ) + ), + forall(i, + forall(j, + forall(k, + tmp(j,k) += A(i,j,k) * B(i,l) + ) + ) + ) + ) + ); + + std::cout << "stmt: " << stmt << std::endl; + printCodeToFile("precompute_tensorContraction", stmt); + + X(l,m,n) = A(i,j,k) * B(i,l) * C(j,m) * D(k,n); + X.compile(stmt.concretize()); + X.assemble(); + X.compute(); + + Tensor expected("expected", {N, N, N}, Format{Dense, Dense, Dense}); + expected(l, m, n) = A(i,j,k) * B(i,l) * C(j,m) * D(k,n); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, X); +} + + +TEST(workspaces, precompute_tensorContraction2) { + int N = 16; + + Tensor X("X", {N, N, N}, Format{Dense, Dense, Dense}); + Tensor A("A", {N, N, N}, Format{Dense, Sparse, Sparse}); + Tensor B("B", {N, N}, Format{Dense, Dense}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + D.insert({i, j}, (double) i*j); + for (int k = 0; k < N; k++) { + A.insert({i,j,k}, (double) i*j*k); + } + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n"); + TensorVar tmp1("tmp1", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + TensorVar tmp2("tmp2", Type(Float64, {(size_t)N}), Format{Dense}); + IndexStmt stmt = + forall(l, + where( + forall(m, + where( + forall(k, + forall(n, + X(l,m,n) += tmp2(k) * D(k,n) // contracts k + ) + ) + , + forall(j, + forall(k, + tmp2(k) += tmp1(j,k) * C(j,m) // contracts j + ) + ) + ) + ), + forall(i, + forall(j, + forall(k, + tmp1(j,k) += A(i,j,k) * B(i,l) // contracts i + ) + ) + ) + ) + ); + + std::cout << "stmt: " << stmt << std::endl; + printCodeToFile("precompute_tensorContraction2", stmt); + + X(l,m,n) = A(i,j,k) * B(i,l) * C(j,m) * D(k,n); + X.compile(stmt.concretize()); + X.assemble(); + X.compute(); + + Tensor expected("expected", {N, N, N}, Format{Dense, Dense, Dense}); + expected(l, m, n) = A(i,j,k) * B(i,l) * C(j,m) * D(k,n); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, X); +} + + From 6a31a0881e8bed1c7e5a747478f2f53b1860c40b Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Sat, 25 Feb 2023 18:13:27 -0500 Subject: [PATCH 03/14] add recursive functionality for kernel fussion --- src/index_notation/transformations.cpp | 120 ++++++++++++++++++++----- test/tests-workspaces.cpp | 68 +++++++++++--- 2 files changed, 156 insertions(+), 32 deletions(-) diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index 0373cdc2f..fad8a934c 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -118,6 +118,7 @@ IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const { } }) ); + cout << "currentOrdering: " << util::join(currentOrdering) << endl; if (!content->pattern_ordered && currentOrdering == getreplacepattern()) { taco_iassert(getreplacepattern().size() == 2); @@ -294,10 +295,27 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { using IndexNotationVisitor::visit; Assignment innerAssignment; vector indexAccessVars; + vector indexVarsUntilBranch; + unsigned int pathIdx = 0; + vector path; + + // insert constructor with path + GetAssignment(vector& _path) : path(_path) {} void visit(const ForallNode* node) { Forall forall(node); + cout << "Forall: " << forall << endl; + cout << "pathIdx: " << pathIdx << endl; + // print path + cout << "path: "; + for (const auto& p : path) { + cout << p << " " << std::endl; + } + cout << endl; indexAccessVars.push_back(forall.getIndexVar()); + if (pathIdx < path.size()) { + indexVarsUntilBranch.push_back(forall.getIndexVar()); + } if (isa(forall.getStmt())) { innerAssignment = to(forall.getStmt()); @@ -306,8 +324,22 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { IndexNotationVisitor::visit(node); } } + + void visit(const WhereNode* node) { + Where where(node); + cout << "Where: " << where << endl; + + if (!path[pathIdx]) { // if path[pathIdx] == 0, go to the producer + pathIdx++; + IndexNotationVisitor::visit(node->producer); + } else { + pathIdx++; + IndexNotationVisitor::visit(node->consumer); + } + + } }; - GetAssignment getAssignment; + GetAssignment getAssignment(getPath()); stmt.accept(&getAssignment); std::cout << getAssignment.innerAssignment << std::endl; @@ -322,7 +354,9 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { struct GetProducerAndConsumer : public IndexNotationVisitor { using IndexNotationVisitor::visit; int pos; + int pathIdx = 0; bool isProducerOnLeft; + vector path; IndexExpr result; IndexExpr producer; IndexExpr consumer; @@ -332,7 +366,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { map> varTypes; IndexExpr op; - GetProducerAndConsumer(int _pos, int _isProducerOnLeft) : pos(_pos), isProducerOnLeft(_isProducerOnLeft), result(nullptr), producer(nullptr), consumer(nullptr), varTypes({}) {} + GetProducerAndConsumer(int _pos, int _isProducerOnLeft, vector& _path) : pos(_pos), isProducerOnLeft(_isProducerOnLeft), path(_path), result(nullptr), producer(nullptr), consumer(nullptr), varTypes({}) {} void addIndexVar(Access access) { // get the dimension and type of each index variable in tensor @@ -363,6 +397,20 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { IndexNotationVisitor::visit(assignment.getRhs()); } + void visit(const WhereNode* node) { + Where where(node); + cout << "Where: " << where << endl; + + // select the path to visit + if (!path[pathIdx]) { // if path[pathIdx] == 0, go to the producer + pathIdx++; + IndexNotationVisitor::visit(node->producer); + } else { + pathIdx++; + IndexNotationVisitor::visit(node->consumer); + } + } + // lhs is a multiplication in the tensor contraction void visit(const MulNode* node) { Mul mul(node); @@ -390,7 +438,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { pos--; } }; - GetProducerAndConsumer getProducerAndConsumer(getPos(), getIsProducerOnLeft()); + GetProducerAndConsumer getProducerAndConsumer(getPos(), getIsProducerOnLeft(), getPath()); stmt.accept(&getProducerAndConsumer); std::cout << "result: " << getProducerAndConsumer.result << std::endl; @@ -432,6 +480,19 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { consumerLoopVars.push_back(var); } } + cout << "producerLoopVars2: "; printVector(producerLoopVars); + cout << "consumerLoopVars2: "; printVector(consumerLoopVars); + + // remove indices from producerLoops and consumerLoops that are in getAssignment.indexVarsUntilBranch + cout << "indexVarsUntilBranch: "; printVector(getAssignment.indexVarsUntilBranch); + for (auto& var : getAssignment.indexVarsUntilBranch) { + producerLoopVars.erase(remove(producerLoopVars.begin(), producerLoopVars.end(), var), producerLoopVars.end()); + consumerLoopVars.erase(remove(consumerLoopVars.begin(), consumerLoopVars.end(), var), consumerLoopVars.end()); + } + + cout << "producerLoopVars3: "; printVector(producerLoopVars); + cout << "consumerLoopVars3: "; printVector(consumerLoopVars); + // check if there are common outer loops in producerAccessOrder and consumerAccessOrder vector commonLoopVars; @@ -446,16 +507,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { break; } } - // for (auto& var : producerLoopVars) { - // auto it = find(consumerLoopVars.begin(), consumerLoopVars.end(), var); - // if (it != consumerLoopVars.end()) { - // commonLoopVars.push_back(var); - // temporaryVars.erase(remove(temporaryVars.begin(), temporaryVars.end(), var), temporaryVars.end()); - // } - // else { - // break; - // } - // } cout << "commonOuterLoops: "; printVector(commonLoopVars); cout << "temporaryVars: "; printVector(temporaryVars); @@ -479,7 +530,8 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { } }; populateDimension(getProducerAndConsumer.varTypes); - TensorVar intermediateTensor("ws", Type(Float64, temporaryDims)); + Access resultAccess = to(getProducerAndConsumer.result); + TensorVar intermediateTensor("t_" + resultAccess.getTensorVar().getName(), Type(Float64, temporaryDims)); Access workspace(intermediateTensor, temporaryVars); cout << "intermediateTensor: " << intermediateTensor << endl; cout << "workspace: " << workspace << endl; @@ -503,6 +555,8 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { // T(i,k) += B(i,j) * C(j,k) is the producer and A(i,j) += T(i,k) * D(k,l) is the consumer struct ProducerConsumerRewriter : public IndexNotationRewriter { using IndexNotationRewriter::visit; + vector& path; + vector visited; Assignment& producer; Assignment& consumer; vector& commonLoopVars; @@ -510,8 +564,8 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { vector& consumerLoopVars; // constructor - ProducerConsumerRewriter(Assignment& producer, Assignment& consumer, vector& commonLoopVars, vector& producerLoopVars, vector& consumerLoopVars) : - producer(producer), consumer(consumer), commonLoopVars(commonLoopVars), producerLoopVars(producerLoopVars), consumerLoopVars(consumerLoopVars) {} + ProducerConsumerRewriter(vector& _path, Assignment& producer, Assignment& consumer, vector& commonLoopVars, vector& producerLoopVars, vector& consumerLoopVars) : + path(_path), producer(producer), consumer(consumer), commonLoopVars(commonLoopVars), producerLoopVars(producerLoopVars), consumerLoopVars(consumerLoopVars) {} IndexStmt generateForalls(IndexStmt innerStmt, vector indexVars) { auto returnStmt = innerStmt; @@ -524,16 +578,38 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { // should find the path to get to this loop to perform the rewrite void visit(const ForallNode* node) { - IndexStmt consumer = generateForalls(this->consumer, consumerLoopVars); - IndexStmt producer = generateForalls(this->producer, producerLoopVars); - Where where(consumer, producer); - stmt = generateForalls(where, commonLoopVars); - return; + if (visited == path) { + IndexStmt consumer = generateForalls(this->consumer, consumerLoopVars); + IndexStmt producer = generateForalls(this->producer, producerLoopVars); + Where where(consumer, producer); + stmt = generateForalls(where, commonLoopVars); + return; + } + IndexNotationRewriter::visit(node); + } + + void visit(const WhereNode* node) { + Where where(node); + cout << "Where: " << where << endl; + + visited.push_back(0); + IndexStmt producer = rewrite(node->producer); + visited.pop_back(); + visited.push_back(1); + IndexStmt consumer = rewrite(node->consumer); + visited.pop_back(); + if (producer == node->producer && consumer == node->consumer) { + stmt = node; + } + else { + stmt = new WhereNode(consumer, producer); + } + } }; - ProducerConsumerRewriter rewriter(producerAssignment, consumerAssignment, commonLoopVars, producerLoopVars, consumerLoopVars); + ProducerConsumerRewriter rewriter(getPath(), producerAssignment, consumerAssignment, commonLoopVars, producerLoopVars, consumerLoopVars); stmt = rewriter.rewrite(stmt); cout << "stmt: " << stmt << endl; diff --git a/test/tests-workspaces.cpp b/test/tests-workspaces.cpp index f3f6071b1..beef666e3 100644 --- a/test/tests-workspaces.cpp +++ b/test/tests-workspaces.cpp @@ -668,9 +668,6 @@ TEST(workspaces, loopfuse) { } IndexVar i("i"), j("j"), k("k"), l("l"), m("m"); - IndexExpr precomputedExpr = B(i,j) * C(j,k); - IndexExpr precomputedExpr2 = precomputedExpr * D(k,l); - // A(i,l) = precomputedExpr2; A(i,m) = B(i,j) * C(j,k) * D(k,l) * E(l,m); IndexStmt stmt = A.getAssignment().concretize(); @@ -678,26 +675,31 @@ TEST(workspaces, loopfuse) { TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); std::cout << stmt << endl; - vector path; + vector path1; + vector path2 = {0}; stmt = stmt - // .reorder({i,j,k,l}) .reorder({i,j,k, l, m}) - .loopfuse(2, true, path) + .loopfuse(3, true, path1) + .loopfuse(2, true, path2) + ; + stmt = stmt .parallelize(i, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces) ; stmt = stmt.concretize(); cout << "final stmt: " << stmt << endl; - - std::cout << "stmt: " << stmt << std::endl; printCodeToFile("loopfuse", stmt); A.compile(stmt.concretize()); A.assemble(); A.compute(); - return; - + Tensor expected("expected", {N, N}, Format{Dense, Dense}); + expected(i,m) = B(i,j) * C(j,k) * D(k,l) * E(l,m); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); } TEST(workspaces, precompute2D_mul) { @@ -962,3 +964,49 @@ TEST(workspaces, precompute_tensorContraction2) { } + +TEST(workspaces, sddmmPlusSpmm) { + Type t(type(), {3,3}); + const IndexVar i("i"), j("j"), k("k"), l("l"); + + TensorVar A("A", t, Format{Dense, Dense}); + TensorVar B("B", t, Format{Dense, Sparse}); + TensorVar C("C", t, Format{Dense, Dense}); + TensorVar D("D", t, Format{Dense, Dense}); + TensorVar E("E", t, Format{Dense, Dense}); + + TensorVar tmp("tmp", Type(), Format()); + + // A(i,j) = B(i,j) * C(i,k) * D(j,k) * E(j,l) + IndexStmt fused = + forall(i, + forall(j, + forall(k, + forall(l, A(i,l) += B(i,j) * C(i,k) * D(j,k) * E(j,l)) + ) + ) + ); + + std::cout << "before topological sort: " << fused << std::endl; + fused = reorderLoopsTopologically(fused); + // std::vector order{"i", "j", "k", "l"}; + fused = fused.reorder({i, j, k, l}); + std::cout << "after topological sort: " << fused << std::endl; + + // fused = fused.precompute(B(i,j) * C(i,k) * D(j,k), {}, {}, tmp); + std::cout << "after precompute: " << fused << std::endl; + + // Kernel kernel = compile(fused); + + // IndexStmt fusedNested = + // forall(i, + // forall(j, + // where( + // forall(l, A(i,l) += tmp * E(j,l)), // consumer + // forall(k, tmp += B(i,j) * C(i,k) * D(j,k)) // producer + // ) + // ) + // ); + + // std::cout << "nested loop stmt: " << fusedNested << std::endl; +} \ No newline at end of file From 2479221b390ad32ea4f73b63fb0eb1c0f45063a4 Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Sat, 25 Feb 2023 22:03:21 -0500 Subject: [PATCH 04/14] remove comments --- src/index_notation/transformations.cpp | 64 +++----------------------- 1 file changed, 6 insertions(+), 58 deletions(-) diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index fad8a934c..1716a2034 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -283,14 +283,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { cout << endl; }; - cout << "pos: " << getPos() << std::endl; - cout << "isProducerOnLeft: " << getIsProducerOnLeft() << endl; - cout << "path: "; - for (const auto& p : getPath()) { - cout << p << " " << std::endl; - } - cout << endl; - struct GetAssignment : public IndexNotationVisitor { using IndexNotationVisitor::visit; Assignment innerAssignment; @@ -304,14 +296,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { void visit(const ForallNode* node) { Forall forall(node); - cout << "Forall: " << forall << endl; - cout << "pathIdx: " << pathIdx << endl; - // print path - cout << "path: "; - for (const auto& p : path) { - cout << p << " " << std::endl; - } - cout << endl; + indexAccessVars.push_back(forall.getIndexVar()); if (pathIdx < path.size()) { indexVarsUntilBranch.push_back(forall.getIndexVar()); @@ -327,7 +312,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { void visit(const WhereNode* node) { Where where(node); - cout << "Where: " << where << endl; if (!path[pathIdx]) { // if path[pathIdx] == 0, go to the producer pathIdx++; @@ -342,9 +326,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { GetAssignment getAssignment(getPath()); stmt.accept(&getAssignment); - std::cout << getAssignment.innerAssignment << std::endl; - cout << "Index access order: "; printVector(getAssignment.indexAccessVars); - // saves the result, producer and consumer of the assignment // result = producer * consumer // eg: Assignment is A(i,j) += B(i,j) * C(j,k) * D(k,l) @@ -364,7 +345,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { set producerVars; set consumerVars; map> varTypes; - IndexExpr op; GetProducerAndConsumer(int _pos, int _isProducerOnLeft, vector& _path) : pos(_pos), isProducerOnLeft(_isProducerOnLeft), path(_path), result(nullptr), producer(nullptr), consumer(nullptr), varTypes({}) {} @@ -383,12 +363,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { // result is stored in the left hand side of the assignment result = assignment.getLhs(); resultVars = assignment.getLhs().getIndexVars(); - std::cout << "result: " << result - << ", rhs: " << assignment.getRhs() - << ", freeVars: " << assignment.getFreeVars() - << ", indexVars: " << assignment.getIndexVars() - << ", indexSetRelation: " << assignment.getIndexSetRel() - << std::endl; // add the index variables of the result to the map addIndexVar(to(assignment.getLhs())); @@ -420,7 +394,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { void visit(const AccessNode* node) { Access access(node); - cout << "pos: " << pos << ", access: " << access << endl; IndexExpr* it; set* vars; if ((pos > 0 && isProducerOnLeft) || (pos <= 0 && !isProducerOnLeft)) { it = &producer; vars = &producerVars; } @@ -441,13 +414,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { GetProducerAndConsumer getProducerAndConsumer(getPos(), getIsProducerOnLeft(), getPath()); stmt.accept(&getProducerAndConsumer); - std::cout << "result: " << getProducerAndConsumer.result << std::endl; - std::cout << "producer: " << getProducerAndConsumer.producer << std::endl; - std::cout << "consumer: " << getProducerAndConsumer.consumer << std::endl; - std::cout << "resultVars: " << getProducerAndConsumer.resultVars << std::endl; - cout << "producerVars: "; printSet(getProducerAndConsumer.producerVars); - cout << "consumerVars: "; printSet(getProducerAndConsumer.consumerVars); - // indices in the temporary comes from the producer indices (IndexVars) // that are either in result indices or in consumer indices // indices in the producer that are neither in producer indices nor in consumer indices @@ -461,7 +427,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { temporaryVars.push_back(var); } } - cout << "temporaryVars: "; printVector(temporaryVars); // get the producer index access pattern // get the consumer index access pattern @@ -480,20 +445,13 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { consumerLoopVars.push_back(var); } } - cout << "producerLoopVars2: "; printVector(producerLoopVars); - cout << "consumerLoopVars2: "; printVector(consumerLoopVars); // remove indices from producerLoops and consumerLoops that are in getAssignment.indexVarsUntilBranch - cout << "indexVarsUntilBranch: "; printVector(getAssignment.indexVarsUntilBranch); for (auto& var : getAssignment.indexVarsUntilBranch) { producerLoopVars.erase(remove(producerLoopVars.begin(), producerLoopVars.end(), var), producerLoopVars.end()); consumerLoopVars.erase(remove(consumerLoopVars.begin(), consumerLoopVars.end(), var), consumerLoopVars.end()); } - cout << "producerLoopVars3: "; printVector(producerLoopVars); - cout << "consumerLoopVars3: "; printVector(consumerLoopVars); - - // check if there are common outer loops in producerAccessOrder and consumerAccessOrder vector commonLoopVars; for (auto& var : getAssignment.indexAccessVars) { @@ -507,16 +465,12 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { break; } } - cout << "commonOuterLoops: "; printVector(commonLoopVars); - cout << "temporaryVars: "; printVector(temporaryVars); // remove commonLoopVars from producerLoopVars and consumerLoopVars for (auto& var : commonLoopVars) { producerLoopVars.erase(remove(producerLoopVars.begin(), producerLoopVars.end(), var), producerLoopVars.end()); consumerLoopVars.erase(remove(consumerLoopVars.begin(), consumerLoopVars.end(), var), consumerLoopVars.end()); } - cout << "producerLoopVars: "; printVector(producerLoopVars); - cout << "consumerLoopVars: "; printVector(consumerLoopVars); // create the intermediate tensor vector temporaryDims; @@ -533,22 +487,18 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { Access resultAccess = to(getProducerAndConsumer.result); TensorVar intermediateTensor("t_" + resultAccess.getTensorVar().getName(), Type(Float64, temporaryDims)); Access workspace(intermediateTensor, temporaryVars); - cout << "intermediateTensor: " << intermediateTensor << endl; - cout << "workspace: " << workspace << endl; Assignment producerAssignment(workspace, getProducerAndConsumer.producer, getAssignment.innerAssignment.getOperator()); - cout << "producerAssignment: " << producerAssignment << endl; Assignment consumerAssignment; + // if the producer is on left, then consumer is constructed by + // multiplying workspace * consumer and if the producer is on right, + // then the consumer is constructed by multiplying consumer * workspace if (!getIsProducerOnLeft()) { consumerAssignment = Assignment(to(getProducerAndConsumer.result), getProducerAndConsumer.consumer * workspace, getAssignment.innerAssignment.getOperator()); } else { consumerAssignment = Assignment(to(getProducerAndConsumer.result), workspace * getProducerAndConsumer.consumer, getAssignment.innerAssignment.getOperator()); } - cout << "consumerAssignment: " << consumerAssignment << endl; - - // check if there are common outer loops - // if there are common outer loops, then remove those common outer loops from the temporaryVars // rewrite the index notation to use the temporary // eg: Assignment is A(i,j) += B(i,j) * C(j,k) * D(k,l) @@ -578,6 +528,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { // should find the path to get to this loop to perform the rewrite void visit(const ForallNode* node) { + // at the end of the path, rewrite should happen using the producer and consumer if (visited == path) { IndexStmt consumer = generateForalls(this->consumer, consumerLoopVars); IndexStmt producer = generateForalls(this->producer, producerLoopVars); @@ -590,8 +541,8 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { void visit(const WhereNode* node) { Where where(node); - cout << "Where: " << where << endl; + // add 0 to visited if the producer is visited and 1 if the consumer is visited visited.push_back(0); IndexStmt producer = rewrite(node->producer); visited.pop_back(); @@ -604,14 +555,11 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { else { stmt = new WhereNode(consumer, producer); } - } - }; ProducerConsumerRewriter rewriter(getPath(), producerAssignment, consumerAssignment, commonLoopVars, producerLoopVars, consumerLoopVars); stmt = rewriter.rewrite(stmt); - cout << "stmt: " << stmt << endl; return stmt; } From 58640114af550ee428b9925639e5eca36f793b56 Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Mon, 6 Mar 2023 17:18:40 -0500 Subject: [PATCH 05/14] fix naming of temporary tensor --- src/index_notation/transformations.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index 1716a2034..ea7654189 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -485,7 +485,11 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { }; populateDimension(getProducerAndConsumer.varTypes); Access resultAccess = to(getProducerAndConsumer.result); - TensorVar intermediateTensor("t_" + resultAccess.getTensorVar().getName(), Type(Float64, temporaryDims)); + string path = ""; + for (auto& p : getPath()) { + path += to_string(p); + } + TensorVar intermediateTensor("t_" + resultAccess.getTensorVar().getName() + path, Type(Float64, temporaryDims)); Access workspace(intermediateTensor, temporaryVars); Assignment producerAssignment(workspace, getProducerAndConsumer.producer, getAssignment.innerAssignment.getOperator()); From 672facf6854978f55f0f4d9c209ccc46593fbcae Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Mon, 6 Mar 2023 18:32:59 -0500 Subject: [PATCH 06/14] fix index merge inside a where clause --- src/index_notation/transformations.cpp | 7 +++- test/tests-workspaces.cpp | 55 ++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index ea7654189..c60b52b52 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -4,6 +4,7 @@ #include "taco/index_notation/index_notation_rewriter.h" #include "taco/index_notation/index_notation_nodes.h" #include "taco/error/error_messages.h" +#include "taco/index_notation/index_notation_visitor.h" #include "taco/storage/index.h" #include "taco/util/collections.h" #include "taco/lower/iterator.h" @@ -288,6 +289,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { Assignment innerAssignment; vector indexAccessVars; vector indexVarsUntilBranch; + vector indexVarsAfterBranch; unsigned int pathIdx = 0; vector path; @@ -301,6 +303,9 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { if (pathIdx < path.size()) { indexVarsUntilBranch.push_back(forall.getIndexVar()); } + if (pathIdx >= path.size()) { + indexVarsAfterBranch.push_back(forall.getIndexVar()); + } if (isa(forall.getStmt())) { innerAssignment = to(forall.getStmt()); @@ -454,7 +459,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { // check if there are common outer loops in producerAccessOrder and consumerAccessOrder vector commonLoopVars; - for (auto& var : getAssignment.indexAccessVars) { + for (auto& var : getAssignment.indexVarsAfterBranch) { auto itC = find(consumerLoopVars.begin(), consumerLoopVars.end(), var); auto itP = find(producerLoopVars.begin(), producerLoopVars.end(), var); if (itC != consumerLoopVars.end() && itP != producerLoopVars.end()) { diff --git a/test/tests-workspaces.cpp b/test/tests-workspaces.cpp index beef666e3..0ddfb8b2d 100644 --- a/test/tests-workspaces.cpp +++ b/test/tests-workspaces.cpp @@ -702,6 +702,61 @@ TEST(workspaces, loopfuse) { ASSERT_TENSOR_EQ(expected, A); } + + +TEST(workspaces, loopcontractfuse) { + int N = 16; + Tensor A("A", {N, N, N}, Format{Dense, Dense, Dense}); + Tensor B("B", {N, N, N}, Format{Dense, Sparse, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + Tensor E("E", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + for (int k = 0; k < N; k++) { + B.insert({i, j, k}, (double) i); + } + C.insert({i, j}, (double) j); + E.insert({i, j}, (double) i*j); + D.insert({i, j}, (double) i*j); + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n"); + A(l,m,n) = B(i,j,k) * C(i,l) * D(j,m) * E(k,n); + + IndexStmt stmt = A.getAssignment().concretize(); + + std::cout << stmt << endl; + vector path1; + vector path2 = {1}; + stmt = stmt + .reorder({l,i,m, j, k, n}) + .loopfuse(2, true, path1) + .loopfuse(2, true, path2) + ; + stmt = stmt + .parallelize(l, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces) + ; + + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("loopcontractfuse", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + A.compute(); + + Tensor expected("expected", {N, N, N}, Format{Dense, Dense, Dense}); + expected(l,m,n) = B(i,j,k) * C(i,l) * D(j,m) * E(k,n); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +} + TEST(workspaces, precompute2D_mul) { int N = 16; Tensor A("A", {N, N}, Format{Dense, Dense}); From c3caf52ee04b0b1489b2b243f6038a646ae34b5c Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Mon, 6 Mar 2023 18:59:14 -0500 Subject: [PATCH 07/14] add reordering for index stmt with branches --- include/taco/index_notation/index_notation.h | 3 + include/taco/index_notation/transformations.h | 2 + src/index_notation/index_notation.cpp | 11 ++ src/index_notation/transformations.cpp | 123 ++++++++++++++++-- test/tests-workspaces.cpp | 54 ++++++++ 5 files changed, 185 insertions(+), 8 deletions(-) diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index d93bbda9b..26a130ccd 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -668,6 +668,9 @@ class IndexStmt : public util::IntrusivePtr { /// reorder takes a new ordering for a set of index variables that are directly nested in the iteration order IndexStmt reorder(std::vector reorderedvars) const; + /// reorders the index variables in a nested structure with where clauses + IndexStmt reorder(std::vector path, std::vector reorderedvars) const; + /// The mergeby transformation specifies how to merge iterators on /// the given index variable. By default, if an iterator is used for windowing /// it will be merged with the "gallop" strategy. diff --git a/include/taco/index_notation/transformations.h b/include/taco/index_notation/transformations.h index 5d59261fa..a38494387 100644 --- a/include/taco/index_notation/transformations.h +++ b/include/taco/index_notation/transformations.h @@ -67,10 +67,12 @@ class Reorder : public TransformationInterface { public: Reorder(IndexVar i, IndexVar j); Reorder(std::vector replacePattern); + Reorder(std::vector path, std::vector replacePattern); IndexVar geti() const; IndexVar getj() const; const std::vector& getreplacepattern() const; + const std::vector& getpath() const; /// Apply the reorder optimization to a concrete index statement. Returns /// an undefined statement and a reason if the statement cannot be lowered. diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index bf8973b39..3079255c9 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -1927,6 +1927,17 @@ IndexStmt IndexStmt::reorder(std::vector reorderedvars) const { return transformed; } +IndexStmt IndexStmt::reorder(std::vector path, std::vector reorderedvars) const { + string reason; + cout << "Index statement path: " << util::join(path) << endl; + cout << "Index statement reorderedvars: " << reorderedvars << endl; + IndexStmt transformed = Reorder(path, reorderedvars).apply(*this, &reason); + if (!transformed.defined()) { + taco_uerror << reason; + } + return transformed; +} + IndexStmt IndexStmt::mergeby(IndexVar i, MergeStrategy strategy) const { string reason; IndexStmt transformed = SetMergeStrategy(i, strategy).apply(*this, &reason); diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index c60b52b52..55ca39501 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -64,6 +64,7 @@ std::ostream& operator<<(std::ostream& os, const Transformation& t) { // class Reorder struct Reorder::Content { + std::vector path; std::vector replacePattern; bool pattern_ordered; // In case of Reorder(i, j) need to change replacePattern ordering to actually reorder }; @@ -78,6 +79,12 @@ Reorder::Reorder(std::vector replacePattern) : content(new Conte content->pattern_ordered = true; } +Reorder::Reorder(std::vector path, std::vector replacePattern) : content(new Content) { + content->path = path; + content->replacePattern = replacePattern; + content->pattern_ordered = true; +} + IndexVar Reorder::geti() const { return content->replacePattern[0]; } @@ -93,13 +100,66 @@ const std::vector& Reorder::getreplacepattern() const { return content->replacePattern; } +const std::vector& Reorder::getpath() const { + return content->path; +} + IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const { INIT_REASON(reason); string r; - if (!isConcreteNotation(stmt, &r)) { - *reason = "The index statement is not valid concrete index notation: " + r; - return IndexStmt(); + + // TODO - Add a different check for concrete index notation with branching + // if (!isConcreteNotation(stmt, &r)) { + // *reason = "The index statement is not valid concrete index notation: " + r; + // return IndexStmt(); + // } + + IndexStmt originalStmt = stmt; + struct ReorderVisitor : public IndexNotationVisitor { + using IndexNotationVisitor::visit; + vector& path; + unsigned int pathIdx = 0; + IndexStmt innerStmt; + + ReorderVisitor(vector& path) : path(path) {} + + void visit(const ForallNode* node) { + if (pathIdx == path.size()) { + innerStmt = IndexStmt(node); + return; + } + IndexNotationVisitor::visit(node); + } + + void visit(const WhereNode* node) { + + Where where(node); + + if (pathIdx == path.size()) { + innerStmt = IndexStmt(node); + return; + } + + if (!path[pathIdx]) { + pathIdx++; + IndexNotationVisitor::visit(node->producer); + } else { + pathIdx++; + IndexNotationVisitor::visit(node->consumer); + } + } + }; + + cout << "original statement: " << originalStmt << endl; + ReorderVisitor reorderVisitor(content->path); + + auto p = getpath(); + cout << "path: " << util::join(p) << endl; + if (p.size() > 0) { + originalStmt.accept(&reorderVisitor); + cout << "reordering statment: " << reorderVisitor.innerStmt << endl; + stmt = reorderVisitor.innerStmt; } // collect current ordering of IndexVars @@ -130,7 +190,52 @@ IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const { *reason = "The foralls of reorder pattern: " + util::join(getreplacepattern()) + " were not directly nested."; return IndexStmt(); } - return ForAllReplace(currentOrdering, getreplacepattern()).apply(stmt, reason); + + cout << "replacePattern: " << util::join(getreplacepattern()) << endl; + auto reorderedStmt = ForAllReplace(currentOrdering, getreplacepattern()).apply(stmt, reason); + + + struct ReorderedRewriter : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + + IndexStmt reorderedStmt; + vector& path; + vector visited; + + ReorderedRewriter(IndexStmt reorderedStmt, vector& path) : reorderedStmt(reorderedStmt), path(path) {} + + void visit(const ForallNode* node) { + // at the end of the path, rewrite should happen using the producer and consumer + if (visited == path) { + stmt = reorderedStmt; + return; + } + IndexNotationRewriter::visit(node); + } + + void visit(const WhereNode* node) { + Where where(node); + + // add 0 to visited if the producer is visited and 1 if the consumer is visited + visited.push_back(0); + IndexStmt producer = rewrite(node->producer); + visited.pop_back(); + visited.push_back(1); + IndexStmt consumer = rewrite(node->consumer); + visited.pop_back(); + if (producer == node->producer && consumer == node->consumer) { + stmt = node; + } + else { + stmt = new WhereNode(consumer, producer); + } + + } + }; + ReorderedRewriter reorderedRewriter(reorderedStmt, content->path); + stmt = reorderedRewriter.rewrite(originalStmt); + + return stmt; } void Reorder::print(std::ostream& os) const { @@ -1068,10 +1173,12 @@ IndexStmt ForAllReplace::apply(IndexStmt stmt, string* reason) const { INIT_REASON(reason); string r; - if (!isConcreteNotation(stmt, &r)) { - *reason = "The index statement is not valid concrete index notation: " + r; - return IndexStmt(); - } + + // TODO - Add a different check for concrete index notation with branching + // if (!isConcreteNotation(stmt, &r)) { + // *reason = "The index statement is not valid concrete index notation: " + r; + // return IndexStmt(); + // } /// Since all IndexVars can only appear once, assume replacement will work and error if it doesn't struct ForAllReplaceRewriter : public IndexNotationRewriter { diff --git a/test/tests-workspaces.cpp b/test/tests-workspaces.cpp index 0ddfb8b2d..ec084456e 100644 --- a/test/tests-workspaces.cpp +++ b/test/tests-workspaces.cpp @@ -757,6 +757,60 @@ TEST(workspaces, loopcontractfuse) { ASSERT_TENSOR_EQ(expected, A); } +TEST(workspaces, loopreordercontractfuse) { + int N = 16; + Tensor A("A", {N, N, N}, Format{Dense, Dense, Dense}); + Tensor B("B", {N, N, N}, Format{Dense, Sparse, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + Tensor E("E", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + for (int k = 0; k < N; k++) { + B.insert({i, j, k}, (double) i); + } + C.insert({i, j}, (double) j); + E.insert({i, j}, (double) i*j); + D.insert({i, j}, (double) i*j); + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n"); + A(l,m,n) = B(i,j,k) * C(i,l) * D(j,m) * E(k,n); + + IndexStmt stmt = A.getAssignment().concretize(); + + std::cout << stmt << endl; + vector path1; + vector path2 = {1}; + stmt = stmt + .reorder({l,i,m, j, k, n}) + .loopfuse(2, true, path1) + .reorder(path2, {m,k,j,n}) + .loopfuse(2, true, path2) + ; + stmt = stmt + .parallelize(l, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces) + ; + + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("loopreordercontractfuse", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + A.compute(); + + Tensor expected("expected", {N, N, N}, Format{Dense, Dense, Dense}); + expected(l,m,n) = B(i,j,k) * C(i,l) * D(j,m) * E(k,n); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +} + TEST(workspaces, precompute2D_mul) { int N = 16; Tensor A("A", {N, N}, Format{Dense, Dense}); From f0b25a1ae2fd03b819f473a770d2fd4db34e589a Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Mon, 6 Mar 2023 19:00:34 -0500 Subject: [PATCH 08/14] remove comments in reordering --- src/index_notation/transformations.cpp | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index 55ca39501..504fc17fb 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -150,18 +150,8 @@ IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const { } } }; - - cout << "original statement: " << originalStmt << endl; ReorderVisitor reorderVisitor(content->path); - auto p = getpath(); - cout << "path: " << util::join(p) << endl; - if (p.size() > 0) { - originalStmt.accept(&reorderVisitor); - cout << "reordering statment: " << reorderVisitor.innerStmt << endl; - stmt = reorderVisitor.innerStmt; - } - // collect current ordering of IndexVars bool startedMatch = false; std::vector currentOrdering; @@ -179,7 +169,6 @@ IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const { } }) ); - cout << "currentOrdering: " << util::join(currentOrdering) << endl; if (!content->pattern_ordered && currentOrdering == getreplacepattern()) { taco_iassert(getreplacepattern().size() == 2); @@ -191,10 +180,8 @@ IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const { return IndexStmt(); } - cout << "replacePattern: " << util::join(getreplacepattern()) << endl; auto reorderedStmt = ForAllReplace(currentOrdering, getreplacepattern()).apply(stmt, reason); - struct ReorderedRewriter : public IndexNotationRewriter { using IndexNotationRewriter::visit; From 270daf8690d3eb2832a543a87c196dec3ccef821 Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Mon, 6 Mar 2023 19:08:30 -0500 Subject: [PATCH 09/14] add comments to address review and remove more prints --- include/taco/index_notation/index_notation.h | 7 +++++++ include/taco/parser/lexer.h | 2 +- src/index_notation/index_notation.cpp | 12 +----------- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index 26a130ccd..f2072d756 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -648,6 +648,13 @@ class IndexStmt : public util::IntrusivePtr { /// The loopfuse transformation fuses common outer loops in /// 2 iteration graphs. + /// when performing loopfuse operation on an already branched index statement + /// eg: forall(l, where(forall(ijk, T(j,k) += A*B), forall(mjkn, X(l,m,n) += T*C*D))) + /// and we want to further breakdown T*C*D into T2 = T*C and T2*D + /// we can use the path vector to specify the branch we want to apply the fuse on + /// eg: loopfuse(2, true, {1}) where 2 refers to breaking T*C*D at the 2nd position + /// and true refers to making T*C as the producer (if false, then C*D will be the producer if used with 1) + /// and {1} refers to the branch we want to apply the fuse on IndexStmt loopfuse(int pos, bool isProducerOnLeft, std::vector& path) const; diff --git a/include/taco/parser/lexer.h b/include/taco/parser/lexer.h index c9e185dcd..e304c2f71 100644 --- a/include/taco/parser/lexer.h +++ b/include/taco/parser/lexer.h @@ -22,7 +22,7 @@ enum class Token { sub, mul, div, - colon, + colon, // numbers before the colon indicate the path to branch in a branched iteration graph eq, eot, // End of tokens error diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 3079255c9..721580606 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -1855,15 +1855,7 @@ IndexStmt IndexStmt::divide(IndexVar i, IndexVar i1, IndexVar i2, size_t splitFa } IndexStmt IndexStmt::loopfuse(int pos, bool isProducerOnLeft, vector& path) const { - - std::cout << "Loop fuse pos: " << pos; - std::cout << ", Loop fuse isProducerOnLeft: " << isProducerOnLeft; - for (const auto& p : path) { - std::cout << " " << p; - } - std::cout << std::endl; - - string reason; + string reason; // reason saves the error message if the transformation fails IndexStmt transformed = *this; transformed = Transformation(LoopFuse(pos, isProducerOnLeft, path)).apply(transformed, &reason); if (!transformed.defined()) { @@ -1929,8 +1921,6 @@ IndexStmt IndexStmt::reorder(std::vector reorderedvars) const { IndexStmt IndexStmt::reorder(std::vector path, std::vector reorderedvars) const { string reason; - cout << "Index statement path: " << util::join(path) << endl; - cout << "Index statement reorderedvars: " << reorderedvars << endl; IndexStmt transformed = Reorder(path, reorderedvars).apply(*this, &reason); if (!transformed.defined()) { taco_uerror << reason; From 2ebfbc71a48261bb0617598dc4ffc5525f2515aa Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Tue, 7 Mar 2023 17:07:29 -0500 Subject: [PATCH 10/14] fix producer consumer internchange when the producer is at the end of the assignment, the argument packing needs to be changed according to the changed index statement --- include/taco/tensor.h | 1 + src/tensor.cpp | 87 +++++++++++++++++++++++++++++++++++++++ test/tests-workspaces.cpp | 56 ++++++++++++++++++++++++- 3 files changed, 143 insertions(+), 1 deletion(-) diff --git a/include/taco/tensor.h b/include/taco/tensor.h index c462cbd32..75af68ba2 100644 --- a/include/taco/tensor.h +++ b/include/taco/tensor.h @@ -429,6 +429,7 @@ class TensorBase { /// Compute the given expression and put the values in the tensor storage. void compute(); + void compute(IndexStmt stmt); /// Compile, assemble and compute as needed. void evaluate(); diff --git a/src/tensor.cpp b/src/tensor.cpp index 257c396c3..eb4b4595a 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -775,6 +775,41 @@ static inline map getTensors(const IndexExpr& expr) { return getOperands.arguments; } +static inline map getTensors(const IndexStmt& stmt, vector& operands) { + struct GetOperands : public IndexNotationVisitor { + using IndexNotationVisitor::visit; + vector& operands; + map arguments; + + GetOperands(vector& operands) : operands(operands) {} + + void visit(const AccessNode* node) { + if (!isa(node)) { + return; // temporary ignore + } + Access ac = Access(node); + taco_iassert(isa(node)) << "Unknown subexpression"; + + if (!util::contains(arguments, node->tensorVar)) { + arguments.insert({node->tensorVar, to(node)->tensor}); + operands.push_back(node->tensorVar); + } + + // Also add any tensors backing index sets of tensor accesses. + for (auto& p : node->indexSetModes) { + auto tv = p.second.tensor.getTensorVar(); + if (!util::contains(arguments, tv)) { + arguments.insert({tv, p.second.tensor}); + operands.push_back(tv); + } + } + } + }; + GetOperands getOperands(operands); + stmt.accept(&getOperands); + return getOperands.arguments; +} + static inline vector packArguments(const TensorBase& tensor) { vector arguments; @@ -805,6 +840,35 @@ vector packArguments(const TensorBase& tensor) { return arguments; } +static inline +vector packArguments(const TensorBase& tensor, const IndexStmt stmt) { + vector arguments; + + // Pack the result tensor + arguments.push_back(tensor.getStorage()); + + // Pack any index sets on the result tensor at the front of the arguments list. + auto lhs = getNode(tensor.getAssignment().getLhs()); + // We check isa rather than isa to catch cases + // where the underlying access is represented with the base AccessNode class. + if (isa(lhs)) { + auto indexSetModes = to(lhs)->indexSetModes; + for (auto& it : indexSetModes) { + arguments.push_back(it.second.tensor.getStorage()); + } + } + + // Pack operand tensors + std::vector operands; + auto tensors = getTensors(stmt, operands); + for (auto& operand : operands) { + taco_iassert(util::contains(tensors, operand)); + arguments.push_back(tensors.at(operand).getStorage()); + } + + return arguments; +} + void TensorBase::assemble() { taco_uassert(!needsCompile()) << error::assemble_without_compile; if (!needsAssemble()) { @@ -849,6 +913,29 @@ void TensorBase::compute() { } } +void TensorBase::compute(IndexStmt stmt) { + taco_uassert(!needsCompile()) << error::compute_without_compile; + if (!needsCompute()) { + return; + } + setNeedsCompute(false); + // Sync operand tensors if needed. + auto operands = getTensors(getAssignment().getRhs()); + for (auto& operand : operands) { + operand.second.syncValues(); + operand.second.removeDependentTensor(*this); + } + + auto arguments = packArguments(*this, stmt); + this->content->module->callFuncPacked("compute", arguments.data()); + + if (content->assembleWhileCompute) { + setNeedsAssemble(false); + taco_tensor_t* tensorData = ((taco_tensor_t*)arguments[0]); + content->valuesSize = unpackTensorData(*tensorData, *this); + } +} + void TensorBase::evaluate() { this->compile(); if (!getAssignment().getOperator().defined()) { diff --git a/test/tests-workspaces.cpp b/test/tests-workspaces.cpp index ec084456e..e34935305 100644 --- a/test/tests-workspaces.cpp +++ b/test/tests-workspaces.cpp @@ -652,6 +652,7 @@ TEST(workspaces, tile_dotProduct_3) { TEST(workspaces, loopfuse) { int N = 16; + float SPARSITY = 0.3; Tensor A("A", {N, N}, Format{Dense, Dense}); Tensor B("B", {N, N}, Format{Dense, Sparse}); Tensor C("C", {N, N}, Format{Dense, Dense}); @@ -660,7 +661,9 @@ TEST(workspaces, loopfuse) { for (int i = 0; i < N; i++) { for (int j = 0; j < N; j++) { - B.insert({i, j}, (double) i); + float rand_float = (float) rand() / (float) RAND_MAX; + if (rand_float < SPARSITY) + B.insert({i, j}, (double) i); C.insert({i, j}, (double) j); E.insert({i, j}, (double) i*j); D.insert({i, j}, (double) i*j); @@ -703,6 +706,57 @@ TEST(workspaces, loopfuse) { } +TEST(workspaces, loopreversefuse) { + int N = 16; + float SPARSITY = 0.3; + Tensor A("A", {N, N}, Format{Dense, Dense}); + Tensor B("B", {N, N}, Format{Dense, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + Tensor E("E", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + float rand_float = (float) rand() / (float) RAND_MAX; + if (rand_float < SPARSITY) + B.insert({i, j}, (double) rand_float); + C.insert({i, j}, (double) j); + E.insert({i, j}, (double) i*j); + D.insert({i, j}, (double) i*j); + } + } + + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"); + A(i,m) = B(i,j) * C(j,k) * D(k,l) * E(l,m); + + IndexStmt stmt = A.getAssignment().concretize(); + + std::cout << stmt << endl; + vector path1; + stmt = stmt + .reorder({m,k,l,i,j}) + .loopfuse(2, false, path1) + ; + stmt = stmt + .parallelize(m, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces) + ; + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("loopreversefuse", stmt); + + A.compile(stmt); + B.pack(); + A.assemble(); + A.compute(stmt); + + Tensor expected("expected", {N, N}, Format{Dense, Dense}); + expected(i,m) = B(i,j) * C(j,k) * D(k,l) * E(l,m); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +} TEST(workspaces, loopcontractfuse) { int N = 16; From 1cc5cf398ea098981e39b1a0936f3e8a51b44743 Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Wed, 8 Mar 2023 11:31:22 -0500 Subject: [PATCH 11/14] fix seg fault when consumer is NULL and add sddmm test case --- src/index_notation/transformations.cpp | 6 ++-- test/tests-workspaces.cpp | 50 ++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index 504fc17fb..07dc3f363 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -12,6 +12,7 @@ #include "taco/lower/mode.h" #include "taco/lower/mode_format_impl.h" +#include #include #include #include @@ -470,7 +471,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { void visit(const WhereNode* node) { Where where(node); - cout << "Where: " << where << endl; // select the path to visit if (!path[pathIdx]) { // if path[pathIdx] == 0, go to the producer @@ -596,9 +596,9 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const { // multiplying workspace * consumer and if the producer is on right, // then the consumer is constructed by multiplying consumer * workspace if (!getIsProducerOnLeft()) { - consumerAssignment = Assignment(to(getProducerAndConsumer.result), getProducerAndConsumer.consumer * workspace, getAssignment.innerAssignment.getOperator()); + consumerAssignment = Assignment(to(getProducerAndConsumer.result), getProducerAndConsumer.consumer == NULL ? workspace : getProducerAndConsumer.consumer * workspace, getAssignment.innerAssignment.getOperator()); } else { - consumerAssignment = Assignment(to(getProducerAndConsumer.result), workspace * getProducerAndConsumer.consumer, getAssignment.innerAssignment.getOperator()); + consumerAssignment = Assignment(to(getProducerAndConsumer.result), getProducerAndConsumer.consumer == NULL ? workspace : workspace * getProducerAndConsumer.consumer, getAssignment.innerAssignment.getOperator()); } // rewrite the index notation to use the temporary diff --git a/test/tests-workspaces.cpp b/test/tests-workspaces.cpp index e34935305..c253dc5a3 100644 --- a/test/tests-workspaces.cpp +++ b/test/tests-workspaces.cpp @@ -865,6 +865,56 @@ TEST(workspaces, loopreordercontractfuse) { ASSERT_TENSOR_EQ(expected, A); } +TEST(workspaces, sddmm) { + int N = 16; + float SPARSITY = 0.3; + vector dims{N,N}; + const IndexVar i("i"), j("j"), k("k"), l("l"); + + Tensor A("A", dims, Format{Dense, Dense}); + Tensor B("B", dims, Format{Dense, Sparse}); + Tensor C("C", dims, Format{Dense, Dense}); + Tensor D("D", dims, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + float rand_float = (float) rand() / (float) RAND_MAX; + if (rand_float < SPARSITY) + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + D.insert({i, j}, (double) i*j); + } + } + + A(i,j) = B(i,j) * C(i,k) * D(j,k); + + IndexStmt stmt = A.getAssignment().concretize(); + + vector path1; + stmt = stmt + .reorder({i,k,j}); + stmt = stmt + .loopfuse(3, true, path1); + stmt = stmt + .parallelize(i, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces) + ; + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("sddmm", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + A.compute(); + + Tensor expected("expected", dims, Format{Dense, Dense}); + expected(i,j) = B(i,j) * C(i,k) * D(j,k); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, A); +} + TEST(workspaces, precompute2D_mul) { int N = 16; Tensor A("A", {N, N}, Format{Dense, Dense}); From 78b56f2d61f253f8d09bd97bd48823ee337a9e09 Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Wed, 8 Mar 2023 18:54:14 -0500 Subject: [PATCH 12/14] fix failing reorder with branching add missing inner statement search if branch path is given --- src/index_notation/transformations.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index 07dc3f363..4d708e111 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -152,6 +152,10 @@ IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const { } }; ReorderVisitor reorderVisitor(content->path); + if (content->path.size() > 0) { + reorderVisitor.visit(stmt); + stmt = reorderVisitor.innerStmt; + } // collect current ordering of IndexVars bool startedMatch = false; From 92f96e463ff8719ef3eeffa36aacea36cbe5c723 Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Thu, 9 Mar 2023 12:32:48 -0500 Subject: [PATCH 13/14] hoist higher dimensional temporaries to the top --- include/taco/index_notation/index_notation.h | 2 +- include/taco/lower/lowerer_impl_imperative.h | 2 +- src/index_notation/index_notation.cpp | 48 ++++++++++----- src/lower/lowerer_impl_imperative.cpp | 62 +++++++++++++------- test/tests-workspaces.cpp | 4 +- 5 files changed, 79 insertions(+), 39 deletions(-) diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index f2072d756..6927752d2 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -1323,7 +1323,7 @@ std::vector getAttrQueryResults(IndexStmt stmt); // [Olivia] /// Returns the temporaries in the index statement, in the order they appear. -std::map getTemporaryLocations(IndexStmt stmt); +std::map > getTemporaryLocations(IndexStmt stmt); /// Returns the results in the index statement that should be assembled by /// ungrouped insertion. diff --git a/include/taco/lower/lowerer_impl_imperative.h b/include/taco/lower/lowerer_impl_imperative.h index fa97e3cd9..b19c19034 100644 --- a/include/taco/lower/lowerer_impl_imperative.h +++ b/include/taco/lower/lowerer_impl_imperative.h @@ -513,7 +513,7 @@ class LowererImplImperative : public LowererImpl { std::set nonFullyInitializedResults; /// Map used to hoist temporary workspace initialization - std::map temporaryInitialization; + std::map > temporaryInitialization; /// Map used to hoist parallel temporary workspaces. Maps workspace shared by all threads to where statement std::map whereToTemporaryVar; diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 721580606..7cead8387 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -11,6 +11,7 @@ #include "error/error_checks.h" #include "taco/error/error_messages.h" +#include "taco/index_notation/index_notation_visitor.h" #include "taco/type.h" #include "taco/format.h" @@ -3474,20 +3475,39 @@ bool allForFreeLoopsBeforeAllReductionLoops(IndexStmt stmt) { return true; } -std::map getTemporaryLocations(IndexStmt stmt) { - map temporaryLocs; - Forall f = Forall(); - match(stmt, - function([&](const ForallNode* op, Matcher* ctx) { - f = op; - ctx->match(op->stmt); - }), - function([&](const WhereNode* w, Matcher* ctx) { - if (!(f == IndexStmt())) - temporaryLocs.insert({f, Where(w)}); - }) - ); - return temporaryLocs; +std::map > getTemporaryLocations(IndexStmt stmt) { + struct TemporaryLocsGetter : public IndexNotationVisitor { + map > temporaryLocs; + Forall f; + + using IndexNotationVisitor::visit; + + void visit(const ForallNode *op) { + Forall forall = Forall(op); + + if (f == NULL) { + f = op; + } + IndexNotationVisitor::visit(op); + } + + void visit(const WhereNode *op) { + Where where = Where(op); + if (temporaryLocs.find(f) != temporaryLocs.end()) { + temporaryLocs[f].push_back(where); + } + else { + vector whereVec; + whereVec.push_back(where); + temporaryLocs.insert({f, whereVec}); + } + IndexNotationVisitor::visit(op); + } + }; + TemporaryLocsGetter getter; + getter.visit(stmt); + + return getter.temporaryLocs; } diff --git a/src/lower/lowerer_impl_imperative.cpp b/src/lower/lowerer_impl_imperative.cpp index c1f614463..614693b3f 100644 --- a/src/lower/lowerer_impl_imperative.cpp +++ b/src/lower/lowerer_impl_imperative.cpp @@ -783,14 +783,22 @@ Stmt LowererImplImperative::lowerForall(Forall forall) // Emit temporary initialization if forall is sequential or parallelized by // cpu threads and leads to a where statement // This is for workspace hoisting by 1-level + vector> temporaryValuesInit; vector temporaryValuesInitFree = {Stmt(), Stmt()}; auto temp = temporaryInitialization.find(forall); - if (temp != temporaryInitialization.end() && forall.getParallelUnit() == - ParallelUnit::NotParallel && !isScalar(temp->second.getTemporary().getType())) - temporaryValuesInitFree = codeToInitializeTemporary(temp->second); - else if (temp != temporaryInitialization.end() && forall.getParallelUnit() == - ParallelUnit::CPUThread && !isScalar(temp->second.getTemporary().getType())) { - temporaryValuesInitFree = codeToInitializeTemporaryParallel(temp->second, forall.getParallelUnit()); + if (temp != temporaryInitialization.end()) { + auto whereClauses = temp->second; + // iterate over whereClauses + for (auto& where : whereClauses) { + if (forall.getParallelUnit() == ParallelUnit::NotParallel && !isScalar(where.getTemporary().getType())) { + temporaryValuesInitFree = codeToInitializeTemporary(where); + temporaryValuesInit.push_back(temporaryValuesInitFree); + } + else if (forall.getParallelUnit() == ParallelUnit::CPUThread && !isScalar(where.getTemporary().getType())) { + temporaryValuesInitFree = codeToInitializeTemporaryParallel(where, forall.getParallelUnit()); + temporaryValuesInit.push_back(temporaryValuesInitFree); + } + } } Stmt loops; @@ -890,10 +898,21 @@ Stmt LowererImplImperative::lowerForall(Forall forall) parallelUnitIndexVars.erase(forall.getParallelUnit()); parallelUnitSizes.erase(forall.getParallelUnit()); } + + vector inits; + vector frees; + // iterate over temporaryValuesInit and add to inits and frees + for (auto& s : temporaryValuesInit) { + inits.push_back(s[0]); + frees.push_back(s[1]); + } + Stmt initsBlock = Block::make(inits); + Stmt freesBlock = Block::make(frees); + return Block::blanks(preInitValues, - temporaryValuesInitFree[0], + initsBlock, loops, - temporaryValuesInitFree[1]); + freesBlock); } Stmt LowererImplImperative::lowerForallCloned(Forall forall) { @@ -2523,20 +2542,23 @@ Stmt LowererImplImperative::lowerWhere(Where where) { vector temporaryValuesInitFree = {Stmt(), Stmt()}; bool temporaryHoisted = false; for (auto it = temporaryInitialization.begin(); it != temporaryInitialization.end(); ++it) { - if (it->second == where && it->first.getParallelUnit() == - ParallelUnit::NotParallel && !isScalar(temporary.getType())) { - temporaryHoisted = true; - } else if (it->second == where && it->first.getParallelUnit() == - ParallelUnit::CPUThread && !isScalar(temporary.getType())) { - temporaryHoisted = true; - auto decls = codeToInitializeLocalTemporaryParallel(where, it->first.getParallelUnit()); - - temporaryValuesInitFree[0] = ir::Block::make(decls); + auto whereClauses = it->second; + for (auto& whereClause : whereClauses) { + if (whereClause == where && it->first.getParallelUnit() == + ParallelUnit::NotParallel && !isScalar(temporary.getType())) { + temporaryHoisted = true; + } else if (whereClause == where && it->first.getParallelUnit() == + ParallelUnit::CPUThread && !isScalar(temporary.getType())) { + temporaryHoisted = true; + auto decls = codeToInitializeLocalTemporaryParallel(where, it->first.getParallelUnit()); + + temporaryValuesInitFree[0] = ir::Block::make(decls); + } } - } - if (!temporaryHoisted) { - temporaryValuesInitFree = codeToInitializeTemporary(where); + if (!temporaryHoisted) { + temporaryValuesInitFree = codeToInitializeTemporary(where); + } } Stmt initializeTemporary = temporaryValuesInitFree[0]; diff --git a/test/tests-workspaces.cpp b/test/tests-workspaces.cpp index c253dc5a3..daaf4a273 100644 --- a/test/tests-workspaces.cpp +++ b/test/tests-workspaces.cpp @@ -939,9 +939,7 @@ TEST(workspaces, precompute2D_mul) { TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); - vector path; - stmt = stmt.loopfuse(2, true, path); - + vector path; stmt = stmt.precompute(precomputedExpr, {i,k}, {i,k}, ws); stmt = stmt.precompute(ws(i,k) * D(k,l), {i,l}, {i,l}, t); stmt = stmt.concretize(); From 1c413ab4aad37aebb63ec66c153dde2036129d43 Mon Sep 17 00:00:00 2001 From: Adhhitha Dias Date: Thu, 15 Jun 2023 11:04:21 -0400 Subject: [PATCH 14/14] add test cases for evaluation --- src/codegen/module.cpp | 4 +- src/tensor.cpp | 8 +- test/tests-workspaces.cpp | 765 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 743 insertions(+), 34 deletions(-) diff --git a/src/codegen/module.cpp b/src/codegen/module.cpp index 89738f22d..08593bcca 100644 --- a/src/codegen/module.cpp +++ b/src/codegen/module.cpp @@ -128,7 +128,7 @@ string Module::compile() { #ifdef TACO_DEBUG // In debug mode, compile the generated code with debug symbols and a // low optimization level. - string defaultFlags = "-g -O0 -std=c99"; + string defaultFlags = "-O3 -ffast-math -std=c99"; #else // Otherwise, use the standard set of optimizing flags. string defaultFlags = "-O3 -ffast-math -std=c99"; @@ -145,6 +145,8 @@ string Module::compile() { prefix + file_ending + " " + shims_file + " " + "-o " + fullpath + " -lm"; + // std::cout << "Compiling generated code with command:\n" << cmd << "\n"; + // open the output file & write out the source compileToSource(tmpdir, libname); diff --git a/src/tensor.cpp b/src/tensor.cpp index eb4b4595a..78e30a3b7 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -895,7 +895,7 @@ void TensorBase::compute() { if (!needsCompute()) { return; } - setNeedsCompute(false); + // setNeedsCompute(false); // Sync operand tensors if needed. auto operands = getTensors(getAssignment().getRhs()); for (auto& operand : operands) { @@ -907,7 +907,7 @@ void TensorBase::compute() { this->content->module->callFuncPacked("compute", arguments.data()); if (content->assembleWhileCompute) { - setNeedsAssemble(false); + // setNeedsAssemble(false); taco_tensor_t* tensorData = ((taco_tensor_t*)arguments[0]); content->valuesSize = unpackTensorData(*tensorData, *this); } @@ -918,7 +918,7 @@ void TensorBase::compute(IndexStmt stmt) { if (!needsCompute()) { return; } - setNeedsCompute(false); + // setNeedsCompute(false); // Sync operand tensors if needed. auto operands = getTensors(getAssignment().getRhs()); for (auto& operand : operands) { @@ -930,7 +930,7 @@ void TensorBase::compute(IndexStmt stmt) { this->content->module->callFuncPacked("compute", arguments.data()); if (content->assembleWhileCompute) { - setNeedsAssemble(false); + // setNeedsAssemble(false); taco_tensor_t* tensorData = ((taco_tensor_t*)arguments[0]); content->valuesSize = unpackTensorData(*tensorData, *this); } diff --git a/test/tests-workspaces.cpp b/test/tests-workspaces.cpp index daaf4a273..62c2f28db 100644 --- a/test/tests-workspaces.cpp +++ b/test/tests-workspaces.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -9,6 +10,8 @@ #include "taco/index_notation/index_notation.h" #include "codegen/codegen.h" #include "taco/lower/lower.h" +#include "taco/util/env.h" +#include "time.h" using namespace taco; @@ -669,42 +672,433 @@ TEST(workspaces, loopfuse) { D.insert({i, j}, (double) i*j); } } + B.pack(); IndexVar i("i"), j("j"), k("k"), l("l"), m("m"); A(i,m) = B(i,j) * C(j,k) * D(k,l) * E(l,m); IndexStmt stmt = A.getAssignment().concretize(); - TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); - TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + // TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + // TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); std::cout << stmt << endl; - vector path1; - vector path2 = {0}; + vector path0; + vector path1 = {1}; + vector path2 = {1, 0}; + // stmt = stmt - .reorder({i,j,k, l, m}) - .loopfuse(3, true, path1) - .loopfuse(2, true, path2) - ; + .reorder({i, l, j, k, m}) + .loopfuse(1, true, path0); + + std::cout << "inter: " << stmt << std::endl; + stmt = stmt - .parallelize(i, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces) + .reorder(path1, {l, j}) + .loopfuse(2, false, path1) + // .loopfuse(1, false, path2) ; + // stmt = stmt + // .parallelize(i, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces) + // ; + // stmt = stmt.concretize(); cout << "final stmt: " << stmt << endl; printCodeToFile("loopfuse", stmt); - A.compile(stmt.concretize()); + A.compile(stmt); A.assemble(); - A.compute(); + + clock_t begin = clock(); + A.compute(stmt); + clock_t end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC; + + std::cout << "executed\n"; Tensor expected("expected", {N, N}, Format{Dense, Dense}); expected(i,m) = B(i,j) * C(j,k) * D(k,l) * E(l,m); expected.compile(); expected.assemble(); + begin = clock(); expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC; ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; +} + +TEST(workspaces, sddmm_spmm) { + int N = 16; + float SPARSITY = 0.3; + Tensor A("A", {N, N}, Format{Dense, Dense}); + Tensor B("B", {N, N}, Format{Dense, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + Tensor E("E", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + float rand_float = (float) rand() / (float) RAND_MAX; + if (rand_float < SPARSITY) + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + E.insert({i, j}, (double) i*j); + D.insert({i, j}, (double) i*j); + } + } + B.pack(); + + + + // 3 -> A(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l) - + IndexVar i("i"), j("j"), k("k"), l("l"); + A(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l); + + IndexStmt stmt = A.getAssignment().concretize(); + // TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + // TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + std::cout << stmt << endl; + + /* BEGIN sddmm_spmm TEST */ + vector path0; + stmt = stmt + .reorder({i, j, k, l}) + .loopfuse(3, true, path0) + ; + /* END sddmm_spmm TEST */ + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("sddmm_spmm", stmt); + + A.compile(stmt); + A.assemble(); + + Tensor expected("expected", {N, N}, Format{Dense, Dense}); + expected(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l); + IndexStmt exp = makeReductionNotation(expected.getAssignment()); + exp = insertTemporaries(exp); + exp = exp.concretize(); + expected.compile(exp); + expected.assemble(); + + clock_t begin; + clock_t end; + + for (int i = 0; i< 10; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC; + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + + + } +TEST(workspaces, sddmm_spmm_gemm) { + int N = 16; + float SPARSITY = 0.3; + Tensor A("A", {N, N}, Format{Dense, Dense}); + Tensor B("B", {N, N}, Format{Dense, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + Tensor E("E", {N, N}, Format{Dense, Dense}); + Tensor F("F", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + float rand_float = (float) rand() / (float) RAND_MAX; + if (rand_float < SPARSITY) + B.insert({i, j}, (double) i); + C.insert({i, j}, (double) j); + E.insert({i, j}, (double) i*j); + D.insert({i, j}, (double) i*j); + F.insert({i, j}, (double) i*j); + } + } + B.pack(); + + + + // 3 -> A(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l) - + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"); + A(i,m) = B(i,j) * C(i,k) * D(j,k) * E(j,l) * F(l,m); + + IndexStmt stmt = A.getAssignment().concretize(); + // TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + // TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + std::cout << stmt << endl; + + /* BEGIN sddmm_spmm TEST */ + vector path0; + stmt = stmt + .reorder({i, j, k, l, m}) + .loopfuse(3, true, path0) + ; + /* END sddmm_spmm TEST */ + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("sddmm_spmm", stmt); + + A.compile(stmt); + A.assemble(); + + Tensor expected("expected", {N, N}, Format{Dense, Dense}); + expected(i,m) = B(i,j) * C(i,k) * D(j,k) * E(j,l) * F(l,m); + IndexStmt exp = makeReductionNotation(expected.getAssignment()); + exp = insertTemporaries(exp); + exp = exp.concretize(); + expected.compile(exp); + expected.assemble(); + + clock_t begin; + clock_t end; + + for (int i = 0; i< 10; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC; + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + + + +} + +TEST(workspaces, sddmm_spmm_gemm_real) { + + int K = 16; + int L = 16; + int M = 16; + + std::string mat_file = util::getFromEnv("TENSOR_FILE", ""); + + std::cout << mat_file << std::endl; + + Tensor B = read(mat_file, Format({Dense, Sparse}), true); + B.setName("B"); + B.pack(); + + if (mat_file == "") { + std::cout << "No tensor file specified!\n"; + return; + } + + Tensor C("C", {B.getDimension(0), K}, Format{Dense, Dense}); + for (int i=0; i D("D", {B.getDimension(1), K}, Format{Dense, Dense}); + for (int j=0; j E("E", {B.getDimension(1), L}, Format{Dense, Dense}); + for (int j=0; j F("F", {L, M}, Format{Dense, Dense}); + for (int j=0; j A("A", {B.getDimension(0), M}, Format{Dense, Dense}); + + // 3 -> A(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l) * F(l,m) - + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"); + A(i,m) = B(i,j) * C(i,k) * D(j,k) * E(j,l) * F(l,m); + + IndexStmt stmt = A.getAssignment().concretize(); + // TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + // TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + std::cout << stmt << endl; + + /* BEGIN sddmm_spmm_gemm_real TEST */ + vector path0; + vector path1 = {1}; + vector path2 = {1, 0}; + vector path3 = {1, 0, 0}; + vector path4 = {1, 1}; + vector path5 = {1, 0, 1}; + vector path6 = {1, 0, 0, 0}; + stmt = stmt + .reorder({i, k, j, l, m}) + .loopfuse(1, true, path0) + // .loopfuse(4, true, path1) + // .loopfuse(3, true, path2) + // .loopfuse(1, false, path3) + // .reorder(path4, {m, l}) + // .reorder(path5, {l, j}) + // .reorder(path6, {j, k}) + ; + /* END sddmm_spmm_gemm_real TEST */ + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("sddmm_spmm", stmt); + + A.compile(stmt); + A.assemble(); + + Tensor expected("expected", {B.getDimension(0), M}, Format{Dense, Dense}); + expected(i,m) = B(i,j) * C(i,k) * D(j,k) * E(j,l) * F(l,m); + IndexStmt exp = makeReductionNotation(expected.getAssignment()); + exp = insertTemporaries(exp); + exp = exp.concretize(); + expected.compile(exp); + expected.assemble(); + + clock_t begin; + clock_t end; + + for (int i = 0; i< 10; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC * 1000; + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC * 1000; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + + std::cout << "workspaces, sddmm_spmm_gemm -> execution completed for matrix: " << mat_file << std::endl; + +} + +TEST(workspaces, sddmm_spmm_real) { + int K = 16; + int L = 16; + + std::string mat_file = util::getFromEnv("TENSOR_FILE", ""); + + Tensor B = read(mat_file, Format({Dense, Sparse}), true); + B.setName("B"); + B.pack(); + + if (mat_file == "") { + std::cout << "No tensor file specified!\n"; + return; + } + + Tensor C("C", {B.getDimension(0), K}, Format{Dense, Dense}); + for (int i=0; i D("D", {B.getDimension(1), K}, Format{Dense, Dense}); + for (int j=0; j E("E", {B.getDimension(1), L}, Format{Dense, Dense}); + for (int j=0; j A("A", {B.getDimension(0), L}, Format{Dense, Dense}); + + + // 3 -> A(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l) - + IndexVar i("i"), j("j"), k("k"), l("l"); + A(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l); + + IndexStmt stmt = A.getAssignment().concretize(); + // TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + // TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense}); + + std::cout << stmt << endl; + + /* BEGIN sddmm_spmm_real TEST */ + vector path0; + stmt = stmt + .reorder({i, j, k, l}) + .loopfuse(3, true, path0) + ; + /* END sddmm_spmm_real TEST */ + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("sddmm_spmm", stmt); + + A.compile(stmt); + A.assemble(); + + Tensor expected("expected", {B.getDimension(0), L}, Format{Dense, Dense}); + expected(i,l) = B(i,j) * C(i,k) * D(j,k) * E(j,l); + IndexStmt exp = makeReductionNotation(expected.getAssignment()); + exp = insertTemporaries(exp); + exp = exp.concretize(); + expected.compile(exp); + expected.assemble(); + + clock_t begin; + clock_t end; + + for (int i = 0; i< 10; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC * 1000; + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC * 1000; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + + std::cout << "workspaces, sddmm_spmm -> execution completed for matrix: " << mat_file << std::endl; + +} TEST(workspaces, loopreversefuse) { int N = 16; @@ -734,12 +1128,12 @@ TEST(workspaces, loopreversefuse) { std::cout << stmt << endl; vector path1; stmt = stmt - .reorder({m,k,l,i,j}) - .loopfuse(2, false, path1) - ; - stmt = stmt - .parallelize(m, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces) + .reorder({m,i,l,k,j}) + .loopfuse(3, false, path1) ; + // stmt = stmt + // .parallelize(m, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces) + // ; stmt = stmt.concretize(); cout << "final stmt: " << stmt << endl; @@ -783,16 +1177,20 @@ TEST(workspaces, loopcontractfuse) { IndexStmt stmt = A.getAssignment().concretize(); std::cout << stmt << endl; - vector path1; - vector path2 = {1}; - stmt = stmt - .reorder({l,i,m, j, k, n}) - .loopfuse(2, true, path1) - .loopfuse(2, true, path2) - ; - stmt = stmt - .parallelize(l, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces) - ; + + /* BEGIN loopcontractfuse TEST */ + vector path0; + vector path1 = {1}; + vector path2 = {1, 0}; + vector path3 = {1, 1}; + stmt = stmt + .reorder({l, i, j, k, m, n}) + .loopfuse(2, true, path0) + .loopfuse(2, true, path1) + .reorder(path2, {m, k, j}) + .reorder(path3, {n, m, k}) + ; + /* END loopcontractfuse TEST */ stmt = stmt.concretize(); @@ -801,14 +1199,321 @@ TEST(workspaces, loopcontractfuse) { A.compile(stmt.concretize()); A.assemble(); - A.compute(); Tensor expected("expected", {N, N, N}, Format{Dense, Dense, Dense}); expected(l,m,n) = B(i,j,k) * C(i,l) * D(j,m) * E(k,n); expected.compile(); expected.assemble(); - expected.compute(); - ASSERT_TENSOR_EQ(expected, A); + + clock_t begin; + clock_t end; + + for (int i=0; i<10; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC * 1000; + + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC * 1000; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + +} + +TEST(workspaces, loopcontractfuse_real) { + int L = 16; + int M = 16; + int N = 16; + Tensor A("A", {L, M, N}, Format{Dense, Dense, Dense}); + // Tensor B("B", {N, N, N}, Format{Dense, Sparse, Sparse}); + // Tensor C("C", {N, N}, Format{Dense, Dense}); + // Tensor D("D", {N, N}, Format{Dense, Dense}); + // Tensor E("E", {N, N}, Format{Dense, Dense}); + + std::string mat_file = util::getFromEnv("TENSOR_FILE", ""); + + // std::cout << mat_file << std::endl; + + Tensor B = read(mat_file, Format({Dense, Sparse, Sparse}), true); + B.setName("B"); + B.pack(); + + // std::cout << "B tensor successfully read and packed!\n"; + // return; + + Tensor C("C", {B.getDimension(0), L}, Format{Dense, Dense}); + for (int i=0; i D("D", {B.getDimension(1), M}, Format{Dense, Dense}); + for (int j=0; j E("E", {B.getDimension(2), N}, Format{Dense, Dense}); + for (int k=0; k path0; + vector path1 = {1}; + vector path2 = {1, 0}; + vector path3 = {1, 1}; + stmt = stmt + .reorder({l, i, j, k, m, n}) + .loopfuse(2, true, path0) + .loopfuse(2, true, path1) + .reorder(path2, {k, m, j}) + .reorder(path3, {m, n, k}) + ; + /* END loopcontractfuse_real TEST */ + + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("loopcontractfuse", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + + Tensor expected("expected", {N, N, N}, Format{Dense, Dense, Dense}); + expected(l,m,n) = B(i,j,k) * C(i,l) * D(j,m) * E(k,n); + expected.compile(); + expected.assemble(); + + clock_t begin; + clock_t end; + + for (int i=0; i<3; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC * 1000; + + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC * 1000; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + +std::cout << "workspaces, loopcontractfuse -> execution completed for matrix: " << mat_file << std::endl; + +} + +TEST(workspaces, spttm_ttm) { + int N = 16; + Tensor A("A", {N, N, N}, Format{Dense, Dense, Dense}); + Tensor B("B", {N, N, N}, Format{Dense, Sparse, Sparse}); + Tensor C("C", {N, N}, Format{Dense, Dense}); + Tensor D("D", {N, N}, Format{Dense, Dense}); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + for (int k = 0; k < N; k++) { + B.insert({i, j, k}, (double) i); + } + C.insert({i, j}, (double) j); + D.insert({i, j}, (double) i*j); + } + } + + // 5 -> A(i,l,m) = B(i,j,k) * C(j,l) * D(k,m) - + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n"); + A(i,l,m) = B(i,j,k) * C(j,l) * D(k,m); + + IndexStmt stmt = A.getAssignment().concretize(); + + std::cout << stmt << endl; + + /* BEGIN spttm_ttm TEST */ + vector path0; + vector path1 = {1}; + stmt = stmt + .reorder({l, i, j, k, m}) + .loopfuse(2, true, path0) + .reorder(path1, {m, k}) + ; + /* END spttm_ttm TEST */ + + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("spttm_ttm", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + + Tensor expected("expected", {N, N, N}, Format{Dense, Dense, Dense}); + expected(i,l,m) = B(i,j,k) * C(j,l) * D(k,m); + expected.compile(); + expected.assemble(); + + clock_t begin; + clock_t end; + + for (int i=0; i<10; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC * 1000; + + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC * 1000; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + +} + +TEST(workspaces, spttm_ttm_real) { + // int N = 16; + // Tensor A("A", {N, N, N}, Format{Dense, Dense, Dense}); + // Tensor B("B", {N, N, N}, Format{Dense, Sparse, Sparse}); + // Tensor C("C", {N, N}, Format{Dense, Dense}); + // Tensor D("D", {N, N}, Format{Dense, Dense}); + + // for (int i = 0; i < N; i++) { + // for (int j = 0; j < N; j++) { + // for (int k = 0; k < N; k++) { + // B.insert({i, j, k}, (double) i); + // } + // C.insert({i, j}, (double) j); + // D.insert({i, j}, (double) i*j); + // } + // } + + int L = 16; + int M = 16; + + std::string mat_file = util::getFromEnv("TENSOR_FILE", ""); + + // std::cout << mat_file << std::endl; + + Tensor B = read(mat_file, Format({Dense, Sparse, Sparse}), true); + B.setName("B"); + B.pack(); + + // std::cout << "B tensor successfully read and packed!\n"; + // return; + + Tensor C("C", {B.getDimension(1), L}, Format{Dense, Dense}); + for (int i=0; i D("D", {B.getDimension(2), M}, Format{Dense, Dense}); + for (int j=0; j A("A", {B.getDimension(0), L, M}, Format{Dense, Dense, Dense}); + + // 5 -> A(i,l,m) = B(i,j,k) * C(j,l) * D(k,m) - + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n"); + A(i,l,m) = B(i,j,k) * C(j,l) * D(k,m); + + IndexStmt stmt = A.getAssignment().concretize(); + + std::cout << stmt << endl; + + /* BEGIN spttm_ttm TEST */ + vector path0; + vector path1 = {1}; + vector path2 = {1, 0}; + vector path3 = {1, 0, 0}; + vector path4 = {1, 1}; + vector path5 = {1, 0, 1}; + vector path6 = {1, 0, 0, 0}; + stmt = stmt + .reorder({i, k, j, l, m}) + .loopfuse(1, true, path0) + .loopfuse(4, true, path1) + .loopfuse(3, true, path2) + .loopfuse(1, false, path3) + .reorder(path4, {m, l}) + .reorder(path5, {l, j}) + .reorder(path6, {j, k}) + ; + /* END spttm_ttm TEST */ + + + stmt = stmt.concretize(); + cout << "final stmt: " << stmt << endl; + printCodeToFile("spttm_ttm", stmt); + + A.compile(stmt.concretize()); + A.assemble(); + + Tensor expected("expected", {B.getDimension(0), L, M}, Format{Dense, Dense, Dense}); + expected(i,l,m) = B(i,j,k) * C(j,l) * D(k,m); + expected.compile(); + expected.assemble(); + + clock_t begin; + clock_t end; + + for (int i=0; i<10; i++) { + begin = clock(); + A.compute(stmt); + end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC * 1000; + + begin = clock(); + expected.compute(); + end = clock(); + double elapsed_secs_ref = double(end - begin) / CLOCKS_PER_SEC * 1000; + // ASSERT_TENSOR_EQ(expected, A); + + std::cout << elapsed_secs << std::endl; + std::cout << elapsed_secs_ref << std::endl; + } + } TEST(workspaces, loopreordercontractfuse) { @@ -905,7 +1610,9 @@ TEST(workspaces, sddmm) { A.compile(stmt.concretize()); A.assemble(); + // beging timing A.compute(); + // end timing Tensor expected("expected", dims, Format{Dense, Dense}); expected(i,j) = B(i,j) * C(i,k) * D(j,k);