Skip to content

Commit

Permalink
[xla:ffi] Add CallAsync for invoking potentially asyncrhronous FFI ha…
Browse files Browse the repository at this point in the history
…ndlers

PiperOrigin-RevId: 674358537
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Sep 13, 2024
1 parent 3f8c182 commit 32ebd69
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 50 deletions.
1 change: 1 addition & 0 deletions xla/ffi/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ xla_cc_test(
"//xla/ffi:type_id_registry",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:device_memory_allocator",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down
19 changes: 16 additions & 3 deletions xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ limitations under the License.
#include "xla/primitive_util.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/concurrency/chain.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/env.h"
Expand Down Expand Up @@ -1205,10 +1207,21 @@ TEST(FfiTest, AsyncHandler) {
CallOptions options;
options.backend_options = CallOptions::CpuOptions{&device};

auto status = Call(*handler, call_frame, options);
TF_ASSERT_OK(status);
{ // Synchronous call.
absl::Status status = Call(*handler, call_frame, options);
TF_ASSERT_OK(status);
EXPECT_EQ(value, 42);
}

value = 0; // reset value between calls

EXPECT_EQ(value, 42);
{ // Asynchronous call.
tsl::AsyncValueRef<tsl::Chain> async_value =
CallAsync(*handler, call_frame, options);
tsl::BlockUntilReady(async_value);
ASSERT_TRUE(async_value.IsConcrete());
EXPECT_EQ(value, 42);
}
}

TEST(FfiTest, Metadata) {
Expand Down
126 changes: 84 additions & 42 deletions xla/ffi/ffi_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <exception>
#include <string>
#include <string_view>
#include <type_traits>
#include <utility>
#include <variant>
#include <vector>
Expand Down Expand Up @@ -131,52 +132,59 @@ static XLA_FFI_ExecutionContext CreateExecutionContext(
//===----------------------------------------------------------------------===//

absl::Status TakeStatus(XLA_FFI_Error* error) {
if (error == nullptr) return absl::OkStatus();
if (ABSL_PREDICT_TRUE(error == nullptr)) return absl::OkStatus();
absl::Status status = std::move(error->status);
delete error;
return status;
}

absl::Status Call(Ffi& handler, CallFrame& call_frame,
const CallOptions& options, ExecutionStage stage) {
XLA_FFI_ExecutionContext ctx = CreateExecutionContext(options);
XLA_FFI_CallFrame ffi_call_frame = call_frame.Build(
GetXlaFfiApi(), &ctx, static_cast<XLA_FFI_ExecutionStage>(stage));
XLA_FFI_Error* error = nullptr;
try {
error = handler.Call(&ffi_call_frame);
} catch (std::exception& e) {
return Unknown("XLA FFI call failed: %s", e.what());
}
tsl::AsyncValueRef<tsl::Chain> TakeFuture(XLA_FFI_Future* future) {
// Non-reference-counted async value ref for synchronous FFI handlers.
static tsl::AsyncValueOwningRef<tsl::Chain>* chain = [] {
auto* storage = new tsl::internal::AsyncValueStorage<tsl::Chain>();
return new tsl::AsyncValueOwningRef<tsl::Chain>(
tsl::MakeAvailableAsyncValueRef<tsl::Chain>(*storage));
}();

// If FFI handler returned synchronous error, it must not launch any
// asynchronous work that can also return an error.
if (error != nullptr) {
DCHECK_EQ(ffi_call_frame.future, nullptr)
<< "Error must not be used together with a future";
}
if (ABSL_PREDICT_TRUE(future == nullptr)) return chain->AsRef();

// Wait for the completion of asynchronous work launched by the handler.
if (XLA_FFI_Future* future = ffi_call_frame.future;
ABSL_PREDICT_FALSE(future != nullptr)) {
absl::Cleanup delete_future = [&] { delete future; };
tsl::BlockUntilReady(future->async_value);
if (ABSL_PREDICT_FALSE(future->async_value.IsError())) {
return future->async_value.GetError();
}
// Keeps XLA_FFI_Future alive until it is completed.
absl::Cleanup delete_future = [future] { delete future; };

// If the future is already completed, immediately return the underlying async
// value and destroy the XLA_FFI_Future.
if (ABSL_PREDICT_TRUE(future->async_value.IsAvailable())) {
return std::move(future->async_value);
}

return TakeStatus(error);
// If the future is not completed, return a copy of the underlying async value
// and keep XLA_FFI_Future alive until it is completed.
tsl::AsyncValueRef<tsl::Chain> async_value = future->async_value;
async_value.AndThen([delete_future = std::move(delete_future)] {});
return async_value;
}

absl::Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame,
const CallOptions& options, XLA_FFI_ExecutionStage stage) {
template <typename Handler>
static absl::StatusOr<XLA_FFI_Future*> Call(Handler& handler,
CallFrame& call_frame,
const CallOptions& options,
ExecutionStage stage) {
XLA_FFI_ExecutionContext ctx = CreateExecutionContext(options);
XLA_FFI_CallFrame ffi_call_frame =
call_frame.Build(GetXlaFfiApi(), &ctx, stage);
XLA_FFI_CallFrame ffi_call_frame = call_frame.Build(
GetXlaFfiApi(), &ctx, static_cast<XLA_FFI_ExecutionStage>(stage));

XLA_FFI_Error* error = nullptr;

// FFI handlers might be defined in external libraries and use exceptions, so
// take extra care to catch them and convert to a status.
try {
error = (*handler)(&ffi_call_frame);
if constexpr (std::is_same_v<Handler, Ffi>) {
error = handler.Call(&ffi_call_frame);
} else if constexpr (std::is_same_v<Handler, XLA_FFI_Handler*>) {
error = (*handler)(&ffi_call_frame);
} else {
static_assert(sizeof(Handler) == 0, "Unsupported handler type");
}
} catch (std::exception& e) {
return Unknown("XLA FFI call failed: %s", e.what());
}
Expand All @@ -186,19 +194,53 @@ absl::Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame,
if (error != nullptr) {
DCHECK_EQ(ffi_call_frame.future, nullptr)
<< "Error must not be used together with a future";
return TakeStatus(error);
}

// Wait for the completion of asynchronous work launched by the handler.
if (XLA_FFI_Future* future = ffi_call_frame.future;
ABSL_PREDICT_FALSE(future != nullptr)) {
absl::Cleanup delete_future = [&] { delete future; };
tsl::BlockUntilReady(future->async_value);
if (ABSL_PREDICT_FALSE(future->async_value.IsError())) {
return future->async_value.GetError();
}
}
return ffi_call_frame.future;
}

return TakeStatus(error);
static absl::Status BlockUntilReady(XLA_FFI_Future* future) {
if (ABSL_PREDICT_TRUE(future == nullptr)) return absl::OkStatus();

tsl::AsyncValueRef<tsl::Chain> av = TakeFuture(future);
tsl::BlockUntilReady(av);
return ABSL_PREDICT_FALSE(av.IsError()) ? av.GetError() : absl::OkStatus();
}

absl::Status Call(Ffi& handler, CallFrame& call_frame,
const CallOptions& options, ExecutionStage stage) {
TF_ASSIGN_OR_RETURN(XLA_FFI_Future * future,
Call<Ffi>(handler, call_frame, options, stage));
return BlockUntilReady(future);
}

absl::Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame,
const CallOptions& options, XLA_FFI_ExecutionStage stage) {
TF_ASSIGN_OR_RETURN(
XLA_FFI_Future * future,
Call<XLA_FFI_Handler*>(handler, call_frame, options,
static_cast<ExecutionStage>(stage)));
return BlockUntilReady(future);
}

tsl::AsyncValueRef<tsl::Chain> CallAsync(Ffi& handler, CallFrame& call_frame,
const CallOptions& options,
ExecutionStage stage) {
TF_ASSIGN_OR_RETURN(XLA_FFI_Future * future,
Call<Ffi>(handler, call_frame, options, stage));
return TakeFuture(future);
}

tsl::AsyncValueRef<tsl::Chain> CallAsync(XLA_FFI_Handler* handler,
CallFrame& call_frame,
const CallOptions& options,
XLA_FFI_ExecutionStage stage) {
TF_ASSIGN_OR_RETURN(
XLA_FFI_Future * future,
Call<XLA_FFI_Handler*>(handler, call_frame, options,
static_cast<ExecutionStage>(stage)));
return TakeFuture(future);
}

static XLA_FFI_Metadata BuildMetadata() {
Expand Down
25 changes: 23 additions & 2 deletions xla/ffi/ffi_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/stream.h"
#include "xla/tsl/concurrency/chain.h"

namespace xla::ffi {

Expand Down Expand Up @@ -73,10 +74,18 @@ struct CallOptions {
};

// Takes ownership of the XLA FFI error and returns underlying status. Frees
// `error` if it's not nullptr; returns OK status otherwise.
// `error` if it's not nullptr. If `error` is nullptr, returns OkStatus.
absl::Status TakeStatus(XLA_FFI_Error* error);

// Calls an XLA FFI handler with the given call frame and options.
// Takes ownership of the XLA FFI future and returns underlying AsyncValue.
// Frees `future` if it's not nullptr. If `future` is nullptr, returns available
// async value.
tsl::AsyncValueRef<tsl::Chain> TakeFuture(XLA_FFI_Future* future);

// Calls an XLA FFI handler with the given call frame and options. This is a
// synchronous call and it might block the caller thread if the handler is
// asynchronous. It is unsafe to call if from a thread pool that runs tasks
// scheduled by the handler itself.
absl::Status Call(Ffi& handler, CallFrame& call_frame,
const CallOptions& options = {},
ExecutionStage stage = ExecutionStage::kExecute);
Expand All @@ -86,6 +95,18 @@ absl::Status Call(
const CallOptions& options = {},
XLA_FFI_ExecutionStage stage = XLA_FFI_ExecutionStage_EXECUTE);

// Calls an XLA FFI handler with the given call frame and options. This is an
// asynchronous call and it will not block the caller thread. Returned async
// value will become available when the handler completes execution.
tsl::AsyncValueRef<tsl::Chain> CallAsync(
Ffi& handler, CallFrame& call_frame, const CallOptions& options = {},
ExecutionStage stage = ExecutionStage::kExecute);

tsl::AsyncValueRef<tsl::Chain> CallAsync(
XLA_FFI_Handler* handler, CallFrame& call_frame,
const CallOptions& options = {},
XLA_FFI_ExecutionStage stage = XLA_FFI_ExecutionStage_EXECUTE);

// Gets metadata from the handler by calling it with a special call frame.
absl::StatusOr<XLA_FFI_Metadata> GetMetadata(Ffi& handler);
absl::StatusOr<XLA_FFI_Metadata> GetMetadata(XLA_FFI_Handler* handler);
Expand Down
17 changes: 14 additions & 3 deletions xla/ffi/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1031,10 +1031,21 @@ TEST(FfiTest, AsyncHandler) {
CallOptions options;
options.backend_options = CallOptions::CpuOptions{&device};

auto status = Call(*handler, call_frame, options);
TF_ASSERT_OK(status);
{ // Synchronous call.
absl::Status status = Call(*handler, call_frame, options);
TF_ASSERT_OK(status);
EXPECT_EQ(value, 42);
}

value = 0; // reset value between calls

EXPECT_EQ(value, 42);
{ // Asynchronous call.
tsl::AsyncValueRef<tsl::Chain> async_value =
CallAsync(*handler, call_frame, options);
tsl::BlockUntilReady(async_value);
ASSERT_TRUE(async_value.IsConcrete());
EXPECT_EQ(value, 42);
}
}

TEST(FfiTest, Metadata) {
Expand Down

0 comments on commit 32ebd69

Please sign in to comment.