Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CAPTURE] defer_measurements is plxpr compatible #6838

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 172 additions & 1 deletion pennylane/transforms/defer_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
"""Code for the tape transform implementing the deferred measurement principle."""

from functools import lru_cache, partial

import pennylane as qml
from pennylane.measurements import CountsMP, MeasurementValue, MidMeasureMP, ProbabilityMP, SampleMP
from pennylane.ops.op_math import ctrl
Expand Down Expand Up @@ -101,7 +103,176 @@
return results[0]


@transform
@lru_cache
def _get_plxpr_defer_measurements():
try:
# pylint: disable=import-outside-toplevel
import jax

from pennylane.capture import PlxprInterpreter
from pennylane.capture.primitives import (
AbstractMeasurement,
AbstractOperator,
cond_prim,
ctrl_transform_prim,
measure_prim,
)
except ImportError:
return None, None

Check warning on line 121 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L120-L121

Added lines #L120 - L121 were not covered by tests

class DeferMeasurementsInterpreter(PlxprInterpreter):

Check notice on line 123 in pennylane/transforms/defer_measurements.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/defer_measurements.py#L123

Redefining name 'DeferMeasurementsInterpreter' from outer scope (line 272) (redefined-outer-name)
"""Interpreter for applying the defer_measurements transform to plxpr."""

def __init__(self, num_wires):
super().__init__()
self._num_wires = num_wires
self._measurements_map = {}

Check warning on line 129 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L127-L129

Added lines #L127 - L129 were not covered by tests

# State variables
self._cur_wire = None
self._cur_measurement_idx = None

Check warning on line 133 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L132-L133

Added lines #L132 - L133 were not covered by tests

def setup(self) -> None:
"""Initialize the instance before interpreting equations.

Blank by default, this method can initialize any additional instance variables
needed by an interpreter. For example, a device interpreter could initialize a statevector,
or a compilation interpreter could initialize a staging area for the latest operation on each wire.
"""
self._cur_wire = self._num_wires - 1
self._cur_measurement_idx = 0

Check warning on line 143 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L142-L143

Added lines #L142 - L143 were not covered by tests

def cleanup(self) -> None:
"""Perform any final steps after iterating through all equations.

Blank by default, this method can clean up instance variables. Particularly,
this method can be used to deallocate qubits and registers when converting to
a Catalyst variant jaxpr.
"""
self._measurements_map = {}
self._cur_wire = None
self._cur_measurement_idx = None

Check warning on line 154 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L152-L154

Added lines #L152 - L154 were not covered by tests

def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list:
"""Evaluate a jaxpr.

Args:
jaxpr (jax.core.Jaxpr): the jaxpr to evaluate
consts (list[TensorLike]): the constant variables for the jaxpr
*args (tuple[TensorLike]): The arguments for the jaxpr.

Returns:
list[TensorLike]: the results of the execution.

"""
self._env = {}

Check notice on line 168 in pennylane/transforms/defer_measurements.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/defer_measurements.py#L168

Attribute '_env' defined outside __init__ (attribute-defined-outside-init)
self.setup()

Check warning on line 169 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L168-L169

Added lines #L168 - L169 were not covered by tests

for arg, invar in zip(args, jaxpr.invars, strict=True):
self._env[invar] = arg
for const, constvar in zip(consts, jaxpr.constvars, strict=True):
self._env[constvar] = const

Check warning on line 174 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L171-L174

Added lines #L171 - L174 were not covered by tests

for eqn in jaxpr.eqns:

Check warning on line 176 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L176

Added line #L176 was not covered by tests

custom_handler = self._primitive_registrations.get(eqn.primitive, None)
if custom_handler:
invals = [self.read(invar) for invar in eqn.invars]
outvals = custom_handler(self, *invals, **eqn.params)
elif isinstance(eqn.outvars[0].aval, AbstractOperator):
outvals = self.interpret_operation_eqn(eqn)
elif isinstance(eqn.outvars[0].aval, AbstractMeasurement):
outvals = self.interpret_measurement_eqn(eqn)

Check warning on line 185 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L178-L185

Added lines #L178 - L185 were not covered by tests
else:
invals = [self.read(invar) for invar in eqn.invars]
outvals = eqn.primitive.bind(*invals, **eqn.params)

Check warning on line 188 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L187-L188

Added lines #L187 - L188 were not covered by tests

if not eqn.primitive.multiple_results:
outvals = [outvals]
for outvar, outval in zip(eqn.outvars, outvals, strict=True):
self._env[outvar] = outval

Check warning on line 193 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L190-L193

