Skip to content

Commit

Permalink
Pipelinemod (#33)
Browse files Browse the repository at this point in the history
* With mosaic

* Extend pipeline capabilities

* fixups

* continuing

* continuing

* Update WORKSPACE

* fix

* bump commits

* working

* format

* fix

* fix

* cleanup print

* format

* fixup

* hash

* fix

* bump enzyme commit

* fixup

* fixup

* fixup

* Add dot [fwd]

* fix

* add abi tests
  • Loading branch information
wsmoses authored Feb 18, 2024
1 parent 185b4c0 commit 6dbea3b
Show file tree
Hide file tree
Showing 26 changed files with 1,536 additions and 473 deletions.
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ steps:
bazel --output_user_root=`pwd`/baztmp build :enzyme_ad
cp bazel-bin/*.whl .
python -m pip install *.whl
cd test && python -m pip install "jax[cpu]" && python test.py
cd test && python -m pip install "jax[cpu]" && python test.py && python bench_vs_xla.py
artifact_paths:
- "*.whl"

Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
cd test
nm -C $(python3 -c "from enzyme_ad.jax import enzyme_call; print(enzyme_call.__file__)") | grep ExecutorCache::
python3 test.py
python3 bench_vs_xla.py
cd lit_tests
lit . --verbose
Expand Down
14 changes: 7 additions & 7 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ load("@rules_cc//cc:repositories.bzl", "rules_cc_dependencies")

rules_cc_dependencies()

LLVM_COMMIT = "3a82a1c3f6bdc9259cc4641f66fc76d1e171e382"
LLVM_SHA256 = "c525cdb14bb239695852d696bcd13a6d47e579be18386ba2048515fe7f059153"
LLVM_COMMIT = "5932f3f861f84305bd01050d0af8e0dcb232a8b3"
LLVM_SHA256 = "ffbb065b6b9c2aef72949e84484ce5db3a86f682e7f4910a79eb5236856d259a"
LLVM_TARGETS = ["X86", "AArch64", "AMDGPU", "NVPTX"]

http_archive(
Expand All @@ -30,8 +30,8 @@ http_archive(
load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure")
llvm_configure(name = "llvm-project", targets = LLVM_TARGETS)

XLA_COMMIT = "28d55494cbaf896e52fff2d9a5255eff19e8c072"
XLA_SHA256 = ""
XLA_COMMIT = "6ee7005b0dbe29ba0cd077a690db1555ec6de346"
XLA_SHA256 = "76f36ca2eecb246eb7931f2d77c5e9a32859c42aeb5b235f5c17ef6fc6fa71ae"

http_archive(
name = "xla",
Expand Down Expand Up @@ -60,8 +60,8 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen

pip_install_dependencies()

ENZYME_COMMIT = "bad7df07dd7657c0c6884667d62ad1d9bcfd1d16"
ENZYME_SHA256 = ""
ENZYME_COMMIT = "1e1c0eb1c9b4ae3fa6b0acc2394e305b3fc4e042"
ENZYME_SHA256 = "07eb58bb2b4d877f940b88b87bcb2a8e9f02a9320a23b31697c5a4105cb6f031"

http_archive(
name = "enzyme",
Expand All @@ -70,7 +70,7 @@ http_archive(
urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)],
)

JAX_COMMIT = "f691fe468a8e1f8545f7d624055d58b823ee3201"
JAX_COMMIT = "9a098e922aff62a3b49bd673b9518d97ee599248"
JAX_SHA256 = ""

http_archive(
Expand Down
16 changes: 8 additions & 8 deletions patches/jax.patch
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
--- a/jaxlib/cpu/BUILD
+++ a/jaxlib/cpu/BUILD
@@ -79,7 +79,7 @@ cc_library(
":ducc_fft_flatbuffers_cc",
"@xla//xla/service:custom_call_status",
"@com_github_google_flatbuffers//:flatbuffers",
- "@ducc//:fft",
+ "@ducc//:fft_wrapper",
--- a/jaxlib/mosaic/BUILD
+++ b/jaxlib/mosaic/BUILD
@@ -20,7 +20,7 @@ licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = [
- "//:__subpackages__",
+ "//visibility:public",
],
)

96 changes: 95 additions & 1 deletion src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("@llvm-project//llvm:tblgen.bzl", "gentbl")

licenses(["notice"])

Expand Down Expand Up @@ -39,11 +41,96 @@ py_library(
visibility = ["//visibility:public"]
)


gentbl(
name = "mhlo-derivatives",
tbl_outs = [(
"-gen-mlir-derivatives",
"Implementations/MHLODerivatives.inc",
)],
tblgen = "@enzyme//:enzyme-tblgen",
td_file = "Implementations/MHLODerivatives.td",
td_srcs = [
"Implementations/MHLODerivatives.td",
"Implementations/Common.td",
"Implementations/HLODerivatives.td",
],
deps = [
"@enzyme//:enzyme-tblgen",
],
)

gentbl(
name = "stablehlo-derivatives",
tbl_outs = [(
"-gen-mlir-derivatives",
"Implementations/StableHLODerivatives.inc",
)],
tblgen = "@enzyme//:enzyme-tblgen",
td_file = "Implementations/StableHLODerivatives.td",
td_srcs = [
"Implementations/StableHLODerivatives.td",
"Implementations/Common.td",
"Implementations/HLODerivatives.td",
],
deps = [
"@enzyme//:enzyme-tblgen",
],
)

td_library(
name = "EnzymeXLAPassesTdFiles",
srcs = [
],
deps = [
"@llvm-project//mlir:PassBaseTdFiles",
],
)

gentbl_cc_library(
name = "EnzymeXLAPassesIncGen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=enzymexla",
],
"Passes/Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes/Passes.td",
deps = [":EnzymeXLAPassesTdFiles"],
)

cc_library(
name = "XLADerivatives",
srcs = glob(
[
"Implementations/*.cpp",
"Passes/*.cpp",
],
),
hdrs = glob([
"Implementations/*.h",
"Passes/*.h",
]),
deps = [
":EnzymeXLAPassesIncGen",
":mhlo-derivatives",
":stablehlo-derivatives",
"@stablehlo//:stablehlo_ops",
"@xla//xla/mlir_hlo",
"@enzyme//:EnzymeMLIR",
]
)

pybind_library(
name = "compile_with_xla",
srcs = ["compile_with_xla.cc"],
hdrs = ["compile_with_xla.h"],
hdrs = glob(["compile_with_xla.h", "Implementations/*.h", "Passes/*.h"]),
deps = [
":XLADerivatives",
# This is similar to xla_binary rule and is needed to make XLA client compile.
"@tsl//tsl/framework:allocator",
"@tsl//tsl/framework:allocator_registry_impl",
Expand Down Expand Up @@ -99,6 +186,12 @@ pybind_library(
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:Parser",

# EnzymeMLIR
"@enzyme//:EnzymeMLIR",

# Mosaic
"@jax//jaxlib/mosaic:tpu_dialect",
],
)

Expand All @@ -116,6 +209,7 @@ pybind_extension(
":clang_compile",
":compile_with_xla",
"@com_google_absl//absl/status:statusor",
"@stablehlo//:stablehlo_passes",
"@xla//xla/stream_executor:stream_executor_impl",
],
visibility = ["//visibility:public"],
Expand Down
89 changes: 89 additions & 0 deletions src/enzyme_ad/jax/Implementations/Common.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
class InactiveOp<string dialect_, string opName_> {
string dialect = dialect_;
string opName = opName_;
}

class AllocationOp<string dialect_, string opName_> {
string dialect = dialect_;
string opName = opName_;
}

class ControlFlowOp<string dialect_, string opName_, string impl_> {
string dialect = dialect_;
string opName = opName_;
string impl = impl_;
}

class MemoryIdentityOp<string dialect_, string opName_, list<int> ptrargs_, list<int> storedargs_ = []> {
string dialect = dialect_;
string opName = opName_;
list<int> ptrargs = ptrargs_;
list<int> storedargs = storedargs_;
}

class ReadOnlyIdentityOp<string dialect_, string opName_, list<int> ptrargs_> : MemoryIdentityOp<dialect_, opName_, ptrargs_>;

class BranchOp<string dialect_, string opName_> {
string dialect = dialect_;
string opName = opName_;
}

class RegionTerminatorOp<string dialect_, string opName_> {
string dialect = dialect_;
string opName = opName_;
}

class ForwardFromSummedReverseInternal<int unused_> {
int unused = unused_;
}
def ForwardFromSummedReverse : ForwardFromSummedReverseInternal<0>;


class MLIRDerivative<string dialect_, string opName_, dag patternToMatch, list<dag> resultOps, dag forwardOps=(ForwardFromSummedReverse)> {
string dialect = dialect_;
string opName = opName_;
dag PatternToMatch = patternToMatch;
list<dag> ArgDerivatives = resultOps;
dag ArgDuals = forwardOps;
}

class Operation<bit usesPrimal_, bit usesShadow_, bit usesCustom_=0> {
bit usesPrimal = usesPrimal_;
bit usesShadow = usesShadow_;
bit usesCustom = usesCustom_;
}

class DiffeRetIndex<list<int> indices_> {
list<int> indices = indices_;
}
def DiffeRet : DiffeRetIndex<[-1]>;

def Shadow : Operation</*primal*/0, /*shadow*/1> {
}

class GlobalExpr<bit uses_primal, bit uses_shadow, string val> : Operation<uses_primal, uses_shadow>{
string value = val;
}

class Inst<string mnemonic, string dialect_> : Operation</*primal*/1, /*shadow*/0> {
string name = mnemonic;
string dialect = dialect_;
}

class ConstantFP<string val, string dialect_, string op_, string type_=""> : Operation</*primal*/0, /*shadow*/0> {
string value = val;
string dialect = dialect_;
string opName = op_;
string type = type_;
}

def SelectIfActive : Operation</*primal*/0, /*shadow*/0, /*custom*/1> {

}

def Op {
}

def ResultTypes : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "op->getResultTypes()">;


Loading

0 comments on commit 6dbea3b

Please sign in to comment.