Skip to content

Commit

Permalink
Initial import of Jasc. (#790)
Browse files Browse the repository at this point in the history
This PR imports Jasc, an MLIR-based compiler for JAX kernels, to the
repo. Most functionality works, but some doesn't, and some shortcuts
have been taken. In particular:

* The GPU code are not part of the import and the corresponding
  tests are deactivate. Maybe we need a better mechanism to choose
  the backend before we can make GPUs work in OSS.
* I have completely removed support for sparse and the related tests.
* I have removed several tools that aren't supported in OSS.
* I have also not spend a lot of effort to reduce the diff between this
  repository and our internal version but there are a few easy things we
  could do to make potential future syncs easier.

What does work, though, is to compile everything and run the tests
(some of which are deactivated):

    bazel build //...
    bazel test test:*
  • Loading branch information
ingomueller-net authored Feb 8, 2024
1 parent c71cde6 commit 43c5bfd
Show file tree
Hide file tree
Showing 66 changed files with 8,198 additions and 0 deletions.
21 changes: 21 additions & 0 deletions jasc/.bazelrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Except where documented otherwise, the options in this file have been copied
# blindly from @EnzymeAD/Enzyme-JAX. The build still seems to work fine without
# any of the `define`s but they are kept just in case.

build --announce_rc

build --experimental_repo_remote_exec
build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17
build --cxxopt=-w --host_cxxopt=-w
build --define=grpc_no_ares=true
build --define=tsl_link_protobuf=true
build --define=open_source_build=true

build --define=framework_shared_object=true
build --define=tsl_protobuf_header_only=true
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true

# Sets the name of JAX's MLIR native extension. This exact value is expected
# by the Python files of JAX.
build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
1 change: 1 addition & 0 deletions jasc/.bazelversion
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
6.4.0
241 changes: 241 additions & 0 deletions jasc/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# Schedules for JAX.

load("@rules_python//python:defs.bzl", "py_library")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
load("@llvm-project//mlir:build_defs.bzl", "cc_headers_only")
load("@rules_license//rules:license.bzl", "license")

license(
name = "license",
package_name = "Jasc",
license_text = "LICENSE",
package_url = "https://github.com/iree-org/iree-llvm-sandbox/tree/main/jasc",
)

package(
default_applicable_licenses = [":license"],
default_visibility = ["//visibility:public"],
)

py_library(
name = "jasc",
srcs = [
"__init__.py",
"jasc.py",
],
deps = [
":call_kernel",
":primitives",
"//dialect:python",
"//transform_ops",
"@jax1//jax",
"@jax1//jaxlib/mlir:bufferization_dialect",
"@jax1//jaxlib/mlir:core",
"@jax1//jaxlib/mlir:ir",
"@jax1//jaxlib/mlir:pdl_dialect",
"@jax1//jaxlib/mlir:transform_dialect",
],
)

py_library(
name = "tuner",
srcs = ["tuner.py"],
deps = [
":jasc",
"//transform_ops",
"@jax1//jax",
"@jax1//jaxlib/mlir:ir",
"@jax1//jaxlib/mlir:transform_dialect",
],
)

py_library(
name = "primitives",
srcs = ["primitives.py"],
deps = [
":call_kernel",
"//dialect:python",
"@jax1//jax",
"@jax1//jax:extend",
"@jax1//jaxlib/mlir:ir",
"@jax1//jaxlib/mlir:pdl_dialect",
"@jax1//jaxlib/mlir:stablehlo_dialect",
"@jax1//jaxlib/mlir:transform_dialect",
],
)

cc_binary(
name = "libmlir_c_runner_utils.so",
linkopts = [
"-Wl,-soname=libmlir_c_runner_utils.so",
"-Wl,-rpath='$$ORIGIN'",
],
linkshared = 1,
deps = ["@llvm-project//mlir:mlir_c_runner_utils"],
)

pybind_extension(
name = "call_kernel",
srcs = ["call_kernel.cc"],
deps = [
":libmlir_c_runner_utils.so",
":mlir_lowering_shared_library",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:ExecutionEngine",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11_abseil//pybind11_abseil:import_status_module",
"@pybind11_abseil//pybind11_abseil:status_casters",
"@status_macros",
],
)

#
# `mlir_lowering` library.
#
# 1. Dependencies only. This allows to get the headers of all dependencies.
cc_library(
name = "mlir_lowering_shared_library_deps",
visibility = ["//visibility:private"],
deps = [
"@com_google_absl//absl/status",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineToStandard",
"@llvm-project//mlir:AllToLLVMIRTranslations",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ArithToLLVM",
"@llvm-project//mlir:ArithTransforms",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:BufferizationTransforms",
"@llvm-project//mlir:ControlFlowToLLVM",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncToLLVM",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
"@llvm-project//mlir:GPUToNVVMTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:IndexToLLVM",
"@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MathToLLVM",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefToLLVM",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:ReconcileUnrealizedCasts",
"@llvm-project//mlir:SCFToControlFlow",
"@llvm-project//mlir:SerializeToCubin_stub",
"@llvm-project//mlir:SparseTensorTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformDialectTransforms",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorToLLVM",
"@llvm-project//mlir:VectorToSCF",
"@xla//xla/mlir_hlo:mhlo_passes",
],
)

cc_headers_only(
name = "mlir_lowering_shared_library_deps_headers",
src = "mlir_lowering_shared_library_deps",
visibility = ["//visibility:private"],
)

# 2. The main library. This shouldn't be used directly in `py_extension`s.
cc_library(
name = "mlir_lowering",
srcs = [
"gpu_lowering_passes.cc",
"mlir_lowering.cc",
],
hdrs = [
"gpu_lowering_passes.h",
"mlir_lowering.h",
],
data = ["gpu_post_bufferize.mlir"],
visibility = [
# `jaxlib_mlir_capi_shared_library` needs to depend on `mlir_lowering`
# because (1) it depends on other targets that need symbols from this
# target and (2) that target cannot depend on
# `mlir_lowering_shared_library` because the reverse dependency must
# exist (since, otherwise, `mlir_lowering_shared_library` would
# duplicate symbols from `jaxlib_mlir_capi_shared_library`).
"@jax1//jaxlib/mlir/_mlir_libs:__pkg__",
],
deps = [
":mlir_lowering_shared_library_deps_headers",
# Only depend on the headers here to avoid duplicate symbols.
"//dialect:dialect_headers",
"//transform_ops:jasc_transform_ops_headers",
],
# This is important since it makes sure that the symbols of the library are
# exported by the `.so` target below even though they aren't used directly.
alwayslink = True,
)

cc_headers_only(
name = "mlir_lowering_headers",
src = "mlir_lowering",
visibility = ["//visibility:private"],
)

# 3. Shared object file. This forces to create a shared library, which dependent
# targets can link against, instead of using the default static linking. This
# ensures that the symbols in that library exist only once instead of once for
# each time it is linked statically.
# This pattern is copied from JAX. A platform independent version exists there.
cc_binary(
name = "libmlirlowering.so",
linkopts = [
"-Wl,-soname=libmlirlowering.so",
"-Wl,-rpath='$$ORIGIN'",
],
linkshared = 1,
visibility = ["//visibility:private"],
deps = [
":mlir_lowering",
"@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library",
],
)

# 4. A `cc_library` wrapper of the shared library. This is the main target.
cc_library(
name = "mlir_lowering_shared_library",
srcs = [
"mlir_lowering.h",
":libmlirlowering.so",
],
deps = [
":mlir_lowering_headers",
":mlir_lowering_shared_library_deps_headers",
"@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library",
],
)

cc_binary(
name = "jasc-opt",
srcs = ["jasc_opt.cc"],
deps = [
":mlir_lowering",
"//dialect",
"//transform_ops:jasc_transform_ops",
"@com_google_absl//absl/status:statusor",
"@llvm-project//mlir:AllExtensions",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:AllToLLVMIRTranslations",
"@llvm-project//mlir:MlirOptLib",
"@xla//xla/mlir_hlo:mhlo_passes",
],
)
Loading

0 comments on commit 43c5bfd

Please sign in to comment.