Skip to content

Commit

Permalink
Non-functioning first draft of plxpr defer_measurements
Browse files Browse the repository at this point in the history
  • Loading branch information
mudit2812 committed Jan 15, 2025
1 parent 4befb11 commit ae67a0b
Showing 1 changed file with 172 additions and 1 deletion.
173 changes: 172 additions & 1 deletion pennylane/transforms/defer_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
"""Code for the tape transform implementing the deferred measurement principle."""

from functools import lru_cache, partial

import pennylane as qml
from pennylane.measurements import CountsMP, MeasurementValue, MidMeasureMP, ProbabilityMP, SampleMP
from pennylane.ops.op_math import ctrl
Expand Down Expand Up @@ -101,7 +103,176 @@ def null_postprocessing(results):
return results[0]


@transform
@lru_cache
def _get_plxpr_defer_measurements():
try:
# pylint: disable=import-outside-toplevel
import jax

from pennylane.capture import PlxprInterpreter
from pennylane.capture.primitives import (
AbstractMeasurement,
AbstractOperator,
cond_prim,
ctrl_transform_prim,
measure_prim,
)
except ImportError:
return None, None

Check warning on line 121 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L120-L121

Added lines #L120 - L121 were not covered by tests

class DeferMeasurementsInterpreter(PlxprInterpreter):

Check notice on line 123 in pennylane/transforms/defer_measurements.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/defer_measurements.py#L123

Redefining name 'DeferMeasurementsInterpreter' from outer scope (line 272) (redefined-outer-name)
"""Interpreter for applying the defer_measurements transform to plxpr."""

def __init__(self, num_wires):
super().__init__()
self._num_wires = num_wires
self._measurements_map = {}

Check warning on line 129 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L127-L129

Added lines #L127 - L129 were not covered by tests

# State variables
self._cur_wire = None
self._cur_measurement_idx = None

Check warning on line 133 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L132-L133

Added lines #L132 - L133 were not covered by tests

def setup(self) -> None:
"""Initialize the instance before interpreting equations.
Blank by default, this method can initialize any additional instance variables
needed by an interpreter. For example, a device interpreter could initialize a statevector,
or a compilation interpreter could initialize a staging area for the latest operation on each wire.
"""
self._cur_wire = self._num_wires - 1
self._cur_measurement_idx = 0

Check warning on line 143 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L142-L143

Added lines #L142 - L143 were not covered by tests

def cleanup(self) -> None:
"""Perform any final steps after iterating through all equations.
Blank by default, this method can clean up instance variables. Particularly,
this method can be used to deallocate qubits and registers when converting to
a Catalyst variant jaxpr.
"""
self._measurements_map = {}
self._cur_wire = None
self._cur_measurement_idx = None

Check warning on line 154 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L152-L154

Added lines #L152 - L154 were not covered by tests

def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list:
"""Evaluate a jaxpr.
Args:
jaxpr (jax.core.Jaxpr): the jaxpr to evaluate
consts (list[TensorLike]): the constant variables for the jaxpr
*args (tuple[TensorLike]): The arguments for the jaxpr.
Returns:
list[TensorLike]: the results of the execution.
"""
self._env = {}

Check notice on line 168 in pennylane/transforms/defer_measurements.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/defer_measurements.py#L168

Attribute '_env' defined outside __init__ (attribute-defined-outside-init)
self.setup()

Check warning on line 169 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L168-L169

Added lines #L168 - L169 were not covered by tests

for arg, invar in zip(args, jaxpr.invars, strict=True):
self._env[invar] = arg
for const, constvar in zip(consts, jaxpr.constvars, strict=True):
self._env[constvar] = const

Check warning on line 174 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L171-L174

Added lines #L171 - L174 were not covered by tests

for eqn in jaxpr.eqns:

Check warning on line 176 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L176

Added line #L176 was not covered by tests

custom_handler = self._primitive_registrations.get(eqn.primitive, None)
if custom_handler:
invals = [self.read(invar) for invar in eqn.invars]
outvals = custom_handler(self, *invals, **eqn.params)
elif isinstance(eqn.outvars[0].aval, AbstractOperator):
outvals = self.interpret_operation_eqn(eqn)
elif isinstance(eqn.outvars[0].aval, AbstractMeasurement):
outvals = self.interpret_measurement_eqn(eqn)

Check warning on line 185 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L178-L185

Added lines #L178 - L185 were not covered by tests
else:
invals = [self.read(invar) for invar in eqn.invars]
outvals = eqn.primitive.bind(*invals, **eqn.params)

Check warning on line 188 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L187-L188

Added lines #L187 - L188 were not covered by tests

if not eqn.primitive.multiple_results:
outvals = [outvals]
for outvar, outval in zip(eqn.outvars, outvals, strict=True):
self._env[outvar] = outval

