Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] Allow higher order primitives to accept dynamically shaped arrays #6786

Merged
merged 26 commits into from
Jan 27, 2025

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented Jan 8, 2025

Context:

By turning on the experimental jax_dynamic_shapes mode, you can capture and compile jaxpr for a series of different shapes at the same time. While this expermental feature has issues and isn't fully supported by jax yet, it is used by catalyst. To continue to support all of catalyst's features, we need to be able to capture and work with dynamic shapes as well.

Description of the Change:

  • Adds a qml.capture.determine_abstracted_axes function to determine the required abstracted_axes and the corresponding abstract shapes.
  • Use the determine_abstracted_axes function in all of our higher order primitives other than grad and jacobian, as grad and jacobian may prove more complicated.
  • Add a document explaining abstract shapes and how we can work with them.

Benefits:

Our higher order primitives can accept inputs with abstract shapes.

Possible Drawbacks:

This jax mode is still experimental.

Related GitHub Issues:

[sc-81471]

Copy link
Contributor

github-actions bot commented Jan 8, 2025

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@albi3ro albi3ro changed the title [Draft][Capture] Allow higher order primitives to accept dynamically shaped arrays [Capture] Allow higher order primitives to accept dynamically shaped arrays Jan 9, 2025
@albi3ro albi3ro marked this pull request as ready for review January 9, 2025 22:32
Copy link

codecov bot commented Jan 9, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.54%. Comparing base (afec979) to head (35b0091).
Report is 2 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #6786   +/-   ##
=======================================
  Coverage   99.54%   99.54%           
=======================================
  Files         477      478    +1     
  Lines       45246    45296   +50     
=======================================
+ Hits        45042    45092   +50     
  Misses        204      204           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@lillian542 lillian542 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few minor comments, but I have questions about determine_abstract_axes 😅

I need some help understanding what is going on, it would probably be easiest just to chat.

pennylane/capture/intro_to_dynamic_shapes.md Outdated Show resolved Hide resolved
pennylane/capture/intro_to_dynamic_shapes.md Outdated Show resolved Hide resolved
pennylane/capture/intro_to_dynamic_shapes.md Show resolved Hide resolved
pennylane/capture/intro_to_dynamic_shapes.md Show resolved Hide resolved
pennylane/capture/intro_to_dynamic_shapes.md Outdated Show resolved Hide resolved
pennylane/capture/intro_to_dynamic_shapes.md Show resolved Hide resolved
pennylane/compiler/qjit_api.py Outdated Show resolved Hide resolved
pennylane/capture/dynamic_shapes.py Show resolved Hide resolved
Copy link
Contributor

@mudit2812 mudit2812 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. It seems clear to me that "dynamic shapes" at this point refers to dynamic axis sizes, but the number of dimensions cannot be dynamic? Let me know if I'm misinterpreting. If not, probably worth noting that in the markdown file.

pennylane/capture/dynamic_shapes.py Outdated Show resolved Hide resolved
pennylane/capture/intro_to_dynamic_shapes.md Outdated Show resolved Hide resolved
pennylane/capture/intro_to_dynamic_shapes.md Outdated Show resolved Hide resolved
pennylane/capture/intro_to_dynamic_shapes.md Outdated Show resolved Hide resolved
pennylane/capture/intro_to_dynamic_shapes.md Outdated Show resolved Hide resolved
pennylane/capture/dynamic_shapes.py Outdated Show resolved Hide resolved
pennylane/capture/dynamic_shapes.py Outdated Show resolved Hide resolved
pennylane/compiler/qjit_api.py Show resolved Hide resolved
tests/capture/test_dynamic_shapes.py Outdated Show resolved Hide resolved
tests/capture/test_nested_plxpr.py Outdated Show resolved Hide resolved
@albi3ro
Copy link
Contributor Author

albi3ro commented Jan 15, 2025

Looks good. It seems clear to me that "dynamic shapes" at this point refers to dynamic axis sizes, but the number of dimensions cannot be dynamic? Let me know if I'm misinterpreting. If not, probably worth noting that in the markdown file.

@mudit2812 Yes, the number of dimensions cannot be dynamic. Only the size of the dimensions.

Copy link
Contributor

@mudit2812 mudit2812 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Just a few final clean up related comments, otherwise I'm happy to approve.

pennylane/capture/dynamic_shapes.py Show resolved Hide resolved
pennylane/capture/dynamic_shapes.py Outdated Show resolved Hide resolved
tests/capture/test_dynamic_shapes.py Outdated Show resolved Hide resolved
tests/capture/test_dynamic_shapes.py Outdated Show resolved Hide resolved
tests/capture/test_nested_plxpr.py Outdated Show resolved Hide resolved
tests/capture/test_capture_cond.py Outdated Show resolved Hide resolved
tests/capture/test_dynamic_shapes.py Outdated Show resolved Hide resolved
@albi3ro albi3ro enabled auto-merge (squash) January 27, 2025 15:02
@albi3ro albi3ro merged commit d9b821d into master Jan 27, 2025
46 checks passed
@albi3ro albi3ro deleted the dynamic-capture-hop-2 branch January 27, 2025 15:23
albi3ro added a commit to PennyLaneAI/pennylane-lightning that referenced this pull request Jan 27, 2025
**Context:**

In [PL PR #6786](PennyLaneAI/pennylane#6786), I
added additional kwargs for the for, while, and cond primitives for
handling inputs with abstract shapes. Those inputs broken the custom
registrations for the lightning interpreter.

**Description of the Change:**

Import the `FlattenedHigherOrderPrimitives` to get default behavior for
for, while, and cond. These should better stay updated with changes to
the primitives and the base class.

**Benefits:**

Lightning works again.

**Possible Drawbacks:**

**Related GitHub Issues:**

---------

Co-authored-by: ringo-but-quantum <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants