diff --git a/include/cudaq/Optimizer/Transforms/Passes.h b/include/cudaq/Optimizer/Transforms/Passes.h index cf39c803d6..77f461bec2 100644 --- a/include/cudaq/Optimizer/Transforms/Passes.h +++ b/include/cudaq/Optimizer/Transforms/Passes.h @@ -34,7 +34,7 @@ std::unique_ptr createDelayMeasurementsPass(); std::unique_ptr createExpandMeasurementsPass(); std::unique_ptr createLambdaLiftingPass(); std::unique_ptr createLowerToCFGPass(); -std::unique_ptr createObserveAnsatzPass(std::vector &); +std::unique_ptr createObserveAnsatzPass(const std::vector &); std::unique_ptr createQuakeAddMetadata(); std::unique_ptr createQuakeAddDeallocs(); std::unique_ptr createQuakeSynthesizer(); diff --git a/lib/Optimizer/Transforms/ObserveAnsatz.cpp b/lib/Optimizer/Transforms/ObserveAnsatz.cpp index 0d2ba3b38a..79c87f00fb 100644 --- a/lib/Optimizer/Transforms/ObserveAnsatz.cpp +++ b/lib/Optimizer/Transforms/ObserveAnsatz.cpp @@ -200,14 +200,14 @@ struct AnsatzFunctionAnalysis { /// measurement basis change operations. struct AppendMeasurements : public OpRewritePattern { explicit AppendMeasurements(MLIRContext *ctx, const AnsatzFunctionInfo &info, - std::vector &bsf) + ArrayRef bsf) : OpRewritePattern(ctx), infoMap(info), termBSF(bsf) {} /// The pre-computed analysis information AnsatzFunctionInfo infoMap; /// The Pauli term representation - std::vector &termBSF; + ArrayRef termBSF; LogicalResult matchAndRewrite(func::FuncOp funcOp, PatternRewriter &rewriter) const override { @@ -215,31 +215,24 @@ struct AppendMeasurements : public OpRewritePattern { // Use an Analysis to count the number of qubits. auto iter = infoMap.find(funcOp); - if (iter == infoMap.end()) { - std::string msg = "Errors encountered in pass analysis\n"; - funcOp.emitError(msg); - return failure(); - } + if (iter == infoMap.end()) + return funcOp.emitOpError("Errors encountered in pass analysis"); auto nQubits = iter->second.nQubits; - if (nQubits != termBSF.size() / 2) { - std::string msg = "Invalid number of binary-symplectic elements " - "provided. Must provide 2 * NQubits = " + - std::to_string(2 * nQubits) + "\n"; - funcOp.emitError(msg); - return failure(); - } + if (nQubits != termBSF.size() / 2) + return funcOp.emitOpError("Invalid number of binary-symplectic elements " + "provided. Must provide 2 * NQubits = " + + std::to_string(2 * nQubits)); // If the mapping pass was not run, we expect no pre-existing measurements. - if (!iter->second.mappingPassRan && !iter->second.measurements.empty()) { - std::string msg = "Cannot observe kernel with measures in it.\n"; - funcOp.emitError(msg); - return failure(); - } + if (!iter->second.mappingPassRan && !iter->second.measurements.empty()) + return funcOp.emitOpError("Cannot observe kernel with measures in it."); + // Attempt to remove measurements. Note that the mapping pass may add - // measurements to kernels that don't contain any measurements. For - // observe kernels, we remove them here since we are adding specific - // measurements below. + // measurements to kernels that don't contain any measurements. For observe + // kernels, we remove them here since we are adding specific measurements + // below. Note: each `op` in the list of measurements must be removed by the + // end of this loop, otherwise the end result may be incorrect. for (auto *op : iter->second.measurements) { bool safeToRemove = [&]() { for (auto user : op->getUsers()) @@ -247,15 +240,16 @@ struct AppendMeasurements : public OpRewritePattern { return false; return true; }(); - if (!safeToRemove) { - std::string msg = - "Cannot observe kernel with non dangling measurements.\n"; - funcOp.emitError(msg); - return failure(); - } + if (!safeToRemove) + return funcOp.emitOpError( + "Cannot observe kernel with non dangling measurements."); + for (auto result : op->getResults()) if (quake::isLinearType(result.getType())) result.replaceAllUsesWith(op->getOperand(0)); + + // Force remove `op`. + op->dropAllReferences(); op->erase(); } @@ -272,7 +266,7 @@ struct AppendMeasurements : public OpRewritePattern { // Loop over the binary-symplectic form provided and append // measurements as necessary. - std::vector qubitsToMeasure; + SmallVector qubitsToMeasure; for (std::size_t i = 0; i < termBSF.size() / 2; i++) { bool xElement = termBSF[i]; bool zElement = termBSF[i + nQubits]; @@ -328,13 +322,13 @@ struct AppendMeasurements : public OpRewritePattern { class ObserveAnsatzPass : public cudaq::opt::impl::ObserveAnsatzBase { protected: - std::vector binarySymplecticForm; + SmallVector binarySymplecticForm; public: using ObserveAnsatzBase::ObserveAnsatzBase; - ObserveAnsatzPass(std::vector &bsfData) - : binarySymplecticForm(bsfData) {} + ObserveAnsatzPass(const std::vector &bsfData) + : binarySymplecticForm{bsfData.begin(), bsfData.end()} {} void runOnOperation() override { auto funcOp = dyn_cast(getOperation()); @@ -359,16 +353,14 @@ class ObserveAnsatzPass target.addLegalDialect(); if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { - emitError(funcOp.getLoc(), "failed to observe ansatz"); + std::move(patterns)))) signalPassFailure(); - } } }; } // namespace std::unique_ptr -cudaq::opt::createObserveAnsatzPass(std::vector &bsfData) { - return std::make_unique(bsfData); +cudaq::opt::createObserveAnsatzPass(const std::vector &packed) { + return std::make_unique(packed); }