From c67fc1b083f391b73e59df65de3fdbeed537c14d Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 8 Jan 2025 08:09:40 -0800 Subject: [PATCH] [xla:python] Add method to get python callback capsule without requiring operand or result shapes / returning capsule descriptor. PiperOrigin-RevId: 713295376 --- xla/python/pjrt_ifrt/pjrt_executable.cc | 14 ++++++++++++-- xla/python/py_client.cc | 13 +++++++++++++ xla/python/py_client.h | 9 +++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/xla/python/pjrt_ifrt/pjrt_executable.cc b/xla/python/pjrt_ifrt/pjrt_executable.cc index 26611a302dcec..3b9ccbdfebee2 100644 --- a/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/primitive_util.h" @@ -547,6 +548,15 @@ PjRtLoadedExecutable::Execute( opts.use_major_to_minor_data_layout_for_callbacks = true; opts.non_donatable_input_indices = options.non_donatable_input_indices; + auto context = std::make_shared(); + auto platform_id = pjrt_loaded_executable_->client()->platform_id(); + // Forward callbacks via FFI's ExecutionContext for CPU/GPU platforms only. + if (platform_id == CpuId() || platform_id == CudaId() || + platform_id == RocmId() || platform_id == SyclId()) { + CHECK_OK(context->ffi_context().Insert(all_loaded_host_callbacks_.get())); + opts.context = context.get(); + } + if (!all_loaded_host_callbacks_->empty() && !returned_future_supported) { return Internal( "Host callback not supported without returned future support in " @@ -623,8 +633,8 @@ PjRtLoadedExecutable::Execute( // can use the futures to extend the lifetime of the host callbacks until // the execution finishes. status.OnReady([all_loaded_host_callbacks = all_loaded_host_callbacks_, - host_callback_states = - std::move(host_callback_states)](absl::Status) mutable { + host_callback_states = std::move(host_callback_states), + context = std::move(context)](absl::Status) mutable { all_loaded_host_callbacks.reset(); }); } diff --git a/xla/python/py_client.cc b/xla/python/py_client.cc index 838323411b865..373d7fef0bb88 100644 --- a/xla/python/py_client.cc +++ b/xla/python/py_client.cc @@ -647,6 +647,16 @@ PyClient::GetEmitPythonCallbackDescriptor( return std::make_pair(descriptor, nb::object(std::move(callback_capsule))); } +absl::StatusOr PyClient::GetEmitPythonCallback( + nb::callable callable) { + absl::Span operand_shapes; + absl::Span result_shapes; + TF_ASSIGN_OR_RETURN(auto descriptor_and_callback, + GetEmitPythonCallbackDescriptor( + std::move(callable), operand_shapes, result_shapes)); + return nb::object(std::move(descriptor_and_callback.second)); +} + XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback", &XlaPythonCpuCallback); @@ -764,6 +774,9 @@ PyType_Slot PyClient::slots_[] = { xla::ValueOrThrowWrapper(&PyClient::GetEmitPythonCallbackDescriptor), nb::arg("callable"), nb::arg("operand_shapes"), nb::arg("result_shapes").none() = nb::none()) + .def("get_emit_python_callback", + xla::ValueOrThrowWrapper(&PyClient::GetEmitPythonCallback), + nb::arg("callable")) .def("make_python_callback_from_host_send_and_recv", xla::ValueOrThrowWrapper( &PyClient::MakePythonCallbackUsingHostSendAndRecv), diff --git a/xla/python/py_client.h b/xla/python/py_client.h index a8893a0b41441..a107fa57be749 100644 --- a/xla/python/py_client.h +++ b/xla/python/py_client.h @@ -197,6 +197,15 @@ class PyClient { absl::Span operand_shapes, absl::Span result_shapes); + // `GetEmitPythonCallback` takes in an input Python callable. It returns a + // Python object whose reference will keep the Python callback alive. + // + // The callable receives as arguments NumPy arrays for arguments with array + // types, and None for Token argument. The callable must return a tuple of + // either arrays or None values. + absl::StatusOr GetEmitPythonCallback( + nanobind::callable callable); + // `MakePythonCallbackUsingHostSendAndRecv` takes in an input Python callable // that takes in arguments of shapes `operand_shapes` and returns results of // shapes `result_shapes`. The arguments correspond to Send ops in the HLO