Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] Allow higher order primitives to accept dynamically shaped arrays #6786

Merged
merged 26 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
54f7c67
support dynamic shape inputs for hop's
albi3ro Jan 6, 2025
b1303c1
Merge branch 'master' into dynamic-capture-hop-2
albi3ro Jan 8, 2025
1469ad2
explanation doc
albi3ro Jan 8, 2025
a0c9176
Merge branch 'dynamic-capture-hop-2' of https://github.com/PennyLaneA…
albi3ro Jan 8, 2025
2c59a64
add while loop support
albi3ro Jan 8, 2025
8e6a167
adding tests
albi3ro Jan 9, 2025
5bc1a90
Merge branch 'master' into dynamic-capture-hop-2
albi3ro Jan 9, 2025
6fff950
adding testing
albi3ro Jan 14, 2025
b6a7eb2
Merge branch 'master' into dynamic-capture-hop-2
albi3ro Jan 14, 2025
4b04dae
fix marking
albi3ro Jan 14, 2025
0916fba
Merge branch 'dynamic-capture-hop-2' of https://github.com/PennyLaneA…
albi3ro Jan 14, 2025
a5fb9ee
Update pennylane/capture/dynamic_shapes.py
albi3ro Jan 14, 2025
8604e5b
Apply suggestions from code review
albi3ro Jan 14, 2025
47bdf11
black and changelog
albi3ro Jan 14, 2025
a44e7b9
respond to feedback
albi3ro Jan 15, 2025
53d90b7
Merge branch 'master' into dynamic-capture-hop-2
albi3ro Jan 15, 2025
c57ff02
Apply suggestions from code review
albi3ro Jan 15, 2025
8403a58
Apply suggestions from code review
albi3ro Jan 15, 2025
b54d092
Apply suggestions from code review
albi3ro Jan 15, 2025
88361bc
all the dynamic shapes
albi3ro Jan 15, 2025
d4c7a6b
some clarirications and tests
albi3ro Jan 21, 2025
9192ebd
Update pennylane/capture/dynamic_shapes.py
albi3ro Jan 21, 2025
ab090bf
fix failing tests
albi3ro Jan 21, 2025
c9508ab
change fixture usage
albi3ro Jan 22, 2025
bc12b41
Merge branch 'master' into dynamic-capture-hop-2
albi3ro Jan 27, 2025
35b0091
black
albi3ro Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pennylane/capture/dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ def _get_shape_for_array(x, abstract_shapes: list) -> dict:

Examples of shape -> abstract axes:

* `(3,4) -> {}`
* `(tracer1, ) -> {0: "a"}`
* `(tracer1, tracer1) -> {0: "a", 1: "a"}`
* `(3, tracer1) -> {1: "a"}`
* `(tracer1, 2, tracer2) -> {0: "a", 2: "b"}`
* ``(3,4) -> {}``
* ``(tracer1, ) -> {0: "a"}``
* ``(tracer1, tracer1) -> {0: "a", 1: "a"}``
* ``(3, tracer1) -> {1: "a"}``
* ``(tracer1, 2, tracer2) -> {0: "a", 2: "b"}``

`abstract_shapes` contains all the tracers found in shapes.
``abstract_shapes`` contains all the tracers found in shapes.

"""
abstract_axes = {}
Expand Down
21 changes: 8 additions & 13 deletions tests/capture/test_capture_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,21 +849,16 @@ def f(x):
qml.assert_equal(q.queue[1].base, qml.RX(0.5, 0))


@pytest.mark.usefixtures("enable_disable_dynamic_shapes")
def test_cond_abstracted_axes():
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""Test cond can accept inputs with dynamic shapes."""
jax.config.update("jax_dynamic_shapes", True)
try:
def workflow(x, predicate):
return qml.cond(predicate, jax.numpy.sum, false_fn=jax.numpy.prod)(x)

def workflow(x, predicate):
return qml.cond(predicate, jax.numpy.sum, false_fn=jax.numpy.prod)(x)
jaxpr = jax.make_jaxpr(workflow, abstracted_axes=({0: "a"}, {}))(jax.numpy.arange(3), True)

