Skip to content

Commit

Permalink
[xla:cpu] Use ffi::CallAsync in custom call thunk
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675721439
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Sep 17, 2024
1 parent bc1aad8 commit d3a787c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 29 deletions.
10 changes: 1 addition & 9 deletions xla/backends/cpu/runtime/custom_call_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,7 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> CustomCallThunk::CallTypedFFI(
/*called_computation=*/nullptr, custom_call_params->ffi_execution_context,
execution_state_.get()};

// Call the function and check execution status.
auto status = ffi::Call(handler->bundle.execute, call_frame, call_options);
if (!status.ok()) {
// Overwrite the returned error code to kInternal to match the original CPU
// implementation.
// TODO(penporn): Use TF_RETURN_IF_ERROR when thunks is the only runtime.
return Internal("%s", status.message());
}
return OkExecuteEvent();
return ffi::CallAsync(handler->bundle.execute, call_frame, call_options);
}

tsl::AsyncValueRef<Thunk::ExecuteEvent> CustomCallThunk::CallUntypedAPI(
Expand Down
24 changes: 4 additions & 20 deletions xla/tests/custom_call_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1058,11 +1058,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiWrongNumberOfArguments) {
module->AddEntryComputation(builder.Build());

auto status = Execute(std::move(module), {}).status();
// NOTE: In the current CPU implementation, the 'kInternal' status code is
// returned when the argument is invalid. This behavior differs from that of
// the GPU, which returns 'kInvalidArgument' in such case. When the CPU adopts
// the thunks runtime, the status code will be unified across both backends.
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), HasSubstr("Wrong number of arguments"));
}

Expand All @@ -1085,11 +1081,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiWrongRankOfArgument) {
module->AddEntryComputation(builder.Build());

auto status = Execute(std::move(module), {}).status();
// NOTE: In the current CPU implementation, the 'kInternal' status code is
// returned when the argument is invalid. This behavior differs from that of
// the GPU, which returns 'kInvalidArgument' in such case. When the CPU adopts
// the thunks runtime, the status code will be unified across both backends.
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), HasSubstr("Wrong buffer rank"));
}

Expand All @@ -1106,11 +1098,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiWrongDTypeOfArgument) {
module->AddEntryComputation(builder.Build());

auto status = Execute(std::move(module), {}).status();
// NOTE: In the current CPU implementation, the 'kInternal' status code is
// returned when the argument is invalid. This behavior differs from that of
// the GPU, which returns 'kInvalidArgument' in such case. When the CPU adopts
// the thunks runtime, the status code will be unified across both backends.
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), HasSubstr("Wrong buffer dtype"));
}

Expand Down Expand Up @@ -1266,11 +1254,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiWrongEnumType) {
module->AddEntryComputation(builder.Build());

auto status = Execute(std::move(module), {}).status();
// NOTE: In the current CPU implementation, the 'kInternal' status code is
// returned when the argument is invalid. This behavior differs from that of
// the GPU, which returns 'kInvalidArgument' in such case. When the CPU adopts
// the thunks runtime, the status code will be unified across both backends.
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), HasSubstr("Wrong scalar data type"));
}

Expand Down

0 comments on commit d3a787c

Please sign in to comment.