Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] Allow higher order primitives to accept dynamically shaped arrays #6786

Merged
merged 26 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
54f7c67
support dynamic shape inputs for hop's
albi3ro Jan 6, 2025
b1303c1
Merge branch 'master' into dynamic-capture-hop-2
albi3ro Jan 8, 2025
1469ad2
explanation doc
albi3ro Jan 8, 2025
a0c9176
Merge branch 'dynamic-capture-hop-2' of https://github.com/PennyLaneA…
albi3ro Jan 8, 2025
2c59a64
add while loop support
albi3ro Jan 8, 2025
8e6a167
adding tests
albi3ro Jan 9, 2025
5bc1a90
Merge branch 'master' into dynamic-capture-hop-2
albi3ro Jan 9, 2025
6fff950
adding testing
albi3ro Jan 14, 2025
b6a7eb2
Merge branch 'master' into dynamic-capture-hop-2
albi3ro Jan 14, 2025
4b04dae
fix marking
albi3ro Jan 14, 2025
0916fba
Merge branch 'dynamic-capture-hop-2' of https://github.com/PennyLaneA…
albi3ro Jan 14, 2025
a5fb9ee
Update pennylane/capture/dynamic_shapes.py
albi3ro Jan 14, 2025
8604e5b
Apply suggestions from code review
albi3ro Jan 14, 2025
47bdf11
black and changelog
albi3ro Jan 14, 2025
a44e7b9
respond to feedback
albi3ro Jan 15, 2025
53d90b7
Merge branch 'master' into dynamic-capture-hop-2
albi3ro Jan 15, 2025
c57ff02
Apply suggestions from code review
albi3ro Jan 15, 2025
8403a58
Apply suggestions from code review
albi3ro Jan 15, 2025
b54d092
Apply suggestions from code review
albi3ro Jan 15, 2025
88361bc
all the dynamic shapes
albi3ro Jan 15, 2025
d4c7a6b
some clarirications and tests
albi3ro Jan 21, 2025
9192ebd
Update pennylane/capture/dynamic_shapes.py
albi3ro Jan 21, 2025
ab090bf
fix failing tests
albi3ro Jan 21, 2025
c9508ab
change fixture usage
albi3ro Jan 22, 2025
bc12b41
Merge branch 'master' into dynamic-capture-hop-2
albi3ro Jan 27, 2025
35b0091
black
albi3ro Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

<h3>Improvements 🛠</h3>

