Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:ffi] Add auto-binding for FFI results #11235

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions xla/ffi/api/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,20 @@ struct ArgBinding {
using Arg = void;
};

// XLA FFI binding for a returned result.
//
// Example: binding for the `MyType` result
//
// template <>
// struct RetBinding<MyType> {
// using Ret = MyType;
// };
//
template <typename T>
struct RetBinding {
using Ret = void;
};

// XLA FFI binding for a named attribute.
//
// Example: binding for the `MyType` attribute
Expand Down Expand Up @@ -394,6 +408,10 @@ template <typename Param>
inline constexpr bool is_arg_binding_v =
!std::is_void_v<typename ArgBinding<Param>::Arg>;

template <typename Param>
inline constexpr bool is_ret_binding_v =
!std::is_void_v<typename RetBinding<Param>::Ret>;

template <typename Param>
inline constexpr bool is_attr_binding_v =
!std::is_void_v<typename AttrBinding<Param>::Attr>;
Expand Down Expand Up @@ -422,6 +440,11 @@ struct BindOne<Fn, Param, Params...> {
return BindOne<Fn, Params...>::To(
std::move(fn),
std::move(binding).template Arg<typename ArgBinding<Param>::Arg>());
} else if constexpr (is_ret_binding_v<Param>) {
// Bind parameter as an FFI handler result.
return BindOne<Fn, Params...>::To(
std::move(fn),
std::move(binding).template Ret<typename RetBinding<Param>::Ret>());

} else if constexpr (is_attr_binding_v<Param>) {
// Bind parameter as a named FFI handler attribute.
Expand Down
14 changes: 14 additions & 0 deletions xla/ffi/api/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,20 @@ struct ArgBinding<Buffer<dtype, rank>> {
using Arg = Buffer<dtype, rank>;
};

//===----------------------------------------------------------------------===//
// Results binding
//===----------------------------------------------------------------------===//

template <>
struct RetBinding<Result<BufferBase>> {
using Ret = BufferBase;
};

template <DataType dtype, size_t rank>
struct RetBinding<Result<Buffer<dtype, rank>>> {
using Ret = Buffer<dtype, rank>;
};

//===----------------------------------------------------------------------===//
// Arguments decoding
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,18 @@ TEST(FfiTest, AutoBinding) {
TF_ASSERT_OK(status);
}

TEST(FfiTest, AutoBindingResult) {
auto handler =
Ffi::BindTo(+[](Result<BufferBase> buffer) { return Error::Success(); });

CallFrameBuilder builder;
builder.AddBufferRet(se::DeviceMemoryBase(), PrimitiveType::F32, /*dims=*/{});
auto call_frame = builder.Build();

auto status = Call(*handler, call_frame);
TF_ASSERT_OK(status);
}

struct I32AndF32 {
int32_t i32;
float f32;
Expand Down
Loading