forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 4
CI: 04/23/25 upstream sync #379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rocm-repo-management-api-2
wants to merge
929
commits into
rocm-main
Choose a base branch
from
ci-upstream-sync-178_1
base: rocm-main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…functions About half of the tracing-cache-miss explanations in a large benchmark end up being from JAX-internal functions, such as `jax.numpy` functions. These cache misses are not what the JAX user wants to see, so we filter them out, using the same mechanism used for filtering tracebacks.
…4.34. As of today it has been 180 days since the release of 0.4.34 where the following legacy LAPACK kernels were no longer used when lowering: * getrf * geqrf / orgqr * potrf * gesdd * syevd * geev * gehrd Following our compatibility policy, these are now safe to remove. PiperOrigin-RevId: 746388529
PiperOrigin-RevId: 746397395
…ernals PiperOrigin-RevId: 746397452
PiperOrigin-RevId: 746402180
Previously, jax.jit returned a function with extra attributes, e.g., `trace`, and `lower`, such that we can use: ``` jax.jit(f).trace(...) ``` The new attributes create problems when `jax.jit` is used along `functools.wraps`. Essentially, `functools.wraps(jax.jit(f))(wrapper)` is supposed to result in a function that when invoked will invoke `wrapper` and then presumably `jax.jit(f)`. This works as expected if you just call the result, but if you try to use it with `lower` and `trace`, the `wrapper` is bypassed. This is because `wraps` copies the attributes `trace` and `lower` from `jax.jit(f)` onto the resulting function, so when `trace` is invoked the `wrapper` is bypassed entirely. See jax-ml#27829 and jax-ml#27825. The solution proposed here is to make the `trace` and `lower` be class attributes, so that they are not copied by `functools.wraps`. Thus, if you try to use `lower` or `trace` on the result of `functools.wraps(jax.jit(f))()` you will get an error. That is better than silently ignoring the wrapper. The workaround is to apply `jax.jit` last among your wrappers. Fixes: jax-ml#27829
PiperOrigin-RevId: 746425307
So far Mosaic was implicitly relying on XLA to register the NVPTX target which made problems in cases where only a Mosaic kernel gets compiled and XLA didn't initialize the LLVM NVPTX target. PiperOrigin-RevId: 746433654
The skip decorator being used here only worked for test methods, not test classes, so it accidentally had the effect of skipping all the tests. But we don't really need a special decorator here anyway. PiperOrigin-RevId: 746434607
Follow-up from jax-ml#27916. jax-fixit PiperOrigin-RevId: 746442635
stdout redirection is inherently racy; mark test cases doing it as thread unsafe. PiperOrigin-RevId: 746443039
…loop` lowering PiperOrigin-RevId: 746444372
http://github.com/openxla/xla/commit/ca9011742bb84b3d2158feb262ddca221957ccc9. PiperOrigin-RevId: 746448816
…k capsules to jax.dlpack.from_dlpack(). to_dlpack() is not needed in the current version of the dlpack protocol. The from_dlpack() method accepts an object that implements __dlpack__(). In most cases, a JAX array can be passed directly to functions like torch.dlpack.from_dlpack(), and vice versa for other frameworks. The main exception is TensorFlow which does not implement the current protocol. PiperOrigin-RevId: 746464890
PiperOrigin-RevId: 746490665
jax-fixit PiperOrigin-RevId: 746496570
PiperOrigin-RevId: 746520758
…s the same as the order of arguments received in `jit` API and make it keyword-only PiperOrigin-RevId: 746527807
This is not needed under the newer DLPack protocol for users, and there's an equivalent (`__dlpack__`). PiperOrigin-RevId: 746530351
PiperOrigin-RevId: 746543312
The new `METADATA` specification disallows use of underscore and automatically converts any usage of them to dash. https://packaging.python.org/en/latest/specifications/core-metadata/#provides-extra-multiple-use This should fix the following error: jax-ml#27874 from appearing in future JAX releases PiperOrigin-RevId: 746546162
PiperOrigin-RevId: 746546870
Use a count of chips (or omit it if 1) rather than specifying an ICI topology. Examples: * tpu_v5e_1x1 -> tpu_v5e * tpu_v5e_4x2 -> tpu_v5e_x8 PiperOrigin-RevId: 746547477
PiperOrigin-RevId: 746564071
…thon as a patch, rolling back. Reverts b1c96d4 PiperOrigin-RevId: 746565341
Missing space in '..math::' meant that the math wasn't rendering correctly.
These APIs are already broken on GPU and TPU by virtue of not being implemented in the PJRT C API, so it seems unlikely that they have any users. PiperOrigin-RevId: 746595857
This parameter is available from jax-ml#23040 and documented in https://docs.jax.dev/en/latest/_autosummary/jax.numpy.isin.html. PiperOrigin-RevId: 746606206
PiperOrigin-RevId: 750226360
PiperOrigin-RevId: 750226731
PiperOrigin-RevId: 750229300
PiperOrigin-RevId: 750230068
…e block size is too small. PiperOrigin-RevId: 750244014
PiperOrigin-RevId: 750262738
Co-authored-by: Matthew Johnson <[email protected]>
PiperOrigin-RevId: 750282747
…if they are identical PiperOrigin-RevId: 750284947
PiperOrigin-RevId: 750287933
… pipeline emitter if there's nothing to wait for. Also enforce that `arrival_count` is always > 0. PiperOrigin-RevId: 750294068
PiperOrigin-RevId: 750296702
PiperOrigin-RevId: 750299496
PiperOrigin-RevId: 750302979
PiperOrigin-RevId: 750309885
PiperOrigin-RevId: 750342878
The signature is: ``` jax.shard_map(f, /, *, out_specs, axis_names=set(), in_specs=None, mesh=None, check_vma=True) ``` This API is a drop-in replacement for the experimental shard_map endpoint with just two small changes: `check_rep` is renamed to `check_vma` and all arguments (except `f`) to `shard_map` are keyword only and `f` is positional only. **But why are mesh and in_specs optional? And what is the new `axis_names` argument?** * `mesh` is optional because it can be inferred from the context if user sets the mesh via `jax.sharding.use_mesh(mesh)`. * `in_specs` is optional because it can be inferred from the arguments passed to `shard_map` if all mesh axes are `Explicit`. * `axis_names`: axis_names tells `shard_map` which axes are `Manual`. If empty, it implies the `shard_map` is `Manual` over all mesh axes. Before in the experimental endpoint of `shard_map`, this argument was called `auto`. But after the advent of `sharding_in_types`, mesh axes can be `Auto`, `Explicit` or `Manual`. So `auto` was not enough since axes can be `Explicit` too. That's why `jax.shard_map` flips the argument to `axis_names`. **If `in_specs` is optional, why is `out_specs` compulsory?** This is because, we still need to know which dimension to concat over. It can't be inferred automatically since the choice can be anything. PiperOrigin-RevId: 750343135
…REG spill. Taking new factors into account for auto tunning: - q_dtype_name - kv_dtype_name - num_q_heads_per_blk - num_kv_heads_per_blk - head_dim - page_size - max_num_batched_tokens - max_model_len = page_size * pages_per_seq We only has 32 SREGs in TensorCore. If the page size is small, we can easily spill SREGs. This cl suggests using `page_size = max_model_len // 16` which will make sure at most 16 SREGs will be used for KV page indices per sequence. PiperOrigin-RevId: 750370022
PiperOrigin-RevId: 750374339
PiperOrigin-RevId: 750385718
…ings PiperOrigin-RevId: 750390355
PiperOrigin-RevId: 750400956
…ard_map.py to `jax/_src` The signature is: `jax.shard_map(f, /, *, out_specs, axis_names=set(), in_specs=None, mesh=None, check_vma=True)` This API is a drop-in replacement for the experimental shard_map endpoint with just two small changes: check_rep is renamed to check_vma and all arguments (except f) to shard_map are keyword only and f is positional only. **But why are mesh and in_specs optional? And what is the new axis_names argument?** mesh is optional because it can be inferred from the context if user sets the mesh via jax.sharding.use_mesh(mesh). in_specs is optional because it can be inferred from the arguments passed to shard_map if all mesh axes are Explicit. axis_names: axis_names tells shard_map which axes are Manual. If empty, it implies the shard_map is Manual over all mesh axes. Before in the experimental endpoint of shard_map, this argument was called auto. But after the advent of sharding_in_types, mesh axes can be Auto, Explicit or Manual. So auto was not enough since axes can be Explicit too. That's why jax.shard_map flips the argument to axis_names. **If in_specs is optional, why is out_specs compulsory?** This is because, we still need to know which dimension to concat over. It can't be inferred automatically since the choice can be anything. END_PUBLIC PiperOrigin-RevId: 750401402
…oadcast in lower_to_llo. - To make fold_in non-trivial, in Pallas the key is now represented as a (1, 2)-shaped key. - 2 new primitives were added for wrapping/unwrapping the key from scalars. This is needed because JAX's wrap/unwrap return to and from vectors, whereas in Pallas we need to return a list of scalars. PiperOrigin-RevId: 750422791
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Daily sync with upstream