Added lines #L190 - L193 were not covered by tests

# Read the final result of the Jaxpr from the environment
outvals = []
for var in jaxpr.outvars:
outval = self.read(var)
if isinstance(outval, qml.operation.Operator):
outvals.append(self.interpret_operation(outval))

Check warning on line 200 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L196-L200

Added lines #L196 - L200 were not covered by tests
else:
outvals.append(outval)
self.cleanup()
self._env = {}

Check notice on line 204 in pennylane/transforms/defer_measurements.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/defer_measurements.py#L204

Attribute '_env' defined outside __init__ (attribute-defined-outside-init)
return outvals

Check warning on line 205 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L202-L205

Added lines #L202 - L205 were not covered by tests

@DeferMeasurementsInterpreter.register_primitive(measure_prim)
def _(self, wires, reset=False, postselect=None):
with qml.QueuingManager.stop_recording():
meas = type.__call__(

Check warning on line 210 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L209-L210

Added lines #L209 - L210 were not covered by tests
MidMeasureMP, Wires(self._cur_wire), reset=reset, postselect=postselect
)

cnot_wires = (wires, self._cur_wire)
self._measurements_map[self._cur_measurement_idx] = self._cur_wire

Check warning on line 215 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L214-L215

Added lines #L214 - L215 were not covered by tests

if postselect is not None:
qml.Projector(jax.numpy.array([postselect]), wires=wires)

Check warning on line 218 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L217-L218

Added lines #L217 - L218 were not covered by tests

qml.CNOT(wires=cnot_wires)

Check warning on line 220 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L220

Added line #L220 was not covered by tests

if reset:
if postselect is None:
qml.CNOT(wires=cnot_wires[::-1])
elif postselect == 1:
qml.X(wires=wires)

Check warning on line 226 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L222-L226

Added lines #L222 - L226 were not covered by tests

# cur_idx = self._cur_measurement_idx
self._cur_measurement_idx += 1
self._cur_wire -= 1
return MeasurementValue([meas], lambda x: x)

Check warning on line 231 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L229-L231

Added lines #L229 - L231 were not covered by tests

@DeferMeasurementsInterpreter.register_primitive(cond_prim)
def _(self, *invals, jaxpr_branches, consts_slices, args_slice):

Check notice on line 234 in pennylane/transforms/defer_measurements.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/defer_measurements.py#L234

Unused argument 'self' (unused-argument)
n_branches = len(jaxpr_branches)
conditions = invals[:n_branches]
args = invals[args_slice]

Check warning on line 237 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L235-L237

Added lines #L235 - L237 were not covered by tests

for i, (condition, jaxpr) in enumerate(zip(conditions, jaxpr_branches, strict=True)):
if isinstance(condition, MeasurementValue):
control_wires = Wires([m.wires[0] for m in condition.measurements])
for branch, value in condition._items():
if value:
cur_consts = invals[consts_slices[i]]
ctrl_transform_prim.bind(

Check warning on line 245 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L239-L245

Added lines #L239 - L245 were not covered by tests
*cur_consts,
*args,
*control_wires,
jaxpr=jaxpr,
n_control=len(control_wires),
control_values=branch,
work_wires=None,
n_consts=len(cur_consts),
)

return []

Check warning on line 256 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L256

Added line #L256 was not covered by tests

def defer_measurements_plxpr_to_plxpr(

Check notice on line 258 in pennylane/transforms/defer_measurements.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/defer_measurements.py#L258

Redefining name 'defer_measurements_plxpr_to_plxpr' from outer scope (line 272) (redefined-outer-name)
jaxpr, consts, targs, tkwargs, *args
): # pylint: disable=unused-argument

interpreter = DeferMeasurementsInterpreter()

Check warning on line 262 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L262

Added line #L262 was not covered by tests

def wrapper(*inner_args):
return interpreter.eval(jaxpr, consts, *inner_args)

Check warning on line 265 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L264-L265

Added lines #L264 - L265 were not covered by tests

return jax.make_jaxpr(wrapper)(*args)

Check warning on line 267 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L267

Added line #L267 was not covered by tests

return DeferMeasurementsInterpreter, defer_measurements_plxpr_to_plxpr


DeferMeasurementsInterpreter, defer_measurements_plxpr_to_plxpr = _get_plxpr_defer_measurements()


@partial(transform, plxpr_transform=defer_measurements_plxpr_to_plxpr)
def defer_measurements(
tape: QuantumScript, reduce_postselected: bool = True, allow_postselect: bool = True
) -> tuple[QuantumScriptBatch, PostprocessingFn]:
Expand Down
Loading