From 94ebf59fbd9c93dc89dc6579e6d9db1bee3aedb4 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Fri, 27 Sep 2024 15:28:00 -0700 Subject: [PATCH] [JAX] Add PyClient::GetAllDevices() and expose it as an internal JAX backend API JAX backend forwards `xla::ifrt::Client::GetAllDevices()` to `xla::PyClient::GetAllDevices()`, which is accessible via JAX `backend.get_all_devices()`. This API is an internal JAX API that is used for building an experimental mesh utils API (finding colocated CPU devices) and should not be used by the user code. PiperOrigin-RevId: 679748877 --- xla/python/ifrt/BUILD | 1 + xla/python/ifrt/client.h | 5 ++ xla/python/ifrt/client_impl_test_lib.cc | 15 ++++- xla/python/ifrt/mock.cc | 3 + xla/python/ifrt/mock.h | 2 + xla/python/ifrt_proxy/client/client.cc | 40 ++++++++++--- xla/python/ifrt_proxy/client/client.h | 9 ++- xla/python/ifrt_proxy/client/client_test.cc | 58 +++++++++++++++++-- xla/python/ifrt_proxy/client/version.h | 2 +- xla/python/ifrt_proxy/common/VERSION.md | 6 ++ .../ifrt_proxy/common/ifrt_service.proto | 7 ++- xla/python/ifrt_proxy/server/ifrt_backend.cc | 20 +++++-- .../ifrt_proxy/server/ifrt_backend_test.cc | 58 +++++++++++++++++-- xla/python/pjrt_ifrt/pjrt_client.h | 6 ++ xla/python/py_client.cc | 10 ++++ xla/python/py_client.h | 1 + xla/python/py_compile_only_client.cc | 3 + xla/python/xla_client.py | 2 +- xla/python/xla_client_test.py | 9 +++ xla/python/xla_extension/__init__.pyi | 1 + 20 files changed, 230 insertions(+), 28 deletions(-) diff --git a/xla/python/ifrt/BUILD b/xla/python/ifrt/BUILD index 78cb7d039bcd0f..7305a4fbecb5b1 100644 --- a/xla/python/ifrt/BUILD +++ b/xla/python/ifrt/BUILD @@ -344,6 +344,7 @@ cc_library( deps = [ ":ifrt", ":test_util", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], alwayslink = True, diff --git a/xla/python/ifrt/client.h b/xla/python/ifrt/client.h index bf03b857c8b254..c2b214830753d1 100644 --- a/xla/python/ifrt/client.h +++ b/xla/python/ifrt/client.h @@ -205,6 +205,11 @@ class Client : public llvm::RTTIExtends { virtual absl::Span addressable_devices() const = 0; virtual int process_index() const = 0; + // Returns all devices. The result includes primary devices that are included + // in `devices()` as well as any other devices that are associated with + // the primary devices. + virtual absl::Span GetAllDevices() const = 0; + // TODO(hyeontaek): Consider removing this API. This API is potentially not // being used by JAX or will be replaced with explicit device assignment. virtual absl::StatusOr GetDefaultDeviceAssignment( diff --git a/xla/python/ifrt/client_impl_test_lib.cc b/xla/python/ifrt/client_impl_test_lib.cc index 6a0f7e2cd7e27c..38edc163d0204d 100644 --- a/xla/python/ifrt/client_impl_test_lib.cc +++ b/xla/python/ifrt/client_impl_test_lib.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/test_util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -54,6 +55,18 @@ TEST(ClientImplTest, Devices) { EXPECT_GE(client->process_index(), 0); } +TEST(ClientImplTest, GetAllDevices) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + + EXPECT_GE(client->GetAllDevices().size(), client->device_count()); + + for (Device* device : client->GetAllDevices()) { + TF_ASSERT_OK_AND_ASSIGN(auto* looked_up_device, + client->LookupDevice(device->Id())); + EXPECT_EQ(device, looked_up_device); + } +} + TEST(ClientImplTest, DefaultCompiler) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); EXPECT_THAT(client->GetDefaultCompiler(), NotNull()); diff --git a/xla/python/ifrt/mock.cc b/xla/python/ifrt/mock.cc index 9bd006d48753de..09d527052d9933 100644 --- a/xla/python/ifrt/mock.cc +++ b/xla/python/ifrt/mock.cc @@ -172,6 +172,9 @@ MockClient::MockClient(std::unique_ptr delegated) ON_CALL(*this, process_index).WillByDefault([this]() { return delegated_->process_index(); }); + ON_CALL(*this, GetAllDevices).WillByDefault([this]() { + return delegated_->GetAllDevices(); + }); ON_CALL(*this, GetDefaultDeviceAssignment) .WillByDefault([this](int num_replicas, int num_partitions) { return delegated_->GetDefaultDeviceAssignment(num_replicas, diff --git a/xla/python/ifrt/mock.h b/xla/python/ifrt/mock.h index c437268b2f7f43..391942d023dab1 100644 --- a/xla/python/ifrt/mock.h +++ b/xla/python/ifrt/mock.h @@ -145,6 +145,8 @@ class MockClient : public llvm::RTTIExtends { MOCK_METHOD(absl::Span, addressable_devices, (), (const, final)); MOCK_METHOD(int, process_index, (), (const, final)); + MOCK_METHOD(absl::Span, GetAllDevices, (), + (const, final)); MOCK_METHOD(absl::StatusOr, GetDefaultDeviceAssignment, (int num_replicas, int num_partitions), (const, final)); MOCK_METHOD(absl::StatusOr, LookupDevice, (DeviceId device_id), diff --git a/xla/python/ifrt_proxy/client/client.cc b/xla/python/ifrt_proxy/client/client.cc index db3a6026ac885d..277d8600e9ee9c 100644 --- a/xla/python/ifrt_proxy/client/client.cc +++ b/xla/python/ifrt_proxy/client/client.cc @@ -67,6 +67,20 @@ absl::StatusOr> Client::Create( absl::flat_hash_set addressable_device_ids( init_response.addressable_device_ids().begin(), init_response.addressable_device_ids().end()); + absl::flat_hash_set primary_device_ids; + if (rpc_helper->version().protocol_version() < 7) { + // Legacy implementation for servers do not support Client::GetAllDevices() + // and thus do not provide device_ids(). Assume that it contains all device + // ids from devices(). + primary_device_ids.reserve(init_response.all_devices().size()); + for (const auto& d : init_response.all_devices()) { + primary_device_ids.insert(d.id()); + } + } else { + primary_device_ids.reserve(init_response.primary_device_ids().size()); + primary_device_ids.insert(init_response.primary_device_ids().begin(), + init_response.primary_device_ids().end()); + } absl::flat_hash_map> memories; for (const auto& m : init_response.memories()) { @@ -77,10 +91,11 @@ absl::StatusOr> Client::Create( } absl::flat_hash_map> devices; - std::vector device_ptrs; + std::vector primary_device_ptrs; std::vector addressable_device_ptrs; + std::vector all_device_ptrs; - for (const auto& d : init_response.devices()) { + for (const auto& d : init_response.all_devices()) { absl::flat_hash_map pjrt_device_attributes; if (rpc_helper->version().protocol_version() <= 3) { @@ -99,14 +114,18 @@ absl::StatusOr> Client::Create( d.device_kind(), d.debug_string(), d.to_string(), std::move(pjrt_device_attributes)); bool is_addressable = addressable_device_ids.contains(d.id()); + bool is_primary = primary_device_ids.contains(d.id()); auto device = std::make_unique(std::move(desc), d.local_device_id(), d.local_hardware_id(), is_addressable); - device_ptrs.push_back(device.get()); + all_device_ptrs.push_back(device.get()); if (is_addressable) { addressable_device_ptrs.push_back(device.get()); } + if (is_primary) { + primary_device_ptrs.push_back(device.get()); + } if (d.has_default_memory_id()) { const auto it = memories.find(d.default_memory_id()); @@ -150,9 +169,10 @@ absl::StatusOr> Client::Create( std::move(rpc_helper), init_response.session_id(), init_response.platform_name(), init_response.platform_version(), init_response.platform_id(), init_response.process_index(), runtime_type, - std::move(devices), device_ptrs, std::move(addressable_device_ptrs), + std::move(devices), std::move(primary_device_ptrs), + std::move(addressable_device_ptrs), all_device_ptrs, std::move(memories))); - for (ifrt::Device* device : device_ptrs) { + for (ifrt::Device* device : all_device_ptrs) { tensorflow::down_cast(device)->client_ = client.get(); } return client; @@ -163,8 +183,9 @@ Client::Client(std::shared_ptr rpc_helper, uint64_t session_id, uint64_t platform_id, uint64_t process_index, std::string runtime_type, absl::flat_hash_map> devices, - std::vector device_ptrs, + std::vector primary_device_ptrs, std::vector addressable_device_ptrs, + std::vector all_device_ptrs, absl::flat_hash_map> memories) : rpc_helper_(rpc_helper), platform_name_(std::move(platform_name)), @@ -175,8 +196,9 @@ Client::Client(std::shared_ptr rpc_helper, uint64_t session_id, // TODO(b/309059940): Forward the backend attributes to the client. attributes_(AttributeMap::Map()), devices_(std::move(devices)), - device_ptrs_(device_ptrs), + primary_device_ptrs_(primary_device_ptrs), addressable_device_ptrs_(std::move(addressable_device_ptrs)), + all_device_ptrs_(all_device_ptrs), memories_(std::move(memories)), default_compiler_(this, rpc_helper) {} @@ -302,6 +324,10 @@ xla::ifrt::Future<> Client::GetReadyFuture( return JoinFutures(futures); } +absl::Span Client::GetAllDevices() const { + return all_device_ptrs_; +} + absl::StatusOr Client::GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const { auto req = std::make_unique(); diff --git a/xla/python/ifrt_proxy/client/client.h b/xla/python/ifrt_proxy/client/client.h index fd700441eb63ec..d271819a95ac48 100644 --- a/xla/python/ifrt_proxy/client/client.h +++ b/xla/python/ifrt_proxy/client/client.h @@ -110,12 +110,13 @@ class Client final : public llvm::RTTIExtends { return addressable_devices().size(); } absl::Span devices() const override { - return device_ptrs_; + return primary_device_ptrs_; } absl::Span addressable_devices() const override { return addressable_device_ptrs_; } int process_index() const override { return process_index_; } + absl::Span GetAllDevices() const override; absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; absl::StatusOr LookupDevice( @@ -148,8 +149,9 @@ class Client final : public llvm::RTTIExtends { std::string platform_name, std::string platform_version, uint64_t platform_id, uint64_t process_index, std::string runtime_type, absl::flat_hash_map> devices, - std::vector device_ptrs, + std::vector primary_device_ptrs, std::vector addressable_device_ptrs, + std::vector all_device_ptrs, absl::flat_hash_map> memories); // rpc_helper_ will be referenced by various IFRT objects whose lifetime is @@ -166,8 +168,9 @@ class Client final : public llvm::RTTIExtends { const AttributeMap attributes_; const absl::flat_hash_map> devices_; - const std::vector device_ptrs_; + const std::vector primary_device_ptrs_; const std::vector addressable_device_ptrs_; + const std::vector all_device_ptrs_; const absl::flat_hash_map> memories_; diff --git a/xla/python/ifrt_proxy/client/client_test.cc b/xla/python/ifrt_proxy/client/client_test.cc index 3f1dbb45c7dea6..03dd43f3c93cfe 100644 --- a/xla/python/ifrt_proxy/client/client_test.cc +++ b/xla/python/ifrt_proxy/client/client_test.cc @@ -83,7 +83,7 @@ class ClientTest : public ::testing::TestWithParam { platform_id: 42 process_index: 1 runtime_type: "ifrt-service" - devices { + all_devices { id: 0 local_hardware_id: 1234 device_kind: "mock" @@ -94,7 +94,7 @@ class ClientTest : public ::testing::TestWithParam { value { string_value: "device0" } } } - devices { + all_devices { id: 1 local_hardware_id: 1234 device_kind: "mock" @@ -120,6 +120,55 @@ class ClientTest : public ::testing::TestWithParam { } )pb", &response)); + } else if (Version().protocol_version() < 7) { + ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + R"pb( + platform_name: "ifrt-service" + platform_version: "n/a" + platform_id: 42 + process_index: 1 + runtime_type: "ifrt-service" + all_devices { + id: 0 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 0 + memory_ids: [ 0 ] + attributes { + attributes { + key: "name" + value { string_value: "device0" } + } + } + } + all_devices { + id: 1 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 1 + memory_ids: [ 1 ] + attributes { + attributes { + key: "name" + value { string_value: "device1" } + } + } + } + addressable_device_ids: 1 + memories { + id: 0 + memory_space_kind: "mock" + kind_id: 0 + device_ids: [ 0 ] + } + memories { + id: 1 + memory_space_kind: "mock" + kind_id: 1 + device_ids: [ 1 ] + } + )pb", + &response)); } else { ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( R"pb( @@ -128,7 +177,7 @@ class ClientTest : public ::testing::TestWithParam { platform_id: 42 process_index: 1 runtime_type: "ifrt-service" - devices { + all_devices { id: 0 local_hardware_id: 1234 device_kind: "mock" @@ -141,7 +190,7 @@ class ClientTest : public ::testing::TestWithParam { } } } - devices { + all_devices { id: 1 local_hardware_id: 1234 device_kind: "mock" @@ -154,6 +203,7 @@ class ClientTest : public ::testing::TestWithParam { } } } + primary_device_ids: [ 0, 1 ] addressable_device_ids: 1 memories { id: 0 diff --git a/xla/python/ifrt_proxy/client/version.h b/xla/python/ifrt_proxy/client/version.h index f9d682243daee2..a2ec9b80d27ff4 100644 --- a/xla/python/ifrt_proxy/client/version.h +++ b/xla/python/ifrt_proxy/client/version.h @@ -24,7 +24,7 @@ namespace proxy { // LINT.IfChange // TODO(b/296144873): Document the version upgrade policy. inline constexpr int kClientMinVersion = 3; -inline constexpr int kClientMaxVersion = 6; +inline constexpr int kClientMaxVersion = 7; // LINT.ThenChange(//tensorflow/compiler/xla/python/ifrt_proxy/common/VERSION.md) } // namespace proxy diff --git a/xla/python/ifrt_proxy/common/VERSION.md b/xla/python/ifrt_proxy/common/VERSION.md index 35173c3952caec..29797d5e6decda 100644 --- a/xla/python/ifrt_proxy/common/VERSION.md +++ b/xla/python/ifrt_proxy/common/VERSION.md @@ -35,3 +35,9 @@ * Added date: 2024-09-30. * Changes: * Added `ExecuteOptions::fill_status`. + +## Version 7 + +* Added date: 2024-10-01. +* Changes: + * Added support for `Client::GetAllDevices()`. diff --git a/xla/python/ifrt_proxy/common/ifrt_service.proto b/xla/python/ifrt_proxy/common/ifrt_service.proto index 3f17ee69abbd57..4d94c51a07acc7 100644 --- a/xla/python/ifrt_proxy/common/ifrt_service.proto +++ b/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -207,9 +207,12 @@ message InitResponse { AttributeMapProto attributes = 10; // New in Version 4. } - repeated Device devices = 6; // == ifrt::Client::devices() + repeated Device all_devices = 6; // == ifrt::Client::GetAllDevices() + repeated int32 primary_device_ids = + 10; // == [device.id for device in ifrt::Client::devices()] repeated int32 addressable_device_ids = - 7; // == ifrt::Client::addressable_devices() + 7; // == [device.id for device in ifrt::Client::GetAllDevices() if + // device.IsAddressable()] message Memory { int32 id = 1; diff --git a/xla/python/ifrt_proxy/server/ifrt_backend.cc b/xla/python/ifrt_proxy/server/ifrt_backend.cc index 1149e771b8e118..36431c7a9c8a83 100644 --- a/xla/python/ifrt_proxy/server/ifrt_backend.cc +++ b/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -266,8 +266,14 @@ absl::StatusOr IfrtBackend::HandleInit( init_resp->set_runtime_type(AsProtoStringData(client_->runtime_type())); init_resp->set_process_index(client_->process_index()); - for (auto* device : client_->devices()) { - InitResponse::Device* d = init_resp->add_devices(); + absl::Span all_devices; + if (version_.protocol_version() < 7) { + all_devices = client_->devices(); + } else { + all_devices = client_->GetAllDevices(); + } + for (auto* device : all_devices) { + InitResponse::Device* d = init_resp->add_all_devices(); d->set_id(device->Id().value()); d->set_device_kind(AsProtoStringData(device->Kind())); if (auto default_memory = device->DefaultMemory(); default_memory.ok()) { @@ -289,13 +295,17 @@ absl::StatusOr IfrtBackend::HandleInit( } else { *d->mutable_attributes() = device->Attributes().ToProto(); } + + if (device->IsAddressable()) { + init_resp->add_addressable_device_ids(device->Id().value()); + } } - for (auto* addressable_device : client_->addressable_devices()) { - init_resp->add_addressable_device_ids(addressable_device->Id().value()); + for (auto* device : client_->devices()) { + init_resp->add_primary_device_ids(device->Id().value()); } absl::flat_hash_map memories; - for (auto* device : client_->devices()) { + for (auto* device : all_devices) { for (xla::ifrt::Memory* memory : device->Memories()) { const auto [it, inserted] = memories.insert({memory->Id().value(), memory}); diff --git a/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index 09c74cc68981b9..21f8a59fd26780 100644 --- a/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -250,6 +250,8 @@ class IfrtBackendHandlerTest : public IfrtBackendTest { } ON_CALL(*mock_client, devices()).WillByDefault(Return(raw_device_ptrs)); + ON_CALL(*mock_client, GetAllDevices()) + .WillByDefault(Return(raw_device_ptrs)); ON_CALL(*mock_client, LookupDevice(_)) .WillByDefault( Invoke([this](DeviceId id) -> absl::StatusOr { @@ -434,7 +436,7 @@ TEST_P(IfrtBackendHandlerTest, Init) { platform_id: 42 process_index: 1 runtime_type: "ifrt-service" - devices { + all_devices { id: 0 device_kind: "mock" default_memory_id: 0 @@ -444,7 +446,7 @@ TEST_P(IfrtBackendHandlerTest, Init) { value { string_value: "device0" } } } - devices { + all_devices { id: 1 device_kind: "mock" default_memory_id: 1 @@ -466,6 +468,53 @@ TEST_P(IfrtBackendHandlerTest, Init) { } } )pb")))))); + } else if (Version().protocol_version() < 7) { + EXPECT_THAT(CallBackend(std::move(request)), + IsOkAndHolds(Pointee( + Partially(IgnoringRepeatedFieldOrdering(EquivToProto(R"pb( + init_response { + session_id: 12345 + platform_name: "ifrt_backend" + platform_version: "n/a" + platform_id: 42 + process_index: 1 + runtime_type: "ifrt-service" + all_devices { + id: 0 + device_kind: "mock" + default_memory_id: 0 + memory_ids: [ 0 ] + attributes { + attributes { + key: "name" + value { string_value: "device0" } + } + } + } + all_devices { + id: 1 + device_kind: "mock" + default_memory_id: 1 + memory_ids: [ 1 ] + attributes { + attributes { + key: "name" + value { string_value: "device1" } + } + } + } + memories { + id: 0 + memory_space_kind: "mock" + device_ids: [ 0 ] + } + memories { + id: 1 + memory_space_kind: "mock" + device_ids: [ 1 ] + } + } + )pb")))))); } else { EXPECT_THAT(CallBackend(std::move(request)), IsOkAndHolds(Pointee( @@ -477,7 +526,7 @@ TEST_P(IfrtBackendHandlerTest, Init) { platform_id: 42 process_index: 1 runtime_type: "ifrt-service" - devices { + all_devices { id: 0 device_kind: "mock" default_memory_id: 0 @@ -489,7 +538,7 @@ TEST_P(IfrtBackendHandlerTest, Init) { } } } - devices { + all_devices { id: 1 device_kind: "mock" default_memory_id: 1 @@ -501,6 +550,7 @@ TEST_P(IfrtBackendHandlerTest, Init) { } } } + primary_device_ids: [ 0, 1 ] memories { id: 0 memory_space_kind: "mock" diff --git a/xla/python/pjrt_ifrt/pjrt_client.h b/xla/python/pjrt_ifrt/pjrt_client.h index 23900c049f344c..fdf11b8ecc67c5 100644 --- a/xla/python/pjrt_ifrt/pjrt_client.h +++ b/xla/python/pjrt_ifrt/pjrt_client.h @@ -206,6 +206,12 @@ class PjRtClient final return addressable_devices_; } int process_index() const override { return pjrt_client_->process_index(); } + + absl::Span GetAllDevices() const override { + DCHECK(this); + return devices_; + } + absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override { DCHECK(this); diff --git a/xla/python/py_client.cc b/xla/python/py_client.cc index d108e9d9c1e47d..9aa30743c59e29 100644 --- a/xla/python/py_client.cc +++ b/xla/python/py_client.cc @@ -179,6 +179,15 @@ std::vector> PyClient::LocalDevices() { return devices; } +std::vector> PyClient::GetAllDevices() { + std::vector> devices; + devices.reserve(ifrt_client_->GetAllDevices().size()); + for (ifrt::Device* device : ifrt_client_->GetAllDevices()) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + absl::StatusOr> PyClient::DeviceFromLocalHardwareId( int local_hardware_id) { TF_ASSIGN_OR_RETURN(ifrt::Device * device, @@ -693,6 +702,7 @@ PyType_Slot PyClient::slots_[] = { .def("local_device_count", &PyClient::addressable_device_count) .def("devices", &PyClient::Devices) .def("local_devices", &PyClient::LocalDevices) + .def("get_all_devices", &PyClient::GetAllDevices) .def("device_from_local_hardware_id", xla::ValueOrThrowWrapper(&PyClient::DeviceFromLocalHardwareId)) .def("live_executables", &PyClient::LiveExecutables) diff --git a/xla/python/py_client.h b/xla/python/py_client.h index 374b7f6d2e530c..7c95431619afd5 100644 --- a/xla/python/py_client.h +++ b/xla/python/py_client.h @@ -126,6 +126,7 @@ class PyClient { std::vector> Devices(); std::vector> LocalDevices(); + std::vector> GetAllDevices(); absl::StatusOr> DeviceFromLocalHardwareId( int local_hardware_id); diff --git a/xla/python/py_compile_only_client.cc b/xla/python/py_compile_only_client.cc index fa32529000f6de..9d37a41af1c87e 100644 --- a/xla/python/py_compile_only_client.cc +++ b/xla/python/py_compile_only_client.cc @@ -295,6 +295,9 @@ class CompileOnlyIfRtClient final return {}; } int process_index() const override { return 0; } + absl::Span GetAllDevices() const override { + return devices_; + } absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override { return Unimplemented( diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 51d879814b2e68..a96b4fecd736f1 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 289 +_version = 290 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index 441d5fbf450fa4..368ef243da3c19 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -2669,6 +2669,15 @@ def testScatter(self): class DeviceTest(ComputationTest): + def testDevices(self): + self.assertNotEmpty(self.backend.devices()) + + def testLocalDevices(self): + self.assertNotEmpty(self.backend.local_devices()) + + def testGetAllDevices(self): + self.assertNotEmpty(self.backend.get_all_devices()) + def testPlatform(self): for device in self.backend.local_devices(): self.assertEqual(device.platform, self.backend.platform) diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index e363d8d82471cb..845909f076224b 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -499,6 +499,7 @@ class Client: def local_device_count(self) -> int: ... def devices(self) -> List[Device]: ... def local_devices(self) -> List[Device]: ... + def get_all_devices(self) -> List[Device]: ... def device_from_local_hardware_id(self, int) -> Device: ... def live_buffers(self) -> List[Any]: ... def live_arrays(self) -> List[ArrayImpl]: ...