Skip to content

CI: 04/23/25 upstream sync #379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 929 commits into
base: rocm-main
Choose a base branch
from

Conversation

rocm-repo-management-api-2[bot]
Copy link

Daily sync with upstream

gnecula and others added 30 commits April 11, 2025 12:53
…functions

About half of the tracing-cache-miss explanations in a large benchmark
end up being from JAX-internal functions, such as `jax.numpy` functions.
These cache misses are not what the JAX user wants to see, so we filter
them out, using the same mechanism used for filtering tracebacks.
…4.34.

As of today it has been 180 days since the release of 0.4.34 where the following legacy LAPACK kernels were no longer used when lowering:

* getrf
* geqrf / orgqr
* potrf
* gesdd
* syevd
* geev
* gehrd

Following our compatibility policy, these are now safe to remove.

PiperOrigin-RevId: 746388529
Previously, jax.jit returned a function with extra attributes, e.g., `trace`, and `lower`, such that we can use:

```
jax.jit(f).trace(...)
```

The new attributes create problems when `jax.jit` is used along `functools.wraps`.
Essentially, `functools.wraps(jax.jit(f))(wrapper)` is supposed to result in a
function that when invoked will invoke `wrapper` and then presumably `jax.jit(f)`.
This works as expected if you just call the result, but if you try to use it with
`lower` and `trace`, the `wrapper` is bypassed. This is because `wraps` copies the
attributes `trace` and `lower` from `jax.jit(f)` onto the resulting function,
so when `trace` is invoked the `wrapper` is bypassed entirely.

See jax-ml#27829 and jax-ml#27825.

The solution proposed here is to make the `trace` and `lower` be class attributes,
so that they are not copied by `functools.wraps`.
Thus, if you try to use `lower` or `trace` on the result of
`functools.wraps(jax.jit(f))()` you will get an error.
That is better than silently ignoring the wrapper.
The workaround is to apply `jax.jit` last among your wrappers.

Fixes: jax-ml#27829
So far Mosaic was implicitly relying on XLA to register the NVPTX target which made problems in cases where only a Mosaic kernel gets compiled and XLA didn't initialize the LLVM NVPTX target.

PiperOrigin-RevId: 746433654
The skip decorator being used here only worked for test methods, not test classes, so it accidentally had the effect of skipping all the tests.
But we don't really need a special decorator here anyway.

PiperOrigin-RevId: 746434607
Follow-up from jax-ml#27916.
jax-fixit

PiperOrigin-RevId: 746442635
stdout redirection is inherently racy; mark test cases doing it as thread unsafe.

PiperOrigin-RevId: 746443039
…k capsules to jax.dlpack.from_dlpack().

to_dlpack() is not needed in the current version of the dlpack protocol. The from_dlpack() method accepts an object that implements __dlpack__(). In most cases, a JAX array can be passed directly to functions like torch.dlpack.from_dlpack(), and vice versa for other frameworks. The main exception is TensorFlow which does not implement the current protocol.

PiperOrigin-RevId: 746464890
PiperOrigin-RevId: 746520758
…s the same as the order of arguments received in `jit` API and make it keyword-only

PiperOrigin-RevId: 746527807
This is not needed under the newer DLPack protocol for users, and there's an equivalent (`__dlpack__`).

PiperOrigin-RevId: 746530351
PiperOrigin-RevId: 746543312
The new `METADATA` specification disallows use of underscore and automatically converts any usage of them to dash.

https://packaging.python.org/en/latest/specifications/core-metadata/#provides-extra-multiple-use

This should fix the following error: jax-ml#27874  from appearing in future JAX releases

PiperOrigin-RevId: 746546162
PiperOrigin-RevId: 746546870
Use a count of chips (or omit it if 1) rather than specifying an ICI topology.

Examples:
* tpu_v5e_1x1 -> tpu_v5e
* tpu_v5e_4x2 -> tpu_v5e_x8
PiperOrigin-RevId: 746547477
PiperOrigin-RevId: 746554582
…thon as a patch, rolling back.

Reverts b1c96d4

PiperOrigin-RevId: 746565341
Missing space in '..math::' meant that the math wasn't rendering correctly.
These APIs are already broken on GPU and TPU by virtue of not being implemented in the PJRT C API, so it seems unlikely that they have any users.

