Skip to content

Commit

Permalink
Create public entry point for PJRT TPU wrapper and use the C++ PJRT W…
Browse files Browse the repository at this point in the history
…rapper by default

PiperOrigin-RevId: 695506093
  • Loading branch information
changm authored and Google-ML-Automation committed Nov 11, 2024
1 parent 0fd7aac commit 879ce5e
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 0 deletions.
42 changes: 42 additions & 0 deletions xla/pjrt/plugin/xla_tpu/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
load("@tsl//tsl/platform:rules_cc.bzl", "cc_library")
load("//xla/tsl:tsl.default.bzl", "get_compatible_with_libtpu_portable")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = ["//visibility:public"],
licenses = ["notice"],
)

cc_library(
name = "xla_tpu_pjrt_client",
srcs = [
"xla_tpu_pjrt_client.cc",
],
hdrs = ["xla_tpu_pjrt_client.h"],
compatible_with = get_compatible_with_libtpu_portable(),
deps = [
"//xla/pjrt:pjrt_api",
"//xla/pjrt:pjrt_c_api_client",
"//xla/pjrt:pjrt_client",
"//xla/pjrt/c:pjrt_c_api_hdrs",
"//xla/pjrt/c:pjrt_c_api_tpu_hdrs",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

cc_test(
name = "xla_tpu_pjrt_client_test",
srcs = ["xla_tpu_pjrt_client_test.cc"],
tags = ["no_oss"],
deps = [
":xla_tpu_pjrt_client",
"//xla/tests:xla_internal_test_main",
"@tsl//tsl/platform:test",
],
)
2 changes: 2 additions & 0 deletions xla/pjrt/plugin/xla_tpu/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Public PJRT entry point for XLA:TPU. Please use PJRT to access XLA:TPU
functionality.
52 changes: 52 additions & 0 deletions xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client.h"

#include <memory>

#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_tpu.h"
#include "xla/pjrt/pjrt_c_api_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "tsl/platform/statusor.h"

namespace xla {

const char kTpuPjrtName[] = "tpu";

absl::StatusOr<std::unique_ptr<PjRtClient>> GetXlaPjrtTpuClient() {
const PJRT_Api* tpu_c_api = GetPjrtApi();
if (!tpu_c_api) {
return absl::InternalError("Failed to get PjrtApi");
}

TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> tpu_client,
xla::WrapClientAroundCApi(tpu_c_api));

if (tpu_client->platform_name() != kTpuPjrtName) {
return absl::InternalError(
absl::StrCat("Expected TPU client, got ", tpu_client->platform_name()));
}

return tpu_client;
}

} // namespace xla
31 changes: 31 additions & 0 deletions xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_PJRT_PLUGIN_XLA_TPU_XLA_TPU_PJRT_CLIENT_H_
#define XLA_PJRT_PLUGIN_XLA_TPU_XLA_TPU_PJRT_CLIENT_H_

#include <memory>

#include "absl/status/statusor.h"
#include "xla/pjrt/pjrt_client.h"

namespace xla {

// Public entry point to get an XLA:TPU PjRtClient
absl::StatusOr<std::unique_ptr<PjRtClient>> GetXlaPjrtTpuClient();

} // namespace xla

#endif // XLA_PJRT_PLUGIN_XLA_TPU_XLA_TPU_PJRT_CLIENT_H_
27 changes: 27 additions & 0 deletions xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/pjrt/plugin/xla_tpu/xla_tpu_pjrt_client.h"

#include "tsl/platform/test.h"

namespace xla {

TEST(XlaCpuPjrtClientTest, GetXlaPjrtTpuClient) {
ASSERT_OK_AND_ASSIGN(auto client, GetXlaPjrtTpuClient());
EXPECT_EQ(client->platform_name(), "tpu");
}

} // namespace xla

0 comments on commit 879ce5e

Please sign in to comment.