diff --git a/pennylane/transforms/defer_measurements.py b/pennylane/transforms/defer_measurements.py index 02f299507c2..dca07daa1a8 100644 --- a/pennylane/transforms/defer_measurements.py +++ b/pennylane/transforms/defer_measurements.py @@ -13,6 +13,8 @@ # limitations under the License. """Code for the tape transform implementing the deferred measurement principle.""" +from functools import lru_cache, partial + import pennylane as qml from pennylane.measurements import CountsMP, MeasurementValue, MidMeasureMP, ProbabilityMP, SampleMP from pennylane.ops.op_math import ctrl @@ -101,7 +103,176 @@ def null_postprocessing(results): return results[0] -@transform +@lru_cache +def _get_plxpr_defer_measurements(): + try: + # pylint: disable=import-outside-toplevel + import jax + + from pennylane.capture import PlxprInterpreter + from pennylane.capture.primitives import ( + AbstractMeasurement, + AbstractOperator, + cond_prim, + ctrl_transform_prim, + measure_prim, + ) + except ImportError: + return None, None + + class DeferMeasurementsInterpreter(PlxprInterpreter): + """Interpreter for applying the defer_measurements transform to plxpr.""" + + def __init__(self, num_wires): + super().__init__() + self._num_wires = num_wires + self._measurements_map = {} + + # State variables + self._cur_wire = None + self._cur_measurement_idx = None + + def setup(self) -> None: + """Initialize the instance before interpreting equations. + + Blank by default, this method can initialize any additional instance variables + needed by an interpreter. For example, a device interpreter could initialize a statevector, + or a compilation interpreter could initialize a staging area for the latest operation on each wire. + """ + self._cur_wire = self._num_wires - 1 + self._cur_measurement_idx = 0 + + def cleanup(self) -> None: + """Perform any final steps after iterating through all equations. + + Blank by default, this method can clean up instance variables. Particularly, + this method can be used to deallocate qubits and registers when converting to + a Catalyst variant jaxpr. + """ + self._measurements_map = {} + self._cur_wire = None + self._cur_measurement_idx = None + + def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: + """Evaluate a jaxpr. + + Args: + jaxpr (jax.core.Jaxpr): the jaxpr to evaluate + consts (list[TensorLike]): the constant variables for the jaxpr + *args (tuple[TensorLike]): The arguments for the jaxpr. + + Returns: + list[TensorLike]: the results of the execution. + + """ + self._env = {} + self.setup() + + for arg, invar in zip(args, jaxpr.invars, strict=True): + self._env[invar] = arg + for const, constvar in zip(consts, jaxpr.constvars, strict=True): + self._env[constvar] = const + + for eqn in jaxpr.eqns: + + custom_handler = self._primitive_registrations.get(eqn.primitive, None) + if custom_handler: + invals = [self.read(invar) for invar in eqn.invars] + outvals = custom_handler(self, *invals, **eqn.params) + elif isinstance(eqn.outvars[0].aval, AbstractOperator): + outvals = self.interpret_operation_eqn(eqn) + elif isinstance(eqn.outvars[0].aval, AbstractMeasurement): + outvals = self.interpret_measurement_eqn(eqn) + else: + invals = [self.read(invar) for invar in eqn.invars] + outvals = eqn.primitive.bind(*invals, **eqn.params) + + if not eqn.primitive.multiple_results: + outvals = [outvals] + for outvar, outval in zip(eqn.outvars, outvals, strict=True): + self._env[outvar] = outval + + # Read the final result of the Jaxpr from the environment + outvals = [] + for var in jaxpr.outvars: + outval = self.read(var) + if isinstance(outval, qml.operation.Operator): + outvals.append(self.interpret_operation(outval)) + else: + outvals.append(outval) + self.cleanup() + self._env = {} + return outvals + + @DeferMeasurementsInterpreter.register_primitive(measure_prim) + def _(self, wires, reset=False, postselect=None): + with qml.QueuingManager.stop_recording(): + meas = type.__call__( + MidMeasureMP, Wires(self._cur_wire), reset=reset, postselect=postselect + ) + + cnot_wires = (wires, self._cur_wire) + self._measurements_map[self._cur_measurement_idx] = self._cur_wire + + if postselect is not None: + qml.Projector(jax.numpy.array([postselect]), wires=wires) + + qml.CNOT(wires=cnot_wires) + + if reset: + if postselect is None: + qml.CNOT(wires=cnot_wires[::-1]) + elif postselect == 1: + qml.X(wires=wires) + + # cur_idx = self._cur_measurement_idx + self._cur_measurement_idx += 1 + self._cur_wire -= 1 + return MeasurementValue([meas], lambda x: x) + + @DeferMeasurementsInterpreter.register_primitive(cond_prim) + def _(self, *invals, jaxpr_branches, consts_slices, args_slice): + n_branches = len(jaxpr_branches) + conditions = invals[:n_branches] + args = invals[args_slice] + + for i, (condition, jaxpr) in enumerate(zip(conditions, jaxpr_branches, strict=True)): + if isinstance(condition, MeasurementValue): + control_wires = Wires([m.wires[0] for m in condition.measurements]) + for branch, value in condition._items(): + if value: + cur_consts = invals[consts_slices[i]] + ctrl_transform_prim.bind( + *cur_consts, + *args, + *control_wires, + jaxpr=jaxpr, + n_control=len(control_wires), + control_values=branch, + work_wires=None, + n_consts=len(cur_consts), + ) + + return [] + + def defer_measurements_plxpr_to_plxpr( + jaxpr, consts, targs, tkwargs, *args + ): # pylint: disable=unused-argument + + interpreter = DeferMeasurementsInterpreter() + + def wrapper(*inner_args): + return interpreter.eval(jaxpr, consts, *inner_args) + + return jax.make_jaxpr(wrapper)(*args) + + return DeferMeasurementsInterpreter, defer_measurements_plxpr_to_plxpr + + +DeferMeasurementsInterpreter, defer_measurements_plxpr_to_plxpr = _get_plxpr_defer_measurements() + + +@partial(transform, plxpr_transform=defer_measurements_plxpr_to_plxpr) def defer_measurements( tape: QuantumScript, reduce_postselected: bool = True, allow_postselect: bool = True ) -> tuple[QuantumScriptBatch, PostprocessingFn]: