Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial implementation of kernel fuse directive #545

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,18 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
IndexStmt divide(IndexVar i, IndexVar i1, IndexVar i2, size_t divideFactor) const; // TODO: TailStrategy


/// The loopfuse transformation fuses common outer loops in
/// 2 iteration graphs.
/// when performing loopfuse operation on an already branched index statement
/// eg: forall(l, where(forall(ijk, T(j,k) += A*B), forall(mjkn, X(l,m,n) += T*C*D)))
/// and we want to further breakdown T*C*D into T2 = T*C and T2*D
/// we can use the path vector to specify the branch we want to apply the fuse on
/// eg: loopfuse(2, true, {1}) where 2 refers to breaking T*C*D at the 2nd position
/// and true refers to making T*C as the producer (if false, then C*D will be the producer if used with 1)
/// and {1} refers to the branch we want to apply the fuse on
IndexStmt loopfuse(int pos, bool isProducerOnLeft, std::vector<int>& path) const;


/// The reorder transformation swaps two directly nested index
/// variables in an iteration graph. This changes the order of
/// iteration through the space and the order of tensor accesses.
Expand All @@ -663,6 +675,9 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
/// reorder takes a new ordering for a set of index variables that are directly nested in the iteration order
IndexStmt reorder(std::vector<IndexVar> reorderedvars) const;

/// reorders the index variables in a nested structure with where clauses
IndexStmt reorder(std::vector<int> path, std::vector<IndexVar> reorderedvars) const;

/// The mergeby transformation specifies how to merge iterators on
/// the given index variable. By default, if an iterator is used for windowing
/// it will be merged with the "gallop" strategy.
Expand Down Expand Up @@ -1308,7 +1323,7 @@ std::vector<TensorVar> getAttrQueryResults(IndexStmt stmt);

// [Olivia]
/// Returns the temporaries in the index statement, in the order they appear.
std::map<Forall, Where> getTemporaryLocations(IndexStmt stmt);
std::map<Forall, std::vector<Where> > getTemporaryLocations(IndexStmt stmt);

