diff --git a/xla/ffi/api/ffi.h b/xla/ffi/api/ffi.h index eb5d33b7bc620..b31da22175333 100644 --- a/xla/ffi/api/ffi.h +++ b/xla/ffi/api/ffi.h @@ -1019,9 +1019,6 @@ template struct ResultEncoding { static std::variant Encode( const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx, Future future) { - // TODO(ezhulenev): Add benchmarks for asynchronous FFI handlers, and - // optimize the fast path for returning completed futures. - // Create XLA_FFI_Future object that will signal completion to the runtime. XLA_FFI_Future_Create_Args args; args.struct_size = XLA_FFI_Future_Create_Args_STRUCT_SIZE; diff --git a/xla/ffi/api/ffi_test.cc b/xla/ffi/api/ffi_test.cc index 89e9a3658f76e..bea5176a560b6 100644 --- a/xla/ffi/api/ffi_test.cc +++ b/xla/ffi/api/ffi_test.cc @@ -1307,6 +1307,27 @@ void BM_AnyBufferArgX4(benchmark::State& state) { BENCHMARK(BM_AnyBufferArgX4); +//===----------------------------------------------------------------------===// +// BM_AsyncAnyBufferArgX1 +//===----------------------------------------------------------------------===// + +void BM_AsyncAnyBufferArgX1(benchmark::State& state) { + auto call_frame = WithBufferArgs(1).Build(); + + auto handler = Ffi::Bind().Arg().To([](auto buffer) { + benchmark::DoNotOptimize(buffer); + Promise promise; + promise.SetAvailable(); + return Future(promise); + }); + + for (auto _ : state) { + CHECK_OK(Call(*handler, call_frame)); + } +} + +BENCHMARK(BM_AsyncAnyBufferArgX1); + //===----------------------------------------------------------------------===// // BM_BufferArgX1 //===----------------------------------------------------------------------===// diff --git a/xla/ffi/ffi_test.cc b/xla/ffi/ffi_test.cc index 455180c947468..ef615bc3f9d50 100644 --- a/xla/ffi/ffi_test.cc +++ b/xla/ffi/ffi_test.cc @@ -1168,6 +1168,26 @@ void BM_AnyBufferArgX8(benchmark::State& state) { BENCHMARK(BM_AnyBufferArgX8); +//===----------------------------------------------------------------------===// +// BM_AsyncAnyBufferArgX1 +//===----------------------------------------------------------------------===// + +void BM_AsyncAnyBufferArgX1(benchmark::State& state) { + auto call_frame = WithBufferArgs(1).Build(); + + auto done = tsl::MakeAvailableAsyncValueRef(); + auto handler = Ffi::Bind().Arg().To([&](auto buffer) { + benchmark::DoNotOptimize(buffer); + return done; + }); + + for (auto _ : state) { + CHECK_OK(Call(*handler, call_frame)); + } +} + +BENCHMARK(BM_AsyncAnyBufferArgX1); + //===----------------------------------------------------------------------===// // BM_BufferArgX1 //===----------------------------------------------------------------------===//