* The higher order primitives in program capture can now accept inputs with abstract shapes.
[(#6786)](https://github.com/PennyLaneAI/pennylane/pull/6786)

albi3ro marked this conversation as resolved.
Show resolved Hide resolved
* The coefficients of observables now have improved differentiability.
[(#6598)](https://github.com/PennyLaneAI/pennylane/pull/6598)

Expand Down
2 changes: 2 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
~create_measurement_obs_primitive
~create_measurement_wires_primitive
~create_measurement_mcm_primitive
~determine_abstracted_axes
~expand_plxpr_transforms
~run_autograph
~make_plxpr
Expand Down Expand Up @@ -170,6 +171,7 @@ def _(*args, **kwargs):
)
from .flatfn import FlatFn
from .make_plxpr import make_plxpr, run_autograph
from .dynamic_shapes import determine_abstracted_axes

# by defining this here, we avoid
# E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module)
Expand Down
48 changes: 38 additions & 10 deletions pennylane/capture/base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,12 +385,15 @@ def handle_ctrl_transform(self, *invals, n_control, jaxpr, control_values, work_


@PlxprInterpreter.register_primitive(for_loop_prim)
def handle_for_loop(self, start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice):
def handle_for_loop(
self, start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice
):
"""Handle a for loop primitive."""
init_state = args[args_slice]
abstract_shapes = args[abstract_shapes_slice]

new_jaxpr_body_fn = jaxpr_to_jaxpr(
copy(self), jaxpr_body_fn, args[consts_slice], start, *init_state
copy(self), jaxpr_body_fn, args[consts_slice], *abstract_shapes, start, *init_state
)

return for_loop_prim.bind(
Expand All @@ -401,6 +404,7 @@ def handle_for_loop(self, start, stop, step, *args, jaxpr_body_fn, consts_slice,
jaxpr_body_fn=new_jaxpr_body_fn,
consts_slice=consts_slice,
args_slice=args_slice,
abstract_shapes_slice=abstract_shapes_slice,
)


Expand All @@ -424,15 +428,27 @@ def handle_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):

@PlxprInterpreter.register_primitive(while_loop_prim)
def handle_while_loop(
self, *invals, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice
self,
*invals,
jaxpr_body_fn,
jaxpr_cond_fn,
body_slice,
cond_slice,
args_slice,
abstract_shapes_slice,
):
"""Handle a while loop primitive."""
consts_body = invals[body_slice]
consts_cond = invals[cond_slice]
init_state = invals[args_slice]
abstract_shapes = invals[abstract_shapes_slice]

new_jaxpr_body_fn = jaxpr_to_jaxpr(copy(self), jaxpr_body_fn, consts_body, *init_state)
new_jaxpr_cond_fn = jaxpr_to_jaxpr(copy(self), jaxpr_cond_fn, consts_cond, *init_state)
new_jaxpr_body_fn = jaxpr_to_jaxpr(
copy(self), jaxpr_body_fn, consts_body, *abstract_shapes, *init_state
)
new_jaxpr_cond_fn = jaxpr_to_jaxpr(
copy(self), jaxpr_cond_fn, consts_cond, *abstract_shapes, *init_state
)

return while_loop_prim.bind(
*invals,
Expand All @@ -441,6 +457,7 @@ def handle_while_loop(
body_slice=body_slice,
cond_slice=cond_slice,
args_slice=args_slice,
abstract_shapes_slice=abstract_shapes_slice,
)


Expand Down Expand Up @@ -482,16 +499,24 @@ def handle_jacobian(self, *invals, jaxpr, n_consts, **params):


def flatten_while_loop(
self, *invals, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice
self,
*invals,
jaxpr_body_fn,
jaxpr_cond_fn,
body_slice,
cond_slice,
args_slice,
abstract_shapes_slice,
):
"""Handle the while loop by a flattened python strategy."""
consts_body = invals[body_slice]
consts_cond = invals[cond_slice]
init_state = invals[args_slice]
abstract_shapes = invals[abstract_shapes_slice]

fn_res = init_state
while copy(self).eval(jaxpr_cond_fn, consts_cond, *fn_res)[0]:
fn_res = copy(self).eval(jaxpr_body_fn, consts_body, *fn_res)
while copy(self).eval(jaxpr_cond_fn, consts_cond, *abstract_shapes, *fn_res)[0]:
fn_res = copy(self).eval(jaxpr_body_fn, consts_body, *abstract_shapes, *fn_res)

return fn_res

Expand All @@ -515,14 +540,17 @@ def flattened_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):
FlattenedHigherOrderPrimitives[cond_prim] = flattened_cond


def flattened_for(self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice):
def flattened_for(
self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice
):
"""Handle the for loop by a flattened python strategy."""
consts = invals[consts_slice]
init_state = invals[args_slice]
abstract_shapes = invals[abstract_shapes_slice]

res = init_state
for i in range(start, stop, step):
res = copy(self).eval(jaxpr_body_fn, consts, i, *res)
res = copy(self).eval(jaxpr_body_fn, consts, *abstract_shapes, i, *res)

return res

Expand Down
100 changes: 100 additions & 0 deletions pennylane/capture/dynamic_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains a utility for handling inputs with dynamically shaped arrays.
"""
from string import ascii_lowercase

has_jax = True
try:
import jax
except ImportError: # pragma: no cover
has_jax = False # pragma: no cover


def _get_shape_for_array(x, abstract_shapes: list) -> dict:
"""
Populate the dictionay of abstract axes for a single tensorlike.
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

This dictionary has dimensions as keys, and a string marker as the value.

Examples of shape -> abstract axes:

* `(3,4) -> {}`
* `(tracer1, ) -> {0: "a"}`
* `(tracer1, tracer1) -> {0: "a", 1: "a"}`
* `(3, tracer1) -> {1: "a"}`
* `(tracer1, 2, tracer2) -> {0: "a", 2: "b"}`

`abstract_shapes` contains all the tracers found in shapes.
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

"""
abstract_axes = {}
for i, s in enumerate(getattr(x, "shape", ())):
if not isinstance(s, int): # if not int, then abstract
found = False
# check if the shape tracer is one we have already encountered
for previous_idx, previous_shape in enumerate(abstract_shapes):
if s is previous_shape:
abstract_axes[i] = ascii_lowercase[previous_idx]
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
found = True
continue
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
# haven't encountered it, so add it to abstract_axes
# and use new letter designation
if not found:
abstract_axes[i] = ascii_lowercase[len(abstract_shapes)]
abstract_shapes.append(s)

return abstract_axes


def determine_abstracted_axes(args):
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
"""Computed the abstracted axes and extracing the abstract shapes from the arguments.
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

Args:
args (tuple): the arguments for a higher order primitive

Returns:
tuple, tuple: the corresponding abstracted axes and dynamic shapes

See the ``intro_to_dynamic_shapes.md`` document for more information on how dynamic shapes work.

To make jaxpr from arguments with dynamic shapes, the ``abstracted_axes`` keyword argument must be set.
Then, when calling the jaxpr, variables for the dynamic shapes must be passed.

.. code-block:: python

def f(n):
x = jax.numpy.ones((n,))
abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes((x,))
jaxpr = jax.make_jaxpr(jax.numpy.sum, abstracted_axes=abstracted_axes)(x)
return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, x)

"""
if not has_jax: # pragma: no cover
raise ImportError("jax must be installed to use determine_abstracted_axes")
if not jax.config.jax_dynamic_shapes: # pylint: disable=no-member
return None, tuple()

args, structure = jax.tree_util.tree_flatten(args)

abstract_shapes = []
# note: this function in-place mutates abstract_shapes
# adding any additional abstract shapes found
abstracted_axes = [_get_shape_for_array(a, abstract_shapes) for a in args]

if not abstract_shapes:
return None, ()
abstracted_axes = jax.tree_util.tree_unflatten(structure, abstracted_axes)
return abstracted_axes, abstract_shapes
Loading
Loading