/// Returns the results in the index statement that should be assembled by
/// ungrouped insertion.
Expand Down
26 changes: 26 additions & 0 deletions include/taco/index_notation/transformations.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class IndexStmt;
class TransformationInterface;
class Reorder;
class Precompute;
class LoopFuse;
class ForAllReplace;
class AddSuchThatPredicates;
class Parallelize;
Expand All @@ -32,6 +33,7 @@ class Transformation {
public:
Transformation(Reorder);
Transformation(Precompute);
Transformation(LoopFuse);
Transformation(ForAllReplace);
Transformation(Parallelize);
Transformation(TopoReorder);
Expand Down Expand Up @@ -65,10 +67,12 @@ class Reorder : public TransformationInterface {
public:
Reorder(IndexVar i, IndexVar j);
Reorder(std::vector<IndexVar> replacePattern);
Reorder(std::vector<int> path, std::vector<IndexVar> replacePattern);

IndexVar geti() const;
IndexVar getj() const;
const std::vector<IndexVar>& getreplacepattern() const;
const std::vector<int>& getpath() const;

/// Apply the reorder optimization to a concrete index statement. Returns
/// an undefined statement and a reason if the statement cannot be lowered.
Expand Down Expand Up @@ -114,6 +118,28 @@ class Precompute : public TransformationInterface {
/// Print a precompute command.
std::ostream &operator<<(std::ostream &, const Precompute &);

/// The loopfuse optimization rewrite an index expression to precompute
/// part of the `expr` and store it to a workspace.
class LoopFuse : public TransformationInterface {
public:
LoopFuse();
LoopFuse(int pos, bool isProducerOnLeft, std::vector<int>& path);

int getPos() const;
bool getIsProducerOnLeft() const;
std::vector<int>& getPath() const;

/// Apply the loopfuse optimization to a concrete index statement.
IndexStmt apply(IndexStmt, std::string *reason = nullptr) const;

void print(std::ostream &os) const;

private:
struct Content;
std::shared_ptr<Content> content;
};

std::ostream &operator<<(std::ostream &, const LoopFuse &);

/// Replaces all occurrences of directly nested forall nodes of pattern with
/// directly nested loops of replacement
Expand Down
2 changes: 1 addition & 1 deletion include/taco/lower/lowerer_impl_imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ class LowererImplImperative : public LowererImpl {
std::set<ir::Expr> nonFullyInitializedResults;

/// Map used to hoist temporary workspace initialization
std::map<Forall, Where> temporaryInitialization;
std::map<Forall, std::vector<Where> > temporaryInitialization;

/// Map used to hoist parallel temporary workspaces. Maps workspace shared by all threads to where statement
std::map<Where, TensorVar> whereToTemporaryVar;
Expand Down
1 change: 1 addition & 0 deletions include/taco/parser/lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum class Token {
sub,
mul,
div,
colon, // numbers before the colon indicate the path to branch in a branched iteration graph
eq,
eot, // End of tokens
error
Expand Down
1 change: 1 addition & 0 deletions include/taco/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ class TensorBase {

/// Compute the given expression and put the values in the tensor storage.
void compute();
void compute(IndexStmt stmt);

/// Compile, assemble and compute as needed.
void evaluate();
Expand Down
4 changes: 3 additions & 1 deletion src/codegen/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ string Module::compile() {
#ifdef TACO_DEBUG
// In debug mode, compile the generated code with debug symbols and a
// low optimization level.
string defaultFlags = "-g -O0 -std=c99";
string defaultFlags = "-O3 -ffast-math -std=c99";
#else
// Otherwise, use the standard set of optimizing flags.
string defaultFlags = "-O3 -ffast-math -std=c99";
Expand All @@ -145,6 +145,8 @@ string Module::compile() {
prefix + file_ending + " " + shims_file + " " +
"-o " + fullpath + " -lm";

// std::cout << "Compiling generated code with command:\n" << cmd << "\n";

// open the output file & write out the source
compileToSource(tmpdir, libname);

Expand Down
70 changes: 56 additions & 14 deletions src/index_notation/index_notation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "error/error_checks.h"
#include "taco/error/error_messages.h"
#include "taco/index_notation/index_notation_visitor.h"
#include "taco/type.h"
#include "taco/format.h"

Expand Down Expand Up @@ -1854,6 +1855,18 @@ IndexStmt IndexStmt::divide(IndexVar i, IndexVar i1, IndexVar i2, size_t splitFa
return transformed;
}

IndexStmt IndexStmt::loopfuse(int pos, bool isProducerOnLeft, vector<int>& path) const {
string reason; // reason saves the error message if the transformation fails
IndexStmt transformed = *this;
transformed = Transformation(LoopFuse(pos, isProducerOnLeft, path)).apply(transformed, &reason);
if (!transformed.defined()) {
taco_uerror << reason;
}
return transformed;

return *this;
}

IndexStmt IndexStmt::precompute(IndexExpr expr, std::vector<IndexVar> i_vars,
std::vector<IndexVar> iw_vars, TensorVar workspace) const {

Expand Down Expand Up @@ -1907,6 +1920,15 @@ IndexStmt IndexStmt::reorder(std::vector<IndexVar> reorderedvars) const {
return transformed;
}

IndexStmt IndexStmt::reorder(std::vector<int> path, std::vector<IndexVar> reorderedvars) const {
string reason;
IndexStmt transformed = Reorder(path, reorderedvars).apply(*this, &reason);
if (!transformed.defined()) {
taco_uerror << reason;
}
return transformed;
}

IndexStmt IndexStmt::mergeby(IndexVar i, MergeStrategy strategy) const {
string reason;
IndexStmt transformed = SetMergeStrategy(i, strategy).apply(*this, &reason);
Expand Down Expand Up @@ -2048,6 +2070,7 @@ IndexStmt IndexStmt::assemble(TensorVar result, AssembleStrategy strategy,
return transformed;
}


IndexStmt IndexStmt::wsaccel(TensorVar& ws, bool shouldAccel, const std::vector<IndexVar>& accelIndexVars) {
if (accelIndexVars.size() == 0) {
ws.setAccelIndexVars(accelIndexVars, shouldAccel);
Expand Down Expand Up @@ -3452,20 +3475,39 @@ bool allForFreeLoopsBeforeAllReductionLoops(IndexStmt stmt) {
return true;
}

std::map<Forall, Where> getTemporaryLocations(IndexStmt stmt) {
map<Forall, Where> temporaryLocs;
Forall f = Forall();
match(stmt,
function<void(const ForallNode*, Matcher*)>([&](const ForallNode* op, Matcher* ctx) {
f = op;
ctx->match(op->stmt);
}),
function<void(const WhereNode*, Matcher*)>([&](const WhereNode* w, Matcher* ctx) {
if (!(f == IndexStmt()))
temporaryLocs.insert({f, Where(w)});
})
);
return temporaryLocs;
std::map<Forall, vector<Where> > getTemporaryLocations(IndexStmt stmt) {
struct TemporaryLocsGetter : public IndexNotationVisitor {
map<Forall, vector<Where> > temporaryLocs;
Forall f;

using IndexNotationVisitor::visit;

void visit(const ForallNode *op) {
Forall forall = Forall(op);

if (f == NULL) {
f = op;
}
IndexNotationVisitor::visit(op);
}

void visit(const WhereNode *op) {
Where where = Where(op);
if (temporaryLocs.find(f) != temporaryLocs.end()) {
temporaryLocs[f].push_back(where);
}
else {
vector<Where> whereVec;
whereVec.push_back(where);
temporaryLocs.insert({f, whereVec});
}
IndexNotationVisitor::visit(op);
}
};
TemporaryLocsGetter getter;
getter.visit(stmt);

return getter.temporaryLocs;
}


Expand Down
Loading