Skip to content

Commit

Permalink
[xla:python] Add method to get python callback capsule without requir…
Browse files Browse the repository at this point in the history
…ing operand or result shapes / returning capsule descriptor.

PiperOrigin-RevId: 713295376
  • Loading branch information
danielsuo authored and Google-ML-Automation committed Feb 1, 2025
1 parent ece956f commit c67fc1b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
14 changes: 12 additions & 2 deletions xla/python/pjrt_ifrt/pjrt_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ limitations under the License.
#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h"
#include "xla/pjrt/host_callback.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/primitive_util.h"
Expand Down Expand Up @@ -547,6 +548,15 @@ PjRtLoadedExecutable::Execute(
opts.use_major_to_minor_data_layout_for_callbacks = true;
opts.non_donatable_input_indices = options.non_donatable_input_indices;

auto context = std::make_shared<xla::ExecuteContext>();
auto platform_id = pjrt_loaded_executable_->client()->platform_id();
// Forward callbacks via FFI's ExecutionContext for CPU/GPU platforms only.
if (platform_id == CpuId() || platform_id == CudaId() ||
platform_id == RocmId() || platform_id == SyclId()) {
CHECK_OK(context->ffi_context().Insert(all_loaded_host_callbacks_.get()));
opts.context = context.get();
}

if (!all_loaded_host_callbacks_->empty() && !returned_future_supported) {
return Internal(
"Host callback not supported without returned future support in "
Expand Down Expand Up @@ -623,8 +633,8 @@ PjRtLoadedExecutable::Execute(
// can use the futures to extend the lifetime of the host callbacks until
// the execution finishes.
status.OnReady([all_loaded_host_callbacks = all_loaded_host_callbacks_,
host_callback_states =
std::move(host_callback_states)](absl::Status) mutable {
host_callback_states = std::move(host_callback_states),
context = std::move(context)](absl::Status) mutable {
all_loaded_host_callbacks.reset();
});
}
Expand Down
13 changes: 13 additions & 0 deletions xla/python/py_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,16 @@ PyClient::GetEmitPythonCallbackDescriptor(
return std::make_pair(descriptor, nb::object(std::move(callback_capsule)));
}

absl::StatusOr<nb::object> PyClient::GetEmitPythonCallback(
nb::callable callable) {
absl::Span<const Shape> operand_shapes;
absl::Span<const Shape> result_shapes;
TF_ASSIGN_OR_RETURN(auto descriptor_and_callback,
GetEmitPythonCallbackDescriptor(
std::move(callable), operand_shapes, result_shapes));
return nb::object(std::move(descriptor_and_callback.second));
}

XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback",
&XlaPythonCpuCallback);

Expand Down Expand Up @@ -764,6 +774,9 @@ PyType_Slot PyClient::slots_[] = {
xla::ValueOrThrowWrapper(&PyClient::GetEmitPythonCallbackDescriptor),
nb::arg("callable"), nb::arg("operand_shapes"),
nb::arg("result_shapes").none() = nb::none())
.def("get_emit_python_callback",
xla::ValueOrThrowWrapper(&PyClient::GetEmitPythonCallback),
nb::arg("callable"))
.def("make_python_callback_from_host_send_and_recv",
xla::ValueOrThrowWrapper(
&PyClient::MakePythonCallbackUsingHostSendAndRecv),
Expand Down
9 changes: 9 additions & 0 deletions xla/python/py_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,15 @@ class PyClient {
absl::Span<Shape const> operand_shapes,
absl::Span<Shape const> result_shapes);

// `GetEmitPythonCallback` takes in an input Python callable. It returns a
// Python object whose reference will keep the Python callback alive.
//
// The callable receives as arguments NumPy arrays for arguments with array
// types, and None for Token argument. The callable must return a tuple of
// either arrays or None values.
absl::StatusOr<nanobind::object> GetEmitPythonCallback(
nanobind::callable callable);

// `MakePythonCallbackUsingHostSendAndRecv` takes in an input Python callable
// that takes in arguments of shapes `operand_shapes` and returns results of
// shapes `result_shapes`. The arguments correspond to Send ops in the HLO
Expand Down

0 comments on commit c67fc1b

Please sign in to comment.