jaxpr = jax.make_jaxpr(workflow, abstracted_axes=({0: "a"}, {}))(jax.numpy.arange(3), True)
output_true = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 4, jax.numpy.arange(4), True)
assert qml.math.allclose(output_true[0], 6) # 0 + 1 + 2 + 3

output_true = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 4, jax.numpy.arange(4), True)
assert qml.math.allclose(output_true[0], 6) # 0 + 1 + 2 + 3

output_false = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.arange(2), False)
assert qml.math.allclose(output_false[0], 0) # 0 * 1

finally:
jax.config.update("jax_dynamic_shapes", False)
output_false = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.arange(2), False)
assert qml.math.allclose(output_false[0], 0) # 0 * 1
26 changes: 12 additions & 14 deletions tests/capture/test_capture_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,26 +234,24 @@ def loop_body(i, array, sum_val):
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, array)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"

@pytest.mark.usefixtures("enable_disable_dynamic_shapes")
def test_dynamic_shape_input(self):
jax.config.update("jax_dynamic_shapes", True)
try:
"""Test that the for loop can accept inputs with dynamic shapes."""

def f(x):
n = jax.numpy.shape(x)[0]
def f(x):
n = jax.numpy.shape(x)[0]

@qml.for_loop(n)
def g(_, y):
return y + y
@qml.for_loop(n)
def g(_, y):
return y + y

return g(x)
return g(x)

jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(5))
jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(5))

[output] = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3, jax.numpy.arange(3))
expected = jax.numpy.array([0, 8, 16]) # [0, 1, 2] * 2**3
assert jax.numpy.allclose(output, expected)
finally:
jax.config.update("jax_dynamic_shapes", False)
[output] = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3, jax.numpy.arange(3))
expected = jax.numpy.array([0, 8, 16]) # [0, 1, 2] * 2**3
assert jax.numpy.allclose(output, expected)


class TestCaptureCircuitsForLoop:
Expand Down
23 changes: 9 additions & 14 deletions tests/capture/test_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,25 +366,20 @@ def circuit(x):
assert qml.math.allclose(res, jax.numpy.cos(x))


@pytest.mark.usefixtures("enable_disable_dynamic_shapes")
def test_dynamic_shape_input():
"""Test that the qnode can accept an input with a dynamic shape."""

jax.config.update("jax_dynamic_shapes", True)
try:

@qml.qnode(qml.device("default.qubit", wires=1))
def circuit(x):
qml.RX(jax.numpy.sum(x), 0)
return qml.expval(qml.Z(0))

jaxpr = jax.make_jaxpr(circuit, abstracted_axes=("a",))(jax.numpy.arange(4))
@qml.qnode(qml.device("default.qubit", wires=1))
def circuit(x):
qml.RX(jax.numpy.sum(x), 0)
return qml.expval(qml.Z(0))

[output] = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3, jax.numpy.arange(3))
expected = jax.numpy.cos(0 + 1 + 2)
assert jax.numpy.allclose(expected, output)
jaxpr = jax.make_jaxpr(circuit, abstracted_axes=("a",))(jax.numpy.arange(4))

finally:
jax.config.update("jax_dynamic_shapes", False)
[output] = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3, jax.numpy.arange(3))
expected = jax.numpy.cos(0 + 1 + 2)
assert jax.numpy.allclose(expected, output)


# pylint: disable=too-many-public-methods
Expand Down
25 changes: 10 additions & 15 deletions tests/capture/test_capture_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,27 +81,22 @@ def loop(a, b, idx):
assert np.allclose(res_arr1_jxpr, expected), f"Expected {expected}, but got {res_arr1_jxpr}"
assert np.allclose(res_idx, res_idx_jxpr) and res_idx_jxpr == 10

@pytest.mark.usefixtures("enable_disable_dynamic_shapes")
def test_while_loop_dyanmic_shape_array(self):
"""Test while loop can accept ararys with dynamic shapes."""

jax.config.update("jax_dynamic_shapes", True)
def f(x):
@qml.while_loop(lambda res: jax.numpy.sum(res) < 10)
def g(res):
return res + res

try:
return g(x)

def f(x):
@qml.while_loop(lambda res: jax.numpy.sum(res) < 10)
def g(res):
return res + res
jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(2))

return g(x)

jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(2))

