-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BasePass.get_pre_conditions
and BasePass.get_post_conditions
#1689
Changes from 17 commits
e494551
da64baa
135fa39
801bb2a
297cd4e
34ed2c2
b07b052
04d318b
ea9700d
3ca2716
e94838a
a94fa69
c41dab2
376d87c
d104331
2e4ca31
dc05f2f
ed275e6
f0fb617
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -145,6 +145,31 @@ const PassPtr &DecomposeClassicalExp() { | |||||||||||||
return pp; | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
std::optional<OpTypeSet> get_gate_set(const BasePass &base_pass) { | ||||||||||||||
OpTypeSet allowed_ops; | ||||||||||||||
for (const std::pair<const std::type_index, std::shared_ptr<tket::Predicate>> | ||||||||||||||
&p : base_pass.get_conditions().first) { | ||||||||||||||
if (p.second->to_string().substr(0, 17) == "GateSetPredicate:") { | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this check is necessary. We can just use |
||||||||||||||
std::shared_ptr<GateSetPredicate> gsp_ptr = | ||||||||||||||
std::dynamic_pointer_cast<GateSetPredicate>(p.second); | ||||||||||||||
if (allowed_ops.empty()) { | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will end up giving the wrong result in the (admittedly unlikely) case that we have already encountered two There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done (I think) |
||||||||||||||
allowed_ops = gsp_ptr->get_allowed_types(); | ||||||||||||||
continue; | ||||||||||||||
} | ||||||||||||||
OpTypeSet intersection; | ||||||||||||||
OpTypeSet candidate_allowed_ops = gsp_ptr->get_allowed_types(); | ||||||||||||||
std::set_intersection( | ||||||||||||||
candidate_allowed_ops.begin(), candidate_allowed_ops.end(), | ||||||||||||||
allowed_ops.begin(), allowed_ops.end(), | ||||||||||||||
std::inserter(intersection, intersection.begin())); | ||||||||||||||
allowed_ops = intersection; | ||||||||||||||
} | ||||||||||||||
} | ||||||||||||||
if (allowed_ops.empty()) return {}; | ||||||||||||||
|
||||||||||||||
return allowed_ops; | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
PYBIND11_MODULE(passes, m) { | ||||||||||||||
py::module_::import("pytket._tket.predicates"); | ||||||||||||||
m.def( | ||||||||||||||
|
@@ -212,7 +237,6 @@ PYBIND11_MODULE(passes, m) { | |||||||||||||
); | ||||||||||||||
} | ||||||||||||||
}; | ||||||||||||||
|
||||||||||||||
py::class_<BasePass, PassPtr, PyBasePass>( | ||||||||||||||
m, "BasePass", "Base class for passes.") | ||||||||||||||
.def( | ||||||||||||||
|
@@ -268,6 +292,37 @@ PYBIND11_MODULE(passes, m) { | |||||||||||||
return py::cast(serialise(base_pass)); | ||||||||||||||
}, | ||||||||||||||
":return: A JSON serializable dictionary representation of the Pass.") | ||||||||||||||
.def( | ||||||||||||||
"get_preconditions", | ||||||||||||||
[](const BasePass &base_pass) { | ||||||||||||||
std::vector<PredicatePtr> pre_conditions; | ||||||||||||||
for (const std::pair< | ||||||||||||||
const std::type_index, std::shared_ptr<tket::Predicate>> | ||||||||||||||
&p : base_pass.get_conditions().first) { | ||||||||||||||
pre_conditions.push_back(p.second); | ||||||||||||||
} | ||||||||||||||
return pre_conditions; | ||||||||||||||
}, | ||||||||||||||
"Returns the precondition Predicates for the given pass." | ||||||||||||||
"\n:return: A list of Predicate") | ||||||||||||||
.def( | ||||||||||||||
"get_postconditions", | ||||||||||||||
[](const BasePass &base_pass) { | ||||||||||||||
std::vector<PredicatePtr> post_conditions; | ||||||||||||||
for (const std::pair< | ||||||||||||||
const std::type_index, std::shared_ptr<tket::Predicate>> & | ||||||||||||||
p : base_pass.get_conditions().second.specific_postcons_) { | ||||||||||||||
post_conditions.push_back(p.second); | ||||||||||||||
} | ||||||||||||||
return post_conditions; | ||||||||||||||
}, | ||||||||||||||
"Returns the postcondition Predicates for the given pass." | ||||||||||||||
"\n\n:return: A list of :py:class:`Predicate`") | ||||||||||||||
.def( | ||||||||||||||
"get_gate_set", &get_gate_set, | ||||||||||||||
"Returns the intersection of all set of OpType for all " | ||||||||||||||
"GateSetPredicate in the `BasePass` preconditions.", | ||||||||||||||
"\n\n:return: A set of allowed OpType") | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||||||||||
.def_static( | ||||||||||||||
"from_dict", | ||||||||||||||
[](const py::dict &base_pass_dict, | ||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,7 +38,7 @@ def requirements(self): | |
self.requires("pybind11_json/0.2.14") | ||
self.requires("symengine/0.13.0") | ||
self.requires("tkassert/0.3.4@tket/stable") | ||
self.requires("tket/1.3.49@tket/stable") | ||
self.requires("tket/1.3.50@tket/stable") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This shouldn't need updating. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've reverted - let's see if the same issue comes up |
||
self.requires("tklog/0.3.3@tket/stable") | ||
self.requires("tkrng/0.3.3@tket/stable") | ||
self.requires("tktokenswap/0.3.9@tket/stable") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1092,6 +1092,26 @@ def test_greedy_pauli_synth() -> None: | |
assert GreedyPauliSimp().apply(c) | ||
|
||
|
||
def test_get_pre_conditions() -> None: | ||
pre_cons = GreedyPauliSimp().get_preconditions() | ||
gate_set = pre_cons[0].gate_set # type: ignore | ||
assert OpType.CX in gate_set | ||
assert OpType.Measure in gate_set | ||
|
||
|
||
def test_get_post_conditions() -> None: | ||
gate_set = {OpType.CX, OpType.Rz, OpType.H, OpType.Reset, OpType.Measure} | ||
post_cons = AutoRebase(gate_set).get_postconditions() | ||
assert post_cons[0].gate_set == gate_set # type: ignore | ||
|
||
|
||
def test_get_gate_set() -> None: | ||
gate_set = GreedyPauliSimp().get_gate_set() | ||
assert OpType.CX in gate_set # type: ignore | ||
assert OpType.Measure in gate_set # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be able to avoid the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
assert CliffordPushThroughMeasures().get_gate_set() is None | ||
|
||
|
||
if __name__ == "__main__": | ||
test_predicate_generation() | ||
test_compilation_unit_generation() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ | |
|
||
class TketConan(ConanFile): | ||
name = "tket" | ||
version = "1.3.49" | ||
version = "1.3.50" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't be needed. |
||
package_type = "library" | ||
license = "Apache 2" | ||
homepage = "https://github.com/CQCL/tket" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is also worth adding a function specifically to get the gate-set requirements of a pass; i.e. which gets the preconditions, checks for any
GateSetPredicate
s, and returns the set intersection of their gate sets (orNone
if there are noGateSetPredicate
s.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done