diff --git a/xla/python/BUILD b/xla/python/BUILD index ffb3b8dc4a60d..a9081b5d63f1c 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -497,6 +497,8 @@ cc_library( "//xla/pjrt:exceptions", "//xla/pjrt:host_callback", "//xla/service:custom_call_status", + "//xla/service:custom_call_target_registry", + "//xla/service:platform_util", "@com_google_absl//absl/base", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -1219,6 +1221,7 @@ cc_library( deps = [ ":nb_class_ptr", ":py_client", + ":py_client_gpu", "//xla/pjrt:pjrt_client", "//xla/pjrt:status_casters", "//xla/pjrt/distributed:client", diff --git a/xla/python/py_client.cc b/xla/python/py_client.cc index 0e36a346f67e3..d108e9d9c1e47 100644 --- a/xla/python/py_client.cc +++ b/xla/python/py_client.cc @@ -649,12 +649,6 @@ PyClient::GetEmitPythonCallbackDescriptor(nb::callable callable, XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback", &XlaPythonCpuCallback); -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || TENSORFLOW_USE_SYCL -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "xla_python_gpu_callback", &XlaPythonGpuCallback, - absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value())); -#endif - /* static */ int PyClient::tp_traverse(PyObject* self, visitproc visit, void* arg) { PyClient* c = nb::inst_ptr(self); diff --git a/xla/python/py_client_gpu.cc b/xla/python/py_client_gpu.cc index b045bfb8ca606..d1c01a62d16a7 100644 --- a/xla/python/py_client_gpu.cc +++ b/xla/python/py_client_gpu.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/status/statusor.h" +#include "absl/strings/ascii.h" #include "absl/strings/numbers.h" #include "xla/service/custom_call_status.h" #include "tsl/platform/errors.h" @@ -33,6 +34,8 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/python/callback.h" #include "xla/python/nb_numpy.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/platform_util.h" #if TENSORFLOW_USE_ROCM #define gpuSuccess hipSuccess @@ -157,4 +160,12 @@ void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers, } } +// TODO(danfm): When compiled as part of a jaxlib plugin, this will register +// the custom call target in the plugin's registry. This won't affect +// registration via the Python API, but we should remove this once we have +// fully migrated to the plugin interface. +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( + "xla_python_gpu_callback", &XlaPythonGpuCallback, + absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value())); + } // namespace xla