[output] = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3, jax.numpy.arange(3))
expected = jax.numpy.array([0, 4, 8])
assert jax.numpy.allclose(output, expected)
finally:
jax.config.update("jax_dynamic_shapes", False)
[output] = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 3, jax.numpy.arange(3))
expected = jax.numpy.array([0, 4, 8])
assert jax.numpy.allclose(output, expected)


class TestCaptureCircuitsWhileLoop:
Expand Down
26 changes: 12 additions & 14 deletions tests/capture/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,6 @@
jnp = pytest.importorskip("jax.numpy")


@pytest.fixture
def enable_disable():
jax.config.update("jax_dynamic_shapes", True)
try:
yield
finally:
jax.config.update("jax_dynamic_shapes", False)


def test_null_if_not_enabled():
"""Test None and an empty tuple are returned if dynamic shapes is not enabled."""
Expand All @@ -47,7 +39,8 @@ def f(*args):
_ = jax.make_jaxpr(f)(jnp.eye(4))


def test_null_if_no_abstract_shapes(enable_disable):
@pytest.mark.usefixtures("enable_disable_dynamic_shapes")
def test_null_if_no_abstract_shapes():
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""Test the None and an empty tuple are returned if no dynamic shapes exist."""

def f(*args):
Expand All @@ -59,7 +52,8 @@ def f(*args):
_ = jax.make_jaxpr(f)(jnp.eye(4))


def test_single_abstract_shape(enable_disable):
@pytest.mark.usefixtures("enable_disable_dynamic_shapes")
def test_single_abstract_shape():
"""Test we get the correct answer for a single abstract shape."""

initial_abstracted_axes = ({0: "a"},)
Expand All @@ -77,6 +71,7 @@ def f(*args):
_ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(jnp.arange(4))


@pytest.mark.usefixtures("enable_disable_dynamic_shapes")
@pytest.mark.parametrize(
"initial_abstracted_axes, num_shapes",
[
Expand All @@ -86,7 +81,7 @@ def f(*args):
],
)
def test_single_abstract_shape_multiple_abstract_axes(
enable_disable, initial_abstracted_axes, num_shapes
initial_abstracted_axes, num_shapes
):
"""Test we get the correct answer for a single input with two abstract axes."""

Expand All @@ -103,7 +98,8 @@ def f(*args):
_ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(jnp.eye(4))
lillian542 marked this conversation as resolved.
Show resolved Hide resolved


def test_pytree_input(enable_disable):
@pytest.mark.usefixtures("enable_disable_dynamic_shapes")
def test_pytree_input():
"""Test a pytree input with dynamic shapes."""

initial_abstracted_axes = (
Expand All @@ -129,7 +125,8 @@ def f(*args):
_ = jax.make_jaxpr(f, abstracted_axes=initial_abstracted_axes)(arg)


def test_input_created_with_jnp_ones(enable_disable):
@pytest.mark.usefixtures("enable_disable_dynamic_shapes")
def test_input_created_with_jnp_ones():
"""Test that determine_abstracted_axes works with manually created dynamic arrays."""

def f(n):
Expand All @@ -146,7 +143,8 @@ def f(n):
_ = jax.make_jaxpr(f)(3)


def test_large_number_of_abstract_axes(enable_disable):
@pytest.mark.usefixtures("enable_disable_dynamic_shapes")
def test_large_number_of_abstract_axes():
"""Test that determine_abstracted_axes can handle over 26 abstract axes."""

def f(shapes):
Expand Down
80 changes: 33 additions & 47 deletions tests/capture/test_nested_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,54 +186,45 @@ def workflow(x):
out = jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 0.5)
assert qml.math.isclose(out, qml.math.sin(-(0.5 + 0.3)))

@pytest.mark.usefixtures("enable_disable_dynamic_shapes")
def test_dynamic_shape_input(self):
"""Test that the adjoint transform can accept arrays with dynamic shapes."""
jax.config.update("jax_dynamic_shapes", True)
try:

def f(x):
qml.adjoint(qml.RX)(x, 0)

jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(4))
def f(x):
qml.adjoint(qml.RX)(x, 0)

tape = qml.tape.plxpr_to_tape(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.arange(2))
expected = qml.adjoint(qml.RX(jax.numpy.arange(2), 0))
qml.assert_equal(tape[0], expected)
jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(4))

finally:
jax.config.update("jax_dynamic_shapes", False)
tape = qml.tape.plxpr_to_tape(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.arange(2))
expected = qml.adjoint(qml.RX(jax.numpy.arange(2), 0))
qml.assert_equal(tape[0], expected)

@pytest.mark.usefixtures("enable_disable_dynamic_shapes")
def test_complicated_dynamic_shape_input(self):
"""Test a dynamic shape input with a more complicate shape."""
jax.config.update("jax_dynamic_shapes", True)
try:

def g(x, y):
qml.RX(x["a"], 0)
qml.RY(y, 0)

def f(x, y):
qml.adjoint(g)(x, y)

x_a_axes = {0: "n"}
y_axes = {0: "m"}
x = {"a": jax.numpy.arange(2)}
y = jax.numpy.arange(3)
def g(x, y):
qml.RX(x["a"], 0)
qml.RY(y, 0)

abstracted_axes = ({"a": x_a_axes}, y_axes)
jaxpr = jax.make_jaxpr(f, abstracted_axes=abstracted_axes)(x, y)
tape = qml.tape.plxpr_to_tape(
jaxpr.jaxpr, jaxpr.consts, 3, 4, jax.numpy.arange(3), jax.numpy.arange(4)
)
def f(x, y):
qml.adjoint(g)(x, y)

op1 = qml.adjoint(qml.RY(jax.numpy.arange(4), 0))
op2 = qml.adjoint(qml.RX(jax.numpy.arange(3), 0))
qml.assert_equal(op1, tape[0])
qml.assert_equal(op2, tape[1])
x_a_axes = {0: "n"}
y_axes = {0: "m"}
x = {"a": jax.numpy.arange(2)}
y = jax.numpy.arange(3)

finally:
jax.config.update("jax_dynamic_shapes", False)
abstracted_axes = ({"a": x_a_axes}, y_axes)
jaxpr = jax.make_jaxpr(f, abstracted_axes=abstracted_axes)(x, y)
tape = qml.tape.plxpr_to_tape(
jaxpr.jaxpr, jaxpr.consts, 3, 4, jax.numpy.arange(3), jax.numpy.arange(4)
)

op1 = qml.adjoint(qml.RY(jax.numpy.arange(4), 0))
op2 = qml.adjoint(qml.RX(jax.numpy.arange(3), 0))
qml.assert_equal(op1, tape[0])
qml.assert_equal(op2, tape[1])

class TestCtrlQfunc:
"""Tests for the ctrl primitive."""
Expand Down Expand Up @@ -410,22 +401,17 @@ def workflow(x):
out = jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts, 0.5)
assert qml.math.isclose(out, -0.5 * qml.math.sin(0.5 + 0.3))

@pytest.mark.usefixtures("enable_disable_dynamic_shapes")
def test_dynamic_shape_input(self):
"""Test that ctrl can accept dynamic shape inputs."""
jax.config.update("jax_dynamic_shapes", True)
try:

def f(x):
qml.ctrl(qml.RX, (2, 3))(x, 0)

jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(4))
def f(x):
qml.ctrl(qml.RX, (2, 3))(x, 0)

tape = qml.tape.plxpr_to_tape(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.arange(2))
expected = qml.ctrl(qml.RX(jax.numpy.arange(2), 0), (2, 3))
qml.assert_equal(tape[0], expected)
jaxpr = jax.make_jaxpr(f, abstracted_axes=("a",))(jax.numpy.arange(4))

finally:
jax.config.update("jax_dynamic_shapes", False)
tape = qml.tape.plxpr_to_tape(jaxpr.jaxpr, jaxpr.consts, 2, jax.numpy.arange(2))
expected = qml.ctrl(qml.RX(jax.numpy.arange(2), 0), (2, 3))
qml.assert_equal(tape[0], expected)

def test_pytree_input(self):
"""Test that ctrl can accept pytree inputs."""
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ def enable_disable_plxpr():
qml.capture.disable()


@pytest.fixture(scope="function")
def enable_disable_dynamic_shapes():
jax.config.update("jax_dynamic_shapes", True)
try:
yield
finally:
jax.config.update("jax_dynamic_shapes", False)

#######################################################################

try:
Expand Down