Skip to content

Commit

Permalink
build xla as static lib
Browse files Browse the repository at this point in the history
  • Loading branch information
sphw committed Nov 7, 2023
1 parent 9de816c commit cd05d0a
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 11 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ jobs:
# an older version of glibc
runs-on: ubuntu-20.04
steps:
- name: Install libtool-bin
run: |
sudo apt-get update && sudo apt-get install -y libtool-bin
- uses: actions/checkout@v3
- uses: erlef/setup-beam@v1
with:
Expand Down Expand Up @@ -149,7 +152,7 @@ jobs:
# Add repository with the latest git version for action/checkout to properly clone the repo
add-apt-repository ppa:git-core/ppa
# We run as root, so sudo is not necessary per se, but some actions (like setup-bazel) make use of it
apt-get update && apt-get install -y ca-certificates curl git sudo unzip wget
apt-get update && apt-get install -y ca-certificates curl git sudo unzip wget libtool-bin
# Install GitHub CLI used by our scripts
curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg | gpg --dearmor -o /usr/share/keyrings/githubcli-archive-keyring.gpg
echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null
Expand Down Expand Up @@ -196,6 +199,9 @@ jobs:
# an older version of glibc
runs-on: ubuntu-20.04
steps:
- name: Install libtool-bin
run: |
sudo apt-get update && sudo apt-get install -y libtool-bin
- uses: actions/checkout@v3
- uses: erlef/setup-beam@v1
with:
Expand Down
82 changes: 72 additions & 10 deletions extension/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm",)
load("@tsl//tsl:tsl.bzl", "tsl_grpc_cc_dependencies",)
load("@tsl//tsl:tsl.bzl", "transitive_hdrs",)
load("@rules_pkg//pkg:tar.bzl", "pkg_tar")
load(":static-lib.bzl", "cc_static_library")


package(default_visibility=["//visibility:private"])

