Skip to content

Commit

Permalink
Enable libnvjitlink by default in OSS
Browse files Browse the repository at this point in the history
The hermetic CUDA change gained us nvjitlink support, so we can now enable if by default. But since there is a memory leak in CUDA SDK 12.4 and below we only enable it by default for later versions. Users will still be able to force-enable it through the flag.

I'm also adding the missing tsl::Flag to DebugOptionsFlag. Without that the debug option couldn't be changed on the command line.

PiperOrigin-RevId: 696405923
  • Loading branch information
beckerhe authored and Google-ML-Automation committed Nov 14, 2024
1 parent dde3c51 commit 27b4c50
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 21 deletions.
5 changes: 0 additions & 5 deletions xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@ load(
# copybara:uncomment "tf_pyclif_proto_library",
)
load("@tsl//tsl/platform:rules_cc.bzl", "cc_library")
load(
"@tsl//tsl/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
)

# copybara:uncomment load("@rules_python//python:proto.bzl", "py_proto_library")
load("//third_party/compute_library:build_defs.bzl", "if_enable_acl")
Expand Down Expand Up @@ -1147,7 +1143,6 @@ cc_library(
],
hdrs = ["debug_options_flags.h"],
copts = if_enable_acl(["-DXLA_CPU_USE_ACL=1"]),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
visibility = internal_visibility([":friends"]),
deps =
[
Expand Down
14 changes: 12 additions & 2 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_enable_llvm_module_compilation_parallelism(false);
opts.set_xla_gpu_enable_libnvptxcompiler(
stream_executor::IsLibNvPtxCompilerSupported());
opts.set_xla_gpu_enable_libnvjitlink(
stream_executor::IsLibNvJitLinkSupported());
opts.set_xla_gpu_libnvjitlink_mode(DebugOptions::LIB_NV_JIT_LINK_MODE_AUTO);

opts.set_xla_gpu_enable_dot_strength_reduction(true);

Expand Down Expand Up @@ -1852,6 +1851,17 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_enable_libnvptxcompiler(),
"Use libnvptxcompiler for PTX-to-GPU-assembly compilation instead of "
"calling ptxas."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_libnvjitlink",
[debug_options](bool enabled) {
debug_options->set_xla_gpu_libnvjitlink_mode(
enabled ? DebugOptions::LIB_NV_JIT_LINK_MODE_ENABLED
: DebugOptions::LIB_NV_JIT_LINK_MODE_DISABLED);
return true;
},
stream_executor::IsLibNvJitLinkSupported(),
"Use libnvjitlink for PTX-to-GPU-assembly compilation instead of "
"calling ptxas."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_dot_strength_reduction",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_dot_strength_reduction),
Expand Down
3 changes: 1 addition & 2 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1861,7 +1861,6 @@ cc_library(
"//xla/service/gpu/transforms:triangular_solve_rewriter",
"//xla/service/llvm_ir:llvm_util",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_memory_allocator",
"//xla/stream_executor:dnn",
"//xla/stream_executor:semantic_version",
"//xla/stream_executor:stream_executor_h",
Expand All @@ -1870,13 +1869,13 @@ cc_library(
"//xla/stream_executor/cuda:cuda_driver_version",
"//xla/stream_executor/cuda:cuda_platform_id",
"//xla/stream_executor/cuda:nvjitlink",
"//xla/stream_executor/cuda:nvjitlink_known_issues",
"//xla/stream_executor/cuda:nvjitlink_support",
"//xla/stream_executor/cuda:ptx_compilation_method",
"//xla/stream_executor/cuda:ptx_compiler",
"//xla/stream_executor/cuda:ptx_compiler_support",
"//xla/stream_executor/cuda:ptx_linking_method",
"//xla/stream_executor/gpu:gpu_asm_opts",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/tsl/util:env_var",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
Expand Down
43 changes: 40 additions & 3 deletions xla/service/gpu/nvptx_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ limitations under the License.
#include "xla/stream_executor/cuda/cuda_driver_version.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/cuda/nvjitlink.h"
#include "xla/stream_executor/cuda/nvjitlink_known_issues.h"
#include "xla/stream_executor/cuda/nvjitlink_support.h"
#include "xla/stream_executor/cuda/ptx_compilation_method.h"
#include "xla/stream_executor/cuda/ptx_compiler.h"
Expand Down Expand Up @@ -670,8 +671,35 @@ absl::StatusOr<PtxCompilationMethod> ChooseCompilationMethod(
}
};

if (!debug_options.xla_gpu_enable_libnvjitlink()) {
VLOG(3) << "Discarding NvJitLink since it is disabled.";
// This is true if the user explicitly requested the use of libNvJitLink
// through the command line flag. In that case we bypass all the sanity checks
// and enable its usage. It means compilation might fail which is a better
// diagnostic to the user instead of silently discarding NvJitLink.
const bool libnvjitlink_force_enabled =
debug_options.xla_gpu_libnvjitlink_mode() ==
DebugOptions::LIB_NV_JIT_LINK_MODE_ENABLED;

if (!stream_executor::IsLibNvJitLinkSupported() &&
!libnvjitlink_force_enabled) {
VLOG(3) << "Discarding NvJitLink since it is not supported in this build.";
remove_compilation_method(PtxCompilationMethod::kNvJitLink);
} else if (stream_executor::LoadedNvJitLinkHasKnownIssues() &&
!libnvjitlink_force_enabled) {
auto formatted_version = [&]() -> std::string {
absl::StatusOr<stream_executor::NvJitLinkVersion> version =
stream_executor::GetNvJitLinkVersion();
if (version.ok()) {
return absl::StrCat(std::get<0>(*version), ".", std::get<1>(*version));
}
return "unknown";
}();

VLOG(3) << "Discarding NvJitLink since the loaded library version ("
<< formatted_version << ") has known issues.";
remove_compilation_method(PtxCompilationMethod::kNvJitLink);
} else if (debug_options.xla_gpu_libnvjitlink_mode() ==
DebugOptions::LIB_NV_JIT_LINK_MODE_DISABLED) {
VLOG(3) << "Discarding NvJitLink since it was explicitly disabled.";
remove_compilation_method(PtxCompilationMethod::kNvJitLink);
}
if (!debug_options.xla_gpu_enable_libnvptxcompiler()) {
Expand Down Expand Up @@ -903,8 +931,17 @@ absl::StatusOr<se::PtxLinkingMethod> NVPTXCompiler::ChooseLinkingMethod(

using LinkingMethod = se::PtxLinkingMethod;

// If the user has explicitly requested NvJitLink we will try to use it and
// fail later during linking if it is not available or has known issues.
if (debug_options.xla_gpu_libnvjitlink_mode() ==
DebugOptions::LIB_NV_JIT_LINK_MODE_ENABLED) {
return LinkingMethod::kNvJitLink;
}

if (stream_executor::IsLibNvJitLinkSupported() &&
debug_options.xla_gpu_enable_libnvjitlink()) {
!stream_executor::LoadedNvJitLinkHasKnownIssues() &&
debug_options.xla_gpu_libnvjitlink_mode() !=
DebugOptions::LIB_NV_JIT_LINK_MODE_DISABLED) {
return se::PtxLinkingMethod::kNvJitLink;
}

Expand Down
8 changes: 5 additions & 3 deletions xla/service/gpu/ptx_compilation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,11 @@ class NVPTXCompilationTests
debug_options->set_xla_gpu_enable_libnvptxcompiler(
compilation_method == PtxCompilationMethod::kNvPtxCompiler);

debug_options->set_xla_gpu_enable_libnvjitlink(
compilation_method == PtxCompilationMethod::kNvJitLink ||
linking_method == PtxLinkingMethod::kNvJitLink);
debug_options->set_xla_gpu_libnvjitlink_mode(
(compilation_method == PtxCompilationMethod::kNvJitLink ||
linking_method == PtxLinkingMethod::kNvJitLink)
? DebugOptions::LIB_NV_JIT_LINK_MODE_ENABLED
: DebugOptions::LIB_NV_JIT_LINK_MODE_DISABLED);

debug_options->set_xla_gpu_enable_llvm_module_compilation_parallelism(
linking_method != PtxLinkingMethod::kNone);
Expand Down
29 changes: 25 additions & 4 deletions xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ config_setting(

bool_flag(
name = "enable_libnvjitlink_support",
build_setting_default = if_google(
True,
oss_value = False,
),
build_setting_default = True,
)

config_setting(
Expand Down Expand Up @@ -912,6 +909,30 @@ xla_cc_test(
],
)

cc_library(
name = "nvjitlink_known_issues",
srcs = ["nvjitlink_known_issues.cc"],
hdrs = ["nvjitlink_known_issues.h"],
deps = [":nvjitlink"],
)

xla_cc_test(
name = "nvjitlink_known_issues_test",
srcs = ["nvjitlink_known_issues_test.cc"],
# LibNvJitLink is a binary-only library. Therefore is not compatible with msan/tsan.
tags = [
"nomsan",
"notsan",
],
deps = [
":nvjitlink_known_issues",
":nvjitlink_support",
"@com_google_googletest//:gtest",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)

cc_library(
name = "cuda_asm_compiler",
srcs = ["cuda_asm_compiler.cc"],
Expand Down
37 changes: 37 additions & 0 deletions xla/stream_executor/cuda/nvjitlink_known_issues.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/cuda/nvjitlink_known_issues.h"

#include "xla/stream_executor/cuda/nvjitlink.h"

namespace stream_executor {

bool LoadedNvJitLinkHasKnownIssues() {
// There is a memory leak in libnvjitlink from version 12.0 to 12.4.
// The memory leak was fixed in CUDA Toolkit 12.4 Update 1, but we can't
// distinguish between NvJitLink coming from CUDA Toolkit 12.4 and 12.4
// Update 1. Therefore we only return true for 12.5 and higher to be on the
// safe side.
constexpr NvJitLinkVersion kMinVersionWithoutKnownIssues{12, 5};

// Note that this needs to be a runtime version test because we load
// LibNvJitLink as a dynamic library and the version might vary and not be the
// same that we saw at compile time.
return GetNvJitLinkVersion().value_or(NvJitLinkVersion{0, 0}) >=
kMinVersionWithoutKnownIssues;
}

} // namespace stream_executor
28 changes: 28 additions & 0 deletions xla/stream_executor/cuda/nvjitlink_known_issues.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_STREAM_EXECUTOR_CUDA_NVJITLINK_KNOWN_ISSUES_H_
#define XLA_STREAM_EXECUTOR_CUDA_NVJITLINK_KNOWN_ISSUES_H_

namespace stream_executor {

// Returns true if the loaded NvJitLink library is known to have bugs and
// shouldn't be used unconditionally. Returns false otherwise - also returns
// false if NvJitLink is not available.
bool LoadedNvJitLinkHasKnownIssues();

} // namespace stream_executor

#endif // XLA_STREAM_EXECUTOR_CUDA_NVJITLINK_KNOWN_ISSUES_H_
33 changes: 33 additions & 0 deletions xla/stream_executor/cuda/nvjitlink_known_issues_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/cuda/nvjitlink_known_issues.h"

#include <gtest/gtest.h>
#include "xla/stream_executor/cuda/nvjitlink_support.h"
#include "tsl/platform/test.h"

namespace {

TEST(NvJitLinkKnownIssuesTest, ReturnsFalseWhenNvJitLinkIsNotAvailable) {
if (stream_executor::IsLibNvJitLinkSupported()) {
GTEST_SKIP();
}
// This is the only invariance we can test without writing a pointless change
// detector test.
EXPECT_FALSE(stream_executor::LoadedNvJitLinkHasKnownIssues());
}

} // namespace
16 changes: 14 additions & 2 deletions xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -966,8 +966,20 @@ message DebugOptions {
// but potentially higher the performance.
int32 xla_gpu_cudnn_gemm_max_plans = 318;

enum LibNvJitLinkMode {
// LibNvJitLink is used if it is available and no buggy version has been
// detected.
LIB_NV_JIT_LINK_MODE_AUTO = 0;
// LibNvJitLink is never used.
LIB_NV_JIT_LINK_MODE_DISABLED = 1;
// LibNvJitLink is used always. If it is not available, compilation will
// fail.
LIB_NV_JIT_LINK_MODE_ENABLED = 2;
}

// If enabled, uses the libnvjitlink library for PTX compilation and linking
bool xla_gpu_enable_libnvjitlink = 319;
LibNvJitLinkMode xla_gpu_libnvjitlink_mode = 343;
reserved 319; // was xla_gpu_enable_libnvjitlink with boolean type.

// If true, XLA will wrap `dot` operations into async computations in an
// effort to parallelize matrix operations.
Expand Down Expand Up @@ -1030,7 +1042,7 @@ message DebugOptions {
}
PGLEStrictnessLevel xla_gpu_pgle_accuracy_checker = 341;

// Next id: 343
// Next id: 344

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit 27b4c50

Please sign in to comment.