Skip to content

Commit

Permalink
[xla:ffi] Optimize returning tsl::AsyncValueRef from FFI handler
Browse files Browse the repository at this point in the history
BEFORE: BM_AsyncAnyBufferArgX1       55.2 ns         55.2 ns     12514333
AFTER:  BM_AsyncAnyBufferArgX1       31.6 ns         31.6 ns     21924178
PiperOrigin-RevId: 674518452
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Sep 14, 2024
1 parent accb49a commit 7ea247e
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 12 deletions.
3 changes: 1 addition & 2 deletions xla/ffi/api/c_api_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ extern "C" {
typedef XLA_FFI_Error* XLA_FFI_INTERNAL_Error_Forward(void* status);

// Forwards `tsl::AsyncValue` object pointed to by `async_value` to XLA FFI
// future. Pointer ownership stays with the caller. Constructed XLA_FFI_Error
// adds +1 reference to the underlying async value object.
// future. Async value ownership transferred to the XLA FFI future.
typedef XLA_FFI_Future* XLA_FFI_INTERNAL_Future_Forward(void* async_value);

// Returns a pointer to main compute stream (`se::Stream` pointer). In
Expand Down
2 changes: 1 addition & 1 deletion xla/ffi/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ struct ResultEncoding<stage, tsl::AsyncValueRef<tsl::Chain>> {
XLA_FFI_ExecutionContext* ctx,
tsl::AsyncValueRef<tsl::Chain> async_value) {
return api->internal_api->XLA_FFI_INTERNAL_Future_Forward(
async_value.GetAsyncValue());
async_value.release());
}
};

Expand Down
13 changes: 6 additions & 7 deletions xla/ffi/ffi_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,18 @@ tsl::AsyncValueRef<tsl::Chain> TakeFuture(XLA_FFI_Future* future) {

if (ABSL_PREDICT_TRUE(future == nullptr)) return chain->AsRef();

// Keeps XLA_FFI_Future alive until it is completed.
absl::Cleanup delete_future = [future] { delete future; };

// If the future is already completed, immediately return the underlying async
// value and destroy the XLA_FFI_Future.
// value and delete the XLA_FFI_Future.
if (ABSL_PREDICT_TRUE(future->async_value.IsAvailable())) {
return std::move(future->async_value);
tsl::AsyncValueRef<tsl::Chain> async_value = std::move(future->async_value);
delete future;
return async_value;
}

// If the future is not completed, return a copy of the underlying async value
// and keep XLA_FFI_Future alive until it is completed.
tsl::AsyncValueRef<tsl::Chain> async_value = future->async_value;
async_value.AndThen([delete_future = std::move(delete_future)] {});
async_value.AndThen([future] { delete future; });
return async_value;
}

Expand Down Expand Up @@ -799,7 +798,7 @@ static XLA_FFI_Future* XLA_FFI_INTERNAL_Future_Forward(void* async_value) {
DCHECK(tsl_async_value) << "Async value must not be null";

return new XLA_FFI_Future{
tsl::AsyncValueRef<tsl::Chain>(tsl::FormRef(tsl_async_value))};
tsl::AsyncValueRef<tsl::Chain>(tsl::TakeRef(tsl_async_value))};
}

static void* XLA_FFI_INTERNAL_Stream_Get(XLA_FFI_ExecutionContext* ctx) {
Expand Down
9 changes: 7 additions & 2 deletions xla/ffi/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1175,10 +1175,15 @@ BENCHMARK(BM_AnyBufferArgX8);
void BM_AsyncAnyBufferArgX1(benchmark::State& state) {
auto call_frame = WithBufferArgs(1).Build();

auto done = tsl::MakeAvailableAsyncValueRef<tsl::Chain>();
static tsl::AsyncValueOwningRef<tsl::Chain>* done = [] {
auto* storage = new tsl::internal::AsyncValueStorage<tsl::Chain>();
return new tsl::AsyncValueOwningRef<tsl::Chain>(
tsl::MakeAvailableAsyncValueRef<tsl::Chain>(*storage));
}();

auto handler = Ffi::Bind().Arg<AnyBuffer>().To([&](auto buffer) {
benchmark::DoNotOptimize(buffer);
return done;
return done->AsRef();
});

for (auto _ : state) {
Expand Down

0 comments on commit 7ea247e

Please sign in to comment.