Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial import of Jasc. #790

Merged
merged 21 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
72b5cc7
Initial import of Jasc.
ingomueller-net Oct 30, 2023
7492d5a
Fix Bazel version with .bazelversion file.
ingomueller-net Jan 15, 2024
eec8576
Add license information to file headers and BUILD files.
ingomueller-net Jan 15, 2024
4fd8bfd
Run buildifier on all BUILD files for linting.
ingomueller-net Jan 15, 2024
a9fe031
Document options in .bazelrc.
ingomueller-net Jan 15, 2024
7021ced
Add hacky patch to make lit work in bazel.
ingomueller-net Jan 15, 2024
b02b947
Add missing BUILD dependencies for tests.
ingomueller-net Jan 15, 2024
e3b1e34
Fix crash in initialization of ExecutionEngineOptions.
ingomueller-net Jan 15, 2024
3e3a5db
Make path/label to requirements.txt absolute.
ingomueller-net Jan 15, 2024
3101624
Add compile command extractor to WORKSPACE.
ingomueller-net Jan 15, 2024
941e6c0
Adapt header guards to new source location.
ingomueller-net Jan 15, 2024
0ffc5f4
Update link to external location on Github.
ingomueller-net Jan 15, 2024
65757ea
Add trailing new line to all files that didn't have one.
ingomueller-net Jan 15, 2024
fd440ba
Document and update Python requirements.
ingomueller-net Jan 16, 2024
a5ace9f
Remove unnecessary WORKSPACE patches. Simplify remaining file names.
ingomueller-net Jan 16, 2024
afbfeb7
Remove all references to CUDA.
ingomueller-net Jan 16, 2024
892e39f
Reduce diff to internal version
ingomueller-net Jan 19, 2024
a9286a1
Massively reduce changes to BUILD files.
ingomueller-net Jan 17, 2024
df0110b
Simplify patches: remove unnecessary changes and squash.
ingomueller-net Jan 19, 2024
7073c15
Deactivate GPU-related tests.
ingomueller-net Jan 19, 2024
57b215c
Remove `-c opt` from .bazelrc.
ingomueller-net Feb 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions jasc/.bazelrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Except where documented otherwise, the options in this file have been copied
# blindly from @EnzymeAD/Enzyme-JAX. The build still seems to work fine without
# any of the `define`s but they are kept just in case.

build --announce_rc

build --experimental_repo_remote_exec
build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17
build --cxxopt=-w --host_cxxopt=-w
build --define=grpc_no_ares=true
build --define=tsl_link_protobuf=true
build --define=open_source_build=true

build --define=framework_shared_object=true
build --define=tsl_protobuf_header_only=true
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true

# Sets the name of JAX's MLIR native extension. This exact value is expected
# by the Python files of JAX.
build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
1 change: 1 addition & 0 deletions jasc/.bazelversion
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
6.4.0
241 changes: 241 additions & 0 deletions jasc/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# Schedules for JAX.

load("@rules_python//python:defs.bzl", "py_library")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
load("@llvm-project//mlir:build_defs.bzl", "cc_headers_only")
load("@rules_license//rules:license.bzl", "license")

license(
name = "license",
package_name = "Jasc",
license_text = "LICENSE",
package_url = "https://github.com/iree-org/iree-llvm-sandbox/tree/main/jasc",
)

package(
default_applicable_licenses = [":license"],
default_visibility = ["//visibility:public"],
)

py_library(
name = "jasc",
srcs = [
"__init__.py",
"jasc.py",
],
deps = [
":call_kernel",
":primitives",
"//dialect:python",
"//transform_ops",
"@jax1//jax",
"@jax1//jaxlib/mlir:bufferization_dialect",
"@jax1//jaxlib/mlir:core",
"@jax1//jaxlib/mlir:ir",
"@jax1//jaxlib/mlir:pdl_dialect",
"@jax1//jaxlib/mlir:transform_dialect",
],
)

py_library(
name = "tuner",
srcs = ["tuner.py"],
deps = [
":jasc",
"//transform_ops",
"@jax1//jax",
"@jax1//jaxlib/mlir:ir",
"@jax1//jaxlib/mlir:transform_dialect",
],
)

py_library(
name = "primitives",
srcs = ["primitives.py"],
deps = [
":call_kernel",
"//dialect:python",
"@jax1//jax",
"@jax1//jax:extend",
"@jax1//jaxlib/mlir:ir",
"@jax1//jaxlib/mlir:pdl_dialect",
"@jax1//jaxlib/mlir:stablehlo_dialect",
"@jax1//jaxlib/mlir:transform_dialect",
],
)

