Skip to content

Commit

Permalink
[Capture] Allow higher order primitives to accept dynamically shaped …
Browse files Browse the repository at this point in the history
…arrays (#6786)

**Context:**

By turning on the experimental `jax_dynamic_shapes` mode, you can
capture and compile jaxpr for a series of different shapes at the same
time. While this expermental feature has issues and isn't fully
supported by jax yet, it is used by catalyst. To continue to support all
of catalyst's features, we need to be able to capture and work with
dynamic shapes as well.

**Description of the Change:**

* Adds a `qml.capture.determine_abstracted_axes` function to determine
the required `abstracted_axes` and the corresponding abstract shapes.
* Use the `determine_abstracted_axes` function in all of our higher
order primitives other than `grad` and `jacobian`, as `grad` and
`jacobian` may prove more complicated.
* Add a document explaining abstract shapes and how we can work with
them.

**Benefits:**

Our higher order primitives can accept inputs with abstract shapes.

**Possible Drawbacks:**

This jax mode is still experimental.

**Related GitHub Issues:**

[sc-81471]

---------

Co-authored-by: lillian542 <[email protected]>
Co-authored-by: Mudit Pandey <[email protected]>
  • Loading branch information
3 people authored Jan 27, 2025
1 parent 421f345 commit d9b821d
Show file tree
Hide file tree
Showing 17 changed files with 985 additions and 31 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

<h3>Improvements 🛠</h3>

* The higher order primitives in program capture can now accept inputs with abstract shapes.
[(#6786)](https://github.com/PennyLaneAI/pennylane/pull/6786)

* `QNode` objects now have an `update` method that allows for re-configuring settings like `diff_method`, `mcm_method`, and more. This allows for easier on-the-fly adjustments to workflows. Any arguments not specified will retain their original value.
[(#6803)](https://github.com/PennyLaneAI/pennylane/pull/6803)

Expand Down
2 changes: 2 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
~create_measurement_obs_primitive
~create_measurement_wires_primitive
~create_measurement_mcm_primitive
~determine_abstracted_axes
~expand_plxpr_transforms
~run_autograph
~make_plxpr
Expand Down Expand Up @@ -170,6 +171,7 @@ def _(*args, **kwargs):
)
from .flatfn import FlatFn
from .make_plxpr import make_plxpr, run_autograph
from .dynamic_shapes import determine_abstracted_axes

# by defining this here, we avoid
# E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module)
Expand Down
48 changes: 38 additions & 10 deletions pennylane/capture/base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,15 @@ def handle_ctrl_transform(self, *invals, n_control, jaxpr, control_values, work_


@PlxprInterpreter.register_primitive(for_loop_prim)
def handle_for_loop(self, start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice):
def handle_for_loop(
self, start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice
):
"""Handle a for loop primitive."""
init_state = args[args_slice]
abstract_shapes = args[abstract_shapes_slice]

new_jaxpr_body_fn = jaxpr_to_jaxpr(
copy(self), jaxpr_body_fn, args[consts_slice], start, *init_state
copy(self), jaxpr_body_fn, args[consts_slice], *abstract_shapes, start, *init_state
)

return for_loop_prim.bind(
Expand All @@ -400,6 +403,7 @@ def handle_for_loop(self, start, stop, step, *args, jaxpr_body_fn, consts_slice,
jaxpr_body_fn=new_jaxpr_body_fn,
consts_slice=consts_slice,
args_slice=args_slice,
abstract_shapes_slice=abstract_shapes_slice,
)


Expand All @@ -423,15 +427,27 @@ def handle_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):

@PlxprInterpreter.register_primitive(while_loop_prim)
def handle_while_loop(
self, *invals, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice
self,
*invals,
jaxpr_body_fn,
jaxpr_cond_fn,
body_slice,
cond_slice,
args_slice,
abstract_shapes_slice,
):
"""Handle a while loop primitive."""
consts_body = invals[body_slice]
consts_cond = invals[cond_slice]
init_state = invals[args_slice]
abstract_shapes = invals[abstract_shapes_slice]

new_jaxpr_body_fn = jaxpr_to_jaxpr(copy(self), jaxpr_body_fn, consts_body, *init_state)
new_jaxpr_cond_fn = jaxpr_to_jaxpr(copy(self), jaxpr_cond_fn, consts_cond, *init_state)
new_jaxpr_body_fn = jaxpr_to_jaxpr(
copy(self), jaxpr_body_fn, consts_body, *abstract_shapes, *init_state
)
new_jaxpr_cond_fn = jaxpr_to_jaxpr(
copy(self), jaxpr_cond_fn, consts_cond, *abstract_shapes, *init_state
)

return while_loop_prim.bind(
*invals,
Expand All @@ -440,6 +456,7 @@ def handle_while_loop(
body_slice=body_slice,
cond_slice=cond_slice,
args_slice=args_slice,
abstract_shapes_slice=abstract_shapes_slice,
)


Expand Down Expand Up @@ -481,16 +498,24 @@ def handle_jacobian(self, *invals, jaxpr, n_consts, **params):


def flatten_while_loop(
self, *invals, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice
self,
*invals,
jaxpr_body_fn,
jaxpr_cond_fn,
body_slice,
cond_slice,
args_slice,
abstract_shapes_slice,
):
"""Handle the while loop by a flattened python strategy."""
consts_body = invals[body_slice]
consts_cond = invals[cond_slice]
init_state = invals[args_slice]
abstract_shapes = invals[abstract_shapes_slice]

fn_res = init_state
while copy(self).eval(jaxpr_cond_fn, consts_cond, *fn_res)[0]:
fn_res = copy(self).eval(jaxpr_body_fn, consts_body, *fn_res)
while copy(self).eval(jaxpr_cond_fn, consts_cond, *abstract_shapes, *fn_res)[0]:
fn_res = copy(self).eval(jaxpr_body_fn, consts_body, *abstract_shapes, *fn_res)

return fn_res

Expand All @@ -514,14 +539,17 @@ def flattened_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):
FlattenedHigherOrderPrimitives[cond_prim] = flattened_cond


def flattened_for(self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice):
def flattened_for(
self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice
):
"""Handle the for loop by a flattened python strategy."""
consts = invals[consts_slice]
init_state = invals[args_slice]
abstract_shapes = invals[abstract_shapes_slice]

res = init_state
for i in range(start, stop, step):
res = copy(self).eval(jaxpr_body_fn, consts, i, *res)
res = copy(self).eval(jaxpr_body_fn, consts, *abstract_shapes, i, *res)

return res

Expand Down
115 changes: 115 additions & 0 deletions pennylane/capture/dynamic_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains a utility for handling inputs with dynamically shaped arrays.
"""
from functools import lru_cache
from string import ascii_lowercase as letters

has_jax = True
try:
import jax
except ImportError: # pragma: no cover
has_jax = False # pragma: no cover


@lru_cache
def _get_letter(ind: int) -> str:
if ind < 26:
return letters[ind]
if ind < 702:
return letters[ind // 26 - 1] + letters[ind % 26]
raise NotImplementedError("we only support up to 702 dynamic axes") # pragma: no cover


def _get_shape_for_array(x, abstract_shapes: list) -> dict:
"""
Populate the dictionary of abstract axes for a single tensorlike.
This dictionary has dimensions as keys, and a string marker as the value.
Examples of shape -> abstract axes:
* ``(3,4) -> {}``
* ``(tracer1, ) -> {0: "a"}``
* ``(tracer1, tracer1) -> {0: "a", 1: "a"}``
* ``(3, tracer1) -> {1: "a"}``
* ``(tracer1, 2, tracer2) -> {0: "a", 2: "b"}``
``abstract_shapes`` contains all the tracers found in shapes.
"""
abstract_axes = {}
for i, s in enumerate(getattr(x, "shape", ())):
if not isinstance(s, int): # if not int, then abstract
found = False
# check if the shape tracer is one we have already encountered
for previous_idx, previous_shape in enumerate(abstract_shapes):
if s is previous_shape:
abstract_axes[i] = _get_letter(previous_idx)
found = True
break
# haven't encountered it, so add it to abstract_axes
# and use new letter designation
if not found:
abstract_axes[i] = _get_letter(len(abstract_shapes))
abstract_shapes.append(s)

return abstract_axes


def determine_abstracted_axes(args):
"""Computed the abstracted axes and extracting the abstract shapes from the arguments.
Args:
args (tuple): the arguments for a higher order primitive
Returns:
tuple, tuple: the corresponding abstracted axes and dynamic shapes
Note that "dynamic shapes" only refers to the size of dimensions, but not the number of dimensions.
Even with dynamic shapes mode enabled, we cannot change the number of dimensions.
See the ``intro_to_dynamic_shapes.md`` document for more information on how dynamic shapes work.
To make jaxpr from arguments with dynamic shapes, the ``abstracted_axes`` keyword argument must be set.
Then, when calling the jaxpr, variables for the dynamic shapes must be passed.
.. code-block:: python
jax.config.update("jax_dynamic_shapes", True)
def f(n):
x = jax.numpy.ones((n,))
abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes((x,))
jaxpr = jax.make_jaxpr(jax.numpy.sum, abstracted_axes=abstracted_axes)(x)
return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, x)
"""
if not has_jax: # pragma: no cover
raise ImportError("jax must be installed to use determine_abstracted_axes")
if not jax.config.jax_dynamic_shapes: # pylint: disable=no-member
return None, tuple()

args, structure = jax.tree_util.tree_flatten(args)

abstract_shapes = []
# note: this function in-place mutates abstract_shapes
# adding any additional abstract shapes found
abstracted_axes = [_get_shape_for_array(a, abstract_shapes) for a in args]

if not abstract_shapes:
return None, ()
abstracted_axes = jax.tree_util.tree_unflatten(structure, abstracted_axes)
return abstracted_axes, abstract_shapes
Loading

0 comments on commit d9b821d

Please sign in to comment.