diff --git a/xla/python/py_client.cc b/xla/python/py_client.cc index efce0ebc1e5898..e2065b014a4493 100644 --- a/xla/python/py_client.cc +++ b/xla/python/py_client.cc @@ -49,12 +49,14 @@ limitations under the License. #include "nanobind/stl/unique_ptr.h" // from @nanobind // IWYU pragma: keep #include "nanobind/stl/variant.h" // from @nanobind // IWYU pragma: keep #include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep +#include "xla/layout.h" #include "xla/literal.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/pjrt_stream_executor_client.h" #include "xla/pjrt/status_casters.h" #include "xla/python/callback.h" @@ -66,6 +68,7 @@ limitations under the License. #include "xla/python/ifrt/memory.h" #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/nb_class_ptr.h" +#include "xla/python/nb_numpy.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" @@ -79,6 +82,7 @@ limitations under the License. #include "xla/python/python_ref_manager.h" #include "xla/python/traceback.h" #include "xla/python/transfer_guard_lib.h" +#include "xla/python/types.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/platform_util.h" // IWYU pragma: keep #include "xla/shape.h" @@ -802,6 +806,18 @@ PyType_Slot PyClient::slots_[] = { nb::arg("result_shapes"), nb::arg("send_channel_ids"), nb::arg("recv_channel_ids"), nb::arg("serializer").none() = nb::none()) + .def( + "get_default_layout", + [](PyClient& self, nb_dtype dtype, + nb::sequence dims_seq) -> std::unique_ptr { + PrimitiveType element_type = + xla::ValueOrThrow(DtypeToPrimitiveType(dtype)); + std::vector dims = SequenceToVector(dims_seq); + xla::Layout layout = xla::ValueOrThrow( + self.pjrt_client()->GetDefaultLayout(element_type, dims)); + return std::make_unique(layout); + }, + nb::arg("dtype"), nb::arg("dims_seq")) .def("__getattr__", [](PyClient& client, std::string_view name) -> nb::object { const auto& attrs = client.attributes(); diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 115c27989a38e5..d2ede223efc3c2 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -49,7 +49,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 254 +_version = 255 # Version number for MLIR:Python components. mlir_api_version = 55 diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index 0315b9bb19d166..4fa7e5913b6860 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -539,6 +539,9 @@ class Client: recv_channel_ids: Sequence[int], serializer: Optional[Callable] = ..., ) -> Any: ... + def get_default_layout( + self, dtype: np.dtype, dims_seq: Sequence[int] + ) -> PjRtLayout: ... def __getattr__(self, name: str) -> Any: ... class CpuCollectives: ...