diff --git a/xla/backends/cpu/runtime/custom_call_thunk.cc b/xla/backends/cpu/runtime/custom_call_thunk.cc index eb266987013b0..8ce3106213b07 100644 --- a/xla/backends/cpu/runtime/custom_call_thunk.cc +++ b/xla/backends/cpu/runtime/custom_call_thunk.cc @@ -267,15 +267,7 @@ tsl::AsyncValueRef CustomCallThunk::CallTypedFFI( /*called_computation=*/nullptr, custom_call_params->ffi_execution_context, execution_state_.get()}; - // Call the function and check execution status. - auto status = ffi::Call(handler->bundle.execute, call_frame, call_options); - if (!status.ok()) { - // Overwrite the returned error code to kInternal to match the original CPU - // implementation. - // TODO(penporn): Use TF_RETURN_IF_ERROR when thunks is the only runtime. - return Internal("%s", status.message()); - } - return OkExecuteEvent(); + return ffi::CallAsync(handler->bundle.execute, call_frame, call_options); } tsl::AsyncValueRef CustomCallThunk::CallUntypedAPI( diff --git a/xla/tests/custom_call_test.cc b/xla/tests/custom_call_test.cc index f2c8c4a400fa8..7e03bd9f2971a 100644 --- a/xla/tests/custom_call_test.cc +++ b/xla/tests/custom_call_test.cc @@ -1058,11 +1058,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiWrongNumberOfArguments) { module->AddEntryComputation(builder.Build()); auto status = Execute(std::move(module), {}).status(); - // NOTE: In the current CPU implementation, the 'kInternal' status code is - // returned when the argument is invalid. This behavior differs from that of - // the GPU, which returns 'kInvalidArgument' in such case. When the CPU adopts - // the thunks runtime, the status code will be unified across both backends. - EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), HasSubstr("Wrong number of arguments")); } @@ -1085,11 +1081,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiWrongRankOfArgument) { module->AddEntryComputation(builder.Build()); auto status = Execute(std::move(module), {}).status(); - // NOTE: In the current CPU implementation, the 'kInternal' status code is - // returned when the argument is invalid. This behavior differs from that of - // the GPU, which returns 'kInvalidArgument' in such case. When the CPU adopts - // the thunks runtime, the status code will be unified across both backends. - EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), HasSubstr("Wrong buffer rank")); } @@ -1106,11 +1098,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiWrongDTypeOfArgument) { module->AddEntryComputation(builder.Build()); auto status = Execute(std::move(module), {}).status(); - // NOTE: In the current CPU implementation, the 'kInternal' status code is - // returned when the argument is invalid. This behavior differs from that of - // the GPU, which returns 'kInvalidArgument' in such case. When the CPU adopts - // the thunks runtime, the status code will be unified across both backends. - EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), HasSubstr("Wrong buffer dtype")); } @@ -1266,11 +1254,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiWrongEnumType) { module->AddEntryComputation(builder.Build()); auto status = Execute(std::move(module), {}).status(); - // NOTE: In the current CPU implementation, the 'kInternal' status code is - // returned when the argument is invalid. This behavior differs from that of - // the GPU, which returns 'kInvalidArgument' in such case. When the CPU adopts - // the thunks runtime, the status code will be unified across both backends. - EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), HasSubstr("Wrong scalar data type")); }