Skip to content

Commit

Permalink
[JAX] Add PyClient::GetAllDevices() and expose it as a private JAX ba…
Browse files Browse the repository at this point in the history
…ckend API

JAX backend forwards `xla::ifrt::Client::GetAllDevices()` to
`xla::PyClient::GetAllDevices()`, which is accessible via JAX
`backend._get_all_devices()`. This API is a transitional private backend API
that is used for building an experimental API (finding colocated CPU devices)
and should not be used by any other code.

PiperOrigin-RevId: 679748877
  • Loading branch information
hyeontaek authored and Google-ML-Automation committed Oct 8, 2024
1 parent 364b54f commit 48bb5d1
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 0 deletions.
12 changes: 12 additions & 0 deletions xla/python/py_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ std::vector<nb_class_ptr<PyDevice>> PyClient::LocalDevices() {
return devices;
}

std::vector<nb_class_ptr<PyDevice>> PyClient::GetAllDevices() {
std::vector<nb_class_ptr<PyDevice>> devices;
devices.reserve(ifrt_client_->GetAllDevices().size());
for (ifrt::Device* device : ifrt_client_->GetAllDevices()) {
devices.push_back(GetPyDevice(device));
}
return devices;
}

absl::StatusOr<nb_class_ptr<PyDevice>> PyClient::DeviceFromLocalHardwareId(
int local_hardware_id) {
TF_ASSIGN_OR_RETURN(ifrt::Device * device,
Expand Down Expand Up @@ -693,6 +702,9 @@ PyType_Slot PyClient::slots_[] = {
.def("local_device_count", &PyClient::addressable_device_count)
.def("devices", &PyClient::Devices)
.def("local_devices", &PyClient::LocalDevices)
// TODO(hyeontaek): Remove this method once we have a unified API for
// enumerating devices with different criteria.
.def("_get_all_devices", &PyClient::GetAllDevices)
.def("device_from_local_hardware_id",
xla::ValueOrThrowWrapper(&PyClient::DeviceFromLocalHardwareId))
.def("live_executables", &PyClient::LiveExecutables)
Expand Down
5 changes: 5 additions & 0 deletions xla/python/py_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ class PyClient {

std::vector<nb_class_ptr<PyDevice>> Devices();
std::vector<nb_class_ptr<PyDevice>> LocalDevices();
// Returns all devices in the client. Private API; only use this method for
// implementing backend._get_all_devices().
// TODO(hyeontaek): Remove this method once we have a unified API for
// enumerating devices with different criteria.
std::vector<nb_class_ptr<PyDevice>> GetAllDevices();
absl::StatusOr<nb_class_ptr<PyDevice>> DeviceFromLocalHardwareId(
int local_hardware_id);

Expand Down
11 changes: 11 additions & 0 deletions xla/python/xla_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2726,6 +2726,17 @@ def testScatter(self):

class DeviceTest(ComputationTest):

def testDevices(self):
self.assertNotEmpty(self.backend.devices())

def testLocalDevices(self):
self.assertNotEmpty(self.backend.local_devices())

def testGetAllDevices(self):
# TODO(hyeontaek): Remove this method once we have a unified API for
# enumerating devices with different criteria.
self.assertNotEmpty(self.backend._get_all_devices()) # pylint: disable=protected-access

def testPlatform(self):
for device in self.backend.local_devices():
self.assertEqual(device.platform, self.backend.platform)
Expand Down
1 change: 1 addition & 0 deletions xla/python/xla_extension/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ class Client:
def local_device_count(self) -> int: ...
def devices(self) -> List[Device]: ...
def local_devices(self) -> List[Device]: ...
def _get_all_devices(self) -> List[Device]: ...
def device_from_local_hardware_id(self, int) -> Device: ...
def live_buffers(self) -> List[Any]: ...
def live_arrays(self) -> List[ArrayImpl]: ...
Expand Down

0 comments on commit 48bb5d1

Please sign in to comment.