Skip to content

Commit 1323a5e

Browse files
authored
Merge pull request #4 from EnzymeAD/scuda
Statically linked cuda
2 parents 6137f40 + 6136dd2 commit 1323a5e

File tree

2 files changed

+37
-37
lines changed

2 files changed

+37
-37
lines changed

R/Reactant/build_tarballs.jl

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ include(joinpath(YGGDRASIL_DIR, "fancy_toys.jl"))
66

77
name = "Reactant"
88
repo = "https://github.com/EnzymeAD/Reactant.jl.git"
9-
reactant_commit = "d3cee1ac27e3dd9f90779c84e0a6848a802b6878"
10-
version = v"0.0.252"
9+
reactant_commit = "9d0e99d266c712dd0e85b37e87ea358a585f0564"
10+
version = v"0.0.254"
1111

1212
sources = [
1313
GitSource(repo, reactant_commit),
@@ -330,6 +330,10 @@ elif [[ "${target}" == *mingw32* ]]; then
330330
331331
332332
clang @bazel-bin/libReactantExtra.so-2.params
333+
elif [[ "${target}" == aarch64-* ]] && [[ "${HERMETIC_CUDA_VERSION}" == *13.* ]]; then
334+
$BAZEL ${BAZEL_FLAGS[@]} build ${BAZEL_BUILD_FLAGS[@]} :libReactantExtra.so || echo stage1
335+
cp /workspace/srcdir/libnvvm-linux-x86_64-*/nvvm/bin/cicc /workspace/bazel_root/*/external/cuda_nvvm/nvvm/bin/cicc
336+
$BAZEL ${BAZEL_FLAGS[@]} build ${BAZEL_BUILD_FLAGS[@]} :libReactantExtra.so
333337
else
334338
$BAZEL ${BAZEL_FLAGS[@]} build ${BAZEL_BUILD_FLAGS[@]} :libReactantExtra.so
335339
fi
@@ -345,10 +349,11 @@ if [[ "${bb_full_target}" == *gpu+cuda* ]]; then
345349
find bazel-bin
346350
find ${libdir}
347351
352+
# if [[ "${target}" == x86_64-linux-gnu ]] || [[ "${HERMETIC_CUDA_VERSION}" == *13.* ]]; then
348353
if [[ "${target}" == x86_64-linux-gnu ]]; then
349354
NVCC_DIR=(bazel-bin/libReactantExtra.so.runfiles/cuda_nvcc)
350355
else
351-
NVCC_DIR=(/workspace/srcdir/cuda_nvcc-*-archive)
356+
NVCC_DIR=(/workspace/srcdir/cuda_nvcc-linux-sbsa*-archive)
352357
fi
353358
354359
if [ -f "${NVCC_DIR[@]}/nvvm/libdevice/libdevice.10.bc" ]; then
@@ -419,7 +424,7 @@ augment_platform_block="""
419424
"""
420425

421426
# for gpu in ("none", "cuda", "rocm"), mode in ("opt", "dbg"), platform in platforms
422-
for gpu in ("none", "cuda"), mode in ("opt", "dbg"), cuda_version in ("none", "12.6", "12.8", "13.0"), platform in platforms
427+
for gpu in ("none", "cuda"), mode in ("opt", "dbg"), cuda_version in ("none", "12.9", "13.0"), platform in platforms
423428

424429
augmented_platform = deepcopy(platform)
425430
augmented_platform["mode"] = mode
@@ -469,7 +474,7 @@ for gpu in ("none", "cuda"), mode in ("opt", "dbg"), cuda_version in ("none", "1
469474
hermetic_cuda_version_map = Dict(
470475
# Our platform tags use X.Y version scheme, but for some CUDA versions we need to
471476
# pass Bazel a full version number X.Y.Z. See `CUDA_REDIST_JSON_DICT` in
472-
# <https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl>.
477+
# <https://github.com/google-ml-infra/rules_ml_toolchain/blob/main/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl>.
473478
"none" => "none",
474479
"11.8" => "11.8",
475480
"12.1" => "12.1.1",
@@ -478,7 +483,8 @@ for gpu in ("none", "cuda"), mode in ("opt", "dbg"), cuda_version in ("none", "1
478483
"12.4" => "12.4.1",
479484
"12.6" => "12.6.3",
480485
"12.8" => "12.8.1",
481-
"13.0" => "13.0.0"
486+
"12.9" => "12.9.1",
487+
"13.0" => "13.0.1"
482488
)
483489

484490
prefix="""
@@ -496,15 +502,28 @@ for gpu in ("none", "cuda"), mode in ("opt", "dbg"), cuda_version in ("none", "1
496502
end
497503

498504
if arch(platform) == "aarch64" && gpu == "cuda"
499-
if hermetic_cuda_version_map[cuda_version] == "13.0.0"
500-
# bazel currentlty tries to run external/cuda_nvcc/bin/../nvvm/bin/cicc: line 1: ELF
501-
continue
502-
505+
if hermetic_cuda_version_map[cuda_version] == "13.0.1"
506+
# See https://developer.download.nvidia.com/compute/cuda/redist/redistrib_13.0.0.json
507+
push!(platform_sources,
508+
ArchiveSource("https://developer.download.nvidia.com/compute/cuda/redist/libnvvm/linux-x86_64/libnvvm-linux-x86_64-13.0.88-archive.tar.xz",
509+
"17ef1665b63670887eeba7d908da5669fa8c66bb73b5b4c1367f49929c086353"),
510+
)
511+
push!(platform_sources,
512+
ArchiveSource("https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/linux-sbsa/cuda_nvcc-linux-sbsa-13.0.88-archive.tar.xz",
513+
"01b01e10aa2662ad1b3aeab3317151d7d6d4a650eeade55ded504f6b7fced18e"),
514+
)
515+
elseif hermetic_cuda_version_map[cuda_version] == "13.0.0"
503516
# See https://developer.download.nvidia.com/compute/cuda/redist/redistrib_13.0.0.json
504517
push!(platform_sources,
505518
ArchiveSource("https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/linux-sbsa/cuda_nvcc-linux-sbsa-13.0.48-archive.tar.xz",
506519
"3146cee5148535cb06ea5727b6cc1b0d97a85838d1d98514dc6a589ca38e1495"),
507520
)
521+
elseif hermetic_cuda_version_map[cuda_version] == "12.9.1"
522+
# See https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.8.1.json
523+
push!(platform_sources,
524+
ArchiveSource("https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/linux-sbsa/cuda_nvcc-linux-sbsa-12.9.86-archive.tar.xz",
525+
"0aa1fce92dbae76c059c27eefb9d0ffb58e1291151e44ff7c7f1fc2dd9376c0d"),
526+
)
508527
elseif hermetic_cuda_version_map[cuda_version] == "12.8.1"
509528
# See https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.8.1.json
510529
push!(platform_sources,
@@ -552,24 +571,9 @@ for gpu in ("none", "cuda"), mode in ("opt", "dbg"), cuda_version in ("none", "1
552571
if gpu == "cuda"
553572
for lib in (
554573
"libnccl",
555-
"libcufft",
556-
"libcudnn_engines_precompiled",
557-
"libcudart",
558-
"libcublasLt",
559-
"libcudnn_heuristic",
560-
"libcudnn_cnn",
561-
"libnvrtc",
562-
"libcudnn_adv",
563-
"libcudnn",
564-
"libnvJitLink",
565-
"libcublas",
566-
"libcudnn_ops",
567-
"libnvrtc-builtins",
568-
"libcudnn_graph",
569-
"libcusolver",
570574
# "libcuda",
571-
"libcudnn_engines_runtime_compiled",
572-
"libcusparse",
575+
"libnvrtc",
576+
"libnvrtc-builtins",
573577
"libnvshmem_host",
574578
"nvshmem_bootstrap_uid",
575579
"nvshmem_transport_ibrc"
@@ -614,7 +618,7 @@ for (i,build) in enumerate(builds)
614618
name, version, build.sources, build.script,
615619
build.platforms, build.products, build.dependencies;
616620
preferred_gcc_version=build.preferred_gcc_version, build.preferred_llvm_version, julia_compat="1.10",
617-
compression_format="xz",
621+
# compression_format="xz",
618622
# We use GCC 13, so we can't dlopen the library during audit
619623
augment_platform_block, lazy_artifacts=true, lock_microarchitecture=false, dont_dlopen=true,
620624
# When we're running CI for Enzyme-JAX (i.e. when the commit is

R/Reactant/platform_augmentation.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ else
3030
end
3131

3232
const cuda_version_preference = if haskey(preferences, "cuda_version")
33-
expected = ("none", "12.6", "12.8", "13.0")
33+
expected = ("none", "12.8", "13.0")
3434
if isa(preferences["cuda_version"], String) && preferences["cuda_version"] in expected
3535
preferences["cuda_version"]
3636
else
@@ -93,16 +93,12 @@ function augment_platform!(platform::Platform)
9393
Libdl.dlclose(handle)
9494

9595
if cuda_version_tag == "none" && current_cuda_version isa VersionNumber
96-
if v"12.4" <= current_cuda_version < v"12.6"
97-
cuda_version_tag = "12.6"
98-
elseif v"12.6" <= current_cuda_version < v"12.8"
99-
cuda_version_tag = "12.6"
100-
elseif v"12.8" <= current_cuda_version < v"13"
101-
cuda_version_tag = "12.8"
102-
elseif v"13.0" <= current_cuda_version < v"14" && arch(platform) == "x86_64"
96+
if v"12" <= current_cuda_version < v"13"
97+
cuda_version_tag = "12.9"
98+
elseif v"13.0" <= current_cuda_version < v"14"
10399
cuda_version_tag = "13.0"
104100
else
105-
@debug "CUDA version $(current_cuda_version) in $(path) not supported with this version of Reactant"
101+
@warn "CUDA version $(current_cuda_version) in $(path) not supported with this version of Reactant (min supported: 12)"
106102
end
107103
end
108104

0 commit comments

Comments
 (0)