Skip to content

Commit

Permalink
[tfrt:jitrt] Add support for returning failures from JitRt custom cal…
Browse files Browse the repository at this point in the history
…l binding

PiperOrigin-RevId: 445029205
  • Loading branch information
ezhulenev authored and copybara-github committed Apr 28, 2022
1 parent 43a5658 commit eaf4cbb
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 10 deletions.
1 change: 1 addition & 0 deletions backends/jitrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ tfrt_cc_library(
visibility = ["@tf_runtime//:friends"],
deps = [
":custom_call",
"@llvm-project//mlir:Support",
],
)

Expand Down
14 changes: 8 additions & 6 deletions backends/jitrt/include/tfrt/jitrt/custom_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class CustomCall {
virtual ~CustomCall() = default;

virtual llvm::StringRef name() const = 0;
virtual void call(void** args) = 0;
virtual mlir::LogicalResult call(void** args) = 0;

static CustomCallBinding<> Bind(std::string callee);
};
Expand Down Expand Up @@ -213,16 +213,17 @@ class CustomCallHandler : public CustomCall {
public:
llvm::StringRef name() const override { return callee_; }

void call(void** args) override {
mlir::LogicalResult call(void** args) override {
// Decode arguments from the opaque pointers.
auto decoded_args = internal::DecodeArgs(args);
assert(decoded_args.size() == kSize);
if (decoded_args.size() != kSize) return mlir::failure();

call(std::move(decoded_args), std::make_index_sequence<kSize>{});
return call(std::move(decoded_args), std::make_index_sequence<kSize>{});
}

template <std::size_t... Is>
void call(internal::DecodedArgs args, std::index_sequence<Is...>) {
mlir::LogicalResult call(internal::DecodedArgs args,
std::index_sequence<Is...>) {
// A helper structure to allow each decoder find the correct offset in the
// arguments.
internal::DecodingOffsets offsets;
Expand All @@ -236,7 +237,8 @@ class CustomCallHandler : public CustomCall {
// Check that all of them were successfully decoded.
std::array<bool, kSize> decoded = {
mlir::succeeded(std::get<Is>(fn_args))...};
assert(llvm::all_of(decoded, [](bool succeeded) { return succeeded; }));
if (llvm::any_of(decoded, [](bool succeeded) { return !succeeded; }))
return mlir::failure();

// Forward unpacked arguments to the callback.
return fn_(std::move(*std::get<Is>(fn_args))...);
Expand Down
15 changes: 12 additions & 3 deletions backends/jitrt/lib/custom_calls/custom_call_testlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@

#include <utility>

#include "mlir/Support/LogicalResult.h"
#include "tfrt/jitrt/custom_call.h"

namespace tfrt {
namespace jitrt {

static void TimesTwo(MemrefDesc input, MemrefDesc output) {
assert(input.dtype == DType::F32 && output.dtype == DType::F32);
assert(input.sizes == output.sizes);
using mlir::failure;
using mlir::LogicalResult;
using mlir::success;

static LogicalResult TimesTwo(MemrefDesc input, MemrefDesc output) {
// TODO(ezhulenev): Support all floating point dtypes.
if (input.dtype != output.dtype || input.sizes != output.sizes ||
input.dtype != DType::F32)
return failure();

int64_t num_elements = 1;
for (int64_t d : input.sizes) num_elements *= d;
Expand All @@ -33,6 +40,8 @@ static void TimesTwo(MemrefDesc input, MemrefDesc output) {

for (int64_t i = 0; i < num_elements; ++i)
output_data[i] = input_data[i] * 2.0;

return success();
}

void RegisterCustomCallTestLib(CustomCallRegistry* registry) {
Expand Down
6 changes: 5 additions & 1 deletion backends/jitrt/lib/jitrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,11 @@ extern "C" void runtimeCustomCall(const char* callee, void** args) {
// TODO(ezhulenev): Return failure if custom call is not registered.
auto* custom_call = registry->Find(callee);
assert(custom_call && "unknown custom call");
custom_call->call(args);

// TODO(ezhulenev): Handle failures in custom calls.
auto result = custom_call->call(args);
assert(mlir::succeeded(result) && "failed custom call");
(void)result;
}

llvm::orc::SymbolMap RuntimeApiSymbolMap(llvm::orc::MangleAndInterner mangle) {
Expand Down

0 comments on commit eaf4cbb

Please sign in to comment.