# Static library which contains dependencies necessary for building on
# top of XLA
cc_binary(
name = "libxla_extension.so",
cc_static_library(
name = "libxla_extension",
deps = [
"//xla:xla_proto_cc_impl",
"//xla:xla_data_proto_cc_impl",
Expand Down Expand Up @@ -89,10 +91,6 @@ cc_binary(
+ if_rocm([
"//xla/stream_executor:rocm_platform"
]),
copts = ["-fvisibility=default"],
linkopts = ["-shared"],
features = ["-use_header_modules"],
linkshared = 1,
)

# Transitive hdrs gets all headers required by deps, including
Expand All @@ -101,7 +99,71 @@ cc_binary(
transitive_hdrs(
name = "xla_extension_dep_headers",
deps = [
":libxla_extension.so",
"//xla:xla_proto_cc_impl",
"//xla:xla_data_proto_cc_impl",
"//xla/service:hlo_proto_cc_impl",
"//xla/service:memory_space_assignment_proto_cc_impl",
"//xla/service/gpu:backend_configs_cc_impl",
"//xla/stream_executor:dnn_proto_cc_impl",
"//xla:literal",
"//xla:shape_util",
"//xla:status",
"//xla:statusor",
"//xla:types",
"//xla:util",
"//xla/client:xla_computation",
"//xla/mlir/utils:error_util",
"//xla/mlir_hlo",
"//xla/mlir_hlo:all_passes",
"//xla/pjrt:mlir_to_hlo",
"//xla/client/lib:lu_decomposition",
"//xla/client/lib:math",
"//xla/client/lib:qr",
"//xla/client/lib:svd",
"//xla/client/lib:self_adjoint_eig",
"//xla/client/lib:sorting",
"//xla/mlir_hlo:mhlo_passes",
"//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo",
"//xla/pjrt:interpreter_device",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_compiler",
"//xla/pjrt:tfrt_cpu_pjrt_client",
"//xla/pjrt:pjrt_c_api_client",
"//xla/pjrt:tpu_client",
"//xla/pjrt:pjrt_plugin_device_client",
"//xla/pjrt:pjrt_plugin_device_client_headers",
"//xla/pjrt/distributed",
"//xla/pjrt/gpu:se_gpu_pjrt_client",
"//xla/pjrt/distributed:client",
"//xla/pjrt/distributed:service",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/base:log_severity",
"@com_google_protobuf//:protobuf",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:ReconcileUnrealizedCasts",
"@llvm-project//mlir:SparseTensorDialect",
"@tf_runtime//:core_runtime",
"@tf_runtime//:hostcontext",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:fingerprint",
"@tsl//tsl/platform:float8",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:env_impl",
"@tsl//tsl/platform:tensor_float_32_utils",
"@tsl//tsl/profiler/utils:time_utils_impl",
"@tsl//tsl/profiler/backends/cpu:annotation_stack_impl",
"@tsl//tsl/profiler/backends/cpu:traceme_recorder_impl",
"@tsl//tsl/protobuf:autotuning_proto_cc_impl",
"@tsl//tsl/protobuf:protos_all_cc_impl",
"@tsl//tsl/protobuf:dnn_proto_cc_impl",
"@tsl//tsl/framework:allocator",
"@tsl//tsl/framework:allocator_registry_impl",
"@tsl//tsl/util:determinism",
]
)

Expand Down Expand Up @@ -154,16 +216,16 @@ genrule(
""",
)

# This genrule remaps libxla_extension.so to lib/libxla_extension.so
# This genrule remaps libxla_extension.a to lib/libxla_extension.a
genrule(
name = "xla_extension_lib",
srcs = [
":libxla_extension.so",
":libxla_extension",
],
outs = ["lib"],
cmd = """
mkdir $@
mv $(location :libxla_extension.so) $@
cp $(locations :libxla_extension) $@
"""
)

Expand Down
72 changes: 72 additions & 0 deletions extension/static-lib.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Provides a rule that outputs a monolithic static library."""

# Reference: https://gist.github.com/oquenchil/3f88a39876af2061f8aad6cdc9d7c045

load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain")

TOOLS_CPP_REPO = "@bazel_tools"

def _cc_static_library_impl(ctx):
output_lib = ctx.actions.declare_file("{}.a".format(ctx.attr.name))
output_flags = ctx.actions.declare_file("{}.link".format(ctx.attr.name))

cc_toolchain = find_cpp_toolchain(ctx)

# Aggregate linker inputs of all dependencies
lib_sets = []
for dep in ctx.attr.deps:
lib_sets.append(dep[CcInfo].linking_context.linker_inputs)
input_depset = depset(transitive = lib_sets)

# Collect user link flags and make sure they are unique
unique_flags = {}
for inp in input_depset.to_list():
unique_flags.update({
flag: None
for flag in inp.user_link_flags
})
link_flags = unique_flags.keys()

# Collect static libraries
libs = []
for inp in input_depset.to_list():
for lib in inp.libraries:
if lib.pic_static_library:
libs.append(lib.pic_static_library)
elif lib.static_library:
libs.append(lib.static_library)

lib_paths = [lib.path for lib in libs]

ar_path = cc_toolchain.ar_executable
# FIXME ar_executable returned llvm-lib.exe on my system, but we want llvm-ar.exe
ar_path = ar_path.replace("llvm-lib.exe", "llvm-ar.exe")
ar_path = ""

ctx.actions.run_shell(
command = "libtool -static -o {0} {1}".format(output_lib.path, " ".join(lib_paths)),
#command = "\"{0}\" rcT {1} {2} && echo -e 'create {1}\naddlib {1}\nsave\nend' | \"{0}\" -M".format(ar_path, output_lib.path, " ".join(lib_paths)),
inputs = libs + cc_toolchain.all_files.to_list(),
outputs = [output_lib],
mnemonic = "ArMerge",
progress_message = "Merging static library {}".format(output_lib.path),
)
ctx.actions.write(
output = output_flags,
content = "\n".join(link_flags) + "\n",
)
return [
DefaultInfo(files = depset([output_flags, output_lib])),
]

cc_static_library = rule(
implementation = _cc_static_library_impl,
attrs = {
"deps": attr.label_list(),
"_cc_toolchain": attr.label(
default = TOOLS_CPP_REPO + "//tools/cpp:current_cc_toolchain",
),
},
toolchains = [TOOLS_CPP_REPO + "//tools/cpp:toolchain_type"],
incompatible_use_toolchain_transition = True,
)

0 comments on commit cd05d0a

Please sign in to comment.