Skip to content

Commit

Permalink
Initial import of Jasc.
Browse files Browse the repository at this point in the history
This PR imports Jasc, an MLIR-based compiler for JAX kernels, to the
repo. Most functionality works, but some doesn't, and some shortcuts
have been taken. In particular:

* The GPU tests fail, and I haven't investigated why beyond the fact
  that it can't run on my machine without GPU. Maybe we need a better
  mechanism to choose the backend before we can make GPUs work in OSS.
* I have completely removed support for sparse and the related tests.
* I have removed several tools that aren't supported in OSS.
* The BUILD files are quite hacky and need to be cleaned up. They
  currently contain remains of two different approaches I tried, one
  with several shared object files and one with a single one. The latter
  is the one that is currently used but I *think* that the former can
  also work; I just gave up that approach at some point due to a weird
  error only to find out that the same error also occurs with the other
  approach and is unrelated.
* I have also not made any attempt yet to reduce the diff between this
  repository and our internal version but there are a few easy things we
  could do to make potential future syncs easier.
* I have not added copyright and license information yet.

What does work, though, is to compile everything and run the tests (some
of which fail):

    bazel build //...
    bazel test test:*
  • Loading branch information
ingomueller-net committed Dec 5, 2023
1 parent c71cde6 commit 72b5cc7
Show file tree
Hide file tree
Showing 68 changed files with 8,140 additions and 0 deletions.
18 changes: 18 additions & 0 deletions jasc/.bazelrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
build --announce_rc

build --experimental_repo_remote_exec
build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17
build --cxxopt=-w --host_cxxopt=-w
build --define=grpc_no_ares=true
build --define=tsl_link_protobuf=true
build --define open_source_build=true

build --define framework_shared_object=true
build --define tsl_protobuf_header_only=true
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true

build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.

build -c opt

223 changes: 223 additions & 0 deletions jasc/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Schedules for JAX.

load("@rules_python//python:defs.bzl", "py_library")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
load("@llvm-project//mlir:build_defs.bzl", "cc_headers_only")

package(
# default_applicable_licenses = ["//third_party/mlir_edge:license"],
default_visibility = ["//visibility:public"],
)

py_library(
name = "jasc",
srcs = ["jasc.py", "__init__.py"],
deps = [
":call_kernel",
":primitives",
"//dialect:python",
"//transform_ops",
"@jax1//jax:jax",
"@jax1//jaxlib/mlir:bufferization_dialect",
"@jax1//jaxlib/mlir:core",
"@jax1//jaxlib/mlir:ir",
"@jax1//jaxlib/mlir:pdl_dialect",
"@jax1//jaxlib/mlir:transform_dialect",
"@jax1//jaxlib/mlir:jasc_dialect",
],
)

py_library(
name = "tuner",
srcs = ["tuner.py"],
deps = [
":jasc",
"@jax1//jax:jax",
"@jax1//jaxlib/mlir:ir",
"@jax1//jaxlib/mlir:jasc_dialect",
"@jax1//jaxlib/mlir:transform_dialect",
],
)

py_library(
name = "primitives",
srcs = ["primitives.py"],
deps = [
":call_kernel",
"//dialect:python",
"@jax1//jax:jax",
"@jax1//jax:extend",
"@jax1//jaxlib/mlir:ir",
"@jax1//jaxlib/mlir:pdl_dialect",
"@jax1//jaxlib/mlir:stablehlo_dialect",
"@jax1//jaxlib/mlir:transform_dialect",
],
)

cc_library(
name = "call_kernel_shared_library_deps",
deps = [
"@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library",
":mlir_lowering_shared_library",
# "//third_party/gpus/cuda:cuda_headers",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:ExecutionEngine",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11_abseil//pybind11_abseil:import_status_module",
"@pybind11_abseil//pybind11_abseil:status_casters",
],
)

cc_headers_only(
name = "call_kernel_shared_library_deps_headers",
src = "call_kernel_shared_library_deps",
)

cc_binary(
name = "libcallkernel.so",
linkopts = [
"-Wl,-soname=libcallkernel.so",
"-Wl,-rpath='$$ORIGIN'",
],
linkshared = 1,
deps = [":call_kernel_shared_library_deps"],
)

cc_library(
name = "call_kernel_shared_library",
srcs = [":libcallkernel.so"],
deps = [":call_kernel_shared_library_deps_headers"],
)

cc_binary(
name = "libmlir_c_runner_utils.so",
linkopts = [
"-Wl,-soname=libmlir_c_runner_utils.so",
"-Wl,-rpath='$$ORIGIN'",
],
linkshared = 1,
deps = ["@llvm-project//mlir:mlir_c_runner_utils",],
)

pybind_extension(
name = "call_kernel",
srcs = ["call_kernel.cc"],
deps = [
":call_kernel_shared_library",
":libmlir_c_runner_utils.so",
"@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library",
":mlir_lowering_shared_library",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@llvm-project//mlir:ExecutionEngine",
"@status_macros//:status_macros",
"@pybind11_abseil//pybind11_abseil:import_status_module",
],
)

cc_library(
name = "mlir_lowering_shared_library_deps",
deps = [
"@com_google_absl//absl/status",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineToStandard",
"@llvm-project//mlir:AllToLLVMIRTranslations",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ArithToLLVM",
"@llvm-project//mlir:ArithTransforms",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:BufferizationTransforms",
"@llvm-project//mlir:ControlFlowToLLVM",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncToLLVM",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
"@llvm-project//mlir:GPUToNVVMTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:IndexToLLVM",
"@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MathToLLVM",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefToLLVM",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:ReconcileUnrealizedCasts",
"@llvm-project//mlir:SCFToControlFlow",
"@llvm-project//mlir:SerializeToCubin_stub",
"@llvm-project//mlir:SparseTensorTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformDialectTransforms",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorToLLVM",
"@llvm-project//mlir:VectorToSCF",
"@xla//xla/mlir_hlo:mhlo_passes",
],
)

cc_headers_only(
name = "mlir_lowering_shared_library_deps_headers",
src = "mlir_lowering_shared_library_deps",
)

cc_binary(
name = "libmlirlowering.so",
linkopts = [
"-Wl,-soname=libmlirlowering.so",
"-Wl,-rpath='$$ORIGIN'",
],
linkshared = 1,
deps = [":mlir_lowering"],
)

cc_library(
name = "mlir_lowering_shared_library",
srcs = [":libmlirlowering.so", "mlir_lowering.h"],
deps = [":mlir_lowering_shared_library_deps_headers"],
)

cc_headers_only(
name = "mlir_lowering_shared_library_headers",
src = "mlir_lowering_shared_library",
)

cc_library(
name = "mlir_lowering",
srcs = [
"gpu_lowering_passes.cc",
"mlir_lowering.cc",
],
hdrs = [
"gpu_lowering_passes.h",
"mlir_lowering.h",
],
data = ["gpu_post_bufferize.mlir"],
deps = [
"//dialect:jasc_dialect_headers",
"//transform_ops:jasc_transform_ops_headers",
":mlir_lowering_shared_library_deps_headers",
],
alwayslink = True,
)

cc_binary(
name = "jasc-opt",
srcs = ["jasc_opt.cc"],
deps = [
":mlir_lowering",
"@llvm-project//mlir:AllExtensions",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:AllToLLVMIRTranslations",
"@llvm-project//mlir:MlirOptLib",
"@xla//xla/mlir_hlo:mhlo_passes",
"//dialect",
"//transform_ops:jasc_transform_ops_shared_library",
"@com_google_absl//absl/status:statusor",
],
)
Loading

0 comments on commit 72b5cc7

Please sign in to comment.