cc_binary(
name = "libmlir_c_runner_utils.so",
linkopts = [
"-Wl,-soname=libmlir_c_runner_utils.so",
"-Wl,-rpath='$$ORIGIN'",
],
linkshared = 1,
deps = ["@llvm-project//mlir:mlir_c_runner_utils"],
)

pybind_extension(
name = "call_kernel",
srcs = ["call_kernel.cc"],
deps = [
":libmlir_c_runner_utils.so",
":mlir_lowering_shared_library",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:ExecutionEngine",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@pybind11_abseil//pybind11_abseil:import_status_module",
"@pybind11_abseil//pybind11_abseil:status_casters",
"@status_macros",
],
)

#
# `mlir_lowering` library.
#
# 1. Dependencies only. This allows to get the headers of all dependencies.
cc_library(
name = "mlir_lowering_shared_library_deps",
visibility = ["//visibility:private"],
deps = [
"@com_google_absl//absl/status",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineToStandard",
"@llvm-project//mlir:AllToLLVMIRTranslations",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ArithToLLVM",
"@llvm-project//mlir:ArithTransforms",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:BufferizationTransforms",
"@llvm-project//mlir:ControlFlowToLLVM",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncToLLVM",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
"@llvm-project//mlir:GPUToNVVMTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:IndexToLLVM",
"@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MathToLLVM",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefToLLVM",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:ReconcileUnrealizedCasts",
"@llvm-project//mlir:SCFToControlFlow",
"@llvm-project//mlir:SerializeToCubin_stub",
"@llvm-project//mlir:SparseTensorTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformDialectTransforms",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorToLLVM",
"@llvm-project//mlir:VectorToSCF",
"@xla//xla/mlir_hlo:mhlo_passes",
],
)

cc_headers_only(
name = "mlir_lowering_shared_library_deps_headers",
src = "mlir_lowering_shared_library_deps",
visibility = ["//visibility:private"],
)

# 2. The main library. This shouldn't be used directly in `py_extension`s.
cc_library(
name = "mlir_lowering",
srcs = [
"gpu_lowering_passes.cc",
"mlir_lowering.cc",
],
hdrs = [
"gpu_lowering_passes.h",
"mlir_lowering.h",
],
data = ["gpu_post_bufferize.mlir"],
visibility = [
# `jaxlib_mlir_capi_shared_library` needs to depend on `mlir_lowering`
# because (1) it depends on other targets that need symbols from this
# target and (2) that target cannot depend on
# `mlir_lowering_shared_library` because the reverse dependency must
# exist (since, otherwise, `mlir_lowering_shared_library` would
# duplicate symbols from `jaxlib_mlir_capi_shared_library`).
"@jax1//jaxlib/mlir/_mlir_libs:__pkg__",
],
deps = [
":mlir_lowering_shared_library_deps_headers",
# Only depend on the headers here to avoid duplicate symbols.
"//dialect:dialect_headers",
"//transform_ops:jasc_transform_ops_headers",
],
# This is important since it makes sure that the symbols of the library are
# exported by the `.so` target below even though they aren't used directly.
alwayslink = True,
)

cc_headers_only(
name = "mlir_lowering_headers",
src = "mlir_lowering",
visibility = ["//visibility:private"],
)

# 3. Shared object file. This forces to create a shared library, which dependent
# targets can link against, instead of using the default static linking. This
# ensures that the symbols in that library exist only once instead of once for
# each time it is linked statically.
# This pattern is copied from JAX. A platform independent version exists there.
cc_binary(
name = "libmlirlowering.so",
linkopts = [
"-Wl,-soname=libmlirlowering.so",
"-Wl,-rpath='$$ORIGIN'",
],
linkshared = 1,
visibility = ["//visibility:private"],
deps = [
":mlir_lowering",
"@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library",
],
)

# 4. A `cc_library` wrapper of the shared library. This is the main target.
cc_library(
name = "mlir_lowering_shared_library",
srcs = [
"mlir_lowering.h",
":libmlirlowering.so",
],
deps = [
":mlir_lowering_headers",
":mlir_lowering_shared_library_deps_headers",
"@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library",
],
)

cc_binary(
name = "jasc-opt",
srcs = ["jasc_opt.cc"],
deps = [
":mlir_lowering",
"//dialect",
"//transform_ops:jasc_transform_ops",
"@com_google_absl//absl/status:statusor",
"@llvm-project//mlir:AllExtensions",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:AllToLLVMIRTranslations",
"@llvm-project//mlir:MlirOptLib",
"@xla//xla/mlir_hlo:mhlo_passes",
],
)
Loading
Loading