Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP variadic #2226

Merged
merged 1 commit into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def AssertingInactiveArg : InactiveArgSpec {
bit asserting = 1;
}

class Variadic<string getter_> {
string getter = getter_;
}

def Unimplemented {

Expand Down
17 changes: 15 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class MEnzymeLogic {
unsigned width;
mlir::Type additionalType;
const MFnTypeInfo typeInfo;
bool omp;

inline bool operator<(const MForwardCacheKey &rhs) const {
if (todiff < rhs.todiff)
Expand Down Expand Up @@ -100,6 +101,12 @@ class MEnzymeLogic {
return true;
if (rhs.typeInfo < typeInfo)
return false;

if (omp < rhs.omp)
return true;
if (rhs.omp < omp)
return false;

// equal
return false;
}
Expand All @@ -117,6 +124,7 @@ class MEnzymeLogic {
mlir::Type additionalType;
const MFnTypeInfo typeInfo;
const std::vector<bool> volatileArgs;
bool omp;

inline bool operator<(const MReverseCacheKey &rhs) const {
if (todiff < rhs.todiff)
Expand Down Expand Up @@ -182,6 +190,11 @@ class MEnzymeLogic {
if (rhs.volatileArgs < volatileArgs)
return false;

if (omp < rhs.omp)
return true;
if (rhs.omp < omp)
return false;

// equal
return false;
}
Expand All @@ -196,7 +209,7 @@ class MEnzymeLogic {
std::vector<bool> returnPrimals, DerivativeMode mode,
bool freeMemory, size_t width, mlir::Type addedType,
MFnTypeInfo type_args, std::vector<bool> volatile_args,
void *augmented, llvm::StringRef postpasses);
void *augmented, bool omp, llvm::StringRef postpasses);

FunctionOpInterface
CreateReverseDiff(FunctionOpInterface fn, std::vector<DIFFE_TYPE> retType,
Expand All @@ -205,7 +218,7 @@ class MEnzymeLogic {
std::vector<bool> returnShadows, DerivativeMode mode,
bool freeMemory, size_t width, mlir::Type addedType,
MFnTypeInfo type_args, std::vector<bool> volatile_args,
void *augmented, llvm::StringRef postpasses);
void *augmented, bool omp, llvm::StringRef postpasses);

void
initializeShadowValues(SmallVector<mlir::Block *> &dominatorToposortBlocks,
Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
std::vector<bool> returnPrimals, std::vector<bool> returnShadows,
DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType,
MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented,
llvm::StringRef postpasses) {
bool omp, llvm::StringRef postpasses) {

if (fn.getFunctionBody().empty()) {
llvm::errs() << fn << "\n";
Expand Down Expand Up @@ -217,7 +217,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(

MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone(
*this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP,
retType, constants, addedType, postpasses);
retType, constants, addedType, omp, postpasses);

ReverseCachedFunctions[tup] = gutils->newFunc;

Expand Down
9 changes: 5 additions & 4 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse(
ArrayRef<DIFFE_TYPE> ReturnActivity, ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode_, unsigned width, StringRef postpasses)
DerivativeMode mode_, unsigned width, bool omp, StringRef postpasses)
: MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {},
invertedPointers_, returnPrimals, returnShadows,
constantvalues_, activevals_, ReturnActivity,
ArgDiffeTypes_, originalToNewFn_, originalToNewFnOps_,
mode_, width, /*omp*/ false, postpasses) {}
mode_, width, omp, postpasses) {}

Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() {
Type indexType = getIndexType();
Expand Down Expand Up @@ -138,7 +138,7 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone(
FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
const ArrayRef<bool> returnPrimals, const ArrayRef<bool> returnShadows,
ArrayRef<DIFFE_TYPE> retType, ArrayRef<DIFFE_TYPE> constant_args,
mlir::Type additionalArg, llvm::StringRef postpasses) {
mlir::Type additionalArg, bool omp, llvm::StringRef postpasses) {
std::string prefix;

switch (mode_) {
Expand Down Expand Up @@ -174,5 +174,6 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone(
return new MGradientUtilsReverse(
Logic, newFunc, todiff, TA, invertedPointers, returnPrimals,
returnShadows, constant_values, nonconstant_values, retType,
constant_args, originalToNew, originalToNewOps, mode_, width, postpasses);
constant_args, originalToNew, originalToNewOps, mode_, width, omp,
postpasses);
}
40 changes: 38 additions & 2 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,8 +1394,30 @@ static VariableSetting parseVariables(const DagInit *tree, ActionType intrinsic,
for (auto tree : ptree->getArgs()) {
SmallVector<unsigned, 2> next(prev.begin(), prev.end());
next.push_back(i);
if (auto dg = dyn_cast<DagInit>(tree))
if (auto dg = dyn_cast<DagInit>(tree)) {
if (ptree->getArgNameStr(i).size()) {
auto opName = dg->getOperator()->getAsString();
auto Def = cast<DefInit>(dg->getOperator())->getDef();
if (opName == "Variadic" || Def->isSubClassOf("Variadic")) {
auto expr = Def->getValueAsString("getter");
std::string op;
if (intrinsic != MLIRDerivatives)
op = (origName + "." + expr + "()").str();
else
op = (origName + "->" + expr + "()").str();
std::vector<int> extractions;
if (prev.size() > 0) {
for (unsigned i = 1; i < next.size(); i++) {
extractions.push_back(next[i]);
}
}
nameToOrdinal.insert(ptree->getArgNameStr(i), op, false,
extractions);
continue;
}
}
insert(dg, next);
}

if (ptree->getArgNameStr(i).size()) {
std::string op;
Expand Down Expand Up @@ -1580,8 +1602,22 @@ static void emitMLIRReverse(raw_ostream &os, const Record *pattern,
auto name = ptree->getArgNameStr(treeEn.index());
SmallVector<unsigned, 2> next(prev.begin(), prev.end());
next.push_back(treeEn.index());
if (auto dg = dyn_cast<DagInit>(tree))
if (auto dg = dyn_cast<DagInit>(tree)) {
if (name.size()) {
auto opName = dg->getOperator()->getAsString();
auto Def = cast<DefInit>(dg->getOperator())->getDef();
if (opName == "Variadic" || Def->isSubClassOf("Variadic")) {
auto expr = Def->getValueAsString("getter");
varNameToCondition[name] = std::make_tuple(
("llvm::is_contained(op->getOperand(idx), op." + expr +
"())")
.str(),
"", false);
continue;
}
}
insert(dg, next);
}

if (name.size()) {
varNameToCondition[name] = std::make_tuple(
Expand Down
Loading