Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase flexibility of stim.Circuit.diagram filter_coords args #618

Merged
merged 1 commit into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/stim/circuit/circuit_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,3 +1547,29 @@ def test_shortest_graphlike_error_many_obs():
OBSERVABLE_INCLUDE(1200) rec[-1]
""")
assert len(c.shortest_graphlike_error()) == 5


def test_detslice_filter_coords_flexibility():
c = stim.Circuit.generated("repetition_code:memory", distance=3, rounds=3)
d1 = c.diagram("detslice", filter_coords=[stim.DemTarget.relative_detector_id(1)])
d2 = c.diagram("detslice-svg", filter_coords=stim.DemTarget.relative_detector_id(1))
d3 = c.diagram("detslice", filter_coords=["D1"])
d4 = c.diagram("detslice", filter_coords="D1")
d5 = c.diagram("detector-slice-svg", filter_coords=[3, 0])
d6 = c.diagram("detslice-svg", filter_coords=[[3, 0]])
assert str(d1) == str(d2)
assert str(d1) == str(d3)
assert str(d1) == str(d4)
assert str(d1) == str(d5)
assert str(d1) == str(d6)
assert str(d1) != str(c.diagram("detslice", filter_coords="L0"))

d1 = c.diagram("detslice", filter_coords=[stim.DemTarget.relative_detector_id(1), stim.DemTarget.relative_detector_id(3), stim.DemTarget.relative_detector_id(5), "D7"])
d2 = c.diagram("detslice", filter_coords=["D1", "D3", "D5", "D7"])
d3 = c.diagram("detslice-svg", filter_coords=[3,])
d4 = c.diagram("detslice-svg", filter_coords=[[3,]])
d5 = c.diagram("detslice-svg", filter_coords=[[3, 0], [3, 1], [3, 2], [3, 3]])
assert str(d1) == str(d2)
assert str(d1) == str(d3)
assert str(d1) == str(d4)
assert str(d1) == str(d5)
81 changes: 60 additions & 21 deletions src/stim/cmd/command_diagram.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

#include "stim/cmd/command_diagram.pybind.h"

#include "stim/arg_parse.h"
#include "stim/cmd/command_help.h"
#include "stim/dem/detector_error_model_target.pybind.h"
#include "stim/diagram/base64.h"
#include "stim/diagram/crumble.h"
#include "stim/diagram/detector_slice/detector_slice_set.h"
Expand Down Expand Up @@ -143,31 +145,68 @@ DiagramHelper stim_pybind::dem_diagram(const DetectorErrorModel &dem, const std:
throw std::invalid_argument("Unrecognized diagram type: " + type);
}
}

CoordFilter item_to_filter_single(const pybind11::handle &obj) {
if (pybind11::isinstance<ExposedDemTarget>(obj)) {
CoordFilter filter;
filter.exact_target = pybind11::cast<ExposedDemTarget>(obj).internal();
filter.use_target = true;
return filter;
}

try {
std::string text = pybind11::cast<std::string>(obj);
if (text.size() > 1 && text[0] == 'D') {
CoordFilter filter;
filter.exact_target = DemTarget::relative_detector_id(parse_exact_uint64_t_from_string(text.substr(1)));
filter.use_target = true;
return filter;
}
if (text.size() > 1 && text[0] == 'L') {
CoordFilter filter;
filter.exact_target = DemTarget::observable_id(parse_exact_uint64_t_from_string(text.substr(1)));
filter.use_target = true;
return filter;
}
} catch (const pybind11::cast_error &) {
} catch (const std::invalid_argument &) {
}

CoordFilter filter;
for (const auto &c : obj) {
filter.coordinates.push_back(pybind11::cast<double>(c));
}
return filter;
}

std::vector<CoordFilter> item_to_filter_multi(const pybind11::object &obj) {
if (obj.is_none()) {
return {CoordFilter{}};
}

try {
return {item_to_filter_single(obj)};
} catch (const pybind11::cast_error &) {
} catch (const std::invalid_argument &) {
}

std::vector<CoordFilter> filters;
for (const auto &filter_case : obj) {
filters.push_back(item_to_filter_single(filter_case));
}
return filters;
}

DiagramHelper stim_pybind::circuit_diagram(
const Circuit &circuit,
const std::string &type,
const pybind11::object &tick,
const pybind11::object &filter_coords_obj) {
std::vector<CoordFilter> filter_coords;
try {
if (filter_coords_obj.is_none()) {
filter_coords.push_back({});
} else {
for (const auto &filter_case : filter_coords_obj) {
CoordFilter filter;
if (pybind11::isinstance<DemTarget>(filter_case)) {
filter.exact_target = pybind11::cast<DemTarget>(filter_case);
filter.use_target = true;
} else {
for (const auto &c : filter_case) {
filter.coordinates.push_back(pybind11::cast<double>(c));
}
}
filter_coords.push_back(std::move(filter));
}
}
filter_coords = item_to_filter_multi(filter_coords_obj);
} catch (const std::exception &_) {
throw std::invalid_argument("filter_coords wasn't a list of list of floats.");
throw std::invalid_argument("filter_coords wasn't an Iterable[stim.DemTarget | Iterable[float]].");
}

uint64_t tick_min;
Expand Down Expand Up @@ -198,21 +237,21 @@ DiagramHelper stim_pybind::circuit_diagram(
std::stringstream out;
out << DiagramTimelineAsciiDrawer::make_diagram(circuit);
return DiagramHelper{DIAGRAM_TYPE_TEXT, out.str()};
} else if (type == "timeline-svg") {
} else if (type == "timeline-svg" || type == "timeline") {
std::stringstream out;
DiagramTimelineSvgDrawer::make_diagram_write_to(
circuit, out, tick_min, num_ticks, SVG_MODE_TIMELINE, filter_coords);
return DiagramHelper{DIAGRAM_TYPE_SVG, out.str()};
} else if (type == "time-slice-svg" || type == "timeslice-svg") {
} else if (type == "time-slice-svg" || type == "timeslice-svg" || type == "timeslice" || type == "time-slice") {
std::stringstream out;
DiagramTimelineSvgDrawer::make_diagram_write_to(
circuit, out, tick_min, num_ticks, SVG_MODE_TIME_SLICE, filter_coords);
return DiagramHelper{DIAGRAM_TYPE_SVG, out.str()};
} else if (type == "detslice-svg" || type == "detector-slice-svg") {
} else if (type == "detslice-svg" || type == "detslice" || type == "detector-slice-svg" || type == "detector-slice") {
std::stringstream out;
DetectorSliceSet::from_circuit_ticks(circuit, tick_min, num_ticks, filter_coords).write_svg_diagram_to(out);
return DiagramHelper{DIAGRAM_TYPE_SVG, out.str()};
} else if (type == "detslice-with-ops-svg" || type == "time+detector-slice-svg") {
} else if (type == "detslice-with-ops" || type == "detslice-with-ops-svg" || type == "time+detector-slice-svg") {
std::stringstream out;
DiagramTimelineSvgDrawer::make_diagram_write_to(
circuit, out, tick_min, num_ticks, SVG_MODE_TIME_DETECTOR_SLICE, filter_coords);
Expand Down
8 changes: 8 additions & 0 deletions src/stim/diagram/detector_slice/detector_slice_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ std::ostream &stim_draw_internal::operator<<(std::ostream &out, const DetectorSl
slice.write_text_diagram_to(out);
return out;
}
std::ostream &stim_draw_internal::operator<<(std::ostream &out, const CoordFilter &filter) {
if (filter.use_target) {
out << filter.exact_target;
} else {
out << comma_sep(filter.coordinates);
}
return out;
}

void DetectorSliceSet::write_text_diagram_to(std::ostream &out) const {
DiagramTimelineAsciiDrawer drawer(num_qubits, false);
Expand Down
1 change: 1 addition & 0 deletions src/stim/diagram/detector_slice/detector_slice_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Coord<2> pick_polygon_center(stim::SpanRef<const Coord<2>> coords);
bool is_colinear(Coord<2> a, Coord<2> b, Coord<2> c, float atol);

std::ostream &operator<<(std::ostream &out, const DetectorSliceSet &slice);
std::ostream &operator<<(std::ostream &out, const CoordFilter &filter);

} // namespace stim_draw_internal

Expand Down