Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 16, 2024
1 parent 66c7cdc commit a996c20
Show file tree
Hide file tree
Showing 10 changed files with 248 additions and 43 deletions.
4 changes: 4 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ pybind_library(
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:Parser",

"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
"@llvm-project//mlir:MLIRBindingsPythonCore",

# EnzymeMLIR
"@enzyme//:EnzymeMLIR",
Expand All @@ -209,6 +212,7 @@ pybind_extension(
":clang_compile",
":compile_with_xla",
"@com_google_absl//absl/status:statusor",
"@stablehlo//:stablehlo_passes",
"@xla//xla/stream_executor:stream_executor_impl",
],
visibility = ["//visibility:public"],
Expand Down
40 changes: 40 additions & 0 deletions src/enzyme_ad/jax/Implementations/HLODerivatives.td
Original file line number Diff line number Diff line change
@@ -1,8 +1,48 @@

def Add : HLOInst<"AddOp">;
def Sub : HLOInst<"SubtractOp">;
def Neg : HLOInst<"NegOp">;
def Mul : HLOInst<"MulOp">;
def Div : HLOInst<"DivOp">;
def Rem : HLOInst<"RemainderOp">;


def CheckedMul : HLOInst<"MulOp">;
def CheckedDiv : HLOInst<"DivOp">;

def : HLODerivative<"AddOp", (Op $x, $y),
[
(DiffeRet),
(DiffeRet),
]
>;

def : HLODerivative<"SubtractOp", (Op $x, $y),
[
(DiffeRet),
(Neg (DiffeRet)),
]
>;
def : HLODerivative<"NegOp", (Op $x),
[
(Neg (DiffeRet)),
]
>;
def : HLODerivative<"MulOp", (Op $x, $y),
[
(CheckedMul (DiffeRet), $y),
(CheckedMul (DiffeRet), $x)
]
>;
def : HLODerivative<"DivOp", (Op $x, $y),
[
(CheckedDiv (DiffeRet), $y),
(Neg (Mul (CheckedDiv (DiffeRet), $y), (Div $x, $y)))
]
// (CheckedDiv (FSub (SelectIfActive $x, (FMul (Shadow $x), $y), (Zero $x)), (SelectIfActive $y, (FMul (Shadow $y), $x), (Zero $y))), (FMul $y, $y))
>;

def : HLOReadOnlyIdentityOp<"ReshapeOp">;
def : HLOReadOnlyIdentityOp<"SliceOp">;
def : HLOReadOnlyIdentityOp<"BroadcastInDimOp">;

7 changes: 1 addition & 6 deletions src/enzyme_ad/jax/Implementations/MHLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@ class HLODerivative<string opName_, dag patternToMatch, list<dag> resultOps> : M

class HLOInst<string m> : Inst<m, "mhlo">;

def Add : HLOInst<"AddOp">;
def Sub : HLOInst<"mhlo::SubOp">;
def Neg : HLOInst<"mhlo::NegOp">;
def Mul : HLOInst<"mhlo::MulOp">;
def Div : HLOInst<"mhlo::DivOp">;
def Rem : HLOInst<"mhlo::RemOp">;
class HLOReadOnlyIdentityOp<string opName_, list<int> ptrargs_ = [0]> : ReadOnlyIdentityOp<"mhlo", opName_, ptrargs_>;

include "HLODerivatives.td"
11 changes: 3 additions & 8 deletions src/enzyme_ad/jax/Implementations/StableHLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,8 @@ include "Common.td"

class HLODerivative<string opName_, dag patternToMatch, list<dag> resultOps> : MLIRDerivative<"stablehlo", opName_, patternToMatch, resultOps>;

class HLOInst<string m> : Inst<m, "mhlo">;

def Add : HLOInst<"AddOp">;
def Sub : HLOInst<"mhlo::SubOp">;
def Neg : HLOInst<"mhlo::NegOp">;
def Mul : HLOInst<"mhlo::MulOp">;
def Div : HLOInst<"mhlo::DivOp">;
def Rem : HLOInst<"mhlo::RemOp">;
class HLOInst<string m> : Inst<m, "stablehlo">;

class HLOReadOnlyIdentityOp<string opName_, list<int> ptrargs_ = [0]> : ReadOnlyIdentityOp<"stablehlo", opName_, ptrargs_>;

include "HLODerivatives.td"
137 changes: 134 additions & 3 deletions src/enzyme_ad/jax/compile_with_xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,125 @@
#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Implementations/XLADerivatives.h"

#include "mlir/lib/Bindings/Python/IRModule.h"
#include "mlir/CAPI/IR.h"
#include "mlir-c/IR.h"

void prepareRegistry(mlir::DialectRegistry &registry) {
mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry);
mlir::enzyme::registerXLAAutoDiffInterfaces(registry);
}


/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
/// suffix in `lastUsedID`.
static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName, unsigned &lastUsedID,
mlir::ModuleOp module) {
using namespace llvm;
using namespace mlir;
SmallString<64> newSymName(oldSymName);
newSymName.push_back('_');

MLIRContext *ctx = module->getContext();

while (true) {
auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID));
if (!SymbolTable::lookupSymbolIn(module, possible))
return possible;
}

}

/// Checks if a symbol with the same name as `op` already exists in `source`.
/// If so, renames `op` and updates all its references in `target`.
static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op,
mlir::ModuleOp target,
mlir::ModuleOp source,
unsigned &lastUsedID) {
using namespace llvm;
using namespace mlir;
if (!SymbolTable::lookupSymbolIn(source, op.getName()))
return success();

StringRef oldSymName = op.getName();
StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target);

