Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BasePass.get_pre_conditions and BasePass.get_post_conditions #1689

Merged
merged 19 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion pytket/binders/passes.cpp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is also worth adding a function specifically to get the gate-set requirements of a pass; i.e. which gets the preconditions, checks for any GateSetPredicates, and returns the set intersection of their gate sets (or None if there are no GateSetPredicates.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,30 @@ const PassPtr &DecomposeClassicalExp() {
return pp;
}

std::optional<OpTypeSet> get_gate_set(const BasePass &base_pass) {
std::optional<OpTypeSet> allowed_ops;
for (const std::pair<const std::type_index, std::shared_ptr<tket::Predicate>>
&p : base_pass.get_conditions().first) {
std::shared_ptr<GateSetPredicate> gsp_ptr =
std::dynamic_pointer_cast<GateSetPredicate>(p.second);
if (!gsp_ptr) {
continue;
}
OpTypeSet candidate_allowed_ops = gsp_ptr->get_allowed_types();
if (!allowed_ops) {
allowed_ops = candidate_allowed_ops;
} else {
OpTypeSet intersection;
std::set_intersection(
candidate_allowed_ops.begin(), candidate_allowed_ops.end(),
allowed_ops->begin(), allowed_ops->end(),
std::inserter(intersection, intersection.begin()));
allowed_ops = intersection;
}
}
return allowed_ops;
}

PYBIND11_MODULE(passes, m) {
py::module_::import("pytket._tket.predicates");
m.def(
Expand Down Expand Up @@ -212,7 +236,6 @@ PYBIND11_MODULE(passes, m) {
);
}
};

py::class_<BasePass, PassPtr, PyBasePass>(
m, "BasePass", "Base class for passes.")
.def(
Expand Down Expand Up @@ -268,6 +291,38 @@ PYBIND11_MODULE(passes, m) {
return py::cast(serialise(base_pass));
},
":return: A JSON serializable dictionary representation of the Pass.")
.def(
"get_preconditions",
[](const BasePass &base_pass) {
std::vector<PredicatePtr> pre_conditions;
for (const std::pair<
const std::type_index, std::shared_ptr<tket::Predicate>>
&p : base_pass.get_conditions().first) {
pre_conditions.push_back(p.second);
}
return pre_conditions;
},
"Returns the precondition Predicates for the given pass."
"\n:return: A list of Predicate")
.def(
"get_postconditions",
[](const BasePass &base_pass) {
std::vector<PredicatePtr> post_conditions;
for (const std::pair<
const std::type_index, std::shared_ptr<tket::Predicate>> &
p : base_pass.get_conditions().second.specific_postcons_) {
post_conditions.push_back(p.second);
}
return post_conditions;
},
"Returns the postcondition Predicates for the given pass."
"\n\n:return: A list of :py:class:`Predicate`")
.def(
"get_gate_set", &get_gate_set,
"Returns the intersection of all set of OpType for all "
"GateSetPredicate in the `BasePass` preconditions, or `None` "
"if there are no gate-set predicates.",
"\n\n:return: A set of allowed OpType")
.def_static(
"from_dict",
[](const py::dict &base_pass_dict,
Expand Down
2 changes: 1 addition & 1 deletion pytket/binders/predicates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ PYBIND11_MODULE(predicates, m) {
"implies", &Predicate::implies,
":return: True if predicate implies another one, else False",
py::arg("other"))
.def("__str__", [](const Predicate &) { return "<tket::Predicate>"; })
.def("__str__", &Predicate::to_string)
.def("__repr__", &Predicate::to_string)
.def(
"to_dict",
Expand Down
4 changes: 4 additions & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Changelog
Unreleased
----------

Features:

* Add `BasePass.get_preconditions()` and `BasePass.getpost_conditions()`.

API changes:

* Remove the deprecated methods `auto_rebase_pass()` and `auto_squash_pass()`.
Expand Down
15 changes: 15 additions & 0 deletions pytket/pytket/_tket/passes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ class BasePass:
:param after_apply: Invoked after a pass is applied. The CompilationUnit and a summary of the pass configuration are passed into the callback.
:return: True if pass modified the circuit, else False
"""
def get_gate_set(self) -> set[pytket._tket.circuit.OpType] | None:
"""
:return: A set of allowed OpType
"""
def get_postconditions(self) -> list[pytket._tket.predicates.Predicate]:
"""
Returns the postcondition Predicates for the given pass.

:return: A list of :py:class:`Predicate`
"""
def get_preconditions(self) -> list[pytket._tket.predicates.Predicate]:
"""
Returns the precondition Predicates for the given pass.
:return: A list of Predicate
"""
def to_dict(self) -> typing.Any:
"""
:return: A JSON serializable dictionary representation of the Pass.
Expand Down
21 changes: 21 additions & 0 deletions pytket/tests/predicates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,27 @@ def test_greedy_pauli_synth() -> None:
assert GreedyPauliSimp().apply(c)


def test_get_pre_conditions() -> None:
pre_cons = GreedyPauliSimp().get_preconditions()
gate_set = pre_cons[0].gate_set # type: ignore
assert OpType.CX in gate_set
assert OpType.Measure in gate_set


def test_get_post_conditions() -> None:
gate_set = {OpType.CX, OpType.Rz, OpType.H, OpType.Reset, OpType.Measure}
post_cons = AutoRebase(gate_set).get_postconditions()
assert post_cons[0].gate_set == gate_set # type: ignore


def test_get_gate_set() -> None:
gate_set = GreedyPauliSimp().get_gate_set()
assert gate_set is not None
assert OpType.CX in gate_set
assert OpType.Measure in gate_set
assert CliffordPushThroughMeasures().get_gate_set() is None


if __name__ == "__main__":
test_predicate_generation()
test_compilation_unit_generation()
Expand Down
Loading