PiperOrigin-RevId: 746595857
This parameter is available from jax-ml#23040 and documented in https://docs.jax.dev/en/latest/_autosummary/jax.numpy.isin.html.

PiperOrigin-RevId: 746606206
Google-ML-Automation and others added 28 commits April 22, 2025 09:43
PiperOrigin-RevId: 750226731
…e block size is too small.

PiperOrigin-RevId: 750244014
…if they are identical

PiperOrigin-RevId: 750284947
PiperOrigin-RevId: 750287933
… pipeline emitter if there's nothing to wait for.

Also enforce that `arrival_count` is always > 0.

PiperOrigin-RevId: 750294068
The signature is:

```
jax.shard_map(f, /, *, out_specs, axis_names=set(), in_specs=None, mesh=None, check_vma=True)
```

This API is a drop-in replacement for the experimental shard_map endpoint with just two small changes: `check_rep` is renamed to `check_vma` and all arguments (except `f`) to `shard_map` are keyword only and `f` is positional only.

**But why are mesh and in_specs optional? And what is the new `axis_names` argument?**

* `mesh` is optional because it can be inferred from the context if user sets the mesh via `jax.sharding.use_mesh(mesh)`.

* `in_specs` is optional because it can be inferred from the arguments passed to `shard_map` if all mesh axes are `Explicit`.

* `axis_names`: axis_names tells `shard_map` which axes are `Manual`. If empty, it implies the `shard_map` is `Manual` over all mesh axes.
Before in the experimental endpoint of `shard_map`, this argument was called `auto`. But after the advent of `sharding_in_types`, mesh axes can be `Auto`, `Explicit` or `Manual`. So `auto` was not enough since axes can be `Explicit` too. That's why `jax.shard_map` flips the argument to `axis_names`.

**If `in_specs` is optional, why is `out_specs` compulsory?**

This is because, we still need to know which dimension to concat over. It can't be inferred automatically since the choice can be anything.

PiperOrigin-RevId: 750343135
…REG spill.

Taking new factors into account for auto tunning:
- q_dtype_name
- kv_dtype_name
- num_q_heads_per_blk
- num_kv_heads_per_blk
- head_dim
- page_size
- max_num_batched_tokens
- max_model_len = page_size * pages_per_seq

We only has 32 SREGs in TensorCore. If the page size is small, we can easily spill SREGs. This cl suggests using `page_size = max_model_len // 16` which will make sure at most 16 SREGs will be used for KV page indices per sequence.

PiperOrigin-RevId: 750370022
…ard_map.py to `jax/_src`

The signature is:

`jax.shard_map(f, /, *, out_specs, axis_names=set(), in_specs=None, mesh=None, check_vma=True)`

This API is a drop-in replacement for the experimental shard_map endpoint with just two small changes: check_rep is renamed to check_vma and all arguments (except f) to shard_map are keyword only and f is positional only.

**But why are mesh and in_specs optional? And what is the new axis_names argument?**

mesh is optional because it can be inferred from the context if user sets the mesh via jax.sharding.use_mesh(mesh).

in_specs is optional because it can be inferred from the arguments passed to shard_map if all mesh axes are Explicit.

axis_names: axis_names tells shard_map which axes are Manual. If empty, it implies the shard_map is Manual over all mesh axes. Before in the experimental endpoint of shard_map, this argument was called auto. But after the advent of sharding_in_types, mesh axes can be Auto, Explicit or Manual. So auto was not enough since axes can be Explicit too. That's why jax.shard_map flips the argument to axis_names.

**If in_specs is optional, why is out_specs compulsory?**

This is because, we still need to know which dimension to concat over. It can't be inferred automatically since the choice can be anything.

END_PUBLIC

PiperOrigin-RevId: 750401402
…oadcast in lower_to_llo.

- To make fold_in non-trivial, in Pallas the key is now represented as a (1, 2)-shaped key.
- 2 new primitives were added for wrapping/unwrapping the key from scalars. This is needed because JAX's wrap/unwrap return to and from vectors, whereas in Pallas we need to return a list of scalars.

PiperOrigin-RevId: 750422791
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot requested a review from a team as a code owner April 23, 2025 06:02
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot enabled auto-merge (rebase) April 23, 2025 06:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.