From 54f7c67c47f99bfb0a2441541ff9586620287335 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 6 Jan 2025 13:24:12 -0500 Subject: [PATCH 01/19] support dynamic shape inputs for hop's --- pennylane/capture/__init__.py | 1 + pennylane/capture/base_interpreter.py | 15 +++++--- pennylane/capture/dynamic_shapes.py | 51 +++++++++++++++++++++++++++ pennylane/compiler/qjit_api.py | 18 +++++++--- pennylane/ops/op_math/adjoint.py | 10 ++++-- pennylane/ops/op_math/controlled.py | 6 +++- pennylane/workflow/_capture_qnode.py | 4 ++- 7 files changed, 93 insertions(+), 12 deletions(-) create mode 100644 pennylane/capture/dynamic_shapes.py diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index 28a7a0edb35..8bd4d358ce0 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -170,6 +170,7 @@ def _(*args, **kwargs): ) from .flatfn import FlatFn from .make_plxpr import make_plxpr, run_autograph +from .dynamic_shapes import determine_abstracted_axes # by defining this here, we avoid # E0611: No name 'AbstractOperator' in module 'pennylane.capture' (no-name-in-module) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index b23f6be3e6b..8aad8708095 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -382,12 +382,15 @@ def handle_ctrl_transform(self, *invals, n_control, jaxpr, control_values, work_ @PlxprInterpreter.register_primitive(for_loop_prim) -def handle_for_loop(self, start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice): +def handle_for_loop( + self, start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice +): """Handle a for loop primitive.""" init_state = args[args_slice] + abstract_shapes = args[abstract_shapes_slice] new_jaxpr_body_fn = jaxpr_to_jaxpr( - copy(self), jaxpr_body_fn, args[consts_slice], start, *init_state + copy(self), jaxpr_body_fn, args[consts_slice], *abstract_shapes, start, *init_state ) return for_loop_prim.bind( @@ -398,6 +401,7 @@ def handle_for_loop(self, start, stop, step, *args, jaxpr_body_fn, consts_slice, jaxpr_body_fn=new_jaxpr_body_fn, consts_slice=consts_slice, args_slice=args_slice, + abstract_shapes_slice=abstract_shapes_slice, ) @@ -512,14 +516,17 @@ def flattened_cond(self, *invals, jaxpr_branches, consts_slices, args_slice): FlattenedHigherOrderPrimitives[cond_prim] = flattened_cond -def flattened_for(self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice): +def flattened_for( + self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice +): """Handle the for loop by a flattened python strategy.""" consts = invals[consts_slice] init_state = invals[args_slice] + abstract_shapes = invals[abstract_shapes_slice] res = init_state for i in range(start, stop, step): - res = copy(self).eval(jaxpr_body_fn, consts, i, *res) + res = copy(self).eval(jaxpr_body_fn, consts, *abstract_shapes, i, *res) return res diff --git a/pennylane/capture/dynamic_shapes.py b/pennylane/capture/dynamic_shapes.py new file mode 100644 index 00000000000..ea566724a6c --- /dev/null +++ b/pennylane/capture/dynamic_shapes.py @@ -0,0 +1,51 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contains a utility for handling inputs with dynamically shaped arrays. +""" +from string import ascii_lowercase + +has_jax = True +try: + import jax +except ImportError: + has_jax = False + + +def determine_abstracted_axes(args, structure=None): + """Computed the abstracted axes and extracing the abstract shapes from the arguments.""" + if not has_jax: + raise ImportError("jax must be installed to use determine_abstracted_axes") + if not jax.config.jax_dynamic_shapes: + return None, tuple() + if structure is None: + args, structure = jax.tree_util.tree_flatten(args) + abstracted_axes = [] + abstract_shapes = [] + for l in args: + l_shape = [] + for s in getattr(l, "shape", ()): + if isinstance(s, int): # not abstract + l_shape.append(()) + else: + l_shape.append(ascii_lowercase[len(abstract_shapes)]) + if all(s is not x for x in abstract_shapes): + # not already added + abstract_shapes.append(s) + abstracted_axes.append(tuple(l_shape)) + + if not abstract_shapes: + return None, () + abstracted_axes = jax.tree_util.tree_unflatten(structure, abstracted_axes) + return abstracted_axes, abstract_shapes diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index 08d88988b79..f585d4f2437 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -633,16 +633,17 @@ def _get_for_loop_qfunc_prim(): # pylint: disable=too-many-arguments @for_loop_prim.def_impl - def _(start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice): + def _(start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice): consts = args[consts_slice] init_state = args[args_slice] + abstract_shapes = args[abstract_shapes_slice] # in case start >= stop, return the initial state fn_res = init_state for i in range(start, stop, step): - fn_res = jax.core.eval_jaxpr(jaxpr_body_fn, consts, i, *fn_res) + fn_res = jax.core.eval_jaxpr(jaxpr_body_fn, consts, *abstract_shapes, i, *fn_res) return fn_res @@ -692,11 +693,18 @@ def _call_capture_enabled(self, *init_state): for_loop_prim = _get_for_loop_qfunc_prim() + abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes((0, *init_state)) + flat_fn = FlatFn(self.body_fn) - jaxpr_body_fn = jax.make_jaxpr(flat_fn)(0, *init_state) + jaxpr_body_fn = jax.make_jaxpr(flat_fn, abstracted_axes=abstracted_axes)(0, *init_state) consts_slice = slice(0, len(jaxpr_body_fn.consts)) - args_slice = slice(len(jaxpr_body_fn.consts), None) + args_slice = slice( + len(jaxpr_body_fn.consts), -len(abstract_shapes) if abstract_shapes else None + ) + abstract_shapes_slice = ( + slice(-len(abstract_shapes), None) if abstract_shapes else slice(0, 0) + ) flat_args, _ = jax.tree_util.tree_flatten(init_state) @@ -706,9 +714,11 @@ def _call_capture_enabled(self, *init_state): self.step, *jaxpr_body_fn.consts, *flat_args, + *abstract_shapes, jaxpr_body_fn=jaxpr_body_fn.jaxpr, consts_slice=consts_slice, args_slice=args_slice, + abstract_shapes_slice=abstract_shapes_slice, ) assert flat_fn.out_tree is not None return jax.tree_util.tree_unflatten(flat_fn.out_tree, results) diff --git a/pennylane/ops/op_math/adjoint.py b/pennylane/ops/op_math/adjoint.py index 400f2fc83c0..5e0366a5fa0 100644 --- a/pennylane/ops/op_math/adjoint.py +++ b/pennylane/ops/op_math/adjoint.py @@ -222,9 +222,15 @@ def _capture_adjoint_transform(qfunc: Callable, lazy=True) -> Callable: @wraps(qfunc) def new_qfunc(*args, **kwargs): - jaxpr = jax.make_jaxpr(partial(qfunc, **kwargs))(*args) + abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes(args) + jaxpr = jax.make_jaxpr(partial(qfunc, **kwargs), abstracted_axes=abstracted_axes)(*args) adjoint_prim.bind( - *jaxpr.consts, *args, jaxpr=jaxpr.jaxpr, lazy=lazy, n_consts=len(jaxpr.consts) + *jaxpr.consts, + *abstract_shapes, + *args, + jaxpr=jaxpr.jaxpr, + lazy=lazy, + n_consts=len(jaxpr.consts), ) return new_qfunc diff --git a/pennylane/ops/op_math/controlled.py b/pennylane/ops/op_math/controlled.py index d49209660c6..d768ec0a00e 100644 --- a/pennylane/ops/op_math/controlled.py +++ b/pennylane/ops/op_math/controlled.py @@ -268,10 +268,14 @@ def _capture_ctrl_transform(qfunc: Callable, control, control_values, work_wires @wraps(qfunc) def new_qfunc(*args, **kwargs): - jaxpr = jax.make_jaxpr(functools.partial(qfunc, **kwargs))(*args) + abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes(args) + jaxpr = jax.make_jaxpr(functools.partial(qfunc, **kwargs), abstracted_axes=abstracted_axes)( + *args + ) control_wires = qml.wires.Wires(control) # make sure is iterable ctrl_prim.bind( *jaxpr.consts, + *abstract_shapes, *args, *control_wires, jaxpr=jaxpr.jaxpr, diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py index 41b8b37e175..b5eefdc1161 100644 --- a/pennylane/workflow/_capture_qnode.py +++ b/pennylane/workflow/_capture_qnode.py @@ -382,9 +382,10 @@ def f(x): if not qnode.device.wires: raise NotImplementedError("devices must specify wires for integration with plxpr capture.") + abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes(args) qfunc = partial(qnode.func, **kwargs) if kwargs else qnode.func flat_fn = FlatFn(qfunc) - qfunc_jaxpr = jax.make_jaxpr(flat_fn)(*args) + qfunc_jaxpr = jax.make_jaxpr(flat_fn, abstracted_axes=abstracted_axes)(*args) execute_kwargs = copy(qnode.execute_kwargs) mcm_config = asdict(execute_kwargs.pop("mcm_config")) @@ -395,6 +396,7 @@ def f(x): res = qnode_prim.bind( *qfunc_jaxpr.consts, + *abstract_shapes, *flat_args, shots=shots, qnode=qnode, From 1469ad231299e276d8306cee2a5dc9079a6a19b6 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 8 Jan 2025 11:26:48 -0500 Subject: [PATCH 02/19] explanation doc --- pennylane/capture/intro_to_dynamic_shapes.md | 440 +++++++++++++++++++ 1 file changed, 440 insertions(+) create mode 100644 pennylane/capture/intro_to_dynamic_shapes.md diff --git a/pennylane/capture/intro_to_dynamic_shapes.md b/pennylane/capture/intro_to_dynamic_shapes.md new file mode 100644 index 00000000000..1e1e4f40adb --- /dev/null +++ b/pennylane/capture/intro_to_dynamic_shapes.md @@ -0,0 +1,440 @@ +# Introduction to dynamic shapes in jax + + +```python +import jax +``` + +Dynamic shapes are experimental feature of jax with limited support and feature coverage. + + +```python +jax.config.update("jax_dynamic_shapes", False) +``` + +Without this setup, we can't create arrays whose size depends on an abstract value. + + +```python +%xmode Minimal +def f(n): + return jax.numpy.ones((n,)) + +jax.make_jaxpr(f)(3) +``` + + Exception reporting mode: Minimal + + + + TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Tracedwith,). + If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions. + The error occurred while tracing the function f at /var/folders/k1/0v_kvphn55lgf_45kntf1hqm0000gq/T/ipykernel_27275/1980236754.py:2 for make_jaxpr. This concrete value was not available in Python because it depends on the value of the argument n. + + + + +```python +jax.make_jaxpr(f, static_argnums=0)(3) +``` + + + + + { lambda ; . let + a:f32[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] 1.0 + in (a,) } + + + +But if we make `3` as a static argnum with a fixed value, we can now produce jaxpr. But now we can see that it has `3` hardcoded into the jaxpr and the jaxpr could not be reused with a different input. + +Once we enable the experimental `"jax_dynamic_shapes"` mode we can capture such a function into jaxpr. + +Now the shapes of an array can themselves contain dynamic tracers. + + +```python +jax.config.update("jax_dynamic_shapes", True) +``` + + +```python +jax.make_jaxpr(f)(3) +``` + + { lambda ; a:i32[]. let + b:f32[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 a + in (b,) } + + + +With the use of the `abstracted_axes` keyword argument, we can also produce jaxpr for an input with a dynamic shape. + +By using the `abstracted_axes` to make the first dimension of our input as dynamic, we can reuse the same jaxpr for different sizes of arrays. + + +```python +def g(x): + return jax.numpy.sum(x) + +jax.make_jaxpr(g, abstracted_axes=("x",))(jax.numpy.array([1,2,3])) +``` + + + + + { lambda ; a:i32[] b:i32[a]. let c:i32[] = reduce_sum[axes=(0,)] b in (c,) } + +## Limitations of dynamic shapes and numerical manipulations + +1. Slicing into a dynamically sized array. + +Erick has an open PR to fix this issue on the jax github. Catalyst currently patches this bug on their side by patching the jax source code. + + +```python +def h(x): + return x[0] + +jax.make_jaxpr(h, abstracted_axes=("x", ) )(jax.numpy.array([0, 1,2])) +``` + + + TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].. + The error occurred while tracing the function h at /var/folders/k1/0v_kvphn55lgf_45kntf1hqm0000gq/T/ipykernel_27275/2165410745.py:1 for make_jaxpr. This concrete value was not available in Python because it depends on the value of the argument x. + See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError + + + +Executing with `eval_jaxpr`: + +No idea how to fix this right now. + + +```python +def k(n): + return jax.numpy.ones((n,)) + +jaxpr = jax.make_jaxpr(k)(3) +jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3) +``` + + + XlaRuntimeError: UNKNOWN: /var/folders/k1/0v_kvphn55lgf_45kntf1hqm0000gq/T/ipykernel_27275/1615670206.py:2:11: error: 'mhlo.dynamic_broadcast_in_dim' op can't be translated to XLA HLO + /var/folders/k1/0v_kvphn55lgf_45kntf1hqm0000gq/T/ipykernel_27275/1615670206.py:4:8: note: called from + IPython/core/interactiveshell.py:3577:20: note: called from + exec(code_obj, self.user_global_ns, self.user_ns) + ^ + IPython/core/interactiveshell.py:3517:19: note: called from + if await self.run_code(code, result, async_=asy): + ^ + IPython/core/interactiveshell.py:3334:29: note: called from + has_raised = await self.run_ast_nodes(code_ast.body, cell_name, + ^ + IPython/core/async_helpers.py:128:8: note: called from + coro.send(None) + ^ + IPython/core/interactiveshell.py:3130:21: note: called from + result = runner(coro) + ^ + IPython/core/interactiveshell.py:3075:21: note: called from + result = self._run_cell( + ^ + ipykernel/zmqshell.py:549:15: note: called from + return super().run_cell(*args, **kwargs) + ^ + ipykernel/ipkernel.py:449:26: note: called from + res = shell.run_cell( + ^ + /var/folders/k1/0v_kvphn55lgf_45kntf1hqm0000gq/T/ipykernel_27275/1615670206.py:2:11: note: see current operation: %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<1xi32>) -> tensor + + + + +## Extending support to PLXPR HOP's + +When capturing higher order primitives, we call `jax.make_jaxpr(f)` with arguments whose shapes are tracers. + +When calling `jax.make_jaxpr` inside a traced function, such as we do when using HOP's, we still need to specify the `abstracted_axes`. Failing to do so leads in an error: + + +```python +def f(n): + x = jax.numpy.ones((n,)) + jaxpr = jax.make_jaxpr(jax.numpy.sum)(x) + return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, n, x) + +jax.make_jaxpr(f)(3) +``` + + + AssertionError + + + + +```python +def f(n): + x = jax.numpy.ones((n,)) + jaxpr = jax.make_jaxpr(jax.numpy.sum, abstracted_axes=("n",))(x) + print("inner jaxpr: ", jaxpr) + return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, n, x) + +jax.make_jaxpr(f)(3) +``` +``` +inner jaxpr: { lambda ; a:i32[] b:f32[a]. let c:f32[] = reduce_sum[axes=(0,)] b in (c,) } +{ lambda ; a:i32[]. let + b:f32[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 a + c:f32[] = reduce_sum[axes=(0,)] b + in (c,) } +``` + + + +Note in this case that I am passing `n` when evaluating the jaxpr, even though `n` wasn't an argument that produced the jaxpr. + +`n` was an implicit argument contained inside of `x`, so `make_jaxpr` promotes it explicit input. We can see this in the "inner jaxpr" printed out inside the function. Even though the function that produced it only had `x` as an input, the jaxpr has `a:i32[], b:f32[a]` as two arguments. When re-evaluating the jaxpr later, we need to make sure to pass the value for `n` as well. + +To handle generic functions, we must then be able to determine which axes are dynamic from the arguments, and extract the tracer values for all the abstract dimensions. + + +```python +alphabet = "abcdefghijklmnop" +def determine_abstracted_axes(args): + + leaves, structure = jax.tree_util.tree_flatten(args) + abstracted_axes = [] + abstract_shapes = [] + + for l in leaves: + l_shape = [] + for s in l.shape: + if isinstance(s, int): # not abstract + l_shape.append(()) + else: + l_shape.append(alphabet[len(abstract_shapes)]) + abstract_shapes.append(s) + abstracted_axes.append(tuple(l_shape) if len(l_shape) != 1 else l_shape[0]) # maybe ? + abstracted_axes = jax.tree_util.tree_unflatten(structure, abstracted_axes) + return abstracted_axes, abstract_shapes +``` + + +```python +def f(n): + x = jax.numpy.ones((n,)) + abstracted_axes, abstract_shapes = determine_abstracted_axes((x,)) + jaxpr = jax.make_jaxpr(jax.numpy.sum, abstracted_axes=abstracted_axes)(x) + return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, x) + +jax.make_jaxpr(f)(3) +``` +``` +{ lambda ; a:i32[]. let + b:f32[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 a + c:f32[] = reduce_sum[axes=(0,)] b + in (c,) } +``` + + + +We can now take these learnings a make custom higher order primitive that supports dynamically shaped inputs: + + +```python +prim = jax.core.Primitive("hop") +prim.multiple_results = True + +@prim.def_impl +def _(*args, jaxpr, n_consts): + return jax.core.eval_jaxpr(jaxpr, args[:n_consts], *args[n_consts:]) + +@prim.def_abstract_eval +def _(*args, jaxpr, n_consts): + return [v.aval for v in jaxpr.outvars] + +def bind_prim(f, *args): + abstracted_axes, abstract_shapes = determine_abstracted_axes(args) + jaxpr = jax.make_jaxpr(f, abstracted_axes=abstracted_axes)(*args) + return prim.bind(*jaxpr.consts, *abstract_shapes, *args, jaxpr=jaxpr.jaxpr, n_consts=len(jaxpr.consts)) +``` + + +```python +def workflow(x): + return bind_prim(jax.numpy.sum, x) + +jaxpr = jax.make_jaxpr(workflow, abstracted_axes=("a", ))(jax.numpy.array([1,2,3])) +jaxpr +``` + + +``` +{ lambda ; a:i32[] b:i32[a]. let + c:i32[] = hop[ + jaxpr={ lambda ; d:i32[] e:i32[d]. let + f:i32[] = reduce_sum[axes=(0,)] e + in (f,) } + n_consts=0 + ] a b + in (c,) } +``` + + + + +```python +jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.array([1,1])) +``` + + + + + [Array(2, dtype=int32)] + + + +Great! It's working! + +At least for that example with `jax.numpy.sum`. + +What happens when the higher order primitive returns a dynamic shaped array too? + + +```python +def workflow2(x): + return bind_prim(lambda x: 2*x, x) + +jaxpr = jax.make_jaxpr(workflow2, abstracted_axes=("a", ))(jax.numpy.array([1,2,3])) +jaxpr +``` + + + KeyError: Var(id=4694049536):int32[] + + + +It no longer works ;( + +The output shape for the primitive contains a variable that is not in the local environment. It lived in the environment of the inner jaxpr, and is not present in the outer jaxpr. + +Do we have any workarounds? + +If we enforce that the HOP has a return shape that *matches* one of the inputs, we are home free. + +For example, with for loops and while loops, we can insist that the output shapes are the same as the input shapes: + + +```python +prim2 = jax.core.Primitive("hop") +prim2.multiple_results = True + +@prim2.def_impl +def _(*args, jaxpr, n_consts, in_abstract_inds): + return jax.core.eval_jaxpr(jaxpr, args[:n_consts], *args[n_consts:]) + +@prim2.def_abstract_eval +def _(*args, jaxpr, n_consts, n_abstract_inds): + return args[n_consts+n_abstract_inds:] + +def bind_prim2(f, *args): + abstracted_axes, abstract_shapes = determine_abstracted_axes(args) + jaxpr = jax.make_jaxpr(f, abstracted_axes=abstracted_axes)(*args) + return prim2.bind(*jaxpr.consts, *abstract_shapes, *args, + jaxpr=jaxpr.jaxpr, + n_consts=len(jaxpr.consts), + n_abstract_inds=len(abstract_shapes) + ) +``` + + +```python +def workflow3(x): + return bind_prim2(lambda x: 2*x, x) + +jaxpr = jax.make_jaxpr(workflow3, abstracted_axes=("a", ))(jax.numpy.array([1,2,3])) +jaxpr +``` + +``` +{ lambda ; a:i32[] b:i32[a]. let + c:i32[a] = hop[ + jaxpr={ lambda ; d:i32[] e:i32[d]. let + f:i32[d] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 2 d + g:i32[d] = mul f e + in (g,) } + n_abstract_inds=1 + n_consts=0 + ] a b + in (c,) } +``` + + + +So once again we are good! Our primitive accepted something of shape `i32[a]` and returned something of shape `i32[a]`. The value of `a` was already present in the local environment, so it could continue to exist in the jaxpr. + +What if the shape isn't accessible? What if we wanted to resize one of the inputs, or create a fully new dimension like is done with `jax.numpy.ones`? + +That now gets a bit trickier. The solution has several issues: + +1) A bit more difficult to read and follow +2) Relies on unstable componets of jax internals + +But why let those concerns stop us now! Let's do it. + +What we need to do in this case in hi-jack how `DynamicJaxTracer` creates an equation for the relevant primitive. It will no longer use the default logic relying on the `abstract_eval`, but our own pipeline. + +Here we are going to create a primitive that accepts an argument `n`, and returns an array of shape `f32[n,2]`. + + +```python +prim3 = jax.core.Primitive("dynamic_output") +prim3.multiple_results = True +``` + + +```python + +``` + + +```python +from jax._src.interpreters import partial_eval as pe + +def custom_staging_rule(jaxpr_trace, *invars, **params): + new_shapes = [jax.core.DShapedArray((invars[0],2), jax.numpy.float32.dtype)] + out_tracers = [pe.DynamicJaxprTracer(jaxpr_trace, o) for o in new_shapes] + eqn = pe.new_jaxpr_eqn( + [jaxpr_trace.getvar(x) for x in invars], + [jaxpr_trace.makevar(o) for o in out_tracers], + prim3, + params, + jax.core.no_effects, + ) + jaxpr_trace.frame.add_eqn(eqn) + return out_tracers + +pe.custom_staging_rules[prim3] = custom_staging_rule +``` + + +```python + +``` + + +```python +def workflow4(n): + return prim3.bind(n) + +jax.make_jaxpr(workflow4)(2) +``` + +``` +{ lambda ; a:i32[]. let b:f32[a,2] = dynamic_output a in (b,) } +``` + + +This custom staging rule route will be most useful for allowing the shape of `sample` to depend on a dynamic number of shots. From 2c59a64c57fbbaef39e029113ae9bc5383f7d8ed Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 8 Jan 2025 17:04:13 -0500 Subject: [PATCH 03/19] add while loop support --- pennylane/capture/base_interpreter.py | 33 ++++++++++++++++++++++----- pennylane/compiler/qjit_api.py | 28 ++++++++++++++++++----- 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index 8aad8708095..08bc002c6a6 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -425,15 +425,27 @@ def handle_cond(self, *invals, jaxpr_branches, consts_slices, args_slice): @PlxprInterpreter.register_primitive(while_loop_prim) def handle_while_loop( - self, *invals, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice + self, + *invals, + jaxpr_body_fn, + jaxpr_cond_fn, + body_slice, + cond_slice, + args_slice, + abstract_shapes_slice, ): """Handle a while loop primitive.""" consts_body = invals[body_slice] consts_cond = invals[cond_slice] init_state = invals[args_slice] + abstract_shapes = invals[abstract_shapes_slice] - new_jaxpr_body_fn = jaxpr_to_jaxpr(copy(self), jaxpr_body_fn, consts_body, *init_state) - new_jaxpr_cond_fn = jaxpr_to_jaxpr(copy(self), jaxpr_cond_fn, consts_cond, *init_state) + new_jaxpr_body_fn = jaxpr_to_jaxpr( + copy(self), jaxpr_body_fn, consts_body, *abstract_shapes, *init_state + ) + new_jaxpr_cond_fn = jaxpr_to_jaxpr( + copy(self), jaxpr_cond_fn, consts_cond, *abstract_shapes, *init_state + ) return while_loop_prim.bind( *invals, @@ -442,6 +454,7 @@ def handle_while_loop( body_slice=body_slice, cond_slice=cond_slice, args_slice=args_slice, + abstract_shapes_slice=abstract_shapes_slice, ) @@ -483,16 +496,24 @@ def handle_jacobian(self, *invals, jaxpr, n_consts, **params): def flatten_while_loop( - self, *invals, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice + self, + *invals, + jaxpr_body_fn, + jaxpr_cond_fn, + body_slice, + cond_slice, + args_slice, + abstract_shapes_slice, ): """Handle the while loop by a flattened python strategy.""" consts_body = invals[body_slice] consts_cond = invals[cond_slice] init_state = invals[args_slice] + abstract_shapes_slice = invals[abstract_shapes_slice] fn_res = init_state - while copy(self).eval(jaxpr_cond_fn, consts_cond, *fn_res)[0]: - fn_res = copy(self).eval(jaxpr_body_fn, consts_body, *fn_res) + while copy(self).eval(jaxpr_cond_fn, consts_cond, *abstract_shapes_slice, *fn_res)[0]: + fn_res = copy(self).eval(jaxpr_body_fn, consts_body, *abstract_shapes_slice, *fn_res) return fn_res diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index f585d4f2437..6f2ca052e95 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -411,16 +411,26 @@ def _get_while_loop_qfunc_prim(): while_loop_prim.multiple_results = True @while_loop_prim.def_impl - def _(*args, jaxpr_body_fn, jaxpr_cond_fn, body_slice, cond_slice, args_slice): + def _( + *args, + jaxpr_body_fn, + jaxpr_cond_fn, + body_slice, + cond_slice, + args_slice, + abstract_shapes_slice, + ): jaxpr_consts_body = args[body_slice] jaxpr_consts_cond = args[cond_slice] init_state = args[args_slice] - + abstract_shapes = args[abstract_shapes_slice] # If cond_fn(*init_state) is False, return the initial state fn_res = init_state while jax.core.eval_jaxpr(jaxpr_cond_fn, jaxpr_consts_cond, *fn_res)[0]: - fn_res = jax.core.eval_jaxpr(jaxpr_body_fn, jaxpr_consts_body, *fn_res) + fn_res = jax.core.eval_jaxpr( + jaxpr_body_fn, jaxpr_consts_body, *abstract_shapes, *fn_res + ) return fn_res @@ -461,26 +471,32 @@ def _call_capture_enabled(self, *init_state): while_loop_prim = _get_while_loop_qfunc_prim() + abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes(init_state) + flat_body_fn = FlatFn(self.body_fn) - jaxpr_body_fn = jax.make_jaxpr(flat_body_fn)(*init_state) - jaxpr_cond_fn = jax.make_jaxpr(self.cond_fn)(*init_state) + jaxpr_body_fn = jax.make_jaxpr(flat_body_fn, abstracted_axes=abstracted_axes)(*init_state) + jaxpr_cond_fn = jax.make_jaxpr(self.cond_fn, abstracted_axes=abstracted_axes)(*init_state) n_bf_c = len(jaxpr_body_fn.consts) n_cf_c = len(jaxpr_cond_fn.consts) + end_abstract_shapes = -len(abstract_shapes) if abstract_shapes else None body_consts = slice(0, n_bf_c) cond_consts = slice(n_bf_c, n_bf_c + n_cf_c) - args_slice = slice(n_cf_c + n_bf_c, None) + args_slice = slice(n_cf_c + n_bf_c, end_abstract_shapes) + abstract_shapes_slice = slice(end_abstract_shapes, None) if abstract_shapes else slice(0, 0) flat_args, _ = jax.tree_util.tree_flatten(init_state) results = while_loop_prim.bind( *jaxpr_body_fn.consts, *jaxpr_cond_fn.consts, *flat_args, + *abstract_shapes, jaxpr_body_fn=jaxpr_body_fn.jaxpr, jaxpr_cond_fn=jaxpr_cond_fn.jaxpr, body_slice=body_consts, cond_slice=cond_consts, args_slice=args_slice, + abstract_shapes_slice=abstract_shapes_slice, ) assert flat_body_fn.out_tree is not None, "Should be set when constructing the jaxpr" return jax.tree_util.tree_unflatten(flat_body_fn.out_tree, results) From 8e6a16707d391282bf1d510bce3f537b4c349628 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 9 Jan 2025 17:31:21 -0500 Subject: [PATCH 04/19] adding tests --- pennylane/capture/dynamic_shapes.py | 29 +++++++++++++++++--- pennylane/compiler/qjit_api.py | 2 +- tests/capture/test_capture_for_loop.py | 21 +++++++++++++++ tests/capture/test_capture_qnode.py | 21 +++++++++++++++ tests/capture/test_capture_while_loop.py | 22 +++++++++++++++ tests/capture/test_nested_plxpr.py | 34 ++++++++++++++++++++++++ 6 files changed, 124 insertions(+), 5 deletions(-) diff --git a/pennylane/capture/dynamic_shapes.py b/pennylane/capture/dynamic_shapes.py index ea566724a6c..3d3ddb71bc3 100644 --- a/pennylane/capture/dynamic_shapes.py +++ b/pennylane/capture/dynamic_shapes.py @@ -23,14 +23,35 @@ has_jax = False -def determine_abstracted_axes(args, structure=None): - """Computed the abstracted axes and extracing the abstract shapes from the arguments.""" +def determine_abstracted_axes(args): + """Computed the abstracted axes and extracing the abstract shapes from the arguments. + + Args: + args (tuple): the arguments for a higher order primitive + + Returns: + tuple, tuple: the corresponding abstracted axes and dynamic shapes + + See the ``intro_to_dynamic_shapes.md`` document for more information on how dynamic shapes work. + + To make jaxpr from arguments with dynamic shapes, the ``abstracted_axes`` keyword argument must be set. + Then, when calling the jaxpr, variables for the dynamic shapes must be passed. + + ``` + def f(n): + x = jax.numpy.ones((n,)) + abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes((x,)) + jaxpr = jax.make_jaxpr(jax.numpy.sum, abstracted_axes=abstracted_axes)(x) + return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, x) + ``` + + """ if not has_jax: raise ImportError("jax must be installed to use determine_abstracted_axes") if not jax.config.jax_dynamic_shapes: return None, tuple() - if structure is None: - args, structure = jax.tree_util.tree_flatten(args) + + args, structure = jax.tree_util.tree_flatten(args) abstracted_axes = [] abstract_shapes = [] for l in args: diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index 6f2ca052e95..0c1683a3c03 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -427,7 +427,7 @@ def _( abstract_shapes = args[abstract_shapes_slice] # If cond_fn(*init_state) is False, return the initial state fn_res = init_state - while jax.core.eval_jaxpr(jaxpr_cond_fn, jaxpr_consts_cond, *fn_res)[0]: + while jax.core.eval_jaxpr(jaxpr_cond_fn, jaxpr_consts_cond, *abstract_shapes, *fn_res)[0]: fn_res = jax.core.eval_jaxpr( jaxpr_body_fn, jaxpr_consts_body, *abstract_shapes, *fn_res ) diff --git a/tests/capture/test_capture_for_loop.py b/tests/capture/test_capture_for_loop.py index 7ba8e098db1..f967a1b5d1c 100644 --- a/tests/capture/test_capture_for_loop.py +++ b/tests/capture/test_capture_for_loop.py @@ -234,6 +234,27 @@ def loop_body(i, array, sum_val): res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, array) assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + def test_dynamic_shape_input(self): + jax.config.update("jax_dynamic_shapes", True) + try: + + def f(x): + n = jax.numpy.shape(x)[0] + + @qml.for_loop(n) + def g(_, y): + return y + y + + return g(x) + + jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(5)) + + [output] = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3, jax.numpy.arange(3)) + expected = jax.numpy.array([0, 8, 16]) # [0, 1, 2] * 2**3 + assert jax.numpy.allclose(output, expected) + finally: + jax.config.update("jax_dynamic_shapes", False) + class TestCaptureCircuitsForLoop: """Tests for capturing for loops into jaxpr in the context of quantum circuits.""" diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index 0ac083a768c..5b35b13abb7 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -366,6 +366,27 @@ def circuit(x): assert qml.math.allclose(res, jax.numpy.cos(x)) +def test_dynamic_shape_input(): + """Test that the qnode can accept an input with a dynamic shape.""" + + jax.config.update("jax_dynamic_shapes", True) + try: + + @qml.qnode(qml.device("default.qubit", wires=1)) + def circuit(x): + qml.RX(jax.numpy.sum(x), 0) + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(circuit, abstracted_axes=("a",))(jax.numpy.arange(4)) + + [output] = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3, jax.numpy.arange(3)) + expected = jax.numpy.cos(0 + 1 + 2) + assert jax.numpy.allclose(expected, output) + + finally: + jax.config.update("jax_dynamic_shapes", False) + + # pylint: disable=too-many-public-methods class TestQNodeVmapIntegration: """Tests for integrating JAX vmap with the QNode primitive.""" diff --git a/tests/capture/test_capture_while_loop.py b/tests/capture/test_capture_while_loop.py index ac0e4d00d1e..c5de5b1fa39 100644 --- a/tests/capture/test_capture_while_loop.py +++ b/tests/capture/test_capture_while_loop.py @@ -81,6 +81,28 @@ def loop(a, b, idx): assert np.allclose(res_arr1_jxpr, expected), f"Expected {expected}, but got {res_arr1_jxpr}" assert np.allclose(res_idx, res_idx_jxpr) and res_idx_jxpr == 10 + def test_while_loop_dyanmic_shape_array(self): + """Test while loop can accept ararys with dynamic shapes.""" + + jax.config.update("jax_dynamic_shapes", True) + + try: + + def f(x): + @qml.while_loop(lambda res: jax.numpy.sum(res) < 10) + def g(res): + return res + res + + return g(x) + + jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(2)) + + [output] = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3, jax.numpy.arange(3)) + expected = jax.numpy.array([0, 4, 8]) + assert jax.numpy.allclose(output, expected) + finally: + jax.config.update("jax_dynamic_shapes", False) + class TestCaptureCircuitsWhileLoop: """Tests for capturing for while loops into jaxpr in the context of quantum circuits.""" diff --git a/tests/capture/test_nested_plxpr.py b/tests/capture/test_nested_plxpr.py index 031ae5c2328..919990dfd0d 100644 --- a/tests/capture/test_nested_plxpr.py +++ b/tests/capture/test_nested_plxpr.py @@ -186,6 +186,23 @@ def workflow(x): out = jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 0.5) assert qml.math.isclose(out, qml.math.sin(-(0.5 + 0.3))) + def test_dynamic_shape_input(self): + """Test that the adjoint transform can accept arrays with dynamic shapes.""" + jax.config.update("jax_dynamic_shapes", True) + try: + + def f(x): + qml.adjoint(qml.RX)(x, 0) + + jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(4)) + + tape = qml.tape.plxpr_to_tape(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.arange(2)) + expected = qml.adjoint(qml.RX(jax.numpy.arange(2), 0)) + qml.assert_equal(tape[0], expected) + + finally: + jax.config.update("jax_dynamic_shapes", False) + class TestCtrlQfunc: """Tests for the ctrl primitive.""" @@ -361,3 +378,20 @@ def workflow(x): out = jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 0.5) assert qml.math.isclose(out, -0.5 * qml.math.sin(0.5 + 0.3)) + + def test_dynamic_shape_input(self): + """Test that ctrl can accept dynamic shape inputs.""" + jax.config.update("jax_dynamic_shapes", True) + try: + + def f(x): + qml.ctrl(qml.RX, (2, 3))(x, 0) + + jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(4)) + + tape = qml.tape.plxpr_to_tape(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.arange(2)) + expected = qml.ctrl(qml.RX(jax.numpy.arange(2), 0), (2, 3)) + qml.assert_equal(tape[0], expected) + + finally: + jax.config.update("jax_dynamic_shapes", False) From 6fff950726cbb3364b41d1a4bf8bbf68735473da Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 14 Jan 2025 11:15:26 -0500 Subject: [PATCH 05/19] adding testing --- doc/releases/changelog-dev.md | 3 + pennylane/capture/dynamic_shapes.py | 23 ++--- pennylane/ops/op_math/adjoint.py | 3 +- pennylane/ops/op_math/controlled.py | 3 +- tests/capture/test_dynamic_shapes.py | 129 +++++++++++++++++++++++++++ tests/capture/test_nested_plxpr.py | 42 +++++++++ 6 files changed, 191 insertions(+), 12 deletions(-) create mode 100644 tests/capture/test_dynamic_shapes.py diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 0f9e46664b9..67a4b0c4a26 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -6,6 +6,9 @@

Improvements 🛠

+* The higher order primitives in program capture can now accept inputs with abstract shapes. + [(#6786)](https://github.com/PennyLaneAI/pennylane/pull/6786) +

Breaking changes 💔

Deprecations 👋

diff --git a/pennylane/capture/dynamic_shapes.py b/pennylane/capture/dynamic_shapes.py index 3d3ddb71bc3..c88e44bc4da 100644 --- a/pennylane/capture/dynamic_shapes.py +++ b/pennylane/capture/dynamic_shapes.py @@ -46,7 +46,7 @@ def f(n): ``` """ - if not has_jax: + if not has_jax: # pragma: no cover raise ImportError("jax must be installed to use determine_abstracted_axes") if not jax.config.jax_dynamic_shapes: return None, tuple() @@ -55,16 +55,19 @@ def f(n): abstracted_axes = [] abstract_shapes = [] for l in args: - l_shape = [] - for s in getattr(l, "shape", ()): - if isinstance(s, int): # not abstract - l_shape.append(()) - else: - l_shape.append(ascii_lowercase[len(abstract_shapes)]) - if all(s is not x for x in abstract_shapes): - # not already added + l_shape = {} + for i, s in enumerate(getattr(l, "shape", ())): + if not isinstance(s, int): # not abstract + found = False + for j, previous_shape in enumerate(abstract_shapes): + if s is previous_shape: + l_shape[i] = ascii_lowercase[j] + found = True + continue + if not found: + l_shape[i] = ascii_lowercase[len(abstract_shapes)] abstract_shapes.append(s) - abstracted_axes.append(tuple(l_shape)) + abstracted_axes.append(l_shape) if not abstract_shapes: return None, () diff --git a/pennylane/ops/op_math/adjoint.py b/pennylane/ops/op_math/adjoint.py index 5e0366a5fa0..b9ee99bc8cd 100644 --- a/pennylane/ops/op_math/adjoint.py +++ b/pennylane/ops/op_math/adjoint.py @@ -224,10 +224,11 @@ def _capture_adjoint_transform(qfunc: Callable, lazy=True) -> Callable: def new_qfunc(*args, **kwargs): abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes(args) jaxpr = jax.make_jaxpr(partial(qfunc, **kwargs), abstracted_axes=abstracted_axes)(*args) + flat_args = jax.tree_util.tree_leaves(args) adjoint_prim.bind( *jaxpr.consts, *abstract_shapes, - *args, + *flat_args, jaxpr=jaxpr.jaxpr, lazy=lazy, n_consts=len(jaxpr.consts), diff --git a/pennylane/ops/op_math/controlled.py b/pennylane/ops/op_math/controlled.py index d768ec0a00e..2694d828996 100644 --- a/pennylane/ops/op_math/controlled.py +++ b/pennylane/ops/op_math/controlled.py @@ -272,11 +272,12 @@ def new_qfunc(*args, **kwargs): jaxpr = jax.make_jaxpr(functools.partial(qfunc, **kwargs), abstracted_axes=abstracted_axes)( *args ) + flat_args = jax.tree_util.tree_leaves(args) control_wires = qml.wires.Wires(control) # make sure is iterable ctrl_prim.bind( *jaxpr.consts, *abstract_shapes, - *args, + *flat_args, *control_wires, jaxpr=jaxpr.jaxpr, n_control=len(control_wires), diff --git a/tests/capture/test_dynamic_shapes.py b/tests/capture/test_dynamic_shapes.py new file mode 100644 index 00000000000..2216d4c6e99 --- /dev/null +++ b/tests/capture/test_dynamic_shapes.py @@ -0,0 +1,129 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests a function for determining abstracted axes and extracting the abstract shapes. +""" +# pylint: disable=redefined-outer-name, unused-argument + +import pytest + +from pennylane.capture import determine_abstracted_axes + +marks = pytest.mark.jax + +jax = pytest.importorskip("jax") +jnp = pytest.importorskip("jax.numpy") + + +@pytest.fixture +def enable_disable(): + jax.config.update("jax_dynamic_shapes", True) + try: + yield + finally: + jax.config.update("jax_dynamic_shapes", False) + + +def test_null_if_not_enabled(): + """Test None and an empty tuple are returned if dynamic shapes is not enabled.""" + + def f(*args): + abstracted_axes, abstract_shapes = determine_abstracted_axes(args) + + assert abstracted_axes is None + assert abstract_shapes == () + + _ = jax.make_jaxpr(f)(jnp.eye(4)) + + +def test_null_if_no_abstract_shapes(enable_disable): + """Test the None and an empty tuple are returned if no dynamic shapes exist.""" + + def f(*args): + abstracted_axes, abstract_shapes = determine_abstracted_axes(args) + + assert abstracted_axes is None + assert abstract_shapes == () + + _ = jax.make_jaxpr(f)(jnp.eye(4)) + + +def test_single_abstract_shape(enable_disable): + """Test we get the correct answer for a single abstract shape.""" + + initial_abstracted_axes = ({0: "a"},) + + def f(*args): + abstracted_axes, abstract_shapes = determine_abstracted_axes(args) + + assert abstracted_axes == initial_abstracted_axes + assert len(abstract_shapes) == 1 + + # test we can make jaxpr with these abstracted axes + jaxpr = jax.make_jaxpr(lambda *args: 0, abstracted_axes=abstracted_axes)(*args) + _ = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, *args) + + _ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(jnp.arange(4)) + + +@pytest.mark.parametrize( + "initial_abstracted_axes, num_shapes", + [ + (({0: "a", 1: "b"},), 2), + (({0: "a", 1: "a"},), 1), + (({1: "a"},), 1), + ], +) +def test_single_abstract_shape_multiple_abstract_axes( + enable_disable, initial_abstracted_axes, num_shapes +): + """Test we get the correct answer for a single input with two abstract axes.""" + + def f(*args): + abstracted_axes, abstract_shapes = determine_abstracted_axes(args) + + assert abstracted_axes == initial_abstracted_axes + assert len(abstract_shapes) == num_shapes + + # test we can make jaxpr with these abstracted axes + jaxpr = jax.make_jaxpr(lambda *args: 0, abstracted_axes=abstracted_axes)(*args) + _ = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, *args) + + _ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(jnp.eye(4)) + + +def test_pytree_input(enable_disable): + """Test a pytree input with dynamic shapes.""" + + initial_abstracted_axes = ( + {"input0": {}, "input1": {0: "a"}, "input2": {0: "a"}, "input3": {1: "b"}}, + ) + arg = { + "input0": jnp.arange(5), + "input1": jnp.arange(3), + "input2": jnp.arange(3), + "input3": jnp.eye(4), + } + + def f(*args): + abstracted_axes, abstract_shapes = determine_abstracted_axes(args) + assert abstracted_axes == initial_abstracted_axes + assert len(abstract_shapes) == 2 + + # test we can make jaxpr with these abstracted axes + jaxpr = jax.make_jaxpr(lambda *args: 0, abstracted_axes=abstracted_axes)(*args) + flat_args = jax.tree_util.tree_leaves(args) + _ = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, *flat_args) + + _ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(arg) diff --git a/tests/capture/test_nested_plxpr.py b/tests/capture/test_nested_plxpr.py index 919990dfd0d..f2bfebc14d5 100644 --- a/tests/capture/test_nested_plxpr.py +++ b/tests/capture/test_nested_plxpr.py @@ -203,6 +203,34 @@ def f(x): finally: jax.config.update("jax_dynamic_shapes", False) + def test_complicated_dynamic_shape_input(self): + """Test a dynamic shape input with a more complicate shape.""" + jax.config.update("jax_dynamic_shapes", True) + try: + + def g(x, y): + qml.RX(x["a"], 0) + qml.RY(y, 0) + + def f(x, y): + qml.adjoint(g)(x, y) + + abstracted_axes = ({"a": {0: "n"}}, {0: "m"}) + jaxpr = jax.make_jaxpr(f, abstracted_axes=abstracted_axes)( + {"a": jax.numpy.arange(2)}, jax.numpy.arange(3) + ) + tape = qml.tape.plxpr_to_tape( + jaxpr.jaxpr, jaxpr.consts, 3, 4, jax.numpy.arange(3), jax.numpy.arange(4) + ) + + op1 = qml.adjoint(qml.RY(jax.numpy.arange(4), 0)) + op2 = qml.adjoint(qml.RX(jax.numpy.arange(3), 0)) + qml.assert_equal(op1, tape[0]) + qml.assert_equal(op2, tape[1]) + + finally: + jax.config.update("jax_dynamic_shapes", False) + class TestCtrlQfunc: """Tests for the ctrl primitive.""" @@ -395,3 +423,17 @@ def f(x): finally: jax.config.update("jax_dynamic_shapes", False) + + def test_pytree_input(self): + """Test that ctrl can accept pytree inputs.""" + + def g(x): + qml.RX(x["a"], x["wire"]) + + def f(x): + qml.ctrl(g, [1])(x) + + jaxpr = jax.make_jaxpr(f)({"a": 0.5, "wire": 0}) + tape = qml.tape.plxpr_to_tape(jaxpr.jaxpr, jaxpr.consts, 0.5, 0) + expected = qml.ctrl(qml.RX(0.5, 0), [1]) + qml.assert_equal(tape[0], expected) From 4b04dae072e7d3ca6df6ef7e5cfc320876515f60 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 14 Jan 2025 11:42:37 -0500 Subject: [PATCH 06/19] fix marking --- tests/capture/test_dynamic_shapes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/capture/test_dynamic_shapes.py b/tests/capture/test_dynamic_shapes.py index 2216d4c6e99..32ff32d875f 100644 --- a/tests/capture/test_dynamic_shapes.py +++ b/tests/capture/test_dynamic_shapes.py @@ -20,7 +20,7 @@ from pennylane.capture import determine_abstracted_axes -marks = pytest.mark.jax +pytestmark = pytest.mark.jax jax = pytest.importorskip("jax") jnp = pytest.importorskip("jax.numpy") From a5fb9eecf6d6c11c8c8db3a8d11c3441567537b7 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Tue, 14 Jan 2025 12:01:49 -0500 Subject: [PATCH 07/19] Update pennylane/capture/dynamic_shapes.py --- pennylane/capture/dynamic_shapes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pennylane/capture/dynamic_shapes.py b/pennylane/capture/dynamic_shapes.py index c88e44bc4da..52ba0f93ebf 100644 --- a/pennylane/capture/dynamic_shapes.py +++ b/pennylane/capture/dynamic_shapes.py @@ -19,8 +19,8 @@ has_jax = True try: import jax -except ImportError: - has_jax = False +except ImportError: # pragma: no cover + has_jax = False # pragma: no cover def determine_abstracted_axes(args): From 8604e5b031bd42832ab1f3c261ef7685039ebb12 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Tue, 14 Jan 2025 12:06:13 -0500 Subject: [PATCH 08/19] Apply suggestions from code review --- pennylane/capture/base_interpreter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index f49e11d3a6b..e7727a46a65 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -512,11 +512,11 @@ def flatten_while_loop( consts_body = invals[body_slice] consts_cond = invals[cond_slice] init_state = invals[args_slice] - abstract_shapes_slice = invals[abstract_shapes_slice] + abstract_shapes = invals[abstract_shapes_slice] fn_res = init_state - while copy(self).eval(jaxpr_cond_fn, consts_cond, *abstract_shapes_slice, *fn_res)[0]: - fn_res = copy(self).eval(jaxpr_body_fn, consts_body, *abstract_shapes_slice, *fn_res) + while copy(self).eval(jaxpr_cond_fn, consts_cond, *abstract_shapes, *fn_res)[0]: + fn_res = copy(self).eval(jaxpr_body_fn, consts_body, *abstract_shapes, *fn_res) return fn_res From 47bdf117595545e11b927ec82b4006f36b497a2e Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 14 Jan 2025 13:35:05 -0500 Subject: [PATCH 09/19] black and changelog --- doc/releases/changelog-dev.md | 4 +++- pennylane/capture/dynamic_shapes.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 67a4b0c4a26..f5a42e8c2bb 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -23,4 +23,6 @@

Contributors ✍️

This release contains contributions from (in alphabetical order): -Diksha Dhawan \ No newline at end of file + +Diksha Dhawan +Christina Lee \ No newline at end of file diff --git a/pennylane/capture/dynamic_shapes.py b/pennylane/capture/dynamic_shapes.py index 52ba0f93ebf..b13aad05155 100644 --- a/pennylane/capture/dynamic_shapes.py +++ b/pennylane/capture/dynamic_shapes.py @@ -19,8 +19,8 @@ has_jax = True try: import jax -except ImportError: # pragma: no cover - has_jax = False # pragma: no cover +except ImportError: # pragma: no cover + has_jax = False # pragma: no cover def determine_abstracted_axes(args): From a44e7b998ad4afa37b2adeade6e67bc0cb47c8e1 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 15 Jan 2025 13:53:03 -0500 Subject: [PATCH 10/19] respond to feedback --- pennylane/capture/__init__.py | 1 + pennylane/capture/dynamic_shapes.py | 71 +++++++++++++------- pennylane/capture/intro_to_dynamic_shapes.md | 41 ++++++++++- pennylane/compiler/qjit_api.py | 24 +++---- pennylane/ops/op_math/condition.py | 7 +- tests/capture/test_capture_cond.py | 20 ++++++ 6 files changed, 124 insertions(+), 40 deletions(-) diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index 8bd4d358ce0..2d99d649c6b 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -33,6 +33,7 @@ ~create_measurement_obs_primitive ~create_measurement_wires_primitive ~create_measurement_mcm_primitive + ~determine_abstracted_axes ~expand_plxpr_transforms ~run_autograph ~make_plxpr diff --git a/pennylane/capture/dynamic_shapes.py b/pennylane/capture/dynamic_shapes.py index b13aad05155..89cf8d32509 100644 --- a/pennylane/capture/dynamic_shapes.py +++ b/pennylane/capture/dynamic_shapes.py @@ -23,6 +23,42 @@ has_jax = False # pragma: no cover +def _get_shape_for_array(x, abstract_shapes: list) -> dict: + """ + Populate the dictionay of abstract axes for a single tensorlike. + + This dictionary has dimensions as axes, and a string marker as the value. + + Examples of shape -> abstract axes: + + * `(3,4) -> {}` + * `(tracer1, ) -> {0: "a"}` + * `(tracer1, tracer1) -> {0: "a", 1: "a"}` + * `(3, tracer1) -> {1: "a"}` + * `(tracer1, 2, tracer2) -> {0: "a", 2: "b"}` + + `abstract_shapes` contains all the tracers found in shapes. + + """ + abstract_axes = {} + for i, s in enumerate(getattr(x, "shape", ())): + if not isinstance(s, int): # if not int, then abstract + found = False + # check if the shape tracer is one we have already encountered + for previous_idx, previous_shape in enumerate(abstract_shapes): + if s is previous_shape: + abstract_axes[i] = ascii_lowercase[previous_idx] + found = True + continue + # haven't encountered it, so add it to abstract_axes + # and use new letter designation + if not found: + abstract_axes[i] = ascii_lowercase[len(abstract_shapes)] + abstract_shapes.append(s) + + return abstract_axes + + def determine_abstracted_axes(args): """Computed the abstracted axes and extracing the abstract shapes from the arguments. @@ -37,37 +73,26 @@ def determine_abstracted_axes(args): To make jaxpr from arguments with dynamic shapes, the ``abstracted_axes`` keyword argument must be set. Then, when calling the jaxpr, variables for the dynamic shapes must be passed. - ``` - def f(n): - x = jax.numpy.ones((n,)) - abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes((x,)) - jaxpr = jax.make_jaxpr(jax.numpy.sum, abstracted_axes=abstracted_axes)(x) - return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, x) - ``` + .. code-block:: python + + def f(n): + x = jax.numpy.ones((n,)) + abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes((x,)) + jaxpr = jax.make_jaxpr(jax.numpy.sum, abstracted_axes=abstracted_axes)(x) + return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, x) """ if not has_jax: # pragma: no cover raise ImportError("jax must be installed to use determine_abstracted_axes") - if not jax.config.jax_dynamic_shapes: + if not jax.config.jax_dynamic_shapes: # pylint: disable=no-member return None, tuple() args, structure = jax.tree_util.tree_flatten(args) - abstracted_axes = [] + abstract_shapes = [] - for l in args: - l_shape = {} - for i, s in enumerate(getattr(l, "shape", ())): - if not isinstance(s, int): # not abstract - found = False - for j, previous_shape in enumerate(abstract_shapes): - if s is previous_shape: - l_shape[i] = ascii_lowercase[j] - found = True - continue - if not found: - l_shape[i] = ascii_lowercase[len(abstract_shapes)] - abstract_shapes.append(s) - abstracted_axes.append(l_shape) + # note: this function in-place mutates abstract_shapes + # adding any additional abstract shapes found + abstracted_axes = [_get_shape_for_array(a, abstract_shapes) for a in args] if not abstract_shapes: return None, () diff --git a/pennylane/capture/intro_to_dynamic_shapes.md b/pennylane/capture/intro_to_dynamic_shapes.md index 1e1e4f40adb..f2b699c1411 100644 --- a/pennylane/capture/intro_to_dynamic_shapes.md +++ b/pennylane/capture/intro_to_dynamic_shapes.md @@ -86,6 +86,45 @@ jax.make_jaxpr(g, abstracted_axes=("x",))(jax.numpy.array([1,2,3])) { lambda ; a:i32[] b:i32[a]. let c:i32[] = reduce_sum[axes=(0,)] b in (c,) } +### Understanding `abstracted_axes` + +Suppose we want to have two arrays with dynamic array dimensions `a` and `b`. +`x` has two dynamic axes, with a shape `(a, b)`. This corresponds to an abstracted axes specification of `{0:"a", 1:"b"}`. +`y` has one dynamic axis and one static axis, with a shape `(4, b)`. This corresponds to an abstracted axes specification of +`{1:"b"}`. As the `0` dimension is static, it is not included in the dictionary. + +The abstracted axes for both `x` and `y` include the string `"b"`. This is because the second dimension of `x` and the second dimension +of `y` should always match and should be represented by a single tracer variable. + +``` +a = 3 +b = 4 +x = jnp.zeros((a, b)) +y = jnp.zeros((4, b)) +x_axes = {0: "a", 1: "b"} +y_axes = {1: "b"} +args = (x, y) +abstracted_axes = (x_axes, y_axes) +jax.make_jaxpr(f, abstracted_axes=abstracted_axes)(*args) +``` +``` +{ lambda ; a:i32[] b:i32[] c:f32[a,b] d:f32[4,b]. let in (0,) } +``` + +The abstracted axes should have the same pytree structure as `args`, but with each tensor replaced by a dictionary indicating which axes +are abstract. Suppose our first argument is instead a dictionary with tensorlike leaves. Then we should provide an `abstracted_axes` with +the same tree structure. + +``` +args = ({"x": x, "y": y},) +abstracted_axes = ({"x": x_axes, "y": y_axes},) +jax.make_jaxpr(f, abstracted_axes=abstracted_axes)(*args) +``` +``` +{ lambda ; a:i32[] b:i32[] c:f32[a,b] d:f32[4,b]. let in (0,) } +``` + + ## Limitations of dynamic shapes and numerical manipulations 1. Slicing into a dynamically sized array. @@ -152,7 +191,7 @@ jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3) -## Extending support to PLXPR HOP's +## Extending support to PLXPR Higher Order Primitives (HOP's) When capturing higher order primitives, we call `jax.make_jaxpr(f)` with arguments whose shapes are tracers. diff --git a/pennylane/compiler/qjit_api.py b/pennylane/compiler/qjit_api.py index 0c1683a3c03..081f032a96a 100644 --- a/pennylane/compiler/qjit_api.py +++ b/pennylane/compiler/qjit_api.py @@ -410,6 +410,7 @@ def _get_while_loop_qfunc_prim(): while_loop_prim = create_non_interpreted_prim()("while_loop") while_loop_prim.multiple_results = True + # pylint: disable=too-many-arguments @while_loop_prim.def_impl def _( *args, @@ -477,20 +478,17 @@ def _call_capture_enabled(self, *init_state): jaxpr_body_fn = jax.make_jaxpr(flat_body_fn, abstracted_axes=abstracted_axes)(*init_state) jaxpr_cond_fn = jax.make_jaxpr(self.cond_fn, abstracted_axes=abstracted_axes)(*init_state) - n_bf_c = len(jaxpr_body_fn.consts) - n_cf_c = len(jaxpr_cond_fn.consts) - end_abstract_shapes = -len(abstract_shapes) if abstract_shapes else None - body_consts = slice(0, n_bf_c) - cond_consts = slice(n_bf_c, n_bf_c + n_cf_c) - args_slice = slice(n_cf_c + n_bf_c, end_abstract_shapes) - abstract_shapes_slice = slice(end_abstract_shapes, None) if abstract_shapes else slice(0, 0) + body_consts = slice(0, len(jaxpr_body_fn.consts)) + cond_consts = slice(body_consts.stop, body_consts.stop + len(jaxpr_cond_fn.consts)) + abstract_shapes_slice = slice(cond_consts.stop, cond_consts.stop + len(abstract_shapes)) + args_slice = slice(abstract_shapes_slice.stop, None) flat_args, _ = jax.tree_util.tree_flatten(init_state) results = while_loop_prim.bind( *jaxpr_body_fn.consts, *jaxpr_cond_fn.consts, - *flat_args, *abstract_shapes, + *flat_args, jaxpr_body_fn=jaxpr_body_fn.jaxpr, jaxpr_cond_fn=jaxpr_cond_fn.jaxpr, body_slice=body_consts, @@ -715,12 +713,8 @@ def _call_capture_enabled(self, *init_state): jaxpr_body_fn = jax.make_jaxpr(flat_fn, abstracted_axes=abstracted_axes)(0, *init_state) consts_slice = slice(0, len(jaxpr_body_fn.consts)) - args_slice = slice( - len(jaxpr_body_fn.consts), -len(abstract_shapes) if abstract_shapes else None - ) - abstract_shapes_slice = ( - slice(-len(abstract_shapes), None) if abstract_shapes else slice(0, 0) - ) + abstract_shapes_slice = slice(consts_slice.stop, consts_slice.stop + len(abstract_shapes)) + args_slice = slice(abstract_shapes_slice.stop, None) flat_args, _ = jax.tree_util.tree_flatten(init_state) @@ -729,8 +723,8 @@ def _call_capture_enabled(self, *init_state): self.stop, self.step, *jaxpr_body_fn.consts, - *flat_args, *abstract_shapes, + *flat_args, jaxpr_body_fn=jaxpr_body_fn.jaxpr, consts_slice=consts_slice, args_slice=args_slice, diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index deace92e73c..8b64802bb2c 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -234,13 +234,17 @@ def __call_capture_enabled(self, *args, **kwargs): consts = [] consts_slices = [] + abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes(args) + for pred, fn in branches: conditions.append(pred) if fn is None: jaxpr_branches.append(None) consts_slices.append(slice(0, 0)) else: - jaxpr = jax.make_jaxpr(functools.partial(fn, **kwargs))(*args) + jaxpr = jax.make_jaxpr( + functools.partial(fn, **kwargs), abstracted_axes=abstracted_axes + )(*args) jaxpr_branches.append(jaxpr.jaxpr) consts_slices.append(slice(end_const_ind, end_const_ind + len(jaxpr.consts))) consts += jaxpr.consts @@ -250,6 +254,7 @@ def __call_capture_enabled(self, *args, **kwargs): results = cond_prim.bind( *conditions, *consts, + *abstract_shapes, *flat_args, jaxpr_branches=jaxpr_branches, consts_slices=consts_slices, diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index ed8922b1d0b..3007e45ed42 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -847,3 +847,23 @@ def f(x): assert isinstance(q.queue[0], qml.measurements.MidMeasureMP) assert isinstance(q.queue[1], qml.ops.Conditional) qml.assert_equal(q.queue[1].base, qml.RX(0.5, 0)) + + +def test_cond_abstracted_axes(): + """Test cond can accept inputs with dynamic shapes.""" + jax.config.update("jax_dynamic_shapes", True) + try: + + def workflow(x, predicate): + return qml.cond(predicate, jax.numpy.sum, false_fn=jax.numpy.prod)(x) + + jaxpr = jax.make_jaxpr(workflow, abstracted_axes=({0: "a"}, {}))(jax.numpy.arange(3), True) + + output_true = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 4, jax.numpy.arange(4), True) + assert qml.math.allclose(output_true[0], 6) # 0 + 1 + 2 + 3 + + output_false = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.arange(2), False) + assert qml.math.allclose(output_false[0], 0) # 0 * 1 + + finally: + jax.config.update("jax_dynamic_shapes", False) From c57ff027cdaad88d0e826a642f805fccac2ea744 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Wed, 15 Jan 2025 13:57:16 -0500 Subject: [PATCH 11/19] Apply suggestions from code review Co-authored-by: lillian542 <38584660+lillian542@users.noreply.github.com> --- pennylane/capture/dynamic_shapes.py | 2 +- pennylane/capture/intro_to_dynamic_shapes.md | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/pennylane/capture/dynamic_shapes.py b/pennylane/capture/dynamic_shapes.py index 89cf8d32509..e1488f11bb3 100644 --- a/pennylane/capture/dynamic_shapes.py +++ b/pennylane/capture/dynamic_shapes.py @@ -27,7 +27,7 @@ def _get_shape_for_array(x, abstract_shapes: list) -> dict: """ Populate the dictionay of abstract axes for a single tensorlike. - This dictionary has dimensions as axes, and a string marker as the value. + This dictionary has dimensions as keys, and a string marker as the value. Examples of shape -> abstract axes: diff --git a/pennylane/capture/intro_to_dynamic_shapes.md b/pennylane/capture/intro_to_dynamic_shapes.md index f2b699c1411..af42e754768 100644 --- a/pennylane/capture/intro_to_dynamic_shapes.md +++ b/pennylane/capture/intro_to_dynamic_shapes.md @@ -5,14 +5,11 @@ import jax ``` -Dynamic shapes are experimental feature of jax with limited support and feature coverage. - +Dynamic shapes are an experimental feature of jax with limited support and feature coverage. +Without the `"jax_dynamic_shapes"` feature, we can't create arrays whose size depends on an abstract value. ```python jax.config.update("jax_dynamic_shapes", False) -``` - -Without this setup, we can't create arrays whose size depends on an abstract value. ```python @@ -279,7 +276,7 @@ jax.make_jaxpr(f)(3) -We can now take these learnings a make custom higher order primitive that supports dynamically shaped inputs: +We can now take these learnings to make a custom higher order primitive that supports dynamically shaped inputs: ```python From 8403a5803d0c68f1bb62d7dfb3ac14cb4e95806e Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Wed, 15 Jan 2025 16:24:46 -0500 Subject: [PATCH 12/19] Apply suggestions from code review Co-authored-by: Mudit Pandey --- pennylane/capture/dynamic_shapes.py | 4 ++-- pennylane/capture/intro_to_dynamic_shapes.md | 12 +++--------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/pennylane/capture/dynamic_shapes.py b/pennylane/capture/dynamic_shapes.py index e1488f11bb3..6d17b5061f2 100644 --- a/pennylane/capture/dynamic_shapes.py +++ b/pennylane/capture/dynamic_shapes.py @@ -25,7 +25,7 @@ def _get_shape_for_array(x, abstract_shapes: list) -> dict: """ - Populate the dictionay of abstract axes for a single tensorlike. + Populate the dictionary of abstract axes for a single tensorlike. This dictionary has dimensions as keys, and a string marker as the value. @@ -60,7 +60,7 @@ def _get_shape_for_array(x, abstract_shapes: list) -> dict: def determine_abstracted_axes(args): - """Computed the abstracted axes and extracing the abstract shapes from the arguments. + """Computed the abstracted axes and extracting the abstract shapes from the arguments. Args: args (tuple): the arguments for a higher order primitive diff --git a/pennylane/capture/intro_to_dynamic_shapes.md b/pennylane/capture/intro_to_dynamic_shapes.md index af42e754768..3cc139726c8 100644 --- a/pennylane/capture/intro_to_dynamic_shapes.md +++ b/pennylane/capture/intro_to_dynamic_shapes.md @@ -133,7 +133,7 @@ Erick has an open PR to fix this issue on the jax github. Catalyst currently pa def h(x): return x[0] -jax.make_jaxpr(h, abstracted_axes=("x", ) )(jax.numpy.array([0, 1,2])) +jax.make_jaxpr(h, abstracted_axes=("x", ) )(jax.numpy.array([0, 1, 2])) ``` @@ -143,7 +143,7 @@ jax.make_jaxpr(h, abstracted_axes=("x", ) )(jax.numpy.array([0, 1,2])) -Executing with `eval_jaxpr`: +2. Executing with `eval_jaxpr`: No idea how to fix this right now. @@ -231,7 +231,7 @@ inner jaxpr: { lambda ; a:i32[] b:f32[a]. let c:f32[] = reduce_sum[axes=(0,)] b Note in this case that I am passing `n` when evaluating the jaxpr, even though `n` wasn't an argument that produced the jaxpr. -`n` was an implicit argument contained inside of `x`, so `make_jaxpr` promotes it explicit input. We can see this in the "inner jaxpr" printed out inside the function. Even though the function that produced it only had `x` as an input, the jaxpr has `a:i32[], b:f32[a]` as two arguments. When re-evaluating the jaxpr later, we need to make sure to pass the value for `n` as well. +`n` was an implicit argument contained inside of `x`, so `make_jaxpr` promotes it to an explicit input. We can see this in the "inner jaxpr" printed out inside the function. Even though the function that produced it only had `x` as an input, the jaxpr has `a:i32[], b:f32[a]` as two arguments. When re-evaluating the jaxpr later, we need to make sure to pass the value for `n` as well. To handle generic functions, we must then be able to determine which axes are dynamic from the arguments, and extract the tracer values for all the abstract dimensions. @@ -431,9 +431,6 @@ prim3.multiple_results = True ``` -```python - -``` ```python @@ -456,9 +453,6 @@ pe.custom_staging_rules[prim3] = custom_staging_rule ``` -```python - -``` ```python From b54d092868e6b61a75d3591e4937b91b8b2fe2c4 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Wed, 15 Jan 2025 16:27:36 -0500 Subject: [PATCH 13/19] Apply suggestions from code review Co-authored-by: Mudit Pandey --- pennylane/capture/dynamic_shapes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pennylane/capture/dynamic_shapes.py b/pennylane/capture/dynamic_shapes.py index 6d17b5061f2..2ad58e7d798 100644 --- a/pennylane/capture/dynamic_shapes.py +++ b/pennylane/capture/dynamic_shapes.py @@ -49,7 +49,7 @@ def _get_shape_for_array(x, abstract_shapes: list) -> dict: if s is previous_shape: abstract_axes[i] = ascii_lowercase[previous_idx] found = True - continue + break # haven't encountered it, so add it to abstract_axes # and use new letter designation if not found: From 88361bc9ae8566cd07edc0365b9e4d7364e99de6 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 15 Jan 2025 17:45:42 -0500 Subject: [PATCH 14/19] all the dynamic shapes --- pennylane/capture/dynamic_shapes.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/pennylane/capture/dynamic_shapes.py b/pennylane/capture/dynamic_shapes.py index 2ad58e7d798..6269d82d460 100644 --- a/pennylane/capture/dynamic_shapes.py +++ b/pennylane/capture/dynamic_shapes.py @@ -14,8 +14,11 @@ """ Contains a utility for handling inputs with dynamically shaped arrays. """ +from functools import lru_cache from string import ascii_lowercase +import numpy as np + has_jax = True try: import jax @@ -23,6 +26,16 @@ has_jax = False # pragma: no cover +@lru_cache +def _get_letter(ind: int) -> str: + if ind < 26: + return ascii_lowercase[ind] + # absolutely overkill, but it works + num_letters = int(np.ceil(np.log(ind) / np.log(26))) + letters = (ascii_lowercase[(ind // 26**i) % 26] for i in range(num_letters)) + return "".join(letters) + + def _get_shape_for_array(x, abstract_shapes: list) -> dict: """ Populate the dictionary of abstract axes for a single tensorlike. @@ -47,13 +60,13 @@ def _get_shape_for_array(x, abstract_shapes: list) -> dict: # check if the shape tracer is one we have already encountered for previous_idx, previous_shape in enumerate(abstract_shapes): if s is previous_shape: - abstract_axes[i] = ascii_lowercase[previous_idx] + abstract_axes[i] = _get_letter(previous_idx) found = True break # haven't encountered it, so add it to abstract_axes # and use new letter designation if not found: - abstract_axes[i] = ascii_lowercase[len(abstract_shapes)] + abstract_axes[i] = _get_letter(len(abstract_shapes)) abstract_shapes.append(s) return abstract_axes From d4c7a6b95835cceac6c56ffc8b7e2ffd4d04dc34 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 21 Jan 2025 09:36:27 -0500 Subject: [PATCH 15/19] some clarirications and tests --- pennylane/capture/dynamic_shapes.py | 18 +++++++----- pennylane/capture/intro_to_dynamic_shapes.md | 6 +++- tests/capture/test_dynamic_shapes.py | 31 ++++++++++++++++++++ tests/capture/test_nested_plxpr.py | 11 ++++--- 4 files changed, 53 insertions(+), 13 deletions(-) diff --git a/pennylane/capture/dynamic_shapes.py b/pennylane/capture/dynamic_shapes.py index 6269d82d460..ba7d33f8495 100644 --- a/pennylane/capture/dynamic_shapes.py +++ b/pennylane/capture/dynamic_shapes.py @@ -15,9 +15,7 @@ Contains a utility for handling inputs with dynamically shaped arrays. """ from functools import lru_cache -from string import ascii_lowercase - -import numpy as np +from string import ascii_lowercase as letters has_jax = True try: @@ -29,11 +27,10 @@ @lru_cache def _get_letter(ind: int) -> str: if ind < 26: - return ascii_lowercase[ind] - # absolutely overkill, but it works - num_letters = int(np.ceil(np.log(ind) / np.log(26))) - letters = (ascii_lowercase[(ind // 26**i) % 26] for i in range(num_letters)) - return "".join(letters) + return letters[ind] + if ind < 702: + return letters[ind // 26 - 1] + letters[ind % 26] + raise NotImplementedError("we only support up to 702 dynamic axes") # pragma: no cover def _get_shape_for_array(x, abstract_shapes: list) -> dict: @@ -81,6 +78,9 @@ def determine_abstracted_axes(args): Returns: tuple, tuple: the corresponding abstracted axes and dynamic shapes + Note that "dynamic shapes" only refers to the size of dimensions, but not the number of dimensions. + Even with dynamic shapes mode enabled, we cannot change the number of dimensions. + See the ``intro_to_dynamic_shapes.md`` document for more information on how dynamic shapes work. To make jaxpr from arguments with dynamic shapes, the ``abstracted_axes`` keyword argument must be set. @@ -88,6 +88,8 @@ def determine_abstracted_axes(args): .. code-block:: python + jax.config.update("jax_dynamic_shapes", True) + def f(n): x = jax.numpy.ones((n,)) abstracted_axes, abstract_shapes = qml.capture.determine_abstracted_axes((x,)) diff --git a/pennylane/capture/intro_to_dynamic_shapes.md b/pennylane/capture/intro_to_dynamic_shapes.md index 3cc139726c8..8c29e75a29c 100644 --- a/pennylane/capture/intro_to_dynamic_shapes.md +++ b/pennylane/capture/intro_to_dynamic_shapes.md @@ -8,9 +8,13 @@ import jax Dynamic shapes are an experimental feature of jax with limited support and feature coverage. Without the `"jax_dynamic_shapes"` feature, we can't create arrays whose size depends on an abstract value. +Note that "dynamic shapes" reference an array with a dynamic size for one or more dimensions. +The number of dimensions must still remain fixed. We cannot do `jax.numpy.ones([3] * n)` for +a tracer `n`. + ```python jax.config.update("jax_dynamic_shapes", False) - +``` ```python %xmode Minimal diff --git a/tests/capture/test_dynamic_shapes.py b/tests/capture/test_dynamic_shapes.py index 32ff32d875f..f0c63920815 100644 --- a/tests/capture/test_dynamic_shapes.py +++ b/tests/capture/test_dynamic_shapes.py @@ -127,3 +127,34 @@ def f(*args): _ = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, *flat_args) _ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(arg) + + +def test_input_created_with_jnp_ones(enable_disable): + """Test that determine_abstracted_axes works with manually created dynamic arrays.""" + + def f(n): + m = n + 1 + ones = jax.numpy.ones((m, 3)) + zeros = jax.numpy.zeros((4, n)) + + abstracted_axes, abstract_shapes = determine_abstracted_axes((ones, zeros)) + assert abstracted_axes == ({0: "a"}, {1: "b"}) + assert len(abstract_shapes) == 2 + assert abstract_shapes[0] is m + assert abstract_shapes[1] is n + + _ = jax.make_jaxpr(f)(3) + + +def test_large_number_of_abstract_axes(enable_disable): + """Test that determine_abstracted_axes can handle over 26 abstract axes.""" + + def f(shapes): + ones = jax.numpy.zeros(shapes) + abstracted_axes, abstract_shapes = determine_abstracted_axes((ones,)) + + assert abstracted_axes + assert len(set(abstracted_axes[0].keys())) == 30 # unique keys for each axis + assert len(abstract_shapes) == 30 + + _ = jax.make_jaxpr(f)(list(range(30))) diff --git a/tests/capture/test_nested_plxpr.py b/tests/capture/test_nested_plxpr.py index f2bfebc14d5..62f30a9c49a 100644 --- a/tests/capture/test_nested_plxpr.py +++ b/tests/capture/test_nested_plxpr.py @@ -215,10 +215,13 @@ def g(x, y): def f(x, y): qml.adjoint(g)(x, y) - abstracted_axes = ({"a": {0: "n"}}, {0: "m"}) - jaxpr = jax.make_jaxpr(f, abstracted_axes=abstracted_axes)( - {"a": jax.numpy.arange(2)}, jax.numpy.arange(3) - ) + x_a_axes = {0: "n"} + y_axes = {0: "m"} + x = {"a": jax.numpy.arange(2)} + y = jax.numpy.arange(3) + + abstracted_axes = ({"a": x_a_axes}, y_axes) + jaxpr = jax.make_jaxpr(f, abstracted_axes=abstracted_axes)(x, y) tape = qml.tape.plxpr_to_tape( jaxpr.jaxpr, jaxpr.consts, 3, 4, jax.numpy.arange(3), jax.numpy.arange(4) ) From 9192ebd0ebba9e6ce221349adae397afba45b17a Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Tue, 21 Jan 2025 14:33:27 -0500 Subject: [PATCH 16/19] Update pennylane/capture/dynamic_shapes.py Co-authored-by: Mudit Pandey --- pennylane/capture/dynamic_shapes.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pennylane/capture/dynamic_shapes.py b/pennylane/capture/dynamic_shapes.py index ba7d33f8495..c3c9c7237da 100644 --- a/pennylane/capture/dynamic_shapes.py +++ b/pennylane/capture/dynamic_shapes.py @@ -41,13 +41,13 @@ def _get_shape_for_array(x, abstract_shapes: list) -> dict: Examples of shape -> abstract axes: - * `(3,4) -> {}` - * `(tracer1, ) -> {0: "a"}` - * `(tracer1, tracer1) -> {0: "a", 1: "a"}` - * `(3, tracer1) -> {1: "a"}` - * `(tracer1, 2, tracer2) -> {0: "a", 2: "b"}` + * ``(3,4) -> {}`` + * ``(tracer1, ) -> {0: "a"}`` + * ``(tracer1, tracer1) -> {0: "a", 1: "a"}`` + * ``(3, tracer1) -> {1: "a"}`` + * ``(tracer1, 2, tracer2) -> {0: "a", 2: "b"}`` - `abstract_shapes` contains all the tracers found in shapes. + ``abstract_shapes`` contains all the tracers found in shapes. """ abstract_axes = {} From ab090bf433d5f48c66bda659db180f6e71460b44 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 21 Jan 2025 14:47:56 -0500 Subject: [PATCH 17/19] fix failing tests --- tests/capture/test_capture_cond.py | 21 +++---- tests/capture/test_capture_for_loop.py | 26 ++++---- tests/capture/test_capture_qnode.py | 23 +++---- tests/capture/test_capture_while_loop.py | 25 +++----- tests/capture/test_dynamic_shapes.py | 26 ++++---- tests/capture/test_nested_plxpr.py | 80 ++++++++++-------------- tests/conftest.py | 8 +++ 7 files changed, 92 insertions(+), 117 deletions(-) diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index 3007e45ed42..d3bd0017b68 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -849,21 +849,16 @@ def f(x): qml.assert_equal(q.queue[1].base, qml.RX(0.5, 0)) +@pytest.mark.usefixtures("enable_disable_dynamic_shapes") def test_cond_abstracted_axes(): """Test cond can accept inputs with dynamic shapes.""" - jax.config.update("jax_dynamic_shapes", True) - try: + def workflow(x, predicate): + return qml.cond(predicate, jax.numpy.sum, false_fn=jax.numpy.prod)(x) - def workflow(x, predicate): - return qml.cond(predicate, jax.numpy.sum, false_fn=jax.numpy.prod)(x) + jaxpr = jax.make_jaxpr(workflow, abstracted_axes=({0: "a"}, {}))(jax.numpy.arange(3), True) - jaxpr = jax.make_jaxpr(workflow, abstracted_axes=({0: "a"}, {}))(jax.numpy.arange(3), True) + output_true = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 4, jax.numpy.arange(4), True) + assert qml.math.allclose(output_true[0], 6) # 0 + 1 + 2 + 3 - output_true = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 4, jax.numpy.arange(4), True) - assert qml.math.allclose(output_true[0], 6) # 0 + 1 + 2 + 3 - - output_false = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.arange(2), False) - assert qml.math.allclose(output_false[0], 0) # 0 * 1 - - finally: - jax.config.update("jax_dynamic_shapes", False) + output_false = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.arange(2), False) + assert qml.math.allclose(output_false[0], 0) # 0 * 1 diff --git a/tests/capture/test_capture_for_loop.py b/tests/capture/test_capture_for_loop.py index f967a1b5d1c..174ba768c0e 100644 --- a/tests/capture/test_capture_for_loop.py +++ b/tests/capture/test_capture_for_loop.py @@ -234,26 +234,24 @@ def loop_body(i, array, sum_val): res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, array) assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + @pytest.mark.usefixtures("enable_disable_dynamic_shapes") def test_dynamic_shape_input(self): - jax.config.update("jax_dynamic_shapes", True) - try: + """Test that the for loop can accept inputs with dynamic shapes.""" - def f(x): - n = jax.numpy.shape(x)[0] + def f(x): + n = jax.numpy.shape(x)[0] - @qml.for_loop(n) - def g(_, y): - return y + y + @qml.for_loop(n) + def g(_, y): + return y + y - return g(x) + return g(x) - jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(5)) + jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(5)) - [output] = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3, jax.numpy.arange(3)) - expected = jax.numpy.array([0, 8, 16]) # [0, 1, 2] * 2**3 - assert jax.numpy.allclose(output, expected) - finally: - jax.config.update("jax_dynamic_shapes", False) + [output] = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3, jax.numpy.arange(3)) + expected = jax.numpy.array([0, 8, 16]) # [0, 1, 2] * 2**3 + assert jax.numpy.allclose(output, expected) class TestCaptureCircuitsForLoop: diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index 5b35b13abb7..617a4a655d9 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -366,25 +366,20 @@ def circuit(x): assert qml.math.allclose(res, jax.numpy.cos(x)) +@pytest.mark.usefixtures("enable_disable_dynamic_shapes") def test_dynamic_shape_input(): """Test that the qnode can accept an input with a dynamic shape.""" - jax.config.update("jax_dynamic_shapes", True) - try: - - @qml.qnode(qml.device("default.qubit", wires=1)) - def circuit(x): - qml.RX(jax.numpy.sum(x), 0) - return qml.expval(qml.Z(0)) - - jaxpr = jax.make_jaxpr(circuit, abstracted_axes=("a",))(jax.numpy.arange(4)) + @qml.qnode(qml.device("default.qubit", wires=1)) + def circuit(x): + qml.RX(jax.numpy.sum(x), 0) + return qml.expval(qml.Z(0)) - [output] = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3, jax.numpy.arange(3)) - expected = jax.numpy.cos(0 + 1 + 2) - assert jax.numpy.allclose(expected, output) + jaxpr = jax.make_jaxpr(circuit, abstracted_axes=("a",))(jax.numpy.arange(4)) - finally: - jax.config.update("jax_dynamic_shapes", False) + [output] = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3, jax.numpy.arange(3)) + expected = jax.numpy.cos(0 + 1 + 2) + assert jax.numpy.allclose(expected, output) # pylint: disable=too-many-public-methods diff --git a/tests/capture/test_capture_while_loop.py b/tests/capture/test_capture_while_loop.py index c5de5b1fa39..93aad8b73ab 100644 --- a/tests/capture/test_capture_while_loop.py +++ b/tests/capture/test_capture_while_loop.py @@ -81,27 +81,22 @@ def loop(a, b, idx): assert np.allclose(res_arr1_jxpr, expected), f"Expected {expected}, but got {res_arr1_jxpr}" assert np.allclose(res_idx, res_idx_jxpr) and res_idx_jxpr == 10 + @pytest.mark.usefixtures("enable_disable_dynamic_shapes") def test_while_loop_dyanmic_shape_array(self): """Test while loop can accept ararys with dynamic shapes.""" - jax.config.update("jax_dynamic_shapes", True) + def f(x): + @qml.while_loop(lambda res: jax.numpy.sum(res) < 10) + def g(res): + return res + res - try: + return g(x) - def f(x): - @qml.while_loop(lambda res: jax.numpy.sum(res) < 10) - def g(res): - return res + res + jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(2)) - return g(x) - - jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(2)) - - [output] = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3, jax.numpy.arange(3)) - expected = jax.numpy.array([0, 4, 8]) - assert jax.numpy.allclose(output, expected) - finally: - jax.config.update("jax_dynamic_shapes", False) + [output] = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3, jax.numpy.arange(3)) + expected = jax.numpy.array([0, 4, 8]) + assert jax.numpy.allclose(output, expected) class TestCaptureCircuitsWhileLoop: diff --git a/tests/capture/test_dynamic_shapes.py b/tests/capture/test_dynamic_shapes.py index f0c63920815..a7c2fe8f2bf 100644 --- a/tests/capture/test_dynamic_shapes.py +++ b/tests/capture/test_dynamic_shapes.py @@ -26,14 +26,6 @@ jnp = pytest.importorskip("jax.numpy") -@pytest.fixture -def enable_disable(): - jax.config.update("jax_dynamic_shapes", True) - try: - yield - finally: - jax.config.update("jax_dynamic_shapes", False) - def test_null_if_not_enabled(): """Test None and an empty tuple are returned if dynamic shapes is not enabled.""" @@ -47,7 +39,8 @@ def f(*args): _ = jax.make_jaxpr(f)(jnp.eye(4)) -def test_null_if_no_abstract_shapes(enable_disable): +@pytest.mark.usefixtures("enable_disable_dynamic_shapes") +def test_null_if_no_abstract_shapes(): """Test the None and an empty tuple are returned if no dynamic shapes exist.""" def f(*args): @@ -59,7 +52,8 @@ def f(*args): _ = jax.make_jaxpr(f)(jnp.eye(4)) -def test_single_abstract_shape(enable_disable): +@pytest.mark.usefixtures("enable_disable_dynamic_shapes") +def test_single_abstract_shape(): """Test we get the correct answer for a single abstract shape.""" initial_abstracted_axes = ({0: "a"},) @@ -77,6 +71,7 @@ def f(*args): _ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(jnp.arange(4)) +@pytest.mark.usefixtures("enable_disable_dynamic_shapes") @pytest.mark.parametrize( "initial_abstracted_axes, num_shapes", [ @@ -86,7 +81,7 @@ def f(*args): ], ) def test_single_abstract_shape_multiple_abstract_axes( - enable_disable, initial_abstracted_axes, num_shapes + initial_abstracted_axes, num_shapes ): """Test we get the correct answer for a single input with two abstract axes.""" @@ -103,7 +98,8 @@ def f(*args): _ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(jnp.eye(4)) -def test_pytree_input(enable_disable): +@pytest.mark.usefixtures("enable_disable_dynamic_shapes") +def test_pytree_input(): """Test a pytree input with dynamic shapes.""" initial_abstracted_axes = ( @@ -129,7 +125,8 @@ def f(*args): _ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(arg) -def test_input_created_with_jnp_ones(enable_disable): +@pytest.mark.usefixtures("enable_disable_dynamic_shapes") +def test_input_created_with_jnp_ones(): """Test that determine_abstracted_axes works with manually created dynamic arrays.""" def f(n): @@ -146,7 +143,8 @@ def f(n): _ = jax.make_jaxpr(f)(3) -def test_large_number_of_abstract_axes(enable_disable): +@pytest.mark.usefixtures("enable_disable_dynamic_shapes") +def test_large_number_of_abstract_axes(): """Test that determine_abstracted_axes can handle over 26 abstract axes.""" def f(shapes): diff --git a/tests/capture/test_nested_plxpr.py b/tests/capture/test_nested_plxpr.py index 62f30a9c49a..ee94c9b6599 100644 --- a/tests/capture/test_nested_plxpr.py +++ b/tests/capture/test_nested_plxpr.py @@ -186,54 +186,45 @@ def workflow(x): out = jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 0.5) assert qml.math.isclose(out, qml.math.sin(-(0.5 + 0.3))) + @pytest.mark.usefixtures("enable_disable_dynamic_shapes") def test_dynamic_shape_input(self): """Test that the adjoint transform can accept arrays with dynamic shapes.""" - jax.config.update("jax_dynamic_shapes", True) - try: - def f(x): - qml.adjoint(qml.RX)(x, 0) - - jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(4)) + def f(x): + qml.adjoint(qml.RX)(x, 0) - tape = qml.tape.plxpr_to_tape(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.arange(2)) - expected = qml.adjoint(qml.RX(jax.numpy.arange(2), 0)) - qml.assert_equal(tape[0], expected) + jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(4)) - finally: - jax.config.update("jax_dynamic_shapes", False) + tape = qml.tape.plxpr_to_tape(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.arange(2)) + expected = qml.adjoint(qml.RX(jax.numpy.arange(2), 0)) + qml.assert_equal(tape[0], expected) + @pytest.mark.usefixtures("enable_disable_dynamic_shapes") def test_complicated_dynamic_shape_input(self): """Test a dynamic shape input with a more complicate shape.""" - jax.config.update("jax_dynamic_shapes", True) - try: - - def g(x, y): - qml.RX(x["a"], 0) - qml.RY(y, 0) - - def f(x, y): - qml.adjoint(g)(x, y) - x_a_axes = {0: "n"} - y_axes = {0: "m"} - x = {"a": jax.numpy.arange(2)} - y = jax.numpy.arange(3) + def g(x, y): + qml.RX(x["a"], 0) + qml.RY(y, 0) - abstracted_axes = ({"a": x_a_axes}, y_axes) - jaxpr = jax.make_jaxpr(f, abstracted_axes=abstracted_axes)(x, y) - tape = qml.tape.plxpr_to_tape( - jaxpr.jaxpr, jaxpr.consts, 3, 4, jax.numpy.arange(3), jax.numpy.arange(4) - ) + def f(x, y): + qml.adjoint(g)(x, y) - op1 = qml.adjoint(qml.RY(jax.numpy.arange(4), 0)) - op2 = qml.adjoint(qml.RX(jax.numpy.arange(3), 0)) - qml.assert_equal(op1, tape[0]) - qml.assert_equal(op2, tape[1]) + x_a_axes = {0: "n"} + y_axes = {0: "m"} + x = {"a": jax.numpy.arange(2)} + y = jax.numpy.arange(3) - finally: - jax.config.update("jax_dynamic_shapes", False) + abstracted_axes = ({"a": x_a_axes}, y_axes) + jaxpr = jax.make_jaxpr(f, abstracted_axes=abstracted_axes)(x, y) + tape = qml.tape.plxpr_to_tape( + jaxpr.jaxpr, jaxpr.consts, 3, 4, jax.numpy.arange(3), jax.numpy.arange(4) + ) + op1 = qml.adjoint(qml.RY(jax.numpy.arange(4), 0)) + op2 = qml.adjoint(qml.RX(jax.numpy.arange(3), 0)) + qml.assert_equal(op1, tape[0]) + qml.assert_equal(op2, tape[1]) class TestCtrlQfunc: """Tests for the ctrl primitive.""" @@ -410,22 +401,17 @@ def workflow(x): out = jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 0.5) assert qml.math.isclose(out, -0.5 * qml.math.sin(0.5 + 0.3)) + @pytest.mark.usefixtures("enable_disable_dynamic_shapes") def test_dynamic_shape_input(self): """Test that ctrl can accept dynamic shape inputs.""" - jax.config.update("jax_dynamic_shapes", True) - try: - - def f(x): - qml.ctrl(qml.RX, (2, 3))(x, 0) - - jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(4)) + def f(x): + qml.ctrl(qml.RX, (2, 3))(x, 0) - tape = qml.tape.plxpr_to_tape(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.arange(2)) - expected = qml.ctrl(qml.RX(jax.numpy.arange(2), 0), (2, 3)) - qml.assert_equal(tape[0], expected) + jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(4)) - finally: - jax.config.update("jax_dynamic_shapes", False) + tape = qml.tape.plxpr_to_tape(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.arange(2)) + expected = qml.ctrl(qml.RX(jax.numpy.arange(2), 0), (2, 3)) + qml.assert_equal(tape[0], expected) def test_pytree_input(self): """Test that ctrl can accept pytree inputs.""" diff --git a/tests/conftest.py b/tests/conftest.py index 8c70bf17c44..ab2f34202fb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -170,6 +170,14 @@ def enable_disable_plxpr(): qml.capture.disable() +@pytest.fixture(scope="function") +def enable_disable_dynamic_shapes(): + jax.config.update("jax_dynamic_shapes", True) + try: + yield + finally: + jax.config.update("jax_dynamic_shapes", False) + ####################################################################### try: From c9508ab76f98f8e9433984473ec58a79d33e897e Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 22 Jan 2025 12:01:26 -0500 Subject: [PATCH 18/19] change fixture usage --- tests/capture/test_capture_cond.py | 5 +- tests/capture/test_capture_for_loop.py | 4 +- tests/capture/test_capture_qnode.py | 4 +- tests/capture/test_capture_while_loop.py | 4 +- tests/capture/test_dynamic_shapes.py | 169 +++++++++++------------ 5 files changed, 89 insertions(+), 97 deletions(-) diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index d3bd0017b68..e9c160dc2f1 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -849,9 +849,10 @@ def f(x): qml.assert_equal(q.queue[1].base, qml.RX(0.5, 0)) -@pytest.mark.usefixtures("enable_disable_dynamic_shapes") -def test_cond_abstracted_axes(): +# pylint: disable=unused-argument +def test_cond_abstracted_axes(enable_disable_dynamic_shapes): """Test cond can accept inputs with dynamic shapes.""" + def workflow(x, predicate): return qml.cond(predicate, jax.numpy.sum, false_fn=jax.numpy.prod)(x) diff --git a/tests/capture/test_capture_for_loop.py b/tests/capture/test_capture_for_loop.py index 174ba768c0e..d6e09ad3c81 100644 --- a/tests/capture/test_capture_for_loop.py +++ b/tests/capture/test_capture_for_loop.py @@ -234,8 +234,8 @@ def loop_body(i, array, sum_val): res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, array) assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" - @pytest.mark.usefixtures("enable_disable_dynamic_shapes") - def test_dynamic_shape_input(self): + # pylint: disable=unused-argument + def test_dynamic_shape_input(self, enable_disable_dynamic_shapes): """Test that the for loop can accept inputs with dynamic shapes.""" def f(x): diff --git a/tests/capture/test_capture_qnode.py b/tests/capture/test_capture_qnode.py index 617a4a655d9..c486dc520c5 100644 --- a/tests/capture/test_capture_qnode.py +++ b/tests/capture/test_capture_qnode.py @@ -366,8 +366,8 @@ def circuit(x): assert qml.math.allclose(res, jax.numpy.cos(x)) -@pytest.mark.usefixtures("enable_disable_dynamic_shapes") -def test_dynamic_shape_input(): +# pylint: disable=unused-argument +def test_dynamic_shape_input(enable_disable_dynamic_shapes): """Test that the qnode can accept an input with a dynamic shape.""" @qml.qnode(qml.device("default.qubit", wires=1)) diff --git a/tests/capture/test_capture_while_loop.py b/tests/capture/test_capture_while_loop.py index 93aad8b73ab..cf6bd8ac041 100644 --- a/tests/capture/test_capture_while_loop.py +++ b/tests/capture/test_capture_while_loop.py @@ -81,8 +81,8 @@ def loop(a, b, idx): assert np.allclose(res_arr1_jxpr, expected), f"Expected {expected}, but got {res_arr1_jxpr}" assert np.allclose(res_idx, res_idx_jxpr) and res_idx_jxpr == 10 - @pytest.mark.usefixtures("enable_disable_dynamic_shapes") - def test_while_loop_dyanmic_shape_array(self): + # pylint: disable=unused-argument + def test_while_loop_dyanmic_shape_array(self, enable_disable_dynamic_shapes): """Test while loop can accept ararys with dynamic shapes.""" def f(x): diff --git a/tests/capture/test_dynamic_shapes.py b/tests/capture/test_dynamic_shapes.py index a7c2fe8f2bf..58d2c463a9e 100644 --- a/tests/capture/test_dynamic_shapes.py +++ b/tests/capture/test_dynamic_shapes.py @@ -26,7 +26,6 @@ jnp = pytest.importorskip("jax.numpy") - def test_null_if_not_enabled(): """Test None and an empty tuple are returned if dynamic shapes is not enabled.""" @@ -40,119 +39,111 @@ def f(*args): @pytest.mark.usefixtures("enable_disable_dynamic_shapes") -def test_null_if_no_abstract_shapes(): - """Test the None and an empty tuple are returned if no dynamic shapes exist.""" - - def f(*args): - abstracted_axes, abstract_shapes = determine_abstracted_axes(args) - - assert abstracted_axes is None - assert abstract_shapes == () +class TestDyanmicShapes: - _ = jax.make_jaxpr(f)(jnp.eye(4)) + def test_null_if_no_abstract_shapes(self): + """Test the None and an empty tuple are returned if no dynamic shapes exist.""" + def f(*args): + abstracted_axes, abstract_shapes = determine_abstracted_axes(args) -@pytest.mark.usefixtures("enable_disable_dynamic_shapes") -def test_single_abstract_shape(): - """Test we get the correct answer for a single abstract shape.""" + assert abstracted_axes is None + assert abstract_shapes == () - initial_abstracted_axes = ({0: "a"},) + _ = jax.make_jaxpr(f)(jnp.eye(4)) - def f(*args): - abstracted_axes, abstract_shapes = determine_abstracted_axes(args) + def test_single_abstract_shape(self): + """Test we get the correct answer for a single abstract shape.""" - assert abstracted_axes == initial_abstracted_axes - assert len(abstract_shapes) == 1 + initial_abstracted_axes = ({0: "a"},) - # test we can make jaxpr with these abstracted axes - jaxpr = jax.make_jaxpr(lambda *args: 0, abstracted_axes=abstracted_axes)(*args) - _ = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, *args) + def f(*args): + abstracted_axes, abstract_shapes = determine_abstracted_axes(args) - _ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(jnp.arange(4)) + assert abstracted_axes == initial_abstracted_axes + assert len(abstract_shapes) == 1 + # test we can make jaxpr with these abstracted axes + jaxpr = jax.make_jaxpr(lambda *args: 0, abstracted_axes=abstracted_axes)(*args) + _ = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, *args) -@pytest.mark.usefixtures("enable_disable_dynamic_shapes") -@pytest.mark.parametrize( - "initial_abstracted_axes, num_shapes", - [ - (({0: "a", 1: "b"},), 2), - (({0: "a", 1: "a"},), 1), - (({1: "a"},), 1), - ], -) -def test_single_abstract_shape_multiple_abstract_axes( - initial_abstracted_axes, num_shapes -): - """Test we get the correct answer for a single input with two abstract axes.""" + _ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(jnp.arange(4)) - def f(*args): - abstracted_axes, abstract_shapes = determine_abstracted_axes(args) + @pytest.mark.parametrize( + "initial_abstracted_axes, num_shapes", + [ + (({0: "a", 1: "b"},), 2), + (({0: "a", 1: "a"},), 1), + (({1: "a"},), 1), + ], + ) + def test_single_abstract_shape_multiple_abstract_axes( + self, initial_abstracted_axes, num_shapes + ): + """Test we get the correct answer for a single input with two abstract axes.""" - assert abstracted_axes == initial_abstracted_axes - assert len(abstract_shapes) == num_shapes + def f(*args): + abstracted_axes, abstract_shapes = determine_abstracted_axes(args) - # test we can make jaxpr with these abstracted axes - jaxpr = jax.make_jaxpr(lambda *args: 0, abstracted_axes=abstracted_axes)(*args) - _ = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, *args) + assert abstracted_axes == initial_abstracted_axes + assert len(abstract_shapes) == num_shapes - _ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(jnp.eye(4)) + # test we can make jaxpr with these abstracted axes + jaxpr = jax.make_jaxpr(lambda *args: 0, abstracted_axes=abstracted_axes)(*args) + _ = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, *args) + _ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(jnp.eye(4)) -@pytest.mark.usefixtures("enable_disable_dynamic_shapes") -def test_pytree_input(): - """Test a pytree input with dynamic shapes.""" + def test_pytree_input(self): + """Test a pytree input with dynamic shapes.""" - initial_abstracted_axes = ( - {"input0": {}, "input1": {0: "a"}, "input2": {0: "a"}, "input3": {1: "b"}}, - ) - arg = { - "input0": jnp.arange(5), - "input1": jnp.arange(3), - "input2": jnp.arange(3), - "input3": jnp.eye(4), - } + initial_abstracted_axes = ( + {"input0": {}, "input1": {0: "a"}, "input2": {0: "a"}, "input3": {1: "b"}}, + ) + arg = { + "input0": jnp.arange(5), + "input1": jnp.arange(3), + "input2": jnp.arange(3), + "input3": jnp.eye(4), + } - def f(*args): - abstracted_axes, abstract_shapes = determine_abstracted_axes(args) - assert abstracted_axes == initial_abstracted_axes - assert len(abstract_shapes) == 2 - - # test we can make jaxpr with these abstracted axes - jaxpr = jax.make_jaxpr(lambda *args: 0, abstracted_axes=abstracted_axes)(*args) - flat_args = jax.tree_util.tree_leaves(args) - _ = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, *flat_args) + def f(*args): + abstracted_axes, abstract_shapes = determine_abstracted_axes(args) + assert abstracted_axes == initial_abstracted_axes + assert len(abstract_shapes) == 2 - _ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(arg) + # test we can make jaxpr with these abstracted axes + jaxpr = jax.make_jaxpr(lambda *args: 0, abstracted_axes=abstracted_axes)(*args) + flat_args = jax.tree_util.tree_leaves(args) + _ = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *abstract_shapes, *flat_args) + _ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(arg) -@pytest.mark.usefixtures("enable_disable_dynamic_shapes") -def test_input_created_with_jnp_ones(): - """Test that determine_abstracted_axes works with manually created dynamic arrays.""" + def test_input_created_with_jnp_ones(self): + """Test that determine_abstracted_axes works with manually created dynamic arrays.""" - def f(n): - m = n + 1 - ones = jax.numpy.ones((m, 3)) - zeros = jax.numpy.zeros((4, n)) + def f(n): + m = n + 1 + ones = jax.numpy.ones((m, 3)) + zeros = jax.numpy.zeros((4, n)) - abstracted_axes, abstract_shapes = determine_abstracted_axes((ones, zeros)) - assert abstracted_axes == ({0: "a"}, {1: "b"}) - assert len(abstract_shapes) == 2 - assert abstract_shapes[0] is m - assert abstract_shapes[1] is n + abstracted_axes, abstract_shapes = determine_abstracted_axes((ones, zeros)) + assert abstracted_axes == ({0: "a"}, {1: "b"}) + assert len(abstract_shapes) == 2 + assert abstract_shapes[0] is m + assert abstract_shapes[1] is n - _ = jax.make_jaxpr(f)(3) + _ = jax.make_jaxpr(f)(3) + def test_large_number_of_abstract_axes(self): + """Test that determine_abstracted_axes can handle over 26 abstract axes.""" -@pytest.mark.usefixtures("enable_disable_dynamic_shapes") -def test_large_number_of_abstract_axes(): - """Test that determine_abstracted_axes can handle over 26 abstract axes.""" - - def f(shapes): - ones = jax.numpy.zeros(shapes) - abstracted_axes, abstract_shapes = determine_abstracted_axes((ones,)) + def f(shapes): + ones = jax.numpy.zeros(shapes) + abstracted_axes, abstract_shapes = determine_abstracted_axes((ones,)) - assert abstracted_axes - assert len(set(abstracted_axes[0].keys())) == 30 # unique keys for each axis - assert len(abstract_shapes) == 30 + assert abstracted_axes + assert len(set(abstracted_axes[0].keys())) == 30 # unique keys for each axis + assert len(abstract_shapes) == 30 - _ = jax.make_jaxpr(f)(list(range(30))) + _ = jax.make_jaxpr(f)(list(range(30))) From 35b00915c6379ec876ffc81e17f4e7fcd449b8ac Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 27 Jan 2025 10:08:04 -0500 Subject: [PATCH 19/19] black --- tests/capture/test_nested_plxpr.py | 2 ++ tests/conftest.py | 1 + 2 files changed, 3 insertions(+) diff --git a/tests/capture/test_nested_plxpr.py b/tests/capture/test_nested_plxpr.py index ee94c9b6599..06578cdd126 100644 --- a/tests/capture/test_nested_plxpr.py +++ b/tests/capture/test_nested_plxpr.py @@ -226,6 +226,7 @@ def f(x, y): qml.assert_equal(op1, tape[0]) qml.assert_equal(op2, tape[1]) + class TestCtrlQfunc: """Tests for the ctrl primitive.""" @@ -404,6 +405,7 @@ def workflow(x): @pytest.mark.usefixtures("enable_disable_dynamic_shapes") def test_dynamic_shape_input(self): """Test that ctrl can accept dynamic shape inputs.""" + def f(x): qml.ctrl(qml.RX, (2, 3))(x, 0) diff --git a/tests/conftest.py b/tests/conftest.py index ab2f34202fb..a2296be1239 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -178,6 +178,7 @@ def enable_disable_dynamic_shapes(): finally: jax.config.update("jax_dynamic_shapes", False) + ####################################################################### try: