From 72b5cc76f9dab6e46f5cf3608d1c35d9badbc7ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 30 Oct 2023 09:44:40 +0000 Subject: [PATCH 01/21] Initial import of Jasc. This PR imports Jasc, an MLIR-based compiler for JAX kernels, to the repo. Most functionality works, but some doesn't, and some shortcuts have been taken. In particular: * The GPU tests fail, and I haven't investigated why beyond the fact that it can't run on my machine without GPU. Maybe we need a better mechanism to choose the backend before we can make GPUs work in OSS. * I have completely removed support for sparse and the related tests. * I have removed several tools that aren't supported in OSS. * The BUILD files are quite hacky and need to be cleaned up. They currently contain remains of two different approaches I tried, one with several shared object files and one with a single one. The latter is the one that is currently used but I *think* that the former can also work; I just gave up that approach at some point due to a weird error only to find out that the same error also occurs with the other approach and is unrelated. * I have also not made any attempt yet to reduce the diff between this repository and our internal version but there are a few easy things we could do to make potential future syncs easier. * I have not added copyright and license information yet. What does work, though, is to compile everything and run the tests (some of which fail): bazel build //... bazel test test:* --- jasc/.bazelrc | 18 + jasc/BUILD | 223 ++++ jasc/WORKSPACE | 207 ++++ jasc/__init__.py | 0 jasc/call_kernel.cc | 448 +++++++ jasc/dialect/BUILD | 194 ++++ jasc/dialect/__init__.py | 4 + jasc/dialect/_ods_common.py | 6 + jasc/dialect/bindings.cc | 26 + jasc/dialect/capi.cc | 6 + jasc/dialect/capi.h | 4 + jasc/dialect/dialect.cc | 20 + jasc/dialect/dialect.h | 9 + jasc/dialect/dialect.td | 16 + jasc/dialect/jasc.py | 2 + jasc/dialect/ops.cc | 15 + jasc/dialect/ops.h | 13 + jasc/dialect/ops.td | 33 + jasc/dialect/ops_py.td | 6 + jasc/external/requirements-top-level.txt | 4 + jasc/external/requirements.txt | 16 + jasc/gpu_lowering_passes.cc | 231 ++++ jasc/gpu_lowering_passes.h | 43 + jasc/gpu_post_bufferize.mlir | 9 + jasc/jasc.py | 1358 ++++++++++++++++++++++ jasc/jasc_opt.cc | 30 + jasc/mlir_lowering.cc | 503 ++++++++ jasc/mlir_lowering.h | 21 + jasc/patches/apply.sh | 12 + jasc/patches/clang_macos.patch | 13 + jasc/patches/jax.patch | 525 +++++++++ jasc/patches/jax_workspace.patch | 22 + jasc/patches/llvm_build.patch | 94 ++ jasc/patches/stablehlo_build.patch | 10 + jasc/patches/xla.patch | 100 ++ jasc/primitives.py | 208 ++++ jasc/test/BUILD | 262 +++++ jasc/test/abstractions.py | 1186 +++++++++++++++++++ jasc/test/autotuning.py | 80 ++ jasc/test/batch_matmul_gpu.py | 58 + jasc/test/bindings.py | 83 ++ jasc/test/cpu_integration.py | 111 ++ jasc/test/diagnostics.py | 59 + jasc/test/filecheck_test.sh | 10 + jasc/test/fold_fill_into_pad.mlir | 35 + jasc/test/gpu_integration.py | 30 + jasc/test/jit.py | 73 ++ jasc/test/lit.cfg.py | 39 + jasc/test/lit.site.cfg.in.py | 16 + jasc/test/matmul_cpu.py | 122 ++ jasc/test/matmul_gpu.py | 237 ++++ jasc/test/normalization.py | 302 +++++ jasc/test/parametric_schedule.mlir | 39 + jasc/test/synchronize.mlir | 21 + jasc/test/tag.py | 101 ++ jasc/test/test.mlir | 8 + jasc/test/wrap-in-cpu-launch.mlir | 40 + jasc/transform_ops/BUILD | 135 +++ jasc/transform_ops/_ods_common.py | 6 + jasc/transform_ops/_transform_ops_gen.py | 7 + jasc/transform_ops/bindings.cpp | 20 + jasc/transform_ops/dialect_extension.cc | 26 + jasc/transform_ops/dialect_extension.h | 10 + jasc/transform_ops/jasc_transform_ops.cc | 220 ++++ jasc/transform_ops/jasc_transform_ops.h | 13 + jasc/transform_ops/jasc_transform_ops.py | 29 + jasc/transform_ops/jasc_transform_ops.td | 140 +++ jasc/tuner.py | 173 +++ 68 files changed, 8140 insertions(+) create mode 100644 jasc/.bazelrc create mode 100644 jasc/BUILD create mode 100644 jasc/WORKSPACE create mode 100644 jasc/__init__.py create mode 100644 jasc/call_kernel.cc create mode 100644 jasc/dialect/BUILD create mode 100644 jasc/dialect/__init__.py create mode 100644 jasc/dialect/_ods_common.py create mode 100644 jasc/dialect/bindings.cc create mode 100644 jasc/dialect/capi.cc create mode 100644 jasc/dialect/capi.h create mode 100644 jasc/dialect/dialect.cc create mode 100644 jasc/dialect/dialect.h create mode 100644 jasc/dialect/dialect.td create mode 100644 jasc/dialect/jasc.py create mode 100644 jasc/dialect/ops.cc create mode 100644 jasc/dialect/ops.h create mode 100644 jasc/dialect/ops.td create mode 100644 jasc/dialect/ops_py.td create mode 100644 jasc/external/requirements-top-level.txt create mode 100644 jasc/external/requirements.txt create mode 100644 jasc/gpu_lowering_passes.cc create mode 100644 jasc/gpu_lowering_passes.h create mode 100644 jasc/gpu_post_bufferize.mlir create mode 100644 jasc/jasc.py create mode 100644 jasc/jasc_opt.cc create mode 100644 jasc/mlir_lowering.cc create mode 100644 jasc/mlir_lowering.h create mode 100644 jasc/patches/apply.sh create mode 100644 jasc/patches/clang_macos.patch create mode 100644 jasc/patches/jax.patch create mode 100644 jasc/patches/jax_workspace.patch create mode 100644 jasc/patches/llvm_build.patch create mode 100644 jasc/patches/stablehlo_build.patch create mode 100644 jasc/patches/xla.patch create mode 100644 jasc/primitives.py create mode 100644 jasc/test/BUILD create mode 100644 jasc/test/abstractions.py create mode 100644 jasc/test/autotuning.py create mode 100644 jasc/test/batch_matmul_gpu.py create mode 100644 jasc/test/bindings.py create mode 100644 jasc/test/cpu_integration.py create mode 100644 jasc/test/diagnostics.py create mode 100755 jasc/test/filecheck_test.sh create mode 100644 jasc/test/fold_fill_into_pad.mlir create mode 100644 jasc/test/gpu_integration.py create mode 100644 jasc/test/jit.py create mode 100644 jasc/test/lit.cfg.py create mode 100644 jasc/test/lit.site.cfg.in.py create mode 100644 jasc/test/matmul_cpu.py create mode 100644 jasc/test/matmul_gpu.py create mode 100644 jasc/test/normalization.py create mode 100644 jasc/test/parametric_schedule.mlir create mode 100644 jasc/test/synchronize.mlir create mode 100644 jasc/test/tag.py create mode 100644 jasc/test/test.mlir create mode 100644 jasc/test/wrap-in-cpu-launch.mlir create mode 100644 jasc/transform_ops/BUILD create mode 100644 jasc/transform_ops/_ods_common.py create mode 100644 jasc/transform_ops/_transform_ops_gen.py create mode 100644 jasc/transform_ops/bindings.cpp create mode 100644 jasc/transform_ops/dialect_extension.cc create mode 100644 jasc/transform_ops/dialect_extension.h create mode 100644 jasc/transform_ops/jasc_transform_ops.cc create mode 100644 jasc/transform_ops/jasc_transform_ops.h create mode 100644 jasc/transform_ops/jasc_transform_ops.py create mode 100644 jasc/transform_ops/jasc_transform_ops.td create mode 100644 jasc/tuner.py diff --git a/jasc/.bazelrc b/jasc/.bazelrc new file mode 100644 index 000000000000..baacef0d21ba --- /dev/null +++ b/jasc/.bazelrc @@ -0,0 +1,18 @@ +build --announce_rc + +build --experimental_repo_remote_exec +build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 +build --cxxopt=-w --host_cxxopt=-w +build --define=grpc_no_ares=true +build --define=tsl_link_protobuf=true +build --define open_source_build=true + +build --define framework_shared_object=true +build --define tsl_protobuf_header_only=true +build --define=use_fast_cpp_protos=true +build --define=allow_oversize_protos=true + +build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. + +build -c opt + diff --git a/jasc/BUILD b/jasc/BUILD new file mode 100644 index 000000000000..d78bceeb2042 --- /dev/null +++ b/jasc/BUILD @@ -0,0 +1,223 @@ +# Schedules for JAX. + +load("@rules_python//python:defs.bzl", "py_library") +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") +load("@llvm-project//mlir:build_defs.bzl", "cc_headers_only") + +package( + # default_applicable_licenses = ["//third_party/mlir_edge:license"], + default_visibility = ["//visibility:public"], +) + +py_library( + name = "jasc", + srcs = ["jasc.py", "__init__.py"], + deps = [ + ":call_kernel", + ":primitives", + "//dialect:python", + "//transform_ops", + "@jax1//jax:jax", + "@jax1//jaxlib/mlir:bufferization_dialect", + "@jax1//jaxlib/mlir:core", + "@jax1//jaxlib/mlir:ir", + "@jax1//jaxlib/mlir:pdl_dialect", + "@jax1//jaxlib/mlir:transform_dialect", + "@jax1//jaxlib/mlir:jasc_dialect", + ], +) + +py_library( + name = "tuner", + srcs = ["tuner.py"], + deps = [ + ":jasc", + "@jax1//jax:jax", + "@jax1//jaxlib/mlir:ir", + "@jax1//jaxlib/mlir:jasc_dialect", + "@jax1//jaxlib/mlir:transform_dialect", + ], +) + +py_library( + name = "primitives", + srcs = ["primitives.py"], + deps = [ + ":call_kernel", + "//dialect:python", + "@jax1//jax:jax", + "@jax1//jax:extend", + "@jax1//jaxlib/mlir:ir", + "@jax1//jaxlib/mlir:pdl_dialect", + "@jax1//jaxlib/mlir:stablehlo_dialect", + "@jax1//jaxlib/mlir:transform_dialect", + ], +) + +cc_library( + name = "call_kernel_shared_library_deps", + deps = [ + "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", + ":mlir_lowering_shared_library", + # "//third_party/gpus/cuda:cuda_headers", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MLIRBindingsPythonHeaders", + "@pybind11_abseil//pybind11_abseil:import_status_module", + "@pybind11_abseil//pybind11_abseil:status_casters", + ], +) + +cc_headers_only( + name = "call_kernel_shared_library_deps_headers", + src = "call_kernel_shared_library_deps", +) + +cc_binary( + name = "libcallkernel.so", + linkopts = [ + "-Wl,-soname=libcallkernel.so", + "-Wl,-rpath='$$ORIGIN'", + ], + linkshared = 1, + deps = [":call_kernel_shared_library_deps"], +) + +cc_library( + name = "call_kernel_shared_library", + srcs = [":libcallkernel.so"], + deps = [":call_kernel_shared_library_deps_headers"], +) + +cc_binary( + name = "libmlir_c_runner_utils.so", + linkopts = [ + "-Wl,-soname=libmlir_c_runner_utils.so", + "-Wl,-rpath='$$ORIGIN'", + ], + linkshared = 1, + deps = ["@llvm-project//mlir:mlir_c_runner_utils",], +) + +pybind_extension( + name = "call_kernel", + srcs = ["call_kernel.cc"], + deps = [ + ":call_kernel_shared_library", + ":libmlir_c_runner_utils.so", + "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", + ":mlir_lowering_shared_library", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@llvm-project//mlir:ExecutionEngine", + "@status_macros//:status_macros", + "@pybind11_abseil//pybind11_abseil:import_status_module", + ], +) + +cc_library( + name = "mlir_lowering_shared_library_deps", + deps = [ + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineToStandard", + "@llvm-project//mlir:AllToLLVMIRTranslations", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ArithTransforms", + "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:BufferizationTransforms", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToGPURuntimeTransforms", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexToLLVM", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MemRefToLLVM", + "@llvm-project//mlir:MemRefTransforms", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ReconcileUnrealizedCasts", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:SerializeToCubin_stub", + "@llvm-project//mlir:SparseTensorTransforms", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformDialectTransforms", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorToLLVM", + "@llvm-project//mlir:VectorToSCF", + "@xla//xla/mlir_hlo:mhlo_passes", + ], +) + +cc_headers_only( + name = "mlir_lowering_shared_library_deps_headers", + src = "mlir_lowering_shared_library_deps", +) + +cc_binary( + name = "libmlirlowering.so", + linkopts = [ + "-Wl,-soname=libmlirlowering.so", + "-Wl,-rpath='$$ORIGIN'", + ], + linkshared = 1, + deps = [":mlir_lowering"], +) + +cc_library( + name = "mlir_lowering_shared_library", + srcs = [":libmlirlowering.so", "mlir_lowering.h"], + deps = [":mlir_lowering_shared_library_deps_headers"], +) + +cc_headers_only( + name = "mlir_lowering_shared_library_headers", + src = "mlir_lowering_shared_library", +) + +cc_library( + name = "mlir_lowering", + srcs = [ + "gpu_lowering_passes.cc", + "mlir_lowering.cc", + ], + hdrs = [ + "gpu_lowering_passes.h", + "mlir_lowering.h", + ], + data = ["gpu_post_bufferize.mlir"], + deps = [ + "//dialect:jasc_dialect_headers", + "//transform_ops:jasc_transform_ops_headers", + ":mlir_lowering_shared_library_deps_headers", + ], + alwayslink = True, +) + +cc_binary( + name = "jasc-opt", + srcs = ["jasc_opt.cc"], + deps = [ + ":mlir_lowering", + "@llvm-project//mlir:AllExtensions", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllToLLVMIRTranslations", + "@llvm-project//mlir:MlirOptLib", + "@xla//xla/mlir_hlo:mhlo_passes", + "//dialect", + "//transform_ops:jasc_transform_ops_shared_library", + "@com_google_absl//absl/status:statusor", + ], +) diff --git a/jasc/WORKSPACE b/jasc/WORKSPACE new file mode 100644 index 000000000000..677e3a66cf48 --- /dev/null +++ b/jasc/WORKSPACE @@ -0,0 +1,207 @@ +workspace(name = "jasc") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") + +# +# @rules_cc. +# + +CCRULES_COMMIT = "c8c38f8c710cbbf834283e4777916b68261b359c" +CCRULES_SHA256 = "85723d827f080c5e927334f1fb18a294c0b3f94fee6d6b45945f5cdae6ea0fd4" + +http_archive( + name = "rules_cc", + sha256 = CCRULES_SHA256, + strip_prefix = "rules_cc-" + CCRULES_COMMIT, + urls = [ + "https://github.com/bazelbuild/rules_cc/archive/{commit}.tar.gz".format(commit = CCRULES_COMMIT), + ], +) + +load("@rules_cc//cc:repositories.bzl", "rules_cc_dependencies") + +rules_cc_dependencies() + +# +# @llvm-project. +# + +LLVM_COMMIT = "2f17c9f65e7da50a77101431ddf7f6ed7e1ea92c" +LLVM_SHA256 = "a986740933506ebd1127c8abb64c78655a8c329798f37fd466a8e0f7aa7a5578" +LLVM_TARGETS = ["X86", "AArch64", "AMDGPU"] + +http_archive( + name = "llvm-raw", + build_file_content = "# empty", + sha256 = LLVM_SHA256, + strip_prefix = "llvm-project-" + LLVM_COMMIT, + urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)], + patch_args = ["-p1"], + patches = ["//:patches/llvm_build.patch"] +) + +load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure") + +llvm_configure(name = "llvm-project", targets = LLVM_TARGETS) + +# +# @xla. +# + +XLA_COMMIT = "7ab5df624ff1d98804999b03b21abecd14ec57a6" +XLA_SHA256 = "2b6a3ffdb3acf73eaa9b312407400b09c740450ab2222433890712dd4a402a0f" + +http_archive( + name = "xla", + sha256 = XLA_SHA256, + strip_prefix = "xla-" + XLA_COMMIT, + urls = ["https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)], + patch_args = ["-p1"], + patches = ["//:patches/xla.patch"], +) + +# Note: Further loading below in conjuction with JAX. + +# +# @rules_python. +# + +PYRULES_COMMIT = "fe33a4582c37499f3caeb49a07a78fc7948a8949" +PYRULES_SHA256 = "cfa6957832ae0e0c7ee2ccf455a888a291e8419ed8faf45f4420dd7414d5dd96" + +http_archive( + name = "rules_python", + sha256 = PYRULES_SHA256, + strip_prefix = "rules_python-" + PYRULES_COMMIT, + urls = ["https://github.com/bazelbuild/rules_python/archive/{commit}.tar.gz".format(commit = PYRULES_COMMIT)] +) + +load("@rules_python//python:repositories.bzl", "py_repositories") + +py_repositories() + +load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependencies") + +pip_install_dependencies() + +# +# @jax. +# + +JAX_COMMIT = "32a317f7a43440800e1e39e00ed5f2980e088ab1" +JAX_SHA256 = "6e2147be7360a5c0672b6ba0d654cdb2ac96113b63ef457dfdc76cd50fe69ff1" + +http_archive( + name = "jax1", + sha256 = JAX_SHA256, + strip_prefix = "jax-" + JAX_COMMIT, + urls = ["https://github.com/google/jax/archive/{commit}.tar.gz".format(commit = JAX_COMMIT)], + patch_args = ["-p1"], + patches = ["//:patches/jax.patch"], +) + +# +# Initialize @jax, @xla, and dependencies. +# + +load("@jax1//third_party/xla:workspace.bzl", jax_xla_workspace = "repo") +jax_xla_workspace() + +load("@xla//:workspace4.bzl", "xla_workspace4") +xla_workspace4() + +load("@xla//:workspace3.bzl", "xla_workspace3") +xla_workspace3() + +load("@xla//:workspace2.bzl", "xla_workspace2") +xla_workspace2() + +load("@xla//:workspace1.bzl", "xla_workspace1") +xla_workspace1() + +load("@xla//:workspace0.bzl", "xla_workspace0") +xla_workspace0() + +load("@jax1//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") +flatbuffers() + +load("@jax1//third_party/robin_map:workspace.bzl", robin_map = "repo") +robin_map() + +load("@jax1//third_party/nanobind:workspace.bzl", nanobind = "repo") +nanobind() + +# +# @pybind and friends. +# + +PYBIND_VERSION = "2.11.1" +PYBIND11_SHA256 = "d475978da0cdc2d43b73f30910786759d593a9d8ee05b1b6846d1eb16c6d2e0c" +PYBINDBZL_SHA256 = "e8355ee56c2ff772334b4bfa22be17c709e5573f6d1d561c7176312156c27bd4" +PYBINDABSL_SHA256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + +http_archive( + name = "pybind11_bazel", + sha256 = PYBINDBZL_SHA256, + strip_prefix = "pybind11_bazel-" + PYBIND_VERSION, + urls = ["https://github.com/pybind/pybind11_bazel/archive/refs/tags/v{version}.tar.gz/".format(version = PYBIND_VERSION)], +) + +http_archive( + name = "pybind11", + sha256 = PYBIND11_SHA256, + build_file = "@pybind11_bazel//:pybind11.BUILD", + strip_prefix = "pybind11_bazel-" + PYBIND_VERSION, + urls = ["https://github.com/pybind/pybind11/archive/refs/tags/v{version}.tar.gz/".format(version = PYBIND_VERSION)], +) + +load("@pybind11_bazel//:python_configure.bzl", "python_configure") + +python_configure(name = "local_config_python") + +http_archive( + name = "pybind11_abseil", + sha256 = PYBINDABSL_SHA256, + strip_prefix = "pybind11_abseil-" + PYBIND_VERSION, + urls = ["https://github.com/pybind/pybind11_abseil/archive/refs/tags/v{version}.tar.gz/".format(version = PYBIND_VERSION)], +) + +# +# @com_google_absl and friends. +# + +ABSL_COMMIT = "98eb410c93ad059f9bba1bf43f5bb916fc92a5ea" +ABSL_SHA256 = "aabf6c57e3834f8dc3873a927f37eaf69975d4b28117fc7427dfb1c661542a87" + +http_archive( + name = "com_google_absl", + sha256 = ABSL_SHA256, + strip_prefix = "abseil-cpp-" + ABSL_COMMIT, + urls = ["https://github.com/abseil/abseil-cpp/archive/{commit}.zip".format(commit = ABSL_COMMIT)], +) + +STMACROS_COMMIT = "1592ab2d4b4f92976fc3f4a6cb3a1323a4b549c3" +STMACROS_SHA256 = "4317adf5ff551ab3d39af00cd5a3b965d22b266570609bf198166288083f69c0" + +http_archive( + name = "status_macros", + sha256 = STMACROS_SHA256, + strip_prefix = "status_macros-" + STMACROS_COMMIT, + urls = ["https://github.com/jimrogerz/status_macros/archive/{commit}.zip".format(commit = STMACROS_COMMIT)], +) + +# +# Python dependencies via pip +# + +load("@rules_python//python:pip.bzl", "pip_parse") + +pip_parse( + name = "pip_deps", + requirements_lock = ":requirements.txt", +) + +load("@pip_deps//:requirements.bzl", "install_deps") + +install_deps() diff --git a/jasc/__init__.py b/jasc/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/jasc/call_kernel.cc b/jasc/call_kernel.cc new file mode 100644 index 000000000000..0df443749391 --- /dev/null +++ b/jasc/call_kernel.cc @@ -0,0 +1,448 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +// #include "third_party/gpus/cuda/include/cuda.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/TargetSelect.h" +#include "mlir-c/IR.h" +#include "mlir/CAPI/IR.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/lib/Bindings/Python/IRModule.h" +#include "pybind11/cast.h" +#include "pybind11/pybind11.h" +#include "pybind11_abseil/status_casters.h" +#include "status_macros.h" + +#include "mlir_lowering.h" + +#define VLOG(X) std::cerr + +namespace jasc { +namespace { + +// A "truncated" definition from `struct StridedMemRefType` defined in +// "mlir/include/mlir/ExecutionEngine/CRunnerUtils.h" without `barePtr` and +// `data`. The value of those pointers are not available by the time when the +// metadata is constructed until the CpuKernel is actually invoked. +struct StridedMemRefMD { + int64_t offset; + llvm::SmallVector sizes; + llvm::SmallVector strides; + explicit StridedMemRefMD(size_t rank) : offset(0), sizes(rank, 0), strides(rank, 0) {} +}; + +// A compiled CPU kernel ready to be executed. This class is responsible for +// holding the compiled code. It maintaining a global registry of CPU kernels +// identified by unique ID. +class CpuKernel { + public: + CpuKernel(std::unique_ptr execution_engine, + llvm::SmallVector &&memref_metadata, + int num_inputs, int num_ouputs) + : execution_engine_(std::move(execution_engine)), + memref_metadata_(std::move(memref_metadata)), + num_inputs_(num_inputs), + num_outputs_(num_ouputs) { + absl::WriterMutexLock lock(&global_registry_mutex_); + if (global_registry_ == nullptr) { + global_registry_ = new absl::flat_hash_map(); + } + identifier_ = next_kernel_id_++; + global_registry_->emplace(identifier_, this); + + VLOG(1) << "allocated kernel " << identifier_ << "\n"; + } + + ~CpuKernel() { + absl::WriterMutexLock lock(&global_registry_mutex_); + global_registry_->erase(identifier_); + VLOG(1) << "deallocated kernel " << identifier_ << "\n"; + } + + // A unique identifier for the kernel. + int identifier() const { return identifier_; } + + // Retrieve a kernel given its identifier. + static const CpuKernel *GetKernelById(int id) { + absl::ReaderMutexLock lock(&global_registry_mutex_); + auto it = global_registry_->find(id); + if (it == global_registry_->end()) { + LOG(FATAL) << "unable to find kernel " << id; + } + return it->second; + } + + void Call(void *out, void **ins) const { + std::vector args; + // Each input has one basePtr, one data, and one offset + an array of sizes + // and strides depending on the rank. + size_t flat_args_count = (num_inputs_ + num_outputs_) * 3; + for (const StridedMemRefMD &MD : memref_metadata_) { + flat_args_count += MD.sizes.size(); + flat_args_count += MD.strides.size(); + } + args.reserve(flat_args_count); + + // Reconstructs a memref descriptor structure from a bare pointer. See + // `struct StridedMemRefType` in "mlir/ExecutionEngine/CRunnerUtils.h" + auto pack_args = [&](void *ptr, const StridedMemRefMD& MD) { + args.push_back(ptr); // basePtr + args.push_back(ptr); // data + StridedMemRefMD &tmp_MD = const_cast(MD); + args.push_back(reinterpret_cast(&tmp_MD.offset)); + for (auto &sz : tmp_MD.sizes) { + args.push_back(reinterpret_cast(&sz)); + } + for (auto &sd : tmp_MD.strides) { + args.push_back(reinterpret_cast(&sd)); + } + }; + + for (int i = 0; i < num_inputs_; ++i) { + pack_args(&ins[i], memref_metadata_[i]); + } + + if (num_outputs_ == 1) { + pack_args(&out, memref_metadata_[num_inputs_]); + } else { + void **out_ptrs = reinterpret_cast(out); + for (int i = 0; i < num_outputs_; ++i) { + pack_args(&out_ptrs[i], memref_metadata_[num_inputs_ + i]); + } + } + + assert(args.size() == flat_args_count); + llvm::cantFail(execution_engine_->invokePacked("main", args)); + } + + private: + static absl::Mutex global_registry_mutex_; + static absl::flat_hash_map *global_registry_ + ABSL_GUARDED_BY(global_registry_mutex_); + static int next_kernel_id_ ABSL_GUARDED_BY(global_registry_mutex_); + + std::unique_ptr execution_engine_; + llvm::SmallVector memref_metadata_; + int num_inputs_; + int num_outputs_; + int identifier_; +}; + +absl::Mutex CpuKernel::global_registry_mutex_(absl::kConstInit); +absl::flat_hash_map *CpuKernel::global_registry_ = nullptr; +int CpuKernel::next_kernel_id_ = 0; + +llvm::SmallVector PopulateMemrefMetaData( + mlir::FunctionType ftp) { + // Modified from `fill_sizes_and_strides` defined in cpu_executable.cc + auto fill_metadata = [&](mlir::ArrayRef shape) -> StridedMemRefMD { + StridedMemRefMD MD(shape.size()); + size_t multiplier = 1; + for (int i = static_cast(shape.size()); i > 0; --i) { + size_t position = i - 1; + // Payload using `position` instead of `i`. + size_t size = shape[position]; + MD.sizes[position] = size; + MD.strides[position] = multiplier; + multiplier *= size; + } + return MD; + }; + + llvm::SmallVector ret; + for (auto t : + llvm::concat(ftp.getInputs(), ftp.getResults())) { + auto stp = t.cast(); + ret.push_back(fill_metadata(stp.getShape())); + } + return ret; +} + +absl::StatusOr> CreateCpuKernel( + mlir::python::PyModule &py_module, int num_inputs, int num_outputs, + bool dump_ir) { + mlir::ModuleOp module = unwrap(py_module.get()); + // Fills in the memref metadata according to the function type. + auto entry = llvm::cast(module.lookupSymbol("main")); + auto MDs = PopulateMemrefMetaData(entry.getFunctionType()); + assert(MDs.size() == num_inputs + num_outputs); + + RETURN_IF_ERROR(LowerStableHloToCpuLLVM(module, dump_ir)); + mlir::ExecutionEngineOptions engine_opts; + // TODO(ulysse): Select LLVM opt level. + engine_opts.sharedLibPaths = {"libmlir_c_runner_utils.so"}; + engine_opts.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Default; + auto engineOrError = mlir::ExecutionEngine::create(module, engine_opts); + if (!engineOrError) { + llvm::handleAllErrors( + engineOrError.takeError(), [&](const llvm::StringError &err) { + LOG(FATAL) << "Error while creating execution engine: " + << err.getMessage(); + }); + } + return std::make_unique(llvm::cantFail(std::move(engineOrError)), + std::move(MDs), num_inputs, num_outputs); +} + +// XLA custom call callback that calls a kernel on CPU. The first input is the +// identifier of the kernel. Subsequent inputs are the inputs of the kernel. +void CpuCallback(void *out, void **ins) { + int64_t identifier = *reinterpret_cast(ins[0]); + const CpuKernel *kernel = CpuKernel::GetKernelById(identifier); + kernel->Call(out, (ins + 1)); +} + +// CUstream jasc_cuda_stream = nullptr; + +// class CudaKernel { +// public: +// CudaKernel(std::unique_ptr execution_engine, +// int num_inputs_outputs) +// : execution_engine_(std::move(execution_engine)), +// num_input_outputs_(num_inputs_outputs) {} + +// // Executes the kernel. +// void Call(CUstream stream, void **buffers) const { +// // TODO(ulysse): avoid relying on a global variable. +// CHECK_EQ(jasc_cuda_stream, nullptr); +// jasc_cuda_stream = stream; + +// std::vector inputs; +// inputs.reserve(num_input_outputs_); +// for (int i = 0; i < num_input_outputs_; ++i) { +// inputs.push_back(&buffers[i]); +// } +// llvm::cantFail(execution_engine_->invokePacked("main", inputs)); +// jasc_cuda_stream = nullptr; +// } + +// private: +// std::unique_ptr execution_engine_; +// int num_input_outputs_; +// int num_outputs_; +// }; + +// void CheckCudaError(CUresult result) { +// if (result != CUDA_SUCCESS) { +// const char **error_msg = nullptr; +// cuGetErrorString(result, error_msg); +// LOG(FATAL) << *error_msg; +// } +// } + +// extern "C" void JascCudaLaunchKernel(CUfunction function, intptr_t gridX, +// intptr_t gridY, intptr_t gridZ, +// intptr_t blockX, intptr_t blockY, +// intptr_t blockZ, int32_t smem, +// CUstream stream, void **params, +// void **extra) { +// CheckCudaError(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, +// blockZ, smem, stream, params, extra)); +// } + +// extern "C" CUstream JascCudaStreamCreate() { +// // TODO(ulysse): explicitly pass the stream instead of relying on a global +// // variable. +// return jasc_cuda_stream; +// } + +// extern "C" void JascCudaStreamDestroy(CUstream stream) { +// // NO-op as we are reusing the stream given by XLA. +// } + +// extern "C" CUmodule JascCudaModuleLoad(void *data) { +// // TODO(ulysse): investigate the performance implications of loading the +// // module on the fly. +// CUmodule module; +// CheckCudaError(cuModuleLoadData(&module, data)); +// return module; +// } + +// extern "C" void JascCudaModuleUnload(CUmodule module) { +// CheckCudaError(cuModuleUnload(module)); +// } + +// extern "C" CUfunction JascCudaModuleGetFunction(CUmodule module, +// const char *name) { +// // TODO(ulysse): investigate the performance implications of loading the +// // function on the fly. +// CUfunction function; +// CheckCudaError(cuModuleGetFunction(&function, module, name)); +// return function; +// } + +// extern "C" void JascCudaStreamSynchronize(CUstream stream) { +// CheckCudaError(cuStreamSynchronize(stream)); +// } + +// extern "C" void *JascCudaMemAlloc(uint64_t size_bytes, CUstream) { +// CUdeviceptr ptr; +// CheckCudaError(cuMemAlloc(&ptr, size_bytes)); +// return reinterpret_cast(ptr); +// } + +// extern "C" void JascCudaMemFree(void *ptr, CUstream) { +// CheckCudaError(cuMemFree(reinterpret_cast(ptr))); +// } + +// extern "C" void JascCudaMemcpy(void *dst, void *src, size_t sizeBytes, +// CUstream stream) { +// CheckCudaError(cuMemcpy(reinterpret_cast(dst), +// reinterpret_cast(src), sizeBytes)); +// } + +// absl::StatusOr> CreateCudaKernel( +// mlir::python::PyModule &py_module, int num_inputs, int num_outputs, +// bool dump_ir) { +// mlir::ModuleOp module = unwrap(py_module.get()); +// RETURN_IF_ERROR(LowerStableHloToGpuLLVM(module, dump_ir)); +// mlir::ExecutionEngineOptions engine_opts; +// // TODO(ulysse): Select LLVM opt level. +// engine_opts.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Default; +// auto engineOrError = mlir::ExecutionEngine::create(module, engine_opts); +// if (!engineOrError) { +// llvm::handleAllErrors( +// engineOrError.takeError(), [&](const llvm::StringError &err) { +// LOG(FATAL) << "Error while creating execution engine: " +// << err.getMessage(); +// }); +// } +// engineOrError.get()->registerSymbols( +// [](llvm::orc::MangleAndInterner interner) { +// auto map = llvm::orc::SymbolMap(); +// auto register_symbol = [&map, &interner](llvm::StringRef name, +// auto *func) { +// auto addr = llvm::orc::ExecutorAddr(reinterpret_cast(func)); +// map[interner(name)] = {addr, llvm::JITSymbolFlags::None}; +// }; +// register_symbol("mgpuLaunchKernel", &JascCudaLaunchKernel); +// register_symbol("mgpuStreamCreate", &JascCudaStreamCreate); +// register_symbol("mgpuStreamDestroy", &JascCudaStreamDestroy); +// register_symbol("mgpuModuleLoad", &JascCudaModuleLoad); +// register_symbol("mgpuModuleUnload", &JascCudaModuleUnload); +// register_symbol("mgpuModuleGetFunction", &JascCudaModuleGetFunction); +// register_symbol("mgpuStreamSynchronize", &JascCudaStreamSynchronize); +// register_symbol("mgpuMemAlloc", &JascCudaMemAlloc); +// register_symbol("mgpuMemFree", &JascCudaMemFree); +// register_symbol("mgpuMemcpy", &JascCudaMemcpy); +// return map; +// }); +// return std::make_unique(llvm::cantFail(std::move(engineOrError)), +// num_inputs + num_outputs); +// } + +// // XLA custom call callback that calls a kernel on GPU. +// void GpuCallback(CUstream stream, void **buffers, const char *opaque, +// size_t opaque_len) { +// CHECK_EQ(opaque_len, sizeof(CudaKernel *)); +// CudaKernel *kernel_call; +// std::memcpy(&kernel_call, opaque, sizeof(CudaKernel *)); +// kernel_call->Call(stream, buffers); +// } + +/// Clears the `PyOperation` (representing Python-level handles to +/// `Operation *`s) that are tracked by the context. This function should be +/// called by any entry point that may modify the IR, which could cause above +/// handles to be dangling. +// void clearOperationsInside(mlir::python::PyModule &py_module) { +// llvm::errs() << "clearOperationsInside\n"; +// MlirOperation op = mlirModuleGetOperation(py_module.get()); +// auto py_op = mlir::python::PyOperation::forOperation( +// py_module.getContext(), op, py_module.getCapsule()); +// llvm::errs() << "got py_op\n"; +// py_module.getContext()->clearOperationsInside(py_op->getOperation()); +// } + +namespace py = ::pybind11; + +PYBIND11_MODULE(call_kernel, m) { + pybind11::google::ImportStatusModule(); + + // Initializes LLVM targets. Must be called before CreateCpuKernel. + m.def("init_llvm", []() { + LLVMInitializeNativeTarget(); + LLVMInitializeNativeAsmPrinter(); + LLVMInitializeNativeAsmParser(); + }); + m.def( + "apply_schedule", + [](mlir::python::PyModule &py_module, bool dump_ir) { + // py_module.getContext()->clearOperationsInside(py_module); + mlir::ModuleOp module = unwrap(py_module.get()); + return ApplyTransformScript(module, dump_ir); + }, + py::arg("module"), py::arg("dump_ir") = false); + + py::class_(m, "CpuKernel") + .def_property_readonly("identifier", &CpuKernel::identifier); + + m.def( + "create_cpu_kernel", + [](mlir::python::PyModule &py_module, int num_inputs, int num_outputs, + bool dump_ir) { + // py_module.getContext()->clearOperationsInside(py_module); + return CreateCpuKernel(py_module, num_inputs, num_outputs, dump_ir); + }, + py::arg("module"), py::arg("num_inputs"), py::arg("num_outputs"), + py::arg("dump_ir") = false); + + m.def("get_cpu_callback", []() { + return pybind11::capsule(reinterpret_cast(&CpuCallback), + "xla._CUSTOM_CALL_TARGET"); + }); + + // py::class_(m, "CudaKernel") + // .def_property_readonly("ptr", [](CudaKernel *kernel) { + // union { + // CudaKernel *ptr; + // char bytes[sizeof(CudaKernel *)]; + // } bytes_ptr; + // bytes_ptr.ptr = kernel; + // return pybind11::bytes(bytes_ptr.bytes, sizeof(CudaKernel *)); + // }); + + // m.def( + // "create_cuda_kernel", + // [](mlir::python::PyModule &py_module, int num_inputs, int num_outputs, + // bool dump_ir) { + // clearOperationsInside(py_module); + // return CreateCudaKernel(py_module, num_inputs, num_outputs, dump_ir); + // }, + // py::arg("module"), py::arg("num_inputs"), py::arg("num_outputs"), + // py::arg("dump_ir") = false); + + // m.def("get_cuda_callback", []() { + // return pybind11::capsule(reinterpret_cast(&GpuCallback), + // "xla._CUSTOM_CALL_TARGET"); + // }); + + m.def( + "lower_to_linalg", + [](mlir::python::PyModule &py_module, bool dump_ir) { + // clearOperationsInside(py_module); + mlir::ModuleOp module = unwrap(py_module.get()); + return LowerStableHloToLinalg(module, dump_ir); + }, + py::arg("module"), py::arg("dump_ir") = false); +} + +} // namespace +} // namespace jasc diff --git a/jasc/dialect/BUILD b/jasc/dialect/BUILD new file mode 100644 index 000000000000..1196cd4623d7 --- /dev/null +++ b/jasc/dialect/BUILD @@ -0,0 +1,194 @@ +# MLIR Dialect to support Jasc transformations. + +load("@rules_python//python:defs.bzl", "py_library") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library") +load("@llvm-project//mlir:build_defs.bzl", "cc_headers_only") +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + # default_applicable_licenses = ["//third_party/mlir_edge:license"], + default_visibility = ["//visibility:public"], +) + +td_library( + name = "td_files", + srcs = [ + "dialect.td", + "ops.td", + ], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:TransformDialectTdFiles", + ], +) + +gentbl_cc_library( + name = "dialect_inc_gen", + tbl_outs = [ + ( + ["-gen-dialect-decls"], + "dialect.h.inc", + ), + ( + ["-gen-dialect-defs"], + "dialect.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "dialect.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "ops_inc_gen", + tbl_outs = [ + ( + ["-gen-op-decls"], + "ops.h.inc", + ), + ( + ["-gen-op-defs"], + "ops.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ops.td", + deps = [ + ":dialect_inc_gen", + ":td_files", + ], +) + +cc_library( + name = "jasc_dialect_shared_library_deps", + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:TransformDialect", + ], +) + +cc_headers_only( + name = "jasc_dialect_shared_library_deps_headers", + src = "jasc_dialect_shared_library_deps", +) + +cc_headers_only( + name = "jasc_dialect_headers", + src = "dialect", +) + +cc_library( + name = "dialect", + srcs = [ + "dialect.cc", + "ops.cc", + ], + hdrs = [ + "dialect.h", + "ops.h", + ], + deps = [ + ":dialect_inc_gen", + ":ops_inc_gen", + ":jasc_dialect_shared_library_deps_headers", + ], + alwayslink = True, +) + +cc_library( + name = "capi", + srcs = [ + "capi.cc", + ], + hdrs = [ + "capi.h", + ], + deps = [ + ":dialect", + "@llvm-project//mlir:CAPIIRHeaders", + ], + alwayslink = True, +) + +cc_library( + name = "capi_headers", + hdrs = [ + "capi.h", + ], + deps = [ + "@llvm-project//mlir:CAPIIRHeaders", + ], +) + +gentbl_filegroup( + name = "ops_py_gen", + tbl_outs = [ + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=jasc", + ], + "_ops_gen.py", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ops_py.td", + deps = [ + ":td_files", + ], +) + +cc_library( + name = "jasc_dialect_shared_library", + srcs = [ + ":libjascdialect.so", + "dialect.h", + "ops.h", + ], + deps = [ + ":dialect_inc_gen", + ":ops_inc_gen", + ":jasc_dialect_shared_library_deps_headers", + ], +) + +cc_binary( + name = "libjascdialect.so", + linkopts = [ + "-Wl,-soname=libjascdialect.so", + "-Wl,-rpath='$$ORIGIN'", + ], + linkshared = 1, + deps = [":dialect"], +) + +pybind_extension( + name = "bindings", + srcs = ["bindings.cc"], + deps = [ + "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", + ":jasc_dialect_headers", + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:MLIRBindingsPythonHeaders", + "//:mlir_lowering_shared_library", + "//transform_ops:jasc_transform_ops_shared_library", + ], +) + +py_library( + name = "python", + srcs = [ + "jasc.py", + # "_ods_common.py", + ":ops_py_gen", + ], + deps = [ + ":bindings", + # "@jax//jaxlib/mlir:core", + # "@jax//jaxlib/mlir:ir", + # "@jax//jaxlib/mlir:pdl_dialect", + ], +) diff --git a/jasc/dialect/__init__.py b/jasc/dialect/__init__.py new file mode 100644 index 000000000000..fe0bd366ee86 --- /dev/null +++ b/jasc/dialect/__init__.py @@ -0,0 +1,4 @@ +"""Python bindings for Jasc MLIR operations.""" + +from dialect.bindings import * +from ._ops_gen import * # pylint: disable=wildcard-import diff --git a/jasc/dialect/_ods_common.py b/jasc/dialect/_ods_common.py new file mode 100644 index 000000000000..abcb9e8b64cd --- /dev/null +++ b/jasc/dialect/_ods_common.py @@ -0,0 +1,6 @@ +"""Trampoline to run generated MLIR Python code. + +Generated tablegen dialects expect to be able to find some symbols from the +mlir.dialects package. +""" +from jaxlib.mlir.dialects._ods_common import _cext, equally_sized_accessor, get_default_loc_context, get_op_result_or_op_results, get_op_result_or_value, get_op_results_or_values, segmented_accessor diff --git a/jasc/dialect/bindings.cc b/jasc/dialect/bindings.cc new file mode 100644 index 000000000000..b2f385712f49 --- /dev/null +++ b/jasc/dialect/bindings.cc @@ -0,0 +1,26 @@ +#include "mlir/CAPI/IR.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/lib/Bindings/Python/IRModule.h" +#include "pybind11/pybind11.h" + +#include "dialect/dialect.h" +#include "mlir_lowering.h" +#include "transform_ops/dialect_extension.h" + +PYBIND11_MODULE(_mlirDialectsJasc, m) { + m.def( + "register_and_load_dialect", + [](MlirContext py_context) { + mlir::MLIRContext *context = unwrap(py_context); + mlir::DialectRegistry registry; + registry.insert(); + jasc::registerTransformDialectExtension(registry); + context->appendDialectRegistry(registry); + context->loadDialect(); + }, + pybind11::arg("context") = pybind11::none()); + + m.def("register_lowering_passes", + []() { jasc::registerMLIRLoweringPasses(); }); +} \ No newline at end of file diff --git a/jasc/dialect/capi.cc b/jasc/dialect/capi.cc new file mode 100644 index 000000000000..9a2737d3e895 --- /dev/null +++ b/jasc/dialect/capi.cc @@ -0,0 +1,6 @@ + +#include "capi.h" +#include "mlir/CAPI/Registration.h" +#include "dialect.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Jasc, jasc, jasc::JascDialect) diff --git a/jasc/dialect/capi.h b/jasc/dialect/capi.h new file mode 100644 index 000000000000..f6ca29199cf1 --- /dev/null +++ b/jasc/dialect/capi.h @@ -0,0 +1,4 @@ + +#include "mlir-c/IR.h" + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Jasc, jasc); diff --git a/jasc/dialect/dialect.cc b/jasc/dialect/dialect.cc new file mode 100644 index 000000000000..7a8eb3c0a8e0 --- /dev/null +++ b/jasc/dialect/dialect.cc @@ -0,0 +1,20 @@ +#include "dialect.h" + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "ops.h" + +// Include code generated from dialect.td. +#include "dialect/dialect.cc.inc" + +namespace jasc { + +void JascDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "dialect/ops.cc.inc" + >(); +} + +} // namespace jasc diff --git a/jasc/dialect/dialect.h b/jasc/dialect/dialect.h new file mode 100644 index 000000000000..e81744d959b2 --- /dev/null +++ b/jasc/dialect/dialect.h @@ -0,0 +1,9 @@ +#ifndef THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_DIALECT_H_ +#define THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_DIALECT_H_ + +#include "mlir/IR/Dialect.h" + +// Include code generated from dialect.td. +#include "dialect/dialect.h.inc" + +#endif // THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_DIALECT_H_ diff --git a/jasc/dialect/dialect.td b/jasc/dialect/dialect.td new file mode 100644 index 000000000000..4c53967a8eb1 --- /dev/null +++ b/jasc/dialect/dialect.td @@ -0,0 +1,16 @@ +#ifndef JASC_DIALECT_DIALECT +#define JASC_DIALECT_DIALECT + +include "mlir/IR/DialectBase.td" + +def Jasc_Dialect : Dialect { + let name = "jasc"; + let cppNamespace = "::jasc"; + let dependentDialects = [ + "::mlir::transform::TransformDialect", + "::mlir::gpu::GPUDialect", + "::mlir::memref::MemRefDialect", + ]; +} + +#endif // JASC_DIALECT_DIALECT \ No newline at end of file diff --git a/jasc/dialect/jasc.py b/jasc/dialect/jasc.py new file mode 100644 index 000000000000..5f8d91892f1f --- /dev/null +++ b/jasc/dialect/jasc.py @@ -0,0 +1,2 @@ +from ._ops_gen import * +from .._mlir_libs._mlirDialectsJasc import * \ No newline at end of file diff --git a/jasc/dialect/ops.cc b/jasc/dialect/ops.cc new file mode 100644 index 000000000000..6e9c5dc5a95a --- /dev/null +++ b/jasc/dialect/ops.cc @@ -0,0 +1,15 @@ +#include "ops.h" + +#include "llvm/include/llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" + +#define GET_OP_CLASSES +#include "dialect/ops.cc.inc" + +namespace jasc { + + +} // namespace jasc \ No newline at end of file diff --git a/jasc/dialect/ops.h b/jasc/dialect/ops.h new file mode 100644 index 000000000000..73226910bed9 --- /dev/null +++ b/jasc/dialect/ops.h @@ -0,0 +1,13 @@ +#ifndef THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_OPS_H_ +#define THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_OPS_H_ + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" + +#define GET_OP_CLASSES +#include "dialect/ops.h.inc" + +#endif // THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_OPS_H_ diff --git a/jasc/dialect/ops.td b/jasc/dialect/ops.td new file mode 100644 index 000000000000..96666bbe5d5f --- /dev/null +++ b/jasc/dialect/ops.td @@ -0,0 +1,33 @@ +#ifndef JASC_DIALECT_OPS +#define JASC_DIALECT_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" + +include "dialect.td" + +def Jasc_TagRegionOp : Op +]> { + let summary = "Tags a region for matching it with the transform dialect."; + + let arguments = (ins StrAttr:$name); + let regions = (region SizedRegion<1>:$body); + let results = (outs Variadic:$results); + + let assemblyFormat = "$name $body attr-dict `:` type($results)"; +} + +def Jasc_ReturnOp : Op { + let summary = "Terminates a tag region."; + + let arguments = (ins Variadic:$operands); + let assemblyFormat = "$operands attr-dict `:` type($operands)"; + + let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; +} + +#endif // JASC_DIALECT_OPS \ No newline at end of file diff --git a/jasc/dialect/ops_py.td b/jasc/dialect/ops_py.td new file mode 100644 index 000000000000..6abc4a345234 --- /dev/null +++ b/jasc/dialect/ops_py.td @@ -0,0 +1,6 @@ +#ifndef JASC_DIALECT_OPSPY +#define JASC_DIALECT_OPSPY + +include "ops.td" + +#endif // JASC_DIALECT_OPSPY \ No newline at end of file diff --git a/jasc/external/requirements-top-level.txt b/jasc/external/requirements-top-level.txt new file mode 100644 index 000000000000..30c598b8d07d --- /dev/null +++ b/jasc/external/requirements-top-level.txt @@ -0,0 +1,4 @@ +absl-py +chex +pytest +PyYAML \ No newline at end of file diff --git a/jasc/external/requirements.txt b/jasc/external/requirements.txt new file mode 100644 index 000000000000..d6f6f5646e80 --- /dev/null +++ b/jasc/external/requirements.txt @@ -0,0 +1,16 @@ +absl-py==2.0.0 +chex==0.1.85 +iniconfig==2.0.0 +jax==0.4.20 +jaxlib==0.4.20 +ml-dtypes==0.3.1 +numpy==1.26.2 +opt-einsum==3.3.0 +packaging==23.2 +pluggy==1.3.0 +pytest==7.4.3 +PyYAML==6.0.1 +scipy==1.11.4 +setuptools==69.0.2 +toolz==0.12.0 +typing_extensions==4.8.0 diff --git a/jasc/gpu_lowering_passes.cc b/jasc/gpu_lowering_passes.cc new file mode 100644 index 000000000000..26053d81bf98 --- /dev/null +++ b/jasc/gpu_lowering_passes.cc @@ -0,0 +1,231 @@ +#include +#include + +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace jasc { + +mlir::FailureOr CreateGpuAlloc(mlir::OpBuilder& builder, + mlir::Location loc, + mlir::MemRefType memref_type, + mlir::ValueRange dyn_sizes, + unsigned int) { + // TODO(ulysse): See if we can simplify this code. Synchronization tokens + // are only needed because the GpuToLLVM pass expects them. + auto token_type = mlir::gpu::AsyncTokenType::get(builder.getContext()); + auto wait_op = + builder.create(loc, token_type, mlir::ValueRange()); + auto alloc_op = builder.create( + loc, mlir::TypeRange({memref_type, token_type}), wait_op.getResults(), + dyn_sizes, mlir::ValueRange()); + return alloc_op.getMemref(); +} + +mlir::LogicalResult CreateGpuDealloc(mlir::OpBuilder& builder, + mlir::Location loc, mlir::Value memref) { + // TODO(ulysse): See if we can simplify this code. Synchronization tokens + // are only needed because the GpuToLLVM pass expects them. + auto token_type = mlir::gpu::AsyncTokenType::get(builder.getContext()); + auto sync_op = + builder.create(loc, token_type, mlir::ValueRange()); + builder.create(loc, mlir::TypeRange({token_type}), + sync_op.getResults(), memref); + return mlir::success(); +} + +mlir::LogicalResult CreateGpuMemCpy(mlir::OpBuilder& builder, + mlir::Location loc, mlir::Value from, + mlir::Value to) { + // TODO(ulysse): See if we can simplify this code. Synchronization tokens + // are only needed because the GpuToLLVM pass expects them. + auto token_type = mlir::gpu::AsyncTokenType::get(builder.getContext()); + auto sync_op = + builder.create(loc, token_type, mlir::ValueRange()); + builder.create(loc, token_type, sync_op.getResults(), to, + from); + return mlir::success(); +} + +// Convert "memref.alloc" with no deallocation into "memref.alloca". +mlir::LogicalResult AllocToAlloca(mlir::memref::AllocOp alloc, + mlir::PatternRewriter& rewriter) { + for (mlir::Operation* user : alloc->getUsers()) { + if (llvm::isa(user)) return mlir::failure(); + } + rewriter.replaceOpWithNewOp(alloc, alloc.getType(), + alloc.getDynamicSizes(), + alloc.getAlignmentAttr()); + return mlir::success(); +} + +namespace { + +// Annotates the tensor alloc operations to use the global memory space. Add +// tensor alloc operations after constants to copy them to the global memory +// space. +class SetDefaultGpuMemorySpace + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SetDefaultGpuMemorySpace) + + // Returns the name of the pass suitable for the pass manager. + llvm::StringRef getArgument() const override { + return "jasc-set-default-gpu-memory-space"; + } + + void runOnOperation() override { + getOperation().walk([](mlir::arith::ConstantOp constant) { + auto tensor_type = constant.getType().dyn_cast(); + if (tensor_type == nullptr) return; + + mlir::OpBuilder builder(constant.getContext()); + builder.setInsertionPointAfter(constant); + auto alloc_op = builder.create( + constant.getLoc(), tensor_type, mlir::ValueRange(), + constant.getResult()); + constant.getResult().replaceAllUsesExcept(alloc_op.getResult(), alloc_op); + }); + getOperation().walk([](mlir::bufferization::AllocTensorOp alloc) { + alloc.setMemorySpaceAttr(mlir::gpu::AddressSpaceAttr::get( + alloc->getContext(), mlir::gpu::AddressSpace::Global)); + }); + } + + void getDependentDialects(mlir::DialectRegistry& registry) const override { + registry.insert(); + } +}; + +// Custom version of GpuToLLVMConversionPass to support memory space +// annotations. +class GpuToLLVMConversionPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GpuToLLVMConversionPass) + + // Returns the name of the pass suitable for the pass manager. + llvm::StringRef getArgument() const override { + return "jasc-gpu-to-llvm-conversion"; + } + + void runOnOperation() override { + mlir::LowerToLLVMOptions options(&getContext()); + options.useOpaquePointers = true; + options.useBarePtrCallConv = true; + + mlir::LLVMTypeConverter converter(&getContext(), options); + converter.addTypeAttributeConversion( + [](mlir::BaseMemRefType type, + mlir::gpu::AddressSpaceAttr memory_space) { + // Erase memory space information. + auto int_type = mlir::IntegerType::get(type.getContext(), 64); + return mlir::IntegerAttr::get(int_type, 0); + }); + + mlir::LLVMConversionTarget target(getContext()); + target.addIllegalDialect(); + mlir::RewritePatternSet patterns(&getContext()); + mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns); + mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); + mlir::populateFuncToLLVMConversionPatterns(converter, patterns); + mlir::populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); + + // TODO(ulysse): change some of the patterns to avoid creating new streams. + // Use a different converter for GPU dialect calls as we don't want to use + // bare pointers here. + mlir::LowerToLLVMOptions gpu_options(&getContext()); + gpu_options.useOpaquePointers = true; + gpu_options.useBarePtrCallConv = false; + mlir::LLVMTypeConverter gpu_converter(&getContext(), gpu_options); + gpu_converter.addTypeAttributeConversion( + [](mlir::BaseMemRefType type, + mlir::gpu::AddressSpaceAttr memory_space) { + // Erase memory space information. + auto int_type = mlir::IntegerType::get(type.getContext(), 64); + return mlir::IntegerAttr::get(int_type, 0); + }); + + mlir::populateGpuToLLVMConversionPatterns( + gpu_converter, patterns, mlir::gpu::getDefaultGpuBinaryAnnotation()); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +mlir::LogicalResult MemcpyToGpuPattern(mlir::memref::CopyOp copy, + mlir::PatternRewriter& rewriter) { + if (copy->getParentOfType() || + copy->getParentOfType()) + return mlir::failure(); + mlir::LogicalResult result = CreateGpuMemCpy( + rewriter, copy.getLoc(), copy.getSource(), copy.getTarget()); + if (mlir::failed(result)) return mlir::failure(); + rewriter.eraseOp(copy); + return mlir::success(); +} + +class MemcpyToGpuDialect + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MemcpyToGpuDialect) + + // Returns the name of the pass suitable for the pass manager. + llvm::StringRef getArgument() const override { + return "jasc-memcpy-to-gpu-dialect"; + } + + void runOnOperation() override { + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(MemcpyToGpuPattern); + patterns.add(AllocToAlloca); + if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr CreateSetDefaultGpuMemorySpacePass() { + return std::make_unique(); +} + +std::unique_ptr CreateGpuToLLVMConversionPass() { + return std::make_unique(); +} + +std::unique_ptr CreateMemcpyToGpuDialectPass() { + return std::make_unique(); +} + +void registerGPULoweringPasses() { + mlir::PassRegistration(); + mlir::PassRegistration(); + mlir::PassRegistration(); +} + +} // namespace jasc diff --git a/jasc/gpu_lowering_passes.h b/jasc/gpu_lowering_passes.h new file mode 100644 index 000000000000..97da37d7cefd --- /dev/null +++ b/jasc/gpu_lowering_passes.h @@ -0,0 +1,43 @@ +#ifndef THIRD_PARTY_MLIR_EDGE_JASC_GPU_LOWERING_PASSES_H_ +#define THIRD_PARTY_MLIR_EDGE_JASC_GPU_LOWERING_PASSES_H_ + +#include + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" + +namespace jasc { + +// Creates a pass that annotates the tensor alloc operations to use the global +// memory space. The pass also adds tensor alloc operations after constants to +// copy them to the global memory space. +std::unique_ptr CreateSetDefaultGpuMemorySpacePass(); + +// Creates a custom version of GpuToLLVMConversionPass to support memory space +// annotations. +std::unique_ptr CreateGpuToLLVMConversionPass(); + +// Creates a pass that converts memref.copy to the GPU dialect. +std::unique_ptr CreateMemcpyToGpuDialectPass(); + +mlir::FailureOr CreateGpuAlloc(mlir::OpBuilder& builder, + mlir::Location loc, + mlir::MemRefType memref_type, + mlir::ValueRange dyn_sizes, + unsigned int); + +mlir::LogicalResult CreateGpuDealloc(mlir::OpBuilder& builder, + mlir::Location loc, mlir::Value memref); + +mlir::LogicalResult CreateGpuMemCpy(mlir::OpBuilder& builder, + mlir::Location loc, mlir::Value from, + mlir::Value to); + +void registerGPULoweringPasses(); + +} // namespace jasc + +#endif // THIRD_PARTY_MLIR_EDGE_JASC_GPU_LOWERING_PASSES_H_ \ No newline at end of file diff --git a/jasc/gpu_post_bufferize.mlir b/jasc/gpu_post_bufferize.mlir new file mode 100644 index 000000000000..5d1bd4be7740 --- /dev/null +++ b/jasc/gpu_post_bufferize.mlir @@ -0,0 +1,9 @@ +// Transform script for GPU post-bufferization codegen. +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + // Introduce gpu.launch ops for every linalg op. + %linalg_ops = transform.structured.match interface {LinalgOp} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.jasc.wrap_in_gpu_launch %linalg_ops + : (!transform.any_op) -> !transform.op<"gpu.launch"> +} \ No newline at end of file diff --git a/jasc/jasc.py b/jasc/jasc.py new file mode 100644 index 000000000000..b68e72edbc0a --- /dev/null +++ b/jasc/jasc.py @@ -0,0 +1,1358 @@ +"""Schedules for JAX. + +To compile a function using Jasc, use Jasc.jit instead of jax.jit. + +``` +def computation(a: jax.Array) -> jax.Array: + return jasc.tag(lambda x: x + 1, "plus1")(a) + +def schedule(h: OpHandle) -> None: + h.match_tag("plus1").tile((4,)) + +jasc.jit(computation, schedule) +``` +""" + +from __future__ import annotations + +import abc +from collections.abc import Callable, Iterable, Sequence +import contextlib +from dataclasses import dataclass, field, fields +from enum import Enum +import itertools +from typing import ( + Any, + ClassVar, + List, + Optional, + Protocol, + Type, + TypeVar, + Union, + overload, +) + +import jax +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import pdl, transform +from jaxlib.mlir.dialects.bufferization import LayoutMapOption +from jaxlib.mlir.dialects.transform import ( + bufferization, + gpu, + loop, + memref, + nvgpu, + sparse_tensor, + structured, +) + +import call_kernel +from jaxlib.mlir.dialects import jasc as jasc_dialect +import primitives +from jaxlib.mlir.dialects.transform import jasc_transform_ops as jto + +_JASC_AUTO_NORMALIZATION = True + +def set_auto_normalization(activate: bool): + """Toggles the automatic normalization mode.""" + global _JASC_AUTO_NORMALIZATION + _JASC_AUTO_NORMALIZATION = activate + + +@contextlib.contextmanager +def autonormalize(activate: bool = True): + """Context manager that switches automatic normalization behavior.""" + autonorm_enabled: bool = _JASC_AUTO_NORMALIZATION + set_auto_normalization(activate) + try: + yield + finally: + set_auto_normalization(autonorm_enabled) + + +StaticIntLike = Union[int, ir.IntegerAttr] +StaticBoolLike = Union[bool, ir.BoolAttr] +ValueLike = Union[ir.Operation, ir.OpView, ir.Value] +MixedInt = Union[StaticIntLike, ValueLike] + +IntOrAttrList = Sequence[StaticIntLike] +OptionalIntList = Optional[Union[ir.ArrayAttr, IntOrAttrList]] + +BoolOrAttrList = Sequence[StaticBoolLike] +OptionalBoolList = Optional[Union[ir.ArrayAttr, BoolOrAttrList]] + +DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]] + + +@dataclass +class MultiHandleResult(abc.ABC): + """Base class for all classes that support returning named handles.""" + + def __iter__(self): + yield from [getattr(self, field.name) for field in fields(self)] + + +@dataclass +class AllocaToGlobalResult(MultiHandleResult): + get_global: OpHandle + global_: OpHandle + + +@dataclass +class BufferizatToAllocationResult(MultiHandleResult): + allocated_buffer: ValueHandle + new_ops: OpHandle + + +@dataclass +class ForeachResult(MultiHandleResult): + _op: Any + + @property + @contextlib.contextmanager + def body(self) -> OpHandle: + """Creates a context manager with an insertion point on the body block. + + Yields an `OpHandle` of the argument of the body block. Inserts a `YieldOp` + at the end of the body block if none is present. + """ + # Set insertion point to body block and yield its argument. + block = self._op.body.blocks[0] + with ir.InsertionPoint(block): + yield OpHandle(block.arguments[0]) + + # Add `yield` to body if none present. + operations = block.operations + if len(operations) == 0 or not isinstance( + operations[len(operations) - 1], transform.YieldOp + ): + with ir.InsertionPoint(block): + transform.YieldOp() + + @property + def results(self) -> list[OpHandle]: + return [OpHandle(result) for result in self._op.results_] + + +@dataclass +class MapCopyToThreadsResult(MultiHandleResult): + forall_op: OpHandle + tiled_op: OpHandle + + +@dataclass +class PadResult(MultiHandleResult): + padded: OpHandle + pad: OpHandle + copy: OpHandle + + +class PadCopyBackOp(Enum): + NONE = "none" + LINALG_COPY = "linalg.copy" + BUFFER_COPY_TENSOR = "bufferization.copy_tensor" + + +class TileLoopKind(Enum): + """Kind of loop operation to produce in tiling.""" + + FOR = "scf.for" + FORALL = "scf.forall" + + +@dataclass +class TileResult(MultiHandleResult): + tiled_op: OpHandle + loops: Sequence[OpHandle] + + +@dataclass(frozen=True) +class Normalform(abc.ABC): + """Base class for all normalforms. + + A normalform is defined through a sequence of transformations to be applied to + a handle to reach this normalform. + """ + + propagate_up: ClassVar[bool] + propagate_down: ClassVar[bool] + + @classmethod + @abc.abstractmethod + def _impl(cls, handle: Value) -> Value: + """Defines the transformations required to reach this normalform. + + A normalform may apply arbitrary transforms as long as `handle` is updated + to wrap a valid mlir transform handle. This means a normalform might consume + the initial MLIR transform handle and update `handle` to represent a + different type of operation. Child handles of `handle` should be updated if + that makes sense semantically, but may be invalidated in the process of + normalization. + """ + ... + # TODO(@mluecke): Add tracking of handle invalidations so we can report if + # an invalid handle is accessed after e.g. normalization. + + @classmethod + def apply(cls, handle: Value) -> Value: + """Apply transformations to a handle to bring it into this normalform.""" + new_handle = cls._impl(handle) + # Setting this property propagates the normalform accordingly + handle.normalform = cls + return new_handle + + +@dataclass(frozen=True) +class AnyForm(Normalform): + """Weakest normal form. Any program can be considered to be in this form.""" + + propagate_up: ClassVar[bool] = True + propagate_down: ClassVar[bool] = True + + @classmethod + def _impl(cls, handle: Value) -> Value: + return handle + + +@dataclass(frozen=True) +class LoopNormalform(Normalform): + """A normal form that enables most loop based transformations. + + This normalform can only be applied to an OpHandle. + """ + + propagate_up: ClassVar[bool] = False + propagate_down: ClassVar[bool] = True + + @classmethod + def _impl(cls, handle: OpHandle) -> OpHandle: + with handle.apply_patterns(): + structured.ApplyTilingCanonicalizationPatternsOp() + jto.ApplyFoldFillIntoPadPatternsOp() + loop.ApplyForLoopCanonicalizationPatternsOp() + transform.ApplyCanonicalizationPatternsOp() + + handle.apply_licm(["scf.for"]) + handle.apply_cse() + return handle + + +C = TypeVar("C", bound=Callable) + + +def jasc_transform( + enforced_normalform: Optional[Type[Normalform] | C] = AnyForm, + required_normalform: Optional[Type[Normalform]] = AnyForm, + no_propagate: Optional[bool] = False, +) -> C: + """Decorator for jasc abstractions adding automatic handling of normalization. + + Args: + enforced_normalform: The normalform the resulting handles will have. + required_normalform: The required normalform to apply this transform. + no_propagate: If true, no changes to any normalforms will be done. + + Returns: + The decorated function according to the following: + + This enables automatic enforcement of a specific normalform before this + transform is executed. Propagates the enforced/retained normalform to the + resulting handles. + If no explicit normalform is provided the handles are conservatively assumed + to now be in AnyForm, i.e. the weakest normalform. + """ + + def wrapped(f: C) -> C: + def decorated(*args, **kwargs): + if required_normalform: + # TODO(@mluecke): this assumes that the payload op is surrounded by a + # func op that will be matched and normalized. This is not always + # guaranteed to be the case. + args[0].auto_normalize_parent_func(required_normalform) + results = f(*args, **kwargs) + + def flatten(results: Any) -> List: + """Unpacks all potentially nested iterables into a flat list.""" + all_results = [results] + is_iterable = lambda x: isinstance(x, Iterable) + while any(is_iterable(x) for x in all_results): + all_results = list( + itertools.chain.from_iterable( + x if is_iterable(x) else [x] for x in all_results + ) + ) + return all_results + + if not no_propagate: + for result in flatten(results): + result.normalform = enforced_normalform + return results + + return decorated + + # If the decorator was used without `()` the decorated function will be in + # this variable. We remap it and reset it to the default value. This enables + # using this decorator in similar fashion to e.g. the `dataclass` decorator. + if not isinstance(enforced_normalform, Type): + f = enforced_normalform + enforced_normalform = AnyForm + return wrapped(f) + return wrapped + + +def tuning_param(default_value: Optional[ir.Attribute | int] = None) -> Param: + """Emits a transform op that provides an "empty" to be autotuned param. + + Args: + default_value: The default value for this parameter. This will be used + during interpretation of the transform IR if no autotuning is performed. + If not specfied, 1 : i32 will be used as default value. + """ + i32_type = ir.IntegerType.get_signless(32) + if default_value is None: + default_value = ir.IntegerAttr.get(i32_type, 1) + param_type = transform.ParamType.get(i32_type) + elif isinstance(default_value, int): + default_value = ir.IntegerAttr.get(i32_type, default_value) + param_type = transform.ParamType.get(i32_type) + elif isinstance(default_value.type, ir.IntegerType): + # Support explicit param type for int types of different widths / signedness + param_type = transform.ParamType.get(default_value.type) + else: + # TODO(mluecke): Make this more general once transform.ParamType supports + # types beyond IntegerType + param_type = transform.AnyParamType.get() + op = jto.TuningParamOp(param_type, default_value) + return Param(op.param) + + +def constant_param(value: ir.Attribute | int) -> Param: + """Emits a transform op that provides a constant param.""" + if isinstance(value, int): + value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value) + param_type = transform.ParamType.get(value.type) + else: + # TODO(mluecke): Make this more general once transform.ParamType supports + # types beyond IntegerType + param_type = transform.AnyParamType.get() + op = transform.ParamConstantOp(param_type, value) + return Param(op.param) + + +@dataclass +class Value(abc.ABC): + """Wrapper around a transform handle with methods to chain further transforms.""" + + _mlir_value: ir.Value + _normalform: Type[Normalform] = AnyForm + children: list[Value] = field(default_factory=list) + parent: Optional[Value] = None + + @property + def mlir_value(self) -> ir.Value: + return self._mlir_value + + @property + def normalform(self) -> Type[Normalform]: + return self._normalform + + @normalform.setter + def normalform(self, normalform: Type[Normalform]): + self._normalform = normalform + if self._normalform.propagate_up: + self.propagate_up_normalform(normalform) + if self._normalform.propagate_down: + self.propagate_down_normalform(normalform) + + def propagate_up_normalform(self, normalform: Type[Normalform]): + if self.parent: + # Using the property here would trigger infinite propagation for NFs that + # have to be propagated up and down + self.parent._normalform = normalform + self.parent.propagate_up_normalform(normalform) + + def propagate_down_normalform(self, normalform: Type[Normalform]): + for child in self.children: + # Using the property here would trigger infinite propagation for NFs that + # have to be propagated up and down + child._normalform = normalform + child.propagate_down_normalform(normalform) + + def normalize(self, normalform: Type[Normalform]) -> Value: + """Applies transformations to bring this handle into a specific normalform.""" + normalform.apply(self) + return self + + @classmethod + def _unwrap_handles_from_dynamic_index_list( + cls, + indices: Union[DynamicIndexList, ir.ArrayAttr], + ) -> Union[DynamicIndexList, ir.ArrayAttr]: + """Extracts the MLIR value from each OpHandle in the given DynamicIndexList. + + This brings it into the definition of `DynamicIndexList` used by the + upstream op constructors, such that it can be passed as an argument there. + """ + if indices is None: + return None + # ArrayAttr: there are no OpHandles inside, so nothing to do. + if isinstance(indices, ir.ArrayAttr): + return indices + + # It must be a list: process each index at a time. + def extract_handle(index: Any) -> Any: + if isinstance(index, Value): + return index.mlir_value + elif not isinstance(index, (StaticIntLike, ValueLike)): + # If it's not one of these types, it must be a scalable index, which is + # a singleton list of one index. + return [extract_handle(index[0])] + return index + + return [extract_handle(index) for index in indices] + + +@dataclass +class Param(Value): + """Wrapper around a transform Param with methods to chain further transforms.""" + + +@dataclass +class OpHandle(Value): + """Wrapper around an OpHandle with methods to chain further transforms.""" + + def _ensure_op_type(self, type: Union[ir.Type, str]) -> OpHandle: + """Returns a handle to the same payload ops with the given op type. + + If the op type of the given handle already corresponds to the given type, + it is returned as is. Otherwise, a `transform.cast` is inserted and a handle + to that op is returned. If the expected type is given as a `str`, it is used + to construct a `transform.op<...>` with that string for the expected type. + """ + if isinstance(type, str): + type = transform.OperationType.get(type) + + if self.mlir_value.type != type: + return self.cast(type) + else: + return self + + def alloca_to_global(self) -> AllocaToGlobalResult: + """Creates a `MemRefAllocaToGlobalOp` and returns handles with the results. + + This handle will be updated to represent the tiled newly inserted + `memref.global` ops. + """ + alloca = self._ensure_op_type("memref.alloca") + op = memref.MemRefAllocaToGlobalOp(alloca.mlir_value) + self._mlir_value = op.getGlobal + return AllocaToGlobalResult(get_global=self, global_=OpHandle(op.global_)) + + def apply_cse(self) -> OpHandle: + """Creates a `ApplyCommonSubexpressionEliminationOp` and returns `self`.""" + transform.ApplyCommonSubexpressionEliminationOp(self.mlir_value) + return self + + def apply_dce(self) -> OpHandle: + """Creates a `ApplyDeadCodeEliminationOp` and returns `self`.""" + transform.ApplyDeadCodeEliminationOp(self.mlir_value) + return self + + def apply_licm( + self, to: Optional[Sequence[str | OpHandle]] = None + ) -> OpHandle: + """Creates a `ApplyLoopInvariantCodeMotionOp` for each given op and returns `self`. + + For strings in `to`, matches ops with these names. For the resulting matches + and for each `Value` given directly in `to`, a LICM transform op + (`transform.ApplyLoopInvariantCodeMotionOp`) is created. If `to` is `None`, + a LICM is created for `self`. + """ + # Handle `self` case. + if to is None: + to = [self] + + # Create match for string inputs. + op_names = [op for op in to if isinstance(op, str)] + ops = [op for op in to if isinstance(op, OpHandle)] + if op_names: + matched = self.match_ops(op_names) + ops.append(matched) + + # Create LICM ops. + for op in ops: + transform.ApplyLoopInvariantCodeMotionOp(op.mlir_value) + + return self + + @contextlib.contextmanager + def apply_patterns(self, *, apply_cse: Optional[bool] = None): + """Emits a `transform.ApplyPatternsOp`. + + Returns a context manager with an insertion point on the patterns block. + """ + op = transform.ApplyPatternsOp(self.mlir_value) + op.apply_cse = apply_cse + with ir.InsertionPoint(op.patterns): + yield + + def apply_tuning_config( + self, config: Sequence[int | ir.Attribute] + ) -> OpHandle: + """Creates a `ApplyTuningConfigOp` and returns `self`.""" + config_attr = ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i) + if isinstance(i, int) + else i + for i in config + ] + ) + jto.ApplyTuningConfigOp(self.mlir_value, config=config_attr) + return self + + def auto_normalize_parent_func(self, normalform: Type[Normalform]): + """Auto normalizes the parent function if needed.""" + if self.normalform != normalform and _JASC_AUTO_NORMALIZATION: + func = self.get_parent_op(op_name="func.func", deduplicate=True) + func.normalize(normalform) + + def buffer_loop_hoisting(self) -> OpHandle: + """Creates a `bufferization.BufferLoopHoistingOp` and returns `self`.""" + bufferization.BufferLoopHoistingOp(self.mlir_value) + return self + + def bufferize_to_allocation( + self, + *, + memory_space: Optional[int | str | ir.Attribute] = None, + memcpy_op: Optional[str] = None, + alloc_op: Optional[str] = None, + bufferize_destination_only: Optional[bool] = None, + ) -> BufferizatToAllocationResult: + """Creates a `structured.BufferizeToAllocationOp` op. + + Returns the results as handles in a `BufferizeToAllocationResult`. + """ + op = structured.BufferizeToAllocationOp( + self.mlir_value, + memory_space=memory_space, + memcpy_op=memcpy_op, + alloc_op=alloc_op, + bufferize_destination_only=bufferize_destination_only, + ) + + return BufferizatToAllocationResult( + allocated_buffer=ValueHandle(op.allocated_buffer), + new_ops=OpHandle(op.new_ops), + ) + + def cast(self, type_: ir.Type | str) -> OpHandle: + """Creates a handle from the result of a `CastOp` to the given type.""" + if isinstance(type_, str): + type_ = transform.OperationType.get(type_) + op = transform.CastOp(type_, self.mlir_value) + return OpHandle(op.output) + + def create_async_groups(self) -> OpHandle: + """Creates a handle from the result of a new `CreateAsyncGroupsOp` op.""" + op = nvgpu.CreateAsyncGroupsOp(transform.AnyOpType.get(), self.mlir_value) + # XXX: Should self._mlir_value be updated? + return OpHandle(op.result) + + def eliminate_empty_tensors(self) -> OpHandle: + """Creates a `bufferization.EliminateEmptyTensorsOp` and returns `self`.""" + bufferization.EliminateEmptyTensorsOp(self.mlir_value) + return self + + def foreach( + self, result_types: Optional[Union[ir.Type, Sequence[ir.Type]]] = None + ): + """Emits a `transform.foreach` op. + + The result object gives access to a context manager with an insertion point + on the body block as well as the results of the op. + """ + # TODO(ingomueller): Move boilerplate to upstream and make upstream `body` + # property return first block. + if result_types is None: + result_types = [] + if isinstance(result_types, ir.Type): + result_types = [result_types] + + input_type = self.mlir_value.type + op = transform.ForeachOp(results_=result_types, target=self.mlir_value) + op.body.blocks.append(input_type) + + return ForeachResult(_op=op) + + @jasc_transform(required_normalform=LoopNormalform) + def fuse_into( + self, containing_op: Union[ir.Operation, ir.OpView, ir.Value] + ) -> OpHandle: + """Creates a new `structured.FuseIntoContainingOp`. + + The func.func payload op surrounding the payload this handle represents + will be autonormalized to LoopNormalform if needed. + + This handle will afterwards point to the `fused_op` result. The + `containing_op` handle remains valid. + """ + op = structured.FuseIntoContainingOp( + self.mlir_value, containing_op.mlir_value + ) + self._mlir_value = op.fused_op + return self + + def get_parent_op( + self, + deduplicate: Optional[StaticBoolLike] = None, + isolated_from_above: Optional[StaticBoolLike] = None, + op_name: Optional[str] = None, + ) -> OpHandle: + """Creates a handle from the result of a new `GetParentOp` op.""" + op = transform.GetParentOp( + transform.AnyOpType.get(), + self.mlir_value, + deduplicate=deduplicate, + isolated_from_above=isolated_from_above, + op_name=op_name, + ) + return OpHandle(op.parent) + + def get_producer_of_operand( + self, operand_number: int | ir.Attribute + ) -> OpHandle: + """Creates a handle from the result of a new `GetProducerOfOperand` op.""" + op = transform.GetProducerOfOperand( + transform.AnyOpType.get(), self.mlir_value, operand_number + ) + return OpHandle(op.producer) + + def hoist_pad(self, num_loops: int | ir.Attribute) -> OpHandle: + """Creates a new `structured.HoistPadOp` op. + + This handle will be updated to represent the result of the transform. + """ + op = structured.HoistPadOp( + transform.AnyOpType.get(), self.mlir_value, num_loops + ) + self._mlir_value = op.transformed + return self + + def hoist_redundant_vector_transfers(self) -> OpHandle: + """Creates a new `structured.hoist_redundant_vector_transfers` op. + + This handle will be updated to represent the result of the transform. + """ + op = structured.HoistRedundantVectorTransfersOp( + transform.AnyOpType.get(), self.mlir_value + ) + self._mlir_value = op.transformed + return self + + def insert_slice_to_copy(self) -> OpHandle: + """Creates a new `structured.InsertSliceToCopyOp` op. + + Updates this handle to represent the new linalg.copy operation. + + The transform is a targeted rewrite of a `tensor.insert_slice` or + `tensor.parallel_insert_slice` to `linalg.copy`. If the insert_slice source + is already a linalg.copy, only returns the source op (i.e. does not create + an additional linalg.copy op). + """ + op = structured.InsertSliceToCopyOp( + transform.OperationType.get("linalg.copy"), self.mlir_value + ) + self._mlir_value = op.transformed + return self + + def interchange( + self, iterator_interchange: OptionalIntList = None + ) -> OpHandle: + """Creates a new `structured.interchange` op. + + Updates this handle to represent the transformed linalg operation. + """ + op = structured.InterchangeOp( + self.mlir_value, iterator_interchange=iterator_interchange + ) + return OpHandle(op.transformed) + + def map_forall_to_blocks( + self, + *, + grid_dims: Optional[Union[Sequence[int], ir.Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, ir.Attribute]] = None, + ) -> OpHandle: + """Creates a new `gpu.MapForallToBlocks` op.""" + op = gpu.MapForallToBlocks( + self.mlir_value, + grid_dims=grid_dims, + generate_gpu_launch=generate_gpu_launch, + ) + self._mlir_value = op.result + return self + + def map_copy_to_threads( + self, *, total_num_threads: int, desired_bit_alignment: int + ) -> MapCopyToThreadsResult: + """Creates a new `structured.gpu.MapCopyToThreadsOp` op. + + This handle will be updated to represent the new tiled op. + """ + op = structured.MapCopyToThreadsOp( + self.mlir_value, + total_num_threads=total_num_threads, + desired_bit_alignment=desired_bit_alignment, + ) + self._mlir_value = op.tiled_op + return MapCopyToThreadsResult( + forall_op=OpHandle(op.forall_op), + tiled_op=self, + ) + + def map_nested_forall_to_threads( + self, + *, + block_dims: OptionalIntList = None, + sync_after_distribute: Optional[StaticBoolLike] = None, + warp_size: OptionalIntList = None, + ) -> OpHandle: + """Creates a new `gpu.MapNestedForallToThreads` op.""" + op = gpu.MapNestedForallToThreads( + transform.AnyOpType.get(), + self.mlir_value, + block_dims=block_dims, + sync_after_distribute=sync_after_distribute, + warp_size=warp_size, + ) + self._mlir_value = op.result + return self + + @jasc_transform(required_normalform=LoopNormalform) + def vectorize( + self, + vector_sizes: Optional[Sequence[int | ir.Attribute]] = None, + *, + vectorize_nd_extract: Optional[bool] = None, + ) -> OpHandle: + """Creates a `structured.VectorizeOp` op and returns `self`. + + The func.func payload op surrounding the payload this handle represents + will be autonormalized to LoopNormalform if needed. + """ + if vector_sizes is not None: + vector_sizes = self._unwrap_handles_from_dynamic_index_list(vector_sizes) + structured.VectorizeOp( + self.mlir_value, + vector_sizes=vector_sizes, + vectorize_nd_extract=vectorize_nd_extract, + ) + return self + + def match_ops( + self, + ops: str + | ir.OpView + | structured.MatchInterfaceEnum + | Sequence[str | ir.OpView], + ) -> OpHandle: + """Returns a handle to ops that match the given names, types, or interface. + + If only a single type is given, the value wrapped by the resulting + handle is populated with the respective type. + """ + # Handle interface. + if isinstance(ops, structured.MatchInterfaceEnum) or ( + isinstance(ops, str) + and ops in structured.MatchInterfaceEnum.__members__ + ): + if isinstance(ops, str): + ops = structured.MatchInterfaceEnum[ops] + match_op = structured.MatchOp( + transform.AnyOpType.get(), + self.mlir_value, + interface=ops, + ) + + # Handle op name(s), either given directly as string or given as op. + else: + if isinstance(ops, str): + op_type = transform.OperationType.get(ops) + op_names = [ops] + elif isinstance(ops, Sequence): + op_type = transform.AnyOpType.get() + op_names = [ + op if isinstance(op, str) else op.OPERATION_NAME for op in ops + ] + else: + op_type = transform.OperationType.get(ops.OPERATION_NAME) + op_names = [ops.OPERATION_NAME] + match_op = structured.MatchOp.match_op_names( + op_type, + self.mlir_value, + op_names, + ) + + handle = OpHandle(match_op.results_, parent=self) + self.children.append(handle) + return handle + + def match_sparse_inout_ops(self) -> OpHandle: + op_type = transform.AnyOpType.get() + sparse_op = sparse_tensor.MatchSparseInOut(op_type, self.mlir_value) + handle = OpHandle(sparse_op.result, parent=self) + self.children.append(handle) + return handle + + def match_tag(self, tag_names: str | Sequence[str]) -> OpHandle: + """Returns a handle to linalg operations that match the given tags.""" + if isinstance(tag_names, str): + tag_names = [tag_names] + linalg_iface = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0) + linalg_ops = structured.MatchOp( + pdl.OperationType.get(), self.mlir_value, interface=linalg_iface + ).results + match_tag_op = jto.MatchTagOp(linalg_ops, tags=tag_names) # pylint: disable=no-value-for-parameter + handle = OpHandle(match_tag_op.matched_ops, parent=self) + self.children.append(handle) + return handle + + def one_shot_bufferize( + self, + allow_return_allocs_from_loops: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + function_boundary_type_conversion: Optional[ + Enum | str | ir.Attribute + ] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + ) -> OpHandle: + """Creates a new `bufferization.OneShotBufferizeOp` op. + + This handle will be updated to represent the result of the transform. + """ + if isinstance(function_boundary_type_conversion, str): + function_boundary_type_conversion = LayoutMapOption[ + function_boundary_type_conversion + ] + + op = bufferization.OneShotBufferizeOp( + self.mlir_value, + allow_return_allocs_from_loops=allow_return_allocs_from_loops, + allow_unknown_ops=allow_unknown_ops, + bufferize_function_boundaries=bufferize_function_boundaries, + function_boundary_type_conversion=function_boundary_type_conversion, + memcpy_op=memcpy_op, + print_conflicts=print_conflicts, + test_analysis_only=test_analysis_only, + ) + self._mlir_value = op.transformed + return self + + def pad( + self, + *, + copy_back_op: Optional[Union[str, ir.StringAttr, PadCopyBackOp]] = None, + pack_paddings: OptionalIntList = None, + padding_dimensions: OptionalIntList = None, + padding_values: Optional[Sequence[float]] = None, + pad_to_multiple_of: OptionalIntList = None, + transpose_paddings: Optional[ + Union[ir.ArrayAttr, Sequence[Union[ir.ArrayAttr, IntOrAttrList]]] + ] = None, + ) -> PadResult: + """Creates a new `structured.PadOp` op. + + This handle will be updated to represent the new padded op. + """ + + if isinstance(copy_back_op, PadCopyBackOp): + copy_back_op = copy_back_op.value + + if padding_values is not None: + padding_values = ir.ArrayAttr.get( + [ir.FloatAttr.get_f32(val) for val in padding_values] + ) + + op = structured.PadOp( # pylint: disable=no-value-for-parameter + self.mlir_value, + copy_back_op=copy_back_op, + pack_paddings=pack_paddings, + padding_values=padding_values, + padding_dimensions=padding_dimensions, + pad_to_multiple_of=pad_to_multiple_of, + transpose_paddings=transpose_paddings, + ) + self._mlir_value = op.padded + return PadResult( + padded=self, + pad=OpHandle(op.pad), + copy=OpHandle(op.copy), + ) + + def print(self, name: Optional[str] = None) -> OpHandle: + """Emits a transform op to print this handle and an optional message.""" + transform.PrintOp(target=self.mlir_value, name=name) + return self + + def rewrite_in_destination_passing_style(self) -> OpHandle: + """Creates a new `structured.RewriteInDestinationPassingStyleOp` op. + + This handle will be updated to represent the result of the transform. + """ + op = structured.RewriteInDestinationPassingStyleOp( + transform.AnyOpType.get(), self.mlir_value + ) + self._mlir_value = op.transformed + return self + + def select(self, op_name: str | ir.Attribute) -> OpHandle: + """Returns a handle to the result of a new `transform.SelectOp`.""" + op = transform.SelectOp(transform.AnyOpType.get(), self.mlir_value, op_name) + return OpHandle(op.result) + + def synchronize(self) -> OpHandle: + """Creates a new `SynchronizeOp` op and returns a handle to the barrier. + + self will only be read by this transform and hence stay valid. + """ + op = jto.SynchronizeOp( + transform.OperationType.get("gpu.barrier"), self.mlir_value + ) + return OpHandle(op.barrier) + + @jasc_transform(required_normalform=LoopNormalform) + def take_assumed_branch(self, take_else_branch: bool = None) -> OpHandle: + """Creates a `TakeAssumedBranchOp` and returns `self`. + + The func.func payload op surrounding the payload this handle represents + will be autonormalized to LoopNormalform if needed. + """ + loop.TakeAssumedBranchOp(self.mlir_value, take_else_branch=take_else_branch) + return self + + def _tile_using_for( + self, + *, + tile_sizes: Sequence[int | Param], + interchange: Optional[Sequence[int]] = None, + ): + op = structured.TileUsingForOp( + self.mlir_value, + sizes=self._unwrap_handles_from_dynamic_index_list(tile_sizes), + interchange=interchange, + ) + self._mlir_value = op.tiled_linalg_op + return TileResult( + tiled_op=self, + loops=[OpHandle(loop) for loop in op.loops], + ) + + @jasc_transform(required_normalform=LoopNormalform) + def tile( + self, + *, + loop: TileLoopKind, + tile_sizes: Optional[Sequence[int | Param]] = None, + interchange: Optional[Sequence[int]] = None, + num_threads: Optional[Sequence[int]] = None, + mapping: Optional[ + str | ir.Attribute | Sequence[str | ir.Attribute] + ] = None, + ) -> TileResult: + """Creates a new structured tiling operation. + + Depending on the `loop` kwarg, creates either a `structured.TileUsingFor` or + `structured.TileUsingForall` transform operation. Additional kwargs + parameterize the created op: + + `tile_sizes`: tile sizes to use in the loop, mandatory for `for` loops; + `num_threads`: the number of iterations in the produced loop, only supported + in `forall` tiling at the moment; + `interchange`: interchange of the dimensions, only supported in `for` tiling + at the moment; + `mapping`: mapping of the generated loops to parallelism concepts such as + GPU threads, only supported in `forall` loops (`for` loops are + implicitly sequential). + + This handle will be updated to represent the tiled linalg op. + """ + if loop == TileLoopKind.FOR: + if tile_sizes is None: + raise ValueError("Tile sizes must be provided.") + if num_threads is not None or mapping is not None: + raise ValueError( + "Cannot specify num threads or mapping when tiling to scf.for, use" + " scf.forall instead." + ) + return self._tile_using_for( + tile_sizes=tile_sizes, interchange=interchange + ) + + elif loop == TileLoopKind.FORALL: + if tile_sizes is None and num_threads is None: + raise ValueError("Must specify either tile sizes or num threads.") + if interchange is not None: + raise ValueError( + "Cannot specify interchange when tiling to scf.forall." + ) + if tile_sizes and any( + isinstance(tile_size, Param) for tile_size in tile_sizes + ): + raise ValueError( + "Cannot specify dynamic tile sizes when tiling to scf.forall." + ) + return self._tile_using_forall( + tile_sizes=tile_sizes, num_threads=num_threads, mapping=mapping + ) + + raise ValueError(f"Uknown loop kind {loop}") + + def replace_with_alloc_tensor(self) -> OpHandle: + """Creates a new `bufferization.EmptyTensorToAllocTensorOp` and updates this handle accordingly. + + The payload op in this handle has to be a `tensor.empty`. If the static type + of the MLIR handle indicates that this is the case, the handle is used as + is; otherwise, a `transform.cast` op is inserted that casts this MLIR handle + into the required type. + """ + + tensor_empty = self._ensure_op_type("tensor.empty").mlir_value + op = bufferization.EmptyTensorToAllocTensorOp(tensor_empty) + self._mlir_value = op.transformed + return self + + def _tile_using_forall( + self, + *, + mapping: Optional[ + str | ir.Attribute | Sequence[str | ir.Attribute] + ] = None, + num_threads: Optional[Sequence[int]] = None, + tile_sizes: Optional[Sequence[int]] = None, + ) -> TileResult: + """Creates a new `structured.TileUsingForallOp` op. + + The func.func payload op surrounding the payload this handle represents + will be autonormalized to LoopNormalform if needed. + + This handle will be updated to represent the tiled op. + """ + # TODO(mluecke): Remove string parsing of attributes once builders for GPU + # dialect attributes are available + attr_or_parse = lambda x: ir.Attribute.parse(x) if isinstance(x, str) else x + if isinstance(mapping, (str, ir.Attribute)): + mapping = attr_or_parse(mapping) + elif mapping is not None: + mapping = ir.ArrayAttr.get([attr_or_parse(attr) for attr in mapping]) + + op = structured.TileUsingForallOp( + transform.AnyOpType.get(), + transform.AnyOpType.get(), + self.mlir_value, + num_threads=num_threads, + tile_sizes=tile_sizes, + mapping=mapping, + ) + self._mlir_value = op.tiled_op + return TileResult( + loops=[OpHandle(op.forall_op)], + tiled_op=self, + ) + + def vectorize_children_and_apply_patterns( + self, + disable_multi_reduction_to_contract_patterns: Optional[bool] = None, + disable_transfer_permutation_map_lowering_patterns: Optional[bool] = None, + vectorize_nd_extract: Optional[bool] = None, + vectorize_padding: Optional[bool] = None, + ) -> OpHandle: + """Creates a new `structured.VectorizeChildrenAndApplyPatternsOp` op. + + This handle will be updated to represent the result of the transform. + """ + op = structured.VectorizeChildrenAndApplyPatternsOp( # pylint: disable=no-value-for-parameter + self.mlir_value, + disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns, + disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns, + vectorize_nd_extract=vectorize_nd_extract, + vectorize_padding=vectorize_padding, + ) + self._mlir_value = op.transformed + return self + + +@dataclass +class ValueHandle(Value): + """Wrapper around a ValueHandle with methods to chain further transforms.""" + + +class Schedule(Protocol): + """A schedule for a Jax computation. + + Example: + ``` + def computation(a: jax.Array) -> jax.Array: + return jasc.tag(lambda x: x + 1, "plus1")(a) + + def schedule(h: OpHandle) -> None: + h.match_tag("plus1").tile((4,)) + + jasc.jit(computation, schedule) + ``` + """ + + def __call__(self, handle: OpHandle) -> None: + """Builds a schedule for the computation handle to point to.""" + ... + + +def _flatten_func( + func: Callable[..., Any], args: Sequence[Any] +) -> tuple[Callable[..., Any], Sequence[jax.Array]]: + """Flattens a function inputs and ouputs. + + See jax.tree_util.tree_flatten for background information on flattening. + + Args: + func: A function to flatten. + args: Arguments that will be passed to func. + + Returns: + A tuple composed of: + - a function that has the same semantics than func, but with inputs and + outputs flattened. + - args flattened. + """ + flat_args, in_tree = jax.tree_util.tree_flatten(args) + + def flat(*flat_args: Any) -> Any: + unflat_args = jax.tree_util.tree_unflatten(in_tree, flat_args) + out = func(*unflat_args) + flat_outs, _ = jax.tree_util.tree_flatten(out) + return flat_outs + + return flat, flat_args + + +def jit( + func: Callable[..., Any], + schedule: Optional[Schedule] = None, + *, + module: Optional[ir.Module] = None, + dump_ir: bool = False, +) -> Callable[..., Any]: + """Applies scheduling directives inside func. + + Args: + func: A function to compile using Jasc. The function must use Jasc schedule + directives to optimize its code. + schedule: A schedule to apply to the computation. + module: An already lowered representation of `func` in MLIR. If this is + supplied it will be used for execution rather than lowering `func`. + dump_ir: If true, logs intermediate compilation steps. + + Returns: + A function with the same semantics as func, but compiled using Jasc. + """ + + @jax.jit + def wrapped(*args: Any) -> Any: + if schedule is None: + build_schedule = lambda handle: handle + else: + build_schedule = lambda handle: schedule(OpHandle(handle)) + flat_func, flat_args = _flatten_func(func, args) + out_avals = jax.tree_map( + lambda x: jax.core.ShapedArray(x.shape, x.dtype), + jax.eval_shape(func, *args), + ) + flat_out_avals, out_tree = jax.tree_util.tree_flatten(out_avals) + out_flat = primitives.jit_p.bind( + *flat_args, + func=flat_func, + module=module, + build_schedule=build_schedule, + out_avals=flat_out_avals, + dump_ir=dump_ir, + ) + return jax.tree_util.tree_unflatten(out_tree, out_flat) + + return wrapped + + +def apply_schedule( + module: ir.Module, + schedule: Schedule | None = None, + dump_ir: bool = False, + dump_schedule: bool = False, +) -> None: + """Applies a schedule to the module. + + Args: + module: Existing module with payload IR and possibly an existing schedule. + schedule: The schedule to apply at linalg level. If no schedule is supplied + it is assumed to already be present in the module. + dump_ir: Whether to dump the transformed IR after each pass. + dump_schedule: Whether to dump the schedule after creation. This is only + supported with a schedule that is not already in the module. + """ + + if schedule is not None: + insert_schedule(module, schedule, dump_schedule) + if schedule is None and dump_schedule: + raise ValueError( + "dump_schedule is only supported with a schedule that is not already in" + " the module." + ) + call_kernel.apply_schedule(module, dump_ir) + + +def insert_schedule( + module: ir.Module, + schedule: Schedule | None = None, + dump_schedule: bool = False, +) -> None: + """Inserts the transform script of the schedule into the module. + + Args: + module: Existing module into which the script should be inserted. + schedule: The schedule to apply at linalg level. + dump_schedule: Whether to dump the schedule after creation. + """ + if schedule is None: + schedule = lambda x: x + + # Register jasc transform ops so they can be used in the schedule + jasc_dialect.register_and_load_dialect(module.context) + + # Insert the schedule into the IR + with module.context, ir.Location.unknown(module.context): + with ir.InsertionPoint.at_block_begin(module.body): + sequence_op = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + (), + transform.AnyOpType.get(), + ) + with ir.InsertionPoint(sequence_op.body): + schedule(OpHandle(sequence_op.bodyTarget)) + transform.YieldOp([]) + + if dump_schedule: + print(sequence_op) + + +def tag(func: Callable[..., Any], name: str) -> Callable[..., Any]: + """Tags the function so that it can be matched by Jasc schedules. + + Args: + func: The function to tag. + name: Identifier of the tag. If the same name is used in multiple places, + Jasc schedules will match all occurences of the name. + + Returns: + A function with the same semantics as func. + """ + + def wrapped(*args: Any) -> Any: + flat_func, flat_args = _flatten_func(func, args) + out_avals = jax.tree_map( + lambda x: jax.core.ShapedArray(x.shape, x.dtype), + jax.eval_shape(func, *args), + ) + flat_out_avals, out_tree = jax.tree_util.tree_flatten(out_avals) + out_flat = primitives.tag_p.bind( + *flat_args, + func=flat_func, + out_avals=flat_out_avals, + name=name, + ) + return jax.tree_util.tree_unflatten(out_tree, out_flat) + + return wrapped + + +def yield_(values: Optional[Union[Value, Sequence[Value]]] = None) -> None: + if values is None: + values = [] + if isinstance(values, Value): + values = [values] + values = [v.mlir_value for v in values] + transform.YieldOp(values) + + +@overload +def lower_to_linalg( + func: Callable[..., Any], + *args: Any, + schedule: Schedule | None = None, + dump_ir: bool = False, + dump_schedule: bool = False, +) -> ir.Module: + ... + + +@overload +def lower_to_linalg( + module: ir.Module, + *, + schedule: Schedule | None = None, + dump_ir: bool = False, + dump_schedule: bool = False, +) -> ir.Module: + ... + + +def lower_to_linalg( + func_or_module: Union[Callable[..., Any], ir.Module], + *args: Any, + schedule: Schedule | None = None, + dump_ir: bool = False, + dump_schedule: bool = False, +) -> ir.Module: + """Lowers a function to linalg IR and applies a JASC schedule. + + Args: + func_or_module: A JAX function or an MLIR module to be lowered to linalg IR. + *args: Arguments that will be passed to the JAX function. + schedule: The schedule to apply at linalg level + dump_ir: Whether to dump the transformed IR after each pass + dump_schedule: Whether to dump the schedule after creation + + Returns: + An MLIR module with linalg IR with similar semantics to func + """ + if isinstance(func_or_module, ir.Module): + module = func_or_module + else: + module = lower_to_stablehlo(func_or_module, *args) + insert_schedule(module, schedule, dump_schedule) + call_kernel.lower_to_linalg(module, dump_ir) + return module + + +def lower_to_stablehlo( + func: Callable[..., Any], + *args: Any, +) -> ir.Module: + """Lowers a function to StableHLO IR. + + Args: + func: To function to be lowered to linalg IR. + *args: Arguments that will be passed to func. + + Returns: + An MLIR module with StableHLO IR with similar semantics to func. + """ + + with primitives.enable_jasc_lowering(): + ir_module = jax.jit(func).lower(*args).compiler_ir("stablehlo") + + # Make sure this lowering is not cached and possibly used by `jax.jit` later. + jax.clear_caches() + return ir_module diff --git a/jasc/jasc_opt.cc b/jasc/jasc_opt.cc new file mode 100644 index 000000000000..aee9cc224043 --- /dev/null +++ b/jasc/jasc_opt.cc @@ -0,0 +1,30 @@ +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Target/LLVMIR/Dialect/All.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "dialect/dialect.h" +#include "gpu_lowering_passes.h" +#include "mlir_lowering.h" +#include "transform_ops/dialect_extension.h" + +int main(int argc, char **argv) { + mlir::registerAllPasses(); + jasc::registerGPULoweringPasses(); + jasc::registerMLIRLoweringPasses(); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::registerAllExtensions(registry); + mlir::registerAllToLLVMIRTranslations(registry); + jasc::registerTransformDialectExtension(registry); + + registry.insert< + // clang-format off + jasc::JascDialect + // clang-format on + >(); + + return mlir::asMainReturnCode(mlir::MlirOptMain( + argc, argv, "MLIR modular optimizer driver\n", registry)); +} diff --git a/jasc/mlir_lowering.cc b/jasc/mlir_lowering.cc new file mode 100644 index 000000000000..97f51d9722eb --- /dev/null +++ b/jasc/mlir_lowering.cc @@ -0,0 +1,503 @@ +#include "mlir_lowering.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "llvm/include/llvm/ADT/SmallVector.h" +#include "llvm/include/llvm/Support/SourceMgr.h" +#include "llvm/include/llvm/Support/raw_ostream.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/GPU/Transforms/Utils.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/LLVMIR/Dialect/All.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/LocationSnapshot.h" +#include "mlir/Transforms/Passes.h" +#include "mhlo/transforms/passes.h" + +#include "dialect/dialect.h" +#include "dialect/ops.h" +#include "gpu_lowering_passes.h" +#include "transform_ops/dialect_extension.h" + +namespace jasc { +namespace { + +// Runs the pass manager on the model and handles errors. +absl::Status RunPassManager(mlir::PassManager& pm, mlir::ModuleOp module, + bool dump_ir) { + std::string error_message; + llvm::raw_string_ostream os(error_message); + mlir::MLIRContext* context = module.getContext(); + llvm::SourceMgr srcMgr; + mlir::SourceMgrDiagnosticHandler handler(srcMgr, context, os); + + bool multithreaded = context->isMultithreadingEnabled(); + if (dump_ir) { + context->disableMultithreading(); + pm.enableIRPrinting([](auto*, auto*) { return false; }); + } + + mlir::LogicalResult result = pm.run(module); + if (multithreaded && dump_ir) { + context->enableMultithreading(); + } + + if (mlir::succeeded(result)) return absl::OkStatus(); + return absl::InternalError("Failed to apply transformations:\n\n" + + error_message); +} + +// Base class for passes applicable to any operations. This is needed to +// reduce template nesting below. +template +class OpPassWrapper : public mlir::PassWrapper> { +}; +class ApplyTransformScriptPass + : public mlir::transform::TransformInterpreterPassBase< + ApplyTransformScriptPass, OpPassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ApplyTransformScriptPass) + + ApplyTransformScriptPass() = default; + ApplyTransformScriptPass(const ApplyTransformScriptPass& pass) + : mlir::transform::TransformInterpreterPassBase(pass) {} + ApplyTransformScriptPass(const mlir::transform::TransformOptions& _options) + : mlir::transform::TransformInterpreterPassBase(_options) { + } + + void runOnOperation() override { + options.enableEnforceSingleToplevelTransformOp( + enforceSingleToplevelTransformOp); + TransformInterpreterPassBase::runOnOperation(); + } + + // Returns the name of the pass suitable for the pass manager. + llvm::StringRef getArgument() const override { + return "jasc-apply-transform-script"; + } + + void getDependentDialects(mlir::DialectRegistry& registry) const override { + registry.insert(); + jasc::registerTransformDialectExtension(registry); + } + + // MLIR pass options. This MUST use exactly the specified names. + // clang-format off + Option transformFileName{ // NOLINT + *this, "transform-file-name", llvm::cl::init(""), + llvm::cl::desc( + "Optional filename containing a transform dialect specification to " + "apply. If left empty, the IR is assumed to contain one top-level " + "transform dialect operation somewhere in the module.")}; + Option debugPayloadRootTag{ // NOLINT + *this, "debug-payload-root-tag", llvm::cl::init(""), + llvm::cl::desc( + "Select the operation with 'transform.target_tag' attribute having " + "the given value as payload IR root. If empty select the pass " + "anchor " + "operation as the payload IR root.")}; + Option debugTransformRootTag{ // NOLINT + *this, "debug-transform-root-tag", llvm::cl::init(""), + llvm::cl::desc( + "Select the operation with 'transform.target_tag' attribute having " + "the given value as container IR for top-level transform ops. This " + "allows user control on what transformation to apply. If empty, " + "select the container of the top-level transform op.")}; + ListOption transformLibraryPaths{ // NOLINT + *this, "transform-library-paths", llvm::cl::ZeroOrMore, + llvm::cl::desc( + "Optional name of the file containing transform dialect symbol " + "definitions to be injected into the transform module.")}; + Option enforceSingleToplevelTransformOp{ + *this, "enforce-single-top-level-transform-op", llvm::cl::init(true), + llvm::cl::desc("Ensure that only a single top-level transform op is " + "present in the IR.")}; + // clang-format on +}; + +std::unique_ptr CreateApplyTransformScriptPass( + llvm::StringRef name) { + auto pass = std::make_unique(); + std::string path = ""; + path.append(name); + pass->transformFileName = path; + return std::move(pass); +} + +class LowerTagRegionsPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerTagRegionsPass) + + // Returns the name of the pass suitable for the pass manager. + llvm::StringRef getArgument() const override { + return "jasc-lower-tag-regions"; + } + + void runOnOperation() override { + getOperation().walk([](jasc::TagRegionOp tag_region_op) { + mlir::StringAttr name = tag_region_op.getNameAttr(); + tag_region_op->walk([&](mlir::Operation* op) { + if (llvm::isa(op)) return; + llvm::SmallVector new_array; + auto old_array = op->getAttrOfType("jasc_tags"); + if (old_array != nullptr) { + new_array.append(old_array.begin(), old_array.end()); + } + new_array.push_back(name); + mlir::MLIRContext* ctx = op->getContext(); + op->setAttr("jasc_tags", mlir::ArrayAttr::get(ctx, new_array)); + }); + }); + + getOperation().walk([](jasc::TagRegionOp tag_region_op) { + mlir::Block& body = tag_region_op.getBody().front(); + mlir::Block& parent_block = *tag_region_op->getBlock(); + tag_region_op->replaceAllUsesWith(body.getTerminator()->getOperands()); + parent_block.getOperations().splice( + mlir::Block::iterator(tag_region_op), body.getOperations(), + body.begin(), mlir::Block::iterator(body.getTerminator())); + tag_region_op.erase(); + }); + } +}; + +mlir::LogicalResult AllocRemoval(mlir::memref::CopyOp copy, + mlir::PatternRewriter& rewriter) { + mlir::Value from = copy.getSource(); + mlir::Value to = copy.getTarget(); + if (from.getDefiningOp() == nullptr) return mlir::failure(); + if (!llvm::isa( + from.getDefiningOp())) { + return mlir::failure(); + } + + // Only go up one level to grab the parent function; the match we're looking + // for is at the very end of a function. + auto func = llvm::dyn_cast_or_null(copy->getParentOp()); + if (!func) { + return mlir::failure(); + } + + // If the copy target is a function argument, use it directly. + if (llvm::is_contained(func.getArguments(), to)) { + rewriter.replaceAllUsesWith(from, to); + rewriter.eraseOp(from.getDefiningOp()); + rewriter.eraseOp(copy); + return mlir::success(); + } + return mlir::failure(); +} + +class RemoveCopyToOutParamsPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RemoveCopyToOutParamsPass); + + // Returns the name of the pass suitable for the pass manager. + llvm::StringRef getArgument() const override { + return "jasc-remove-copy-to-out-params"; + } + + // Runs the pass. + void runOnOperation() override { + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(AllocRemoval); + if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +void AddStableHLOToLinalgPasses(mlir::PassManager& pm) { + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + pm.addNestedPass( + mlir::mhlo::createLegalizeHloToLinalgPass(true)); + pm.addPass(std::make_unique()); +} + +void AddBufferizationPasses( + mlir::PassManager& pm, + mlir::bufferization::OneShotBufferizationOptions options, + bool run_sparsification) { + // TODO(ulysse): Avoid unnecessary copies introduced by bufferization. + pm.addPass(mlir::createCSEPass()); + + options.bufferizeFunctionBoundaries = true; + mlir::bufferization::BufferResultsToOutParamsOptions out_params_options; + if (run_sparsification) { + // Setup both sparsification and bufferization. + // + // TODO(peiming, ajcbik, springerm): Make sparse compiler compatible with + // one-shot bufferization. At the moment, they have to be intermixed, which + // prevents us from running two passes independently and from sparsifying + // kernel using transform IR. + mlir::SparsificationOptions sparsification_options; + sparsification_options.enableRuntimeLibrary = false; + sparsification_options.enableIndexReduction = true; + // Sparsification set up. + // TODO(peiming, ajcbik): Maybe lift vectorization to transform IR instead? + pm.addPass(mlir::createSparsificationAndBufferizationPass( + options, sparsification_options, + /*createSparseDeallocs=*/false, + /*enableRuntimeLibrary=*/false, + /*enableBufferInitialization=*/false, + /*vectorLength=*/0, + /*enableVLAVectorization=*/false, + /*enableSIMDIndex32*/ false)); + pm.addPass(mlir::createStorageSpecifierToLLVMPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addNestedPass( + mlir::bufferization::createFinalizingBufferizePass()); + // TODO(peiming, ajcbik): Find a way to avoid generating reallocations. + pm.addNestedPass( + mlir::memref::createExpandReallocPass(false)); + } else { + pm.addPass(mlir::bufferization::createOneShotBufferizePass(options)); + } + // Sparse compiler might insert extra function calls for complex operations. + out_params_options.filterFn = [](mlir::func::FuncOp* func) { + // Only transform the entry point. + return func->getSymName() == "main"; + }; + pm.addPass(mlir::bufferization::createBufferResultsToOutParamsPass( + out_params_options)); + pm.addPass(std::make_unique()); + // TODO(mluecke): Add deallocation passes here when upstream problems are + // fixed +} + +// No-op pass to register dialects needed for LLVM lowering. +class RegisterLLVMTranslationDialectsPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + RegisterLLVMTranslationDialectsPass) + + // Returns the name of the pass suitable for the pass manager. + llvm::StringRef getArgument() const override { + return "jasc-register-llvm-translation-dialects"; + } + + void getDependentDialects(mlir::DialectRegistry& registry) const override { + mlir::registerAllToLLVMIRTranslations(registry); + } + + void runOnOperation() override {} +}; + +void AddLowerBufferizedLinalgToCF(mlir::PassManager& pm) { + pm.addPass(std::make_unique()); + pm.addNestedPass(mlir::createConvertLinalgToLoopsPass()); + pm.addPass(mlir::createConvertSCFToCFPass()); + pm.addPass(mlir::createLowerAffinePass()); +} + +void AddLowerCFToLLVMPasses(mlir::PassManager& pm) { + mlir::ConvertVectorToLLVMPassOptions vector_to_llvm_opts; + vector_to_llvm_opts.reassociateFPReductions = true; + vector_to_llvm_opts.useOpaquePointers = true; + pm.addPass(mlir::createConvertVectorToLLVMPass(vector_to_llvm_opts)); + pm.addNestedPass(mlir::createConvertMathToLLVMPass()); + + // Expand complicated MemRef operations before lowering them. + pm.addPass(mlir::memref::createExpandStridedMetadataPass()); + // The expansion may create affine expressions. Get rid of them. + pm.addPass(mlir::createLowerAffinePass()); + + mlir::FinalizeMemRefToLLVMConversionPassOptions memref_to_llvm_opts; + // memref_to_llvm_opts.useOpaquePointers = true; + pm.addPass( + mlir::createFinalizeMemRefToLLVMConversionPass(memref_to_llvm_opts)); + + mlir::ConvertFuncToLLVMPassOptions func_to_llvm_opts; + // func_to_llvm_opts.useOpaquePointers = true; + func_to_llvm_opts.useBarePtrCallConv = false; + pm.addPass(mlir::createConvertFuncToLLVMPass(func_to_llvm_opts)); + pm.addPass(mlir::createConvertIndexToLLVMPass()); + pm.addPass(mlir::createReconcileUnrealizedCastsPass()); +} + +mlir::BaseMemRefType ConvertGpuArgType( + mlir::TensorType tensor_type, mlir::Attribute, mlir::func::FuncOp, + const mlir::bufferization::BufferizationOptions&) { + // Override the memory space to global. + auto memory_space = mlir::gpu::AddressSpaceAttr::get( + tensor_type.getContext(), mlir::gpu::AddressSpace::Global); + return mlir::bufferization::getMemRefTypeWithStaticIdentityLayout( + tensor_type, memory_space); +} + +class EraseTransformScriptPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(EraseTransformScriptPass) + + // Returns the name of the pass suitable for the pass manager. + llvm::StringRef getArgument() const override { + return "jasc-erase-transform-script"; + } + + void runOnOperation() override { + // Erase the first top-level transform so we can lower normally. + getOperation()->walk( + [](mlir::transform::TransformOpInterface top_level_transform) { + top_level_transform->erase(); + return mlir::WalkResult::interrupt(); + }); + } +}; + +void AddTransformInterpreterPasses(mlir::PassManager& pm) { + pm.addPass(mlir::createLocationSnapshotPass()); + mlir::transform::TransformOptions transformOptions; + transformOptions.enableEnforceSingleToplevelTransformOp(false); + pm.addPass(std::make_unique(transformOptions)); + pm.addPass(std::make_unique()); +} + +} // namespace + +absl::Status ApplyTransformScript(mlir::ModuleOp module, bool dump_ir) { + mlir::PassManager pm(module.getContext()); + AddTransformInterpreterPasses(pm); + return RunPassManager(pm, module, dump_ir); +} + +absl::Status LowerStableHloToCpuLLVM(mlir::ModuleOp module, bool dump_ir) { + mlir::PassManager pm(module.getContext()); + AddStableHLOToLinalgPasses(pm); + AddTransformInterpreterPasses(pm); + // Convert create_empty_tensor to allocs to ensure that they are not touched + // by CSE. Maybe we can create them directly during transformations instead. + pm.addNestedPass( + mlir::bufferization::createEmptyTensorToAllocTensorPass()); + mlir::bufferization::OneShotBufferizationOptions bufferization_options; + bufferization_options.setFunctionBoundaryTypeConversion( + mlir::bufferization::LayoutMapOption::IdentityLayoutMap); + AddBufferizationPasses(pm, bufferization_options, + /*run_sparsification=*/true); + AddLowerBufferizedLinalgToCF(pm); + AddLowerCFToLLVMPasses(pm); + return RunPassManager(pm, module, dump_ir); +} + +absl::Status LowerStableHloToGpuLLVM(mlir::ModuleOp module, bool dump_ir) { +#ifdef MLIR_GPU_TO_CUBIN_PASS_ENABLE + mlir::PassManager pm(module.getContext()); + AddStableHLOToLinalgPasses(pm); + AddTransformInterpreterPasses(pm); + mlir::bufferization::OneShotBufferizationOptions bufferization_options; + bufferization_options.allocationFn = &CreateGpuAlloc; + bufferization_options.memCpyFn = &CreateGpuMemCpy; + bufferization_options.functionArgTypeConverterFn = &ConvertGpuArgType; + bufferization_options.inferFunctionResultLayout = false; + + pm.addNestedPass( + mlir::bufferization::createEmptyTensorToAllocTensorPass()); + pm.addPass(CreateSetDefaultGpuMemorySpacePass()); + + AddBufferizationPasses(pm, bufferization_options, + /*run_sparsification=*/false); + pm.addNestedPass(CreateMemcpyToGpuDialectPass()); + pm.addPass(CreateApplyTransformScriptPass("gpu_post_bufferize.mlir")); + AddLowerBufferizedLinalgToCF(pm); + pm.addPass(mlir::createGpuLauchSinkIndexComputationsPass()); + pm.addPass(mlir::createGpuKernelOutliningPass()); + pm.addPass(mlir::memref::createExpandStridedMetadataPass()); + pm.addNestedPass( + mlir::createConvertGpuOpsToNVVMOps()); + + // TODO(ulysse): see how much of the remaining can we share with GPUs. + // Note: a lot of the GPU lowering code is hidden in GPUToLLVM. + pm.addPass(mlir::createConvertIndexToLLVMPass()); + pm.addPass(mlir::createConvertVectorToLLVMPass()); + pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); + pm.addPass(mlir::createLowerAffinePass()); + mlir::ConvertFuncToLLVMPassOptions func_to_llvm_opts; + func_to_llvm_opts.useOpaquePointers = true; + func_to_llvm_opts.useBarePtrCallConv = true; + pm.addPass(mlir::createConvertFuncToLLVMPass(func_to_llvm_opts)); + pm.addPass(mlir::createCanonicalizerPass()); + + pm.addNestedPass(mlir::createGpuSerializeToCubinPass( + "nvptx64-nvidia-cuda", "sm_35", "+ptx60")); + pm.addPass(CreateGpuToLLVMConversionPass()); + pm.addPass(mlir::createReconcileUnrealizedCastsPass()); + return RunPassManager(pm, module, dump_ir); +#else + return absl::InternalError("MLIR_GPU_TO_CUBIN_PASS_ENABLE not defined"); +#endif // MLIR_GPU_TO_CUBIN_PASS_ENABLE +} + +absl::Status LowerStableHloToLinalg(mlir::ModuleOp module, bool dump_ir) { + mlir::PassManager pm(module.getContext()); + AddStableHLOToLinalgPasses(pm); + AddTransformInterpreterPasses(pm); + return RunPassManager(pm, module, dump_ir); +} + +void registerMLIRLoweringPasses() { + mlir::PassRegistration(); + mlir::PassRegistration(); + mlir::PassRegistration(); + mlir::PassRegistration(); + mlir::PassRegistration(); +} + +} // namespace jasc diff --git a/jasc/mlir_lowering.h b/jasc/mlir_lowering.h new file mode 100644 index 000000000000..4d9dccabef2a --- /dev/null +++ b/jasc/mlir_lowering.h @@ -0,0 +1,21 @@ +#ifndef THIRD_PARTY_MLIR_EDGE_JASC_MLIR_LOWERING_H_ +#define THIRD_PARTY_MLIR_EDGE_JASC_MLIR_LOWERING_H_ + +#include "absl/status/status.h" +#include "mlir/IR/BuiltinOps.h" + +namespace jasc { + +absl::Status ApplyTransformScript(mlir::ModuleOp module, bool dump_ir); + +absl::Status LowerStableHloToCpuLLVM(mlir::ModuleOp module, bool dump_ir); + +absl::Status LowerStableHloToGpuLLVM(mlir::ModuleOp module, bool dump_ir); + +absl::Status LowerStableHloToLinalg(mlir::ModuleOp module, bool dump_ir); + +void registerMLIRLoweringPasses(); + +} // namespace jasc + +#endif // THIRD_PARTY_MLIR_EDGE_JASC_MLIR_LOWERING_H_ diff --git a/jasc/patches/apply.sh b/jasc/patches/apply.sh new file mode 100644 index 000000000000..5e7cbb022c40 --- /dev/null +++ b/jasc/patches/apply.sh @@ -0,0 +1,12 @@ +#!/bin/sh + +cd llvm-project +patch -p1 < ../patches/llvm_build.patch +# patch -p1 < ../patches/clang_macos.patch +cd - + +cd jax +patch -p1 < ../patches/jax_workspace.patch +touch llvm_dummy.BUILD +cd - + diff --git a/jasc/patches/clang_macos.patch b/jasc/patches/clang_macos.patch new file mode 100644 index 000000000000..3ebce89dbf21 --- /dev/null +++ b/jasc/patches/clang_macos.patch @@ -0,0 +1,13 @@ +diff --git a/utils/bazel/llvm-project-overlay/clang/BUILD.bazel b/utils/bazel/llvm-project-overlay/clang/BUILD.bazel +index 419b2eeca7e1..c99b350f4a9f 100644 +--- a/utils/bazel/llvm-project-overlay/clang/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/clang/BUILD.bazel +@@ -1615,7 +1615,7 @@ genrule( + outs = [hdr.replace("lib/Headers/", "staging/include/") for hdr in builtin_headers], + cmd = """ + for src in $(SRCS); do +- relsrc=$${src/*"$(WORKSPACE_ROOT)"\\/clang\\/lib\\/Headers} ++ relsrc=$${src/*external\\llvm-project\\/clang\\/lib\\/Headers} + target=$(@D)/staging/include/$$relsrc + mkdir -p $$(dirname $$target) + cp $$src $$target diff --git a/jasc/patches/jax.patch b/jasc/patches/jax.patch new file mode 100644 index 000000000000..e82777956811 --- /dev/null +++ b/jasc/patches/jax.patch @@ -0,0 +1,525 @@ +--- a/jaxlib/cpu/BUILD ++++ a/jaxlib/cpu/BUILD +@@ -79,7 +79,7 @@ cc_library( + ":ducc_fft_flatbuffers_cc", + "@xla//xla/service:custom_call_status", + "@com_github_google_flatbuffers//:flatbuffers", +- "@ducc//:fft", ++ "@ducc//:fft_wrapper", + ], + ) + + +--- a/jaxlib/mlir/_mlir_libs/BUILD.bazel ++++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel +@@ -241,6 +241,7 @@ cc_library( + deps = [ + ":jax_dialects_capi", + "//jaxlib/mosaic:tpu_dialect_capi_objects", ++ "@com_google_protobuf//:protobuf", + "@llvm-project//mlir:CAPIArithObjects", + "@llvm-project//mlir:CAPIMathObjects", + "@llvm-project//mlir:CAPIMemRefObjects", +@@ -250,7 +251,11 @@ cc_library( + "@llvm-project//mlir:MLIRBindingsPythonCAPIObjects", + "@stablehlo//:chlo_capi_objects", + "@stablehlo//:stablehlo_capi_objects", ++ "@tsl//tsl/platform:env", ++ "@tsl//tsl/platform:env_impl", + "@xla//xla/mlir_hlo:CAPIObjects", ++ "@xla//xla:xla_data_proto_cc", ++ "@xla//xla:xla_data_proto_cc_impl", + ], + ) + + +--- a/jaxlib/mlir/_mlir_libs/BUILD.bazel ++++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel +@@ -139,6 +139,40 @@ py_extension( + ], + ) + ++py_extension( ++ name = "_mlirDialectsTransform", ++ srcs = [ ++ "@llvm-project//mlir:lib/Bindings/Python/DialectTransform.cpp", ++ ], ++ copts = COPTS, ++ linkopts = LINKOPTS, ++ deps = [ ++ ":jax_dialects_capi_headers", ++ ":jaxlib_mlir_capi_shared_library", ++ "@llvm-project//mlir:CAPIIRHeaders", ++ "@llvm-project//mlir:CAPITransformDialect", ++ "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", ++ "@pybind11", ++ ], ++) ++ ++py_extension( ++ name = "_mlirDialectsPDL", ++ srcs = [ ++ "@llvm-project//mlir:lib/Bindings/Python/DialectPDL.cpp", ++ ], ++ copts = COPTS, ++ linkopts = LINKOPTS, ++ deps = [ ++ ":jax_dialects_capi_headers", ++ ":jaxlib_mlir_capi_shared_library", ++ "@llvm-project//mlir:CAPIIRHeaders", ++ "@llvm-project//mlir:CAPIPDL", ++ "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", ++ "@pybind11", ++ ], ++) ++ + ##---------------------------------------------------------------------------## + # MHLO Extensions + ##---------------------------------------------------------------------------## + +--- a/jaxlib/mlir/BUILD.bazel ++++ a/jaxlib/mlir/BUILD.bazel +@@ -75,6 +75,75 @@ symlink_inputs( + ) + + symlink_inputs( ++ name = "bufferization_dialect", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects": [ ++ "@llvm-project//mlir/python:BufferizationOpsPyFiles", ++ ]}}, ++ deps = [ ++ ":core", ++ ":ir", ++ ":mlir", ++ ], ++) ++ ++symlink_inputs( ++ name = "pdl_dialect", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects": [ ++ "@llvm-project//mlir/python:PDLPyFiles", ++ ]}}, ++ deps = [ ++ ":core", ++ ":ir", ++ ":mlir", ++ ":pdl_dialect_extension", ++ ], ++) ++ ++symlink_inputs( ++ name = "transform_dialect", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects": [ ++ "@llvm-project//mlir/python:TransformOpsPyFiles", ++ ]}}, ++ deps = [ ++ ":core", ++ ":ir", ++ ":mlir", ++ ":transform_dialect_extensions", ++ ], ++) ++ ++symlink_inputs( ++ name = "transform_dialect_extensions", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects/transform": [ ++ "@llvm-project//mlir/python:TransformOpsPackagePyFiles", ++ ]}}, ++ deps = [ ++ ":core", ++ ":ir", ++ ":mlir", ++ "//jaxlib/mlir/_mlir_libs:_mlirDialectsTransform", ++ ], ++) ++ ++symlink_inputs( ++ name = "pdl_dialect_extension", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects": [ ++ "@llvm-project//mlir/python:PDLPyFiles", ++ ]}}, ++ deps = [ ++ ":core", ++ ":ir", ++ ":mlir", ++ "//jaxlib/mlir/_mlir_libs:_mlirDialectsPDL", ++ ], ++) ++ ++symlink_inputs( + name = "math_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ + +--- a/jax/BUILD ++++ a/jax/BUILD +@@ -70,6 +70,7 @@ package_group( + # Intentionally avoid jax dependencies on jax.extend. + # See https://jax.readthedocs.io/en/latest/jep/15856-jex.html + "//third_party/py/jax/tests/...", ++ "public", + ] + jax_extend_internal_users, + ) + + +--- a/jaxlib/mlir/_mlir_libs/BUILD.bazel ++++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel +@@ -131,6 +131,8 @@ py_extension( + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIMathHeaders", + "@llvm-project//mlir:CAPIMemRefHeaders", ++ "@llvm-project//mlir:CAPIPDLHeaders", ++ "@llvm-project//mlir:CAPITransformDialectHeaders", + "@llvm-project//mlir:CAPITransformsHeaders", + "@llvm-project//mlir:CAPIVectorHeaders", + "@llvm-project//mlir:MLIRBindingsPythonHeaders", +@@ -279,7 +281,9 @@ cc_library( + "@llvm-project//mlir:CAPIArithObjects", + "@llvm-project//mlir:CAPIMathObjects", + "@llvm-project//mlir:CAPIMemRefObjects", ++ "@llvm-project//mlir:CAPIPDLObjects", + "@llvm-project//mlir:CAPISparseTensorObjects", ++ "@llvm-project//mlir:CAPITransformDialectObjects", + "@llvm-project//mlir:CAPITransformsObjects", + "@llvm-project//mlir:CAPIVectorObjects", + "@llvm-project//mlir:MLIRBindingsPythonCAPIObjects", + +--- a/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc ++++ b/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc +@@ -2,9 +2,12 @@ + // This module is called by mlir/__init__.py during initialization. + + #include "mlir-c/Dialect/Arith.h" ++// #include "mlir-c/Dialect/Bufferization.h" + #include "mlir-c/Dialect/Func.h" + #include "mlir-c/Dialect/Math.h" + #include "mlir-c/Dialect/MemRef.h" ++#include "mlir-c/Dialect/PDL.h" ++#include "mlir-c/Dialect/Transform.h" + #include "mlir-c/Dialect/Vector.h" + #include "mlir-c/Transforms.h" + #include "mlir/Bindings/Python/PybindAdaptors.h" +@@ -19,10 +22,13 @@ PYBIND11_MODULE(_site_initialize_0, m) { + + m.def("register_dialects", [](MlirDialectRegistry registry) { + REGISTER_DIALECT(arith); ++ // REGISTER_DIALECT(bufferization); + REGISTER_DIALECT(func); + REGISTER_DIALECT(math); + REGISTER_DIALECT(memref); ++ REGISTER_DIALECT(pdl); + REGISTER_DIALECT(scf); ++ REGISTER_DIALECT(transform); + REGISTER_DIALECT(vector); + mlirRegisterTransformsPasses(); + // Transforms used by JAX. + +--- a/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc ++++ b/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc +@@ -9,6 +9,7 @@ + #include "mlir-c/Dialect/PDL.h" + #include "mlir-c/Dialect/Transform.h" + #include "mlir-c/Dialect/Vector.h" ++#include "mlir-c/RegisterEverything.h" + #include "mlir-c/Transforms.h" + #include "mlir/Bindings/Python/PybindAdaptors.h" + #include "jaxlib/mlir/_mlir_libs/jax_dialects.h" +@@ -31,6 +32,7 @@ PYBIND11_MODULE(_site_initialize_0, m) { + REGISTER_DIALECT(transform); + REGISTER_DIALECT(vector); + mlirRegisterTransformsPasses(); ++ mlirRegisterAllDialects(registry); + // Transforms used by JAX. + mlirRegisterTransformsStripDebugInfo(); + }); + +--- a/jaxlib/mlir/_mlir_libs/BUILD.bazel ++++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel +@@ -279,18 +279,21 @@ cc_library( + "//jaxlib/mosaic:tpu_dialect_capi_objects", + "@com_google_protobuf//:protobuf", + "@llvm-project//mlir:CAPIArithObjects", ++ "@llvm-project//mlir:CAPIInterfacesObjects", + "@llvm-project//mlir:CAPIMathObjects", + "@llvm-project//mlir:CAPIMemRefObjects", + "@llvm-project//mlir:CAPIPDLObjects", ++ "@llvm-project//mlir:CAPIRegisterEverythingObjects", + "@llvm-project//mlir:CAPISparseTensorObjects", + "@llvm-project//mlir:CAPITransformDialectObjects", + "@llvm-project//mlir:CAPITransformsObjects", + "@llvm-project//mlir:CAPIVectorObjects", +- "@llvm-project//mlir:MLIRBindingsPythonCAPIObjects", ++ "@llvm-project//mlir:CAPIDebugObjects", ++ "@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPIObjects", + "@stablehlo//:chlo_capi_objects", + "@stablehlo//:stablehlo_capi_objects", + "@tsl//tsl/platform:env", +- "@tsl//tsl/platform:env_impl", ++ "@tsl//tsl/platform:env_impl", + "@xla//xla/mlir_hlo:CAPIObjects", + "@xla//xla:xla_data_proto_cc", + "@xla//xla:xla_data_proto_cc_impl", + +--- a/jaxlib/mlir/BUILD.bazel ++++ b/jaxlib/mlir/BUILD.bazel +@@ -120,12 +120,14 @@ symlink_inputs( + rule = py_library, + symlinked_inputs = {"srcs": {"dialects/transform": [ + "@llvm-project//mlir/python:TransformOpsPackagePyFiles", ++ "@jasc//transform_ops:transform_ops", + ]}}, + deps = [ + ":core", + ":ir", + ":mlir", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsTransform", ++ "//jaxlib/mlir/_mlir_libs:_mlirTransformOpsJasc", + ], + ) + +@@ -250,6 +252,20 @@ symlink_inputs( + ], + ) + ++symlink_inputs( ++ name = "jasc_dialect", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects": [ ++ "@jasc//dialect:python", ++ ]}}, ++ deps = [ ++ ":core", ++ ":ir", ++ ":mlir", ++ "//jaxlib/mlir/_mlir_libs:_mlirDialectsJasc", ++ ], ++) ++ + symlink_inputs( + name = "mhlo_dialect", + rule = py_library, + +--- a/jaxlib/mlir/_mlir_libs/BUILD.bazel ++++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel +@@ -71,6 +71,39 @@ py_extension( + ], + ) + ++ ++py_extension( ++ name = "_mlirDialectsJasc", ++ srcs = [ ++ "@jasc//dialect:bindings.cc", ++ ], ++ copts = COPTS, ++ linkopts = LINKOPTS, ++ deps = [ ++ ":jaxlib_mlir_capi_shared_library", ++ "@jasc//dialect:jasc_dialect_headers", ++ "@jasc//transform_ops:jasc_transform_ops_shared_library_headers", ++ "@jasc//:mlir_lowering_shared_library_headers", ++ "@llvm-project//mlir:MLIRBindingsPythonHeaders", ++ "@pybind11", ++ ], ++) ++ ++py_extension( ++ name = "_mlirTransformOpsJasc", ++ srcs = [ ++ "@jasc//transform_ops:bindings.cpp", ++ ], ++ copts = COPTS, ++ linkopts = LINKOPTS, ++ deps = [ ++ ":jaxlib_mlir_capi_shared_library", ++ "@jasc//transform_ops:jasc_transform_ops_headers", ++ "@llvm-project//mlir:MLIRBindingsPythonHeaders", ++ "@pybind11", ++ ], ++) ++ + py_extension( + name = "_mlirSparseTensorPasses", + srcs = [ +@@ -126,6 +159,7 @@ py_extension( + linkopts = LINKOPTS, + deps = [ + ":jax_dialects_capi_headers", ++ "@jasc//dialect:capi_headers", + ":jaxlib_mlir_capi_shared_library", + "@llvm-project//mlir:CAPIArithHeaders", + "@llvm-project//mlir:CAPIIRHeaders", +@@ -276,6 +310,9 @@ cc_library( + name = "jaxlib_mlir_capi_objects", + deps = [ + ":jax_dialects_capi", ++ "@jasc//dialect:capi", ++ "@jasc//transform_ops:jasc_transform_ops", ++ "@jasc//:mlir_lowering", + "//jaxlib/mosaic:tpu_dialect_capi_objects", + "@com_google_protobuf//:protobuf", + "@llvm-project//mlir:CAPIArithObjects", + +--- a/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc ++++ b/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc +@@ -13,6 +13,7 @@ + #include "mlir-c/Transforms.h" + #include "mlir/Bindings/Python/PybindAdaptors.h" + #include "jaxlib/mlir/_mlir_libs/jax_dialects.h" ++#include "dialect/capi.h" + + #define REGISTER_DIALECT(name) \ + MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ +@@ -25,6 +26,7 @@ PYBIND11_MODULE(_site_initialize_0, m) { + REGISTER_DIALECT(arith); + // REGISTER_DIALECT(bufferization); + REGISTER_DIALECT(func); ++ REGISTER_DIALECT(jasc); + REGISTER_DIALECT(math); + REGISTER_DIALECT(memref); + REGISTER_DIALECT(pdl); + + +--- a/jaxlib/mlir/_mlir_libs/BUILD.bazel ++++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel +@@ -209,6 +209,23 @@ py_extension( + ], + ) + ++py_extension( ++ name = "_mlirDialectsLinalg", ++ srcs = [ ++ "@llvm-project//mlir:lib/Bindings/Python/DialectLinalg.cpp", ++ ], ++ copts = COPTS, ++ linkopts = LINKOPTS, ++ deps = [ ++ ":jax_dialects_capi_headers", ++ ":jaxlib_mlir_capi_shared_library", ++ "@llvm-project//mlir:CAPIIRHeaders", ++ "@llvm-project//mlir:CAPILinalg", ++ "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", ++ "@pybind11", ++ ], ++) ++ + ##---------------------------------------------------------------------------## + # MHLO Extensions + ##---------------------------------------------------------------------------## + +--- a/jaxlib/mlir/BUILD.bazel ++++ b/jaxlib/mlir/BUILD.bazel +@@ -13,6 +13,7 @@ + # limitations under the License. + + load("//jaxlib:symlink_files.bzl", "symlink_inputs") ++load("@pip_deps//:requirements.bzl", "requirement") + + package( + default_visibility = [ +@@ -49,6 +50,19 @@ py_library( + ) + + symlink_inputs( ++ name = "complex_dialect", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects": [ ++ "@llvm-project//mlir/python:ComplexOpsPyFiles", ++ ]}}, ++ deps = [ ++ ":core", ++ ":ir", ++ ":mlir", ++ ], ++) ++ ++symlink_inputs( + name = "func_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ +@@ -102,6 +116,80 @@ symlink_inputs( + ) + + symlink_inputs( ++ name = "linalg_dialect", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects/linalg": [ ++ "@llvm-project//mlir/python:LinalgOpsPackagePyFiles", ++ ]}}, ++ deps = [ ++ ":complex_dialect", ++ ":core", ++ ":ir", ++ ":mlir", ++ ":linalg_dialect_gen_files", ++ ":linalg_dialect_opdsl_files", ++ ":linalg_dialect_opdsl_lang_files", ++ ":linalg_dialect_opdsl_ops_files", ++ "//jaxlib/mlir/_mlir_libs:_mlirDialectsLinalg", ++ requirement("PyYAML"), ++ ], ++) ++ ++symlink_inputs( ++ name = "linalg_dialect_gen_files", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects": [ ++ "@llvm-project//mlir/python:LinalgOpsPyFiles", ++ ]}}, ++) ++ ++symlink_inputs( ++ name = "linalg_dialect_opdsl_files", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects/linalg/opdsl": [ ++ "@llvm-project//mlir/python:LinalgOpsPackageOpDSLPyFiles", ++ ]}}, ++) ++ ++symlink_inputs( ++ name = "linalg_dialect_opdsl_lang_files", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects/linalg/opdsl/lang": [ ++ "@llvm-project//mlir/python:LinalgOpsPackageOpDSLLangPyFiles", ++ ]}}, ++) ++ ++symlink_inputs( ++ name = "linalg_dialect_opdsl_ops_files", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects/linalg/opdsl/ops": [ ++ "@llvm-project//mlir/python:LinalgOpsPackageOpDSLOpsPyFiles", ++ ]}}, ++) ++ ++# symlink_files( ++# name = "linalg_package_opdsl_files", ++# srcs = ["//third_party/llvm/llvm-project/mlir/python:LinalgOpsPackageOpDSLPyFiles"], ++# dst = "dialects/linalg/opdsl", ++# flatten = True, ++# ) ++ ++# symlink_files( ++# name = "linalg_package_opdsl_lang_files", ++# srcs = ["//third_party/llvm/llvm-project/mlir/python:LinalgOpsPackageOpDSLLangPyFiles"], ++# dst = "dialects/linalg/opdsl/lang", ++# flatten = True, ++# ) ++ ++# symlink_files( ++# name = "linalg_package_opdsl_ops_files", ++# srcs = ["//third_party/llvm/llvm-project/mlir/python:LinalgOpsPackageOpDSLOpsPyFiles"], ++# dst = "dialects/linalg/opdsl/ops", ++# flatten = True, ++# ) ++ ++ ++symlink_inputs( + name = "transform_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ diff --git a/jasc/patches/jax_workspace.patch b/jasc/patches/jax_workspace.patch new file mode 100644 index 000000000000..d9b517998a36 --- /dev/null +++ b/jasc/patches/jax_workspace.patch @@ -0,0 +1,22 @@ +--- a/third_party/xla/workspace.bzl ++++ b/third_party/xla/workspace.bzl +@@ -24,12 +24,13 @@ XLA_COMMIT = "8f27d321a86029c336558bfbd6 + XLA_SHA256 = "e8225ee13a8e69c49554d0ec87a0a509c645a1b267c557cd5b9bfe175a4b3f29" + + def repo(): +- tf_http_archive( +- name = "xla", +- sha256 = XLA_SHA256, +- strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), +- urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), +- ) ++ # tf_http_archive( ++ # name = "xla", ++ # sha256 = XLA_SHA256, ++ # strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), ++ # urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), ++ # ) ++ pass + + # For development, one often wants to make changes to the TF repository as well + # as the JAX repository. You can override the pinned repository above with a diff --git a/jasc/patches/llvm_build.patch b/jasc/patches/llvm_build.patch new file mode 100644 index 000000000000..d6849fa9b438 --- /dev/null +++ b/jasc/patches/llvm_build.patch @@ -0,0 +1,94 @@ +diff --git a/utils/bazel/llvm-project-overlay/lld/BUILD.bazel b/utils/bazel/llvm-project-overlay/lld/BUILD.bazel +index fb6e2397cc84..db259fffaa63 100644 +--- a/utils/bazel/llvm-project-overlay/lld/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/lld/BUILD.bazel +@@ -108,7 +108,6 @@ cc_library( + "//llvm:TargetParser", + "//llvm:TransformUtils", + "//llvm:config", +- "@llvm_zlib//:zlib", + "@llvm_zstd//:zstd", + ], + ) +diff --git a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel +index 0cc28fd856bc..51764826a130 100644 +--- a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel +@@ -277,11 +277,9 @@ cc_library( + # We unconditionally depend on the custom LLVM zlib wrapper. This will + # be an empty library unless zlib is enabled, in which case it will + # both provide the necessary dependencies and configuration defines. +- "@llvm_zlib//:zlib", + # We unconditionally depend on the custom LLVM zstd wrapper. This will + # be an empty library unless zstd is enabled, in which case it will + # both provide the necessary dependencies and configuration defines. +- "@llvm_zstd//:zstd", + ], + ) + + +--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +@@ -686,12 +686,11 @@ mlir_c_api_cc_library( + hdrs = [ + "include/mlir-c/Dialect/PDL.h", + ], +- header_deps = [ +- ":CAPIIRHeaders", +- ], + includes = ["include"], +- deps = [ ++ capi_deps = [ + ":CAPIIR", ++ ], ++ deps = [ + ":PDLDialect", + ":PDLOpsIncGen", + ":PDLTypesIncGen", +@@ -952,6 +951,27 @@ cc_library( + ], + ) + ++ ++cc_library( ++ name = "MLIRBindingsPythonCoreNoCAPIObjects", ++ srcs = MLIR_PYTHON_BINDINGS_SOURCES, ++ copts = PYBIND11_COPTS, ++ features = PYBIND11_FEATURES, ++ alwayslink = True, ++ deps = [ ++ ":CAPIAsyncHeaders", ++ ":CAPIDebugHeaders", ++ ":CAPIGPUHeaders", ++ ":CAPIIRHeaders", ++ ":CAPIInterfacesHeaders", ++ ":MLIRBindingsPythonHeaders", ++ "//llvm:Support", ++ "@pybind11", ++ "@python_runtime//:headers", ++ ], ++) ++ ++ + # Target that bundles together the CAPI objects needed for + # MLIRBindingsPythonCoreNoCAPI. + cc_library( +@@ -6160,6 +6180,7 @@ cc_library( + ":SideEffectInterfaces", + "//llvm:Support", + ], ++ linkstatic = True, + ) + + td_library( + +--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel ++++ a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +@@ -8907,6 +8907,7 @@ cc_library( + ":mlir_float16_utils", + "//llvm:Support", + ], ++ alwayslink = True, + ) + + # Indirection to avoid 'libmlir_c_runner_utils.so' filename clash. diff --git a/jasc/patches/stablehlo_build.patch b/jasc/patches/stablehlo_build.patch new file mode 100644 index 000000000000..d84b465f2190 --- /dev/null +++ b/jasc/patches/stablehlo_build.patch @@ -0,0 +1,10 @@ +--- a/BUILD.bazel ++++ b/BUILD.bazel +@@ -907,6 +907,7 @@ cc_library( + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ComplexDialect", ++ "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:QuantOps", diff --git a/jasc/patches/xla.patch b/jasc/patches/xla.patch new file mode 100644 index 000000000000..4d20c9935561 --- /dev/null +++ b/jasc/patches/xla.patch @@ -0,0 +1,100 @@ +--- a/xla/BUILD ++++ b/xla/BUILD +@@ -29,6 +29,7 @@ package_group( + "//third_party/py/tpu_graphs/...", + "//tensorflow/compiler/...", + "//tensorflow/python/tpu/...", ++ "public", + ], + ) + + +--- a/xla/python/BUILD ++++ b/xla/python/BUILD +@@ -1134,6 +1134,10 @@ cc_library( + "//xla:statusor", + "//xla:types", + "//xla:util", ++ "//xla:xla_data_proto_cc", ++ "//xla:xla_data_proto_cc_impl", ++ "//xla:xla_proto_cc", ++ "//xla:xla_proto_cc_impl", + "//xla/pjrt:mlir_to_hlo", + "//xla/pjrt:pjrt_api", + "//xla/pjrt:pjrt_c_api_client", +@@ -1147,10 +1151,47 @@ cc_library( + "//xla/pjrt/distributed:service", + "//xla/python/ifrt", + "//xla/python/pjrt_ifrt", ++ "//xla/service/gpu:backend_configs_cc", ++ "//xla/service/gpu:backend_configs_cc_impl", ++ "//xla/service:buffer_assignment_proto_cc", ++ "//xla/service:buffer_assignment_proto_cc_impl", ++ "//xla/service:hlo_proto_cc", ++ "//xla/service:hlo_proto_cc_impl", ++ "//xla/stream_executor:device_description_proto_cc", ++ "//xla/stream_executor:device_description_proto_cc_impl", ++ "//xla/stream_executor:stream_executor_impl", ++ "//xla:autotune_results_proto_cc", ++ "//xla:autotune_results_proto_cc_impl", ++ "//xla:autotuning_proto_cc", ++ "//xla:autotuning_proto_cc_impl", + "@tsl//tsl/distributed_runtime/preemption:preemption_sync_manager", + "@tsl//tsl/platform", ++ "@tsl//tsl/platform:env", ++ "@tsl//tsl/platform:env_impl", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform/cloud:gcs_file_system", ++ "@tsl//tsl/profiler/backends/cpu:traceme_recorder_impl", ++ "@tsl//tsl/profiler/lib:profiler_factory_impl", ++ "@tsl//tsl/profiler/lib:profiler_session_impl", ++ "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", ++ "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc_impl", ++ "@tsl//tsl/profiler/protobuf:profiler_options_proto_cc", ++ "@tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl", ++ "@tsl//tsl/profiler/protobuf:profiler_service_monitor_result_proto_cc", ++ "@tsl//tsl/profiler/protobuf:profiler_service_monitor_result_proto_cc_impl", ++ "@tsl//tsl/profiler/protobuf:profiler_service_proto_cc", ++ "@tsl//tsl/profiler/protobuf:profiler_service_proto_cc_impl", ++ "@tsl//tsl/profiler/protobuf:xplane_proto_cc", ++ "@tsl//tsl/profiler/protobuf:xplane_proto_cc_impl", ++ "@tsl//tsl/profiler/utils:time_utils_impl", ++ "@tsl//tsl/protobuf:coordination_config_proto_cc", ++ "@tsl//tsl/protobuf:coordination_config_proto_cc_impl", ++ "@tsl//tsl/protobuf:bfc_memory_map_proto_cc", ++ "@tsl//tsl/protobuf:bfc_memory_map_proto_cc_impl", ++ "@tsl//tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", ++ "@tsl//tsl/protobuf:dnn_proto_cc_impl", ++ "@tsl//tsl/protobuf:histogram_proto_cc", ++ "@tsl//tsl/protobuf:histogram_proto_cc_impl", + "@tsl//tsl/python/lib/core:numpy", + "@pybind11", + ] + select({ + +--- a/xla/python/BUILD ++++ b/xla/python/BUILD +@@ -1157,6 +1157,7 @@ cc_library( + "//xla/service:buffer_assignment_proto_cc_impl", + "//xla/service:hlo_proto_cc", + "//xla/service:hlo_proto_cc_impl", ++ "//xla/stream_executor/gpu:gpu_init_impl", + "//xla/stream_executor:device_description_proto_cc", + "//xla/stream_executor:device_description_proto_cc_impl", + "//xla/stream_executor:stream_executor_impl", +@@ -1165,6 +1166,7 @@ cc_library( + "//xla:autotuning_proto_cc", + "//xla:autotuning_proto_cc_impl", + "@tsl//tsl/distributed_runtime/preemption:preemption_sync_manager", ++ "@tsl//tsl/framework:allocator_registry_impl", + "@tsl//tsl/platform", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:env_impl", +@@ -1175,6 +1177,8 @@ cc_library( + "@tsl//tsl/profiler/lib:profiler_session_impl", + "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", + "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc_impl", ++ "@tsl//tsl/profiler/protobuf:profiler_analysis_proto_cc", ++ "@tsl//tsl/profiler/protobuf:profiler_analysis_proto_cc_impl", + "@tsl//tsl/profiler/protobuf:profiler_options_proto_cc", + "@tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl", + "@tsl//tsl/profiler/protobuf:profiler_service_monitor_result_proto_cc", diff --git a/jasc/primitives.py b/jasc/primitives.py new file mode 100644 index 000000000000..bb9d78168341 --- /dev/null +++ b/jasc/primitives.py @@ -0,0 +1,208 @@ +"""Jax primitives backing Jasc schedules.""" + +from collections.abc import Callable, Sequence +import contextlib +import itertools +from typing import Any, Optional + +import jax +from jax.extend.linear_util import wrap_init +from jax.interpreters import mlir as jax_mlir +from jax.interpreters import partial_eval as pe +from jax.lib import xla_client +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import pdl +from jaxlib.mlir.dialects import stablehlo +from jaxlib.mlir.dialects import transform + +import call_kernel +from jaxlib.mlir.dialects import jasc as jasc_dialect + + +_JAX_COMPATIBLE_LOWERING = True + +call_kernel.init_llvm() + + +@contextlib.contextmanager +def enable_jasc_lowering(): + """ContextManager to enable usage of `with enable_jasc_lowering()`.""" + global _JAX_COMPATIBLE_LOWERING + _JAX_COMPATIBLE_LOWERING = False + try: + yield + finally: + _JAX_COMPATIBLE_LOWERING = True + + +def _func_to_mlir_module( + ctx: jax_mlir.LoweringRuleContext, func: Callable[..., Any] +) -> ir.Module: + """Compiles func to an MLIR module.""" + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrap_init(func), ctx.avals_in) + closed_jaxpr = jax.core.ClosedJaxpr(jaxpr, consts) + result = jax_mlir.lower_jaxpr_to_module( + module_name='jasc_jit', + jaxpr=closed_jaxpr, + backend_or_name=ctx.module_context.backend_or_name, + ordered_effects=[], + name_stack=ctx.module_context.name_stack, + donated_args=[False] * len(closed_jaxpr.jaxpr.invars), + axis_context=ctx.module_context.axis_context, + platforms=ctx.module_context.platforms, + lowering_parameters=jax_mlir.LoweringParameters(), + ) + if result.keepalive or result.host_callbacks: + raise NotImplementedError('Jasc does not support callbacks') + return result.module + + +def _generate_schedule(build_schedule: Callable[[ir.Value], None]) -> None: + sequence_op = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, (), pdl.OperationType.get() + ) + with ir.InsertionPoint(sequence_op.body): + build_schedule(sequence_op.bodyTarget) + transform.YieldOp() + + +def _jit_lowering( + ctx: jax_mlir.LoweringRuleContext, + *args: ir.Value, + func: Callable[..., Any], + module: Optional[ir.Module] = None, + build_schedule: Callable[[ir.Value], None], + out_avals: Sequence[jax.core.AbstractValue], + dump_ir: bool, +) -> Sequence[ir.Value]: + """Lowers a call to the jit primitive. + + Args: + ctx: Jax lowering context. + *args: MLIR values that holds the value of the flattened function arguments. + func: Function to lower. + module: Optional already lowered representation of func. If this is supplied + it will be used rather than lowering `func`. + build_schedule: Function that generates an MLIR transform dialect script. + Takes the root transform handle as input and expects the insertion point + to be set. + out_avals: Abstract values of func outputs. + dump_ir: If true, log intermediate steps of the compilation process. + + Returns: + A sequence of MLIR values holding the value of the function outputs. + """ + del out_avals + if module is None: + with enable_jasc_lowering(): + lowered_ir = _func_to_mlir_module(ctx, func) + else: + lowered_ir = module + with lowered_ir.context: + with ir.Location.unknown(lowered_ir.context): + with ir.InsertionPoint(lowered_ir.body): + _generate_schedule(build_schedule) + + backend_config = None + mlir_args = [] + + if len(ctx.module_context.platforms) > 1: + raise NotImplementedError('Multi-platform lowering') + if ctx.module_context.platforms[0] == 'cpu': + compiled_kernel = call_kernel.create_cpu_kernel( + module=lowered_ir, + num_inputs=len(args), + num_outputs=len(ctx.avals_out), + dump_ir=dump_ir, + ) + ctx.module_context.add_keepalive(compiled_kernel) + identifier_attr = jax_mlir.dense_int_elements([compiled_kernel.identifier]) + identifier_op = stablehlo.ConstantOp(identifier_attr) + mlir_args = [identifier_op.result] + # elif ctx.module_context.platforms[0] == 'cuda': + # compiled_kernel = call_kernel.create_cuda_kernel( + # module=lowered_ir, + # num_inputs=len(args), + # num_outputs=len(ctx.avals_out), + # dump_ir=dump_ir, + # ) + # ctx.module_context.add_keepalive(compiled_kernel) + # backend_config = ir.StringAttr.get(compiled_kernel.ptr) + else: + raise NotImplementedError( + f'Jasc does not support platform {ctx.module_context.platforms[0]}' + ) + + mlir_args.extend(args) + out_types = tuple( + itertools.chain(*map(jax_mlir.aval_to_ir_types, ctx.avals_out)) + ) + custom_call = stablehlo.CustomCallOp( + out_types, + mlir_args, + call_target_name='jasc.call_kernel', + backend_config=backend_config, + ) + return custom_call.results + + +jit_p = jax.core.Primitive('jasc.jit') +jit_p.multiple_results = True +jit_p.def_impl( + lambda *args, func, module, build_schedule, out_avals, dump_ir: func(*args) +) +jit_p.def_abstract_eval( + lambda *args, func, module, build_schedule, out_avals, dump_ir: out_avals +) +jax_mlir.register_lowering(jit_p, _jit_lowering) + + +xla_client.register_custom_call_target( + 'jasc.call_kernel', call_kernel.get_cpu_callback(), platform='cpu' +) +# xla_client.register_custom_call_target( +# 'jasc.call_kernel', call_kernel.get_cuda_callback(), platform='CUDA' +# ) + + +def _tag_lowering( + ctx: jax_mlir.LoweringRuleContext, + *args: ir.Value, + func: Callable[..., Any], + out_avals: Sequence[jax.core.AbstractValue], + name: str, +) -> Sequence[ir.Value]: + """Lowers a call to the tag primitive. + + Args: + ctx: Jax lowering context. + *args: MLIR values that holds the value of the flattened function arguments. + func: Function to tag. + out_avals: Abstract values of func outputs. + name: Tag name. + + Returns: + A sequence of MLIR values holding the value of the function outputs. + """ + del out_avals + if _JAX_COMPATIBLE_LOWERING: + return jax_mlir.lower_fun(func, multiple_results=True)(ctx, *args) + + jasc_dialect.register_and_load_dialect(ctx.module_context.context) + out_types = tuple( + itertools.chain(*map(jax_mlir.aval_to_ir_types, ctx.avals_out)) + ) + tag_op = jasc_dialect.TagRegionOp(out_types, name) + tag_op.body.blocks.append() + with ir.InsertionPoint(tag_op.body.blocks[0]): + lower_rule = jax_mlir.lower_fun(func, multiple_results=True) + results = lower_rule(ctx, *args) + jasc_dialect.ReturnOp(sum(results, ())) + return tag_op.results + + +tag_p = jax.core.Primitive('jasc.tag') +tag_p.multiple_results = True +tag_p.def_impl(lambda *args, func, out_avals, name: func(*args)) +tag_p.def_abstract_eval(lambda *args, func, out_avals, name: out_avals) +jax_mlir.register_lowering(tag_p, _tag_lowering) diff --git a/jasc/test/BUILD b/jasc/test/BUILD new file mode 100644 index 000000000000..b8a931ba48dc --- /dev/null +++ b/jasc/test/BUILD @@ -0,0 +1,262 @@ +# JASC filecheck tests + +load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") +load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") +load("@pip_deps//:requirements.bzl", "requirement") +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") + +package( + default_visibility = ["//:__subpackages__"], +) + +# The image needs a library target that is not testonly. +# TODO(zinenko): consolidate as filegroups. +py_library( + name = "cpu_integration_lib", + srcs = ["cpu_integration.py"], + tags = ["manual"], + deps = [ + "//:jasc", + requirement("chex"), + requirement("pytest"), + "@jax1//jax:jax", + ], +) + +py_library( + name = "gpu_integration_lib", + srcs = ["gpu_integration.py"], + tags = ["manual"], + deps = [ + ":cpu_integration_lib", + requirement("chex"), + requirement("pytest"), + ], +) + +# The image needs a binary target that is not testonly. +# TODO(zinenko): consolidate as filegroups. +py_binary( + name = "gpu_integration_binary", + srcs = ["gpu_integration.py"], + main = "gpu_integration.py", + tags = ["manual"], + deps = [ + ":gpu_integration_lib", + ], +) + +py_test( + name = "gpu_integration", + srcs = [ + "gpu_integration.py", + ], + tags = ["requires-gpu-nvidia"], + deps = [ + ":gpu_integration_lib", + ], +) + +py_test( + name = "cpu_integration", + srcs = ["cpu_integration.py"], + deps = [ + ":cpu_integration_lib", + ], +) + +py_binary( + name = "abstractions", + srcs = ["abstractions.py"], + deps = [ + "//:jasc", + "//transform_ops", + requirement("absl-py"), + "@jax1//jaxlib/mlir:ir", + "@jax1//jaxlib/mlir:jasc_dialect", + "@jax1//jaxlib/mlir:scf_dialect", + ], +) + +py_binary( + name = "diagnostics", + srcs = ["diagnostics.py"], + deps = [ + "//:jasc", + requirement("absl-py"), + requirement("ml_dtypes"), + requirement("opt_einsum"), + "@jax1//jax:jax", + "@jax1//jaxlib/mlir:ir", + "@jax1//jaxlib/mlir:jasc_dialect", + ], +) + +py_binary( + name = "normalization", + srcs = ["normalization.py"], + deps = [ + "//:jasc", + "//transform_ops", + "@jax1//jaxlib/mlir:ir", + "@jax1//jaxlib/mlir:scf_dialect", + "@jax1//jaxlib/mlir:transform_dialect", + ], +) + +py_binary( + name = "tag", + srcs = ["tag.py"], + deps = [ + "//:jasc", + requirement("absl-py"), + requirement("chex"), + "@jax1//jax:jax", + ], +) + +py_test( + name = "batch_matmul_gpu", + srcs = ["batch_matmul_gpu.py"], + tags = ["requires-gpu-nvidia"], + deps = [ + "//:jasc", + requirement("chex"), + requirement("pytest"), + "@jax1//jax:jax", + "@jax1//jaxlib/mlir:ir", + "@jax1//jaxlib/mlir:transform_dialect", + ], +) + +py_test( + name = "matmul_gpu", + srcs = ["matmul_gpu.py"], + tags = ["requires-gpu-nvidia"], + deps = [ + "//:jasc", + requirement("chex"), + requirement("pytest"), + "@jax1//jax:jax", + "@jax1//jaxlib/mlir:ir", + "@jax1//jaxlib/mlir:transform_dialect", + ], +) + +py_test( + name = "matmul_cpu", + srcs = ["matmul_cpu.py"], + deps = [ + "//:jasc", + requirement("absl-py"), + requirement("chex"), + requirement("pytest"), + "@jax1//jax:jax", + "@jax1//jaxlib/mlir:transform_dialect", + ], +) + +py_test( + name = "autotuning", + srcs = ["autotuning.py"], + deps = [ + "//:jasc", + "//:tuner", + requirement("absl-py"), + requirement("chex"), + requirement("pytest"), + "@jax1//jax:jax", + "@jax1//jaxlib/mlir:ir", + "@jax1//jaxlib/mlir:linalg_dialect", + ], +) + +py_test( + name = "jit", + srcs = ["jit.py"], + deps = [ + "//:jasc", + requirement("chex"), + requirement("pytest"), + "@jax1//jax:jax", + ], +) + +py_binary( + name = "bindings", + srcs = ["bindings.py"], + deps = [ + "//:jasc", + "//transform_ops", + requirement("absl-py"), + "@jax1//jaxlib/mlir:ir", + "@jax1//jaxlib/mlir:pass_manager", + "@jax1//jaxlib/mlir:transform_dialect", + ], +) + +[sh_test( + name = target + "_filecheck_test", + srcs = ["filecheck_test.sh"], + args = [target], + data = [ + ":" + target, + "@llvm-project//llvm:FileCheck", + ], +) for target in [ + "abstractions", + "bindings", + "diagnostics", + "normalization", + "tag", +]] + +LLVM_LIT_PATH_FUNCTION = " " + \ + "# Allow generated file to be relocatable.\n" + \ + "from pathlib import Path\n" + \ + "def path(p):\n" + \ + " if not p: return ''\n" + \ + " return str((Path(__file__).parent / p).resolve())\n" + +LIT_SITE_CFG_IN_HEADER = "# Autogenerated, do not edit.\n\n" + LLVM_LIT_PATH_FUNCTION + +expand_template( + name = "lit_site_cfg_py", + testonly = True, + out = "lit.site.cfg.py", + substitutions = { + "@LIT_SITE_CFG_IN_HEADER@": LIT_SITE_CFG_IN_HEADER, + "@LLVM_TOOLS_DIR@": package_path("@llvm-project//llvm:BUILD"), + "@MLIR_TOOLS_DIR@": package_path("@llvm-project//mlir:BUILD"), + "@SHLIBDIR@": package_path("@llvm-project//llvm:BUILD"), + "@JASC_SOURCE_DIR@": package_path("@jasc//:BUILD"), + "@JASC_TOOLS_DIR@": package_path("@jasc//:jasc-opt"), + }, + template = "lit.site.cfg.in.py", +) + +filegroup( + name = "lit_data", + testonly = True, + data = [ + "lit.cfg.py", + "lit.site.cfg.py", + "//:jasc-opt", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:count", + "@llvm-project//llvm:not", + ], +) + +[ + lit_test( + name = "%s.test" % src, + srcs = [src], + data = [ + ":lit_data", + ], + ) + for src in glob( + include = ["*.mlir"], + ) +] diff --git a/jasc/test/abstractions.py b/jasc/test/abstractions.py new file mode 100644 index 000000000000..0ae974930d45 --- /dev/null +++ b/jasc/test/abstractions.py @@ -0,0 +1,1186 @@ +"""Tests for JASC transform op abstractions.""" +from __future__ import annotations + +from typing import Callable, Sequence + +from absl import app +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import scf +from jaxlib.mlir.dialects import transform +from jaxlib.mlir.dialects.transform import structured + +from jasc import jasc +from jaxlib.mlir.dialects.transform import jasc_transform_ops + +tests: list[Callable[[], None]] = [] +jasc.set_auto_normalization(False) + + +def run(f): + def test(): + print("\nTEST:", f.__name__) + f() + + tests.append(test) + return f + + +def print_schedule(schedule: Callable) -> Callable: + def decorated() -> None: + with ir.Context(): + module = ir.Module.parse("") + jasc.insert_schedule(module, schedule=schedule, dump_schedule=True) + module.operation.verify() + + decorated.__name__ = schedule.__name__ + return decorated + + +# CHECK-LABEL: TEST: test_alloca_to_global +@run +@print_schedule +def test_alloca_to_global(program: jasc.OpHandle) -> None: + get_global, global_ = program.alloca_to_global() + assert isinstance(get_global, jasc.OpHandle) + assert isinstance(global_, jasc.OpHandle) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-DAG: %[[V0:.*]] = cast + # CHECK-DAG: %{{.*}}, %{{.*}} = transform.memref.alloca_to_global %[[V0]] + # CHECK-SAME: (!transform.op<"memref.alloca">) + # CHECK-SAME: -> (!transform.any_op, !transform.any_op) + + +# CHECK-LABEL: TEST: test_apply_cse +@run +@print_schedule +def test_apply_cse(program: jasc.OpHandle) -> None: + program.apply_cse() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: apply_cse to %[[ARG0]] : !transform.any_op + + +# CHECK-LABEL: TEST: test_apply_dce +@run +@print_schedule +def test_apply_dce(program: jasc.OpHandle) -> None: + program.apply_dce() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: apply_dce to %[[ARG0]] : !transform.any_op + + +# CHECK-LABEL: TEST: test_apply_licm_self +@run +@print_schedule +def test_apply_licm_self(program: jasc.OpHandle) -> None: + program.apply_licm() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: apply_licm to %[[ARG0]] + + +# CHECK-LABEL: TEST: test_apply_licm_empty +@run +@print_schedule +def test_apply_licm_empty(program: jasc.OpHandle) -> None: + program.apply_licm(to=[]) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NOT: transform.structured.match ops + # CHECK-NOT: apply_licm + + +# CHECK-LABEL: TEST: test_apply_licm_single +@run +@print_schedule +def test_apply_licm_single(program: jasc.OpHandle) -> None: + program.apply_licm(to=["scf.for"]) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL0:.*]] = transform.structured.match ops{["scf.for"]} + # CHECK-SAME: in %[[ARG0]] + # CHECK-NEXT: apply_licm to %[[VAL0]] + + +# CHECK-LABEL: TEST: test_apply_licm_multi +@run +@print_schedule +def test_apply_licm_multi(program: jasc.OpHandle) -> None: + program.apply_licm(to=["scf.for", "scf.forall"]) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL0:.*]] = transform.structured.match ops{["scf.for", "scf.forall"]} + # CHECK-SAME: in %[[ARG0]] + # CHECK-NEXT: apply_licm to %[[VAL0]] + + +# CHECK-LABEL: TEST: test_apply_licm_mixed +@run +@print_schedule +def test_apply_licm_mixed(program: jasc.OpHandle) -> None: + scf_for = program.match_ops("scf.for") + program.apply_licm(to=["scf.forall", scf_for]) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-DAG: %[[VL0:.*]] = trans{{.*}}.match ops{["scf.for"]} in %[[ARG0]] + # CHECK-DAG: apply_licm to %[[V0]] + # CHECK-DAG: %[[V1:.*]] = trans{{.*}}.match ops{["scf.forall"]} in %[[ARG0]] + # CHECK-DAG: apply_licm to %[[V1]] + + +# CHECK-LABEL: TEST: test_apply_patterns_empty +@run +@print_schedule +def test_apply_patterns_empty(program: jasc.OpHandle) -> None: + with program.apply_patterns(): + pass + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: apply_patterns to %[[ARG0]] { + # CHECK-NEXT: } : !transform.any_op + + +# CHECK-LABEL: TEST: test_apply_patterns_args +@run +@print_schedule +def test_apply_patterns_args(program: jasc.OpHandle) -> None: + with program.apply_patterns(apply_cse=True): + pass + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: apply_patterns to %[[ARG0]] { + # CHECK-NEXT: } {apply_cse} : !transform.any_op + + +# CHECK-LABEL: TEST: test_apply_patterns_multiple +@run +@print_schedule +def test_apply_patterns_multiple(program: jasc.OpHandle) -> None: + with program.apply_patterns(): + transform.ApplyCanonicalizationPatternsOp() + transform.ApplyCanonicalizationPatternsOp() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: apply_patterns to %[[ARG0]] { + # CHECK-NEXT: transform.apply_patterns.canonicalization + # CHECK-NEXT: transform.apply_patterns.canonicalization + # CHECK-NEXT: } : !transform.any_op + + +# CHECK-LABEL: TEST: test_apply_schedule_in_module +@run +def test_apply_schedule_in_module() -> None: + def schedule(program: jasc.OpHandle) -> None: + func = program.match_ops("func.func") + func.apply_dce() + + with ir.Context(): + module = ir.Module.parse(""" + module { + func.func @foo() { + %c0 = arith.constant 0 : i32 + func.return + } + }""") + jasc.insert_schedule(module, schedule, dump_schedule=True) + jasc.apply_schedule(module, dump_ir=False) + print(module) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL0:.*]] = transform.structured.match ops{["func.func"]} + # CHECK-SAME: in %[[ARG0]] + # CHECK-NEXT: apply_dce to %[[VAL0]] + # CHECK-NEXT: } + # CHECK-NEXT: module { + # CHECK-NEXT: func.func @foo() { + # CHECK-NEXT: return + # CHECK-NEXT: } + # CHECK-NEXT: } + + +# CHECK-LABEL: TEST: test_apply_schedule_outside_module +@run +def test_apply_schedule_outside_module() -> None: + def schedule(program: jasc.OpHandle) -> None: + func = program.match_ops("func.func") + func.apply_dce() + + with ir.Context(): + module = ir.Module.parse(""" + module { + func.func @foo() { + %c0 = arith.constant 0 : i32 + func.return + } + }""") + jasc.apply_schedule(module, schedule, dump_ir=False, dump_schedule=True) + print(module) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL0:.*]] = transform.structured.match ops{["func.func"]} + # CHECK-SAME: in %[[ARG0]] + # CHECK-NEXT: apply_dce to %[[VAL0]] + # CHECK-NEXT: } + # CHECK-NEXT: module { + # CHECK-NEXT: func.func @foo() { + # CHECK-NEXT: return + # CHECK-NEXT: } + # CHECK-NEXT: } + + +# CHECK-LABEL: TEST: test_apply_tuning_config +@run +@print_schedule +def test_apply_tuning_config(program: jasc.OpHandle) -> None: + program.apply_tuning_config( + [16, ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 16)] + ) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.jasc.apply_tuning_config %[[ARG0]] + # CHECK-SAME: {config = [16 : i32, 16 : i32]} + + +# CHECK-LABEL: TEST: test_buffer_loop_hoisting +@run +@print_schedule +def test_buffer_loop_hoisting(program: jasc.OpHandle) -> None: + program.buffer_loop_hoisting() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: buffer_loop_hoisting %[[ARG0]] + + +# CHECK-LABEL: TEST: test_bufferize_to_allocation +@run +@print_schedule +def test_bufferize_to_allocation(op: jasc.OpHandle) -> None: + result = op.bufferize_to_allocation() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %{{.*}}, %{{.*}} = transform.structured.bufferize_to_allocation + # CHECK-SAME: %[[ARG0]] : !transform.any_op + assert isinstance(result.allocated_buffer, jasc.ValueHandle) + assert isinstance(result.new_ops, jasc.OpHandle) + + +# CHECK-LABEL: TEST: test_cast_type +@run +@print_schedule +def test_cast_type(program: jasc.OpHandle): + program.cast(transform.OperationType.get("test.foo_op")) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[V0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V1:.*]] = cast %[[V0]] + # CHECK-SAME: !transform.any_op to !transform.op<"test.foo_op"> + + +# CHECK-LABEL: TEST: test_cast_string +@run +@print_schedule +def test_cast_string(program: jasc.OpHandle): + program.cast("test.foo_op") + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[V0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V1:.*]] = cast %[[V0]] + # CHECK-SAME: !transform.any_op to !transform.op<"test.foo_op"> + + +# CHECK-LABEL: TEST: test_create_async_groups +@run +@print_schedule +def test_create_async_groups(program: jasc.OpHandle): + program.create_async_groups() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[V0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V1:.*]] = transform.nvgpu.create_async_groups %[[V0]] + + +# CHECK-LABEL: TEST: test_custom_default_value_f32_tuning_param +@run +@print_schedule +def test_custom_default_value_f32_tuning_param(_: jasc.OpHandle) -> None: + jasc.tuning_param(default_value=ir.FloatAttr.get_f32(1.0)) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL0:.*]] = transform.jasc.tuning_param + # CHECK-SAME: {default_value = 1.000000e+00 : f32} + + +# CHECK-LABEL: TEST: test_custom_default_value_int_tuning_param +@run +@print_schedule +def test_custom_default_value_int_tuning_param(_: jasc.OpHandle) -> None: + jasc.tuning_param( + default_value=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1) + ) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL0:.*]] = transform.jasc.tuning_param + # CHECK-SAME: {default_value = 1 : i32} + + +# CHECK-LABEL: TEST: generic_tuning_param +@run +@print_schedule +def generic_tuning_param(_: jasc.OpHandle) -> None: + jasc.tuning_param() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL0:.*]] = transform.jasc.tuning_param + # CHECK-SAME: {default_value = 1 : i32} + + +# CHECK-LABEL: TEST: test_handle_hierarchy +@run +@print_schedule +def test_handle_hierarchy(program: jasc.OpHandle) -> None: + foo_op_handle = program.match_ops("test.foo_op") + foo_tag_handle = program.match_tag("foo_tag") + assert foo_op_handle.parent == program + assert foo_op_handle in program.children + assert foo_tag_handle.parent == program + assert foo_tag_handle in program.children + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["test.foo_op"]} + # CHECK-SAME: in %[[VAL_0]] + # CHECK-SAME: -> !transform.op<"test.foo_op"> + # CHECK-NEXT: %[[VAL_2:.*]] = transform.structured.match interface{LinalgOp} + # CHECK-SAME: in %[[VAL_0]] + # CHECK-NEXT: transform.jasc.match_tag ["foo_tag"] in %[[VAL_2]] + + +# CHECK-LABEL: TEST: test_eliminate_empty_tensors +@run +@print_schedule +def test_eliminate_empty_tensors(program: jasc.OpHandle) -> None: + program.eliminate_empty_tensors() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.bufferization.eliminate_empty_tensors %[[ARG0]] : !transform.any_op + + +# CHECK-LABEL: TEST: test_fold_fill_into_pad +@run +@print_schedule +def test_fold_fill_into_pad(program: jasc.OpHandle) -> None: + with program.apply_patterns(): + jasc_transform_ops.ApplyFoldFillIntoPadPatternsOp() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: apply_patterns to %[[ARG0]] { + # CHECK-NEXT: transform.apply_patterns.fold_fill_into_pad + + +# CHECK-LABEL: TEST: test_foreach_empty +@run +@print_schedule +def test_foreach_empty(program: jasc.OpHandle) -> None: + with program.foreach().body: + pass + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: foreach %[[ARG0]] : !transform.any_op { + # CHECK-NEXT: ^{{.*}}(%[[ARG1:.*]]: !transform.any_op): + # CHECK-NEXT: } + + +# CHECK-LABEL: TEST: test_foreach_only_yield +@run +@print_schedule +def test_foreach_only_yield(program: jasc.OpHandle) -> None: + with program.foreach().body: + jasc.yield_() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: foreach %[[ARG0]] : !transform.any_op { + # CHECK-NEXT: ^{{.*}}(%[[ARG1:.*]]: !transform.any_op): + # CHECK-NEXT: } + + +# CHECK-LABEL: TEST: test_foreach_simple_noyield +@run +@print_schedule +def test_foreach_simple_noyield(program: jasc.OpHandle) -> None: + with program.foreach().body as arg: + assert isinstance(arg, jasc.OpHandle) + arg.print() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: foreach %[[ARG0]] : !transform.any_op { + # CHECK-NEXT: ^{{.*}}(%[[ARG1:.*]]: !transform.any_op): + # CHECK-NEXT: transform.print %[[ARG1]] : !transform.any_op + # CHECK-NEXT: } + + +# CHECK-LABEL: TEST: test_foreach_simple_yield +@run +@print_schedule +def test_foreach_simple_yield(program: jasc.OpHandle) -> None: + with program.foreach().body as arg: + arg.print() + jasc.yield_() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: foreach %[[ARG0]] : !transform.any_op { + # CHECK-NEXT: ^{{.*}}(%[[ARG1:.*]]: !transform.any_op): + # CHECK-NEXT: transform.print %[[ARG1]] : !transform.any_op + # CHECK-NEXT: } + + +# CHECK-LABEL: TEST: test_foreach_explicit_yield +@run +@print_schedule +def test_foreach_explicit_yield(program: jasc.OpHandle) -> None: + foreach = program.foreach([transform.AnyOpType.get()] * 2) + with foreach.body as arg: + jasc.yield_([arg, arg]) + foreach.results[0].print() + foreach.results[1].print() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]]:2 = foreach %[[ARG0]] : !transform.any_op -> !transform.any_op, !transform.any_op { + # CHECK-NEXT: ^{{.*}}(%[[ARG1:.*]]: !transform.any_op): + # CHECK-NEXT: transform.yield %[[ARG1]], %[[ARG1]] : !transform.any_op, !transform.any_op + # CHECK-NEXT: } + # CHECK-NEXT: print %[[V0]]#0 + # CHECK-NEXT: print %[[V0]]#1 + + +# CHECK-LABEL: TEST: test_foreach_yield_one +@run +@print_schedule +def test_foreach_yield_one(program: jasc.OpHandle) -> None: + foreach = program.foreach(transform.AnyOpType.get()) + with foreach.body as arg: + jasc.yield_(arg) + foreach.results[0].print() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = foreach %[[ARG0]] : !transform.any_op -> !transform.any_op { + # CHECK-NEXT: ^{{.*}}(%[[ARG1:.*]]: !transform.any_op): + # CHECK-NEXT: transform.yield %[[ARG1]] : !transform.any_op + # CHECK-NEXT: } + # CHECK-NEXT: print %[[V0]] + + +# CHECK-LABEL: TEST: test_foreach_typed_input +@run +@print_schedule +def test_foreach_typed_input(program: jasc.OpHandle) -> None: + with program.cast("test.foo_op").foreach().body: + pass + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[CAST:.*]] = cast %[[ARG0]] + # CHECK-NEXT: foreach %[[CAST]] : !transform.op<"test.foo_op"> { + # CHECK-NEXT: ^{{.*}}(%[[ARG1:.*]]: !transform.op<"test.foo_op">): + # CHECK-NEXT: } + + +# CHECK-LABEL: TEST: test_fuse_into_standalone +@run +@print_schedule +def test_fuse_into_standalone(op: jasc.OpHandle) -> None: + other_op = op.match_ops("scf.for") + op.fuse_into(other_op) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = transform.structured.match + # CHECK-NEXT: transform.structured.fuse_into_containing_op + # CHECK-SAME: %[[ARG0]] into %[[V0]] + # CHECK-SAME: (!transform.any_op, !transform.op<"scf.for">) + # CHECK-SAME: -> (!transform.any_op, !transform.any_op) + + +# CHECK-LABEL: TEST: test_fuse_into_autonormalized +@run +@print_schedule +def test_fuse_into_autonormalized(op: jasc.OpHandle) -> None: + with jasc.autonormalize(): + other_op = op.match_ops("scf.for") + op.fuse_into(other_op) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = transform.structured.match + # CHECK-NEXT: %[[ARG1:.*]] = get_parent_op %[[ARG0:.*]] { + # CHECK-SAME: op_name = "func.func" + # CHECK-NEXT: apply_patterns to %[[ARG1]] { + # CHECK-NEXT: transform.apply_patterns.linalg.tiling_canonicalization + # CHECK-NEXT: transform.apply_patterns.fold_fill_into_pad + # CHECK-NEXT: transform.apply_patterns.scf.for_loop_canonicalization + # CHECK-NEXT: transform.apply_patterns.canonicalization + # CHECK-NEXT: } + # CHECK-NEXT: %[[ARG2:.*]] = transform.structured.match ops{["scf.for"]} + # CHECK-SAME: in %[[ARG1]] + # CHECK-NEXT: apply_licm to %[[ARG2]] : !transform.any_op + # CHECK-NEXT: apply_cse to %[[ARG1]] : !transform.any_op + # CHECK-NEXT: transform.structured.fuse_into_containing_op + # CHECK-SAME: %[[ARG0]] into %[[V0]] + # CHECK-SAME: (!transform.any_op, !transform.op<"scf.for">) + # CHECK-SAME: -> (!transform.any_op, !transform.any_op) + + +# CHECK-LABEL: TEST: test_get_parent_op_noargs +@run +@print_schedule +def test_get_parent_op_noargs(program: jasc.OpHandle) -> None: + program.get_parent_op().print() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = get_parent_op %[[ARG0]] + # CHECK-NOT: deduplicate + # CHECK-NOT: isolated_from_above + # CHECK-NOT: op_name + # CHECK-SAME: : (!transform.any_op) -> !transform.any_op + # CHECK-NEXT: print %[[V0]] + + +# CHECK-LABEL: TEST: test_get_parent_op_allargs +@run +@print_schedule +def test_get_parent_op_allargs(program: jasc.OpHandle) -> None: + program.get_parent_op( + deduplicate=True, + isolated_from_above=True, + op_name="test.foo_op", + ) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = get_parent_op %[[ARG0]] + # CHECK-SAME: deduplicate + # CHECK-SAME: isolated_from_above + # CHECK-SAME: op_name = "test.foo_op" + + +# CHECK-LABEL: TEST: test_get_producer_of_operand +@run +@print_schedule +def test_get_producer_of_operand(program: jasc.OpHandle) -> None: + program.get_producer_of_operand(operand_number=0) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: get_producer_of_operand %[[ARG0]][0] + + +# CHECK-LABEL: TEST: test_hoist_redundant_vector_transfers +@run +@print_schedule +def test_hoist_redundant_vector_transfers(program: jasc.OpHandle) -> None: + program.hoist_redundant_vector_transfers().print() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = transform.structured.hoist_redundant_vector_transfers %[[ARG0]] + # CHECK-DAG: print %[[V0]] + + +# CHECK-LABEL: TEST: test_hoist_pad +@run +@print_schedule +def test_hoist_pad(program: jasc.OpHandle) -> None: + op = program.hoist_pad(num_loops=1) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = transform.structured.hoist_pad %[[ARG0]] + # CHECK-SAME: by 1 loops + + +# CHECK-LABEL: TEST: test_insert_slice_to_copy +@run +@print_schedule +def test_insert_slice_to_copy(program: jasc.OpHandle) -> None: + program.insert_slice_to_copy() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.structured.insert_slice_to_copy %[[ARG0]] + # CHECK-SAME: !transform.op<"linalg.copy"> + + +# CHECK-LABEL: TEST: test_interchange +@run +@print_schedule +def test_interchange(program: jasc.OpHandle) -> None: + program.interchange([0, 2, 1]) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.structured.interchange %[[ARG0]] + # CHECK-SAME: iterator_interchange = [0, 2, 1] + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + + +# CHECK-LABEL: TEST: test_map_forall_to_blocks_noargs +@run +@print_schedule +def test_map_forall_to_blocks_noargs(program: jasc.OpHandle) -> None: + program.map_forall_to_blocks() + program.print() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = transform.gpu.map_forall_to_blocks %[[ARG0]] + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + # CHECK-NEXT: print %[[V0:.*]] + + +# CHECK-LABEL: TEST: test_map_forall_to_blocks_args +@run +@print_schedule +def test_map_forall_to_blocks_args(program: jasc.OpHandle) -> None: + program.map_forall_to_blocks(grid_dims=[4, 2, 1], generate_gpu_launch=True) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = transform.gpu.map_forall_to_blocks %[[ARG0]] + # CHECK-SAME: generate_gpu_launch grid_dims = [4, 2, 1] + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + + +# CHECK-LABEL: TEST: test_map_copy_to_threads +@run +@print_schedule +def test_map_copy_to_threads(program: jasc.OpHandle) -> None: + result = program.map_copy_to_threads( + total_num_threads=4, desired_bit_alignment=128 + ) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = transform.structured.gpu.map_copy_to_threads %[[ARG0]] + # CHECK-SAME: total_num_threads = 4 desired_bit_alignment = 128 + # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op) + assert isinstance(result.forall_op, jasc.OpHandle) + assert isinstance(result.tiled_op, jasc.OpHandle) + + +# CHECK-LABEL: TEST: test_map_nested_forall_to_threads_noargs +@run +@print_schedule +def test_map_nested_forall_to_threads_noargs(program: jasc.OpHandle) -> None: + program.map_nested_forall_to_threads().print() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = transform.gpu.map_nested_forall_to_threads %[[ARG0]] + # CHECK-NOT: sync_after_distribute + # CHECK-NOT: warp_size + # CHECK-SAME: block_dims = [] + # CHECK-NOT: sync_after_distribute + # CHECK-NOT: warp_size + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + # CHECK: print %[[V0]] + + +# CHECK-LABEL: TEST: test_map_nested_forall_to_threads_allargs +@run +@print_schedule +def test_map_nested_forall_to_threads_allargs(program: jasc.OpHandle) -> None: + program.map_nested_forall_to_threads( + block_dims=[4, 4], sync_after_distribute=False, warp_size=128 + ) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = transform.gpu.map_nested_forall_to_threads %[[ARG0]] + # CHECK-DAG: block_dims = [4, 4] + # CHECK-DAG: sync_after_distribute = false + # CHECK-DAG: warp_size = 128 + + +# CHECK-LABEL: TEST: test_synchronize +@run +@print_schedule +def test_synchronize(program: jasc.OpHandle) -> None: + barrier = program.synchronize() + assert isinstance(barrier, jasc.OpHandle) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.jasc.synchronize %[[ARG0]] + # CHECK-SAME: -> !transform.op<"gpu.barrier"> + + +# CHECK-LABEL: TEST: test_vectorize_static +@run +@print_schedule +def test_vectorize_static(program: jasc.OpHandle) -> None: + program.vectorize([16, 4]) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.structured.vectorize %[[ARG0]] + # CHECK-SAME: vector_sizes [16, 4] : !transform.any_op + + +# CHECK-LABEL: TEST: test_vectorize_array +@run +@print_schedule +def test_vectorize_array(program: jasc.OpHandle) -> None: + sizes = ir.Attribute.parse("[16, 4]") + program.vectorize(sizes) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.structured.vectorize %[[ARG0]] + # CHECK-SAME: vector_sizes [16, 4] : !transform.any_op + + +# CHECK-LABEL: TEST: test_vectorize_autonormalized +@run +@print_schedule +def test_vectorize_autonormalized(program: jasc.OpHandle) -> None: + with jasc.autonormalize(): + program.vectorize([16, 4]) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[ARG1:.*]] = get_parent_op %[[ARG0:.*]] { + # CHECK-SAME: op_name = "func.func" + # CHECK-NEXT: apply_patterns to %[[ARG1]] { + # CHECK-NEXT: transform.apply_patterns.linalg.tiling_canonicalization + # CHECK-NEXT: transform.apply_patterns.fold_fill_into_pad + # CHECK-NEXT: transform.apply_patterns.scf.for_loop_canonicalization + # CHECK-NEXT: transform.apply_patterns.canonicalization + # CHECK-NEXT: } + # CHECK-NEXT: %[[ARG2:.*]] = transform.structured.match ops{["scf.for"]} + # CHECK-SAME: in %[[ARG1]] + # CHECK-NEXT: apply_licm to %[[ARG2]] : !transform.any_op + # CHECK-NEXT: apply_cse to %[[ARG1]] : !transform.any_op + # CHECK-NEXT: transform.structured.vectorize %[[ARG0]] + # CHECK-SAME: vector_sizes [16, 4] : !transform.any_op + + +# CHECK-LABEL: TEST: test_vectorize_mixed +@run +@print_schedule +def test_vectorize_mixed(program: jasc.OpHandle) -> None: + sz1 = program.match_ops("arith.constant") + sz2 = ir.Attribute.parse("4") + program.vectorize([sz1, sz2]) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = transform.structured.match + # CHECK-NEXT: transform.structured.vectorize %[[ARG0]] + # CHECK-SAME: vector_sizes [%[[V0]] : !transform.op<"arith.constant">, 4] + + +# CHECK-LABEL: TEST: test_vectorize_scalable +@run +@print_schedule +def test_vectorize_scalable(program: jasc.OpHandle) -> None: + sz1 = program.match_ops("arith.constant") + sz2 = ir.Attribute.parse("4") + program.vectorize([16, [sz1], [sz2], [8]]) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = transform.structured.match + # CHECK-NEXT: transform.structured.vectorize %[[ARG0]] + # CHECK-SAME: vector_sizes [16, + # CHECK-SAME: [%[[V0]] : !transform.op<"arith.constant">], [4], [8]] + + +# CHECK-LABEL: TEST: test_vectorize_args +@run +@print_schedule +def test_vectorize_args(program: jasc.OpHandle) -> None: + program.vectorize([16, 4], vectorize_nd_extract=True) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.structured.vectorize %[[ARG0]] + # CHECK-SAME: vectorize_nd_extract + + +# CHECK-LABEL: TEST: test_match_ops_single +@run +@print_schedule +def test_match_ops_single(program: jasc.OpHandle): + program.match_ops(scf.ForOp) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]} + # CHECK-SAME: in %[[VAL_0]] + # CHECK-SAME: -> !transform.op<"scf.for"> + + +# CHECK-LABEL: TEST: test_match_ops_string_name +@run +@print_schedule +def test_match_ops_string_name(program: jasc.OpHandle): + program.match_ops("linalg.matmul") + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match + # CHECK-SAME: ops{["linalg.matmul"]} in %[[VAL_0]] + + +# CHECK-LABEL: TEST: test_match_ops_string_iface +@run +@print_schedule +def test_match_ops_string_iface(program: jasc.OpHandle): + program.match_ops("LinalgOp") + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match + # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]] + + +# CHECK-LABEL: TEST: test_match_ops_iface +@run +@print_schedule +def test_match_ops_iface(program: jasc.OpHandle): + program.match_ops(structured.MatchInterfaceEnum.LinalgOp) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match + # CHECK-SAME: interface{LinalgOp} in %[[VAL_0]] + + +# CHECK-LABEL: TEST: test_match_ops_multiple +@run +@print_schedule +def test_match_ops_multiple(program: jasc.OpHandle): + program.match_ops([scf.ForOp, scf.ForallOp]) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match + # CHECK-SAME: ops{["scf.for", "scf.forall"]} in %[[VAL_0]] + # CHECK-SAME: -> !transform.any_op + + +# CHECK-LABEL: TEST: test_match_ops_mixed +@run +@print_schedule +def test_match_ops_mixed(program: jasc.OpHandle): + program.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp]) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match + # CHECK-SAME: ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]] + # CHECK-SAME: -> !transform.any_op + + +# CHECK-LABEL: TEST: test_match_tag_single +@run +@print_schedule +def test_match_tag_single(program: jasc.OpHandle): + program.match_tag("foo_tag") + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match interface{LinalgOp} + # CHECK-SAME: in %[[VAL_0]] + # CHECK-NEXT: transform.jasc.match_tag ["foo_tag"] in %[[VAL_1]] + + +# CHECK-LABEL: TEST: test_match_tag_multiple +@run +@print_schedule +def test_match_tag_multiple(program: jasc.OpHandle): + program.match_tag(["foo_tag_0", "foo_tag_1"]) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match interface{LinalgOp} + # CHECK-SAME: in %[[VAL_0]] + # CHECK-NEXT: transform.jasc.match_tag ["foo_tag_0", "foo_tag_1"] + # CHECK-SAME: in %[[VAL_1]] + + +# CHECK-LABEL: TEST: test_one_shot_bufferize_noargs +@run +@print_schedule +def test_one_shot_bufferize_noargs(program: jasc.OpHandle) -> None: + program.one_shot_bufferize() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = transform.bufferization.one_shot_bufferize %[[ARG0]] + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + + +# CHECK-LABEL: TEST: test_one_shot_bufferize_args +@run +@print_schedule +def test_one_shot_bufferize_args(program: jasc.OpHandle) -> None: + program.one_shot_bufferize( + allow_return_allocs_from_loops=True, + allow_unknown_ops=True, + bufferize_function_boundaries=True, + function_boundary_type_conversion="IdentityLayoutMap", + memcpy_op="linalg.copy", + print_conflicts=True, + test_analysis_only=True, + ) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V0:.*]] = transform.bufferization.one_shot_bufferize + # CHECK-SAME: layout{IdentityLayoutMap} + # CHECK-SAME: %[[ARG0]] + # CHECK-SAME: allow_return_allocs_from_loops = true + # CHECK-SAME: allow_unknown_ops = true + # CHECK-SAME: bufferize_function_boundaries = true + # CHECK-SAME: memcpy_op = "linalg.copy" + # CHECK-SAME: print_conflicts = true + # CHECK-SAME: test_analysis_only = true + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + + +# CHECK-LABEL: TEST: test_pad_allargs +# TODO(ingomueller): I think that `padding_values` and `pad_to_multiple_of` +# should be optional but the mix-in currently doesn't support +# that. Add tests once it does. +@run +@print_schedule +def test_pad_allargs(program: jasc.OpHandle): + result = program.pad( + padding_values=[0.0, 0.0, 0.0, 0.0], + padding_dimensions=[0, 1, 2, 3], + pack_paddings=[1, 1, 1, 1], + pad_to_multiple_of=[1, 1, 1, 1], + transpose_paddings=[[0, 1, 2, 3]], + copy_back_op=jasc.PadCopyBackOp.NONE, + ) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.structured.pad %[[VAL_0]] + # CHECK-SAME: copy_back_op = "none" + # CHECK-SAME: pack_paddings = [1, 1, 1, 1] + # CHECK-SAME: pad_to_multiple_of = [1, 1, 1, 1] + # CHECK-SAME: padding_dimensions = [0, 1, 2, 3] + # CHECK-SAME: padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, + # CHECK-SAME: 0.000000e+00 : f32, 0.000000e+00 : f32] + # CHECK-SAME{LITERAL}: transpose_paddings = [[0, 1, 2, 3]] + assert isinstance(result.pad, jasc.OpHandle) + assert isinstance(result.padded, jasc.OpHandle) + assert isinstance(result.copy, jasc.OpHandle) + + +# CHECK-LABEL: TEST: test_print +@run +@print_schedule +def test_print(program: jasc.OpHandle): + program.print("debugMessage") + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK-NEXT: print %[[VAL_0]] {name = "debugMessage"} + + +# CHECK-LABEL: TEST: test_select +@run +@print_schedule +def test_select(program: jasc.OpHandle): + program.select(op_name="test.op") + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK-NEXT: select "test.op" in %[[VAL_0]] + + +# CHECK-LABEL: TEST: test_replace_with_alloc_tensor_explicit_cast +@run +@print_schedule +def test_replace_with_alloc_tensor_explicit_cast(program: jasc.OpHandle): + program.cast("tensor.empty").replace_with_alloc_tensor() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[V0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V1:.*]] = cast %[[V0]] : {{.*}} to !transform.op<"tensor.empty"> + # CHECK-NEXT: transform.bufferization.empty_tensor_to_alloc_tensor %[[V1]] + + +# CHECK-LABEL: TEST: test_replace_with_alloc_tensor_implicit_cast +@run +@print_schedule +def test_replace_with_alloc_tensor_implicit_cast(program: jasc.OpHandle): + program.replace_with_alloc_tensor() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[V0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[V1:.*]] = cast %[[V0]] : {{.*}} to !transform.op<"tensor.empty"> + # CHECK-NEXT: transform.bufferization.empty_tensor_to_alloc_tensor %[[V1]] + + +# CHECK-LABEL: TEST: test_rewrite_in_destination_passing_style +@run +@print_schedule +def test_rewrite_in_destination_passing_style(program: jasc.OpHandle): + program.rewrite_in_destination_passing_style() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.structured.rewrite_in_destination_passing_style + # CHECK-SAME: %[[VAL_0]] + + +# CHECK-LABEL: TEST: test_take_assumed_branch_standalone +@run +@print_schedule +def test_take_assumed_branch_standalone(program: jasc.OpHandle): + program.take_assumed_branch(take_else_branch=True) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG_0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.scf.take_assumed_branch %[[ARG_0]] take_else_branch + + +# CHECK-LABEL: TEST: test_take_assumed_branch_autonormalized +@run +@print_schedule +def test_take_assumed_branch_autonormalized(program: jasc.OpHandle): + with jasc.autonormalize(): + program.take_assumed_branch(take_else_branch=True) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG_0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[ARG1:.*]] = get_parent_op %[[ARG0:.*]] { + # CHECK-SAME: op_name = "func.func" + # CHECK-NEXT: apply_patterns to %[[ARG1]] { + # CHECK-NEXT: transform.apply_patterns.linalg.tiling_canonicalization + # CHECK-NEXT: transform.apply_patterns.fold_fill_into_pad + # CHECK-NEXT: transform.apply_patterns.scf.for_loop_canonicalization + # CHECK-NEXT: transform.apply_patterns.canonicalization + # CHECK-NEXT: } + # CHECK-NEXT: %[[ARG2:.*]] = transform.structured.match ops{["scf.for"]} + # CHECK-SAME: in %[[ARG1]] + # CHECK-NEXT: apply_licm to %[[ARG2]] : !transform.any_op + # CHECK-NEXT: apply_cse to %[[ARG1]] : !transform.any_op + # CHECK-NEXT: transform.scf.take_assumed_branch %[[ARG_0]] take_else_branch + + +# CHECK-LABEL: TEST: test_tile_to_for_sizes +@run +@print_schedule +def test_tile_to_for_sizes(program: jasc.OpHandle): + result = program.tile(loop=jasc.TileLoopKind.FOR, tile_sizes=(2, 4)) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.structured.tile_using_for %[[VAL_0]][2, 4] + assert isinstance(result.tiled_op, jasc.OpHandle) + for loop in result.loops: + assert isinstance(loop, jasc.OpHandle) + + +# CHECK-LABEL: TEST: test_tile_to_for_parametric +@run +@print_schedule +def test_tile_to_for_parametric(program: jasc.OpHandle) -> None: + tile_size = jasc.tuning_param( + default_value=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1) + ) + program.tile(loop=jasc.TileLoopKind.FOR, tile_sizes=[tile_size]) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[VAL0:.*]] = transform.jasc.tuning_param + # CHECK-SAME: {default_value = 1 : i32} + # CHECK-NEXT: transform.structured.tile_using_for %[[ARG0]][%[[VAL0]]] + + +# CHECK-LABEL: TEST: test_tile_to_forall_string +@run +@print_schedule +def test_tile_to_forall_string(program: jasc.OpHandle): + result = program.tile( + loop=jasc.TileLoopKind.FORALL, + tile_sizes=[64, 64, 1], + mapping="[#gpu.block, #gpu.block, #gpu.block]", + ) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[V0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.structured.tile_using_forall %[[V0]] + # CHECK-SAME: tile_sizes [64, 64, 1] + # CHECK-SAME: (mapping = [#gpu.block, #gpu.block, #gpu.block]) + assert len(result.loops) == 1 + assert isinstance(result.loops[0], jasc.OpHandle) + assert isinstance(result.tiled_op, jasc.OpHandle) + + +# CHECK-LABEL: TEST: test_tile_to_forall_nomapping +@run +@print_schedule +def test_tile_to_forall_nomapping(program: jasc.OpHandle): + program.tile( + loop=jasc.TileLoopKind.FORALL, + tile_sizes=[64, 64, 1], + ) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[V0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.structured.tile_using_forall %[[V0]] + # CHECK-NOT: mapping + + +# CHECK-LABEL: TEST: test_tile_to_forall_list +@run +@print_schedule +def test_tile_to_forall_list(program: jasc.OpHandle): + program.tile( + loop=jasc.TileLoopKind.FORALL, + tile_sizes=[64, 64, 1], + mapping=["#gpu.block", ir.Attribute.parse("#gpu.block")], + ) + # CHECK: (mapping = [#gpu.block, #gpu.block]) + + +# CHECK-LABEL: TEST: test_tile_to_forall_attr +@run +@print_schedule +def test_tile_to_forall_attr(program: jasc.OpHandle): + program.tile( + loop=jasc.TileLoopKind.FORALL, + tile_sizes=[64, 64, 1], + mapping=["#gpu.block", ir.Attribute.parse("#gpu.block")], + ) + # CHECK: (mapping = [#gpu.block, #gpu.block]) + + +# CHECK-LABEL: TEST: test_tile_to_forall_autonormalized +@run +@print_schedule +def test_tile_to_forall_autonormalized(program: jasc.OpHandle): + with jasc.autonormalize(): + program.tile( + loop=jasc.TileLoopKind.FORALL, + tile_sizes=[64, 64, 1], + mapping=[], + ) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[ARG1:.*]] = get_parent_op %[[ARG0:.*]] { + # CHECK-SAME: op_name = "func.func" + # CHECK-NEXT: apply_patterns to %[[ARG1]] { + # CHECK-NEXT: transform.apply_patterns.linalg.tiling_canonicalization + # CHECK-NEXT: transform.apply_patterns.fold_fill_into_pad + # CHECK-NEXT: transform.apply_patterns.scf.for_loop_canonicalization + # CHECK-NEXT: transform.apply_patterns.canonicalization + # CHECK-NEXT: } + # CHECK-NEXT: %[[ARG2:.*]] = transform.structured.match ops{["scf.for"]} + # CHECK-SAME: in %[[ARG1]] + # CHECK-NEXT: apply_licm to %[[ARG2]] : !transform.any_op + # CHECK-NEXT: apply_cse to %[[ARG1]] : !transform.any_op + # CHECK-NEXT: transform.structured.tile_using_forall %[[ARG0]] + # CHECK-SAME: tile_sizes [64, 64, 1] + # CHECK-SAME: (mapping = []) + + +# CHECK-LABEL: TEST: test_vectorize_children_and_apply_patterns +@run +@print_schedule +def test_vectorize_children_and_apply_patterns(program: jasc.OpHandle) -> None: + program.vectorize_children_and_apply_patterns( + disable_multi_reduction_to_contract_patterns=True, + disable_transfer_permutation_map_lowering_patterns=True, + vectorize_nd_extract=True, + vectorize_padding=True, + ) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.structured.vectorize_children_and_apply_patterns %[[ARG0]] { + # CHECK-SAME: disable_multi_reduction_to_contract_patterns + # CHECK-SAME: disable_transfer_permutation_map_lowering_patterns + # CHECK-SAME: vectorize_nd_extract + # CHECK-SAME: vectorize_padding} + + +# CHECK-LABEL: TEST: test_match_sparse_inout +@run +@print_schedule +def test_match_sparse_inout(program: jasc.OpHandle): + program.match_sparse_inout_ops() + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: transform.sparse_tensor.match.sparse_inout %[[ARG0]] + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + for test_fun in tests: + test_fun() + + +if __name__ == "__main__": + app.run(main) diff --git a/jasc/test/autotuning.py b/jasc/test/autotuning.py new file mode 100644 index 000000000000..e5963a88505f --- /dev/null +++ b/jasc/test/autotuning.py @@ -0,0 +1,80 @@ +"""Tests for the JASC autotuning utilities.""" + +from typing import Tuple +import sys + +# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] + +import chex +import jax +from jax import numpy as jnp +import pytest + +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import linalg + +from jasc import jasc +from jasc import tuner + + +def _gen_input(shape: Tuple[int, int], dtype=jnp.float64): + return jax.random.uniform(jax.random.PRNGKey(0), shape, dtype=dtype) + + +def test_matmul_1D_tiling_tuning() -> None: + """Tests the tuning of the tile size in a 64x64 matmul.""" + + def matmul(a: jax.Array, b: jax.Array) -> jax.Array: + return jasc.tag(jax.numpy.matmul, "matmul")(a, b) + + def schedule(module: jasc.OpHandle) -> None: + matmul = module.match_ops(linalg.GenericOp) + tile_size = jasc.tuning_param(1) + with jasc.autonormalize(False): + matmul.tile(loop=jasc.TileLoopKind.FOR, tile_sizes=[tile_size]) + + with ir.Context(): + a = _gen_input((64, 64), dtype=jnp.float32) + b = _gen_input((64, 64), dtype=jnp.float32) + + tuna = tuner.FooTuner(matmul, schedule, inputs=[a, b], budget=10) + best_time, tuned_fun, best_schedule, times = tuna.tune() + + chex.assert_trees_all_close( + jax.jit(matmul)(a, b), tuned_fun(a, b), rtol=1e-5 + ) + print(f"times: \n{times}") + print(f"best time: {best_time}") + + +def testmatmul_2D_tiling_tuning() -> None: + """Tests the tuning of the tile sizes in a 64x64 matmul.""" + + def matmul(a: jax.Array, b: jax.Array) -> jax.Array: + return jasc.tag(jax.numpy.matmul, "matmul")(a, b) + + def schedule(module: jasc.OpHandle) -> None: + matmul = module.match_ops(linalg.GenericOp) + tile_size = jasc.tuning_param(1) + with jasc.autonormalize(False): + matmul.tile(loop=jasc.TileLoopKind.FOR, tile_sizes=[tile_size, tile_size]) + + with ir.Context(): + a = _gen_input((64, 64), dtype=jnp.float32) + b = _gen_input((64, 64), dtype=jnp.float32) + + tuna = tuner.FooTuner(matmul, schedule, inputs=[a, b], budget=10) + best_time, tuned_fun, best_schedule, times = tuna.tune() + + chex.assert_trees_all_close( + jax.jit(matmul)(a, b), tuned_fun(a, b), rtol=1e-5 + ) + + print(f"times: \n{times}") + print(f"best time: {best_time}") + + +if __name__ == "__main__": + args = sys.argv[1:] or ["-s", "-v"] + sys.exit(pytest.main([__file__] + args)) diff --git a/jasc/test/batch_matmul_gpu.py b/jasc/test/batch_matmul_gpu.py new file mode 100644 index 000000000000..e62e075b62f0 --- /dev/null +++ b/jasc/test/batch_matmul_gpu.py @@ -0,0 +1,58 @@ +from typing import Tuple + +# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +import sys +sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] + +import chex +import jax +from jax import lax +from jax import numpy as jnp +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import transform +import pytest + +from jasc import jasc + +JASC_MLIR_TAG = "bmm" + + +def _gen_input(shape: Tuple[int, int], dtype=jnp.float64): + return jax.random.uniform(jax.random.PRNGKey(0), shape, dtype=dtype) + + +def test_batch_matmul(): + def batch_matmul(a: jax.Array, b: jax.Array) -> jax.Array: + return jasc.tag(lax.batch_matmul, JASC_MLIR_TAG)(a, b) + + def schedule(handle: jasc.OpHandle) -> None: + bmm = handle.match_tag(JASC_MLIR_TAG).select("linalg.batch_matmul") + loop = bmm.tile( + loop=jasc.TileLoopKind.FORALL, + num_threads=[1], + mapping=[ir.Attribute.parse("#gpu.block")], + ).loops[0] + fill = handle.match_tag(JASC_MLIR_TAG).select("linalg.fill") + fill.fuse_into(loop) + with handle.match_ops("func.func").apply_patterns(): + transform.ApplyCanonicalizationPatternsOp() + handle.match_ops("tensor.empty").replace_with_alloc_tensor() + handle = handle.one_shot_bufferize( + bufferize_function_boundaries=True, + function_boundary_type_conversion="IdentityLayoutMap", + ) + handle.map_forall_to_blocks(grid_dims=[1, 1, 1], generate_gpu_launch=True) + + a = _gen_input((4, 32, 32)) + b = _gen_input((4, 32, 32)) + + jit_batch_matmul = jasc.jit(batch_matmul, schedule) + chex.assert_gpu_available() + chex.assert_trees_all_close( + jit_batch_matmul(a, b), lax.batch_matmul(a, b), rtol=1e-5 + ) + + +if __name__ == "__main__": + args = sys.argv[1:] or ["-s", "-v"] + sys.exit(pytest.main([__file__] + args)) diff --git a/jasc/test/bindings.py b/jasc/test/bindings.py new file mode 100644 index 000000000000..07be02128340 --- /dev/null +++ b/jasc/test/bindings.py @@ -0,0 +1,83 @@ +from typing import Callable, Sequence + +from absl import app +from jaxlib.mlir import ir, passmanager +from jaxlib.mlir.dialects import transform + +from jaxlib.mlir.dialects import jasc as jd +from jaxlib.mlir.dialects.transform import jasc_transform_ops as jto + +tests: list[Callable[[], None]] = [] + + +def run(f): + def test(): + print("\nTEST:", f.__name__) + f() + + tests.append(test) + return f + + +# CHECK-LABEL: test_register_transform_dialect_extension +@run +def test_register_transform_dialect_extension() -> None: + with ir.Context() as ctx, ir.Location.unknown(): + jto.register_transform_dialect_extension(ctx) + module = ir.Module.create() + with ir.InsertionPoint(module.body): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.AnyOpType.get(), + ) + with ir.InsertionPoint(sequence.body): + jto.MatchTagOp(sequence.bodyTarget, ["tag"]) + transform.YieldOp([]) + module.operation.verify() + print(module) + # CHECK: transform.sequence + # CHECK: transform.jasc.match_tag + + +# CHECK-LABEL: test_register_and_load_dialect +@run +def test_register_and_load_dialect() -> None: + with ir.Context(), ir.Location.unknown(): + jd.register_and_load_dialect() + module = ir.Module.create() + with ir.InsertionPoint(module.body): + op = jd.TagRegionOp([], "tag") + op.body.blocks.append() + with ir.InsertionPoint(op.body.blocks[0]): + jd.ReturnOp([]) + module.operation.verify() + print(module) + # CHECK: jasc.tag_region "tag" { + # CHECK-NEXT: } + + +# CHECK-LABEL: test_register_lowering_passes +@run +def test_register_lowering_passes() -> None: + with ir.Context(), ir.Location.unknown(): + jd.register_lowering_passes() + module = ir.Module.create() + pm = passmanager.PassManager.parse( + "builtin.module(jasc-remove-copy-to-out-params)" + ) + pm.run(module.operation) + print(module) + # CHECK: module + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + for test_fun in tests: + test_fun() + + +if __name__ == "__main__": + app.run(main) diff --git a/jasc/test/cpu_integration.py b/jasc/test/cpu_integration.py new file mode 100644 index 000000000000..e9ea7c5e3bf2 --- /dev/null +++ b/jasc/test/cpu_integration.py @@ -0,0 +1,111 @@ +"""Jasc tests common to all platforms.""" + +from collections.abc import Mapping +import sys + +# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] + +import chex +import jax +from jax import numpy as jnp +import pytest + +from jasc import jasc + + +def _unit_schedule(handle: jasc.OpHandle) -> None: + """A schedule that does nothing.""" + del handle + + +def test_jit_pass(): + def foo() -> None: + pass + + assert jasc.jit(foo, _unit_schedule)() is None + + +def test_jit_single_input_output(): + def foo(x: jax.Array) -> jax.Array: + return x + + x = jnp.array([1, 2, 3]) + chex.assert_trees_all_equal(jasc.jit(foo, _unit_schedule)(x), x) + + +def test_jit_single_op(): + def foo(x: jax.Array) -> jax.Array: + return x + 1 + + chex.assert_trees_all_equal( + jasc.jit(foo, _unit_schedule)(jnp.array([1, 2, 3])), + jnp.array([2, 3, 4]), + ) + + +def test_jit_multiple_inputs(): + def foo(x: jax.Array, y: jax.Array) -> jax.Array: + return x + y + + x = jnp.array([1, 2, 3]) + y = jnp.array([4, 5, 6]) + chex.assert_trees_all_equal( + jasc.jit(foo, _unit_schedule)(x, y), jnp.array([5, 7, 9]) + ) + + +def test_jit_multiple_outputs(): + def foo(x: jax.Array) -> tuple[jax.Array, jax.Array]: + return x + 1, x + 2 + + chex.assert_trees_all_equal( + jasc.jit(foo, _unit_schedule)(jnp.array([1, 2, 3])), + (jnp.array([2, 3, 4]), jnp.array([3, 4, 5])), + ) + + +def test_jit_dict_input(): + def foo(x: Mapping[str, jax.Array]) -> jax.Array: + return x["a"] + x["b"] + + x = {"a": jnp.array([1, 2, 3]), "b": jnp.array([4, 5, 6])} + chex.assert_trees_all_equal( + jasc.jit(foo, _unit_schedule)(x), jnp.array([5, 7, 9]) + ) + + +def test_jit_dict_output(): + def foo(x: jax.Array) -> Mapping[str, jax.Array]: + return {"a": x + 1, "b": x + 2} + + chex.assert_trees_all_equal( + jasc.jit(foo, _unit_schedule)(jnp.array([1, 2, 3])), + {"a": jnp.array([2, 3, 4]), "b": jnp.array([3, 4, 5])}, + ) + + +def test_tag_jit(): + def foo(x: jax.Array) -> jax.Array: + return x + 1 + + jit_foo = jasc.jit(jasc.tag(foo, "foo_tag"), _unit_schedule) + chex.assert_trees_all_equal( + jit_foo(jnp.array([1, 2, 3])), + jnp.array([2, 3, 4]), + ) + + +def test_tag_jax_jit(): + def foo(x: jax.Array) -> jax.Array: + return x + 1 + + jit_foo = jax.jit(jasc.tag(foo, "foo_tag")) + chex.assert_trees_all_equal( + jit_foo(jnp.array([1, 2, 3])), jnp.array([2, 3, 4]) + ) + + +if __name__ == "__main__": + args = sys.argv[1:] or ["-s", "-v"] + sys.exit(pytest.main([__file__] + args)) diff --git a/jasc/test/diagnostics.py b/jasc/test/diagnostics.py new file mode 100644 index 000000000000..bb80644897ca --- /dev/null +++ b/jasc/test/diagnostics.py @@ -0,0 +1,59 @@ +"""Tests for MLIR diagnostics.""" + +from __future__ import annotations + +from typing import Callable, Sequence + +from absl import app +import jax +from jax import numpy as jnp + +from jasc import jasc + +tests: list[Callable[[], None]] = [] + + +def run(f): + def test(): + print("\nTEST:", f.__name__) + f() + + tests.append(test) + return f + + +# CHECK-LABEL: TEST: test_location_notes +@run +def test_location_notes(): + def foo(a: jax.Array) -> jax.Array: + return a + + def schedule(h: jasc.OpHandle) -> None: + # This is invalid because it applies CSE to the embedded transform script, + # which can't be modified. + h.apply_cse() + + # Check that the exception contains notes and locations. + try: + data = jnp.full((16, 16), 1.23) + jasc.lower_to_linalg(foo, data, schedule=schedule) + except Exception as e: + print(e) + # CHECK: /tmp/mlir_snapshot-{{.*}}.tmp.mlir: + # CHECK-SAME: error: cannot apply transform to itself + # CHECK: apply_cse to %{{.*}} : !transform.any_op + # CHECK-NEXT: ^ + # CHECK: /tmp/mlir_snapshot-{{.*}}.tmp.mlir: + # CHECK-SAME: note: target payload op + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + for test_fun in tests: + test_fun() + + +if __name__ == "__main__": + app.run(main) diff --git a/jasc/test/filecheck_test.sh b/jasc/test/filecheck_test.sh new file mode 100755 index 000000000000..8d184e347d53 --- /dev/null +++ b/jasc/test/filecheck_test.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +die() { echo "$*" 1>&2 ; exit 1; } + +TARGET=$1 +TESTFILE="${TEST_SRCDIR}/jasc/test/$TARGET" +$TESTFILE || die "Failure in Python test: \"$TARGET\"" +$TESTFILE | "$TEST_SRCDIR/llvm-project/llvm/FileCheck" "$TESTFILE.py" || die "Failure in FileCheck test: \"$TARGET\"" + +echo "PASS" diff --git a/jasc/test/fold_fill_into_pad.mlir b/jasc/test/fold_fill_into_pad.mlir new file mode 100644 index 000000000000..2292cae0b029 --- /dev/null +++ b/jasc/test/fold_fill_into_pad.mlir @@ -0,0 +1,35 @@ +// RUN: jasc-opt %s -jasc-apply-transform-script \ +// RUN: | FileCheck %s + + +// Test ported from: +// third_party/iree/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir +// CHECK-LABEL: @pad_fill_to_fill +func.func @pad_fill_to_fill(%arg0: tensor<31x62xf32>) -> tensor<32x64xf32> { + // Check that a pad of a fill with the same constant is replaced by a + // bigger fill. + // CHECK-DAG: %[[FILL_CST:.*]] = arith.constant 0.0{{0*e\+00}} : f32 + // CHECK-DAG: %[[EMPTY:.*]] = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[PADDED_FILL:.*]] = linalg.fill ins(%[[FILL_CST]] : f32) outs(%[[EMPTY]] : tensor<32x64xf32>) -> tensor<32x64xf32> + // CHECK: return %[[PADDED_FILL]] + %cst = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %fill = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<31x62xf32>) -> tensor<31x62xf32> + %padded = tensor.pad %fill low[%c0, %c0] high[%c1, %c2] { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %cst : f32 + } : tensor<31x62xf32> to tensor<32x64xf32> + return %padded : tensor<32x64xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.fold_fill_into_pad + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + } : !transform.any_op +} diff --git a/jasc/test/gpu_integration.py b/jasc/test/gpu_integration.py new file mode 100644 index 000000000000..16ddc0fb1555 --- /dev/null +++ b/jasc/test/gpu_integration.py @@ -0,0 +1,30 @@ +"""GPU-specific tests for Jasc.""" + +# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +import sys +sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] + +import chex +import pytest + + +# ===----------------------------------------------------------------------=== # +# Import Jasc CPU tests. +# ===----------------------------------------------------------------------=== # + +# XXX: This "imports" the CPU tests but does not run them on a GPU. How to +# achieve that? Maybe we need to rethink the whole GPU vs. CPU mechanism... +"""Imports common tests.""" +from jasc.test.cpu_integration import * + + +# ===----------------------------------------------------------------------=== # +# Jasc GPU-specific tests. +# ===----------------------------------------------------------------------=== # +def test_running_on_gpu(): + chex.assert_gpu_available() + + +if __name__ == "__main__": + args = sys.argv[1:] or ["-s", "-v"] + sys.exit(pytest.main([__file__] + args)) diff --git a/jasc/test/jit.py b/jasc/test/jit.py new file mode 100644 index 000000000000..860901860329 --- /dev/null +++ b/jasc/test/jit.py @@ -0,0 +1,73 @@ +"""Tests for JASC jit.""" +from __future__ import annotations + +import sys + +# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] + +import chex +import jax +import pytest + +from jasc import jasc + +jasc.set_auto_normalization(False) + + +def _gen_input(shape: tuple[int, int], dtype=jax.numpy.float64): + return jax.random.uniform(jax.random.PRNGKey(0), shape, dtype=dtype) + + +def test_jit_matmul_jax_func() -> None: + """Jasc jit of a jax function without an additional schedule.""" + + def matmul(a: jax.Array, b: jax.Array) -> jax.Array: + return jasc.tag(jax.numpy.matmul, "matmul")(a, b) + + jit_matmul = jasc.jit(matmul) + a = _gen_input((64, 64)) + b = _gen_input((64, 64)) + chex.assert_trees_all_close( + jit_matmul(a, b), jax.numpy.matmul(a, b), rtol=1e-5 + ) + +def test_jit_matmul_jax_func_schedule() -> None: + """Jasc jit of a jax function with a simple schedule.""" + + def matmul(a: jax.Array, b: jax.Array) -> jax.Array: + return jasc.tag(jax.numpy.matmul, "matmul")(a, b) + + def schedule(handle: jasc.OpHandle) -> None: + handle.match_ops("linalg.generic").tile( + loop=jasc.TileLoopKind.FOR, tile_sizes=[32] + ) + + jit_matmul = jasc.jit(matmul, schedule) + a = _gen_input((64, 64)) + b = _gen_input((64, 64)) + chex.assert_trees_all_close( + jit_matmul(a, b), jax.numpy.matmul(a, b), rtol=1e-5 + ) + + +def test_jit_matmul_mlir() -> None: + """Jasc jit of an mlir module that stems from a jax function.""" + + def matmul(a: jax.Array, b: jax.Array) -> jax.Array: + return jasc.tag(jax.numpy.matmul, "matmul")(a, b) + + a = _gen_input((64, 64)) + b = _gen_input((64, 64)) + + module = jasc.lower_to_linalg(matmul, a, b) + jit_matmul = jasc.jit(matmul, module=module) + + chex.assert_trees_all_close( + jit_matmul(a, b), jax.numpy.matmul(a, b), rtol=1e-5 + ) + + +if __name__ == "__main__": + args = sys.argv[1:] or ["-s", "-v"] + sys.exit(pytest.main([__file__] + args)) \ No newline at end of file diff --git a/jasc/test/lit.cfg.py b/jasc/test/lit.cfg.py new file mode 100644 index 000000000000..878e1ca09acb --- /dev/null +++ b/jasc/test/lit.cfg.py @@ -0,0 +1,39 @@ +# -*- Python -*- + +import os + +import lit.formats +import lit.util + +from lit.llvm import llvm_config +from lit.llvm.subst import ToolSubst + +# Configuration file for the 'lit' test runner. + +config.name = "Jasc" + +config.test_format = lit.formats.ShTest(execute_external=False) + +config.suffixes = [ + ".mlir", +] + +config.excludes = [ + "lit.cfg.py", +] + +config.test_source_root = os.path.dirname(__file__) +config.test_exec_root = os.path.join(config.test_source_root) + +llvm_config.use_default_substitutions() + +tool_dirs = [ + config.jasc_tools_dir, + config.mlir_tools_dir, + config.llvm_tools_dir, +] +tools = [ + "jasc-opt", +] + +llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/jasc/test/lit.site.cfg.in.py b/jasc/test/lit.site.cfg.in.py new file mode 100644 index 000000000000..a86f9955798a --- /dev/null +++ b/jasc/test/lit.site.cfg.in.py @@ -0,0 +1,16 @@ +@LIT_SITE_CFG_IN_HEADER@ + +import os.path + +config.llvm_tools_dir = lit_config.substitute("@LLVM_TOOLS_DIR@") +config.llvm_shlib_ext = "@SHLIBEXT@" +config.llvm_shlib_dir = lit_config.substitute(path(r"@SHLIBDIR@")) +config.mlir_tools_dir = "@MLIR_TOOLS_DIR@" +config.jasc_src_dir = "@JASC_SOURCE_DIR@" +config.jasc_tools_dir = "@JASC_TOOLS_DIR@" + +import lit.llvm +lit.llvm.initialize(lit_config, config) + +# Let the main config do the real work. +lit_config.load_config(config, os.path.join(config.jasc_src_dir, "test/lit.cfg.py")) diff --git a/jasc/test/matmul_cpu.py b/jasc/test/matmul_cpu.py new file mode 100644 index 000000000000..e386c73c418b --- /dev/null +++ b/jasc/test/matmul_cpu.py @@ -0,0 +1,122 @@ +"""Integration test of full schedules for `jax.numpy.matmul`.""" + +from typing import Tuple +import sys + +# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] + +import chex +import jax +from jax import numpy as jnp +from jaxlib.mlir.dialects import transform +from jaxlib.mlir.dialects.transform import ( + loop, + memref, + structured, + tensor, + vector, +) +import pytest + +from jasc import jasc + + +def _gen_input(shape: Tuple[int, int], dtype=jnp.float64): + return jax.random.uniform(jax.random.PRNGKey(0), shape, dtype=dtype) + + +# ===----------------------------------------------------------------------=== # +# zinenko@ matmul schedules. +# ===----------------------------------------------------------------------=== # +# These schedules reimplements internal schedules from zinenko@ +# ===----------------------------------------------------------------------=== # + +@pytest.mark.parametrize("m,k,n", [ + (4, 768, 2304), + (4, 2304, 768), + (4, 768, 3072), + (4, 3072, 768), +]) +def test_zinenko_matmul_f32(m, k, n): + def matmul(a: jax.Array, b: jax.Array) -> jax.Array: + return jasc.tag(jax.numpy.matmul, "matmul")(a, b) + + def schedule(module: jasc.OpHandle) -> None: + # Tile matmul. + # Note: Unlike the original schedule, we tile to `scf.forall` such that we + # can fuse the `linalg.fill`, which the other schedule doesn't have. + tiled_matmul, loops = module.match_ops("linalg.generic").tile( + loop=jasc.TileLoopKind.FORALL, tile_sizes=(0, 16) + ) + module.match_ops("linalg.fill").fuse_into(loops[0]) + + # Tile matmul again, then interchange. + tiled_matmul.tile( + loop=jasc.TileLoopKind.FOR, + tile_sizes=(0, 0, 8), + ).tiled_op.interchange([0, 2, 1]).vectorize() + + # Manual clean-up. + func = module.match_ops("func.func") + with func.apply_patterns(): + transform.ApplyCanonicalizationPatternsOp() + structured.ApplyTilingCanonicalizationPatternsOp() + func.apply_cse() + func.match_ops("LoopLikeInterface").apply_licm() + with func.apply_patterns(): + structured.ApplyFoldUnitExtentDimsViaReshapesPatternsOp() + + # Vectorize function. + func.vectorize_children_and_apply_patterns(vectorize_padding=True) + + # Hoist redundant transforms. + with func.apply_patterns(): + transform.ApplyCanonicalizationPatternsOp() + tensor.ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp() + func.apply_cse() + func.hoist_redundant_vector_transfers() + + # Bufferize. + module.one_shot_bufferize( + bufferize_function_boundaries=True, + function_boundary_type_conversion="IdentityLayoutMap", + ) + + # Turn the `scf.forall` into `scf.for`. + # Note: The original schedule does not do that since it creates `scf.for` + # right away (see above). + forall = module.match_ops("scf.forall") + loop.ForallToForOp([transform.AnyOpType.get()], forall.mlir_value) + + # Lowering of vector ops. + func = module.match_ops("func.func") + with func.apply_patterns(): + transform.ApplyCanonicalizationPatternsOp() + with func.apply_patterns(): + vector.ApplyLowerContractionPatternsOp() + vector.ApplyLowerTransposePatternsOp() + vector.ApplyLowerTransferPatternsOp() + vector.ApplyLowerShapeCastPatternsOp() + with func.apply_patterns(): + vector.ApplyTransferToScfPatternsOp(full_unroll=True) + memref.ApplyAllocToAllocaOp() + + # Hoist buffers. (Does not have any effect on this input). + func.buffer_loop_hoisting() + + # Final foldings and clean-up. + with func.apply_patterns(): + memref.ApplyFoldMemrefAliasOpsPatternsOp() + transform.ApplyCanonicalizationPatternsOp() + func.apply_cse() + + jit_matmul = jasc.jit(matmul, schedule) + a = _gen_input((m, k), dtype=jnp.float32) + b = _gen_input((k, n), dtype=jnp.float32) + chex.assert_trees_all_close(jit_matmul(a, b), jnp.matmul(a, b), rtol=1e-5) + + +if __name__ == "__main__": + args = sys.argv[1:] or ["-s", "-v"] + sys.exit(pytest.main([__file__] + args)) diff --git a/jasc/test/matmul_gpu.py b/jasc/test/matmul_gpu.py new file mode 100644 index 000000000000..1da91c236fad --- /dev/null +++ b/jasc/test/matmul_gpu.py @@ -0,0 +1,237 @@ +"""Integration test of full schedules for `jax.numpy.matmul`.""" + +from typing import Tuple + +# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +import sys +sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] + +import chex +import jax +from jax import numpy as jnp +from jaxlib.mlir.dialects import transform +from jaxlib.mlir.dialects.transform import ( + vector, +) +import pytest + +from jasc import jasc + + +def _gen_input(shape: Tuple[int, int], dtype=jnp.float64): + return jax.random.uniform(jax.random.PRNGKey(0), shape, dtype=dtype) + + +# ===----------------------------------------------------------------------=== # +# ntv@ matmult schedule. +# ===----------------------------------------------------------------------=== # +# This schedule re-implements a schedule originally written by ntv in C++. +# ===----------------------------------------------------------------------=== # +def test_ntv_matmul(): + def build_matmul_strategy_block_distribution( + func: jasc.OpHandle, tile_sizes, mapping + ): + tiled, loops = func.match_ops("linalg.generic").tile( + loop=jasc.TileLoopKind.FORALL, tile_sizes=tile_sizes, mapping=mapping + ) + forall = loops[0] + build_canonicalize_and_enabling_transforms(func) + fused_fill = func.match_ops("linalg.fill").fuse_into(forall) + build_canonicalize_and_enabling_transforms(func) + return (fused_fill, tiled, forall) + + def build_map_top_level_forall_to_blocks( + func: jasc.OpHandle, grid_dims + ) -> jasc.OpHandle: + build_canonicalize_and_enabling_transforms(func) + return func.map_forall_to_blocks( + grid_dims=grid_dims, generate_gpu_launch=True + ) + + def build_bufferize(target: jasc.OpHandle): + build_canonicalize_and_enabling_transforms(target) + target.match_ops("tensor.empty").replace_with_alloc_tensor() + # TODO: We have to bufferize on the module in order to bufferize the + # function boundaries but we loose the handle to the function that + # way. Find away around that. + module = target.get_parent_op("builtin.module").one_shot_bufferize( + bufferize_function_boundaries=True, + function_boundary_type_conversion="IdentityLayoutMap", + ) + target = module.match_ops("func.func") # XXX: get our function back + build_canonicalize_and_enabling_transforms(target) + return target + + def build_canonicalize_and_enabling_transforms(target: jasc.OpHandle): + with target.apply_patterns(): + transform.structured.ApplyTilingCanonicalizationPatternsOp() + transform.loop.ApplyForLoopCanonicalizationPatternsOp() + transform.ApplyCanonicalizationPatternsOp() + target.apply_cse() + + def build_transform_strategy(target: jasc.OpHandle): + tsb_x = 16 + tsb_y = 8 + num_blocks_x = (789 + tsb_x + 1) // tsb_x + num_blocks_y = (123 + tsb_y + 1) // tsb_y + mapping = ["#gpu.block", "#gpu.block"] + _, matmul, _ = build_matmul_strategy_block_distribution( + target, [num_blocks_x, num_blocks_y], mapping + ) + + matmul.tile(loop=jasc.TileLoopKind.FOR, tile_sizes=(0, 0, 32)) + build_canonicalize_and_enabling_transforms(target) + + target = build_bufferize(target) + + target = build_map_top_level_forall_to_blocks( + target, [num_blocks_x, num_blocks_y, 1] + ) + build_canonicalize_and_enabling_transforms(target) + + return ((num_blocks_x, num_blocks_y, 1), (1, 1, 1)) + + def matmul(a: jax.Array, b: jax.Array) -> jax.Array: + return jasc.tag(jax.numpy.matmul, "matmul")(a, b) + + def schedule(handle: jasc.OpHandle) -> None: + func = handle.match_ops("func.func") + build_transform_strategy(func) + del handle + + a = _gen_input((123, 789)) + b = _gen_input((789, 123)) + jit_matmul = jasc.jit(matmul, schedule) + chex.assert_gpu_available() + chex.assert_trees_all_close(jit_matmul(a, b), jnp.matmul(a, b), rtol=1e-5) + + +# ===----------------------------------------------------------------------=== # +# springerm@ matmul schedule. +# ===----------------------------------------------------------------------=== # +# This schedule re-implements and extends the schedule from upstream at +# https://github.com/llvm/llvm-project/blob/96ff0255f/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir +# ===----------------------------------------------------------------------=== # +def test_springerm_matmul(): + def matmul(a: jax.Array, b: jax.Array) -> jax.Array: + return jasc.tag(jax.numpy.matmul, "matmul")(a, b) + + def schedule(handle: jasc.OpHandle) -> None: + # Fuse `linalg.fill` into `linalg.generic` (matmul) and tile across + # blocks. + tiled_matmul, loops = handle.match_ops("linalg.generic").tile( + loop=jasc.TileLoopKind.FORALL, + tile_sizes=[64, 64], + mapping=["#gpu.block", "#gpu.block"], + ) + tiled_fused_fill = handle.match_ops("linalg.fill").fuse_into(loops[0]) + + # Tile matmul a second time. + tiled_matmul, _ = tiled_matmul.tile( + loop=jasc.TileLoopKind.FOR, tile_sizes=(0, 0, 16) + ) + + # Pad matmul. + tiled_padded_matmul, pad, copy_back = tiled_matmul.pad( + padding_values=[0.0, 0.0, 0.0], + padding_dimensions=[0, 1, 2], + pack_paddings=[1, 1, 1], + copy_back_op=jasc.PadCopyBackOp.LINALG_COPY, + ) + + # Tile ops across threads and vectorize. + tiled_padded_matmul.tile( + loop=jasc.TileLoopKind.FORALL, + num_threads=[8, 32], + mapping=["#gpu.thread", "#gpu.thread"], + ).tiled_op.vectorize() + tiled_fused_fill.tile( + loop=jasc.TileLoopKind.FORALL, + num_threads=[8, 32], + mapping=["#gpu.thread", "#gpu.thread"], + ).tiled_op.vectorize() + + # Map `tensor.pad` and copy back to threads. + pad_forall_op, tiled_pad_op = pad.map_copy_to_threads( + total_num_threads=256, desired_bit_alignment=128 + ) + copy_back.map_copy_to_threads( + total_num_threads=256, desired_bit_alignment=128 + ) + + # Vectorize padding ops. + tiled_pad_op.vectorize(vector_sizes=[128, 4]) + + # Assign shared memory buffer to padding. + padding_bufferization = pad_forall_op.bufferize_to_allocation( + memory_space=3, + bufferize_destination_only=True, + alloc_op="memref.alloca", + ) + + # Transform `memref.alloca`s to `memref.global`s in order to work in the + # `gpu.launch`. + foreach = padding_bufferization.new_ops.foreach( + [transform.AnyOpType.get()] + ) + with foreach.body as op: + alloca = op.match_ops("memref.alloca") + jasc.yield_(alloca) + foreach.results[0].alloca_to_global() + + # Bufferize the whole function. + ( + handle.match_ops("func.func") + .eliminate_empty_tensors() + .apply_dce() + .apply_cse() + ) + + handle.match_ops("tensor.empty").replace_with_alloc_tensor() + bufferized = handle.one_shot_bufferize( + bufferize_function_boundaries=True, + function_boundary_type_conversion="IdentityLayoutMap", + ) + + # Apply vectorization to copy back from shared memory. + # TODO: Find a way to retain the handle to linalg.copy throughout + # bufferization. + func = bufferized.match_ops("func.func") + func.match_ops("linalg.copy").vectorize(vector_sizes=[128, 4]) + + # Canonicalize, cleanup, and vector lowering. This step also removes + # buffer self-copies. + with func.apply_patterns(apply_cse=True): + transform.ApplyCanonicalizationPatternsOp() + vector.ApplyVectorReductionToContractPatternsOp() + vector.ApplyLowerMaskedTransfersPatternsOp() + vector.ApplyTransferPermutationPatternsOp() + vector.ApplyVectorReductionToContractPatternsOp() + + # Map the `scf.forall`s to GPU blocks and threads. + func.map_forall_to_blocks( + grid_dims=[16, 16, 1], + generate_gpu_launch=True, + ).map_nested_forall_to_threads(block_dims=[32, 8, 1]) + + # Some more clean-ups. + func = bufferized.match_ops( + "func.func" + ).hoist_redundant_vector_transfers() + with func.apply_patterns(apply_cse=True): + vector.ApplyTransferToScfPatternsOp( + max_transfer_rank=1, full_unroll=True + ) + + del handle + + jit_matmul = jasc.jit(matmul, schedule) + a = _gen_input((1024, 1024), dtype=jnp.float32) + b = _gen_input((1024, 1024), dtype=jnp.float32) + chex.assert_gpu_available() + chex.assert_trees_all_close(jit_matmul(a, b), jnp.matmul(a, b), rtol=1e-5) + + +if __name__ == "__main__": + args = sys.argv[1:] or ["-s", "-v"] + sys.exit(pytest.main([__file__] + args)) diff --git a/jasc/test/normalization.py b/jasc/test/normalization.py new file mode 100644 index 000000000000..9246ff522d01 --- /dev/null +++ b/jasc/test/normalization.py @@ -0,0 +1,302 @@ +"""Tests for JASC transform op abstractions.""" +from __future__ import annotations + +from typing import Callable, Sequence + +from absl import app +from jaxlib.mlir import ir + +from jasc import jasc + +tests: list[Callable[[], None]] = [] +jasc.set_auto_normalization(False) + + +def run(f): + def test(): + print("\nTEST:", f.__name__) + f() + + tests.append(test) + return f + + +def print_schedule(schedule: Callable) -> Callable: + def decorated() -> None: + with ir.Context(): + module = ir.Module.parse("") + jasc.insert_schedule(module, schedule=schedule, dump_schedule=True) + module.operation.verify() + + decorated.__name__ = schedule.__name__ + return decorated + + +# CHECK-LABEL: TEST: test_auto_apply_loop_normalform +@run +@print_schedule +def test_auto_apply_loop_normalform(program: jasc.OpHandle) -> None: + with jasc.autonormalize(): + program.tile( + loop=jasc.TileLoopKind.FORALL, tile_sizes=[64, 64, 1], mapping=[] + ) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[ARG1:.*]] = get_parent_op %[[ARG0]] + # CHECK-SAME: {deduplicate, op_name = "func.func"} + # CHECK-NEXT: apply_patterns to %[[ARG1]] { + # CHECK-NEXT: transform.apply_patterns.linalg.tiling_canonicalization + # CHECK-NEXT: transform.apply_patterns.fold_fill_into_pad + # CHECK-NEXT: transform.apply_patterns.scf.for_loop_canonicalization + # CHECK-NEXT: transform.apply_patterns.canonicalization + # CHECK-NEXT: } + # CHECK-NEXT: %[[ARG2:.*]] = transform.structured.match ops{["scf.for"]} + # CHECK-SAME: in %[[ARG1]] + # CHECK-NEXT: apply_licm to %[[ARG2]] + # CHECK-NEXT: apply_cse to %[[ARG1]] + # CHECK-NEXT: transform.structured.tile_using_forall %[[ARG0]] + + +# CHECK-LABEL: TEST: test_autonormalize_contextmanager +@run +@print_schedule +def test_autonormalize_contextmanager(program: jasc.OpHandle) -> None: + # Checks that autonormalization behavior is preserved outside of the + # contextmanager + jasc.set_auto_normalization(True) + with jasc.autonormalize(): + pass + program.auto_normalize_parent_func(jasc.LoopNormalform) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[ARG1:.*]] = get_parent_op %[[ARG0]] + # CHECK-SAME: {deduplicate, op_name = "func.func"} + # CHECK-NEXT: apply_patterns to %[[ARG1]] { + # CHECK-NEXT: transform.apply_patterns.linalg.tiling_canonicalization + # CHECK-NEXT: transform.apply_patterns.fold_fill_into_pad + # CHECK-NEXT: transform.apply_patterns.scf.for_loop_canonicalization + # CHECK-NEXT: transform.apply_patterns.canonicalization + # CHECK-NEXT: } + # CHECK-NEXT: %[[ARG2:.*]] = transform.structured.match ops{["scf.for"]} + # CHECK-SAME: in %[[ARG1]] + # CHECK-NEXT: apply_licm to %[[ARG2]] + # CHECK-NEXT: apply_cse to %[[ARG1]] + + +# CHECK-LABEL: TEST: test_normalforms_autonormalization_decorator_plain +@jasc.jasc_transform +def foo_abstraction_0(handle: jasc.Value) -> jasc.Value: + return jasc.Value(handle.mlir_value, parent=handle) + + +@run +@print_schedule +def test_normalforms_autonormalization_decorator_plain( + program: jasc.OpHandle, +) -> None: + with jasc.autonormalize(): + program.normalize(jasc.LoopNormalform) + assert program.normalform == jasc.LoopNormalform + # This will conservatively reset the normalform and also propagate the + # weaker normalform to the parent + new_handle = foo_abstraction_0(program) + assert new_handle.normalform == jasc.AnyForm + assert program.normalform == jasc.AnyForm + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: apply_patterns to %[[ARG0]] { + # CHECK-NEXT: transform.apply_patterns.linalg.tiling_canonicalization + # CHECK-NEXT: transform.apply_patterns.fold_fill_into_pad + # CHECK-NEXT: transform.apply_patterns.scf.for_loop_canonicalization + # CHECK-NEXT: transform.apply_patterns.canonicalization + # CHECK-NEXT: } + # CHECK-NEXT: %[[ARG1:.*]] = transform.structured.match ops{["scf.for"]} + # CHECK-SAME: in %[[ARG0]] + # CHECK-NEXT: apply_licm to %[[ARG1]] + # CHECK-NEXT: apply_cse to %[[ARG0]] + + +# CHECK-LABEL: TEST: test_normalforms_autonormalization_decorator_called +@jasc.jasc_transform() +def foo_abstraction_1(handle: jasc.Value) -> jasc.Value: + return jasc.Value(handle.mlir_value, parent=handle) + + +@run +@print_schedule +def test_normalforms_autonormalization_decorator_called( + program: jasc.OpHandle, +) -> None: + with jasc.autonormalize(): + program.normalize(jasc.LoopNormalform) + assert program.normalform == jasc.LoopNormalform + # This will conservatively reset the normalform and also propagate the + # weaker normalform to the parent + new_handle = foo_abstraction_1(program) + assert new_handle.normalform == jasc.AnyForm + assert program.normalform == jasc.AnyForm + + +# CHECK-LABEL: TEST: test_normalforms_autonormalization_decorator_args_0 +@jasc.jasc_transform(enforced_normalform=jasc.LoopNormalform) +def foo_abstraction_2(handle: jasc.Value) -> jasc.Value: + return jasc.Value(handle.mlir_value, parent=handle) + + +@run +@print_schedule +def test_normalforms_autonormalization_decorator_args_0( + program: jasc.OpHandle, +) -> None: + with jasc.autonormalize(): + program.normalize(jasc.LoopNormalform) + assert program.normalform == jasc.LoopNormalform + # This will retain the normalform + new_handle = foo_abstraction_2(program) + assert new_handle.normalform == jasc.LoopNormalform + assert program.normalform == jasc.LoopNormalform + + +# CHECK-LABEL: TEST: test_normalforms_autonormalization_decorator_args_1 +@jasc.jasc_transform(no_propagate=True) +def foo_abstraction_3(handle: jasc.Value) -> jasc.Value: + return jasc.Value(handle.mlir_value, parent=handle) + + +@run +@print_schedule +def test_normalforms_autonormalization_decorator_args_1( + program: jasc.OpHandle, +) -> None: + with jasc.autonormalize(): + program.normalize(jasc.LoopNormalform) + assert program.normalform == jasc.LoopNormalform + # This will change nothing regarding normalforms + new_handle = foo_abstraction_3(program) + assert new_handle.normalform == jasc.AnyForm + assert program.normalform == jasc.LoopNormalform + + +# CHECK-LABEL: TEST: test_loop_normalform +@run +@print_schedule +def test_loop_normalform(program: jasc.OpHandle) -> None: + program.normalize(jasc.LoopNormalform) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: apply_patterns to %[[ARG0]] { + # CHECK-NEXT: transform.apply_patterns.linalg.tiling_canonicalization + # CHECK-NEXT: transform.apply_patterns.fold_fill_into_pad + # CHECK-NEXT: transform.apply_patterns.scf.for_loop_canonicalization + # CHECK-NEXT: transform.apply_patterns.canonicalization + # CHECK-NEXT: } + # CHECK-NEXT: %[[ARG1:.*]] = transform.structured.match ops{["scf.for"]} + # CHECK-SAME: in %[[ARG0]] + # CHECK-NEXT: apply_licm to %[[ARG1]] + # CHECK-NEXT: apply_cse to %[[ARG0]] + + +# CHECK-LABEL: TEST: test_no_duplicated_auto_apply +@run +@print_schedule +def test_no_duplicated_auto_apply( + program: jasc.OpHandle, +) -> None: + """Checks autonormalization doesn't trigger when handle is in the correct form.""" + with jasc.autonormalize(): + program.normalize(jasc.LoopNormalform) + program.tile( + loop=jasc.TileLoopKind.FORALL, tile_sizes=[64, 64, 1], mapping=[] + ) + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: apply_patterns to %[[ARG0]] { + # CHECK-NEXT: transform.apply_patterns.linalg.tiling_canonicalization + # CHECK-NEXT: transform.apply_patterns.fold_fill_into_pad + # CHECK-NEXT: transform.apply_patterns.scf.for_loop_canonicalization + # CHECK-NEXT: transform.apply_patterns.canonicalization + # CHECK-NEXT: } + # CHECK-NEXT: %[[ARG1:.*]] = transform.structured.match ops{["scf.for"]} + # CHECK-SAME: in %[[ARG0]] + # CHECK-NEXT: apply_licm to %[[ARG1]] + # CHECK-NEXT: apply_cse to %[[ARG0]] + # CHECK-NOT: apply_patterns + # CHECK-NOT: transform.apply_patterns.linalg.tiling_canonicalization + # CHECK-NOT: transform.apply_patterns.fold_fill_into_pad + # CHECK-NOT: transform.apply_patterns.scf.for_loop_canonicalization + # CHECK-NOT: transform.apply_patterns.canonicalization + # CHECK-NOT: apply_licm + # CHECK-NOT: apply_cse + # CHECK-NEXT: transform.structured.tile_using_forall %[[ARG0]] + + +# CHECK-LABEL: TEST: test_propagation +@run +@print_schedule +def test_propagation(program: jasc.OpHandle): + nested_op = program.match_ops("test.foo_op") + program.normalize(jasc.LoopNormalform) + assert program.normalform == jasc.LoopNormalform + assert nested_op.normalform == jasc.LoopNormalform + nested_op.tile( + loop=jasc.TileLoopKind.FORALL, tile_sizes=[64, 64, 1], mapping=[] + ) + assert nested_op.normalform == jasc.AnyForm + assert program.normalform == jasc.AnyForm + # CHECK: transform.sequence + # CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op): + # CHECK-NEXT: %[[ARG1]] = transform.structured.match ops{["test.foo_op"]} + # CHECK-SAME in %[[ARG0]] + # CHECK-NEXT: apply_patterns to %[[ARG0]] { + # CHECK-NEXT: transform.apply_patterns.linalg.tiling_canonicalization + # CHECK-NEXT: transform.apply_patterns.fold_fill_into_pad + # CHECK-NEXT: transform.apply_patterns.scf.for_loop_canonicalization + # CHECK-NEXT: transform.apply_patterns.canonicalization + # CHECK-NEXT: } + # CHECK-NEXT: %[[ARG2:.*]] = transform.structured.match ops{["scf.for"]} + # CHECK-SAME: in %[[ARG0]] + # CHECK-NEXT: apply_licm to %[[ARG2]] + # CHECK-NEXT: apply_cse to %[[ARG0]] + # CHECK-NEXT: transform.structured.tile_using_forall %[[ARG1]] + + +# CHECK-LABEL: TEST: test_normalize_parent_func_multiple_payloads +@run +def test_normalize_parent_func_multiple_payloads() -> None: + with ir.Context(): + module = ir.Module.parse(""" + func.func public @main(%arg0: f32, %arg1: tensor<8x2xf32>) + -> (tensor<8x2xf32>, tensor<8x2xf32>) { + %0 = linalg.fill ins(%arg0 : f32) outs(%arg1 : tensor<8x2xf32>) -> tensor<8x2xf32> + %1 = linalg.fill ins(%arg0 : f32) outs(%arg1 : tensor<8x2xf32>) -> tensor<8x2xf32> + return %0, %1 : tensor<8x2xf32>, tensor<8x2xf32> + } + """) + + def schedule(program: jasc.OpHandle) -> None: + with jasc.autonormalize(): + program.match_ops("linalg.fill").auto_normalize_parent_func( + jasc.LoopNormalform + ) + + jasc.lower_to_linalg(module, schedule=schedule, dump_schedule=True) + module.operation.verify() + + # CHECK: transform.sequence + # CHECK: %[[V0:.*]] = transform.structured.match + # CHECK-SAME: "linalg.fill" + # CHECK: %[[V1:.*]] = get_parent_op %[[V0]] + # CHECK-SAME: deduplicate + # CHECK: transform.structured.match {{.*}} in %[[V1]] + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + for test_fun in tests: + test_fun() + + +if __name__ == "__main__": + app.run(main) diff --git a/jasc/test/parametric_schedule.mlir b/jasc/test/parametric_schedule.mlir new file mode 100644 index 000000000000..c9a419fe2b22 --- /dev/null +++ b/jasc/test/parametric_schedule.mlir @@ -0,0 +1,39 @@ +// RUN: jasc-opt %s -jasc-apply-transform-script='enforce-single-top-level-transform-op=0' \ +// RUN: -jasc-erase-transform-script \ +// RUN: -jasc-apply-transform-script='enforce-single-top-level-transform-op=0' \ +// RUN: | FileCheck %s + +// This test specializes a parametric transform script and then applies it to a +// computation. +// This is accomplished using two transform interpreter passes. The first to +// apply the first transform script to the second. In this process the value of +// the transform.jasc.tuning_param is set according to the config in the first +// script (16). In the second pass this specialized schedule is applied to the +// computation to tile the `linalg.matmul` with tile size 16. + +// Meta schedule to specialize the parametric schedule +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %parametric = transform.structured.match ops{["transform.sequence"]} in %arg0 : (!transform.any_op) -> !transform.op<"transform.sequence"> + transform.jasc.apply_tuning_config %parametric {config = [16 : i32]} : !transform.op<"transform.sequence"> +} + +// Parametric schedule to tile a matmul +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %0 = transform.jasc.tuning_param {default_value = 0 : i32} -> !transform.param + // CHECK: transform.jasc.tuning_param {default_value = 0 : i32 + // CHECK-SAME: tuned_value = 16 : i32} + %1 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.op<"linalg.matmul"> + %tiled_linalg_op, %loops = transform.structured.tile_using_for %1[%0] : (!transform.op<"linalg.matmul">, !transform.param) -> (!transform.op<"linalg.matmul">, !transform.any_op) +} + +// CHECK-LABEL: @matmul +func.func private @matmul(%lhs : tensor<64x64xi32>, %rhs : tensor<64x64xi32>) -> tensor<64x64xi32> { + %c_0 = arith.constant 0 : i32 + %init_acc_uninitialized = tensor.empty() : tensor<64x64xi32> + %zero_acc = linalg.fill ins(%c_0 : i32) outs(%init_acc_uninitialized : tensor<64x64xi32>) -> tensor<64x64xi32> + %matmul_result = linalg.matmul ins(%lhs, %rhs : tensor<64x64xi32>, tensor<64x64xi32>) outs(%zero_acc : tensor<64x64xi32>) -> tensor<64x64xi32> + // CHECK: scf.for + return %matmul_result : tensor<64x64xi32> +} diff --git a/jasc/test/synchronize.mlir b/jasc/test/synchronize.mlir new file mode 100644 index 000000000000..fad63344da35 --- /dev/null +++ b/jasc/test/synchronize.mlir @@ -0,0 +1,21 @@ +// RUN: jasc-opt %s -jasc-apply-transform-script \ +// RUN: | FileCheck %s + + +// CHECK-LABEL: @synchronize +func.func @synchronize(%arg0 : i32, %arg1 : i32, %arg2 : i32) { + // Check that a gpu.barrier is inserted after the loop + // CHECK: scf.for + // CHECK: } + // CHECK: gpu.barrier + // CHECK: return + scf.for %i0 = %arg0 to %arg1 step %arg2 : i32 { + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.jasc.synchronize %0 : (!transform.any_op) -> (!transform.op<"gpu.barrier">) +} diff --git a/jasc/test/tag.py b/jasc/test/tag.py new file mode 100644 index 000000000000..93f6f84e1183 --- /dev/null +++ b/jasc/test/tag.py @@ -0,0 +1,101 @@ +"""Jasc tag primitive based schedule tests.""" +from __future__ import annotations +from typing import Callable, Sequence + +# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +import sys +sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] + +from absl import app +import chex +import jax +from jax import numpy as jnp + +from jasc import jasc + +tests: list[Callable[[], None]] = [] + + +def run(f): + + def test(): + print("\nTEST:", f.__name__) + f() + + tests.append(test) + return f + + +# CHECK-LABEL: TEST: test_match_tag +@run +def test_match_tag(): + """Tests that a schedule based on matching jasc.tag primitives is applied correctly.""" + + def foo(a: jax.Array) -> jax.Array: + b = jax.lax.abs(jasc.tag(lambda x: x + 1, "b")(a)) + c = jasc.tag(lambda x: x * x, "c")(b) + return c + + input_0 = jnp.full((16, 16), 1.23) + + def schedule(h: jasc.OpHandle) -> None: + # Disable autonormalization here as it interferes with tags attached to + # constants. + with jasc.autonormalize(False): + h.match_tag("b").tile(loop=jasc.TileLoopKind.FOR, tile_sizes=(2, 4)) + h.match_tag("c").tile(loop=jasc.TileLoopKind.FOR, tile_sizes=(4, 2)) + + # Check for correct schedule application + foo_linalg = jasc.lower_to_linalg(foo, input_0, schedule=schedule) + + # CHECK-LABEL: func.func public @main( + # CHECK-SAME: %[[VAL_0:.*]]: + # CHECK: %[[VAL_1:.*]] = arith.constant {jasc_tags = ["b"]} + # CHECK: %[[VAL_2:.*]] = scf.for %[[VAL_3:.*]] = {{.*}} to + # CHECK: %[[VAL_4:.*]] = scf.for %[[VAL_5:.*]] = {{.*}} to + # CHECK-SAME: iter_args(%[[VAL_6:.*]] = + # CHECK: %[[VAL_7:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_5]]] [2, 4] [1, 1] + # CHECK: %[[VAL_8:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_5]]] [2, 4] [1, 1] + # CHECK: %[[VAL_9:.*]] = tensor.extract_slice %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_5]]] [2, 4] [1, 1] + # CHECK: %[[VAL_10:.*]] = linalg.map { arith.addf {jasc_tags = ["b"]} } ins(%[[VAL_7]], %[[VAL_8]] : + # CHECK-SAME: outs(%[[VAL_9]] + # CHECK: %[[VAL_11:.*]] = tensor.insert_slice %[[VAL_10]] into %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_5]]] [2, 4] [1, 1] + # CHECK: scf.yield %[[VAL_11]] + # CHECK: } + # CHECK: scf.yield %[[VAL_4]] + # CHECK: } + # CHECK: %[[VAL_12:.*]] = linalg.map { math.absf } ins(%[[VAL_2]] + # CHECK: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %{{.*}} to + # CHECK: %[[VAL_15:.*]] = arith.constant 16 : index + # CHECK: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %{{.*}} to + # CHECK-SAME: iter_args(%[[VAL_18:.*]] = + # CHECK: %[[VAL_19:.*]] = tensor.extract_slice %[[VAL_12]]{{\[}}%[[VAL_14]], %[[VAL_17]]] [4, 2] [1, 1] + # CHECK: %[[VAL_20:.*]] = tensor.extract_slice %[[VAL_12]]{{\[}}%[[VAL_14]], %[[VAL_17]]] [4, 2] [1, 1] + # CHECK: %[[VAL_21:.*]] = tensor.extract_slice %[[VAL_18]]{{\[}}%[[VAL_14]], %[[VAL_17]]] [4, 2] [1, 1] + # CHECK: %[[VAL_22:.*]] = linalg.map { arith.mulf {jasc_tags = ["c"]} } ins(%[[VAL_19]], %[[VAL_20]] + # CHECK-SAME: outs(%[[VAL_21]] + # CHECK: %[[VAL_23:.*]] = tensor.insert_slice %[[VAL_22]] into %[[VAL_18]]{{\[}}%[[VAL_14]], %[[VAL_17]]] [4, 2] [1, 1] + # CHECK: scf.yield %[[VAL_23]] + # CHECK: } + # CHECK: scf.yield %[[VAL_16]] + # CHECK: } + # CHECK: return %[[VAL_13]] + # CHECK: } + print(foo_linalg) + + # Test for similar results of jasc and jax jit + jax_res = jax.jit(foo)(input_0) + jasc_res = jasc.jit(foo, schedule)(input_0) + chex.assert_trees_all_equal(jasc_res, jax_res) + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + for test_fun in tests: + test_fun() + + +if __name__ == "__main__": + app.run(main) diff --git a/jasc/test/test.mlir b/jasc/test/test.mlir new file mode 100644 index 000000000000..d1920ddc74dd --- /dev/null +++ b/jasc/test/test.mlir @@ -0,0 +1,8 @@ +// This file tests the tests infrastructure. It can be removed once we have any +// test case that actually tests something. + +// RUN: jasc-opt %s -jasc-memcpy-to-gpu-dialect \ +// RUN: | jasc-opt \ +// RUN: | FileCheck %s + +// CHECK: module diff --git a/jasc/test/wrap-in-cpu-launch.mlir b/jasc/test/wrap-in-cpu-launch.mlir new file mode 100644 index 000000000000..16a0422b566f --- /dev/null +++ b/jasc/test/wrap-in-cpu-launch.mlir @@ -0,0 +1,40 @@ +// RUN: jasc-opt %s -jasc-apply-transform-script \ +// RUN: | FileCheck %s + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op): + %linalg_ops = transform.structured.match interface {LinalgOp} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.jasc.wrap_in_gpu_launch %linalg_ops + : (!transform.any_op) -> !transform.op<"gpu.launch"> +} + +// CHECK-LABEL: func.func @already_wrapped_op +// CHECK-SAME: %[[ARG0:.*]]: f32, +// CHECK-SAME: %[[ARG1:.*]]: memref<16xf32>) -> memref<16xf32> { +// CHECK: gpu.launch +// CHECK-NEXT: linalg.fill +// CHECK-NEXT: gpu.terminator +// CHECK-NOT: gpu +func.func @already_wrapped_op(%arg0: f32, %arg1: memref<16xf32>) -> memref<16xf32> { + %c1 = arith.constant 1 : index + gpu.launch blocks(%arg2, %arg3, %arg4) + in (%arg8 = %c1, %arg9 = %c1, %arg10 = %c1) + threads(%arg5, %arg6, %arg7) + in (%arg11 = %c1, %arg12 = %c1, %arg13 = %c1) { + linalg.fill ins(%arg0 : f32) outs(%arg1 : memref<16xf32>) + gpu.terminator + } + return %arg1 : memref<16xf32> +} + +// CHECK-LABEL: func.func @simple_fill +// CHECK-SAME: %[[ARG0:.*]]: f32, +// CHECK-SAME: %[[ARG1:.*]]: memref<16xf32>) -> memref<16xf32> { +// CHECK: gpu.launch +// CHECK-NEXT: linalg.fill +// CHECK-NEXT: gpu.terminator +func.func @simple_fill(%arg0: f32, %arg1: memref<16xf32>) -> memref<16xf32> { + linalg.fill ins(%arg0 : f32) outs(%arg1 : memref<16xf32>) + return %arg1 : memref<16xf32> +} \ No newline at end of file diff --git a/jasc/transform_ops/BUILD b/jasc/transform_ops/BUILD new file mode 100644 index 000000000000..9c757d83d0ac --- /dev/null +++ b/jasc/transform_ops/BUILD @@ -0,0 +1,135 @@ +# JASC extension for the MLIR transform dialect. + +load("@rules_python//python:defs.bzl", "py_library") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library") +load("@llvm-project//mlir:build_defs.bzl", "cc_headers_only") +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_visibility = ["//visibility:public"], +) + +td_library( + name = "td_files", + srcs = glob(["*.td"]), + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:TransformDialectTdFiles", + ], +) + +gentbl_cc_library( + name = "jasc_transform_ops_inc_gen", + tbl_outs = [ + ( + ["-gen-op-decls"], + "jasc_transform_ops.h.inc", + ), + ( + ["-gen-op-defs"], + "jasc_transform_ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "jasc_transform_ops.td", + deps = [":td_files"], +) + +gentbl_filegroup( + name = "jasc_transform_ops_py_gen", + tbl_outs = [( + [ + "-gen-python-op-bindings", + "-bind-dialect=transform", + "-dialect-extension=jasc_transform", + ], + "_jasc_transform_ops_gen.py", + )], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "jasc_transform_ops.td", + deps = [":td_files"], +) + +pybind_extension( + name = "bindings", + srcs = ["bindings.cpp"], + deps = [ + "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", + ":jasc_transform_ops_shared_library", + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:MLIRBindingsPythonHeaders", + ], +) + +py_library( + name = "transform_ops", + srcs = [ + "_ods_common.py", + "_transform_ops_gen.py", + "jasc_transform_ops.py", + ":jasc_transform_ops_py_gen", + ], + deps = [ + ":bindings", + ], +) + +cc_library( + name = "jasc_transform_ops_shared_library_deps", + deps = [ + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformDialect", + ], +) + +cc_headers_only( + name = "jasc_transform_ops_shared_library_deps_headers", + src = "jasc_transform_ops_shared_library_deps", +) + +cc_library( + name = "jasc_transform_ops_shared_library", + srcs = [ + ":libjasctransformops.so", + ], + hdrs = glob(["*.h"]), + deps = [ + "jasc_transform_ops_shared_library_deps_headers", + ], +) + +cc_headers_only( + name = "jasc_transform_ops_shared_library_headers", + src = "jasc_transform_ops_shared_library", +) + +cc_binary( + name = "libjasctransformops.so", + linkopts = [ + "-Wl,-soname=libjasctransformops.so", + "-Wl,-rpath='$$ORIGIN'", + ], + linkshared = 1, + deps = [":jasc_transform_ops"], +) + +cc_headers_only( + name = "jasc_transform_ops_headers", + src = "jasc_transform_ops", +) + +cc_library( + name = "jasc_transform_ops", + srcs = glob(["*.cc"]), + hdrs = glob(["*.h"]), + includes = ["."], + deps = [ + ":jasc_transform_ops_inc_gen", + ":jasc_transform_ops_shared_library_deps_headers" + ], + alwayslink = True, +) diff --git a/jasc/transform_ops/_ods_common.py b/jasc/transform_ops/_ods_common.py new file mode 100644 index 000000000000..abcb9e8b64cd --- /dev/null +++ b/jasc/transform_ops/_ods_common.py @@ -0,0 +1,6 @@ +"""Trampoline to run generated MLIR Python code. + +Generated tablegen dialects expect to be able to find some symbols from the +mlir.dialects package. +""" +from jaxlib.mlir.dialects._ods_common import _cext, equally_sized_accessor, get_default_loc_context, get_op_result_or_op_results, get_op_result_or_value, get_op_results_or_values, segmented_accessor diff --git a/jasc/transform_ops/_transform_ops_gen.py b/jasc/transform_ops/_transform_ops_gen.py new file mode 100644 index 000000000000..cf90c6646f5e --- /dev/null +++ b/jasc/transform_ops/_transform_ops_gen.py @@ -0,0 +1,7 @@ +"""Trampoline to run generated MLIR Python code. + +Generated tablegen dialects expect to be able to find some symbols from the +mlir.dialects package. +""" + +from jaxlib.mlir.dialects._transform_ops_gen import _Dialect diff --git a/jasc/transform_ops/bindings.cpp b/jasc/transform_ops/bindings.cpp new file mode 100644 index 000000000000..1da7d4f5d6ae --- /dev/null +++ b/jasc/transform_ops/bindings.cpp @@ -0,0 +1,20 @@ +#include "mlir/CAPI/IR.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/lib/Bindings/Python/IRModule.h" +#include "pybind11/attr.h" +#include "pybind11/pybind11.h" + +#include "dialect_extension.h" + +PYBIND11_MODULE(_mlirTransformOpsJasc, m) { + m.def( + "register_transform_dialect_extension", + [](mlir::python::DefaultingPyMlirContext py_context) { + mlir::MLIRContext *context = unwrap(py_context->get()); + mlir::DialectRegistry registry; + jasc::registerTransformDialectExtension(registry); + context->appendDialectRegistry(registry); + }, + "context"_a = pybind11::none()); +} diff --git a/jasc/transform_ops/dialect_extension.cc b/jasc/transform_ops/dialect_extension.cc new file mode 100644 index 000000000000..ca3a383af702 --- /dev/null +++ b/jasc/transform_ops/dialect_extension.cc @@ -0,0 +1,26 @@ +#include "dialect_extension.h" + +#include "jasc_transform_ops.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/DialectRegistry.h" + +namespace { +class JascTransformDialectExtension + : public mlir::transform::TransformDialectExtension< + JascTransformDialectExtension> { + public: + using Base::Base; + + void init() { + registerTransformOps< +#define GET_OP_LIST +#include "jasc_transform_ops.cpp.inc" + >(); + } +}; +} // namespace + +void jasc::registerTransformDialectExtension(mlir::DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/jasc/transform_ops/dialect_extension.h b/jasc/transform_ops/dialect_extension.h new file mode 100644 index 000000000000..1e7ae3fe05d6 --- /dev/null +++ b/jasc/transform_ops/dialect_extension.h @@ -0,0 +1,10 @@ +#ifndef THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORM_OPS_DIALECT_EXTENSION_H_ +#define THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORM_OPS_DIALECT_EXTENSION_H_ + +#include "mlir/IR/DialectRegistry.h" + +namespace jasc { +void registerTransformDialectExtension(mlir::DialectRegistry ®istry); +} + +#endif // THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORM_OPS_DIALECT_EXTENSION_H_ diff --git a/jasc/transform_ops/jasc_transform_ops.cc b/jasc/transform_ops/jasc_transform_ops.cc new file mode 100644 index 000000000000..de592b27bab8 --- /dev/null +++ b/jasc/transform_ops/jasc_transform_ops.cc @@ -0,0 +1,220 @@ +#include "jasc_transform_ops.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/BuiltinAttributes.h" + +#define GET_OP_CLASSES +#include "jasc_transform_ops.cpp.inc" + +//===----------------------------------------------------------------------===// +// MatchTagOp +//===----------------------------------------------------------------------===// + +mlir::DiagnosedSilenceableFailure jasc::MatchTagOp::apply( + mlir::transform::TransformRewriter &rewriter, + mlir::transform::TransformResults &results, + mlir::transform::TransformState &state) { + llvm::SmallVector matched_ops; + for (mlir::Operation *op : state.getPayloadOps(getTarget())) { + auto tags = op->getAttrOfType("jasc_tags"); + if (tags == nullptr) continue; + if (tags.size() < getTags().size()) continue; + bool is_match = true; + for (int i = 0; i < getTags().size(); i++) { + if (tags[i] != getTags()[i]) { + is_match = false; + break; + } + } + if (!is_match) continue; + matched_ops.push_back(op); + } + results.set(llvm::cast(getMatchedOps()), matched_ops); + + return mlir::DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// FoldFillIntoPad +//===----------------------------------------------------------------------===// + +// Ported from: +// iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +namespace { +/// Fold `tensor.pad(cst, tensor.extract*(linalg.fill(cst)))` into +/// `linalg.fill(cst, empty)` when the padding constant and the fill constant +/// are the same. +/// This seems generally desirable as a folding but may be too intrusive, so we +/// only apply it selectively for now. +// TODO: atm hardcoded on linalg.fill but we could take any result of any +// generic that yields a constant in that result. +struct FoldFillIntoPad : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + mlir::LogicalResult matchAndRewrite( + mlir::tensor::PadOp padOp, mlir::PatternRewriter &rewriter) const final { + mlir::Operation *currentOp = padOp.getSource().getDefiningOp(); + auto maybeExtractSlice = + mlir::dyn_cast_or_null(currentOp); + while (currentOp && maybeExtractSlice) { + currentOp = maybeExtractSlice.getSource().getDefiningOp(); + maybeExtractSlice = + mlir::dyn_cast_or_null(currentOp); + } + auto fillOp = mlir::dyn_cast_or_null(currentOp); + if (!fillOp) { + return rewriter.notifyMatchFailure( + padOp, "not coming from a linalg.fill op via tensor.extract_slice*"); + } + + mlir::Value padValue = padOp.getConstantPaddingValue(); + mlir::RankedTensorType resultType = padOp.getResultType(); + if (!padValue || + getAsOpFoldResult(padValue) != + getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get())) { + return rewriter.notifyMatchFailure( + padOp, "not a constant value matching the fill value"); + } + + mlir::Location loc = padOp.getLoc(); + auto emptyOp = rewriter.create( + loc, mlir::tensor::getMixedSizes(rewriter, loc, padOp), + resultType.getElementType()); + rewriter.replaceOpWithNewOp(padOp, padValue, + emptyOp.getResult()); + + return mlir::success(); + } +}; +} // namespace + +void jasc::ApplyFoldFillIntoPadPatternsOp::populatePatterns( + mlir::RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} + +//===----------------------------------------------------------------------===// +// SynchronizeOp +//===----------------------------------------------------------------------===// + +void jasc::SynchronizeOp::getEffects( + llvm::SmallVectorImpl &effects) { + mlir::transform::onlyReadsHandle(getOp(), effects); + mlir::transform::producesHandle(getBarrier(), effects); + mlir::transform::modifiesPayload(effects); +} + +mlir::DiagnosedSilenceableFailure jasc::SynchronizeOp::applyToOne( + mlir::transform::TransformRewriter &rewriter, mlir::Operation *operation, + mlir::transform::ApplyToEachResultList &results, + mlir::transform::TransformState &state) { + rewriter.setInsertionPointAfter(operation); + auto barrier = rewriter.create(operation->getLoc()); + results.push_back(barrier); + return mlir::DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// TuningParamOp +//===----------------------------------------------------------------------===// + +mlir::DiagnosedSilenceableFailure jasc::TuningParamOp::apply( + mlir::transform::TransformRewriter &rewriter, + mlir::transform::TransformResults &results, + mlir::transform::TransformState &state) { + if (!getTunedValue().has_value()) { + mlir::emitWarning(getLoc()) + << "tuning param not tuned, falling back to default value"; + } + results.setParams(llvm::cast(getParam()), + {getTunedValue().value_or(getDefaultValue())}); + return mlir::DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// ApplyTuningConfigOp +//===----------------------------------------------------------------------===// + +mlir::DiagnosedSilenceableFailure jasc::ApplyTuningConfigOp::applyToOne( + mlir::transform::TransformRewriter &rewriter, mlir::Operation *operation, + mlir::transform::ApplyToEachResultList &results, + mlir::transform::TransformState &state) { + size_t configIdx = 0; + auto config = llvm::cast(getConfig()); + // Walk all tuning parameters and set their value according to the config + // attribute + mlir::WalkResult walkResult = operation->walk( + [&](jasc::TuningParamOp tuningParamOp) { + if (configIdx >= config.size()) { + return mlir::WalkResult::interrupt(); + } + auto configVal = config[configIdx++]; + tuningParamOp.setTunedValueAttr(configVal); + return mlir::WalkResult::skip(); + }); + if (configIdx == 0) { + operation->emitWarning() + << "no tuning parameters found, expected " << config.size(); + return mlir::DiagnosedSilenceableFailure::success(); + } + if (walkResult.wasInterrupted() || configIdx != config.size()) { + return mlir::emitSilenceableFailure(getLoc()) + << "size of config has to match the number of tunable variables: " + << config.size() << " vs " << configIdx; + } + return mlir::DiagnosedSilenceableFailure::success(); +} + +void jasc::ApplyTuningConfigOp::getEffects( + llvm::SmallVectorImpl &effects) { + mlir::transform::onlyReadsHandle(getTarget(), effects); +} +//===----------------------------------------------------------------------===// +// WrapInGpuLaunchOp +//===----------------------------------------------------------------------===// + +mlir::DiagnosedSilenceableFailure jasc::WrapInGpuLaunchOp::applyToOne( + mlir::transform::TransformRewriter &rewriter, mlir::Operation *operation, + mlir::transform::ApplyToEachResultList &results, + mlir::transform::TransformState &state) { + mlir::Location loc = operation->getLoc(); + + if (!operation->getUsers().empty()) { + return mlir::emitSilenceableFailure(loc) + << "The operation has users, cannot wrap the operation in a " + "gpu.launch"; + } + + if (auto existingLaunchOp = + operation->getParentOfType()) { + mlir::DiagnosedSilenceableFailure diag = + mlir::emitSilenceableFailure(loc) + << "not wrapping this op into a gpu.launch op because it already is " + "contained in one."; + diag.attachNote(existingLaunchOp->getLoc()) + << "contained in this gpu.launch op."; + return diag; + } + + rewriter.setInsertionPoint(operation); + auto one = rewriter.create(loc, 1); + auto launch_op = + rewriter.create(loc, one, one, one, one, one, one); + rewriter.setInsertionPointToEnd(&launch_op.getBody().front()); + auto terminator = rewriter.create(loc); + operation->moveBefore(terminator); + + results.push_back(launch_op); + return mlir::DiagnosedSilenceableFailure::success(); +} + +void jasc::WrapInGpuLaunchOp::getEffects( + llvm::SmallVectorImpl &effects) { + mlir::transform::onlyReadsHandle(getOps(), effects); + mlir::transform::producesHandle(getGpuLaunch(), effects); +} diff --git a/jasc/transform_ops/jasc_transform_ops.h b/jasc/transform_ops/jasc_transform_ops.h new file mode 100644 index 000000000000..bb7147d731ae --- /dev/null +++ b/jasc/transform_ops/jasc_transform_ops.h @@ -0,0 +1,13 @@ +#ifndef THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORMOPS_JASCTRANSFORMOPS_H_ +#define THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORMOPS_JASCTRANSFORMOPS_H_ + +#include + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" + +#define GET_OP_CLASSES +#include "jasc_transform_ops.h.inc" + +#endif // THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORMOPS_JASCTRANSFORMOPS_H_ diff --git a/jasc/transform_ops/jasc_transform_ops.py b/jasc/transform_ops/jasc_transform_ops.py new file mode 100644 index 000000000000..4b31d48da586 --- /dev/null +++ b/jasc/transform_ops/jasc_transform_ops.py @@ -0,0 +1,29 @@ +from ._jasc_transform_ops_gen import * +from ._jasc_transform_ops_gen import _Dialect +from ..._mlir_libs._mlirTransformOpsJasc import * + +try: + from typing import Sequence + from jaxlib.mlir import ir + from jaxlib.mlir.dialects import pdl + from ._ods_common import _cext as _ods_cext + +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MatchTagOp(MatchTagOp): + """Specialization for the MatchTag op class.""" + + def __init__( + self, + target: ir.Value, + tags: Sequence[str], + *, + ip=None, + loc=None, + ): + result_ty = pdl.OperationType.get() + tags_attr = ir.ArrayAttr.get(list(map(ir.StringAttr.get, tags))) + super().__init__(result_ty, target, tags_attr, ip=ip, loc=loc) diff --git a/jasc/transform_ops/jasc_transform_ops.td b/jasc/transform_ops/jasc_transform_ops.td new file mode 100644 index 000000000000..db1644e11934 --- /dev/null +++ b/jasc/transform_ops/jasc_transform_ops.td @@ -0,0 +1,140 @@ +#ifndef JASC_JASC_TRANSFORM_OPS_TD +#define JASC_JASC_TRANSFORM_OPS_TD + +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" + +class Jasc_TransformOp traits = []> : + Op { + let cppNamespace = "jasc"; +} + +def Jasc_MatchTagOp : Jasc_TransformOp<"match_tag", [ + MemoryEffectsOpInterface, + NavigationTransformOpTrait, + DeclareOpInterfaceMethods +]> { + let summary = "Matches operations with the given tag"; + + let arguments = (ins + TransformHandleTypeInterface:$target, + StrArrayAttr:$tags); + + let results = (outs TransformHandleTypeInterface:$matched_ops); + + let assemblyFormat = [{ + $tags `in` $target attr-dict `:` functional-type($target, results) + }]; +} + +def ApplyFoldFillIntoPadPatternsOp : Op, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Populates a pattern that folds + "tensor.pad(cst, tensor.extract*(linalg.fill(cst)))" into + "linalg.fill(cst, empty)" when the padding constant and the fill constant + are the same. + }]; + + let assemblyFormat = "attr-dict"; + let cppNamespace = "jasc"; +} + +def ApplyTuningConfigOp : Jasc_TransformOp<"apply_tuning_config", [ + DeclareOpInterfaceMethods, + TransformEachOpTrait, + TransformOpInterface, + ReportTrackingListenerFailuresOpTrait +]> { + let summary = "Specializes parametric transform IR."; + let description = [{ + Specializes all TuningParamOps nested under `target` using the values in + `config`. This means all unknown parameters will have an explicit value + after applying this transform op. The number of elements in `config` have to + match the number of tunable parameters nested under `target`. If there are + no nested tunable parameters this will not perform any modifications and + return success. + }]; + let arguments = (ins TransformHandleTypeInterface:$target, ArrayAttr:$config); + let assemblyFormat = "$target attr-dict `:` type($target)"; + let cppNamespace = "jasc"; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *operation, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def TuningParamOp : Jasc_TransformOp<"tuning_param", [ + MemoryEffectsOpInterface, + DeclareOpInterfaceMethods, + ParamProducerTransformOpTrait +]> { + let summary = "A tunable parameter"; + let description = [{ + Represents an unknown parameter that is to be specified before execution of + the transform script. Requires a default_value that will be used as value in + the case no tuning is performed. + }]; + let arguments = (ins AnyAttr:$default_value, OptionalAttr:$tuned_value); + let results = (outs TransformParamTypeInterface:$param); + let assemblyFormat = "attr-dict `->` type($param)"; + let cppNamespace = "jasc"; +} + +def SynchronizeOp : Op< + Transform_Dialect, "jasc.synchronize", [ + DeclareOpInterfaceMethods, + TransformEachOpTrait, + TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { + let summary = "Inserts a gpu.barrier after a given operation."; + + let arguments = ( + ins TransformHandleTypeInterface:$op); + let results = (outs TransformHandleTypeInterface:$barrier); + let assemblyFormat = [{ + $op + attr-dict + `:` functional-type(operands, results)}]; + + let cppNamespace = "jasc"; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *operation, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def Jasc_WrapInGpuLaunchOp : Jasc_TransformOp<"wrap_in_gpu_launch", [ + DeclareOpInterfaceMethods, + TransformEachOpTrait, + TransformOpInterface, +]> { + let summary = "Wraps operations in a gpu.launch region."; + + let arguments = (ins TransformHandleTypeInterface:$ops); + let results = (outs Transform_ConcreteOpType<"gpu.launch">:$gpu_launch); + + let assemblyFormat = "$ops attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *operation, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +#endif // JASC_JASC_TRANSFORM_OPS_TD diff --git a/jasc/tuner.py b/jasc/tuner.py new file mode 100644 index 000000000000..60d06af4f901 --- /dev/null +++ b/jasc/tuner.py @@ -0,0 +1,173 @@ +"""Utilities for custom autotuners for parametric transforms.""" + +from __future__ import annotations + +import abc +from collections.abc import Sequence +import dataclasses +import io +import math +import timeit +from typing import Any, Callable, Optional + +import jax +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import transform + +from jasc import jasc +from jaxlib.mlir.dialects.transform import jasc_transform_ops + +@dataclasses.dataclass +class TunerBase(abc.ABC): + """Base class for custom autotuners. + + Provides a default autotuning loop based on a budget of tuning configurations + to evaluate. Subclasses have to implement `get_tuning_config` to drive the + autotuning process. + Currently evaluation of a tuned function is limited to the metric of execution + time. + """ + + func: Callable[[Any], Any] + parametric_schedule: jasc.Schedule + inputs: Sequence[jax.Array] + budget: int = 10 + tuned_func_evals: int = 100 + dump_ir: bool = False + + @abc.abstractmethod + def get_tuning_config( + self, + tuning_vars: Sequence[jasc_transform_ops.TuningParamOp], + previous_results: Sequence[TuningResult], + ) -> TuningConfig: + """Returns a tuning configuration to specialize a parametric schedule. + + A tuning configuration is a list of explicit values for the set of tuning + variables in a parametric schedule. The type of value is currently limited + to int. + Args: + tuning_vars: The list of tuning variables in a parametric schedule + previous_results: The previously evaluated configurations and respective + result metric for this set of tuning variables. + """ + ... + + def tune( + self, + ) -> tuple[ + float, Optional[Callable[..., Any]], Optional[jasc.Schedule], list[float] + ]: + """Explores different configurations for the tuning parameters in the schedule.""" + + # Lower module to linalg and insert parametric schedule into the IR. + initial_module = jasc.lower_to_linalg(self.func, *self.inputs) + jasc.insert_schedule( + initial_module, self.parametric_schedule, dump_schedule=self.dump_ir + ) + + # Print the initial module to an object so it can be reparsed repeatedly. + f = io.StringIO("") + initial_module.operation.print(f) + tuning_vars = self.get_tuning_vars(initial_module) + + # Tuning loop + times: list[float] = [] + best_time: float = math.inf + best_fun: Optional[Callable[..., Any]] = None + best_schedule: Optional[jasc.Schedule] = None + previous_configs: list[TuningResult] = [] + for _ in range(self.budget): + # Create new copy of the module by reparsing it. + + # TODO(mluecke): Reparsing is a slow way of copying but cloning is not yet + # exposed in the Python bindings. Implement a better + # approach to this. + module = ir.Module.parse( + f.getvalue(), context=initial_module.operation.context + ) + config = self.get_tuning_config(tuning_vars, previous_configs) + + def meta_schedule(module: jasc.OpHandle) -> None: + sequence_op = module.match_ops(transform.SequenceOp) + sequence_op.apply_tuning_config(config.values) # pylint: disable=cell-var-from-loop + + # Apply a tuning configuration to convert the parametric schedule to a + # version with explicit parameters only. + jasc.insert_schedule(module, meta_schedule, dump_schedule=self.dump_ir) + jasc.apply_schedule(module) + + # Evaluate this tuning configuration by applying the schedule the payload + # IR and timing the execution. + try: + tuned_fun = jasc.jit(self.func, module=module, dump_ir=self.dump_ir) + time = timeit.timeit( + lambda: tuned_fun(*self.inputs), number=self.tuned_func_evals # pylint: disable=cell-var-from-loop + ) + if time < best_time: + best_time = time + best_fun = tuned_fun + best_schedule = meta_schedule + except: + time = math.inf + + previous_configs.append(TuningResult(config, time)) + times.append(time) + + return best_time, best_fun, best_schedule, times + + def get_tuning_vars(self, module: ir.Module) -> list[ir.Operation]: + """Returns a list of all tuning parameters in a module.""" + tuning_vars: list[ir.Operation] = [] + _walk( + module, + lambda op: tuning_vars.append(op) + if op.operation.name == "transform.jasc.tuning_param" + else None, + ) + return tuning_vars + + +@dataclasses.dataclass +class FooTuner(TunerBase): + """Example implementation of a custom tuner. + + New tuning configurations are determined by successively adding 1 to the + default config of each tuning variable. + """ + + def get_tuning_config( + self, + tuning_vars: Sequence[jasc_transform_ops.TuningParamOp], + previous_results: Sequence[TuningResult], + ) -> TuningConfig: + config: list[int] = [] + if len(previous_results) == 0: + for op in tuning_vars: + config.append(op.default_value.value) + else: + config = [ + config_val + 1 + for config_val in previous_results[-1].configuration.values + ] + return TuningConfig(config) + + +@dataclasses.dataclass(frozen=True) +class TuningResult: + configuration: TuningConfig + result: float + + +@dataclasses.dataclass +class TuningConfig: + values: list[int] + + +def _walk(op: ir.Operation, callback: Callable[[ir.Operation], None]) -> None: + """Calls the `callback` function on `op` and recurses to all nested ops.""" + callback(op) + for region in op.operation.regions: + for block in region.blocks: + for nested_op in block.operations: + _walk(nested_op, callback) From 7492d5a7d3b5ec8662cd4ea6f2a0e85160eff878 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 15 Jan 2024 10:34:37 +0000 Subject: [PATCH 02/21] Fix Bazel version with .bazelversion file. I am using version 6.4.0 because it is the latest current version below 7.0.0, plus it is LTS. Several other minor versions of version 6 have worked as well. Bazel 7.0.0 does not work, probably due to some incompatibility in one of the many dependencies. At least one dependency requires Bazel 5.4.0 or higher. --- jasc/.bazelversion | 1 + 1 file changed, 1 insertion(+) create mode 100644 jasc/.bazelversion diff --git a/jasc/.bazelversion b/jasc/.bazelversion new file mode 100644 index 000000000000..19b860c1872d --- /dev/null +++ b/jasc/.bazelversion @@ -0,0 +1 @@ +6.4.0 From eec85762d7f02e6336aace70b4c561264ab70d8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 15 Jan 2024 08:47:07 +0000 Subject: [PATCH 03/21] Add license information to file headers and BUILD files. --- jasc/BUILD | 14 +- jasc/LICENSE | 219 +++++++++++++++++++++++ jasc/WORKSPACE | 20 +++ jasc/__init__.py | 4 + jasc/call_kernel.cc | 8 + jasc/dialect/BUILD | 6 +- jasc/dialect/__init__.py | 4 + jasc/dialect/_ods_common.py | 4 + jasc/dialect/bindings.cc | 8 + jasc/dialect/capi.cc | 8 + jasc/dialect/capi.h | 8 + jasc/dialect/dialect.cc | 8 + jasc/dialect/dialect.h | 8 + jasc/dialect/dialect.td | 4 + jasc/dialect/jasc.py | 4 + jasc/dialect/ops.cc | 8 + jasc/dialect/ops.h | 8 + jasc/dialect/ops.td | 4 + jasc/dialect/ops_py.td | 4 + jasc/gpu_lowering_passes.cc | 8 + jasc/gpu_lowering_passes.h | 8 + jasc/gpu_post_bufferize.mlir | 4 + jasc/jasc.py | 4 + jasc/jasc_opt.cc | 8 + jasc/mlir_lowering.cc | 8 + jasc/mlir_lowering.h | 8 + jasc/primitives.py | 4 + jasc/test/BUILD | 5 + jasc/test/abstractions.py | 4 + jasc/test/autotuning.py | 4 + jasc/test/batch_matmul_gpu.py | 4 + jasc/test/bindings.py | 4 + jasc/test/cpu_integration.py | 4 + jasc/test/diagnostics.py | 4 + jasc/test/gpu_integration.py | 4 + jasc/test/jit.py | 4 + jasc/test/lit.cfg.py | 4 + jasc/test/lit.site.cfg.in.py | 4 + jasc/test/matmul_cpu.py | 4 + jasc/test/matmul_gpu.py | 4 + jasc/test/normalization.py | 4 + jasc/test/tag.py | 4 + jasc/transform_ops/BUILD | 5 + jasc/transform_ops/_ods_common.py | 4 + jasc/transform_ops/_transform_ops_gen.py | 4 + jasc/transform_ops/dialect_extension.cc | 8 + jasc/transform_ops/dialect_extension.h | 8 + jasc/transform_ops/jasc_transform_ops.cc | 8 + jasc/transform_ops/jasc_transform_ops.h | 8 + jasc/transform_ops/jasc_transform_ops.py | 4 + jasc/transform_ops/jasc_transform_ops.td | 4 + jasc/tuner.py | 4 + 52 files changed, 519 insertions(+), 2 deletions(-) create mode 100644 jasc/LICENSE diff --git a/jasc/BUILD b/jasc/BUILD index d78bceeb2042..1ce1a34ed0d3 100644 --- a/jasc/BUILD +++ b/jasc/BUILD @@ -1,11 +1,23 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + # Schedules for JAX. load("@rules_python//python:defs.bzl", "py_library") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") load("@llvm-project//mlir:build_defs.bzl", "cc_headers_only") +load("@rules_license//rules:license.bzl", "license") + +license( + name = "license", + package_name = "Jasc", + license_text = "LICENSE", + package_url = "https://github.com/iree-org/iree-llvm-sandbox/tree/main/jasc", +) package( - # default_applicable_licenses = ["//third_party/mlir_edge:license"], + default_applicable_licenses = [":license"], default_visibility = ["//visibility:public"], ) diff --git a/jasc/LICENSE b/jasc/LICENSE new file mode 100644 index 000000000000..f9dc50615d7e --- /dev/null +++ b/jasc/LICENSE @@ -0,0 +1,219 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +--- LLVM Exceptions to the Apache 2.0 License ---- + +As an exception, if, as a result of your compiling your source code, portions +of this Software are embedded into an Object form of such source code, you +may redistribute such embedded portions in such Object form without complying +with the conditions of Sections 4(a), 4(b) and 4(d) of the License. + +In addition, if you combine or link compiled forms of this Software with +software that is licensed under the GPLv2 ("Combined Software") and if a +court of competent jurisdiction determines that the patent provision (Section +3), the indemnity provision (Section 9) or other Section of the License +conflicts with the conditions of the GPLv2, you may retroactively and +prospectively choose to deem waived or otherwise exclude such Section(s) of +the License, but only in their entirety and only with respect to the Combined +Software. + diff --git a/jasc/WORKSPACE b/jasc/WORKSPACE index 677e3a66cf48..9f5625842c4b 100644 --- a/jasc/WORKSPACE +++ b/jasc/WORKSPACE @@ -1,8 +1,28 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + workspace(name = "jasc") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") +# +# @rules_license. +# + +LICENSERULES_VERSION = "0.0.7" +LICENSERULES_SHA256 = "4531deccb913639c30e5c7512a054d5d875698daeb75d8cf90f284375fe7c360" + +http_archive( + name = "rules_license", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/rules_license/releases/download/{version}/rules_license-{version}.tar.gz".format(version = LICENSERULES_VERSION), + "https://github.com/bazelbuild/rules_license/releases/download/{version}/rules_license-{version}.tar.gz".format(version = LICENSERULES_VERSION), + ], + sha256 = LICENSERULES_SHA256, +) + # # @rules_cc. # diff --git a/jasc/__init__.py b/jasc/__init__.py index e69de29bb2d1..e146a208b81c 100644 --- a/jasc/__init__.py +++ b/jasc/__init__.py @@ -0,0 +1,4 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + diff --git a/jasc/call_kernel.cc b/jasc/call_kernel.cc index 0df443749391..f73e767e2e68 100644 --- a/jasc/call_kernel.cc +++ b/jasc/call_kernel.cc @@ -1,3 +1,11 @@ +//===-- call_kernel.cc - Runtime glue for JAX kernels -----------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include #include diff --git a/jasc/dialect/BUILD b/jasc/dialect/BUILD index 1196cd4623d7..0e9f60aa1a82 100644 --- a/jasc/dialect/BUILD +++ b/jasc/dialect/BUILD @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + # MLIR Dialect to support Jasc transformations. load("@rules_python//python:defs.bzl", "py_library") @@ -6,7 +10,7 @@ load("@llvm-project//mlir:build_defs.bzl", "cc_headers_only") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") package( - # default_applicable_licenses = ["//third_party/mlir_edge:license"], + default_applicable_licenses = ["//:license"], default_visibility = ["//visibility:public"], ) diff --git a/jasc/dialect/__init__.py b/jasc/dialect/__init__.py index fe0bd366ee86..a0fe3540cad6 100644 --- a/jasc/dialect/__init__.py +++ b/jasc/dialect/__init__.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Python bindings for Jasc MLIR operations.""" from dialect.bindings import * diff --git a/jasc/dialect/_ods_common.py b/jasc/dialect/_ods_common.py index abcb9e8b64cd..596e5b919d8a 100644 --- a/jasc/dialect/_ods_common.py +++ b/jasc/dialect/_ods_common.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Trampoline to run generated MLIR Python code. Generated tablegen dialects expect to be able to find some symbols from the diff --git a/jasc/dialect/bindings.cc b/jasc/dialect/bindings.cc index b2f385712f49..c32d063e3c83 100644 --- a/jasc/dialect/bindings.cc +++ b/jasc/dialect/bindings.cc @@ -1,3 +1,11 @@ +//===-- bindings.cc - Python bindings for Jasc dialect ----------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include "mlir/CAPI/IR.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/MLIRContext.h" diff --git a/jasc/dialect/capi.cc b/jasc/dialect/capi.cc index 9a2737d3e895..6c97cf741c29 100644 --- a/jasc/dialect/capi.cc +++ b/jasc/dialect/capi.cc @@ -1,3 +1,11 @@ +//===-- capi.cc - C-API for the Jasc dialect --------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include "capi.h" #include "mlir/CAPI/Registration.h" diff --git a/jasc/dialect/capi.h b/jasc/dialect/capi.h index f6ca29199cf1..2a0631800b65 100644 --- a/jasc/dialect/capi.h +++ b/jasc/dialect/capi.h @@ -1,3 +1,11 @@ +//===-- capi.h - C-API for the Jasc dialect ---------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include "mlir-c/IR.h" diff --git a/jasc/dialect/dialect.cc b/jasc/dialect/dialect.cc index 7a8eb3c0a8e0..e3a610c410ef 100644 --- a/jasc/dialect/dialect.cc +++ b/jasc/dialect/dialect.cc @@ -1,3 +1,11 @@ +//===-- dialect.cc - Jasc dialect implementation ----------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include "dialect.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" diff --git a/jasc/dialect/dialect.h b/jasc/dialect/dialect.h index e81744d959b2..cf8f60c485c8 100644 --- a/jasc/dialect/dialect.h +++ b/jasc/dialect/dialect.h @@ -1,3 +1,11 @@ +//===-- dialect.h - Jasc dialect --------------------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #ifndef THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_DIALECT_H_ #define THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_DIALECT_H_ diff --git a/jasc/dialect/dialect.td b/jasc/dialect/dialect.td index 4c53967a8eb1..012eb2ff2ff3 100644 --- a/jasc/dialect/dialect.td +++ b/jasc/dialect/dialect.td @@ -1,3 +1,7 @@ +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + #ifndef JASC_DIALECT_DIALECT #define JASC_DIALECT_DIALECT diff --git a/jasc/dialect/jasc.py b/jasc/dialect/jasc.py index 5f8d91892f1f..3a239e256042 100644 --- a/jasc/dialect/jasc.py +++ b/jasc/dialect/jasc.py @@ -1,2 +1,6 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + from ._ops_gen import * from .._mlir_libs._mlirDialectsJasc import * \ No newline at end of file diff --git a/jasc/dialect/ops.cc b/jasc/dialect/ops.cc index 6e9c5dc5a95a..2db6a029ec33 100644 --- a/jasc/dialect/ops.cc +++ b/jasc/dialect/ops.cc @@ -1,3 +1,11 @@ +//===-- ops.cc - Jasc op implementations ------------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include "ops.h" #include "llvm/include/llvm/ADT/SmallVector.h" diff --git a/jasc/dialect/ops.h b/jasc/dialect/ops.h index 73226910bed9..d4989f7de544 100644 --- a/jasc/dialect/ops.h +++ b/jasc/dialect/ops.h @@ -1,3 +1,11 @@ +//===-- ops.h - Ops of the Jasc dialect -------------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #ifndef THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_OPS_H_ #define THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_OPS_H_ diff --git a/jasc/dialect/ops.td b/jasc/dialect/ops.td index 96666bbe5d5f..83cb02ad246e 100644 --- a/jasc/dialect/ops.td +++ b/jasc/dialect/ops.td @@ -1,3 +1,7 @@ +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + #ifndef JASC_DIALECT_OPS #define JASC_DIALECT_OPS diff --git a/jasc/dialect/ops_py.td b/jasc/dialect/ops_py.td index 6abc4a345234..cdf6fda55741 100644 --- a/jasc/dialect/ops_py.td +++ b/jasc/dialect/ops_py.td @@ -1,3 +1,7 @@ +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + #ifndef JASC_DIALECT_OPSPY #define JASC_DIALECT_OPSPY diff --git a/jasc/gpu_lowering_passes.cc b/jasc/gpu_lowering_passes.cc index 26053d81bf98..c5fbfa4b1226 100644 --- a/jasc/gpu_lowering_passes.cc +++ b/jasc/gpu_lowering_passes.cc @@ -1,3 +1,11 @@ +//===-- gpu_lowering_passes.cc - Passes for GPU lowerings -------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include #include diff --git a/jasc/gpu_lowering_passes.h b/jasc/gpu_lowering_passes.h index 97da37d7cefd..636e761f5bcb 100644 --- a/jasc/gpu_lowering_passes.h +++ b/jasc/gpu_lowering_passes.h @@ -1,3 +1,11 @@ +//===-- gpu_lowering_passes.h - Passes for GPU lowering ---------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #ifndef THIRD_PARTY_MLIR_EDGE_JASC_GPU_LOWERING_PASSES_H_ #define THIRD_PARTY_MLIR_EDGE_JASC_GPU_LOWERING_PASSES_H_ diff --git a/jasc/gpu_post_bufferize.mlir b/jasc/gpu_post_bufferize.mlir index 5d1bd4be7740..6a64de8ea1d4 100644 --- a/jasc/gpu_post_bufferize.mlir +++ b/jasc/gpu_post_bufferize.mlir @@ -1,3 +1,7 @@ +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + // Transform script for GPU post-bufferization codegen. transform.sequence failures(suppress) { ^bb0(%arg0: !transform.any_op): diff --git a/jasc/jasc.py b/jasc/jasc.py index b68e72edbc0a..aac9bce86c34 100644 --- a/jasc/jasc.py +++ b/jasc/jasc.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Schedules for JAX. To compile a function using Jasc, use Jasc.jit instead of jax.jit. diff --git a/jasc/jasc_opt.cc b/jasc/jasc_opt.cc index aee9cc224043..bb2a83b32e70 100644 --- a/jasc/jasc_opt.cc +++ b/jasc/jasc_opt.cc @@ -1,3 +1,11 @@ +//===-- jasc_opt.cc - jasc-opt executable -----------------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include "mlir/InitAllDialects.h" #include "mlir/InitAllExtensions.h" #include "mlir/InitAllPasses.h" diff --git a/jasc/mlir_lowering.cc b/jasc/mlir_lowering.cc index 97f51d9722eb..b5ae5cf12b65 100644 --- a/jasc/mlir_lowering.cc +++ b/jasc/mlir_lowering.cc @@ -1,3 +1,11 @@ +//===-- mlir_lowering.cc - Lowering passes for Jasc dialect -----*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include "mlir_lowering.h" #include diff --git a/jasc/mlir_lowering.h b/jasc/mlir_lowering.h index 4d9dccabef2a..6efc362282ea 100644 --- a/jasc/mlir_lowering.h +++ b/jasc/mlir_lowering.h @@ -1,3 +1,11 @@ +//===-- mlir_lowering.h - Passws for lowering Jasc --------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #ifndef THIRD_PARTY_MLIR_EDGE_JASC_MLIR_LOWERING_H_ #define THIRD_PARTY_MLIR_EDGE_JASC_MLIR_LOWERING_H_ diff --git a/jasc/primitives.py b/jasc/primitives.py index bb9d78168341..71037adddf05 100644 --- a/jasc/primitives.py +++ b/jasc/primitives.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Jax primitives backing Jasc schedules.""" from collections.abc import Callable, Sequence diff --git a/jasc/test/BUILD b/jasc/test/BUILD index b8a931ba48dc..4a3bdaea8ad7 100644 --- a/jasc/test/BUILD +++ b/jasc/test/BUILD @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + # JASC filecheck tests load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") @@ -6,6 +10,7 @@ load("@pip_deps//:requirements.bzl", "requirement") load("@bazel_skylib//rules:expand_template.bzl", "expand_template") package( + default_applicable_licenses = ["//:license"], default_visibility = ["//:__subpackages__"], ) diff --git a/jasc/test/abstractions.py b/jasc/test/abstractions.py index 0ae974930d45..9986c5bb2ba1 100644 --- a/jasc/test/abstractions.py +++ b/jasc/test/abstractions.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Tests for JASC transform op abstractions.""" from __future__ import annotations diff --git a/jasc/test/autotuning.py b/jasc/test/autotuning.py index e5963a88505f..3be3f10e89c0 100644 --- a/jasc/test/autotuning.py +++ b/jasc/test/autotuning.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Tests for the JASC autotuning utilities.""" from typing import Tuple diff --git a/jasc/test/batch_matmul_gpu.py b/jasc/test/batch_matmul_gpu.py index e62e075b62f0..806affbeff23 100644 --- a/jasc/test/batch_matmul_gpu.py +++ b/jasc/test/batch_matmul_gpu.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + from typing import Tuple # XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. diff --git a/jasc/test/bindings.py b/jasc/test/bindings.py index 07be02128340..0f137a0617d1 100644 --- a/jasc/test/bindings.py +++ b/jasc/test/bindings.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + from typing import Callable, Sequence from absl import app diff --git a/jasc/test/cpu_integration.py b/jasc/test/cpu_integration.py index e9ea7c5e3bf2..e1e1ac5a2287 100644 --- a/jasc/test/cpu_integration.py +++ b/jasc/test/cpu_integration.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Jasc tests common to all platforms.""" from collections.abc import Mapping diff --git a/jasc/test/diagnostics.py b/jasc/test/diagnostics.py index bb80644897ca..63811a14aade 100644 --- a/jasc/test/diagnostics.py +++ b/jasc/test/diagnostics.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Tests for MLIR diagnostics.""" from __future__ import annotations diff --git a/jasc/test/gpu_integration.py b/jasc/test/gpu_integration.py index 16ddc0fb1555..a7788f738731 100644 --- a/jasc/test/gpu_integration.py +++ b/jasc/test/gpu_integration.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """GPU-specific tests for Jasc.""" # XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. diff --git a/jasc/test/jit.py b/jasc/test/jit.py index 860901860329..00997d861d47 100644 --- a/jasc/test/jit.py +++ b/jasc/test/jit.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Tests for JASC jit.""" from __future__ import annotations diff --git a/jasc/test/lit.cfg.py b/jasc/test/lit.cfg.py index 878e1ca09acb..6037bbfbdb31 100644 --- a/jasc/test/lit.cfg.py +++ b/jasc/test/lit.cfg.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + # -*- Python -*- import os diff --git a/jasc/test/lit.site.cfg.in.py b/jasc/test/lit.site.cfg.in.py index a86f9955798a..8231e179ab6e 100644 --- a/jasc/test/lit.site.cfg.in.py +++ b/jasc/test/lit.site.cfg.in.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + @LIT_SITE_CFG_IN_HEADER@ import os.path diff --git a/jasc/test/matmul_cpu.py b/jasc/test/matmul_cpu.py index e386c73c418b..d483ce42a566 100644 --- a/jasc/test/matmul_cpu.py +++ b/jasc/test/matmul_cpu.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Integration test of full schedules for `jax.numpy.matmul`.""" from typing import Tuple diff --git a/jasc/test/matmul_gpu.py b/jasc/test/matmul_gpu.py index 1da91c236fad..591a8c303337 100644 --- a/jasc/test/matmul_gpu.py +++ b/jasc/test/matmul_gpu.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Integration test of full schedules for `jax.numpy.matmul`.""" from typing import Tuple diff --git a/jasc/test/normalization.py b/jasc/test/normalization.py index 9246ff522d01..7c52b9a174ce 100644 --- a/jasc/test/normalization.py +++ b/jasc/test/normalization.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Tests for JASC transform op abstractions.""" from __future__ import annotations diff --git a/jasc/test/tag.py b/jasc/test/tag.py index 93f6f84e1183..6c769742bd3f 100644 --- a/jasc/test/tag.py +++ b/jasc/test/tag.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Jasc tag primitive based schedule tests.""" from __future__ import annotations from typing import Callable, Sequence diff --git a/jasc/transform_ops/BUILD b/jasc/transform_ops/BUILD index 9c757d83d0ac..7adbf46f858c 100644 --- a/jasc/transform_ops/BUILD +++ b/jasc/transform_ops/BUILD @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + # JASC extension for the MLIR transform dialect. load("@rules_python//python:defs.bzl", "py_library") @@ -6,6 +10,7 @@ load("@llvm-project//mlir:build_defs.bzl", "cc_headers_only") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") package( + default_applicable_licenses = ["//:license"], default_visibility = ["//visibility:public"], ) diff --git a/jasc/transform_ops/_ods_common.py b/jasc/transform_ops/_ods_common.py index abcb9e8b64cd..596e5b919d8a 100644 --- a/jasc/transform_ops/_ods_common.py +++ b/jasc/transform_ops/_ods_common.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Trampoline to run generated MLIR Python code. Generated tablegen dialects expect to be able to find some symbols from the diff --git a/jasc/transform_ops/_transform_ops_gen.py b/jasc/transform_ops/_transform_ops_gen.py index cf90c6646f5e..de9976501c73 100644 --- a/jasc/transform_ops/_transform_ops_gen.py +++ b/jasc/transform_ops/_transform_ops_gen.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Trampoline to run generated MLIR Python code. Generated tablegen dialects expect to be able to find some symbols from the diff --git a/jasc/transform_ops/dialect_extension.cc b/jasc/transform_ops/dialect_extension.cc index ca3a383af702..b576716c0d4b 100644 --- a/jasc/transform_ops/dialect_extension.cc +++ b/jasc/transform_ops/dialect_extension.cc @@ -1,3 +1,11 @@ +//===-- dialect_extension.cc - TD extension for Jasc ------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include "dialect_extension.h" #include "jasc_transform_ops.h" diff --git a/jasc/transform_ops/dialect_extension.h b/jasc/transform_ops/dialect_extension.h index 1e7ae3fe05d6..4f84b5f48bdd 100644 --- a/jasc/transform_ops/dialect_extension.h +++ b/jasc/transform_ops/dialect_extension.h @@ -1,3 +1,11 @@ +//===-- dialect_extension.h - TD extension for Jasc -------------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #ifndef THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORM_OPS_DIALECT_EXTENSION_H_ #define THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORM_OPS_DIALECT_EXTENSION_H_ diff --git a/jasc/transform_ops/jasc_transform_ops.cc b/jasc/transform_ops/jasc_transform_ops.cc index de592b27bab8..cd7627daf266 100644 --- a/jasc/transform_ops/jasc_transform_ops.cc +++ b/jasc/transform_ops/jasc_transform_ops.cc @@ -1,3 +1,11 @@ +//===-- jasc_transform_ops.cc - Transform ops for Jasc dialect --*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include "jasc_transform_ops.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" diff --git a/jasc/transform_ops/jasc_transform_ops.h b/jasc/transform_ops/jasc_transform_ops.h index bb7147d731ae..c8f8d1923ff2 100644 --- a/jasc/transform_ops/jasc_transform_ops.h +++ b/jasc/transform_ops/jasc_transform_ops.h @@ -1,3 +1,11 @@ +//===-- jasc_transform_ops.h - Transform ops for Jasc dialect ---*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #ifndef THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORMOPS_JASCTRANSFORMOPS_H_ #define THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORMOPS_JASCTRANSFORMOPS_H_ diff --git a/jasc/transform_ops/jasc_transform_ops.py b/jasc/transform_ops/jasc_transform_ops.py index 4b31d48da586..e8c10265a3d5 100644 --- a/jasc/transform_ops/jasc_transform_ops.py +++ b/jasc/transform_ops/jasc_transform_ops.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + from ._jasc_transform_ops_gen import * from ._jasc_transform_ops_gen import _Dialect from ..._mlir_libs._mlirTransformOpsJasc import * diff --git a/jasc/transform_ops/jasc_transform_ops.td b/jasc/transform_ops/jasc_transform_ops.td index db1644e11934..8a97bee8c5c6 100644 --- a/jasc/transform_ops/jasc_transform_ops.td +++ b/jasc/transform_ops/jasc_transform_ops.td @@ -1,3 +1,7 @@ +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + #ifndef JASC_JASC_TRANSFORM_OPS_TD #define JASC_JASC_TRANSFORM_OPS_TD diff --git a/jasc/tuner.py b/jasc/tuner.py index 60d06af4f901..098816cb77b5 100644 --- a/jasc/tuner.py +++ b/jasc/tuner.py @@ -1,3 +1,7 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Utilities for custom autotuners for parametric transforms.""" from __future__ import annotations From 4fd8bfd16d93922316d095b97808fa0de3b1bbce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 15 Jan 2024 09:24:48 +0000 Subject: [PATCH 04/21] Run buildifier on all BUILD files for linting. --- jasc/BUILD | 32 +++++++++++++++++++------------- jasc/WORKSPACE | 30 +++++++++++++++++++----------- jasc/dialect/BUILD | 12 ++++++------ jasc/test/BUILD | 16 ++++++++-------- jasc/transform_ops/BUILD | 4 ++-- 5 files changed, 54 insertions(+), 40 deletions(-) diff --git a/jasc/BUILD b/jasc/BUILD index 1ce1a34ed0d3..10d3ab2e9d92 100644 --- a/jasc/BUILD +++ b/jasc/BUILD @@ -23,19 +23,22 @@ package( py_library( name = "jasc", - srcs = ["jasc.py", "__init__.py"], + srcs = [ + "__init__.py", + "jasc.py", + ], deps = [ ":call_kernel", ":primitives", "//dialect:python", "//transform_ops", - "@jax1//jax:jax", + "@jax1//jax", "@jax1//jaxlib/mlir:bufferization_dialect", "@jax1//jaxlib/mlir:core", "@jax1//jaxlib/mlir:ir", + "@jax1//jaxlib/mlir:jasc_dialect", "@jax1//jaxlib/mlir:pdl_dialect", "@jax1//jaxlib/mlir:transform_dialect", - "@jax1//jaxlib/mlir:jasc_dialect", ], ) @@ -44,7 +47,7 @@ py_library( srcs = ["tuner.py"], deps = [ ":jasc", - "@jax1//jax:jax", + "@jax1//jax", "@jax1//jaxlib/mlir:ir", "@jax1//jaxlib/mlir:jasc_dialect", "@jax1//jaxlib/mlir:transform_dialect", @@ -57,7 +60,7 @@ py_library( deps = [ ":call_kernel", "//dialect:python", - "@jax1//jax:jax", + "@jax1//jax", "@jax1//jax:extend", "@jax1//jaxlib/mlir:ir", "@jax1//jaxlib/mlir:pdl_dialect", @@ -111,7 +114,7 @@ cc_binary( "-Wl,-rpath='$$ORIGIN'", ], linkshared = 1, - deps = ["@llvm-project//mlir:mlir_c_runner_utils",], + deps = ["@llvm-project//mlir:mlir_c_runner_utils"], ) pybind_extension( @@ -120,15 +123,15 @@ pybind_extension( deps = [ ":call_kernel_shared_library", ":libmlir_c_runner_utils.so", - "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", ":mlir_lowering_shared_library", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", + "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", "@llvm-project//mlir:ExecutionEngine", - "@status_macros//:status_macros", "@pybind11_abseil//pybind11_abseil:import_status_module", + "@status_macros", ], ) @@ -190,7 +193,10 @@ cc_binary( cc_library( name = "mlir_lowering_shared_library", - srcs = [":libmlirlowering.so", "mlir_lowering.h"], + srcs = [ + "mlir_lowering.h", + ":libmlirlowering.so", + ], deps = [":mlir_lowering_shared_library_deps_headers"], ) @@ -211,9 +217,9 @@ cc_library( ], data = ["gpu_post_bufferize.mlir"], deps = [ + ":mlir_lowering_shared_library_deps_headers", "//dialect:jasc_dialect_headers", "//transform_ops:jasc_transform_ops_headers", - ":mlir_lowering_shared_library_deps_headers", ], alwayslink = True, ) @@ -223,13 +229,13 @@ cc_binary( srcs = ["jasc_opt.cc"], deps = [ ":mlir_lowering", + "//dialect", + "//transform_ops:jasc_transform_ops_shared_library", + "@com_google_absl//absl/status:statusor", "@llvm-project//mlir:AllExtensions", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllToLLVMIRTranslations", "@llvm-project//mlir:MlirOptLib", "@xla//xla/mlir_hlo:mhlo_passes", - "//dialect", - "//transform_ops:jasc_transform_ops_shared_library", - "@com_google_absl//absl/status:statusor", ], ) diff --git a/jasc/WORKSPACE b/jasc/WORKSPACE index 9f5625842c4b..64c276a37364 100644 --- a/jasc/WORKSPACE +++ b/jasc/WORKSPACE @@ -16,11 +16,11 @@ LICENSERULES_SHA256 = "4531deccb913639c30e5c7512a054d5d875698daeb75d8cf90f284375 http_archive( name = "rules_license", + sha256 = LICENSERULES_SHA256, urls = [ "https://mirror.bazel.build/github.com/bazelbuild/rules_license/releases/download/{version}/rules_license-{version}.tar.gz".format(version = LICENSERULES_VERSION), "https://github.com/bazelbuild/rules_license/releases/download/{version}/rules_license-{version}.tar.gz".format(version = LICENSERULES_VERSION), ], - sha256 = LICENSERULES_SHA256, ) # @@ -49,21 +49,29 @@ rules_cc_dependencies() LLVM_COMMIT = "2f17c9f65e7da50a77101431ddf7f6ed7e1ea92c" LLVM_SHA256 = "a986740933506ebd1127c8abb64c78655a8c329798f37fd466a8e0f7aa7a5578" -LLVM_TARGETS = ["X86", "AArch64", "AMDGPU"] + +LLVM_TARGETS = [ + "X86", + "AArch64", + "AMDGPU", +] http_archive( name = "llvm-raw", build_file_content = "# empty", + patch_args = ["-p1"], + patches = ["//:patches/llvm_build.patch"], sha256 = LLVM_SHA256, strip_prefix = "llvm-project-" + LLVM_COMMIT, urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)], - patch_args = ["-p1"], - patches = ["//:patches/llvm_build.patch"] ) load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure") -llvm_configure(name = "llvm-project", targets = LLVM_TARGETS) +llvm_configure( + name = "llvm-project", + targets = LLVM_TARGETS, +) # # @xla. @@ -74,11 +82,11 @@ XLA_SHA256 = "2b6a3ffdb3acf73eaa9b312407400b09c740450ab2222433890712dd4a402a0f" http_archive( name = "xla", + patch_args = ["-p1"], + patches = ["//:patches/xla.patch"], sha256 = XLA_SHA256, strip_prefix = "xla-" + XLA_COMMIT, urls = ["https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)], - patch_args = ["-p1"], - patches = ["//:patches/xla.patch"], ) # Note: Further loading below in conjuction with JAX. @@ -94,7 +102,7 @@ http_archive( name = "rules_python", sha256 = PYRULES_SHA256, strip_prefix = "rules_python-" + PYRULES_COMMIT, - urls = ["https://github.com/bazelbuild/rules_python/archive/{commit}.tar.gz".format(commit = PYRULES_COMMIT)] + urls = ["https://github.com/bazelbuild/rules_python/archive/{commit}.tar.gz".format(commit = PYRULES_COMMIT)], ) load("@rules_python//python:repositories.bzl", "py_repositories") @@ -114,11 +122,11 @@ JAX_SHA256 = "6e2147be7360a5c0672b6ba0d654cdb2ac96113b63ef457dfdc76cd50fe69ff1" http_archive( name = "jax1", + patch_args = ["-p1"], + patches = ["//:patches/jax.patch"], sha256 = JAX_SHA256, strip_prefix = "jax-" + JAX_COMMIT, urls = ["https://github.com/google/jax/archive/{commit}.tar.gz".format(commit = JAX_COMMIT)], - patch_args = ["-p1"], - patches = ["//:patches/jax.patch"], ) # @@ -170,8 +178,8 @@ http_archive( http_archive( name = "pybind11", - sha256 = PYBIND11_SHA256, build_file = "@pybind11_bazel//:pybind11.BUILD", + sha256 = PYBIND11_SHA256, strip_prefix = "pybind11_bazel-" + PYBIND_VERSION, urls = ["https://github.com/pybind/pybind11/archive/refs/tags/v{version}.tar.gz/".format(version = PYBIND_VERSION)], ) diff --git a/jasc/dialect/BUILD b/jasc/dialect/BUILD index 0e9f60aa1a82..76d56f760b5b 100644 --- a/jasc/dialect/BUILD +++ b/jasc/dialect/BUILD @@ -96,8 +96,8 @@ cc_library( ], deps = [ ":dialect_inc_gen", - ":ops_inc_gen", ":jasc_dialect_shared_library_deps_headers", + ":ops_inc_gen", ], alwayslink = True, ) @@ -148,14 +148,14 @@ gentbl_filegroup( cc_library( name = "jasc_dialect_shared_library", srcs = [ - ":libjascdialect.so", "dialect.h", "ops.h", + ":libjascdialect.so", ], deps = [ ":dialect_inc_gen", - ":ops_inc_gen", ":jasc_dialect_shared_library_deps_headers", + ":ops_inc_gen", ], ) @@ -173,12 +173,12 @@ pybind_extension( name = "bindings", srcs = ["bindings.cc"], deps = [ - "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", ":jasc_dialect_headers", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:MLIRBindingsPythonHeaders", "//:mlir_lowering_shared_library", "//transform_ops:jasc_transform_ops_shared_library", + "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:MLIRBindingsPythonHeaders", ], ) diff --git a/jasc/test/BUILD b/jasc/test/BUILD index 4a3bdaea8ad7..e87846678cd6 100644 --- a/jasc/test/BUILD +++ b/jasc/test/BUILD @@ -24,7 +24,7 @@ py_library( "//:jasc", requirement("chex"), requirement("pytest"), - "@jax1//jax:jax", + "@jax1//jax", ], ) @@ -91,7 +91,7 @@ py_binary( requirement("absl-py"), requirement("ml_dtypes"), requirement("opt_einsum"), - "@jax1//jax:jax", + "@jax1//jax", "@jax1//jaxlib/mlir:ir", "@jax1//jaxlib/mlir:jasc_dialect", ], @@ -116,7 +116,7 @@ py_binary( "//:jasc", requirement("absl-py"), requirement("chex"), - "@jax1//jax:jax", + "@jax1//jax", ], ) @@ -128,7 +128,7 @@ py_test( "//:jasc", requirement("chex"), requirement("pytest"), - "@jax1//jax:jax", + "@jax1//jax", "@jax1//jaxlib/mlir:ir", "@jax1//jaxlib/mlir:transform_dialect", ], @@ -142,7 +142,7 @@ py_test( "//:jasc", requirement("chex"), requirement("pytest"), - "@jax1//jax:jax", + "@jax1//jax", "@jax1//jaxlib/mlir:ir", "@jax1//jaxlib/mlir:transform_dialect", ], @@ -156,7 +156,7 @@ py_test( requirement("absl-py"), requirement("chex"), requirement("pytest"), - "@jax1//jax:jax", + "@jax1//jax", "@jax1//jaxlib/mlir:transform_dialect", ], ) @@ -170,7 +170,7 @@ py_test( requirement("absl-py"), requirement("chex"), requirement("pytest"), - "@jax1//jax:jax", + "@jax1//jax", "@jax1//jaxlib/mlir:ir", "@jax1//jaxlib/mlir:linalg_dialect", ], @@ -183,7 +183,7 @@ py_test( "//:jasc", requirement("chex"), requirement("pytest"), - "@jax1//jax:jax", + "@jax1//jax", ], ) diff --git a/jasc/transform_ops/BUILD b/jasc/transform_ops/BUILD index 7adbf46f858c..de88ffbb353b 100644 --- a/jasc/transform_ops/BUILD +++ b/jasc/transform_ops/BUILD @@ -59,8 +59,8 @@ pybind_extension( name = "bindings", srcs = ["bindings.cpp"], deps = [ - "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", ":jasc_transform_ops_shared_library", + "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", "@llvm-project//mlir:CAPIIRHeaders", "@llvm-project//mlir:MLIRBindingsPythonHeaders", ], @@ -134,7 +134,7 @@ cc_library( includes = ["."], deps = [ ":jasc_transform_ops_inc_gen", - ":jasc_transform_ops_shared_library_deps_headers" + ":jasc_transform_ops_shared_library_deps_headers", ], alwayslink = True, ) From a9fe0318e674b752cb96787de0dd1a5438b2dc1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 15 Jan 2024 10:57:56 +0000 Subject: [PATCH 05/21] Document options in .bazelrc. --- jasc/.bazelrc | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/jasc/.bazelrc b/jasc/.bazelrc index baacef0d21ba..e578406db017 100644 --- a/jasc/.bazelrc +++ b/jasc/.bazelrc @@ -1,3 +1,7 @@ +# Except where documented otherwise, the options in this file have been copied +# blindly from @EnzymeAD/Enzyme-JAX. The build still seems to work fine without +# any of the `define`s but they are kept just in case. + build --announce_rc build --experimental_repo_remote_exec @@ -5,14 +9,15 @@ build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 build --cxxopt=-w --host_cxxopt=-w build --define=grpc_no_ares=true build --define=tsl_link_protobuf=true -build --define open_source_build=true +build --define=open_source_build=true -build --define framework_shared_object=true -build --define tsl_protobuf_header_only=true +build --define=framework_shared_object=true +build --define=tsl_protobuf_header_only=true build --define=use_fast_cpp_protos=true build --define=allow_oversize_protos=true +# Sets the name of JAX's MLIR native extension. This exact value is expected +# by the Python files of JAX. build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. build -c opt - From 7021ceda555a29845f7df05ca4023c8e18aafe38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 15 Jan 2024 11:05:50 +0000 Subject: [PATCH 06/21] Add hacky patch to make lit work in bazel. --- jasc/patches/llvm_build.patch | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/jasc/patches/llvm_build.patch b/jasc/patches/llvm_build.patch index d6849fa9b438..6e8494a0c776 100644 --- a/jasc/patches/llvm_build.patch +++ b/jasc/patches/llvm_build.patch @@ -92,3 +92,17 @@ index 0cc28fd856bc..51764826a130 100644 ) # Indirection to avoid 'libmlir_c_runner_utils.so' filename clash. + +--- a/llvm/utils/lit/lit.py ++++ b/llvm/utils/lit/lit.py +@@ -1,5 +1,10 @@ + #!/usr/bin/env python3 + ++from os import path ++import sys ++ ++sys.path.append(path.dirname(__file__)) ++ + from lit.main import main + + if __name__ == "__main__": From b02b947754270242eeb9edb1556eac3e7bbf14af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 15 Jan 2024 11:12:51 +0000 Subject: [PATCH 07/21] Add missing BUILD dependencies for tests. --- jasc/test/BUILD | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/jasc/test/BUILD b/jasc/test/BUILD index e87846678cd6..1d1eeb6203ea 100644 --- a/jasc/test/BUILD +++ b/jasc/test/BUILD @@ -77,6 +77,8 @@ py_binary( "//:jasc", "//transform_ops", requirement("absl-py"), + requirement("ml_dtypes"), + requirement("opt_einsum"), "@jax1//jaxlib/mlir:ir", "@jax1//jaxlib/mlir:jasc_dialect", "@jax1//jaxlib/mlir:scf_dialect", @@ -103,6 +105,9 @@ py_binary( deps = [ "//:jasc", "//transform_ops", + requirement("absl-py"), + requirement("ml_dtypes"), + requirement("opt_einsum"), "@jax1//jaxlib/mlir:ir", "@jax1//jaxlib/mlir:scf_dialect", "@jax1//jaxlib/mlir:transform_dialect", From e3b1e34d7683f407328dab937264e5d4741734a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 15 Jan 2024 14:58:50 +0000 Subject: [PATCH 08/21] Fix crash in initialization of ExecutionEngineOptions. --- jasc/call_kernel.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jasc/call_kernel.cc b/jasc/call_kernel.cc index f73e767e2e68..fd5b89adb2f7 100644 --- a/jasc/call_kernel.cc +++ b/jasc/call_kernel.cc @@ -196,7 +196,9 @@ absl::StatusOr> CreateCpuKernel( RETURN_IF_ERROR(LowerStableHloToCpuLLVM(module, dump_ir)); mlir::ExecutionEngineOptions engine_opts; // TODO(ulysse): Select LLVM opt level. - engine_opts.sharedLibPaths = {"libmlir_c_runner_utils.so"}; + static constexpr std::array sharedLibPaths = { + "libmlir_c_runner_utils.so"}; + engine_opts.sharedLibPaths = sharedLibPaths; engine_opts.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Default; auto engineOrError = mlir::ExecutionEngine::create(module, engine_opts); if (!engineOrError) { From 3e3a5db1ad2d63de471edc12d17b3c6a7f7a165f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 15 Jan 2024 15:33:37 +0000 Subject: [PATCH 09/21] Make path/label to requirements.txt absolute. Otherwise, for some weird reason, the file is expected in `external/`, and putting it there may collide with some Bazel internals. --- jasc/WORKSPACE | 4 ++-- jasc/{external => }/requirements-top-level.txt | 0 jasc/{external => }/requirements.txt | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename jasc/{external => }/requirements-top-level.txt (100%) rename jasc/{external => }/requirements.txt (100%) diff --git a/jasc/WORKSPACE b/jasc/WORKSPACE index 64c276a37364..f59b4e13fb3d 100644 --- a/jasc/WORKSPACE +++ b/jasc/WORKSPACE @@ -220,14 +220,14 @@ http_archive( ) # -# Python dependencies via pip +# Python dependencies via pip. # load("@rules_python//python:pip.bzl", "pip_parse") pip_parse( name = "pip_deps", - requirements_lock = ":requirements.txt", + requirements_lock = "//:requirements.txt", ) load("@pip_deps//:requirements.bzl", "install_deps") diff --git a/jasc/external/requirements-top-level.txt b/jasc/requirements-top-level.txt similarity index 100% rename from jasc/external/requirements-top-level.txt rename to jasc/requirements-top-level.txt diff --git a/jasc/external/requirements.txt b/jasc/requirements.txt similarity index 100% rename from jasc/external/requirements.txt rename to jasc/requirements.txt From 3101624666b9e502e7c278aae01a8473e95597b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 15 Jan 2024 15:35:38 +0000 Subject: [PATCH 10/21] Add compile command extractor to WORKSPACE. This sets up the following command, which creates/updates the `compile_commands.json` file, which is useful for auto-completion and other tools: bazel run @hedron_compile_commands//:refresh_all --- jasc/WORKSPACE | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/jasc/WORKSPACE b/jasc/WORKSPACE index f59b4e13fb3d..7e0efbae7dd7 100644 --- a/jasc/WORKSPACE +++ b/jasc/WORKSPACE @@ -233,3 +233,28 @@ pip_parse( load("@pip_deps//:requirements.bzl", "install_deps") install_deps() + +# +# Hedron's Compile Commands Extractor. +# +CCEXTRACT_COMMIT = "ceeb5dbdefb8839a1e29cc242bc1fe755a43609c" +CCEXTRACT_SHA256 = "4e54e689d138462b568b9b3c4f83248eb112dc5b973ef92c190d4c8c2b0a4a9a" + +http_archive( + name = "hedron_compile_commands", + sha256 = CCEXTRACT_SHA256, + strip_prefix = "bazel-compile-commands-extractor-{commit}".format(commit=CCEXTRACT_COMMIT), + url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/{commit}.tar.gz".format(commit=CCEXTRACT_COMMIT), +) + +load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup") +hedron_compile_commands_setup() + +load("@hedron_compile_commands//:workspace_setup_transitive.bzl", "hedron_compile_commands_setup_transitive") +hedron_compile_commands_setup_transitive() + +load("@hedron_compile_commands//:workspace_setup_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive") +hedron_compile_commands_setup_transitive_transitive() + +load("@hedron_compile_commands//:workspace_setup_transitive_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive_transitive") +hedron_compile_commands_setup_transitive_transitive_transitive() From 941e6c057f2cf912e1effe92382f6aee831107d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 15 Jan 2024 15:59:13 +0000 Subject: [PATCH 11/21] Adapt header guards to new source location. --- jasc/dialect/dialect.h | 6 +++--- jasc/dialect/ops.h | 6 +++--- jasc/gpu_lowering_passes.h | 6 +++--- jasc/mlir_lowering.h | 6 +++--- jasc/transform_ops/dialect_extension.h | 6 +++--- jasc/transform_ops/jasc_transform_ops.h | 6 +++--- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/jasc/dialect/dialect.h b/jasc/dialect/dialect.h index cf8f60c485c8..c9d01da429c6 100644 --- a/jasc/dialect/dialect.h +++ b/jasc/dialect/dialect.h @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// -#ifndef THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_DIALECT_H_ -#define THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_DIALECT_H_ +#ifndef JASC_DIALECT_DIALECT_H_ +#define JASC_DIALECT_DIALECT_H_ #include "mlir/IR/Dialect.h" // Include code generated from dialect.td. #include "dialect/dialect.h.inc" -#endif // THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_DIALECT_H_ +#endif // JASC_DIALECT_DIALECT_H_ diff --git a/jasc/dialect/ops.h b/jasc/dialect/ops.h index d4989f7de544..fff8b9d5502c 100644 --- a/jasc/dialect/ops.h +++ b/jasc/dialect/ops.h @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_OPS_H_ -#define THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_OPS_H_ +#ifndef JASC_DIALECT_OPS_H_ +#define JASC_DIALECT_OPS_H_ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" @@ -18,4 +18,4 @@ #define GET_OP_CLASSES #include "dialect/ops.h.inc" -#endif // THIRD_PARTY_MLIR_EDGE_JASC_DIALECT_OPS_H_ +#endif // JASC_DIALECT_OPS_H_ diff --git a/jasc/gpu_lowering_passes.h b/jasc/gpu_lowering_passes.h index 636e761f5bcb..1359a4379c28 100644 --- a/jasc/gpu_lowering_passes.h +++ b/jasc/gpu_lowering_passes.h @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef THIRD_PARTY_MLIR_EDGE_JASC_GPU_LOWERING_PASSES_H_ -#define THIRD_PARTY_MLIR_EDGE_JASC_GPU_LOWERING_PASSES_H_ +#ifndef JASC_GPU_LOWERING_PASSES_H_ +#define JASC_GPU_LOWERING_PASSES_H_ #include @@ -48,4 +48,4 @@ void registerGPULoweringPasses(); } // namespace jasc -#endif // THIRD_PARTY_MLIR_EDGE_JASC_GPU_LOWERING_PASSES_H_ \ No newline at end of file +#endif // JASC_GPU_LOWERING_PASSES_H_ \ No newline at end of file diff --git a/jasc/mlir_lowering.h b/jasc/mlir_lowering.h index 6efc362282ea..056a9d388382 100644 --- a/jasc/mlir_lowering.h +++ b/jasc/mlir_lowering.h @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef THIRD_PARTY_MLIR_EDGE_JASC_MLIR_LOWERING_H_ -#define THIRD_PARTY_MLIR_EDGE_JASC_MLIR_LOWERING_H_ +#ifndef JASC_MLIR_LOWERING_H_ +#define JASC_MLIR_LOWERING_H_ #include "absl/status/status.h" #include "mlir/IR/BuiltinOps.h" @@ -26,4 +26,4 @@ void registerMLIRLoweringPasses(); } // namespace jasc -#endif // THIRD_PARTY_MLIR_EDGE_JASC_MLIR_LOWERING_H_ +#endif // JASC_MLIR_LOWERING_H_ diff --git a/jasc/transform_ops/dialect_extension.h b/jasc/transform_ops/dialect_extension.h index 4f84b5f48bdd..d29ca6f9892c 100644 --- a/jasc/transform_ops/dialect_extension.h +++ b/jasc/transform_ops/dialect_extension.h @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORM_OPS_DIALECT_EXTENSION_H_ -#define THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORM_OPS_DIALECT_EXTENSION_H_ +#ifndef JASC_TRANSFORM_OPS_DIALECT_EXTENSION_H_ +#define JASC_TRANSFORM_OPS_DIALECT_EXTENSION_H_ #include "mlir/IR/DialectRegistry.h" @@ -15,4 +15,4 @@ namespace jasc { void registerTransformDialectExtension(mlir::DialectRegistry ®istry); } -#endif // THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORM_OPS_DIALECT_EXTENSION_H_ +#endif // JASC_TRANSFORM_OPS_DIALECT_EXTENSION_H_ diff --git a/jasc/transform_ops/jasc_transform_ops.h b/jasc/transform_ops/jasc_transform_ops.h index c8f8d1923ff2..66836a5f41ba 100644 --- a/jasc/transform_ops/jasc_transform_ops.h +++ b/jasc/transform_ops/jasc_transform_ops.h @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORMOPS_JASCTRANSFORMOPS_H_ -#define THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORMOPS_JASCTRANSFORMOPS_H_ +#ifndef JASC_TRANSFORMOPS_JASCTRANSFORMOPS_H_ +#define JASC_TRANSFORMOPS_JASCTRANSFORMOPS_H_ #include @@ -18,4 +18,4 @@ #define GET_OP_CLASSES #include "jasc_transform_ops.h.inc" -#endif // THIRD_PARTY_MLIR_EDGE_JASC_TRANSFORMOPS_JASCTRANSFORMOPS_H_ +#endif // JASC_TRANSFORMOPS_JASCTRANSFORMOPS_H_ From 0ffc5f44cc90ead84bb093524f808e25b5f1e8bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 15 Jan 2024 15:59:38 +0000 Subject: [PATCH 12/21] Update link to external location on Github. --- jasc/test/fold_fill_into_pad.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jasc/test/fold_fill_into_pad.mlir b/jasc/test/fold_fill_into_pad.mlir index 2292cae0b029..e4a6f58e50a8 100644 --- a/jasc/test/fold_fill_into_pad.mlir +++ b/jasc/test/fold_fill_into_pad.mlir @@ -3,7 +3,7 @@ // Test ported from: -// third_party/iree/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir +// https://github.com/openxla/iree/blob/a219cb5008a/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir // CHECK-LABEL: @pad_fill_to_fill func.func @pad_fill_to_fill(%arg0: tensor<31x62xf32>) -> tensor<32x64xf32> { // Check that a pad of a fill with the same constant is replaced by a From 65757ea947dc8306381138f3c03ffaee2d896f9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 15 Jan 2024 16:13:55 +0000 Subject: [PATCH 13/21] Add trailing new line to all files that didn't have one. --- jasc/dialect/bindings.cc | 2 +- jasc/dialect/dialect.td | 2 +- jasc/dialect/jasc.py | 2 +- jasc/dialect/ops.cc | 2 +- jasc/dialect/ops.td | 2 +- jasc/dialect/ops_py.td | 2 +- jasc/gpu_lowering_passes.h | 2 +- jasc/gpu_post_bufferize.mlir | 2 +- jasc/requirements-top-level.txt | 2 +- jasc/test/jit.py | 2 +- jasc/test/wrap-in-cpu-launch.mlir | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/jasc/dialect/bindings.cc b/jasc/dialect/bindings.cc index c32d063e3c83..8dd5d4fc8f61 100644 --- a/jasc/dialect/bindings.cc +++ b/jasc/dialect/bindings.cc @@ -31,4 +31,4 @@ PYBIND11_MODULE(_mlirDialectsJasc, m) { m.def("register_lowering_passes", []() { jasc::registerMLIRLoweringPasses(); }); -} \ No newline at end of file +} diff --git a/jasc/dialect/dialect.td b/jasc/dialect/dialect.td index 012eb2ff2ff3..cd34f43c8bc6 100644 --- a/jasc/dialect/dialect.td +++ b/jasc/dialect/dialect.td @@ -17,4 +17,4 @@ def Jasc_Dialect : Dialect { ]; } -#endif // JASC_DIALECT_DIALECT \ No newline at end of file +#endif // JASC_DIALECT_DIALECT diff --git a/jasc/dialect/jasc.py b/jasc/dialect/jasc.py index 3a239e256042..2f36fd0008d0 100644 --- a/jasc/dialect/jasc.py +++ b/jasc/dialect/jasc.py @@ -3,4 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._ops_gen import * -from .._mlir_libs._mlirDialectsJasc import * \ No newline at end of file +from .._mlir_libs._mlirDialectsJasc import * diff --git a/jasc/dialect/ops.cc b/jasc/dialect/ops.cc index 2db6a029ec33..9c15d3135905 100644 --- a/jasc/dialect/ops.cc +++ b/jasc/dialect/ops.cc @@ -20,4 +20,4 @@ namespace jasc { -} // namespace jasc \ No newline at end of file +} // namespace jasc diff --git a/jasc/dialect/ops.td b/jasc/dialect/ops.td index 83cb02ad246e..d6513ab95bd2 100644 --- a/jasc/dialect/ops.td +++ b/jasc/dialect/ops.td @@ -34,4 +34,4 @@ def Jasc_ReturnOp : Op { let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; } -#endif // JASC_DIALECT_OPS \ No newline at end of file +#endif // JASC_DIALECT_OPS diff --git a/jasc/dialect/ops_py.td b/jasc/dialect/ops_py.td index cdf6fda55741..7335d6d654d4 100644 --- a/jasc/dialect/ops_py.td +++ b/jasc/dialect/ops_py.td @@ -7,4 +7,4 @@ include "ops.td" -#endif // JASC_DIALECT_OPSPY \ No newline at end of file +#endif // JASC_DIALECT_OPSPY diff --git a/jasc/gpu_lowering_passes.h b/jasc/gpu_lowering_passes.h index 1359a4379c28..2b945c585f66 100644 --- a/jasc/gpu_lowering_passes.h +++ b/jasc/gpu_lowering_passes.h @@ -48,4 +48,4 @@ void registerGPULoweringPasses(); } // namespace jasc -#endif // JASC_GPU_LOWERING_PASSES_H_ \ No newline at end of file +#endif // JASC_GPU_LOWERING_PASSES_H_ diff --git a/jasc/gpu_post_bufferize.mlir b/jasc/gpu_post_bufferize.mlir index 6a64de8ea1d4..54649625e2b3 100644 --- a/jasc/gpu_post_bufferize.mlir +++ b/jasc/gpu_post_bufferize.mlir @@ -10,4 +10,4 @@ transform.sequence failures(suppress) { : (!transform.any_op) -> !transform.any_op transform.jasc.wrap_in_gpu_launch %linalg_ops : (!transform.any_op) -> !transform.op<"gpu.launch"> -} \ No newline at end of file +} diff --git a/jasc/requirements-top-level.txt b/jasc/requirements-top-level.txt index 30c598b8d07d..a0a995ffe622 100644 --- a/jasc/requirements-top-level.txt +++ b/jasc/requirements-top-level.txt @@ -1,4 +1,4 @@ absl-py chex pytest -PyYAML \ No newline at end of file +PyYAML diff --git a/jasc/test/jit.py b/jasc/test/jit.py index 00997d861d47..d57fdf4f9a11 100644 --- a/jasc/test/jit.py +++ b/jasc/test/jit.py @@ -74,4 +74,4 @@ def matmul(a: jax.Array, b: jax.Array) -> jax.Array: if __name__ == "__main__": args = sys.argv[1:] or ["-s", "-v"] - sys.exit(pytest.main([__file__] + args)) \ No newline at end of file + sys.exit(pytest.main([__file__] + args)) diff --git a/jasc/test/wrap-in-cpu-launch.mlir b/jasc/test/wrap-in-cpu-launch.mlir index 16a0422b566f..658b3bc1da29 100644 --- a/jasc/test/wrap-in-cpu-launch.mlir +++ b/jasc/test/wrap-in-cpu-launch.mlir @@ -37,4 +37,4 @@ func.func @already_wrapped_op(%arg0: f32, %arg1: memref<16xf32>) -> memref<16xf3 func.func @simple_fill(%arg0: f32, %arg1: memref<16xf32>) -> memref<16xf32> { linalg.fill ins(%arg0 : f32) outs(%arg1 : memref<16xf32>) return %arg1 : memref<16xf32> -} \ No newline at end of file +} From fd440ba39dbbd3f1e36ee3e3a3d1a61d2c43cb25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Tue, 16 Jan 2024 08:26:12 +0000 Subject: [PATCH 14/21] Document and update Python requirements. --- jasc/requirements-top-level.txt | 10 ++++++++++ jasc/requirements.txt | 21 ++++++++++++++------- jasc/test/autotuning.py | 2 +- jasc/test/batch_matmul_gpu.py | 2 +- jasc/test/cpu_integration.py | 2 +- jasc/test/gpu_integration.py | 2 +- jasc/test/jit.py | 2 +- jasc/test/matmul_cpu.py | 2 +- jasc/test/matmul_gpu.py | 2 +- jasc/test/tag.py | 2 +- 10 files changed, 32 insertions(+), 15 deletions(-) diff --git a/jasc/requirements-top-level.txt b/jasc/requirements-top-level.txt index a0a995ffe622..9eb8c17f027f 100644 --- a/jasc/requirements-top-level.txt +++ b/jasc/requirements-top-level.txt @@ -1,4 +1,14 @@ +# Jasc's dependencies. absl-py chex pytest PyYAML + +# JAX's dependencies. +# +# Obtained by pip installing jax at the same version as we use in the WORKSPACE +# into a virtual environment. +numpy +scipy +opt-einsum +ml-dtypes diff --git a/jasc/requirements.txt b/jasc/requirements.txt index d6f6f5646e80..0a58742fd5cd 100644 --- a/jasc/requirements.txt +++ b/jasc/requirements.txt @@ -1,16 +1,23 @@ absl-py==2.0.0 chex==0.1.85 iniconfig==2.0.0 -jax==0.4.20 -jaxlib==0.4.20 -ml-dtypes==0.3.1 -numpy==1.26.2 +# We actually need to install JAX with pip as well because chex, which is +# installed through @pip_deps, depends on JAX and can't be installed if its +# dependencies are not. In order for that pip installed JAX not to be used, +# we remove the corresponding path from the PYTHONPATH in all of our files +# that import chex. This is a gross hack but the only way I have managed to +# make things run until now. I opened an issue on Github here +# https://github.com/bazelbuild/rules_python/issues/1583 but did not get +# any answer. +jax==0.4.23 +jaxlib==0.4.23 +ml-dtypes==0.3.2 +numpy==1.26.3 opt-einsum==3.3.0 packaging==23.2 pluggy==1.3.0 -pytest==7.4.3 +pytest==7.4.4 PyYAML==6.0.1 scipy==1.11.4 -setuptools==69.0.2 toolz==0.12.0 -typing_extensions==4.8.0 +typing_extensions==4.9.0 diff --git a/jasc/test/autotuning.py b/jasc/test/autotuning.py index 3be3f10e89c0..c58885d6a0bf 100644 --- a/jasc/test/autotuning.py +++ b/jasc/test/autotuning.py @@ -7,7 +7,7 @@ from typing import Tuple import sys -# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +# Remove paths to `jax*` packages installed from pip. See requirements.txt. sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] import chex diff --git a/jasc/test/batch_matmul_gpu.py b/jasc/test/batch_matmul_gpu.py index 806affbeff23..aceaf4c8bd56 100644 --- a/jasc/test/batch_matmul_gpu.py +++ b/jasc/test/batch_matmul_gpu.py @@ -4,7 +4,7 @@ from typing import Tuple -# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +# Remove paths to `jax*` packages installed from pip. See requirements.txt. import sys sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] diff --git a/jasc/test/cpu_integration.py b/jasc/test/cpu_integration.py index e1e1ac5a2287..2c2762256b74 100644 --- a/jasc/test/cpu_integration.py +++ b/jasc/test/cpu_integration.py @@ -7,7 +7,7 @@ from collections.abc import Mapping import sys -# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +# Remove paths to `jax*` packages installed from pip. See requirements.txt. sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] import chex diff --git a/jasc/test/gpu_integration.py b/jasc/test/gpu_integration.py index a7788f738731..5c13a47e9aae 100644 --- a/jasc/test/gpu_integration.py +++ b/jasc/test/gpu_integration.py @@ -4,7 +4,7 @@ """GPU-specific tests for Jasc.""" -# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +# Remove paths to `jax*` packages installed from pip. See requirements.txt. import sys sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] diff --git a/jasc/test/jit.py b/jasc/test/jit.py index d57fdf4f9a11..d5aa0329456e 100644 --- a/jasc/test/jit.py +++ b/jasc/test/jit.py @@ -7,7 +7,7 @@ import sys -# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +# Remove paths to `jax*` packages installed from pip. See requirements.txt. sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] import chex diff --git a/jasc/test/matmul_cpu.py b/jasc/test/matmul_cpu.py index d483ce42a566..8ba22fc43614 100644 --- a/jasc/test/matmul_cpu.py +++ b/jasc/test/matmul_cpu.py @@ -7,7 +7,7 @@ from typing import Tuple import sys -# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +# Remove paths to `jax*` packages installed from pip. See requirements.txt. sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] import chex diff --git a/jasc/test/matmul_gpu.py b/jasc/test/matmul_gpu.py index 591a8c303337..e0572f6f5b81 100644 --- a/jasc/test/matmul_gpu.py +++ b/jasc/test/matmul_gpu.py @@ -6,7 +6,7 @@ from typing import Tuple -# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +# Remove paths to `jax*` packages installed from pip. See requirements.txt. import sys sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] diff --git a/jasc/test/tag.py b/jasc/test/tag.py index 6c769742bd3f..e4fc342f98ce 100644 --- a/jasc/test/tag.py +++ b/jasc/test/tag.py @@ -6,7 +6,7 @@ from __future__ import annotations from typing import Callable, Sequence -# XXX: Remove paths to `jax*` packages installed from pip by Bazel rules. +# Remove paths to `jax*` packages installed from pip. See requirements.txt. import sys sys.path = [p for p in sys.path if "/pip_deps_jax" not in p] From a5ace9f220d17055c964bfc497513dbdb40327c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Tue, 16 Jan 2024 08:45:32 +0000 Subject: [PATCH 15/21] Remove unnecessary WORKSPACE patches. Simplify remaining file names. --- jasc/WORKSPACE | 2 +- jasc/patches/apply.sh | 12 ---------- jasc/patches/clang_macos.patch | 13 ----------- jasc/patches/jax_workspace.patch | 22 ------------------- jasc/patches/{llvm_build.patch => llvm.patch} | 0 jasc/patches/stablehlo_build.patch | 10 --------- 6 files changed, 1 insertion(+), 58 deletions(-) delete mode 100644 jasc/patches/apply.sh delete mode 100644 jasc/patches/clang_macos.patch delete mode 100644 jasc/patches/jax_workspace.patch rename jasc/patches/{llvm_build.patch => llvm.patch} (100%) delete mode 100644 jasc/patches/stablehlo_build.patch diff --git a/jasc/WORKSPACE b/jasc/WORKSPACE index 7e0efbae7dd7..1ed86dc026ca 100644 --- a/jasc/WORKSPACE +++ b/jasc/WORKSPACE @@ -60,7 +60,7 @@ http_archive( name = "llvm-raw", build_file_content = "# empty", patch_args = ["-p1"], - patches = ["//:patches/llvm_build.patch"], + patches = ["//:patches/llvm.patch"], sha256 = LLVM_SHA256, strip_prefix = "llvm-project-" + LLVM_COMMIT, urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)], diff --git a/jasc/patches/apply.sh b/jasc/patches/apply.sh deleted file mode 100644 index 5e7cbb022c40..000000000000 --- a/jasc/patches/apply.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/sh - -cd llvm-project -patch -p1 < ../patches/llvm_build.patch -# patch -p1 < ../patches/clang_macos.patch -cd - - -cd jax -patch -p1 < ../patches/jax_workspace.patch -touch llvm_dummy.BUILD -cd - - diff --git a/jasc/patches/clang_macos.patch b/jasc/patches/clang_macos.patch deleted file mode 100644 index 3ebce89dbf21..000000000000 --- a/jasc/patches/clang_macos.patch +++ /dev/null @@ -1,13 +0,0 @@ -diff --git a/utils/bazel/llvm-project-overlay/clang/BUILD.bazel b/utils/bazel/llvm-project-overlay/clang/BUILD.bazel -index 419b2eeca7e1..c99b350f4a9f 100644 ---- a/utils/bazel/llvm-project-overlay/clang/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/clang/BUILD.bazel -@@ -1615,7 +1615,7 @@ genrule( - outs = [hdr.replace("lib/Headers/", "staging/include/") for hdr in builtin_headers], - cmd = """ - for src in $(SRCS); do -- relsrc=$${src/*"$(WORKSPACE_ROOT)"\\/clang\\/lib\\/Headers} -+ relsrc=$${src/*external\\llvm-project\\/clang\\/lib\\/Headers} - target=$(@D)/staging/include/$$relsrc - mkdir -p $$(dirname $$target) - cp $$src $$target diff --git a/jasc/patches/jax_workspace.patch b/jasc/patches/jax_workspace.patch deleted file mode 100644 index d9b517998a36..000000000000 --- a/jasc/patches/jax_workspace.patch +++ /dev/null @@ -1,22 +0,0 @@ ---- a/third_party/xla/workspace.bzl -+++ b/third_party/xla/workspace.bzl -@@ -24,12 +24,13 @@ XLA_COMMIT = "8f27d321a86029c336558bfbd6 - XLA_SHA256 = "e8225ee13a8e69c49554d0ec87a0a509c645a1b267c557cd5b9bfe175a4b3f29" - - def repo(): -- tf_http_archive( -- name = "xla", -- sha256 = XLA_SHA256, -- strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), -- urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), -- ) -+ # tf_http_archive( -+ # name = "xla", -+ # sha256 = XLA_SHA256, -+ # strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), -+ # urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), -+ # ) -+ pass - - # For development, one often wants to make changes to the TF repository as well - # as the JAX repository. You can override the pinned repository above with a diff --git a/jasc/patches/llvm_build.patch b/jasc/patches/llvm.patch similarity index 100% rename from jasc/patches/llvm_build.patch rename to jasc/patches/llvm.patch diff --git a/jasc/patches/stablehlo_build.patch b/jasc/patches/stablehlo_build.patch deleted file mode 100644 index d84b465f2190..000000000000 --- a/jasc/patches/stablehlo_build.patch +++ /dev/null @@ -1,10 +0,0 @@ ---- a/BUILD.bazel -+++ b/BUILD.bazel -@@ -907,6 +907,7 @@ cc_library( - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ComplexDialect", -+ "@llvm-project//mlir:FunctionInterfaces", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:InferTypeOpInterface", - "@llvm-project//mlir:QuantOps", From afbfeb7541c9d14437a5a1b392073da0d5c8460e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Tue, 16 Jan 2024 14:54:01 +0000 Subject: [PATCH 16/21] Remove all references to CUDA. I tried half a day to make CUDA integrate with bazel -- in vain. We can try again in some future iteration. --- jasc/BUILD | 1 - jasc/call_kernel.cc | 174 -------------------------------------------- jasc/primitives.py | 12 --- 3 files changed, 187 deletions(-) diff --git a/jasc/BUILD b/jasc/BUILD index 10d3ab2e9d92..9d3fffd112cf 100644 --- a/jasc/BUILD +++ b/jasc/BUILD @@ -74,7 +74,6 @@ cc_library( deps = [ "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", ":mlir_lowering_shared_library", - # "//third_party/gpus/cuda:cuda_headers", "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:ExecutionEngine", diff --git a/jasc/call_kernel.cc b/jasc/call_kernel.cc index fd5b89adb2f7..a2620225c0d0 100644 --- a/jasc/call_kernel.cc +++ b/jasc/call_kernel.cc @@ -21,7 +21,6 @@ #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" -// #include "third_party/gpus/cuda/include/cuda.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Error.h" @@ -220,154 +219,6 @@ void CpuCallback(void *out, void **ins) { kernel->Call(out, (ins + 1)); } -// CUstream jasc_cuda_stream = nullptr; - -// class CudaKernel { -// public: -// CudaKernel(std::unique_ptr execution_engine, -// int num_inputs_outputs) -// : execution_engine_(std::move(execution_engine)), -// num_input_outputs_(num_inputs_outputs) {} - -// // Executes the kernel. -// void Call(CUstream stream, void **buffers) const { -// // TODO(ulysse): avoid relying on a global variable. -// CHECK_EQ(jasc_cuda_stream, nullptr); -// jasc_cuda_stream = stream; - -// std::vector inputs; -// inputs.reserve(num_input_outputs_); -// for (int i = 0; i < num_input_outputs_; ++i) { -// inputs.push_back(&buffers[i]); -// } -// llvm::cantFail(execution_engine_->invokePacked("main", inputs)); -// jasc_cuda_stream = nullptr; -// } - -// private: -// std::unique_ptr execution_engine_; -// int num_input_outputs_; -// int num_outputs_; -// }; - -// void CheckCudaError(CUresult result) { -// if (result != CUDA_SUCCESS) { -// const char **error_msg = nullptr; -// cuGetErrorString(result, error_msg); -// LOG(FATAL) << *error_msg; -// } -// } - -// extern "C" void JascCudaLaunchKernel(CUfunction function, intptr_t gridX, -// intptr_t gridY, intptr_t gridZ, -// intptr_t blockX, intptr_t blockY, -// intptr_t blockZ, int32_t smem, -// CUstream stream, void **params, -// void **extra) { -// CheckCudaError(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, -// blockZ, smem, stream, params, extra)); -// } - -// extern "C" CUstream JascCudaStreamCreate() { -// // TODO(ulysse): explicitly pass the stream instead of relying on a global -// // variable. -// return jasc_cuda_stream; -// } - -// extern "C" void JascCudaStreamDestroy(CUstream stream) { -// // NO-op as we are reusing the stream given by XLA. -// } - -// extern "C" CUmodule JascCudaModuleLoad(void *data) { -// // TODO(ulysse): investigate the performance implications of loading the -// // module on the fly. -// CUmodule module; -// CheckCudaError(cuModuleLoadData(&module, data)); -// return module; -// } - -// extern "C" void JascCudaModuleUnload(CUmodule module) { -// CheckCudaError(cuModuleUnload(module)); -// } - -// extern "C" CUfunction JascCudaModuleGetFunction(CUmodule module, -// const char *name) { -// // TODO(ulysse): investigate the performance implications of loading the -// // function on the fly. -// CUfunction function; -// CheckCudaError(cuModuleGetFunction(&function, module, name)); -// return function; -// } - -// extern "C" void JascCudaStreamSynchronize(CUstream stream) { -// CheckCudaError(cuStreamSynchronize(stream)); -// } - -// extern "C" void *JascCudaMemAlloc(uint64_t size_bytes, CUstream) { -// CUdeviceptr ptr; -// CheckCudaError(cuMemAlloc(&ptr, size_bytes)); -// return reinterpret_cast(ptr); -// } - -// extern "C" void JascCudaMemFree(void *ptr, CUstream) { -// CheckCudaError(cuMemFree(reinterpret_cast(ptr))); -// } - -// extern "C" void JascCudaMemcpy(void *dst, void *src, size_t sizeBytes, -// CUstream stream) { -// CheckCudaError(cuMemcpy(reinterpret_cast(dst), -// reinterpret_cast(src), sizeBytes)); -// } - -// absl::StatusOr> CreateCudaKernel( -// mlir::python::PyModule &py_module, int num_inputs, int num_outputs, -// bool dump_ir) { -// mlir::ModuleOp module = unwrap(py_module.get()); -// RETURN_IF_ERROR(LowerStableHloToGpuLLVM(module, dump_ir)); -// mlir::ExecutionEngineOptions engine_opts; -// // TODO(ulysse): Select LLVM opt level. -// engine_opts.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Default; -// auto engineOrError = mlir::ExecutionEngine::create(module, engine_opts); -// if (!engineOrError) { -// llvm::handleAllErrors( -// engineOrError.takeError(), [&](const llvm::StringError &err) { -// LOG(FATAL) << "Error while creating execution engine: " -// << err.getMessage(); -// }); -// } -// engineOrError.get()->registerSymbols( -// [](llvm::orc::MangleAndInterner interner) { -// auto map = llvm::orc::SymbolMap(); -// auto register_symbol = [&map, &interner](llvm::StringRef name, -// auto *func) { -// auto addr = llvm::orc::ExecutorAddr(reinterpret_cast(func)); -// map[interner(name)] = {addr, llvm::JITSymbolFlags::None}; -// }; -// register_symbol("mgpuLaunchKernel", &JascCudaLaunchKernel); -// register_symbol("mgpuStreamCreate", &JascCudaStreamCreate); -// register_symbol("mgpuStreamDestroy", &JascCudaStreamDestroy); -// register_symbol("mgpuModuleLoad", &JascCudaModuleLoad); -// register_symbol("mgpuModuleUnload", &JascCudaModuleUnload); -// register_symbol("mgpuModuleGetFunction", &JascCudaModuleGetFunction); -// register_symbol("mgpuStreamSynchronize", &JascCudaStreamSynchronize); -// register_symbol("mgpuMemAlloc", &JascCudaMemAlloc); -// register_symbol("mgpuMemFree", &JascCudaMemFree); -// register_symbol("mgpuMemcpy", &JascCudaMemcpy); -// return map; -// }); -// return std::make_unique(llvm::cantFail(std::move(engineOrError)), -// num_inputs + num_outputs); -// } - -// // XLA custom call callback that calls a kernel on GPU. -// void GpuCallback(CUstream stream, void **buffers, const char *opaque, -// size_t opaque_len) { -// CHECK_EQ(opaque_len, sizeof(CudaKernel *)); -// CudaKernel *kernel_call; -// std::memcpy(&kernel_call, opaque, sizeof(CudaKernel *)); -// kernel_call->Call(stream, buffers); -// } - /// Clears the `PyOperation` (representing Python-level handles to /// `Operation *`s) that are tracked by the context. This function should be /// called by any entry point that may modify the IR, which could cause above @@ -419,31 +270,6 @@ PYBIND11_MODULE(call_kernel, m) { "xla._CUSTOM_CALL_TARGET"); }); - // py::class_(m, "CudaKernel") - // .def_property_readonly("ptr", [](CudaKernel *kernel) { - // union { - // CudaKernel *ptr; - // char bytes[sizeof(CudaKernel *)]; - // } bytes_ptr; - // bytes_ptr.ptr = kernel; - // return pybind11::bytes(bytes_ptr.bytes, sizeof(CudaKernel *)); - // }); - - // m.def( - // "create_cuda_kernel", - // [](mlir::python::PyModule &py_module, int num_inputs, int num_outputs, - // bool dump_ir) { - // clearOperationsInside(py_module); - // return CreateCudaKernel(py_module, num_inputs, num_outputs, dump_ir); - // }, - // py::arg("module"), py::arg("num_inputs"), py::arg("num_outputs"), - // py::arg("dump_ir") = false); - - // m.def("get_cuda_callback", []() { - // return pybind11::capsule(reinterpret_cast(&GpuCallback), - // "xla._CUSTOM_CALL_TARGET"); - // }); - m.def( "lower_to_linalg", [](mlir::python::PyModule &py_module, bool dump_ir) { diff --git a/jasc/primitives.py b/jasc/primitives.py index 71037adddf05..8ae4eeb200c0 100644 --- a/jasc/primitives.py +++ b/jasc/primitives.py @@ -123,15 +123,6 @@ def _jit_lowering( identifier_attr = jax_mlir.dense_int_elements([compiled_kernel.identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) mlir_args = [identifier_op.result] - # elif ctx.module_context.platforms[0] == 'cuda': - # compiled_kernel = call_kernel.create_cuda_kernel( - # module=lowered_ir, - # num_inputs=len(args), - # num_outputs=len(ctx.avals_out), - # dump_ir=dump_ir, - # ) - # ctx.module_context.add_keepalive(compiled_kernel) - # backend_config = ir.StringAttr.get(compiled_kernel.ptr) else: raise NotImplementedError( f'Jasc does not support platform {ctx.module_context.platforms[0]}' @@ -164,9 +155,6 @@ def _jit_lowering( xla_client.register_custom_call_target( 'jasc.call_kernel', call_kernel.get_cpu_callback(), platform='cpu' ) -# xla_client.register_custom_call_target( -# 'jasc.call_kernel', call_kernel.get_cuda_callback(), platform='CUDA' -# ) def _tag_lowering( From 892e39fa1d47b00cc6fe58300238c9b2a33b140a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 19 Jan 2024 09:26:45 +0000 Subject: [PATCH 17/21] Reduce diff to internal version The reductions are related to: * abseil's VLOG, * LLVM's dropping of opaque pointers, * random new lines. --- jasc/call_kernel.cc | 7 ++++--- jasc/jasc.py | 1 + jasc/mlir_lowering.cc | 3 --- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/jasc/call_kernel.cc b/jasc/call_kernel.cc index a2620225c0d0..7281001b4398 100644 --- a/jasc/call_kernel.cc +++ b/jasc/call_kernel.cc @@ -40,7 +40,8 @@ #include "mlir_lowering.h" -#define VLOG(X) std::cerr +// Work-around while `VLOG` is still missing in public abseil. +#define VLOG(X) LOG(INFO) namespace jasc { namespace { @@ -75,13 +76,13 @@ class CpuKernel { identifier_ = next_kernel_id_++; global_registry_->emplace(identifier_, this); - VLOG(1) << "allocated kernel " << identifier_ << "\n"; + VLOG(1) << "allocated kernel " << identifier_; } ~CpuKernel() { absl::WriterMutexLock lock(&global_registry_mutex_); global_registry_->erase(identifier_); - VLOG(1) << "deallocated kernel " << identifier_ << "\n"; + VLOG(1) << "deallocated kernel " << identifier_; } // A unique identifier for the kernel. diff --git a/jasc/jasc.py b/jasc/jasc.py index aac9bce86c34..cdeeae8da6cb 100644 --- a/jasc/jasc.py +++ b/jasc/jasc.py @@ -58,6 +58,7 @@ def schedule(h: OpHandle) -> None: _JASC_AUTO_NORMALIZATION = True + def set_auto_normalization(activate: bool): """Toggles the automatic normalization mode.""" global _JASC_AUTO_NORMALIZATION diff --git a/jasc/mlir_lowering.cc b/jasc/mlir_lowering.cc index b5ae5cf12b65..cdd3ce530035 100644 --- a/jasc/mlir_lowering.cc +++ b/jasc/mlir_lowering.cc @@ -368,12 +368,10 @@ void AddLowerCFToLLVMPasses(mlir::PassManager& pm) { pm.addPass(mlir::createLowerAffinePass()); mlir::FinalizeMemRefToLLVMConversionPassOptions memref_to_llvm_opts; - // memref_to_llvm_opts.useOpaquePointers = true; pm.addPass( mlir::createFinalizeMemRefToLLVMConversionPass(memref_to_llvm_opts)); mlir::ConvertFuncToLLVMPassOptions func_to_llvm_opts; - // func_to_llvm_opts.useOpaquePointers = true; func_to_llvm_opts.useBarePtrCallConv = false; pm.addPass(mlir::createConvertFuncToLLVMPass(func_to_llvm_opts)); pm.addPass(mlir::createConvertIndexToLLVMPass()); @@ -478,7 +476,6 @@ absl::Status LowerStableHloToGpuLLVM(mlir::ModuleOp module, bool dump_ir) { pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); pm.addPass(mlir::createLowerAffinePass()); mlir::ConvertFuncToLLVMPassOptions func_to_llvm_opts; - func_to_llvm_opts.useOpaquePointers = true; func_to_llvm_opts.useBarePtrCallConv = true; pm.addPass(mlir::createConvertFuncToLLVMPass(func_to_llvm_opts)); pm.addPass(mlir::createCanonicalizerPass()); From a9286a1e0eddfc1bd38e6543a2a9ff3622ea4754 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 17 Jan 2024 10:21:57 +0000 Subject: [PATCH 18/21] Massively reduce changes to BUILD files. I put all new constructs into question again. Turns out that many were not needed and some were even plain unused. --- jasc/BUILD | 129 ++++++++++++----------- jasc/WORKSPACE | 8 ++ jasc/dialect/BUILD | 74 ++++--------- jasc/dialect/bindings.cc | 2 +- jasc/dialect/jasc.py | 2 +- jasc/jasc.py | 4 +- jasc/patches/jax.patch | 54 ++++++++++ jasc/primitives.py | 2 +- jasc/test/BUILD | 2 - jasc/test/abstractions.py | 2 +- jasc/test/bindings.py | 4 +- jasc/transform_ops/BUILD | 53 ++-------- jasc/transform_ops/bindings.cpp | 2 +- jasc/transform_ops/jasc_transform_ops.py | 2 +- jasc/tuner.py | 2 +- 15 files changed, 165 insertions(+), 177 deletions(-) diff --git a/jasc/BUILD b/jasc/BUILD index 9d3fffd112cf..9c568e876c35 100644 --- a/jasc/BUILD +++ b/jasc/BUILD @@ -36,7 +36,6 @@ py_library( "@jax1//jaxlib/mlir:bufferization_dialect", "@jax1//jaxlib/mlir:core", "@jax1//jaxlib/mlir:ir", - "@jax1//jaxlib/mlir:jasc_dialect", "@jax1//jaxlib/mlir:pdl_dialect", "@jax1//jaxlib/mlir:transform_dialect", ], @@ -47,9 +46,9 @@ py_library( srcs = ["tuner.py"], deps = [ ":jasc", + "//transform_ops", "@jax1//jax", "@jax1//jaxlib/mlir:ir", - "@jax1//jaxlib/mlir:jasc_dialect", "@jax1//jaxlib/mlir:transform_dialect", ], ) @@ -69,43 +68,6 @@ py_library( ], ) -cc_library( - name = "call_kernel_shared_library_deps", - deps = [ - "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", - ":mlir_lowering_shared_library", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIR", - "@llvm-project//mlir:ExecutionEngine", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MLIRBindingsPythonHeaders", - "@pybind11_abseil//pybind11_abseil:import_status_module", - "@pybind11_abseil//pybind11_abseil:status_casters", - ], -) - -cc_headers_only( - name = "call_kernel_shared_library_deps_headers", - src = "call_kernel_shared_library_deps", -) - -cc_binary( - name = "libcallkernel.so", - linkopts = [ - "-Wl,-soname=libcallkernel.so", - "-Wl,-rpath='$$ORIGIN'", - ], - linkshared = 1, - deps = [":call_kernel_shared_library_deps"], -) - -cc_library( - name = "call_kernel_shared_library", - srcs = [":libcallkernel.so"], - deps = [":call_kernel_shared_library_deps_headers"], -) - cc_binary( name = "libmlir_c_runner_utils.so", linkopts = [ @@ -120,7 +82,6 @@ pybind_extension( name = "call_kernel", srcs = ["call_kernel.cc"], deps = [ - ":call_kernel_shared_library", ":libmlir_c_runner_utils.so", ":mlir_lowering_shared_library", "@com_google_absl//absl/container:flat_hash_map", @@ -128,14 +89,25 @@ pybind_extension( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MLIRBindingsPythonHeaders", "@pybind11_abseil//pybind11_abseil:import_status_module", + "@pybind11_abseil//pybind11_abseil:status_casters", "@status_macros", ], ) +# +# `mlir_lowering` library. +# +# 1. Dependencies only. This allows to get the headers of all dependencies. cc_library( name = "mlir_lowering_shared_library_deps", + visibility = ["//visibility:private"], deps = [ "@com_google_absl//absl/status", "@llvm-project//llvm:Support", @@ -178,8 +150,52 @@ cc_library( cc_headers_only( name = "mlir_lowering_shared_library_deps_headers", src = "mlir_lowering_shared_library_deps", + visibility = ["//visibility:private"], +) + +# 2. The main library. This shouldn't be used directly in `py_extension`s. +cc_library( + name = "mlir_lowering", + srcs = [ + "gpu_lowering_passes.cc", + "mlir_lowering.cc", + ], + hdrs = [ + "gpu_lowering_passes.h", + "mlir_lowering.h", + ], + data = ["gpu_post_bufferize.mlir"], + visibility = [ + # `jaxlib_mlir_capi_shared_library` needs to depend on `mlir_lowering` + # because (1) it depends on other targets that need symbols from this + # target and (2) that target cannot depend on + # `mlir_lowering_shared_library` because the reverse dependency must + # exist (since, otherwise, `mlir_lowering_shared_library` would + # duplicate symbols from `jaxlib_mlir_capi_shared_library`). + "@jax1//jaxlib/mlir/_mlir_libs:__pkg__", + ], + deps = [ + ":mlir_lowering_shared_library_deps_headers", + # Only depend on the headers here to avoid duplicate symbols. + "//dialect:dialect_headers", + "//transform_ops:jasc_transform_ops_headers", + ], + # This is important since it makes sure that the symbols of the library are + # exported by the `.so` target below even though they aren't used directly. + alwayslink = True, +) + +cc_headers_only( + name = "mlir_lowering_headers", + src = "mlir_lowering", + visibility = ["//visibility:private"], ) +# 3. Shared object file. This forces to create a shared library, which dependent +# targets can link against, instead of using the default static linking. This +# ensures that the symbols in that library exist only once instead of once for +# each time it is linked statically. +# This pattern is copied from JAX. A platform independent version exists there. cc_binary( name = "libmlirlowering.so", linkopts = [ @@ -187,40 +203,25 @@ cc_binary( "-Wl,-rpath='$$ORIGIN'", ], linkshared = 1, - deps = [":mlir_lowering"], + visibility = ["//visibility:private"], + deps = [ + ":mlir_lowering", + "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", + ], ) +# 4. A `cc_library` wrapper of the shared library. This is the main target. cc_library( name = "mlir_lowering_shared_library", srcs = [ "mlir_lowering.h", ":libmlirlowering.so", ], - deps = [":mlir_lowering_shared_library_deps_headers"], -) - -cc_headers_only( - name = "mlir_lowering_shared_library_headers", - src = "mlir_lowering_shared_library", -) - -cc_library( - name = "mlir_lowering", - srcs = [ - "gpu_lowering_passes.cc", - "mlir_lowering.cc", - ], - hdrs = [ - "gpu_lowering_passes.h", - "mlir_lowering.h", - ], - data = ["gpu_post_bufferize.mlir"], deps = [ + ":mlir_lowering_headers", ":mlir_lowering_shared_library_deps_headers", - "//dialect:jasc_dialect_headers", - "//transform_ops:jasc_transform_ops_headers", + "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", ], - alwayslink = True, ) cc_binary( @@ -229,7 +230,7 @@ cc_binary( deps = [ ":mlir_lowering", "//dialect", - "//transform_ops:jasc_transform_ops_shared_library", + "//transform_ops:jasc_transform_ops", "@com_google_absl//absl/status:statusor", "@llvm-project//mlir:AllExtensions", "@llvm-project//mlir:AllPassesAndDialects", diff --git a/jasc/WORKSPACE b/jasc/WORKSPACE index 1ed86dc026ca..f2c8b4cbff79 100644 --- a/jasc/WORKSPACE +++ b/jasc/WORKSPACE @@ -120,6 +120,14 @@ pip_install_dependencies() JAX_COMMIT = "32a317f7a43440800e1e39e00ed5f2980e088ab1" JAX_SHA256 = "6e2147be7360a5c0672b6ba0d654cdb2ac96113b63ef457dfdc76cd50fe69ff1" +# We import JAX as `jax1` since `import jax` otherwise imports the *containing* +# folder of the JAX Python module rather than the folder of the module. The +# problem is that Bazel puts empty `__init__.py` files essentially everywhere; +# See https://github.com/bazelbuild/bazel/issues/7653 and +# https://github.com/bazelbuild/bazel/issues/3998. That behaviour can be +# changed with `--incompatible_default_to_explicit_init_py` but then JAX +# *misses* some empty `__init__.py` files and I have no ambition in fixing that +# for them currently. http_archive( name = "jax1", patch_args = ["-p1"], diff --git a/jasc/dialect/BUILD b/jasc/dialect/BUILD index 76d56f760b5b..972d285966b3 100644 --- a/jasc/dialect/BUILD +++ b/jasc/dialect/BUILD @@ -63,27 +63,6 @@ gentbl_cc_library( ], ) -cc_library( - name = "jasc_dialect_shared_library_deps", - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:TransformDialect", - ], -) - -cc_headers_only( - name = "jasc_dialect_shared_library_deps_headers", - src = "jasc_dialect_shared_library_deps", -) - -cc_headers_only( - name = "jasc_dialect_headers", - src = "dialect", -) - cc_library( name = "dialect", srcs = [ @@ -96,12 +75,26 @@ cc_library( ], deps = [ ":dialect_inc_gen", - ":jasc_dialect_shared_library_deps_headers", ":ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:TransformDialect", ], - alwayslink = True, ) +cc_headers_only( + name = "dialect_headers", + src = "dialect", +) + +# +# CAPI library. +# +# We patch this into +# `@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library` +# such that that library (and only that) contains all symbols of the dialect. cc_library( name = "capi", srcs = [ @@ -145,39 +138,15 @@ gentbl_filegroup( ], ) -cc_library( - name = "jasc_dialect_shared_library", - srcs = [ - "dialect.h", - "ops.h", - ":libjascdialect.so", - ], - deps = [ - ":dialect_inc_gen", - ":jasc_dialect_shared_library_deps_headers", - ":ops_inc_gen", - ], -) - -cc_binary( - name = "libjascdialect.so", - linkopts = [ - "-Wl,-soname=libjascdialect.so", - "-Wl,-rpath='$$ORIGIN'", - ], - linkshared = 1, - deps = [":dialect"], -) - pybind_extension( name = "bindings", srcs = ["bindings.cc"], + # Only depend on headers or shared (!) libraries to avoid duplicate symbols. deps = [ - ":jasc_dialect_headers", + ":dialect_headers", "//:mlir_lowering_shared_library", - "//transform_ops:jasc_transform_ops_shared_library", + "//transform_ops:jasc_transform_ops_headers", "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", "@llvm-project//mlir:MLIRBindingsPythonHeaders", ], ) @@ -185,14 +154,11 @@ pybind_extension( py_library( name = "python", srcs = [ + "_ods_common.py", "jasc.py", - # "_ods_common.py", ":ops_py_gen", ], deps = [ ":bindings", - # "@jax//jaxlib/mlir:core", - # "@jax//jaxlib/mlir:ir", - # "@jax//jaxlib/mlir:pdl_dialect", ], ) diff --git a/jasc/dialect/bindings.cc b/jasc/dialect/bindings.cc index 8dd5d4fc8f61..e35baf61aef0 100644 --- a/jasc/dialect/bindings.cc +++ b/jasc/dialect/bindings.cc @@ -16,7 +16,7 @@ #include "mlir_lowering.h" #include "transform_ops/dialect_extension.h" -PYBIND11_MODULE(_mlirDialectsJasc, m) { +PYBIND11_MODULE(bindings, m) { m.def( "register_and_load_dialect", [](MlirContext py_context) { diff --git a/jasc/dialect/jasc.py b/jasc/dialect/jasc.py index 2f36fd0008d0..59786343a74c 100644 --- a/jasc/dialect/jasc.py +++ b/jasc/dialect/jasc.py @@ -3,4 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._ops_gen import * -from .._mlir_libs._mlirDialectsJasc import * +from .bindings import * diff --git a/jasc/jasc.py b/jasc/jasc.py index cdeeae8da6cb..061e59551be0 100644 --- a/jasc/jasc.py +++ b/jasc/jasc.py @@ -52,9 +52,9 @@ def schedule(h: OpHandle) -> None: ) import call_kernel -from jaxlib.mlir.dialects import jasc as jasc_dialect +from dialect import jasc as jasc_dialect import primitives -from jaxlib.mlir.dialects.transform import jasc_transform_ops as jto +from transform_ops import jasc_transform_ops as jto _JASC_AUTO_NORMALIZATION = True diff --git a/jasc/patches/jax.patch b/jasc/patches/jax.patch index e82777956811..d4985be80168 100644 --- a/jasc/patches/jax.patch +++ b/jasc/patches/jax.patch @@ -523,3 +523,57 @@ name = "transform_dialect", rule = py_library, symlinked_inputs = {"srcs": {"dialects": [ + +--- a/jaxlib/mlir/BUILD.bazel ++++ b/jaxlib/mlir/BUILD.bazel +@@ -215,7 +215,6 @@ symlink_inputs( + ":ir", + ":mlir", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsTransform", +- "//jaxlib/mlir/_mlir_libs:_mlirTransformOpsJasc", + ], + ) + + +--- a/jaxlib/mlir/_mlir_libs/BUILD.bazel ++++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel +@@ -71,39 +71,6 @@ py_extension( + ], + ) + +- +-py_extension( +- name = "_mlirDialectsJasc", +- srcs = [ +- "@jasc//dialect:bindings.cc", +- ], +- copts = COPTS, +- linkopts = LINKOPTS, +- deps = [ +- ":jaxlib_mlir_capi_shared_library", +- "@jasc//dialect:jasc_dialect_headers", +- "@jasc//transform_ops:jasc_transform_ops_shared_library_headers", +- "@jasc//:mlir_lowering_shared_library_headers", +- "@llvm-project//mlir:MLIRBindingsPythonHeaders", +- "@pybind11", +- ], +-) +- +-py_extension( +- name = "_mlirTransformOpsJasc", +- srcs = [ +- "@jasc//transform_ops:bindings.cpp", +- ], +- copts = COPTS, +- linkopts = LINKOPTS, +- deps = [ +- ":jaxlib_mlir_capi_shared_library", +- "@jasc//transform_ops:jasc_transform_ops_headers", +- "@llvm-project//mlir:MLIRBindingsPythonHeaders", +- "@pybind11", +- ], +-) +- + py_extension( + name = "_mlirSparseTensorPasses", + srcs = [ diff --git a/jasc/primitives.py b/jasc/primitives.py index 8ae4eeb200c0..326ee38e00b5 100644 --- a/jasc/primitives.py +++ b/jasc/primitives.py @@ -20,7 +20,7 @@ from jaxlib.mlir.dialects import transform import call_kernel -from jaxlib.mlir.dialects import jasc as jasc_dialect +from dialect import jasc as jasc_dialect _JAX_COMPATIBLE_LOWERING = True diff --git a/jasc/test/BUILD b/jasc/test/BUILD index 1d1eeb6203ea..adfa09a351a5 100644 --- a/jasc/test/BUILD +++ b/jasc/test/BUILD @@ -80,7 +80,6 @@ py_binary( requirement("ml_dtypes"), requirement("opt_einsum"), "@jax1//jaxlib/mlir:ir", - "@jax1//jaxlib/mlir:jasc_dialect", "@jax1//jaxlib/mlir:scf_dialect", ], ) @@ -95,7 +94,6 @@ py_binary( requirement("opt_einsum"), "@jax1//jax", "@jax1//jaxlib/mlir:ir", - "@jax1//jaxlib/mlir:jasc_dialect", ], ) diff --git a/jasc/test/abstractions.py b/jasc/test/abstractions.py index 9986c5bb2ba1..91d8ba1e76e6 100644 --- a/jasc/test/abstractions.py +++ b/jasc/test/abstractions.py @@ -14,7 +14,7 @@ from jaxlib.mlir.dialects.transform import structured from jasc import jasc -from jaxlib.mlir.dialects.transform import jasc_transform_ops +from transform_ops import jasc_transform_ops tests: list[Callable[[], None]] = [] jasc.set_auto_normalization(False) diff --git a/jasc/test/bindings.py b/jasc/test/bindings.py index 0f137a0617d1..4838e93a53d4 100644 --- a/jasc/test/bindings.py +++ b/jasc/test/bindings.py @@ -8,8 +8,8 @@ from jaxlib.mlir import ir, passmanager from jaxlib.mlir.dialects import transform -from jaxlib.mlir.dialects import jasc as jd -from jaxlib.mlir.dialects.transform import jasc_transform_ops as jto +from dialect import jasc as jd +from transform_ops import jasc_transform_ops as jto tests: list[Callable[[], None]] = [] diff --git a/jasc/transform_ops/BUILD b/jasc/transform_ops/BUILD index de88ffbb353b..76d69cc96196 100644 --- a/jasc/transform_ops/BUILD +++ b/jasc/transform_ops/BUILD @@ -58,10 +58,10 @@ gentbl_filegroup( pybind_extension( name = "bindings", srcs = ["bindings.cpp"], + # Only depend on headers or shared (!) libraries to avoid duplicate symbols. deps = [ - ":jasc_transform_ops_shared_library", + ":jasc_transform_ops_headers", "@jax1//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", "@llvm-project//mlir:MLIRBindingsPythonHeaders", ], ) @@ -80,8 +80,12 @@ py_library( ) cc_library( - name = "jasc_transform_ops_shared_library_deps", + name = "jasc_transform_ops", + srcs = glob(["*.cc"]), + hdrs = glob(["*.h"]), + includes = ["."], deps = [ + ":jasc_transform_ops_inc_gen", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", @@ -91,50 +95,7 @@ cc_library( ], ) -cc_headers_only( - name = "jasc_transform_ops_shared_library_deps_headers", - src = "jasc_transform_ops_shared_library_deps", -) - -cc_library( - name = "jasc_transform_ops_shared_library", - srcs = [ - ":libjasctransformops.so", - ], - hdrs = glob(["*.h"]), - deps = [ - "jasc_transform_ops_shared_library_deps_headers", - ], -) - -cc_headers_only( - name = "jasc_transform_ops_shared_library_headers", - src = "jasc_transform_ops_shared_library", -) - -cc_binary( - name = "libjasctransformops.so", - linkopts = [ - "-Wl,-soname=libjasctransformops.so", - "-Wl,-rpath='$$ORIGIN'", - ], - linkshared = 1, - deps = [":jasc_transform_ops"], -) - cc_headers_only( name = "jasc_transform_ops_headers", src = "jasc_transform_ops", ) - -cc_library( - name = "jasc_transform_ops", - srcs = glob(["*.cc"]), - hdrs = glob(["*.h"]), - includes = ["."], - deps = [ - ":jasc_transform_ops_inc_gen", - ":jasc_transform_ops_shared_library_deps_headers", - ], - alwayslink = True, -) diff --git a/jasc/transform_ops/bindings.cpp b/jasc/transform_ops/bindings.cpp index 1da7d4f5d6ae..dca7193127c6 100644 --- a/jasc/transform_ops/bindings.cpp +++ b/jasc/transform_ops/bindings.cpp @@ -7,7 +7,7 @@ #include "dialect_extension.h" -PYBIND11_MODULE(_mlirTransformOpsJasc, m) { +PYBIND11_MODULE(bindings, m) { m.def( "register_transform_dialect_extension", [](mlir::python::DefaultingPyMlirContext py_context) { diff --git a/jasc/transform_ops/jasc_transform_ops.py b/jasc/transform_ops/jasc_transform_ops.py index e8c10265a3d5..602e205eda6f 100644 --- a/jasc/transform_ops/jasc_transform_ops.py +++ b/jasc/transform_ops/jasc_transform_ops.py @@ -4,7 +4,7 @@ from ._jasc_transform_ops_gen import * from ._jasc_transform_ops_gen import _Dialect -from ..._mlir_libs._mlirTransformOpsJasc import * +from .bindings import * try: from typing import Sequence diff --git a/jasc/tuner.py b/jasc/tuner.py index 098816cb77b5..0265ce0ddec3 100644 --- a/jasc/tuner.py +++ b/jasc/tuner.py @@ -19,7 +19,7 @@ from jaxlib.mlir.dialects import transform from jasc import jasc -from jaxlib.mlir.dialects.transform import jasc_transform_ops +from transform_ops import jasc_transform_ops @dataclasses.dataclass class TunerBase(abc.ABC): From df0110b1f71c7d1f0359a5af54b345f294db9545 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 19 Jan 2024 11:28:45 +0000 Subject: [PATCH 19/21] Simplify patches: remove unnecessary changes and squash. The patches still added things to several dependencies that are actually not needed. This commit removes those. I also created a fresh patch that "squashes" several `diff` sections to the same file into a single section (i.e., I recreated the patch files from scratch). --- jasc/patches/jax.patch | 596 +++++++++++----------------------------- jasc/patches/llvm.patch | 14 +- jasc/patches/xla.patch | 34 +-- 3 files changed, 176 insertions(+), 468 deletions(-) diff --git a/jasc/patches/jax.patch b/jasc/patches/jax.patch index d4985be80168..f28df7748c44 100644 --- a/jasc/patches/jax.patch +++ b/jasc/patches/jax.patch @@ -1,5 +1,16 @@ +--- a/jax/BUILD ++++ b/jax/BUILD +@@ -70,6 +70,7 @@ package_group( + # Intentionally avoid jax dependencies on jax.extend. + # See https://jax.readthedocs.io/en/latest/jep/15856-jex.html + "//third_party/py/jax/tests/...", ++ "public", + ] + jax_extend_internal_users, + ) + + --- a/jaxlib/cpu/BUILD -+++ a/jaxlib/cpu/BUILD ++++ b/jaxlib/cpu/BUILD @@ -79,7 +79,7 @@ cc_library( ":ducc_fft_flatbuffers_cc", "@xla//xla/service:custom_call_status", @@ -10,83 +21,52 @@ ) ---- a/jaxlib/mlir/_mlir_libs/BUILD.bazel -+++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel -@@ -241,6 +241,7 @@ cc_library( - deps = [ - ":jax_dialects_capi", - "//jaxlib/mosaic:tpu_dialect_capi_objects", -+ "@com_google_protobuf//:protobuf", - "@llvm-project//mlir:CAPIArithObjects", - "@llvm-project//mlir:CAPIMathObjects", - "@llvm-project//mlir:CAPIMemRefObjects", -@@ -250,7 +251,11 @@ cc_library( - "@llvm-project//mlir:MLIRBindingsPythonCAPIObjects", - "@stablehlo//:chlo_capi_objects", - "@stablehlo//:stablehlo_capi_objects", -+ "@tsl//tsl/platform:env", -+ "@tsl//tsl/platform:env_impl", - "@xla//xla/mlir_hlo:CAPIObjects", -+ "@xla//xla:xla_data_proto_cc", -+ "@xla//xla:xla_data_proto_cc_impl", - ], - ) +--- a/jaxlib/mlir/BUILD.bazel ++++ b/jaxlib/mlir/BUILD.bazel +@@ -13,6 +13,7 @@ + # limitations under the License. - ---- a/jaxlib/mlir/_mlir_libs/BUILD.bazel -+++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel -@@ -139,6 +139,40 @@ py_extension( + load("//jaxlib:symlink_files.bzl", "symlink_inputs") ++load("@pip_deps//:requirements.bzl", "requirement") + + package( + default_visibility = [ +@@ -74,6 +75,141 @@ symlink_inputs( ], ) -+py_extension( -+ name = "_mlirDialectsTransform", -+ srcs = [ -+ "@llvm-project//mlir:lib/Bindings/Python/DialectTransform.cpp", -+ ], -+ copts = COPTS, -+ linkopts = LINKOPTS, ++symlink_inputs( ++ name = "bufferization_dialect", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects": [ ++ "@llvm-project//mlir/python:BufferizationOpsPyFiles", ++ ]}}, + deps = [ -+ ":jax_dialects_capi_headers", -+ ":jaxlib_mlir_capi_shared_library", -+ "@llvm-project//mlir:CAPIIRHeaders", -+ "@llvm-project//mlir:CAPITransformDialect", -+ "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", -+ "@pybind11", ++ ":core", ++ ":ir", ++ ":mlir", + ], +) + -+py_extension( -+ name = "_mlirDialectsPDL", -+ srcs = [ -+ "@llvm-project//mlir:lib/Bindings/Python/DialectPDL.cpp", -+ ], -+ copts = COPTS, -+ linkopts = LINKOPTS, ++symlink_inputs( ++ name = "pdl_dialect", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects": [ ++ "@llvm-project//mlir/python:PDLPyFiles", ++ ]}}, + deps = [ -+ ":jax_dialects_capi_headers", -+ ":jaxlib_mlir_capi_shared_library", -+ "@llvm-project//mlir:CAPIIRHeaders", -+ "@llvm-project//mlir:CAPIPDL", -+ "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", -+ "@pybind11", ++ ":core", ++ ":ir", ++ ":mlir", ++ ":pdl_dialect_extension", + ], +) + - ##---------------------------------------------------------------------------## - # MHLO Extensions - ##---------------------------------------------------------------------------## - ---- a/jaxlib/mlir/BUILD.bazel -+++ a/jaxlib/mlir/BUILD.bazel -@@ -75,6 +75,75 @@ symlink_inputs( - ) - - symlink_inputs( -+ name = "bufferization_dialect", ++symlink_inputs( ++ name = "complex_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ -+ "@llvm-project//mlir/python:BufferizationOpsPyFiles", ++ "@llvm-project//mlir/python:ComplexOpsPyFiles", + ]}}, + deps = [ + ":core", @@ -96,20 +76,58 @@ +) + +symlink_inputs( -+ name = "pdl_dialect", ++ name = "linalg_dialect", + rule = py_library, -+ symlinked_inputs = {"srcs": {"dialects": [ -+ "@llvm-project//mlir/python:PDLPyFiles", ++ symlinked_inputs = {"srcs": {"dialects/linalg": [ ++ "@llvm-project//mlir/python:LinalgOpsPackagePyFiles", + ]}}, + deps = [ ++ ":complex_dialect", + ":core", + ":ir", + ":mlir", -+ ":pdl_dialect_extension", ++ ":linalg_dialect_gen_files", ++ ":linalg_dialect_opdsl_files", ++ ":linalg_dialect_opdsl_lang_files", ++ ":linalg_dialect_opdsl_ops_files", ++ "//jaxlib/mlir/_mlir_libs:_mlirDialectsLinalg", ++ requirement("PyYAML"), + ], +) + +symlink_inputs( ++ name = "linalg_dialect_gen_files", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects": [ ++ "@llvm-project//mlir/python:LinalgOpsPyFiles", ++ ]}}, ++) ++ ++symlink_inputs( ++ name = "linalg_dialect_opdsl_files", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects/linalg/opdsl": [ ++ "@llvm-project//mlir/python:LinalgOpsPackageOpDSLPyFiles", ++ ]}}, ++) ++ ++symlink_inputs( ++ name = "linalg_dialect_opdsl_lang_files", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects/linalg/opdsl/lang": [ ++ "@llvm-project//mlir/python:LinalgOpsPackageOpDSLLangPyFiles", ++ ]}}, ++) ++ ++symlink_inputs( ++ name = "linalg_dialect_opdsl_ops_files", ++ rule = py_library, ++ symlinked_inputs = {"srcs": {"dialects/linalg/opdsl/ops": [ ++ "@llvm-project//mlir/python:LinalgOpsPackageOpDSLOpsPyFiles", ++ ]}}, ++) ++ ++symlink_inputs( + name = "transform_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ @@ -128,6 +146,7 @@ + rule = py_library, + symlinked_inputs = {"srcs": {"dialects/transform": [ + "@llvm-project//mlir/python:TransformOpsPackagePyFiles", ++ "@jasc//transform_ops:transform_ops", + ]}}, + deps = [ + ":core", @@ -151,25 +170,19 @@ + ], +) + -+symlink_inputs( + symlink_inputs( name = "math_dialect", rule = py_library, - symlinked_inputs = {"srcs": {"dialects": [ - ---- a/jax/BUILD -+++ a/jax/BUILD -@@ -70,6 +70,7 @@ package_group( - # Intentionally avoid jax dependencies on jax.extend. - # See https://jax.readthedocs.io/en/latest/jep/15856-jex.html - "//third_party/py/jax/tests/...", -+ "public", - ] + jax_extend_internal_users, - ) - --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel -@@ -131,6 +131,8 @@ py_extension( +@@ -126,11 +126,14 @@ py_extension( + linkopts = LINKOPTS, + deps = [ + ":jax_dialects_capi_headers", ++ "@jasc//dialect:capi_headers", + ":jaxlib_mlir_capi_shared_library", + "@llvm-project//mlir:CAPIArithHeaders", "@llvm-project//mlir:CAPIIRHeaders", "@llvm-project//mlir:CAPIMathHeaders", "@llvm-project//mlir:CAPIMemRefHeaders", @@ -178,183 +191,65 @@ "@llvm-project//mlir:CAPITransformsHeaders", "@llvm-project//mlir:CAPIVectorHeaders", "@llvm-project//mlir:MLIRBindingsPythonHeaders", -@@ -279,7 +281,9 @@ cc_library( - "@llvm-project//mlir:CAPIArithObjects", - "@llvm-project//mlir:CAPIMathObjects", - "@llvm-project//mlir:CAPIMemRefObjects", -+ "@llvm-project//mlir:CAPIPDLObjects", - "@llvm-project//mlir:CAPISparseTensorObjects", -+ "@llvm-project//mlir:CAPITransformDialectObjects", - "@llvm-project//mlir:CAPITransformsObjects", - "@llvm-project//mlir:CAPIVectorObjects", - "@llvm-project//mlir:MLIRBindingsPythonCAPIObjects", - ---- a/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc -+++ b/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc -@@ -2,9 +2,12 @@ - // This module is called by mlir/__init__.py during initialization. - - #include "mlir-c/Dialect/Arith.h" -+// #include "mlir-c/Dialect/Bufferization.h" - #include "mlir-c/Dialect/Func.h" - #include "mlir-c/Dialect/Math.h" - #include "mlir-c/Dialect/MemRef.h" -+#include "mlir-c/Dialect/PDL.h" -+#include "mlir-c/Dialect/Transform.h" - #include "mlir-c/Dialect/Vector.h" - #include "mlir-c/Transforms.h" - #include "mlir/Bindings/Python/PybindAdaptors.h" -@@ -19,10 +22,13 @@ PYBIND11_MODULE(_site_initialize_0, m) { - - m.def("register_dialects", [](MlirDialectRegistry registry) { - REGISTER_DIALECT(arith); -+ // REGISTER_DIALECT(bufferization); - REGISTER_DIALECT(func); - REGISTER_DIALECT(math); - REGISTER_DIALECT(memref); -+ REGISTER_DIALECT(pdl); - REGISTER_DIALECT(scf); -+ REGISTER_DIALECT(transform); - REGISTER_DIALECT(vector); - mlirRegisterTransformsPasses(); - // Transforms used by JAX. - ---- a/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc -+++ b/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc -@@ -9,6 +9,7 @@ - #include "mlir-c/Dialect/PDL.h" - #include "mlir-c/Dialect/Transform.h" - #include "mlir-c/Dialect/Vector.h" -+#include "mlir-c/RegisterEverything.h" - #include "mlir-c/Transforms.h" - #include "mlir/Bindings/Python/PybindAdaptors.h" - #include "jaxlib/mlir/_mlir_libs/jax_dialects.h" -@@ -31,6 +32,7 @@ PYBIND11_MODULE(_site_initialize_0, m) { - REGISTER_DIALECT(transform); - REGISTER_DIALECT(vector); - mlirRegisterTransformsPasses(); -+ mlirRegisterAllDialects(registry); - // Transforms used by JAX. - mlirRegisterTransformsStripDebugInfo(); - }); - ---- a/jaxlib/mlir/_mlir_libs/BUILD.bazel -+++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel -@@ -279,18 +279,21 @@ cc_library( - "//jaxlib/mosaic:tpu_dialect_capi_objects", - "@com_google_protobuf//:protobuf", - "@llvm-project//mlir:CAPIArithObjects", -+ "@llvm-project//mlir:CAPIInterfacesObjects", - "@llvm-project//mlir:CAPIMathObjects", - "@llvm-project//mlir:CAPIMemRefObjects", - "@llvm-project//mlir:CAPIPDLObjects", -+ "@llvm-project//mlir:CAPIRegisterEverythingObjects", - "@llvm-project//mlir:CAPISparseTensorObjects", - "@llvm-project//mlir:CAPITransformDialectObjects", - "@llvm-project//mlir:CAPITransformsObjects", - "@llvm-project//mlir:CAPIVectorObjects", -- "@llvm-project//mlir:MLIRBindingsPythonCAPIObjects", -+ "@llvm-project//mlir:CAPIDebugObjects", -+ "@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPIObjects", - "@stablehlo//:chlo_capi_objects", - "@stablehlo//:stablehlo_capi_objects", - "@tsl//tsl/platform:env", -- "@tsl//tsl/platform:env_impl", -+ "@tsl//tsl/platform:env_impl", - "@xla//xla/mlir_hlo:CAPIObjects", - "@xla//xla:xla_data_proto_cc", - "@xla//xla:xla_data_proto_cc_impl", - ---- a/jaxlib/mlir/BUILD.bazel -+++ b/jaxlib/mlir/BUILD.bazel -@@ -120,12 +120,14 @@ symlink_inputs( - rule = py_library, - symlinked_inputs = {"srcs": {"dialects/transform": [ - "@llvm-project//mlir/python:TransformOpsPackagePyFiles", -+ "@jasc//transform_ops:transform_ops", - ]}}, - deps = [ - ":core", - ":ir", - ":mlir", - "//jaxlib/mlir/_mlir_libs:_mlirDialectsTransform", -+ "//jaxlib/mlir/_mlir_libs:_mlirTransformOpsJasc", +@@ -139,6 +142,57 @@ py_extension( ], ) -@@ -250,6 +252,20 @@ symlink_inputs( - ], - ) - -+symlink_inputs( -+ name = "jasc_dialect", -+ rule = py_library, -+ symlinked_inputs = {"srcs": {"dialects": [ -+ "@jasc//dialect:python", -+ ]}}, ++py_extension( ++ name = "_mlirDialectsTransform", ++ srcs = [ ++ "@llvm-project//mlir:lib/Bindings/Python/DialectTransform.cpp", ++ ], ++ copts = COPTS, ++ linkopts = LINKOPTS, + deps = [ -+ ":core", -+ ":ir", -+ ":mlir", -+ "//jaxlib/mlir/_mlir_libs:_mlirDialectsJasc", ++ ":jax_dialects_capi_headers", ++ ":jaxlib_mlir_capi_shared_library", ++ "@llvm-project//mlir:CAPIIRHeaders", ++ "@llvm-project//mlir:CAPITransformDialect", ++ "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", ++ "@pybind11", + ], +) -+ - symlink_inputs( - name = "mhlo_dialect", - rule = py_library, - ---- a/jaxlib/mlir/_mlir_libs/BUILD.bazel -+++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel -@@ -71,6 +71,39 @@ py_extension( - ], - ) - + +py_extension( -+ name = "_mlirDialectsJasc", ++ name = "_mlirDialectsPDL", + srcs = [ -+ "@jasc//dialect:bindings.cc", ++ "@llvm-project//mlir:lib/Bindings/Python/DialectPDL.cpp", + ], + copts = COPTS, + linkopts = LINKOPTS, + deps = [ ++ ":jax_dialects_capi_headers", + ":jaxlib_mlir_capi_shared_library", -+ "@jasc//dialect:jasc_dialect_headers", -+ "@jasc//transform_ops:jasc_transform_ops_shared_library_headers", -+ "@jasc//:mlir_lowering_shared_library_headers", -+ "@llvm-project//mlir:MLIRBindingsPythonHeaders", ++ "@llvm-project//mlir:CAPIIRHeaders", ++ "@llvm-project//mlir:CAPIPDL", ++ "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", + "@pybind11", + ], +) + +py_extension( -+ name = "_mlirTransformOpsJasc", ++ name = "_mlirDialectsLinalg", + srcs = [ -+ "@jasc//transform_ops:bindings.cpp", ++ "@llvm-project//mlir:lib/Bindings/Python/DialectLinalg.cpp", + ], + copts = COPTS, + linkopts = LINKOPTS, + deps = [ ++ ":jax_dialects_capi_headers", + ":jaxlib_mlir_capi_shared_library", -+ "@jasc//transform_ops:jasc_transform_ops_headers", -+ "@llvm-project//mlir:MLIRBindingsPythonHeaders", ++ "@llvm-project//mlir:CAPIIRHeaders", ++ "@llvm-project//mlir:CAPILinalg", ++ "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", + "@pybind11", + ], +) + - py_extension( - name = "_mlirSparseTensorPasses", - srcs = [ -@@ -126,6 +159,7 @@ py_extension( - linkopts = LINKOPTS, - deps = [ - ":jax_dialects_capi_headers", -+ "@jasc//dialect:capi_headers", - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIArithHeaders", - "@llvm-project//mlir:CAPIIRHeaders", -@@ -276,6 +310,9 @@ cc_library( + ##---------------------------------------------------------------------------## + # MHLO Extensions + ##---------------------------------------------------------------------------## +@@ -240,17 +294,30 @@ cc_library( name = "jaxlib_mlir_capi_objects", deps = [ ":jax_dialects_capi", @@ -362,12 +257,41 @@ + "@jasc//transform_ops:jasc_transform_ops", + "@jasc//:mlir_lowering", "//jaxlib/mosaic:tpu_dialect_capi_objects", - "@com_google_protobuf//:protobuf", ++ "@com_google_protobuf//:protobuf", "@llvm-project//mlir:CAPIArithObjects", ++ "@llvm-project//mlir:CAPIInterfacesObjects", + "@llvm-project//mlir:CAPIMathObjects", + "@llvm-project//mlir:CAPIMemRefObjects", ++ "@llvm-project//mlir:CAPIPDLObjects", ++ "@llvm-project//mlir:CAPIRegisterEverythingObjects", + "@llvm-project//mlir:CAPISparseTensorObjects", ++ "@llvm-project//mlir:CAPITransformDialectObjects", + "@llvm-project//mlir:CAPITransformsObjects", + "@llvm-project//mlir:CAPIVectorObjects", +- "@llvm-project//mlir:MLIRBindingsPythonCAPIObjects", ++ "@llvm-project//mlir:CAPIDebugObjects", ++ "@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPIObjects", + "@stablehlo//:chlo_capi_objects", + "@stablehlo//:stablehlo_capi_objects", ++ "@tsl//tsl/platform:env", ++ "@tsl//tsl/platform:env_impl", + "@xla//xla/mlir_hlo:CAPIObjects", ++ "@xla//xla:xla_data_proto_cc", ++ "@xla//xla:xla_data_proto_cc_impl", + ], + ) + --- a/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc +++ b/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc -@@ -13,6 +13,7 @@ +@@ -5,10 +5,14 @@ + #include "mlir-c/Dialect/Func.h" + #include "mlir-c/Dialect/Math.h" + #include "mlir-c/Dialect/MemRef.h" ++#include "mlir-c/Dialect/PDL.h" ++#include "mlir-c/Dialect/Transform.h" + #include "mlir-c/Dialect/Vector.h" ++#include "mlir-c/RegisterEverything.h" #include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/PybindAdaptors.h" #include "jaxlib/mlir/_mlir_libs/jax_dialects.h" @@ -375,205 +299,19 @@ #define REGISTER_DIALECT(name) \ MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ -@@ -25,6 +26,7 @@ PYBIND11_MODULE(_site_initialize_0, m) { +@@ -20,11 +24,15 @@ PYBIND11_MODULE(_site_initialize_0, m) { + m.def("register_dialects", [](MlirDialectRegistry registry) { REGISTER_DIALECT(arith); - // REGISTER_DIALECT(bufferization); REGISTER_DIALECT(func); + REGISTER_DIALECT(jasc); REGISTER_DIALECT(math); REGISTER_DIALECT(memref); - REGISTER_DIALECT(pdl); - - ---- a/jaxlib/mlir/_mlir_libs/BUILD.bazel -+++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel -@@ -209,6 +209,23 @@ py_extension( - ], - ) - -+py_extension( -+ name = "_mlirDialectsLinalg", -+ srcs = [ -+ "@llvm-project//mlir:lib/Bindings/Python/DialectLinalg.cpp", -+ ], -+ copts = COPTS, -+ linkopts = LINKOPTS, -+ deps = [ -+ ":jax_dialects_capi_headers", -+ ":jaxlib_mlir_capi_shared_library", -+ "@llvm-project//mlir:CAPIIRHeaders", -+ "@llvm-project//mlir:CAPILinalg", -+ "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", -+ "@pybind11", -+ ], -+) -+ - ##---------------------------------------------------------------------------## - # MHLO Extensions - ##---------------------------------------------------------------------------## - ---- a/jaxlib/mlir/BUILD.bazel -+++ b/jaxlib/mlir/BUILD.bazel -@@ -13,6 +13,7 @@ - # limitations under the License. - - load("//jaxlib:symlink_files.bzl", "symlink_inputs") -+load("@pip_deps//:requirements.bzl", "requirement") - - package( - default_visibility = [ -@@ -49,6 +50,19 @@ py_library( - ) - - symlink_inputs( -+ name = "complex_dialect", -+ rule = py_library, -+ symlinked_inputs = {"srcs": {"dialects": [ -+ "@llvm-project//mlir/python:ComplexOpsPyFiles", -+ ]}}, -+ deps = [ -+ ":core", -+ ":ir", -+ ":mlir", -+ ], -+) -+ -+symlink_inputs( - name = "func_dialect", - rule = py_library, - symlinked_inputs = {"srcs": {"dialects": [ -@@ -102,6 +116,80 @@ symlink_inputs( - ) - - symlink_inputs( -+ name = "linalg_dialect", -+ rule = py_library, -+ symlinked_inputs = {"srcs": {"dialects/linalg": [ -+ "@llvm-project//mlir/python:LinalgOpsPackagePyFiles", -+ ]}}, -+ deps = [ -+ ":complex_dialect", -+ ":core", -+ ":ir", -+ ":mlir", -+ ":linalg_dialect_gen_files", -+ ":linalg_dialect_opdsl_files", -+ ":linalg_dialect_opdsl_lang_files", -+ ":linalg_dialect_opdsl_ops_files", -+ "//jaxlib/mlir/_mlir_libs:_mlirDialectsLinalg", -+ requirement("PyYAML"), -+ ], -+) -+ -+symlink_inputs( -+ name = "linalg_dialect_gen_files", -+ rule = py_library, -+ symlinked_inputs = {"srcs": {"dialects": [ -+ "@llvm-project//mlir/python:LinalgOpsPyFiles", -+ ]}}, -+) -+ -+symlink_inputs( -+ name = "linalg_dialect_opdsl_files", -+ rule = py_library, -+ symlinked_inputs = {"srcs": {"dialects/linalg/opdsl": [ -+ "@llvm-project//mlir/python:LinalgOpsPackageOpDSLPyFiles", -+ ]}}, -+) -+ -+symlink_inputs( -+ name = "linalg_dialect_opdsl_lang_files", -+ rule = py_library, -+ symlinked_inputs = {"srcs": {"dialects/linalg/opdsl/lang": [ -+ "@llvm-project//mlir/python:LinalgOpsPackageOpDSLLangPyFiles", -+ ]}}, -+) -+ -+symlink_inputs( -+ name = "linalg_dialect_opdsl_ops_files", -+ rule = py_library, -+ symlinked_inputs = {"srcs": {"dialects/linalg/opdsl/ops": [ -+ "@llvm-project//mlir/python:LinalgOpsPackageOpDSLOpsPyFiles", -+ ]}}, -+) -+ -+# symlink_files( -+# name = "linalg_package_opdsl_files", -+# srcs = ["//third_party/llvm/llvm-project/mlir/python:LinalgOpsPackageOpDSLPyFiles"], -+# dst = "dialects/linalg/opdsl", -+# flatten = True, -+# ) -+ -+# symlink_files( -+# name = "linalg_package_opdsl_lang_files", -+# srcs = ["//third_party/llvm/llvm-project/mlir/python:LinalgOpsPackageOpDSLLangPyFiles"], -+# dst = "dialects/linalg/opdsl/lang", -+# flatten = True, -+# ) -+ -+# symlink_files( -+# name = "linalg_package_opdsl_ops_files", -+# srcs = ["//third_party/llvm/llvm-project/mlir/python:LinalgOpsPackageOpDSLOpsPyFiles"], -+# dst = "dialects/linalg/opdsl/ops", -+# flatten = True, -+# ) -+ -+ -+symlink_inputs( - name = "transform_dialect", - rule = py_library, - symlinked_inputs = {"srcs": {"dialects": [ - ---- a/jaxlib/mlir/BUILD.bazel -+++ b/jaxlib/mlir/BUILD.bazel -@@ -215,7 +215,6 @@ symlink_inputs( - ":ir", - ":mlir", - "//jaxlib/mlir/_mlir_libs:_mlirDialectsTransform", -- "//jaxlib/mlir/_mlir_libs:_mlirTransformOpsJasc", - ], - ) - - ---- a/jaxlib/mlir/_mlir_libs/BUILD.bazel -+++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel -@@ -71,39 +71,6 @@ py_extension( - ], - ) - -- --py_extension( -- name = "_mlirDialectsJasc", -- srcs = [ -- "@jasc//dialect:bindings.cc", -- ], -- copts = COPTS, -- linkopts = LINKOPTS, -- deps = [ -- ":jaxlib_mlir_capi_shared_library", -- "@jasc//dialect:jasc_dialect_headers", -- "@jasc//transform_ops:jasc_transform_ops_shared_library_headers", -- "@jasc//:mlir_lowering_shared_library_headers", -- "@llvm-project//mlir:MLIRBindingsPythonHeaders", -- "@pybind11", -- ], --) -- --py_extension( -- name = "_mlirTransformOpsJasc", -- srcs = [ -- "@jasc//transform_ops:bindings.cpp", -- ], -- copts = COPTS, -- linkopts = LINKOPTS, -- deps = [ -- ":jaxlib_mlir_capi_shared_library", -- "@jasc//transform_ops:jasc_transform_ops_headers", -- "@llvm-project//mlir:MLIRBindingsPythonHeaders", -- "@pybind11", -- ], --) -- - py_extension( - name = "_mlirSparseTensorPasses", - srcs = [ ++ REGISTER_DIALECT(pdl); + REGISTER_DIALECT(scf); ++ REGISTER_DIALECT(transform); + REGISTER_DIALECT(vector); + mlirRegisterTransformsPasses(); ++ mlirRegisterAllDialects(registry); + // Transforms used by JAX. + mlirRegisterTransformsStripDebugInfo(); + }); diff --git a/jasc/patches/llvm.patch b/jasc/patches/llvm.patch index 6e8494a0c776..f02a1d1dfbd0 100644 --- a/jasc/patches/llvm.patch +++ b/jasc/patches/llvm.patch @@ -1,8 +1,6 @@ -diff --git a/utils/bazel/llvm-project-overlay/lld/BUILD.bazel b/utils/bazel/llvm-project-overlay/lld/BUILD.bazel -index fb6e2397cc84..db259fffaa63 100644 --- a/utils/bazel/llvm-project-overlay/lld/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/lld/BUILD.bazel -@@ -108,7 +108,6 @@ cc_library( +@@ -109,7 +109,6 @@ cc_library( "//llvm:TargetParser", "//llvm:TransformUtils", "//llvm:config", @@ -10,11 +8,10 @@ index fb6e2397cc84..db259fffaa63 100644 "@llvm_zstd//:zstd", ], ) -diff --git a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel -index 0cc28fd856bc..51764826a130 100644 + --- a/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/llvm/BUILD.bazel -@@ -277,11 +277,9 @@ cc_library( +@@ -288,11 +288,9 @@ cc_library( # We unconditionally depend on the custom LLVM zlib wrapper. This will # be an empty library unless zlib is enabled, in which case it will # both provide the necessary dependencies and configuration defines. @@ -81,10 +78,7 @@ index 0cc28fd856bc..51764826a130 100644 ) td_library( - ---- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -+++ a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -@@ -8907,6 +8907,7 @@ cc_library( +@@ -8886,6 +8907,7 @@ cc_library( ":mlir_float16_utils", "//llvm:Support", ], diff --git a/jasc/patches/xla.patch b/jasc/patches/xla.patch index 4d20c9935561..d57bb8d14182 100644 --- a/jasc/patches/xla.patch +++ b/jasc/patches/xla.patch @@ -22,7 +22,7 @@ "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_api", "//xla/pjrt:pjrt_c_api_client", -@@ -1147,10 +1151,47 @@ cc_library( +@@ -1147,10 +1151,51 @@ cc_library( "//xla/pjrt/distributed:service", "//xla/python/ifrt", "//xla/python/pjrt_ifrt", @@ -32,6 +32,7 @@ + "//xla/service:buffer_assignment_proto_cc_impl", + "//xla/service:hlo_proto_cc", + "//xla/service:hlo_proto_cc_impl", ++ "//xla/stream_executor/gpu:gpu_init_impl", + "//xla/stream_executor:device_description_proto_cc", + "//xla/stream_executor:device_description_proto_cc_impl", + "//xla/stream_executor:stream_executor_impl", @@ -40,6 +41,7 @@ + "//xla:autotuning_proto_cc", + "//xla:autotuning_proto_cc_impl", "@tsl//tsl/distributed_runtime/preemption:preemption_sync_manager", ++ "@tsl//tsl/framework:allocator_registry_impl", "@tsl//tsl/platform", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:env_impl", @@ -50,6 +52,8 @@ + "@tsl//tsl/profiler/lib:profiler_session_impl", + "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", + "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc_impl", ++ "@tsl//tsl/profiler/protobuf:profiler_analysis_proto_cc", ++ "@tsl//tsl/profiler/protobuf:profiler_analysis_proto_cc_impl", + "@tsl//tsl/profiler/protobuf:profiler_options_proto_cc", + "@tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl", + "@tsl//tsl/profiler/protobuf:profiler_service_monitor_result_proto_cc", @@ -70,31 +74,3 @@ "@tsl//tsl/python/lib/core:numpy", "@pybind11", ] + select({ - ---- a/xla/python/BUILD -+++ b/xla/python/BUILD -@@ -1157,6 +1157,7 @@ cc_library( - "//xla/service:buffer_assignment_proto_cc_impl", - "//xla/service:hlo_proto_cc", - "//xla/service:hlo_proto_cc_impl", -+ "//xla/stream_executor/gpu:gpu_init_impl", - "//xla/stream_executor:device_description_proto_cc", - "//xla/stream_executor:device_description_proto_cc_impl", - "//xla/stream_executor:stream_executor_impl", -@@ -1165,6 +1166,7 @@ cc_library( - "//xla:autotuning_proto_cc", - "//xla:autotuning_proto_cc_impl", - "@tsl//tsl/distributed_runtime/preemption:preemption_sync_manager", -+ "@tsl//tsl/framework:allocator_registry_impl", - "@tsl//tsl/platform", - "@tsl//tsl/platform:env", - "@tsl//tsl/platform:env_impl", -@@ -1175,6 +1177,8 @@ cc_library( - "@tsl//tsl/profiler/lib:profiler_session_impl", - "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", - "@tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc_impl", -+ "@tsl//tsl/profiler/protobuf:profiler_analysis_proto_cc", -+ "@tsl//tsl/profiler/protobuf:profiler_analysis_proto_cc_impl", - "@tsl//tsl/profiler/protobuf:profiler_options_proto_cc", - "@tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl", - "@tsl//tsl/profiler/protobuf:profiler_service_monitor_result_proto_cc", From 7073c15cc623f9844ffd474cb83f66f3a7836523 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 19 Jan 2024 13:24:08 +0000 Subject: [PATCH 20/21] Deactivate GPU-related tests. --- jasc/test/BUILD | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/jasc/test/BUILD b/jasc/test/BUILD index adfa09a351a5..1a11b65c3251 100644 --- a/jasc/test/BUILD +++ b/jasc/test/BUILD @@ -56,7 +56,10 @@ py_test( srcs = [ "gpu_integration.py", ], - tags = ["requires-gpu-nvidia"], + tags = [ + "manual", # currently not supported + "requires-gpu-nvidia", + ], deps = [ ":gpu_integration_lib", ], @@ -126,7 +129,10 @@ py_binary( py_test( name = "batch_matmul_gpu", srcs = ["batch_matmul_gpu.py"], - tags = ["requires-gpu-nvidia"], + tags = [ + "manual", # currently not supported + "requires-gpu-nvidia", + ], deps = [ "//:jasc", requirement("chex"), @@ -140,7 +146,10 @@ py_test( py_test( name = "matmul_gpu", srcs = ["matmul_gpu.py"], - tags = ["requires-gpu-nvidia"], + tags = [ + "manual", # currently not supported + "requires-gpu-nvidia", + ], deps = [ "//:jasc", requirement("chex"), From 57b215cbf941c68ad1f17b7373fac5f0c85515e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 8 Feb 2024 08:35:17 +0000 Subject: [PATCH 21/21] Remove `-c opt` from .bazelrc. --- jasc/.bazelrc | 2 -- 1 file changed, 2 deletions(-) diff --git a/jasc/.bazelrc b/jasc/.bazelrc index e578406db017..f996814e4f41 100644 --- a/jasc/.bazelrc +++ b/jasc/.bazelrc @@ -19,5 +19,3 @@ build --define=allow_oversize_protos=true # Sets the name of JAX's MLIR native extension. This exact value is expected # by the Python files of JAX. build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. - -build -c opt