Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BasePass.get_pre_conditions and BasePass.get_post_conditions #1689

Merged
merged 19 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion pytket/binders/passes.cpp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is also worth adding a function specifically to get the gate-set requirements of a pass; i.e. which gets the preconditions, checks for any GateSetPredicates, and returns the set intersection of their gate sets (or None if there are no GateSetPredicates.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,31 @@ const PassPtr &DecomposeClassicalExp() {
return pp;
}

std::optional<OpTypeSet> get_gate_set(const BasePass &base_pass) {
OpTypeSet allowed_ops;
for (const std::pair<const std::type_index, std::shared_ptr<tket::Predicate>>
&p : base_pass.get_conditions().first) {
if (p.second->to_string().substr(0, 17) == "GateSetPredicate:") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this check is necessary. We can just use dynamic_cast and ignore it if it returns nullptr.

std::shared_ptr<GateSetPredicate> gsp_ptr =
std::dynamic_pointer_cast<GateSetPredicate>(p.second);
if (allowed_ops.empty()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will end up giving the wrong result in the (admittedly unlikely) case that we have already encountered two GateSetPredicates with empty intersection. I suggest using a std::optional<OpTypeSet>, where no-value means all all ops are allowed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done (I think)

allowed_ops = gsp_ptr->get_allowed_types();
continue;
}
OpTypeSet intersection;
OpTypeSet candidate_allowed_ops = gsp_ptr->get_allowed_types();
std::set_intersection(
candidate_allowed_ops.begin(), candidate_allowed_ops.end(),
allowed_ops.begin(), allowed_ops.end(),
std::inserter(intersection, intersection.begin()));
allowed_ops = intersection;
}
}
if (allowed_ops.empty()) return {};

return allowed_ops;
}

PYBIND11_MODULE(passes, m) {
py::module_::import("pytket._tket.predicates");
m.def(
Expand Down Expand Up @@ -212,7 +237,6 @@ PYBIND11_MODULE(passes, m) {
);
}
};

py::class_<BasePass, PassPtr, PyBasePass>(
m, "BasePass", "Base class for passes.")
.def(
Expand Down Expand Up @@ -268,6 +292,37 @@ PYBIND11_MODULE(passes, m) {
return py::cast(serialise(base_pass));
},
":return: A JSON serializable dictionary representation of the Pass.")
.def(
"get_preconditions",
[](const BasePass &base_pass) {
std::vector<PredicatePtr> pre_conditions;
for (const std::pair<
const std::type_index, std::shared_ptr<tket::Predicate>>
&p : base_pass.get_conditions().first) {
pre_conditions.push_back(p.second);
}
return pre_conditions;
},
"Returns the precondition Predicates for the given pass."
"\n:return: A list of Predicate")
.def(
"get_postconditions",
[](const BasePass &base_pass) {
std::vector<PredicatePtr> post_conditions;
for (const std::pair<
const std::type_index, std::shared_ptr<tket::Predicate>> &
p : base_pass.get_conditions().second.specific_postcons_) {
post_conditions.push_back(p.second);
}
return post_conditions;
},
"Returns the postcondition Predicates for the given pass."
"\n\n:return: A list of :py:class:`Predicate`")
.def(
"get_gate_set", &get_gate_set,
"Returns the intersection of all set of OpType for all "
"GateSetPredicate in the `BasePass` preconditions.",
"\n\n:return: A set of allowed OpType")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Returns the intersection of all set of OpType for all "
"GateSetPredicate in the `BasePass` preconditions.",
"\n\n:return: A set of allowed OpType")
"Returns the intersection of all set of OpType for all "
"GateSetPredicate in the `BasePass` preconditions, or `None` if there are no gate-set predicares.",
"\n\n:return: A set of allowed OpType, or `None`")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

.def_static(
"from_dict",
[](const py::dict &base_pass_dict,
Expand Down
2 changes: 1 addition & 1 deletion pytket/binders/predicates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ PYBIND11_MODULE(predicates, m) {
"implies", &Predicate::implies,
":return: True if predicate implies another one, else False",
py::arg("other"))
.def("__str__", [](const Predicate &) { return "<tket::Predicate>"; })
.def("__str__", &Predicate::to_string)
.def("__repr__", &Predicate::to_string)
.def(
"to_dict",
Expand Down
2 changes: 1 addition & 1 deletion pytket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def requirements(self):
self.requires("pybind11_json/0.2.14")
self.requires("symengine/0.13.0")
self.requires("tkassert/0.3.4@tket/stable")
self.requires("tket/1.3.49@tket/stable")
self.requires("tket/1.3.50@tket/stable")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't need updating.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reverted - let's see if the same issue comes up

self.requires("tklog/0.3.3@tket/stable")
self.requires("tkrng/0.3.3@tket/stable")
self.requires("tktokenswap/0.3.9@tket/stable")
Expand Down
4 changes: 4 additions & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Changelog
Unreleased
----------

Features:

* Add `BasePass.get_preconditions()` and `BasePass.getpost_conditions()`.

API changes:

* Remove the deprecated methods `auto_rebase_pass()` and `auto_squash_pass()`.
Expand Down
15 changes: 15 additions & 0 deletions pytket/pytket/_tket/passes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ class BasePass:
:param after_apply: Invoked after a pass is applied. The CompilationUnit and a summary of the pass configuration are passed into the callback.
:return: True if pass modified the circuit, else False
"""
def get_gate_set(self) -> set[pytket._tket.circuit.OpType] | None:
"""
:return: A set of allowed OpType
"""
def get_postconditions(self) -> list[pytket._tket.predicates.Predicate]:
"""
Returns the postcondition Predicates for the given pass.

:return: A list of :py:class:`Predicate`
"""
def get_preconditions(self) -> list[pytket._tket.predicates.Predicate]:
"""
Returns the precondition Predicates for the given pass.
:return: A list of Predicate
"""
def to_dict(self) -> typing.Any:
"""
:return: A JSON serializable dictionary representation of the Pass.
Expand Down
20 changes: 20 additions & 0 deletions pytket/tests/predicates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,26 @@ def test_greedy_pauli_synth() -> None:
assert GreedyPauliSimp().apply(c)


def test_get_pre_conditions() -> None:
pre_cons = GreedyPauliSimp().get_preconditions()
gate_set = pre_cons[0].gate_set # type: ignore
assert OpType.CX in gate_set
assert OpType.Measure in gate_set


def test_get_post_conditions() -> None:
gate_set = {OpType.CX, OpType.Rz, OpType.H, OpType.Reset, OpType.Measure}
post_cons = AutoRebase(gate_set).get_postconditions()
assert post_cons[0].gate_set == gate_set # type: ignore


def test_get_gate_set() -> None:
gate_set = GreedyPauliSimp().get_gate_set()
assert OpType.CX in gate_set # type: ignore
assert OpType.Measure in gate_set # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be able to avoid the type: ignores by first asserting that gate_set is not None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

assert CliffordPushThroughMeasures().get_gate_set() is None


if __name__ == "__main__":
test_predicate_generation()
test_compilation_unit_generation()
Expand Down
2 changes: 1 addition & 1 deletion tket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class TketConan(ConanFile):
name = "tket"
version = "1.3.49"
version = "1.3.50"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be needed.

package_type = "library"
license = "Apache 2"
homepage = "https://github.com/CQCL/tket"
Expand Down
Loading