Skip to content

Commit

Permalink
Disable more binary libraries if the disable flag is true.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675567404
  • Loading branch information
klucke authored and Google-ML-Automation committed Sep 17, 2024
1 parent 888981a commit 28f1b88
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 70 deletions.
4 changes: 2 additions & 2 deletions third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
LLVM_COMMIT = "b39a100ff4ec16f1f9cafcc48ea7fed920726650"
LLVM_SHA256 = "d9deed58ce9f004a2fb1fe810f19eb4235f560c79b0e24d96b67a09719ac5513"
LLVM_COMMIT = "c23d6df60d62f971d957e730f6fe55ea89541f6b"
LLVM_SHA256 = "0802cc4a9f52a0c3694d508a903da619dbe205005f59635d8e167fbe4008670c"

tf_http_archive(
name = name,
Expand Down
43 changes: 5 additions & 38 deletions third_party/shardy/temporary.patch
Original file line number Diff line number Diff line change
@@ -1,48 +1,15 @@
diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch
index ed6047c..509398d 100644
--- a/third_party/llvm/generated.patch
+++ b/third_party/llvm/generated.patch
@@ -1,28 +1 @@
Auto generated patch. Do not edit or delete it, even if empty.
-diff -ruN --strip-trailing-cr a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c
---- a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c
-+++ b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c
-@@ -1,4 +1,4 @@
--// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-512 -emit-llvm -Wall -Werror -verify
-+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-512 -Wall -Werror -verify
-
- #include <immintrin.h>
- #include <stddef.h>
-diff -ruN --strip-trailing-cr a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c
---- a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c
-+++ b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c
-@@ -1,4 +1,4 @@
--// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-unknown -target-feature +avx10.2-512 -emit-llvm -Wall -Werror -verify
-+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-unknown -target-feature +avx10.2-512 -Wall -Werror -verify
-
- #include <immintrin.h>
- #include <stddef.h>
-diff -ruN --strip-trailing-cr a/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c b/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c
---- a/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c
-+++ b/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c
-@@ -1,4 +1,4 @@
--// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-256 -emit-llvm -Wall -Werror -verify
-+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-256 -Wall -Werror -verify
-
- unsigned long long test_mm_cvttssd(unsigned long long __A) {
- return _mm_cvttssd(__A); // expected-error {{call to undeclared function '_mm_cvttssd'}}
diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl
index aaf1f1a..5a83cba 100644
index 5a83cba..3686a35 100644
--- a/third_party/llvm/workspace.bzl
+++ b/third_party/llvm/workspace.bzl
@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
- LLVM_COMMIT = "f0b3287297aeeddcf030e3c1b08d05a69ad465aa"
- LLVM_SHA256 = "3bc65e7a760a389f5ace1146cb2ffde724a272e97e71c8b8509149e827df6c83"
+ LLVM_COMMIT = "b39a100ff4ec16f1f9cafcc48ea7fed920726650"
+ LLVM_SHA256 = "d9deed58ce9f004a2fb1fe810f19eb4235f560c79b0e24d96b67a09719ac5513"
- LLVM_COMMIT = "b39a100ff4ec16f1f9cafcc48ea7fed920726650"
- LLVM_SHA256 = "d9deed58ce9f004a2fb1fe810f19eb4235f560c79b0e24d96b67a09719ac5513"
+ LLVM_COMMIT = "c23d6df60d62f971d957e730f6fe55ea89541f6b"
+ LLVM_SHA256 = "0802cc4a9f52a0c3694d508a903da619dbe205005f59635d8e167fbe4008670c"

tf_http_archive(
name = name,
4 changes: 2 additions & 2 deletions third_party/shardy/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
SHARDY_COMMIT = "db1ae2bba4c609870409119c42323252a85e61a2"
SHARDY_SHA256 = "ee22cdd9893bc0005a1a20ecbf2aebdbc5a731aeedae8d903b0e85af23abb519"
SHARDY_COMMIT = "5f8a058aa5b5a8ae60d2e06dbd0a82551af731b3"
SHARDY_SHA256 = "105054f71ecc1445550c8f81d693300c05fe299817933c53c7e1b829bea8ebf5"

tf_http_archive(
name = "shardy",
Expand Down
4 changes: 2 additions & 2 deletions third_party/tsl/third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
LLVM_COMMIT = "b39a100ff4ec16f1f9cafcc48ea7fed920726650"
LLVM_SHA256 = "d9deed58ce9f004a2fb1fe810f19eb4235f560c79b0e24d96b67a09719ac5513"
LLVM_COMMIT = "c23d6df60d62f971d957e730f6fe55ea89541f6b"
LLVM_SHA256 = "0802cc4a9f52a0c3694d508a903da619dbe205005f59635d8e167fbe4008670c"

tf_http_archive(
name = name,
Expand Down
54 changes: 28 additions & 26 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,36 +607,38 @@ GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) {
const HloDotInstruction* dot =
Cast<HloDotInstruction>(hlo_query::GetFirstInstructionWithOpcode(
*fusion.called_computations().at(0), HloOpcode::kDot));

// Add cuBLAS reference config, if available.
std::vector<BackendConfig> configs;
if (algorithm_util::IsSupportedByCublasOrCublasLt(
dot->precision_config().algorithm()) &&
!dot->sparse_operands() && IsAutotuningEnabled()) {
configs.push_back(CuBlasConfig{});
}

// Add cuDNN plans, if available.
bool is_hopper =
!config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper();
bool is_cudnn_enabled =
debug_options_.xla_gpu_cudnn_gemm_fusion_level() > 0 && is_hopper &&
GetDnnVersionInfoOrDefault(config_.GetExecutor()).major_version() >= 9;
if ((IsFusionKind(fusion, kCuDnnFusionKind) && IsAutotuningEnabled()) ||
(IsFusionKind(fusion, kTritonGemmFusionKind) && is_cudnn_enabled &&
algorithm_util::IsSupportedByCudnn(
dot->precision_config().algorithm()) &&
!dot->sparse_operands() && IsAutotuningEnabled())) {
const int plan_count = GetCuDnnPlanCount(fusion, config_);
for (int plan_id = 0; plan_id < plan_count; ++plan_id) {
configs.push_back(CuDnnConfig{plan_id});
if (!debug_options_.xla_gpu_experimental_disable_binary_libraries()) {
// Add cuBLAS reference config, if available.
if (algorithm_util::IsSupportedByCublasOrCublasLt(
dot->precision_config().algorithm()) &&
!dot->sparse_operands() && IsAutotuningEnabled()) {
configs.push_back(CuBlasConfig{});
}
}
if (IsFusionKind(fusion, kCuDnnFusionKind)) {
if (!IsAutotuningEnabled()) {
configs.push_back(CuDnnConfig{-1});

// Add cuDNN plans, if available.
bool is_hopper =
!config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper();
bool is_cudnn_enabled =
debug_options_.xla_gpu_cudnn_gemm_fusion_level() > 0 && is_hopper &&
GetDnnVersionInfoOrDefault(config_.GetExecutor()).major_version() >= 9;
if ((IsFusionKind(fusion, kCuDnnFusionKind) && IsAutotuningEnabled()) ||
(IsFusionKind(fusion, kTritonGemmFusionKind) && is_cudnn_enabled &&
algorithm_util::IsSupportedByCudnn(
dot->precision_config().algorithm()) &&
!dot->sparse_operands() && IsAutotuningEnabled())) {
const int plan_count = GetCuDnnPlanCount(fusion, config_);
for (int plan_id = 0; plan_id < plan_count; ++plan_id) {
configs.push_back(CuDnnConfig{plan_id});
}
}
if (IsFusionKind(fusion, kCuDnnFusionKind)) {
if (!IsAutotuningEnabled()) {
configs.push_back(CuDnnConfig{-1});
}
return configs;
}
return configs;
}

// Add triton configs.
Expand Down
5 changes: 5 additions & 0 deletions xla/service/gpu/nvptx_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,11 @@ bool NVPTXCompiler::RequiresCollectiveScheduleLinearizer(
absl::Status NVPTXCompiler::AddConvAndGemmAutotuningPasses(
HloPassPipeline* pipeline, HloModule* hlo_module,
AutotuneConfig& autotune_config, tsl::thread::ThreadPool* thread_pool) {
if (hlo_module->config()
.debug_options()
.xla_gpu_experimental_disable_binary_libraries()) {
return absl::OkStatus();
}
if (GpuConvAlgorithmPicker::IsEnabled(hlo_module)) {
pipeline->AddPass<GpuConvAlgorithmPicker>(autotune_config);
}
Expand Down

0 comments on commit 28f1b88

Please sign in to comment.