Skip to content

Commit

Permalink
[xla:python] Move registration of xla_python_gpu_callback into GPU cl…
Browse files Browse the repository at this point in the history
…ient target.

This fixes uses of xla_python_gpu_callback outside of the JAX GPU plugin by registering the custom call target in a build unit which is directly linked when including the gpu_support target, instead of hiding the registration behind an ifdef.

PiperOrigin-RevId: 675226878
  • Loading branch information
dfm authored and Google-ML-Automation committed Sep 16, 2024
1 parent 635f65f commit 9850ead
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
3 changes: 3 additions & 0 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,8 @@ cc_library(
"//xla/pjrt:exceptions",
"//xla/pjrt:host_callback",
"//xla/service:custom_call_status",
"//xla/service:custom_call_target_registry",
"//xla/service:platform_util",
"@com_google_absl//absl/base",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -1219,6 +1221,7 @@ cc_library(
deps = [
":nb_class_ptr",
":py_client",
":py_client_gpu",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:status_casters",
"//xla/pjrt/distributed:client",
Expand Down
6 changes: 0 additions & 6 deletions xla/python/py_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -649,12 +649,6 @@ PyClient::GetEmitPythonCallbackDescriptor(nb::callable callable,
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback",
&XlaPythonCpuCallback);

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || TENSORFLOW_USE_SYCL
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
"xla_python_gpu_callback", &XlaPythonGpuCallback,
absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()));
#endif

/* static */ int PyClient::tp_traverse(PyObject* self, visitproc visit,
void* arg) {
PyClient* c = nb::inst_ptr<PyClient>(self);
Expand Down
11 changes: 11 additions & 0 deletions xla/python/py_client_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include "absl/base/casts.h"
#include "absl/status/statusor.h"
#include "absl/strings/ascii.h"
#include "absl/strings/numbers.h"
#include "xla/service/custom_call_status.h"
#include "tsl/platform/errors.h"
Expand All @@ -33,6 +34,8 @@ limitations under the License.
#include "xla/primitive_util.h"
#include "xla/python/callback.h"
#include "xla/python/nb_numpy.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/platform_util.h"

#if TENSORFLOW_USE_ROCM
#define gpuSuccess hipSuccess
Expand Down Expand Up @@ -157,4 +160,12 @@ void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers,
}
}

// TODO(danfm): When compiled as part of a jaxlib plugin, this will register
// the custom call target in the plugin's registry. This won't affect
// registration via the Python API, but we should remove this once we have
// fully migrated to the plugin interface.
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
"xla_python_gpu_callback", &XlaPythonGpuCallback,
absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()));

} // namespace xla

0 comments on commit 9850ead

Please sign in to comment.