Skip to content

Commit

Permalink
[JAX] Add PyClient::GetAllDevices() and expose it as an internal JAX …
Browse files Browse the repository at this point in the history
…backend API

JAX backend forwards `xla::ifrt::Client::GetAllDevices()` to
`xla::PyClient::GetAllDevices()`, which is accessible via JAX
`backend.get_all_devices()`. This API is an internal JAX API that is used for
building an experimental mesh utils API (finding colocated CPU devices) and should
not be used by the user code.

PiperOrigin-RevId: 679748877
  • Loading branch information
hyeontaek authored and Google-ML-Automation committed Oct 3, 2024
1 parent 92e3c7a commit 94ebf59
Show file tree
Hide file tree
Showing 20 changed files with 230 additions and 28 deletions.
1 change: 1 addition & 0 deletions xla/python/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ cc_library(
deps = [
":ifrt",
":test_util",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
alwayslink = True,
Expand Down
5 changes: 5 additions & 0 deletions xla/python/ifrt/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ class Client : public llvm::RTTIExtends<Client, llvm::RTTIRoot> {
virtual absl::Span<Device* const> addressable_devices() const = 0;
virtual int process_index() const = 0;

// Returns all devices. The result includes primary devices that are included
// in `devices()` as well as any other devices that are associated with
// the primary devices.
virtual absl::Span<xla::ifrt::Device* const> GetAllDevices() const = 0;

// TODO(hyeontaek): Consider removing this API. This API is potentially not
// being used by JAX or will be replaced with explicit device assignment.
virtual absl::StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
Expand Down
15 changes: 14 additions & 1 deletion xla/python/ifrt/client_impl_test_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/test_util.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace xla {
Expand Down Expand Up @@ -54,6 +55,18 @@ TEST(ClientImplTest, Devices) {
EXPECT_GE(client->process_index(), 0);
}

TEST(ClientImplTest, GetAllDevices) {
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());

EXPECT_GE(client->GetAllDevices().size(), client->device_count());

for (Device* device : client->GetAllDevices()) {
TF_ASSERT_OK_AND_ASSIGN(auto* looked_up_device,
client->LookupDevice(device->Id()));
EXPECT_EQ(device, looked_up_device);
}
}

TEST(ClientImplTest, DefaultCompiler) {
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
EXPECT_THAT(client->GetDefaultCompiler(), NotNull());
Expand Down
3 changes: 3 additions & 0 deletions xla/python/ifrt/mock.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ MockClient::MockClient(std::unique_ptr<xla::ifrt::Client> delegated)
ON_CALL(*this, process_index).WillByDefault([this]() {
return delegated_->process_index();
});
ON_CALL(*this, GetAllDevices).WillByDefault([this]() {
return delegated_->GetAllDevices();
});
ON_CALL(*this, GetDefaultDeviceAssignment)
.WillByDefault([this](int num_replicas, int num_partitions) {
return delegated_->GetDefaultDeviceAssignment(num_replicas,
Expand Down
2 changes: 2 additions & 0 deletions xla/python/ifrt/mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class MockClient : public llvm::RTTIExtends<MockClient, Client> {
MOCK_METHOD(absl::Span<Device* const>, addressable_devices, (),
(const, final));
MOCK_METHOD(int, process_index, (), (const, final));
MOCK_METHOD(absl::Span<xla::ifrt::Device* const>, GetAllDevices, (),
(const, final));
MOCK_METHOD(absl::StatusOr<DeviceAssignment>, GetDefaultDeviceAssignment,
(int num_replicas, int num_partitions), (const, final));
MOCK_METHOD(absl::StatusOr<Device*>, LookupDevice, (DeviceId device_id),
Expand Down
40 changes: 33 additions & 7 deletions xla/python/ifrt_proxy/client/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ absl::StatusOr<std::unique_ptr<Client>> Client::Create(
absl::flat_hash_set<int> addressable_device_ids(
init_response.addressable_device_ids().begin(),
init_response.addressable_device_ids().end());
absl::flat_hash_set<int> primary_device_ids;
if (rpc_helper->version().protocol_version() < 7) {
// Legacy implementation for servers do not support Client::GetAllDevices()
// and thus do not provide device_ids(). Assume that it contains all device
// ids from devices().
primary_device_ids.reserve(init_response.all_devices().size());
for (const auto& d : init_response.all_devices()) {
primary_device_ids.insert(d.id());
}
} else {
primary_device_ids.reserve(init_response.primary_device_ids().size());
primary_device_ids.insert(init_response.primary_device_ids().begin(),
init_response.primary_device_ids().end());
}

absl::flat_hash_map<int, std::unique_ptr<Memory>> memories;
for (const auto& m : init_response.memories()) {
Expand All @@ -77,10 +91,11 @@ absl::StatusOr<std::unique_ptr<Client>> Client::Create(
}

absl::flat_hash_map<int, std::unique_ptr<Device>> devices;
std::vector<xla::ifrt::Device*> device_ptrs;
std::vector<xla::ifrt::Device*> primary_device_ptrs;
std::vector<xla::ifrt::Device*> addressable_device_ptrs;
std::vector<xla::ifrt::Device*> all_device_ptrs;

for (const auto& d : init_response.devices()) {
for (const auto& d : init_response.all_devices()) {
absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>
pjrt_device_attributes;
if (rpc_helper->version().protocol_version() <= 3) {
Expand All @@ -99,14 +114,18 @@ absl::StatusOr<std::unique_ptr<Client>> Client::Create(
d.device_kind(), d.debug_string(), d.to_string(),
std::move(pjrt_device_attributes));
bool is_addressable = addressable_device_ids.contains(d.id());
bool is_primary = primary_device_ids.contains(d.id());

auto device =
std::make_unique<Device>(std::move(desc), d.local_device_id(),
d.local_hardware_id(), is_addressable);
device_ptrs.push_back(device.get());
all_device_ptrs.push_back(device.get());
if (is_addressable) {
addressable_device_ptrs.push_back(device.get());
}
if (is_primary) {
primary_device_ptrs.push_back(device.get());
}

if (d.has_default_memory_id()) {
const auto it = memories.find(d.default_memory_id());
Expand Down Expand Up @@ -150,9 +169,10 @@ absl::StatusOr<std::unique_ptr<Client>> Client::Create(
std::move(rpc_helper), init_response.session_id(),
init_response.platform_name(), init_response.platform_version(),
init_response.platform_id(), init_response.process_index(), runtime_type,
std::move(devices), device_ptrs, std::move(addressable_device_ptrs),
std::move(devices), std::move(primary_device_ptrs),
std::move(addressable_device_ptrs), all_device_ptrs,
std::move(memories)));
for (ifrt::Device* device : device_ptrs) {
for (ifrt::Device* device : all_device_ptrs) {
tensorflow::down_cast<Device*>(device)->client_ = client.get();
}
return client;
Expand All @@ -163,8 +183,9 @@ Client::Client(std::shared_ptr<RpcHelper> rpc_helper, uint64_t session_id,
uint64_t platform_id, uint64_t process_index,
std::string runtime_type,
absl::flat_hash_map<int, std::unique_ptr<Device>> devices,
std::vector<xla::ifrt::Device*> device_ptrs,
std::vector<xla::ifrt::Device*> primary_device_ptrs,
std::vector<xla::ifrt::Device*> addressable_device_ptrs,
std::vector<xla::ifrt::Device*> all_device_ptrs,
absl::flat_hash_map<int, std::unique_ptr<Memory>> memories)
: rpc_helper_(rpc_helper),
platform_name_(std::move(platform_name)),
Expand All @@ -175,8 +196,9 @@ Client::Client(std::shared_ptr<RpcHelper> rpc_helper, uint64_t session_id,
// TODO(b/309059940): Forward the backend attributes to the client.
attributes_(AttributeMap::Map()),
devices_(std::move(devices)),
device_ptrs_(device_ptrs),
primary_device_ptrs_(primary_device_ptrs),
addressable_device_ptrs_(std::move(addressable_device_ptrs)),
all_device_ptrs_(all_device_ptrs),
memories_(std::move(memories)),
default_compiler_(this, rpc_helper) {}

Expand Down Expand Up @@ -302,6 +324,10 @@ xla::ifrt::Future<> Client::GetReadyFuture(
return JoinFutures(futures);
}

absl::Span<xla::ifrt::Device* const> Client::GetAllDevices() const {
return all_device_ptrs_;
}

absl::StatusOr<DeviceAssignment> Client::GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const {
auto req = std::make_unique<GetDefaultDeviceAssignmentRequest>();
Expand Down
9 changes: 6 additions & 3 deletions xla/python/ifrt_proxy/client/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,13 @@ class Client final : public llvm::RTTIExtends<Client, xla::ifrt::Client> {
return addressable_devices().size();
}
absl::Span<xla::ifrt::Device* const> devices() const override {
return device_ptrs_;
return primary_device_ptrs_;
}
absl::Span<xla::ifrt::Device* const> addressable_devices() const override {
return addressable_device_ptrs_;
}
int process_index() const override { return process_index_; }
absl::Span<xla::ifrt::Device* const> GetAllDevices() const override;
absl::StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
absl::StatusOr<xla::ifrt::Device*> LookupDevice(
Expand Down Expand Up @@ -148,8 +149,9 @@ class Client final : public llvm::RTTIExtends<Client, xla::ifrt::Client> {
std::string platform_name, std::string platform_version,
uint64_t platform_id, uint64_t process_index, std::string runtime_type,
absl::flat_hash_map<int, std::unique_ptr<Device>> devices,
std::vector<xla::ifrt::Device*> device_ptrs,
std::vector<xla::ifrt::Device*> primary_device_ptrs,
std::vector<xla::ifrt::Device*> addressable_device_ptrs,
std::vector<xla::ifrt::Device*> all_device_ptrs,
absl::flat_hash_map<int, std::unique_ptr<Memory>> memories);

// rpc_helper_ will be referenced by various IFRT objects whose lifetime is
Expand All @@ -166,8 +168,9 @@ class Client final : public llvm::RTTIExtends<Client, xla::ifrt::Client> {
const AttributeMap attributes_;

const absl::flat_hash_map<int, std::unique_ptr<Device>> devices_;
const std::vector<xla::ifrt::Device*> device_ptrs_;
const std::vector<xla::ifrt::Device*> primary_device_ptrs_;
const std::vector<xla::ifrt::Device*> addressable_device_ptrs_;
const std::vector<xla::ifrt::Device*> all_device_ptrs_;

const absl::flat_hash_map<int, std::unique_ptr<Memory>> memories_;

Expand Down
58 changes: 54 additions & 4 deletions xla/python/ifrt_proxy/client/client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class ClientTest : public ::testing::TestWithParam</*protocol_version=*/int> {
platform_id: 42
process_index: 1
runtime_type: "ifrt-service"
devices {
all_devices {
id: 0
local_hardware_id: 1234
device_kind: "mock"
Expand All @@ -94,7 +94,7 @@ class ClientTest : public ::testing::TestWithParam</*protocol_version=*/int> {
value { string_value: "device0" }
}
}
devices {
all_devices {
id: 1
local_hardware_id: 1234
device_kind: "mock"
Expand All @@ -120,6 +120,55 @@ class ClientTest : public ::testing::TestWithParam</*protocol_version=*/int> {
}
)pb",
&response));
} else if (Version().protocol_version() < 7) {
ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString(
R"pb(
platform_name: "ifrt-service"
platform_version: "n/a"
platform_id: 42
process_index: 1
runtime_type: "ifrt-service"
all_devices {
id: 0
local_hardware_id: 1234
device_kind: "mock"
default_memory_id: 0
memory_ids: [ 0 ]
attributes {
attributes {
key: "name"
value { string_value: "device0" }
}
}
}
all_devices {
id: 1
local_hardware_id: 1234
device_kind: "mock"
default_memory_id: 1
memory_ids: [ 1 ]
attributes {
attributes {
key: "name"
value { string_value: "device1" }
}
}
}
addressable_device_ids: 1
memories {
id: 0
memory_space_kind: "mock"
kind_id: 0
device_ids: [ 0 ]
}
memories {
id: 1
memory_space_kind: "mock"
kind_id: 1
device_ids: [ 1 ]
}
)pb",
&response));
} else {
ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString(
R"pb(
Expand All @@ -128,7 +177,7 @@ class ClientTest : public ::testing::TestWithParam</*protocol_version=*/int> {
platform_id: 42
process_index: 1
runtime_type: "ifrt-service"
devices {
all_devices {
id: 0
local_hardware_id: 1234
device_kind: "mock"
Expand All @@ -141,7 +190,7 @@ class ClientTest : public ::testing::TestWithParam</*protocol_version=*/int> {
}
}
}
devices {
all_devices {
id: 1
local_hardware_id: 1234
device_kind: "mock"
Expand All @@ -154,6 +203,7 @@ class ClientTest : public ::testing::TestWithParam</*protocol_version=*/int> {
}
}
}
primary_device_ids: [ 0, 1 ]
addressable_device_ids: 1
memories {
id: 0
Expand Down
2 changes: 1 addition & 1 deletion xla/python/ifrt_proxy/client/version.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace proxy {
// LINT.IfChange
// TODO(b/296144873): Document the version upgrade policy.
inline constexpr int kClientMinVersion = 3;
inline constexpr int kClientMaxVersion = 6;
inline constexpr int kClientMaxVersion = 7;
// LINT.ThenChange(//tensorflow/compiler/xla/python/ifrt_proxy/common/VERSION.md)

} // namespace proxy
Expand Down
6 changes: 6 additions & 0 deletions xla/python/ifrt_proxy/common/VERSION.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,9 @@
* Added date: 2024-09-30.
* Changes:
* Added `ExecuteOptions::fill_status`.

## Version 7

* Added date: 2024-10-01.
* Changes:
* Added support for `Client::GetAllDevices()`.
7 changes: 5 additions & 2 deletions xla/python/ifrt_proxy/common/ifrt_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,12 @@ message InitResponse {
AttributeMapProto attributes = 10; // New in Version 4.
}

repeated Device devices = 6; // == ifrt::Client::devices()
repeated Device all_devices = 6; // == ifrt::Client::GetAllDevices()
repeated int32 primary_device_ids =
10; // == [device.id for device in ifrt::Client::devices()]
repeated int32 addressable_device_ids =
7; // == ifrt::Client::addressable_devices()
7; // == [device.id for device in ifrt::Client::GetAllDevices() if
// device.IsAddressable()]

message Memory {
int32 id = 1;
Expand Down
20 changes: 15 additions & 5 deletions xla/python/ifrt_proxy/server/ifrt_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,14 @@ absl::StatusOr<BackendInterface::Response> IfrtBackend::HandleInit(
init_resp->set_runtime_type(AsProtoStringData(client_->runtime_type()));
init_resp->set_process_index(client_->process_index());

for (auto* device : client_->devices()) {
InitResponse::Device* d = init_resp->add_devices();
absl::Span<xla::ifrt::Device* const> all_devices;
if (version_.protocol_version() < 7) {
all_devices = client_->devices();
} else {
all_devices = client_->GetAllDevices();
}
for (auto* device : all_devices) {
InitResponse::Device* d = init_resp->add_all_devices();
d->set_id(device->Id().value());
d->set_device_kind(AsProtoStringData(device->Kind()));
if (auto default_memory = device->DefaultMemory(); default_memory.ok()) {
Expand All @@ -289,13 +295,17 @@ absl::StatusOr<BackendInterface::Response> IfrtBackend::HandleInit(
} else {
*d->mutable_attributes() = device->Attributes().ToProto();
}

if (device->IsAddressable()) {
init_resp->add_addressable_device_ids(device->Id().value());
}
}
for (auto* addressable_device : client_->addressable_devices()) {
init_resp->add_addressable_device_ids(addressable_device->Id().value());
for (auto* device : client_->devices()) {
init_resp->add_primary_device_ids(device->Id().value());
}

absl::flat_hash_map<int, xla::ifrt::Memory*> memories;
for (auto* device : client_->devices()) {
for (auto* device : all_devices) {
for (xla::ifrt::Memory* memory : device->Memories()) {
const auto [it, inserted] =
memories.insert({memory->Id().value(), memory});
Expand Down
Loading

0 comments on commit 94ebf59

Please sign in to comment.