From e8bc18716cdaf190a259f2dcb2b91a2cf7489516 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 9 Jan 2025 17:55:42 -0500 Subject: [PATCH] Fix nametoordinal (#2221) --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 105 ++++++++++--------- 1 file changed, 57 insertions(+), 48 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 50efdaeae61..0c0a45c6ead 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -207,7 +207,7 @@ struct VariableSetting { StringMap> extractions; std::tuple> - lookup(StringRef name, const Record *pattern, const Init *resultRoot) { + lookup(StringRef name, const Record *pattern, const Init *resultRoot) const { auto ord = nameToOrdinal.find(name); if (ord == nameToOrdinal.end()) PrintFatalError(pattern->getLoc(), Twine("unknown named operand '") + @@ -1192,14 +1192,16 @@ void handleUse( const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse, std::string &foundShadowUse, bool &foundDiffRet, std::string precondition, const DagInit *tree, - StringMap> &varNameToCondition); + StringMap> &varNameToCondition, + const VariableSetting &nameToOrdinal); void handleUseArgument( StringRef name, const Init *arg, bool usesPrimal, bool usesShadow, const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse, std::string &foundShadowUse, bool &foundDiffRet, std::string precondition, const DagInit *tree, - StringMap> &varNameToCondition) { + StringMap> &varNameToCondition, + const VariableSetting &nameToOrdinal) { auto arg2 = dyn_cast(arg); @@ -1218,7 +1220,8 @@ void handleUseArgument( handleUse(root, arg2, name.size() ? foundPrimalUse2 : foundPrimalUse, name.size() ? foundShadowUse2 : foundShadowUse, name.size() ? foundDiffRet2 : foundDiffRet, - usesPrimal ? precondition : "", tree, varNameToCondition); + usesPrimal ? precondition : "", tree, varNameToCondition, + nameToOrdinal); if (name.size()) { if (foundPrimalUse2.size() && @@ -1306,7 +1309,8 @@ void handleUse( const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse, std::string &foundShadowUse, bool &foundDiffRet, std::string precondition, const DagInit *tree, - StringMap> &varNameToCondition) { + StringMap> &varNameToCondition, + const VariableSetting &nameToOrdinal) { auto opName = resultTree->getOperator()->getAsString(); auto Def = cast(resultTree->getOperator())->getDef(); if (opName == "DiffeRetIndex" || Def->isSubClassOf("DiffeRetIndex")) { @@ -1339,7 +1343,9 @@ void handleUse( if (numArgs == 3) { if (isa(resultTree->getArg(0)) && resultTree->getArgName(0)) { auto name = resultTree->getArgName(0)->getAsUnquotedString(); - conditionStr = ReplaceAll(conditionStr, "imVal", name); + auto [ord, isVec, ext] = nameToOrdinal.lookup(name, nullptr, nullptr); + assert(!isVec); + conditionStr = ReplaceAll(conditionStr, "imVal", ord); } else assert("Requires name for arg"); } @@ -1362,7 +1368,7 @@ void handleUse( auto arg = resultTree->getArg(i); handleUseArgument(name, arg, true, false, root, resultTree, foundPrimalUse, foundShadowUse, foundDiffRet, - precondition2, tree, varNameToCondition); + precondition2, tree, varNameToCondition, nameToOrdinal); } return; @@ -1375,16 +1381,57 @@ void handleUse( auto name = resultTree->getArgNameStr(argEn.index()); handleUseArgument(name, argEn.value(), usesPrimal, usesShadow, root, resultTree, foundPrimalUse, foundShadowUse, foundDiffRet, - precondition, tree, varNameToCondition); + precondition, tree, varNameToCondition, nameToOrdinal); } } +static VariableSetting parseVariables(const DagInit *tree, ActionType intrinsic, + StringRef origName) { + VariableSetting nameToOrdinal; + std::function)> insert = + [&](const DagInit *ptree, ArrayRef prev) { + unsigned i = 0; + for (auto tree : ptree->getArgs()) { + SmallVector next(prev.begin(), prev.end()); + next.push_back(i); + if (auto dg = dyn_cast(tree)) + insert(dg, next); + + if (ptree->getArgNameStr(i).size()) { + std::string op; + if (intrinsic != MLIRDerivatives) + op = (origName + ".getOperand(" + Twine(next[0]) + ")").str(); + else + op = (origName + "->getOperand(" + Twine(next[0]) + ")").str(); + std::vector extractions; + if (prev.size() > 0) { + for (unsigned i = 1; i < next.size(); i++) { + extractions.push_back(next[i]); + } + } + nameToOrdinal.insert(ptree->getArgNameStr(i), op, false, + extractions); + } + i++; + } + }; + + insert(tree, {}); + + if (tree->getNameStr().size()) + nameToOrdinal.insert(tree->getNameStr(), + (Twine("(&") + origName + ")").str(), false, {}); + return nameToOrdinal; +} + void printDiffUse( raw_ostream &os, Twine prefix, const ListInit *argOps, StringRef origName, ActionType intrinsic, const DagInit *tree, StringMap> &varNameToCondition) { os << prefix << " // Rule " << *tree << "\n"; + VariableSetting nameToOrdinal = parseVariables(tree, intrinsic, origName); + for (auto argOpEn : enumerate(*argOps)) { size_t argIdx = argOpEn.index(); if (auto resultRoot = dyn_cast(argOpEn.value())) { @@ -1417,7 +1464,8 @@ void printDiffUse( // hasDiffeRet(resultTree) handleUse(resultTree, resultTree, foundPrimalUse, foundShadowUse, - foundDiffRet, /*precondition*/ "true", tree, varNameToCondition); + foundDiffRet, /*precondition*/ "true", tree, varNameToCondition, + nameToOrdinal); os << prefix << " // Arg " << argIdx << " : " << *resultTree << "\n"; @@ -1587,45 +1635,6 @@ static void emitMLIRReverse(raw_ostream &os, const Record *pattern, os << " mlir::Value dif = nullptr;\n"; } -static VariableSetting parseVariables(const DagInit *tree, ActionType intrinsic, - StringRef origName) { - VariableSetting nameToOrdinal; - std::function)> insert = - [&](const DagInit *ptree, ArrayRef prev) { - unsigned i = 0; - for (auto tree : ptree->getArgs()) { - SmallVector next(prev.begin(), prev.end()); - next.push_back(i); - if (auto dg = dyn_cast(tree)) - insert(dg, next); - - if (ptree->getArgNameStr(i).size()) { - std::string op; - if (intrinsic != MLIRDerivatives) - op = (origName + ".getOperand(" + Twine(next[0]) + ")").str(); - else - op = (origName + "->getOperand(" + Twine(next[0]) + ")").str(); - std::vector extractions; - if (prev.size() > 0) { - for (unsigned i = 1; i < next.size(); i++) { - extractions.push_back(next[i]); - } - } - nameToOrdinal.insert(ptree->getArgNameStr(i), op, false, - extractions); - } - i++; - } - }; - - insert(tree, {}); - - if (tree->getNameStr().size()) - nameToOrdinal.insert(tree->getNameStr(), - (Twine("(&") + origName + ")").str(), false, {}); - return nameToOrdinal; -} - static void emitReverseCommon(raw_ostream &os, const Record *pattern, const DagInit *tree, ActionType intrinsic, StringRef origName, const ListInit *argOps) {