if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
return op.emitError("unable to update all symbol uses for ")
<< oldSymName << " to " << newSymName;

SymbolTable::setSymbolName(op, newSymName);
return success();
}

MlirOperation run_pass_pipeline(llvm::StringRef mlir, const std::string &pass_pipeline) {
using namespace llvm;
using namespace mlir;

auto ins = mlir::python::PyThreadContextEntry::getTopOfStack()->getDefaultInsertionPoint();
auto blk = unwrap(ins->getBlock().get());

auto oldMod = blk->getParent()->getParentOfType<mlir::ModuleOp>();
// Parse MLIR.
mlir::DialectRegistry registry;
prepareRegistry(registry);
oldMod->getContext()->appendDialectRegistry(registry);
mlir::ParserConfig parser_config(oldMod->getContext());
mlir::OwningOpRef<mlir::ModuleOp> parsed_module =
mlir::parseSourceString<mlir::ModuleOp>(mlir, parser_config);

mlir::PassManager pm(oldMod->getContext());

std::string error_message;
llvm::raw_string_ostream error_stream(error_message);
error_stream << "Failed to parse pipeline\n";
mlir::LogicalResult result = mlir::parsePassPipeline(pass_pipeline, pm, error_stream);
if (mlir::failed(result)) {
throw pybind11::value_error(error_message);
}
if (!mlir::succeeded(pm.run(*parsed_module))) {
throw pybind11::value_error("Pipeline failed");
}


StringRef entryfn = "main";

unsigned lastUsedID = 0;

OpBuilder combinedModuleBuilder(oldMod->getContext());
combinedModuleBuilder.setInsertionPointToStart(oldMod.getBody());

Operation* resultOp = nullptr;
for (auto &op : *parsed_module->getBody()) {
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;

StringRef oldSymName = symbolOp.getName();

if (failed(updateSymbolAndAllUses(symbolOp, *parsed_module, oldMod,
lastUsedID)))
throw pybind11::value_error("failed to update all uses");

StringRef newSymName = symbolOp.getName();
if (oldSymName != newSymName) {
if (oldSymName == entryfn) {
entryfn = newSymName;
}
}
Operation* const cloned = op.clone();
if (newSymName == entryfn) {
resultOp = cloned;
}
combinedModuleBuilder.insert(cloned);
}

SymbolTable::setSymbolVisibility(resultOp,
SymbolTable::Visibility::Private);

return wrap(resultOp);
}

// Compile an MHLO module given as a string to LLVM IR using XLA.
std::unique_ptr<xla::LocalExecutable>
compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
Expand All @@ -54,12 +169,28 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
mlir::OwningOpRef<mlir::ModuleOp> parsed_module =
mlir::parseSourceString<mlir::ModuleOp>(mhlo_text, parser_config);

if (!xla_runtime) {
llvm::StringRef cur_pipeline = pass_pipeline;

mlir::PassManager pm(&context);
if (!xla_runtime) {
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
if (!mlir::succeeded(pm.run(*parsed_module))) {
} else {
std::string tofind = "stablehlo-legalize-to-hlo,";
auto pos = llvm::StringRef(pass_pipeline).find(tofind);
assert(pos != std::string::npos);
auto pre = llvm::StringRef(pass_pipeline.data(), pos + tofind.size() - 1);
cur_pipeline = llvm::StringRef(pass_pipeline.data() + pos + tofind.size(), pass_pipeline.size() - pos - tofind.size());

std::string error_message;
llvm::raw_string_ostream error_stream(error_message);
error_stream << "Failed to parse pre stablehlo pipeline\n";
mlir::LogicalResult result = mlir::parsePassPipeline(pre, pm, error_stream);
if (mlir::failed(result)) {
throw pybind11::value_error(error_message);
}
}
if (!mlir::succeeded(pm.run(*parsed_module))) {
throw pybind11::value_error("StableHLO => MHLO failed");
}
}

// Convert to XLA Computation.
Expand Down
3 changes: 3 additions & 0 deletions src/enzyme_ad/jax/compile_with_xla.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
#include "xla/client/local_client.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
#include "mlir-c/IR.h"
// Compile an MHLO module given as a string to LLVM IR using XLA.
std::unique_ptr<xla::LocalExecutable>
compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
bool xla_runtime,
const std::string &pass_pipeline);

MlirOperation run_pass_pipeline(llvm::StringRef mlir, const std::string &pass_pipeline);
4 changes: 4 additions & 0 deletions src/enzyme_ad/jax/enzyme_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#include "Enzyme/MLIR/Passes/Passes.h"

#include "src/enzyme_ad/jax/Passes/Passes.h"
#include "stablehlo/transforms/Passes.h"

enum class ABI { Primal, Forward, Augmented, Reverse, Tape };

Expand Down Expand Up @@ -1031,6 +1032,7 @@ PYBIND11_MODULE(enzyme_call, m) {
mlir::memref::registerMemRefPasses();
mlir::registerenzymePasses();
regsiterenzymeXLAPasses();
mlir::stablehlo::registerPasses();

pybind11::enum_<Language>(m, "Language")
.value("CPP", Language::CPP)
Expand Down Expand Up @@ -1143,6 +1145,8 @@ PYBIND11_MODULE(enzyme_call, m) {
return pybind11::capsule(reinterpret_cast<void *>(&CpuCallback),
"xla._CUSTOM_CALL_TARGET");
});

m.def("run_pass_pipeline", run_pass_pipeline);

m.def("compile_mhlo_to_llvm_with_xla",
[](const std::string &mhlo_text, bool xla_runtime,
Expand Down
Loading

0 comments on commit a996c20

Please sign in to comment.