diff --git a/backends/jitrt/BUILD b/backends/jitrt/BUILD index 12b5c10bb10..5d5ccf4aac6 100644 --- a/backends/jitrt/BUILD +++ b/backends/jitrt/BUILD @@ -352,6 +352,7 @@ tfrt_cc_library( visibility = ["@tf_runtime//:friends"], deps = [ ":custom_call", + "@llvm-project//mlir:Support", ], ) diff --git a/backends/jitrt/include/tfrt/jitrt/custom_call.h b/backends/jitrt/include/tfrt/jitrt/custom_call.h index 5b87c274b46..ed6983d994b 100644 --- a/backends/jitrt/include/tfrt/jitrt/custom_call.h +++ b/backends/jitrt/include/tfrt/jitrt/custom_call.h @@ -49,7 +49,7 @@ class CustomCall { virtual ~CustomCall() = default; virtual llvm::StringRef name() const = 0; - virtual void call(void** args) = 0; + virtual mlir::LogicalResult call(void** args) = 0; static CustomCallBinding<> Bind(std::string callee); }; @@ -213,16 +213,17 @@ class CustomCallHandler : public CustomCall { public: llvm::StringRef name() const override { return callee_; } - void call(void** args) override { + mlir::LogicalResult call(void** args) override { // Decode arguments from the opaque pointers. auto decoded_args = internal::DecodeArgs(args); - assert(decoded_args.size() == kSize); + if (decoded_args.size() != kSize) return mlir::failure(); - call(std::move(decoded_args), std::make_index_sequence{}); + return call(std::move(decoded_args), std::make_index_sequence{}); } template - void call(internal::DecodedArgs args, std::index_sequence) { + mlir::LogicalResult call(internal::DecodedArgs args, + std::index_sequence) { // A helper structure to allow each decoder find the correct offset in the // arguments. internal::DecodingOffsets offsets; @@ -236,7 +237,8 @@ class CustomCallHandler : public CustomCall { // Check that all of them were successfully decoded. std::array decoded = { mlir::succeeded(std::get(fn_args))...}; - assert(llvm::all_of(decoded, [](bool succeeded) { return succeeded; })); + if (llvm::any_of(decoded, [](bool succeeded) { return !succeeded; })) + return mlir::failure(); // Forward unpacked arguments to the callback. return fn_(std::move(*std::get(fn_args))...); diff --git a/backends/jitrt/lib/custom_calls/custom_call_testlib.cc b/backends/jitrt/lib/custom_calls/custom_call_testlib.cc index 0416fb9b6ed..89926cd04da 100644 --- a/backends/jitrt/lib/custom_calls/custom_call_testlib.cc +++ b/backends/jitrt/lib/custom_calls/custom_call_testlib.cc @@ -16,14 +16,21 @@ #include +#include "mlir/Support/LogicalResult.h" #include "tfrt/jitrt/custom_call.h" namespace tfrt { namespace jitrt { -static void TimesTwo(MemrefDesc input, MemrefDesc output) { - assert(input.dtype == DType::F32 && output.dtype == DType::F32); - assert(input.sizes == output.sizes); +using mlir::failure; +using mlir::LogicalResult; +using mlir::success; + +static LogicalResult TimesTwo(MemrefDesc input, MemrefDesc output) { + // TODO(ezhulenev): Support all floating point dtypes. + if (input.dtype != output.dtype || input.sizes != output.sizes || + input.dtype != DType::F32) + return failure(); int64_t num_elements = 1; for (int64_t d : input.sizes) num_elements *= d; @@ -33,6 +40,8 @@ static void TimesTwo(MemrefDesc input, MemrefDesc output) { for (int64_t i = 0; i < num_elements; ++i) output_data[i] = input_data[i] * 2.0; + + return success(); } void RegisterCustomCallTestLib(CustomCallRegistry* registry) { diff --git a/backends/jitrt/lib/jitrt.cc b/backends/jitrt/lib/jitrt.cc index 6828fc7e594..58dc15e3f47 100644 --- a/backends/jitrt/lib/jitrt.cc +++ b/backends/jitrt/lib/jitrt.cc @@ -1270,7 +1270,11 @@ extern "C" void runtimeCustomCall(const char* callee, void** args) { // TODO(ezhulenev): Return failure if custom call is not registered. auto* custom_call = registry->Find(callee); assert(custom_call && "unknown custom call"); - custom_call->call(args); + + // TODO(ezhulenev): Handle failures in custom calls. + auto result = custom_call->call(args); + assert(mlir::succeeded(result) && "failed custom call"); + (void)result; } llvm::orc::SymbolMap RuntimeApiSymbolMap(llvm::orc::MangleAndInterner mangle) {