Skip to content

Commit

Permalink
Add new methods for getting predicates from python passes
Browse files Browse the repository at this point in the history
  • Loading branch information
sjdilkes committed Nov 21, 2024
1 parent e9ceb8b commit e494551
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 1 deletion.
26 changes: 26 additions & 0 deletions pytket/binders/passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,32 @@ PYBIND11_MODULE(passes, m) {
return py::cast(serialise(base_pass));
},
":return: A JSON serializable dictionary representation of the Pass.")
.def(
"get_pre_conditions",
[](const BasePass &base_pass) {
std::vector<PredicatePtr> pre_conditions;
for (const std::pair<
const std::type_index, std::shared_ptr<tket::Predicate>>
&p : base_pass.get_conditions().first) {
pre_conditions.push_back(p.second);
}
return pre_conditions;
},
"Returns the pre condition Predicates for the given pass."
"\n:return: A list of Predicate")
.def(
"get_post_conditions",
[](const BasePass &base_pass) {
std::vector<PredicatePtr> post_conditions;
for (const std::pair<
const std::type_index, std::shared_ptr<tket::Predicate>> &
p : base_pass.get_conditions().second.specific_postcons_) {
post_conditions.push_back(p.second);
}
return post_conditions;
},
"Returns the post condition Predicates for the given pass."
"\n:return: A list of Predicate")
.def_static(
"from_dict",
[](const py::dict &base_pass_dict,
Expand Down
2 changes: 1 addition & 1 deletion pytket/binders/predicates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ PYBIND11_MODULE(predicates, m) {
"implies", &Predicate::implies,
":return: True if predicate implies another one, else False",
py::arg("other"))
.def("__str__", [](const Predicate &) { return "<tket::Predicate>"; })
.def("__str__", &Predicate::to_string)
.def("__repr__", &Predicate::to_string)
.def(
"to_dict",
Expand Down
7 changes: 7 additions & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
Changelog
=========

Unreleased
----------

Features:

* Add `BasePass.get_pre_conditions()` and `BasePass.get_post_conditions()`.

1.35.0 (November 2024)
----------------------

Expand Down
10 changes: 10 additions & 0 deletions pytket/pytket/_tket/passes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ class BasePass:
:param after_apply: Invoked after a pass is applied. The CompilationUnit and a summary of the pass configuration are passed into the callback.
:return: True if pass modified the circuit, else False
"""
def get_post_conditions(self) -> list[pytket._tket.predicates.Predicate]:
"""
Returns the post condition Predicates for the given pass.
:return: A list of Predicate
"""
def get_pre_conditions(self) -> list[pytket._tket.predicates.Predicate]:
"""
Returns the pre condition Predicates for the given pass.
:return: A list of Predicate
"""
def to_dict(self) -> typing.Any:
"""
:return: A JSON serializable dictionary representation of the Pass.
Expand Down
13 changes: 13 additions & 0 deletions pytket/tests/predicates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,19 @@ def test_auto_rebase_deprecation(recwarn: Any) -> None:
assert "deprecated" in str(w.message)


def test_get_pre_conditions() -> None:
pre_cons = GreedyPauliSimp().get_pre_conditions()
gate_set = pre_cons[0].gate_set
assert OpType.CX in gate_set
assert OpType.Measure in gate_set


def test_get_post_conditions() -> None:
gate_set = {OpType.CX, OpType.Rz, OpType.H, OpType.Reset, OpType.Measure}
post_cons = AutoRebase(gate_set).get_post_conditions()
assert post_cons[0].gate_set == gate_set


if __name__ == "__main__":
test_predicate_generation()
test_compilation_unit_generation()
Expand Down

0 comments on commit e494551

Please sign in to comment.