Skip to content

CI: 04/22/25 upstream sync #377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 887 commits into
base: rocm-main
Choose a base branch
from

Conversation

rocm-repo-management-api-2[bot]
Copy link

Daily sync with upstream

Google-ML-Automation and others added 30 commits April 10, 2025 09:16
If mesh axes are empty, we are setting mesh as None, resulting in an error in
this test.

This fix provides an empty mesh, when no mesh axes in dumped module are empty.

PiperOrigin-RevId: 746058506
…rying with the correct vma as the operands were.

PiperOrigin-RevId: 746065965
PiperOrigin-RevId: 746117643
This fixes some non-intuitive errors where scalar-shaped values in VREGs were being used in operations that expected SREGs.

PiperOrigin-RevId: 746146037
…pes being enabled by default

PiperOrigin-RevId: 746146834
Adds a new WarpMesh object which when used in conjunction with core_map, allows the user to drop into warp-level code rather than programming at the warpgroup level.

PiperOrigin-RevId: 746163942
This change primarily reduces sharding, although in a few cases it also increases shardings. It is harmful to performance to overshard tests since there's a startup and teardown cost to each test run.

In a few cases, change tests to be non-accelerator tests.

PiperOrigin-RevId: 746164539
Partially addresses: jax-ml#18246. If compile can also be a future, this code can be used to safely block on that as well.

PiperOrigin-RevId: 746189742
yashk2810 and others added 28 commits April 18, 2025 13:11
…rly by adding the explicit mesh axis on dim 0

PiperOrigin-RevId: 749125322
Array serialization in array_serialization.py contains a mixture of JAX
specific serialization logic and tensorstore driver. This change separates JAX
and tensorstore methods (a) making serialization more modular and (b)
potentially allowing for alternative array serialization backends in the
future.

Additional clean-up changes include:
- making ocdbt kvstore driver default in tensorstore
- robustified array serialization tests especially on multi-host
- explicit tensorstore array chunking to ensure chunk file size does not blow up

PiperOrigin-RevId: 749175295
…t to tracing cache after sharding_in_types config was turned on which lead to `sharding` always being available on `ShapedArray`

PiperOrigin-RevId: 749206500
PiperOrigin-RevId: 749464614
PiperOrigin-RevId: 749779206
Description:
- Copy mlir module before adding new attributes

Fixes jax-ml#27991
…utation-27991

PiperOrigin-RevId: 749811476
…ed op instead of multiple .at[] calls.

PiperOrigin-RevId: 749818535
Amend the scheme format and top-level domain.
…bstract eval

This can happen if a user forgets to unwrap a ref!

@asabne had this happen to him today, and he was confused as to what was going on. The prior error is unclear:

AssertionError: (MemRef<None>{float32[2,1024,1024]}, MemRef<None>{float32[1,1024,1024]})
PiperOrigin-RevId: 749979253
…ray creation when possible

This changes makes use of the new
`xla::ifrt::Client::MakeArraysFromHostBufferShards()` API when possible. This
API needs a single call to create a multi-shard IFRT Array (to be wrapped as a
JAX `PyArray`), which provides more optimization opportunities for the runtime
than creating single-device IFRT Arrays and then assembling them. Please note
that `xla::ifrt::Client::MakeArraysFromHostBufferShards()` implementation in
PjRt-IFRT is not yet optimized, so there is no immediate performance benefits
for McJAX.

As an exception, it takes the previous path of array assembly if any shard for
`BatchedDevicePut` is not a host buffer, but already a single-device array,
because `xla::ifrt::Client::MakeArraysFromHostBufferShards()` works only if all
the sharded input to be host buffers.

With batching possible at IFRT level, we now skip `DevicePutResultFn` step;
`DevicePut` (now `DevicePutWithDevice` and `DevicePutWithSharding`) internally
calls per-shard functions (with GIL released) and returns a final IFRT Array.

This change includes a code cleanup for
`xla::DevicePutResult::owning_pybuffer`, which was originally intended to hold
a Python object to keep an IFRT Array valid when it is created from
`DevicePut()` implementations, but this role has been entirely covered by
`on_done_with_host_buffer` function supplied at IFRT Array creation time.

PiperOrigin-RevId: 749989229
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot requested a review from a team as a code owner April 22, 2025 06:02
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot enabled auto-merge (rebase) April 22, 2025 06:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.