Check warning on line 193 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L190-L193

Added lines #L190 - L193 were not covered by tests

# Read the final result of the Jaxpr from the environment
outvals = []
for var in jaxpr.outvars:
outval = self.read(var)
if isinstance(outval, qml.operation.Operator):
outvals.append(self.interpret_operation(outval))

Check warning on line 200 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L196-L200

Added lines #L196 - L200 were not covered by tests
else:
outvals.append(outval)
self.cleanup()
self._env = {}

Check notice on line 204 in pennylane/transforms/defer_measurements.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/defer_measurements.py#L204

Attribute '_env' defined outside __init__ (attribute-defined-outside-init)
return outvals

Check warning on line 205 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L202-L205

Added lines #L202 - L205 were not covered by tests

@DeferMeasurementsInterpreter.register_primitive(measure_prim)
def _(self, wires, reset=False, postselect=None):
with qml.QueuingManager.stop_recording():
meas = type.__call__(

Check warning on line 210 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L209-L210

Added lines #L209 - L210 were not covered by tests
MidMeasureMP, Wires(self._cur_wire), reset=reset, postselect=postselect
)

cnot_wires = (wires, self._cur_wire)
self._measurements_map[self._cur_measurement_idx] = self._cur_wire

Check warning on line 215 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L214-L215

Added lines #L214 - L215 were not covered by tests

if postselect is not None:
qml.Projector(jax.numpy.array([postselect]), wires=wires)

Check warning on line 218 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L217-L218

Added lines #L217 - L218 were not covered by tests

qml.CNOT(wires=cnot_wires)

Check warning on line 220 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L220

Added line #L220 was not covered by tests

if reset:
if postselect is None:
qml.CNOT(wires=cnot_wires[::-1])
elif postselect == 1:
qml.X(wires=wires)

Check warning on line 226 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L222-L226

Added lines #L222 - L226 were not covered by tests

# cur_idx = self._cur_measurement_idx
self._cur_measurement_idx += 1
self._cur_wire -= 1
return MeasurementValue([meas], lambda x: x)

Check warning on line 231 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L229-L231

Added lines #L229 - L231 were not covered by tests

@DeferMeasurementsInterpreter.register_primitive(cond_prim)
def _(self, *invals, jaxpr_branches, consts_slices, args_slice):

Check notice on line 234 in pennylane/transforms/defer_measurements.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/defer_measurements.py#L234

Unused argument 'self' (unused-argument)
n_branches = len(jaxpr_branches)
conditions = invals[:n_branches]
args = invals[args_slice]

Check warning on line 237 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L235-L237

Added lines #L235 - L237 were not covered by tests

for i, (condition, jaxpr) in enumerate(zip(conditions, jaxpr_branches, strict=True)):
if isinstance(condition, MeasurementValue):
control_wires = Wires([m.wires[0] for m in condition.measurements])
for branch, value in condition._items():
if value:
cur_consts = invals[consts_slices[i]]
ctrl_transform_prim.bind(

Check warning on line 245 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L239-L245

Added lines #L239 - L245 were not covered by tests
*cur_consts,
*args,
*control_wires,
jaxpr=jaxpr,
n_control=len(control_wires),
control_values=branch,
work_wires=None,
n_consts=len(cur_consts),
)

return []

Check warning on line 256 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L256

Added line #L256 was not covered by tests

def defer_measurements_plxpr_to_plxpr(

Check notice on line 258 in pennylane/transforms/defer_measurements.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/defer_measurements.py#L258

Redefining name 'defer_measurements_plxpr_to_plxpr' from outer scope (line 272) (redefined-outer-name)
jaxpr, consts, targs, tkwargs, *args
): # pylint: disable=unused-argument

interpreter = DeferMeasurementsInterpreter()

Check warning on line 262 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L262

Added line #L262 was not covered by tests

def wrapper(*inner_args):
return interpreter.eval(jaxpr, consts, *inner_args)

Check warning on line 265 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L264-L265

Added lines #L264 - L265 were not covered by tests

return jax.make_jaxpr(wrapper)(*args)

Check warning on line 267 in pennylane/transforms/defer_measurements.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/defer_measurements.py#L267

Added line #L267 was not covered by tests

return DeferMeasurementsInterpreter, defer_measurements_plxpr_to_plxpr


DeferMeasurementsInterpreter, defer_measurements_plxpr_to_plxpr = _get_plxpr_defer_measurements()


@partial(transform, plxpr_transform=defer_measurements_plxpr_to_plxpr)
def defer_measurements(
tape: QuantumScript, reduce_postselected: bool = True, allow_postselect: bool = True
) -> tuple[QuantumScriptBatch, PostprocessingFn]:
Expand Down

0 comments on commit ae67a0b

Please sign in to comment.