Skip to content

Commit

Permalink
Add a method to get default layout in PyClient.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622000403
  • Loading branch information
Jieying Luo authored and copybara-github committed Apr 5, 2024
1 parent 0960128 commit f246c53
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
16 changes: 16 additions & 0 deletions xla/python/py_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ limitations under the License.
#include "nanobind/stl/unique_ptr.h" // from @nanobind // IWYU pragma: keep
#include "nanobind/stl/variant.h" // from @nanobind // IWYU pragma: keep
#include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep
#include "xla/layout.h"
#include "xla/literal.h"
#include "xla/pjrt/exceptions.h"
#include "xla/pjrt/mlir_to_hlo.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/pjrt/pjrt_stream_executor_client.h"
#include "xla/pjrt/status_casters.h"
#include "xla/python/callback.h"
Expand All @@ -66,6 +68,7 @@ limitations under the License.
#include "xla/python/ifrt/memory.h"
#include "xla/python/nb_absl_span.h" // IWYU pragma: keep
#include "xla/python/nb_class_ptr.h"
#include "xla/python/nb_numpy.h"
#include "xla/python/pjrt_ifrt/pjrt_array.h"
#include "xla/python/pjrt_ifrt/pjrt_client.h"
#include "xla/python/pjrt_ifrt/xla_compiler.h"
Expand All @@ -79,6 +82,7 @@ limitations under the License.
#include "xla/python/python_ref_manager.h"
#include "xla/python/traceback.h"
#include "xla/python/transfer_guard_lib.h"
#include "xla/python/types.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/platform_util.h" // IWYU pragma: keep
#include "xla/shape.h"
Expand Down Expand Up @@ -802,6 +806,18 @@ PyType_Slot PyClient::slots_[] = {
nb::arg("result_shapes"), nb::arg("send_channel_ids"),
nb::arg("recv_channel_ids"),
nb::arg("serializer").none() = nb::none())
.def(
"get_default_layout",
[](PyClient& self, nb_dtype dtype,
nb::sequence dims_seq) -> std::unique_ptr<PjRtLayout> {
PrimitiveType element_type =
xla::ValueOrThrow(DtypeToPrimitiveType(dtype));
std::vector<int64_t> dims = SequenceToVector<int64_t>(dims_seq);
xla::Layout layout = xla::ValueOrThrow(
self.pjrt_client()->GetDefaultLayout(element_type, dims));
return std::make_unique<PjRtXlaLayout>(layout);
},
nb::arg("dtype"), nb::arg("dims_seq"))
.def("__getattr__",
[](PyClient& client, std::string_view name) -> nb::object {
const auto& attrs = client.attributes();
Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 254
_version = 255

# Version number for MLIR:Python components.
mlir_api_version = 55
Expand Down
3 changes: 3 additions & 0 deletions xla/python/xla_extension/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,9 @@ class Client:
recv_channel_ids: Sequence[int],
serializer: Optional[Callable] = ...,
) -> Any: ...
def get_default_layout(
self, dtype: np.dtype, dims_seq: Sequence[int]
) -> PjRtLayout: ...
def __getattr__(self, name: str) -> Any: ...

class CpuCollectives: ...
Expand Down

0 comments on commit f246c53

Please sign in to comment.