-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR imports Jasc, an MLIR-based compiler for JAX kernels, to the repo. Most functionality works, but some doesn't, and some shortcuts have been taken. In particular: * The GPU tests fail, and I haven't investigated why beyond the fact that it can't run on my machine without GPU. Maybe we need a better mechanism to choose the backend before we can make GPUs work in OSS. * I have completely removed support for sparse and the related tests. * I have removed several tools that aren't supported in OSS. * The BUILD files are quite hacky and need to be cleaned up. They currently contain remains of two different approaches I tried, one with several shared object files and one with a single one. The latter is the one that is currently used but I *think* that the former can also work; I just gave up that approach at some point due to a weird error only to find out that the same error also occurs with the other approach and is unrelated. * I have also not made any attempt yet to reduce the diff between this repository and our internal version but there are a few easy things we could do to make potential future syncs easier. * I have not added copyright and license information yet. What does work, though, is to compile everything and run the tests (some of which fail): bazel build //... bazel test test:*
- Loading branch information
1 parent
c71cde6
commit 72b5cc7
Showing
68 changed files
with
8,140 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
build --announce_rc | ||
|
||
build --experimental_repo_remote_exec | ||
build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 | ||
build --cxxopt=-w --host_cxxopt=-w | ||
build --define=grpc_no_ares=true | ||
build --define=tsl_link_protobuf=true | ||
build --define open_source_build=true | ||
|
||
build --define framework_shared_object=true | ||
build --define tsl_protobuf_header_only=true | ||
build --define=use_fast_cpp_protos=true | ||
build --define=allow_oversize_protos=true | ||
|
||
build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. | ||
|
||
build -c opt | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
# Schedules for JAX. | ||
|
||
load("@rules_python//python:defs.bzl", "py_library") | ||
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") | ||
load("@llvm-project//mlir:build_defs.bzl", "cc_headers_only") | ||
|
||
package( | ||
# default_applicable_licenses = ["//third_party/mlir_edge:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
py_library( | ||
name = "jasc", | ||
srcs = ["jasc.py", "__init__.py"], | ||
deps = [ | ||
":call_kernel", | ||
":primitives", | ||
"//dialect:python", | ||
"//transform_ops", | ||
"@jax1//jax:jax", | ||
"@jax1//jaxlib/mlir:bufferization_dialect", | ||
"@jax1//jaxlib/mlir:core", | ||
"@jax1//jaxlib/mlir:ir", | ||
"@jax1//jaxlib/mlir:pdl_dialect", | ||
"@jax1//jaxlib/mlir:transform_dialect", | ||
"@jax1//jaxlib/mlir:jasc_dialect", | ||
], | ||
) | ||
|
||
py_library( | ||
name = "tuner", | ||
srcs = ["tuner.py"], | ||
deps = [ | ||
":jasc", | ||
"@jax1//jax:jax", | ||
"@jax1//jaxlib/mlir:ir", | ||
"@jax1//jaxlib/mlir:jasc_dialect", | ||
"@jax1//jaxlib/mlir:transform_dialect", | ||
], | ||
) | ||
|
||
py_library( | ||
name = "primitives", | ||
srcs = ["primitives.py"], | ||
deps = [ | ||
":call_kernel", | ||
"//dialect:python", | ||
"@jax1//jax:jax", | ||
"@jax1//jax:extend", | ||
"@jax1//jaxlib/mlir:ir", | ||
"@jax1//jaxlib/mlir:pdl_dialect", | ||
"@jax1//jaxlib/mlir:stablehlo_dialect", | ||
"@jax1//jaxlib/mlir:transform_dialect", | ||
], | ||
) | ||
|
||
cc_library( | ||
name = "call_kernel_shared_library_deps", | ||
deps = [ | ||
"@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", | ||
":mlir_lowering_shared_library", | ||
# "//third_party/gpus/cuda:cuda_headers", | ||
"@llvm-project//llvm:Support", | ||
"@llvm-project//mlir:CAPIIR", | ||
"@llvm-project//mlir:ExecutionEngine", | ||
"@llvm-project//mlir:FuncDialect", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:MLIRBindingsPythonHeaders", | ||
"@pybind11_abseil//pybind11_abseil:import_status_module", | ||
"@pybind11_abseil//pybind11_abseil:status_casters", | ||
], | ||
) | ||
|
||
cc_headers_only( | ||
name = "call_kernel_shared_library_deps_headers", | ||
src = "call_kernel_shared_library_deps", | ||
) | ||
|
||
cc_binary( | ||
name = "libcallkernel.so", | ||
linkopts = [ | ||
"-Wl,-soname=libcallkernel.so", | ||
"-Wl,-rpath='$$ORIGIN'", | ||
], | ||
linkshared = 1, | ||
deps = [":call_kernel_shared_library_deps"], | ||
) | ||
|
||
cc_library( | ||
name = "call_kernel_shared_library", | ||
srcs = [":libcallkernel.so"], | ||
deps = [":call_kernel_shared_library_deps_headers"], | ||
) | ||
|
||
cc_binary( | ||
name = "libmlir_c_runner_utils.so", | ||
linkopts = [ | ||
"-Wl,-soname=libmlir_c_runner_utils.so", | ||
"-Wl,-rpath='$$ORIGIN'", | ||
], | ||
linkshared = 1, | ||
deps = ["@llvm-project//mlir:mlir_c_runner_utils",], | ||
) | ||
|
||
pybind_extension( | ||
name = "call_kernel", | ||
srcs = ["call_kernel.cc"], | ||
deps = [ | ||
":call_kernel_shared_library", | ||
":libmlir_c_runner_utils.so", | ||
"@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", | ||
":mlir_lowering_shared_library", | ||
"@com_google_absl//absl/container:flat_hash_map", | ||
"@com_google_absl//absl/log", | ||
"@com_google_absl//absl/status:statusor", | ||
"@com_google_absl//absl/synchronization", | ||
"@llvm-project//mlir:ExecutionEngine", | ||
"@status_macros//:status_macros", | ||
"@pybind11_abseil//pybind11_abseil:import_status_module", | ||
], | ||
) | ||
|
||
cc_library( | ||
name = "mlir_lowering_shared_library_deps", | ||
deps = [ | ||
"@com_google_absl//absl/status", | ||
"@llvm-project//llvm:Support", | ||
"@llvm-project//mlir:AffineToStandard", | ||
"@llvm-project//mlir:AllToLLVMIRTranslations", | ||
"@llvm-project//mlir:ArithDialect", | ||
"@llvm-project//mlir:ArithToLLVM", | ||
"@llvm-project//mlir:ArithTransforms", | ||
"@llvm-project//mlir:BufferizationDialect", | ||
"@llvm-project//mlir:BufferizationTransforms", | ||
"@llvm-project//mlir:ControlFlowToLLVM", | ||
"@llvm-project//mlir:FuncDialect", | ||
"@llvm-project//mlir:FuncToLLVM", | ||
"@llvm-project//mlir:GPUDialect", | ||
"@llvm-project//mlir:GPUToGPURuntimeTransforms", | ||
"@llvm-project//mlir:GPUToNVVMTransforms", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:IndexToLLVM", | ||
"@llvm-project//mlir:LLVMCommonConversion", | ||
"@llvm-project//mlir:LinalgTransforms", | ||
"@llvm-project//mlir:MathToLLVM", | ||
"@llvm-project//mlir:MemRefDialect", | ||
"@llvm-project//mlir:MemRefToLLVM", | ||
"@llvm-project//mlir:MemRefTransforms", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:ReconcileUnrealizedCasts", | ||
"@llvm-project//mlir:SCFToControlFlow", | ||
"@llvm-project//mlir:SerializeToCubin_stub", | ||
"@llvm-project//mlir:SparseTensorTransforms", | ||
"@llvm-project//mlir:Support", | ||
"@llvm-project//mlir:TensorDialect", | ||
"@llvm-project//mlir:TransformDialectTransforms", | ||
"@llvm-project//mlir:Transforms", | ||
"@llvm-project//mlir:VectorToLLVM", | ||
"@llvm-project//mlir:VectorToSCF", | ||
"@xla//xla/mlir_hlo:mhlo_passes", | ||
], | ||
) | ||
|
||
cc_headers_only( | ||
name = "mlir_lowering_shared_library_deps_headers", | ||
src = "mlir_lowering_shared_library_deps", | ||
) | ||
|
||
cc_binary( | ||
name = "libmlirlowering.so", | ||
linkopts = [ | ||
"-Wl,-soname=libmlirlowering.so", | ||
"-Wl,-rpath='$$ORIGIN'", | ||
], | ||
linkshared = 1, | ||
deps = [":mlir_lowering"], | ||
) | ||
|
||
cc_library( | ||
name = "mlir_lowering_shared_library", | ||
srcs = [":libmlirlowering.so", "mlir_lowering.h"], | ||
deps = [":mlir_lowering_shared_library_deps_headers"], | ||
) | ||
|
||
cc_headers_only( | ||
name = "mlir_lowering_shared_library_headers", | ||
src = "mlir_lowering_shared_library", | ||
) | ||
|
||
cc_library( | ||
name = "mlir_lowering", | ||
srcs = [ | ||
"gpu_lowering_passes.cc", | ||
"mlir_lowering.cc", | ||
], | ||
hdrs = [ | ||
"gpu_lowering_passes.h", | ||
"mlir_lowering.h", | ||
], | ||
data = ["gpu_post_bufferize.mlir"], | ||
deps = [ | ||
"//dialect:jasc_dialect_headers", | ||
"//transform_ops:jasc_transform_ops_headers", | ||
":mlir_lowering_shared_library_deps_headers", | ||
], | ||
alwayslink = True, | ||
) | ||
|
||
cc_binary( | ||
name = "jasc-opt", | ||
srcs = ["jasc_opt.cc"], | ||
deps = [ | ||
":mlir_lowering", | ||
"@llvm-project//mlir:AllExtensions", | ||
"@llvm-project//mlir:AllPassesAndDialects", | ||
"@llvm-project//mlir:AllToLLVMIRTranslations", | ||
"@llvm-project//mlir:MlirOptLib", | ||
"@xla//xla/mlir_hlo:mhlo_passes", | ||
"//dialect", | ||
"//transform_ops:jasc_transform_ops_shared_library", | ||
"@com_google_absl//absl/status:statusor", | ||
], | ||
) |
Oops, something went wrong.