-
Notifications
You must be signed in to change notification settings - Fork 621
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Capture] Allow higher order primitives to accept dynamically shaped …
…arrays (#6786) **Context:** By turning on the experimental `jax_dynamic_shapes` mode, you can capture and compile jaxpr for a series of different shapes at the same time. While this expermental feature has issues and isn't fully supported by jax yet, it is used by catalyst. To continue to support all of catalyst's features, we need to be able to capture and work with dynamic shapes as well. **Description of the Change:** * Adds a `qml.capture.determine_abstracted_axes` function to determine the required `abstracted_axes` and the corresponding abstract shapes. * Use the `determine_abstracted_axes` function in all of our higher order primitives other than `grad` and `jacobian`, as `grad` and `jacobian` may prove more complicated. * Add a document explaining abstract shapes and how we can work with them. **Benefits:** Our higher order primitives can accept inputs with abstract shapes. **Possible Drawbacks:** This jax mode is still experimental. **Related GitHub Issues:** [sc-81471] --------- Co-authored-by: lillian542 <[email protected]> Co-authored-by: Mudit Pandey <[email protected]>
- Loading branch information
1 parent
421f345
commit d9b821d
Showing
17 changed files
with
985 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Copyright 2025 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Contains a utility for handling inputs with dynamically shaped arrays. | ||
""" | ||
from functools import lru_cache | ||
from string import ascii_lowercase as letters | ||
|
||
has_jax = True | ||
try: | ||
import jax | ||
except ImportError: # pragma: no cover | ||
has_jax = False # pragma: no cover | ||
|
||
|
||
@lru_cache | ||
def _get_letter(ind: int) -> str: | ||
if ind < 26: | ||
return letters[ind] | ||
if ind < 702: | ||
return letters[ind // 26 - 1] + letters[ind % 26] | ||
raise NotImplementedError("we only support up to 702 dynamic axes") # pragma: no cover | ||
|
||
|
||
def _get_shape_for_array(x, abstract_shapes: list) -> dict: | ||
""" | ||
Populate the dictionary of abstract axes for a single tensorlike. | ||
This dictionary has dimensions as keys, and a string marker as the value. | ||
Examples of shape -> abstract axes: | ||
* ``(3,4) -> {}`` | ||
* ``(tracer1, ) -> {0: "a"}`` | ||
* ``(tracer1, tracer1) -> {0: "a", 1: "a"}`` | ||
* ``(3, tracer1) -> {1: "a"}`` | ||
* ``(tracer1, 2, tracer2) -> {0: "a", 2: "b"}`` | ||
``abstract_shapes`` contains all the tracers found in shapes. | ||
""" | ||
abstract_axes = {} | ||
for i, s in enumerate(getattr(x, "shape", ())): | ||
if not isinstance(s, int): # if not int, then abstract | ||
found = False | ||
# check if the shape tracer is one we have already encountered | ||
for previous_idx, previous_shape in enumerate(abstract_shapes): | ||
if s is previous_shape: | ||
abstract_axes[i] = _get_letter(previous_idx) | ||
found = True | ||
break | ||
# haven't encountered it, so add it to abstract_axes | ||
# and use new letter designation | ||
if not found: | ||
abstract_axes[i] = _get_letter(len(abstract_shapes)) | ||
abstract_shapes.append(s) | ||
|
||
return abstract_axes | ||
|
||
|
||
def determine_abstracted_axes(args): | ||
"""Computed the abstracted axes and extracting the abstract shapes from the arguments. | ||
Args: | ||
args (tuple): the arguments for a higher order primitive | ||
Returns: | ||
tuple, tuple: the corresponding abstracted axes and dynamic shapes | ||
Note that "dynamic shapes" only refers to the size of dimensions, but not the number of dimensions. | ||
Even with dynamic shapes mode enabled, we cannot change the number of dimensions. | ||
See the ``intro_to_dynamic_shapes.md`` document for more information on how dynamic shapes work. | ||
To make jaxpr from arguments with dynamic shapes, the ``abstracted_axes`` keyword argument must be set. | ||
Then, when calling the jaxpr, variables for the dynamic shapes must be passed. | ||
.. code-block:: python | ||
jax.config.update("jax_dynamic_shapes", True) | ||
def f(n): | ||
x = jax.numpy.ones((n,)) | ||
abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes((x,)) | ||
jaxpr = jax.make_jaxpr(jax.numpy.sum, abstracted_axes=abstracted_axes)(x) | ||
return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, x) | ||
""" | ||
if not has_jax: # pragma: no cover | ||
raise ImportError("jax must be installed to use determine_abstracted_axes") | ||
if not jax.config.jax_dynamic_shapes: # pylint: disable=no-member | ||
return None, tuple() | ||
|
||
args, structure = jax.tree_util.tree_flatten(args) | ||
|
||
abstract_shapes = [] | ||
# note: this function in-place mutates abstract_shapes | ||
# adding any additional abstract shapes found | ||
abstracted_axes = [_get_shape_for_array(a, abstract_shapes) for a in args] | ||
|
||
if not abstract_shapes: | ||
return None, () | ||
abstracted_axes = jax.tree_util.tree_unflatten(structure, abstracted_axes) | ||
return abstracted_axes, abstract_shapes |
Oops, something went wrong.