Skip to content

Commit

Permalink
Cleanup bazel files (#51)
Browse files Browse the repository at this point in the history
* cleanup baze

* buildifier
  • Loading branch information
wsmoses authored Mar 10, 2024
1 parent 5c65020 commit 405b74d
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 53 deletions.
32 changes: 19 additions & 13 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ package(

py_package(
name = "enzyme_jax_data",
# Only include these Python packages.
packages = [
"@//src/enzyme_ad/jax:enzyme_call.so",
"@llvm-project//clang:builtin_headers_gen",
],
deps = [
"//src/enzyme_ad/jax:enzyme_call.so",
"@llvm-project//clang:builtin_headers_gen",
],
# Only include these Python packages.
packages = ["@//src/enzyme_ad/jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen"],
)

cc_binary(
Expand Down Expand Up @@ -49,17 +52,11 @@ cc_binary(

py_wheel(
name = "enzyme_ad",
author = "Enzyme Authors",
author_email = "[email protected], [email protected]",
distribution = "enzyme_ad",
summary = "Enzyme automatic differentiation tool.",
homepage = "https://enzyme.mit.edu/",
project_urls = {
"GitHub": "https://github.com/EnzymeAD/Enzyme-JAX/",
},
author="Enzyme Authors",
license="LLVM",
author_email="[email protected], [email protected]",
python_tag = "py3",
version = "0.0.6",
license = "LLVM",
platform = select({
"@bazel_tools//src/conditions:windows_x64": "win_amd64",
"@bazel_tools//src/conditions:darwin_arm64": "macosx_11_0_arm64",
Expand All @@ -68,11 +65,20 @@ py_wheel(
"@bazel_tools//src/conditions:linux_x86_64": "manylinux2014_x86_64",
"@bazel_tools//src/conditions:linux_ppc64le": "manylinux2014_ppc64le",
}),
deps = ["//src/enzyme_ad/jax:enzyme_jax_internal", ":enzyme_jax_data"],
strip_path_prefixes = ["src/"],
project_urls = {
"GitHub": "https://github.com/EnzymeAD/Enzyme-JAX/",
},
python_tag = "py3",
requires = [
"absl_py >= 2.0.0",
"jax >= 0.4.21",
"jaxlib >= 0.4.21",
],
strip_path_prefixes = ["src/"],
summary = "Enzyme automatic differentiation tool.",
version = "0.0.6",
deps = [
":enzyme_jax_data",
"//src/enzyme_ad/jax:enzyme_jax_internal",
],
)
73 changes: 37 additions & 36 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
load("@jax//jaxlib:symlink_files.bzl", "symlink_inputs")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("@llvm-project//llvm:tblgen.bzl", "gentbl")

exports_files(["enzymexlamlir-opt.cpp"])

licenses(["notice"])

package(
Expand All @@ -30,9 +30,9 @@ pybind_library(
"@llvm-project//llvm:AsmParser",
"@llvm-project//llvm:CodeGen",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:MC",
"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:Linker",
"@llvm-project//llvm:MC",
"@llvm-project//llvm:OrcJIT",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TargetParser",
Expand All @@ -41,8 +41,11 @@ pybind_library(

py_library(
name = "enzyme_jax_internal",
srcs = ["primitives.py", "__init__.py"],
visibility = ["//visibility:public"]
srcs = [
"__init__.py",
"primitives.py",
],
visibility = ["//visibility:public"],
)

symlink_inputs(
Expand All @@ -56,7 +59,7 @@ symlink_inputs(
td_library(
name = "ImplementationsCommonTdFiles",
srcs = [
":EnzymeImplementationsCommonTdFiles",
":EnzymeImplementationsCommonTdFiles",
],
deps = [
":EnzymeImplementationsCommonTdFiles",
Expand All @@ -76,8 +79,8 @@ gentbl_cc_library(
"Implementations/HLODerivatives.td",
],
deps = [
"@enzyme//:enzyme-tblgen",
":ImplementationsCommonTdFiles",
"@enzyme//:enzyme-tblgen",
],
)

Expand All @@ -94,8 +97,8 @@ gentbl_cc_library(
"Implementations/HLODerivatives.td",
],
deps = [
"@enzyme//:enzyme-tblgen",
":EnzymeImplementationsCommonTdFiles",
"@enzyme//:enzyme-tblgen",
],
)

Expand Down Expand Up @@ -146,27 +149,31 @@ cc_library(
":EnzymeXLAPassesIncGen",
":mhlo-derivatives",
":stablehlo-derivatives",
"@stablehlo//:stablehlo_ops",
"@stablehlo//:stablehlo_passes",
"@stablehlo//:reference_ops",
"@enzyme//:EnzymeMLIR",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:CommonFolders",
"@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:ControlFlowInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:CommonFolders",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
"@stablehlo//:reference_ops",
"@stablehlo//:stablehlo_ops",
"@stablehlo//:stablehlo_passes",
"@xla//xla/mlir_hlo",
"@enzyme//:EnzymeMLIR",
]
],
)

pybind_library(
name = "compile_with_xla",
srcs = ["compile_with_xla.cc"],
hdrs = glob(["compile_with_xla.h", "Implementations/*.h", "Passes/*.h"]),
hdrs = glob([
"compile_with_xla.h",
"Implementations/*.h",
"Passes/*.h",
]),
deps = [
":XLADerivatives",
# This is similar to xla_binary rule and is needed to make XLA client compile.
Expand All @@ -193,7 +200,7 @@ pybind_library(
"@xla//xla/client:client_library",
"@xla//xla/client:executable_build_options",
"@xla//xla/client:xla_computation",
"@xla//xla/service:service",
"@xla//xla/service",
"@xla//xla/service:local_service",
"@xla//xla/service:local_service_utils",
"@xla//xla/service:buffer_assignment_proto_cc",
Expand All @@ -212,7 +219,6 @@ pybind_library(
"@xla//xla:xla_data_proto_cc_impl",
"@xla//xla:xla_proto_cc",
"@xla//xla:xla_proto_cc_impl",

"@stablehlo//:stablehlo_ops",

# Make CPU target available to XLA.
Expand All @@ -221,7 +227,6 @@ pybind_library(
# MHLO stuff.
"@xla//xla/mlir_hlo",
"@xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo",

"@xla//xla/hlo/ir:hlo",

# This is necessary for XLA protobufs to link
Expand All @@ -233,50 +238,46 @@ pybind_library(
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncExtensions",
"@llvm-project//mlir:TensorDialect",

"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",

"@xla//xla/mlir_hlo:all_passes",
"@xla//xla:printer",

# EnzymeMLIR
# EnzymeMLIR
"@enzyme//:EnzymeMLIR",

"@com_google_absl//absl/status:statusor",

# Mosaic
"@jax//jaxlib/mosaic:tpu_dialect",
"@jax//jaxlib/mosaic:tpu_dialect",
],
)

pybind_extension(
name = "enzyme_call",
srcs = ["enzyme_call.cc"],
visibility = ["//visibility:public"],
deps = [
":clang_compile",
":compile_with_xla",
"@com_google_absl//absl/status:statusor",
"@enzyme//:EnzymeMLIR",
"@enzyme//:EnzymeStatic",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:ExecutionEngine",
"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:OrcJIT",
"@llvm-project//llvm:OrcTargetProcess",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
":clang_compile",
":compile_with_xla",
"@com_google_absl//absl/status:statusor",
"@stablehlo//:stablehlo_passes",
"@xla//xla/stream_executor:stream_executor_impl",
"@xla//xla/hlo/ir:hlo",
"@xla//xla/mlir/backends/cpu/transforms:passes",
"@xla//xla/mlir/memref/transforms:passes",
"@xla//xla/mlir/runtime/transforms:passes",
"@xla//xla/mlir_hlo:all_passes",
"@xla//xla/mlir_hlo:deallocation_passes",
"@xla//xla/mlir_hlo:lhlo",
"@xla//xla/mlir_hlo:all_passes",
"@xla//xla/hlo/ir:hlo",
"@xla//xla/service/cpu:cpu_executable",
"@enzyme//:EnzymeStatic",
"@enzyme//:EnzymeMLIR",

"@xla//xla/stream_executor:stream_executor_impl",
],
visibility = ["//visibility:public"],
)
8 changes: 4 additions & 4 deletions test/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
load("@rules_python//python:py_test.bzl", "py_test")
load("@llvm-project//llvm:lit_test.bzl", "package_path", "lit_test")
load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path")
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")

expand_template(
Expand Down Expand Up @@ -30,8 +30,8 @@ exports_files(
data = [
":lit.cfg.py",
":lit_site_cfg_py",
"//src/enzyme_ad/jax:enzyme_jax_internal",
"//:enzymexlamlir-opt",
"//src/enzyme_ad/jax:enzyme_jax_internal",
"@llvm-project//clang:builtin_headers_gen",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:count",
Expand All @@ -40,7 +40,8 @@ exports_files(
)
for src in glob(
[
"**/*.pyt", "**/*.mlir",
"**/*.pyt",
"**/*.mlir",
],
)
]
Expand All @@ -65,7 +66,6 @@ py_test(
],
)


py_test(
name = "llama",
srcs = [
Expand Down

0 comments on commit 405b74d

Please sign in to comment.