diff --git a/xla/python/ifrt/BUILD b/xla/python/ifrt/BUILD index 78cb7d039bcd0f..7305a4fbecb5b1 100644 --- a/xla/python/ifrt/BUILD +++ b/xla/python/ifrt/BUILD @@ -344,6 +344,7 @@ cc_library( deps = [ ":ifrt", ":test_util", + "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", ], alwayslink = True, diff --git a/xla/python/ifrt/client.h b/xla/python/ifrt/client.h index bf03b857c8b254..c2b214830753d1 100644 --- a/xla/python/ifrt/client.h +++ b/xla/python/ifrt/client.h @@ -205,6 +205,11 @@ class Client : public llvm::RTTIExtends { virtual absl::Span addressable_devices() const = 0; virtual int process_index() const = 0; + // Returns all devices. The result includes primary devices that are included + // in `devices()` as well as any other devices that are associated with + // the primary devices. + virtual absl::Span GetAllDevices() const = 0; + // TODO(hyeontaek): Consider removing this API. This API is potentially not // being used by JAX or will be replaced with explicit device assignment. virtual absl::StatusOr GetDefaultDeviceAssignment( diff --git a/xla/python/ifrt/client_impl_test_lib.cc b/xla/python/ifrt/client_impl_test_lib.cc index 6a0f7e2cd7e27c..38edc163d0204d 100644 --- a/xla/python/ifrt/client_impl_test_lib.cc +++ b/xla/python/ifrt/client_impl_test_lib.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/test_util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -54,6 +55,18 @@ TEST(ClientImplTest, Devices) { EXPECT_GE(client->process_index(), 0); } +TEST(ClientImplTest, GetAllDevices) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + + EXPECT_GE(client->GetAllDevices().size(), client->device_count()); + + for (Device* device : client->GetAllDevices()) { + TF_ASSERT_OK_AND_ASSIGN(auto* looked_up_device, + client->LookupDevice(device->Id())); + EXPECT_EQ(device, looked_up_device); + } +} + TEST(ClientImplTest, DefaultCompiler) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); EXPECT_THAT(client->GetDefaultCompiler(), NotNull()); diff --git a/xla/python/ifrt/mock.cc b/xla/python/ifrt/mock.cc index 9bd006d48753de..09d527052d9933 100644 --- a/xla/python/ifrt/mock.cc +++ b/xla/python/ifrt/mock.cc @@ -172,6 +172,9 @@ MockClient::MockClient(std::unique_ptr delegated) ON_CALL(*this, process_index).WillByDefault([this]() { return delegated_->process_index(); }); + ON_CALL(*this, GetAllDevices).WillByDefault([this]() { + return delegated_->GetAllDevices(); + }); ON_CALL(*this, GetDefaultDeviceAssignment) .WillByDefault([this](int num_replicas, int num_partitions) { return delegated_->GetDefaultDeviceAssignment(num_replicas, diff --git a/xla/python/ifrt/mock.h b/xla/python/ifrt/mock.h index c437268b2f7f43..391942d023dab1 100644 --- a/xla/python/ifrt/mock.h +++ b/xla/python/ifrt/mock.h @@ -145,6 +145,8 @@ class MockClient : public llvm::RTTIExtends { MOCK_METHOD(absl::Span, addressable_devices, (), (const, final)); MOCK_METHOD(int, process_index, (), (const, final)); + MOCK_METHOD(absl::Span, GetAllDevices, (), + (const, final)); MOCK_METHOD(absl::StatusOr, GetDefaultDeviceAssignment, (int num_replicas, int num_partitions), (const, final)); MOCK_METHOD(absl::StatusOr, LookupDevice, (DeviceId device_id), diff --git a/xla/python/ifrt_proxy/client/client.cc b/xla/python/ifrt_proxy/client/client.cc index db3a6026ac885d..277d8600e9ee9c 100644 --- a/xla/python/ifrt_proxy/client/client.cc +++ b/xla/python/ifrt_proxy/client/client.cc @@ -67,6 +67,20 @@ absl::StatusOr> Client::Create( absl::flat_hash_set addressable_device_ids( init_response.addressable_device_ids().begin(), init_response.addressable_device_ids().end()); + absl::flat_hash_set primary_device_ids; + if (rpc_helper->version().protocol_version() < 7) { + // Legacy implementation for servers do not support Client::GetAllDevices() + // and thus do not provide device_ids(). Assume that it contains all device + // ids from devices(). + primary_device_ids.reserve(init_response.all_devices().size()); + for (const auto& d : init_response.all_devices()) { + primary_device_ids.insert(d.id()); + } + } else { + primary_device_ids.reserve(init_response.primary_device_ids().size()); + primary_device_ids.insert(init_response.primary_device_ids().begin(), + init_response.primary_device_ids().end()); + } absl::flat_hash_map> memories; for (const auto& m : init_response.memories()) { @@ -77,10 +91,11 @@ absl::StatusOr> Client::Create( } absl::flat_hash_map> devices; - std::vector device_ptrs; + std::vector primary_device_ptrs; std::vector addressable_device_ptrs; + std::vector all_device_ptrs; - for (const auto& d : init_response.devices()) { + for (const auto& d : init_response.all_devices()) { absl::flat_hash_map pjrt_device_attributes; if (rpc_helper->version().protocol_version() <= 3) { @@ -99,14 +114,18 @@ absl::StatusOr> Client::Create( d.device_kind(), d.debug_string(), d.to_string(), std::move(pjrt_device_attributes)); bool is_addressable = addressable_device_ids.contains(d.id()); + bool is_primary = primary_device_ids.contains(d.id()); auto device = std::make_unique(std::move(desc), d.local_device_id(), d.local_hardware_id(), is_addressable); - device_ptrs.push_back(device.get()); + all_device_ptrs.push_back(device.get()); if (is_addressable) { addressable_device_ptrs.push_back(device.get()); } + if (is_primary) { + primary_device_ptrs.push_back(device.get()); + } if (d.has_default_memory_id()) { const auto it = memories.find(d.default_memory_id()); @@ -150,9 +169,10 @@ absl::StatusOr> Client::Create( std::move(rpc_helper), init_response.session_id(), init_response.platform_name(), init_response.platform_version(), init_response.platform_id(), init_response.process_index(), runtime_type, - std::move(devices), device_ptrs, std::move(addressable_device_ptrs), + std::move(devices), std::move(primary_device_ptrs), + std::move(addressable_device_ptrs), all_device_ptrs, std::move(memories))); - for (ifrt::Device* device : device_ptrs) { + for (ifrt::Device* device : all_device_ptrs) { tensorflow::down_cast(device)->client_ = client.get(); } return client; @@ -163,8 +183,9 @@ Client::Client(std::shared_ptr rpc_helper, uint64_t session_id, uint64_t platform_id, uint64_t process_index, std::string runtime_type, absl::flat_hash_map> devices, - std::vector device_ptrs, + std::vector primary_device_ptrs, std::vector addressable_device_ptrs, + std::vector all_device_ptrs, absl::flat_hash_map> memories) : rpc_helper_(rpc_helper), platform_name_(std::move(platform_name)), @@ -175,8 +196,9 @@ Client::Client(std::shared_ptr rpc_helper, uint64_t session_id, // TODO(b/309059940): Forward the backend attributes to the client. attributes_(AttributeMap::Map()), devices_(std::move(devices)), - device_ptrs_(device_ptrs), + primary_device_ptrs_(primary_device_ptrs), addressable_device_ptrs_(std::move(addressable_device_ptrs)), + all_device_ptrs_(all_device_ptrs), memories_(std::move(memories)), default_compiler_(this, rpc_helper) {} @@ -302,6 +324,10 @@ xla::ifrt::Future<> Client::GetReadyFuture( return JoinFutures(futures); } +absl::Span Client::GetAllDevices() const { + return all_device_ptrs_; +} + absl::StatusOr Client::GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const { auto req = std::make_unique(); diff --git a/xla/python/ifrt_proxy/client/client.h b/xla/python/ifrt_proxy/client/client.h index fd700441eb63ec..d271819a95ac48 100644 --- a/xla/python/ifrt_proxy/client/client.h +++ b/xla/python/ifrt_proxy/client/client.h @@ -110,12 +110,13 @@ class Client final : public llvm::RTTIExtends { return addressable_devices().size(); } absl::Span devices() const override { - return device_ptrs_; + return primary_device_ptrs_; } absl::Span addressable_devices() const override { return addressable_device_ptrs_; } int process_index() const override { return process_index_; } + absl::Span GetAllDevices() const override; absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; absl::StatusOr LookupDevice( @@ -148,8 +149,9 @@ class Client final : public llvm::RTTIExtends { std::string platform_name, std::string platform_version, uint64_t platform_id, uint64_t process_index, std::string runtime_type, absl::flat_hash_map> devices, - std::vector device_ptrs, + std::vector primary_device_ptrs, std::vector addressable_device_ptrs, + std::vector all_device_ptrs, absl::flat_hash_map> memories); // rpc_helper_ will be referenced by various IFRT objects whose lifetime is @@ -166,8 +168,9 @@ class Client final : public llvm::RTTIExtends { const AttributeMap attributes_; const absl::flat_hash_map> devices_; - const std::vector device_ptrs_; + const std::vector primary_device_ptrs_; const std::vector addressable_device_ptrs_; + const std::vector all_device_ptrs_; const absl::flat_hash_map> memories_; diff --git a/xla/python/ifrt_proxy/client/client_test.cc b/xla/python/ifrt_proxy/client/client_test.cc index 3f1dbb45c7dea6..03dd43f3c93cfe 100644 --- a/xla/python/ifrt_proxy/client/client_test.cc +++ b/xla/python/ifrt_proxy/client/client_test.cc @@ -83,7 +83,7 @@ class ClientTest : public ::testing::TestWithParam { platform_id: 42 process_index: 1 runtime_type: "ifrt-service" - devices { + all_devices { id: 0 local_hardware_id: 1234 device_kind: "mock" @@ -94,7 +94,7 @@ class ClientTest : public ::testing::TestWithParam { value { string_value: "device0" } } } - devices { + all_devices { id: 1 local_hardware_id: 1234 device_kind: "mock" @@ -120,6 +120,55 @@ class ClientTest : public ::testing::TestWithParam { } )pb", &response)); + } else if (Version().protocol_version() < 7) { + ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + R"pb( + platform_name: "ifrt-service" + platform_version: "n/a" + platform_id: 42 + process_index: 1 + runtime_type: "ifrt-service" + all_devices { + id: 0 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 0 + memory_ids: [ 0 ] + attributes { + attributes { + key: "name" + value { string_value: "device0" } + } + } + } + all_devices { + id: 1 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 1 + memory_ids: [ 1 ] + attributes { + attributes { + key: "name" + value { string_value: "device1" } + } + } + } + addressable_device_ids: 1 + memories { + id: 0 + memory_space_kind: "mock" + kind_id: 0 + device_ids: [ 0 ] + } + memories { + id: 1 + memory_space_kind: "mock" + kind_id: 1 + device_ids: [ 1 ] + } + )pb", + &response)); } else { ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( R"pb( @@ -128,7 +177,7 @@ class ClientTest : public ::testing::TestWithParam { platform_id: 42 process_index: 1 runtime_type: "ifrt-service" - devices { + all_devices { id: 0 local_hardware_id: 1234 device_kind: "mock" @@ -141,7 +190,7 @@ class ClientTest : public ::testing::TestWithParam { } } } - devices { + all_devices { id: 1 local_hardware_id: 1234 device_kind: "mock" @@ -154,6 +203,7 @@ class ClientTest : public ::testing::TestWithParam { } } } + primary_device_ids: [ 0, 1 ] addressable_device_ids: 1 memories { id: 0 diff --git a/xla/python/ifrt_proxy/client/version.h b/xla/python/ifrt_proxy/client/version.h index f9d682243daee2..a2ec9b80d27ff4 100644 --- a/xla/python/ifrt_proxy/client/version.h +++ b/xla/python/ifrt_proxy/client/version.h @@ -24,7 +24,7 @@ namespace proxy { // LINT.IfChange // TODO(b/296144873): Document the version upgrade policy. inline constexpr int kClientMinVersion = 3; -inline constexpr int kClientMaxVersion = 6; +inline constexpr int kClientMaxVersion = 7; // LINT.ThenChange(//tensorflow/compiler/xla/python/ifrt_proxy/common/VERSION.md) } // namespace proxy diff --git a/xla/python/ifrt_proxy/common/VERSION.md b/xla/python/ifrt_proxy/common/VERSION.md index 35173c3952caec..29797d5e6decda 100644 --- a/xla/python/ifrt_proxy/common/VERSION.md +++ b/xla/python/ifrt_proxy/common/VERSION.md @@ -35,3 +35,9 @@ * Added date: 2024-09-30. * Changes: * Added `ExecuteOptions::fill_status`. + +## Version 7 + +* Added date: 2024-10-01. +* Changes: + * Added support for `Client::GetAllDevices()`. diff --git a/xla/python/ifrt_proxy/common/ifrt_service.proto b/xla/python/ifrt_proxy/common/ifrt_service.proto index 3f17ee69abbd57..4d94c51a07acc7 100644 --- a/xla/python/ifrt_proxy/common/ifrt_service.proto +++ b/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -207,9 +207,12 @@ message InitResponse { AttributeMapProto attributes = 10; // New in Version 4. } - repeated Device devices = 6; // == ifrt::Client::devices() + repeated Device all_devices = 6; // == ifrt::Client::GetAllDevices() + repeated int32 primary_device_ids = + 10; // == [device.id for device in ifrt::Client::devices()] repeated int32 addressable_device_ids = - 7; // == ifrt::Client::addressable_devices() + 7; // == [device.id for device in ifrt::Client::GetAllDevices() if + // device.IsAddressable()] message Memory { int32 id = 1; diff --git a/xla/python/ifrt_proxy/server/ifrt_backend.cc b/xla/python/ifrt_proxy/server/ifrt_backend.cc index 1149e771b8e118..36431c7a9c8a83 100644 --- a/xla/python/ifrt_proxy/server/ifrt_backend.cc +++ b/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -266,8 +266,14 @@ absl::StatusOr IfrtBackend::HandleInit( init_resp->set_runtime_type(AsProtoStringData(client_->runtime_type())); init_resp->set_process_index(client_->process_index()); - for (auto* device : client_->devices()) { - InitResponse::Device* d = init_resp->add_devices(); + absl::Span all_devices; + if (version_.protocol_version() < 7) { + all_devices = client_->devices(); + } else { + all_devices = client_->GetAllDevices(); + } + for (auto* device : all_devices) { + InitResponse::Device* d = init_resp->add_all_devices(); d->set_id(device->Id().value()); d->set_device_kind(AsProtoStringData(device->Kind())); if (auto default_memory = device->DefaultMemory(); default_memory.ok()) { @@ -289,13 +295,17 @@ absl::StatusOr IfrtBackend::HandleInit( } else { *d->mutable_attributes() = device->Attributes().ToProto(); } + + if (device->IsAddressable()) { + init_resp->add_addressable_device_ids(device->Id().value()); + } } - for (auto* addressable_device : client_->addressable_devices()) { - init_resp->add_addressable_device_ids(addressable_device->Id().value()); + for (auto* device : client_->devices()) { + init_resp->add_primary_device_ids(device->Id().value()); } absl::flat_hash_map memories; - for (auto* device : client_->devices()) { + for (auto* device : all_devices) { for (xla::ifrt::Memory* memory : device->Memories()) { const auto [it, inserted] = memories.insert({memory->Id().value(), memory}); diff --git a/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index 09c74cc68981b9..21f8a59fd26780 100644 --- a/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -250,6 +250,8 @@ class IfrtBackendHandlerTest : public IfrtBackendTest { } ON_CALL(*mock_client, devices()).WillByDefault(Return(raw_device_ptrs)); + ON_CALL(*mock_client, GetAllDevices()) + .WillByDefault(Return(raw_device_ptrs)); ON_CALL(*mock_client, LookupDevice(_)) .WillByDefault( Invoke([this](DeviceId id) -> absl::StatusOr { @@ -434,7 +436,7 @@ TEST_P(IfrtBackendHandlerTest, Init) { platform_id: 42 process_index: 1 runtime_type: "ifrt-service" - devices { + all_devices { id: 0 device_kind: "mock" default_memory_id: 0 @@ -444,7 +446,7 @@ TEST_P(IfrtBackendHandlerTest, Init) { value { string_value: "device0" } } } - devices { + all_devices { id: 1 device_kind: "mock" default_memory_id: 1 @@ -466,6 +468,53 @@ TEST_P(IfrtBackendHandlerTest, Init) { } } )pb")))))); + } else if (Version().protocol_version() < 7) { + EXPECT_THAT(CallBackend(std::move(request)), + IsOkAndHolds(Pointee( + Partially(IgnoringRepeatedFieldOrdering(EquivToProto(R"pb( + init_response { + session_id: 12345 + platform_name: "ifrt_backend" + platform_version: "n/a" + platform_id: 42 + process_index: 1 + runtime_type: "ifrt-service" + all_devices { + id: 0 + device_kind: "mock" + default_memory_id: 0 + memory_ids: [ 0 ] + attributes { + attributes { + key: "name" + value { string_value: "device0" } + } + } + } + all_devices { + id: 1 + device_kind: "mock" + default_memory_id: 1 + memory_ids: [ 1 ] + attributes { + attributes { + key: "name" + value { string_value: "device1" } + } + } + } + memories { + id: 0 + memory_space_kind: "mock" + device_ids: [ 0 ] + } + memories { + id: 1 + memory_space_kind: "mock" + device_ids: [ 1 ] + } + } + )pb")))))); } else { EXPECT_THAT(CallBackend(std::move(request)), IsOkAndHolds(Pointee( @@ -477,7 +526,7 @@ TEST_P(IfrtBackendHandlerTest, Init) { platform_id: 42 process_index: 1 runtime_type: "ifrt-service" - devices { + all_devices { id: 0 device_kind: "mock" default_memory_id: 0 @@ -489,7 +538,7 @@ TEST_P(IfrtBackendHandlerTest, Init) { } } } - devices { + all_devices { id: 1 device_kind: "mock" default_memory_id: 1 @@ -501,6 +550,7 @@ TEST_P(IfrtBackendHandlerTest, Init) { } } } + primary_device_ids: [ 0, 1 ] memories { id: 0 memory_space_kind: "mock" diff --git a/xla/python/pjrt_ifrt/pjrt_client.h b/xla/python/pjrt_ifrt/pjrt_client.h index 23900c049f344c..fdf11b8ecc67c5 100644 --- a/xla/python/pjrt_ifrt/pjrt_client.h +++ b/xla/python/pjrt_ifrt/pjrt_client.h @@ -206,6 +206,12 @@ class PjRtClient final return addressable_devices_; } int process_index() const override { return pjrt_client_->process_index(); } + + absl::Span GetAllDevices() const override { + DCHECK(this); + return devices_; + } + absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override { DCHECK(this); diff --git a/xla/python/py_client.cc b/xla/python/py_client.cc index d108e9d9c1e47d..9aa30743c59e29 100644 --- a/xla/python/py_client.cc +++ b/xla/python/py_client.cc @@ -179,6 +179,15 @@ std::vector> PyClient::LocalDevices() { return devices; } +std::vector> PyClient::GetAllDevices() { + std::vector> devices; + devices.reserve(ifrt_client_->GetAllDevices().size()); + for (ifrt::Device* device : ifrt_client_->GetAllDevices()) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + absl::StatusOr> PyClient::DeviceFromLocalHardwareId( int local_hardware_id) { TF_ASSIGN_OR_RETURN(ifrt::Device * device, @@ -693,6 +702,7 @@ PyType_Slot PyClient::slots_[] = { .def("local_device_count", &PyClient::addressable_device_count) .def("devices", &PyClient::Devices) .def("local_devices", &PyClient::LocalDevices) + .def("get_all_devices", &PyClient::GetAllDevices) .def("device_from_local_hardware_id", xla::ValueOrThrowWrapper(&PyClient::DeviceFromLocalHardwareId)) .def("live_executables", &PyClient::LiveExecutables) diff --git a/xla/python/py_client.h b/xla/python/py_client.h index 374b7f6d2e530c..7c95431619afd5 100644 --- a/xla/python/py_client.h +++ b/xla/python/py_client.h @@ -126,6 +126,7 @@ class PyClient { std::vector> Devices(); std::vector> LocalDevices(); + std::vector> GetAllDevices(); absl::StatusOr> DeviceFromLocalHardwareId( int local_hardware_id); diff --git a/xla/python/py_compile_only_client.cc b/xla/python/py_compile_only_client.cc index fa32529000f6de..9d37a41af1c87e 100644 --- a/xla/python/py_compile_only_client.cc +++ b/xla/python/py_compile_only_client.cc @@ -295,6 +295,9 @@ class CompileOnlyIfRtClient final return {}; } int process_index() const override { return 0; } + absl::Span GetAllDevices() const override { + return devices_; + } absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override { return Unimplemented( diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 51d879814b2e68..a96b4fecd736f1 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 289 +_version = 290 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index 441d5fbf450fa4..368ef243da3c19 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -2669,6 +2669,15 @@ def testScatter(self): class DeviceTest(ComputationTest): + def testDevices(self): + self.assertNotEmpty(self.backend.devices()) + + def testLocalDevices(self): + self.assertNotEmpty(self.backend.local_devices()) + + def testGetAllDevices(self): + self.assertNotEmpty(self.backend.get_all_devices()) + def testPlatform(self): for device in self.backend.local_devices(): self.assertEqual(device.platform, self.backend.platform) diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index e363d8d82471cb..845909f076224b 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -499,6 +499,7 @@ class Client: def local_device_count(self) -> int: ... def devices(self) -> List[Device]: ... def local_devices(self) -> List[Device]: ... + def get_all_devices(self) -> List[Device]: ... def device_from_local_hardware_id(self, int) -> Device: ... def live_buffers(self) -> List[Any]: ... def live_arrays(self) -> List[ArrayImpl]: ...