Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax example cleanup and replace pjit with jit. #1107

Merged
merged 9 commits into from
Aug 23, 2024

Conversation

nouiz
Copy link
Collaborator

@nouiz nouiz commented Aug 15, 2024

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Remove a deprecation warning in
  • Convert one jax example from pjit to jit with sharding.
    I did the minimal conversion. Maybe some refactoring would simplify it further.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@phu0ngng
Copy link
Collaborator

/te-ci jax

nouiz and others added 5 commits August 21, 2024 00:56
Signed-off-by: Frederic Bastien <[email protected]>
/opt/transformer-engine/examples/jax/encoder/test_multigpu_encoder.py:214:
DeprecationWarning: jax.tree_map is deprecated: u
se jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any
JAX version).
  params_axes_sharding = jax.tree_map(to_device_axis,
nn_partitioning.get_axis_names(params_axes))

Signed-off-by: Frederic Bastien <[email protected]>
Signed-off-by: Frederic Bastien <[email protected]>
Signed-off-by: Frederic Bastien <[email protected]>
@nouiz
Copy link
Collaborator Author

nouiz commented Aug 21, 2024

/te-ci jax

@nouiz
Copy link
Collaborator Author

nouiz commented Aug 21, 2024

/te-ci jax

Signed-off-by: Frederic Bastien <[email protected]>
@phu0ngng
Copy link
Collaborator

/te-ci jax

@phu0ngng phu0ngng merged commit 309c6d4 into NVIDIA:main Aug 23, 2024
14 checks passed
BeingGod pushed a commit to BeingGod/TransformerEngine that referenced this pull request Aug 30, 2024
* Use jit instead of pjit

---------

Signed-off-by: Frederic Bastien <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
Signed-off-by: beinggod <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants