Skip to content

Commit

Permalink
[core] Refactor observe-ansatz pass. (#2288)
Browse files Browse the repository at this point in the history
Replace some data structures with LLVM ADTs for consistency.

Signed-off-by: Eric Schweitz <[email protected]>
  • Loading branch information
schweitzpgi authored Oct 17, 2024
1 parent 78282ef commit 1a05616
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 38 deletions.
2 changes: 1 addition & 1 deletion include/cudaq/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ std::unique_ptr<mlir::Pass> createDelayMeasurementsPass();
std::unique_ptr<mlir::Pass> createExpandMeasurementsPass();
std::unique_ptr<mlir::Pass> createLambdaLiftingPass();
std::unique_ptr<mlir::Pass> createLowerToCFGPass();
std::unique_ptr<mlir::Pass> createObserveAnsatzPass(std::vector<bool> &);
std::unique_ptr<mlir::Pass> createObserveAnsatzPass(const std::vector<bool> &);
std::unique_ptr<mlir::Pass> createQuakeAddMetadata();
std::unique_ptr<mlir::Pass> createQuakeAddDeallocs();
std::unique_ptr<mlir::Pass> createQuakeSynthesizer();
Expand Down
66 changes: 29 additions & 37 deletions lib/Optimizer/Transforms/ObserveAnsatz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,62 +200,56 @@ struct AnsatzFunctionAnalysis {
/// measurement basis change operations.
struct AppendMeasurements : public OpRewritePattern<func::FuncOp> {
explicit AppendMeasurements(MLIRContext *ctx, const AnsatzFunctionInfo &info,
std::vector<bool> &bsf)
ArrayRef<bool> bsf)
: OpRewritePattern(ctx), infoMap(info), termBSF(bsf) {}

/// The pre-computed analysis information
AnsatzFunctionInfo infoMap;

/// The Pauli term representation
std::vector<bool> &termBSF;
ArrayRef<bool> termBSF;

LogicalResult matchAndRewrite(func::FuncOp funcOp,
PatternRewriter &rewriter) const override {
rewriter.startRootUpdate(funcOp);

// Use an Analysis to count the number of qubits.
auto iter = infoMap.find(funcOp);
if (iter == infoMap.end()) {
std::string msg = "Errors encountered in pass analysis\n";
funcOp.emitError(msg);
return failure();
}
if (iter == infoMap.end())
return funcOp.emitOpError("Errors encountered in pass analysis");
auto nQubits = iter->second.nQubits;

if (nQubits != termBSF.size() / 2) {
std::string msg = "Invalid number of binary-symplectic elements "
"provided. Must provide 2 * NQubits = " +
std::to_string(2 * nQubits) + "\n";
funcOp.emitError(msg);
return failure();
}
if (nQubits != termBSF.size() / 2)
return funcOp.emitOpError("Invalid number of binary-symplectic elements "
"provided. Must provide 2 * NQubits = " +
std::to_string(2 * nQubits));

// If the mapping pass was not run, we expect no pre-existing measurements.
if (!iter->second.mappingPassRan && !iter->second.measurements.empty()) {
std::string msg = "Cannot observe kernel with measures in it.\n";
funcOp.emitError(msg);
return failure();
}
if (!iter->second.mappingPassRan && !iter->second.measurements.empty())
return funcOp.emitOpError("Cannot observe kernel with measures in it.");

// Attempt to remove measurements. Note that the mapping pass may add
// measurements to kernels that don't contain any measurements. For
// observe kernels, we remove them here since we are adding specific
// measurements below.
// measurements to kernels that don't contain any measurements. For observe
// kernels, we remove them here since we are adding specific measurements
// below. Note: each `op` in the list of measurements must be removed by the
// end of this loop, otherwise the end result may be incorrect.
for (auto *op : iter->second.measurements) {
bool safeToRemove = [&]() {
for (auto user : op->getUsers())
if (!isa<quake::SinkOp, quake::ReturnWireOp>(user))
return false;
return true;
}();
if (!safeToRemove) {
std::string msg =
"Cannot observe kernel with non dangling measurements.\n";
funcOp.emitError(msg);
return failure();
}
if (!safeToRemove)
return funcOp.emitOpError(
"Cannot observe kernel with non dangling measurements.");

for (auto result : op->getResults())
if (quake::isLinearType(result.getType()))
result.replaceAllUsesWith(op->getOperand(0));

// Force remove `op`.
op->dropAllReferences();
op->erase();
}

Expand All @@ -272,7 +266,7 @@ struct AppendMeasurements : public OpRewritePattern<func::FuncOp> {

// Loop over the binary-symplectic form provided and append
// measurements as necessary.
std::vector<Value> qubitsToMeasure;
SmallVector<Value> qubitsToMeasure;
for (std::size_t i = 0; i < termBSF.size() / 2; i++) {
bool xElement = termBSF[i];
bool zElement = termBSF[i + nQubits];
Expand Down Expand Up @@ -328,13 +322,13 @@ struct AppendMeasurements : public OpRewritePattern<func::FuncOp> {
class ObserveAnsatzPass
: public cudaq::opt::impl::ObserveAnsatzBase<ObserveAnsatzPass> {
protected:
std::vector<bool> binarySymplecticForm;
SmallVector<bool> binarySymplecticForm;

public:
using ObserveAnsatzBase::ObserveAnsatzBase;

ObserveAnsatzPass(std::vector<bool> &bsfData)
: binarySymplecticForm(bsfData) {}
ObserveAnsatzPass(const std::vector<bool> &bsfData)
: binarySymplecticForm{bsfData.begin(), bsfData.end()} {}

void runOnOperation() override {
auto funcOp = dyn_cast<func::FuncOp>(getOperation());
Expand All @@ -359,16 +353,14 @@ class ObserveAnsatzPass
target.addLegalDialect<quake::QuakeDialect>();

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
emitError(funcOp.getLoc(), "failed to observe ansatz");
std::move(patterns))))
signalPassFailure();
}
}
};

} // namespace

std::unique_ptr<mlir::Pass>
cudaq::opt::createObserveAnsatzPass(std::vector<bool> &bsfData) {
return std::make_unique<ObserveAnsatzPass>(bsfData);
cudaq::opt::createObserveAnsatzPass(const std::vector<bool> &packed) {
return std::make_unique<ObserveAnsatzPass>(packed);
}

0 comments on commit 1a05616

Please sign in to comment.