From 6fd82343e4ce3729680931fd6cc240cbe0fbe237 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Wed, 2 Oct 2024 13:43:33 -0700 Subject: [PATCH] Make rocm files use tsl DsoLoader functions instead of the stream_executor wrappers. PiperOrigin-RevId: 681577126 --- xla/stream_executor/rocm/BUILD | 10 ++++ xla/stream_executor/rocm/hipblaslt_wrapper.h | 32 ++++++------ xla/stream_executor/rocm/hipsolver_wrapper.h | 32 ++++++------ xla/stream_executor/rocm/hipsparse_wrapper.h | 50 +++++++++---------- xla/stream_executor/rocm/rocblas_wrapper.h | 3 +- xla/stream_executor/rocm/rocm_dnn.cc | 3 +- .../rocm/rocm_driver_wrapper.h | 33 ++++++------ xla/stream_executor/rocm/rocm_fft.cc | 3 +- xla/stream_executor/rocm/rocsolver_wrapper.h | 32 ++++++------ xla/stream_executor/rocm/roctracer_wrapper.h | 32 ++++++------ 10 files changed, 121 insertions(+), 109 deletions(-) diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index f6e8a867e2b48..5a510dc1908ad 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -86,6 +86,7 @@ cc_library( "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", "@tsl//tsl/platform:casts", + "@tsl//tsl/platform:dso_loader", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", @@ -353,6 +354,7 @@ cc_library( "//xla/tsl/util:determinism_for_kernels", "@local_config_rocm//rocm:rocm_headers", "@tsl//tsl/platform", + "@tsl//tsl/platform:dso_loader", "@tsl//tsl/platform:env", ], alwayslink = True, @@ -445,6 +447,7 @@ cc_library( "//xla/stream_executor/gpu:scoped_activate_context", "//xla/stream_executor/platform", "@local_config_rocm//rocm:rocm_headers", + "@tsl//tsl/platform:dso_loader", "@tsl//tsl/platform:env", "@tsl//tsl/platform:logging", ], @@ -510,6 +513,7 @@ cc_library( "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@local_config_rocm//rocm:rocm_headers", + "@tsl//tsl/platform:dso_loader", "@tsl//tsl/platform:env", "@tsl//tsl/platform:env_impl", "@tsl//tsl/platform:errors", @@ -564,6 +568,7 @@ cc_library( ":rocm_platform_id", "//xla/stream_executor/platform", "@local_config_rocm//rocm:rocm_headers", + "@tsl//tsl/platform:dso_loader", "@tsl//tsl/platform:env", ], alwayslink = True, @@ -600,6 +605,7 @@ cc_library( ":rocsolver_if_static", "//xla/stream_executor/platform", "@local_config_rocm//rocm:rocm_headers", + "@tsl//tsl/platform:dso_loader", "@tsl//tsl/platform:env", ], alwayslink = True, @@ -635,6 +641,7 @@ cc_library( ":rocm_platform_id", "//xla/stream_executor/platform", "@local_config_rocm//rocm:rocm_headers", + "@tsl//tsl/platform:dso_loader", "@tsl//tsl/platform:env", ], alwayslink = True, @@ -691,6 +698,7 @@ cc_library( "//xla/stream_executor/platform", "@com_google_absl//absl/status", "@local_config_rocm//rocm:rocm_headers", + "@tsl//tsl/platform:dso_loader", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status", @@ -725,6 +733,7 @@ cc_library( "//xla/stream_executor/platform", "@com_google_absl//absl/status", "@local_config_rocm//rocm:rocm_headers", + "@tsl//tsl/platform:dso_loader", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:status", @@ -785,6 +794,7 @@ cc_library( "//xla/stream_executor/platform", "@local_config_rocm//rocm:rocm_headers", "@tsl//tsl/platform", + "@tsl//tsl/platform:dso_loader", "@tsl//tsl/platform:env", ], alwayslink = True, diff --git a/xla/stream_executor/rocm/hipblaslt_wrapper.h b/xla/stream_executor/rocm/hipblaslt_wrapper.h index 8a20e7543f1ab..e5c7c22e4a945 100644 --- a/xla/stream_executor/rocm/hipblaslt_wrapper.h +++ b/xla/stream_executor/rocm/hipblaslt_wrapper.h @@ -28,6 +28,7 @@ limitations under the License. #include "rocm/include/hipblaslt.h" #endif #include "xla/stream_executor/platform/port.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" namespace stream_executor { @@ -46,22 +47,21 @@ namespace wrap { #define TO_STR_(x) #x #define TO_STR(x) TO_STR_(x) -#define HIPBLASLT_API_WRAPPER(api_name) \ - template \ - auto api_name(Args... args) -> decltype(::api_name(args...)) { \ - using FuncPtrT = std::add_pointer::type; \ - static FuncPtrT loaded = []() -> FuncPtrT { \ - static const char* kName = TO_STR(api_name); \ - void* f; \ - auto s = tsl::Env::Default() -> GetSymbolFromLibrary( \ - stream_executor::internal::CachedDsoLoader::GetHipblasltDsoHandle() \ - .value(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in hipblaslt lib; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - }(); \ - return loaded(args...); \ +#define HIPBLASLT_API_WRAPPER(api_name) \ + template \ + auto api_name(Args... args) -> decltype(::api_name(args...)) { \ + using FuncPtrT = std::add_pointer::type; \ + static FuncPtrT loaded = []() -> FuncPtrT { \ + static const char* kName = TO_STR(api_name); \ + void* f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ + tsl::internal::CachedDsoLoader::GetHipblasltDsoHandle().value(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in hipblaslt lib; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + }(); \ + return loaded(args...); \ } #endif diff --git a/xla/stream_executor/rocm/hipsolver_wrapper.h b/xla/stream_executor/rocm/hipsolver_wrapper.h index 6f05fcddee3fe..ec84b86b50d6d 100644 --- a/xla/stream_executor/rocm/hipsolver_wrapper.h +++ b/xla/stream_executor/rocm/hipsolver_wrapper.h @@ -30,6 +30,7 @@ limitations under the License. #include "rocm/include/hipsolver.h" #endif #include "xla/stream_executor/platform/port.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" namespace stream_executor { @@ -48,22 +49,21 @@ namespace wrap { #define TO_STR_(x) #x #define TO_STR(x) TO_STR_(x) -#define HIPSOLVER_API_WRAPPER(api_name) \ - template \ - auto api_name(Args... args) -> decltype(::api_name(args...)) { \ - using FuncPtrT = std::add_pointer::type; \ - static FuncPtrT loaded = []() -> FuncPtrT { \ - static const char* kName = TO_STR(api_name); \ - void* f; \ - auto s = tsl::Env::Default() -> GetSymbolFromLibrary( \ - stream_executor::internal::CachedDsoLoader::GetHipsolverDsoHandle() \ - .value(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in hipsolver lib; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - }(); \ - return loaded(args...); \ +#define HIPSOLVER_API_WRAPPER(api_name) \ + template \ + auto api_name(Args... args) -> decltype(::api_name(args...)) { \ + using FuncPtrT = std::add_pointer::type; \ + static FuncPtrT loaded = []() -> FuncPtrT { \ + static const char* kName = TO_STR(api_name); \ + void* f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ + tsl::internal::CachedDsoLoader::GetHipsolverDsoHandle().value(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in hipsolver lib; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + }(); \ + return loaded(args...); \ } #endif diff --git a/xla/stream_executor/rocm/hipsparse_wrapper.h b/xla/stream_executor/rocm/hipsparse_wrapper.h index 3cea716263ba4..df214bfc5bb19 100644 --- a/xla/stream_executor/rocm/hipsparse_wrapper.h +++ b/xla/stream_executor/rocm/hipsparse_wrapper.h @@ -29,6 +29,7 @@ limitations under the License. #endif #include "xla/stream_executor/platform/platform.h" #include "xla/stream_executor/platform/port.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" namespace stream_executor { @@ -47,31 +48,30 @@ namespace wrap { #else -#define HIPSPARSE_API_WRAPPER(__name) \ - static struct DynLoadShim__##__name { \ - constexpr static const char* kName = #__name; \ - using FuncPtrT = std::add_pointer::type; \ - static void* GetDsoHandle() { \ - auto s = \ - stream_executor::internal::CachedDsoLoader::GetHipsparseDsoHandle(); \ - return s.value(); \ - } \ - static FuncPtrT LoadOrDie() { \ - void* f; \ - auto s = tsl::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in miopen DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - } \ - static FuncPtrT DynLoad() { \ - static FuncPtrT f = LoadOrDie(); \ - return f; \ - } \ - template \ - hipsparseStatus_t operator()(Args... args) { \ - return DynLoad()(args...); \ - } \ +#define HIPSPARSE_API_WRAPPER(__name) \ + static struct DynLoadShim__##__name { \ + constexpr static const char* kName = #__name; \ + using FuncPtrT = std::add_pointer::type; \ + static void* GetDsoHandle() { \ + auto s = tsl::internal::CachedDsoLoader::GetHipsparseDsoHandle(); \ + return s.value(); \ + } \ + static FuncPtrT LoadOrDie() { \ + void* f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in miopen DSO; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + } \ + static FuncPtrT DynLoad() { \ + static FuncPtrT f = LoadOrDie(); \ + return f; \ + } \ + template \ + hipsparseStatus_t operator()(Args... args) { \ + return DynLoad()(args...); \ + } \ } __name; #endif diff --git a/xla/stream_executor/rocm/rocblas_wrapper.h b/xla/stream_executor/rocm/rocblas_wrapper.h index 2b74fc2458761..82afa866bf30c 100644 --- a/xla/stream_executor/rocm/rocblas_wrapper.h +++ b/xla/stream_executor/rocm/rocblas_wrapper.h @@ -26,6 +26,7 @@ limitations under the License. #include "rocm/include/rocblas/rocblas.h" #include "rocm/rocm_config.h" #include "xla/stream_executor/platform/port.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" #include "tsl/platform/platform.h" @@ -43,7 +44,7 @@ namespace wrap { } __name; #else -using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; +using tsl::internal::CachedDsoLoader::GetRocblasDsoHandle; #define ROCBLAS_API_WRAPPER(__name) \ static struct DynLoadShim__##__name { \ diff --git a/xla/stream_executor/rocm/rocm_dnn.cc b/xla/stream_executor/rocm/rocm_dnn.cc index 6de2a3e349fde..b97151517d804 100644 --- a/xla/stream_executor/rocm/rocm_dnn.cc +++ b/xla/stream_executor/rocm/rocm_dnn.cc @@ -45,6 +45,7 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/util/determinism.h" #include "xla/tsl/util/env_var.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/hash.h" @@ -248,7 +249,7 @@ namespace wrap { static const char* kName; \ using FuncPtrT = std::add_pointer::type; \ static void* GetDsoHandle() { \ - auto s = internal::CachedDsoLoader::GetMiopenDsoHandle(); \ + auto s = tsl::internal::CachedDsoLoader::GetMiopenDsoHandle(); \ return s.value(); \ } \ static FuncPtrT LoadOrDie() { \ diff --git a/xla/stream_executor/rocm/rocm_driver_wrapper.h b/xla/stream_executor/rocm/rocm_driver_wrapper.h index 89207e9c440db..901d220d8e2f8 100644 --- a/xla/stream_executor/rocm/rocm_driver_wrapper.h +++ b/xla/stream_executor/rocm/rocm_driver_wrapper.h @@ -24,7 +24,7 @@ limitations under the License. #include "rocm/include/hip/hip_runtime.h" #include "rocm/rocm_config.h" -#include "xla/stream_executor/platform/port.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" namespace stream_executor { @@ -46,22 +46,21 @@ namespace wrap { #define TO_STR_(x) #x #define TO_STR(x) TO_STR_(x) -#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \ - template \ - auto hipSymbolName(Args... args) -> decltype(::hipSymbolName(args...)) { \ - using FuncPtrT = std::add_pointer::type; \ - static FuncPtrT loaded = []() -> FuncPtrT { \ - static const char *kName = TO_STR(hipSymbolName); \ - void *f; \ - auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ - stream_executor::internal::CachedDsoLoader::GetHipDsoHandle() \ - .value(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in HIP DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - }(); \ - return loaded(args...); \ +#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \ + template \ + auto hipSymbolName(Args... args) -> decltype(::hipSymbolName(args...)) { \ + using FuncPtrT = std::add_pointer::type; \ + static FuncPtrT loaded = []() -> FuncPtrT { \ + static const char *kName = TO_STR(hipSymbolName); \ + void *f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ + tsl::internal::CachedDsoLoader::GetHipDsoHandle().value(), kName, \ + &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in HIP DSO; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + }(); \ + return loaded(args...); \ } #endif diff --git a/xla/stream_executor/rocm/rocm_fft.cc b/xla/stream_executor/rocm/rocm_fft.cc index e62786a5adf38..dc80077a3f74c 100644 --- a/xla/stream_executor/rocm/rocm_fft.cc +++ b/xla/stream_executor/rocm/rocm_fft.cc @@ -28,6 +28,7 @@ limitations under the License. #include "xla/stream_executor/rocm/rocm_complex_converters.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" @@ -60,7 +61,7 @@ namespace wrap { static const char *kName; \ using FuncPtrT = std::add_pointer::type; \ static void *GetDsoHandle() { \ - auto s = internal::CachedDsoLoader::GetHipfftDsoHandle(); \ + auto s = tsl::internal::CachedDsoLoader::GetHipfftDsoHandle(); \ return s.value(); \ } \ static FuncPtrT LoadOrDie() { \ diff --git a/xla/stream_executor/rocm/rocsolver_wrapper.h b/xla/stream_executor/rocm/rocsolver_wrapper.h index cfeea5f54b667..7d5e6de5e9ac7 100644 --- a/xla/stream_executor/rocm/rocsolver_wrapper.h +++ b/xla/stream_executor/rocm/rocsolver_wrapper.h @@ -28,6 +28,7 @@ limitations under the License. #endif #include "xla/stream_executor/platform/port.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" namespace stream_executor { @@ -46,22 +47,21 @@ namespace wrap { #define TO_STR_(x) #x #define TO_STR(x) TO_STR_(x) -#define ROCSOLVER_API_WRAPPER(api_name) \ - template \ - auto api_name(Args... args) -> decltype(::api_name(args...)) { \ - using FuncPtrT = std::add_pointer::type; \ - static FuncPtrT loaded = []() -> FuncPtrT { \ - static const char* kName = TO_STR(api_name); \ - void* f; \ - auto s = tsl::Env::Default() -> GetSymbolFromLibrary( \ - stream_executor::internal::CachedDsoLoader::GetRocsolverDsoHandle() \ - .value(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in rocsolver lib; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - }(); \ - return loaded(args...); \ +#define ROCSOLVER_API_WRAPPER(api_name) \ + template \ + auto api_name(Args... args) -> decltype(::api_name(args...)) { \ + using FuncPtrT = std::add_pointer::type; \ + static FuncPtrT loaded = []() -> FuncPtrT { \ + static const char* kName = TO_STR(api_name); \ + void* f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ + tsl::internal::CachedDsoLoader::GetRocsolverDsoHandle().value(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in rocsolver lib; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + }(); \ + return loaded(args...); \ } #endif diff --git a/xla/stream_executor/rocm/roctracer_wrapper.h b/xla/stream_executor/rocm/roctracer_wrapper.h index 49f1411712b15..9e97beb80cdf7 100644 --- a/xla/stream_executor/rocm/roctracer_wrapper.h +++ b/xla/stream_executor/rocm/roctracer_wrapper.h @@ -29,6 +29,7 @@ limitations under the License. #include "rocm/include/roctracer/roctracer_hcc.h" #endif #include "xla/stream_executor/platform/port.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" #include "tsl/platform/platform.h" @@ -45,22 +46,21 @@ namespace wrap { #else -#define ROCTRACER_API_WRAPPER(API_NAME) \ - template \ - auto API_NAME(Args... args) -> decltype(::API_NAME(args...)) { \ - using FuncPtrT = std::add_pointer::type; \ - static FuncPtrT loaded = []() -> FuncPtrT { \ - static const char* kName = #API_NAME; \ - void* f; \ - auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ - stream_executor::internal::CachedDsoLoader::GetRoctracerDsoHandle() \ - .value(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in roctracer DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - }(); \ - return loaded(args...); \ +#define ROCTRACER_API_WRAPPER(API_NAME) \ + template \ + auto API_NAME(Args... args) -> decltype(::API_NAME(args...)) { \ + using FuncPtrT = std::add_pointer::type; \ + static FuncPtrT loaded = []() -> FuncPtrT { \ + static const char* kName = #API_NAME; \ + void* f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ + tsl::internal::CachedDsoLoader::GetRoctracerDsoHandle().value(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in roctracer DSO; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + }(); \ + return loaded(args...); \ } #endif // PLATFORM_GOOGLE