Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/stim/circuit/circuit_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,3 +1547,29 @@ def test_shortest_graphlike_error_many_obs():
OBSERVABLE_INCLUDE(1200) rec[-1]
""")
assert len(c.shortest_graphlike_error()) == 5


def test_detslice_filter_coords_flexibility():
c = stim.Circuit.generated("repetition_code:memory", distance=3, rounds=3)
d1 = c.diagram("detslice", filter_coords=[stim.DemTarget.relative_detector_id(1)])
d2 = c.diagram("detslice-svg", filter_coords=stim.DemTarget.relative_detector_id(1))
d3 = c.diagram("detslice", filter_coords=["D1"])
d4 = c.diagram("detslice", filter_coords="D1")
d5 = c.diagram("detector-slice-svg", filter_coords=[3, 0])
d6 = c.diagram("detslice-svg", filter_coords=[[3, 0]])
assert str(d1) == str(d2)
assert str(d1) == str(d3)
assert str(d1) == str(d4)
assert str(d1) == str(d5)
assert str(d1) == str(d6)
assert str(d1) != str(c.diagram("detslice", filter_coords="L0"))

d1 = c.diagram("detslice", filter_coords=[stim.DemTarget.relative_detector_id(1), stim.DemTarget.relative_detector_id(3), stim.DemTarget.relative_detector_id(5), "D7"])
d2 = c.diagram("detslice", filter_coords=["D1", "D3", "D5", "D7"])
d3 = c.diagram("detslice-svg", filter_coords=[3,])
d4 = c.diagram("detslice-svg", filter_coords=[[3,]])
d5 = c.diagram("detslice-svg", filter_coords=[[3, 0], [3, 1], [3, 2], [3, 3]])
assert str(d1) == str(d2)
assert str(d1) == str(d3)
assert str(d1) == str(d4)
assert str(d1) == str(d5)
81 changes: 60 additions & 21 deletions src/stim/cmd/command_diagram.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

#include "stim/cmd/command_diagram.pybind.h"

#include "stim/arg_parse.h"
#include "stim/cmd/command_help.h"
#include "stim/dem/detector_error_model_target.pybind.h"
#include "stim/diagram/base64.h"
#include "stim/diagram/crumble.h"
#include "stim/diagram/detector_slice/detector_slice_set.h"
Expand Down Expand Up @@ -143,31 +145,68 @@ DiagramHelper stim_pybind::dem_diagram(const DetectorErrorModel &dem, const std:
throw std::invalid_argument("Unrecognized diagram type: " + type);
}
}

CoordFilter item_to_filter_single(const pybind11::handle &obj) {
if (pybind11::isinstance<ExposedDemTarget>(obj)) {
CoordFilter filter;
filter.exact_target = pybind11::cast<ExposedDemTarget>(obj).internal();
filter.use_target = true;
return filter;
}

try {
std::string text = pybind11::cast<std::string>(obj);
if (text.size() > 1 && text[0] == 'D') {
CoordFilter filter;
filter.exact_target = DemTarget::relative_detector_id(parse_exact_uint64_t_from_string(text.substr(1)));
filter.use_target = true;
return filter;
}
if (text.size() > 1 && text[0] == 'L') {
CoordFilter filter;
filter.exact_target = DemTarget::observable_id(parse_exact_uint64_t_from_string(text.substr(1)));
filter.use_target = true;
return filter;
}
} catch (const pybind11::cast_error &) {
} catch (const std::invalid_argument &) {
}

CoordFilter filter;
for (const auto &c : obj) {
filter.coordinates.push_back(pybind11::cast<double>(c));
}
return filter;
}

std::vector<CoordFilter> item_to_filter_multi(const pybind11::object &obj) {
if (obj.is_none()) {
return {CoordFilter{}};
}

try {
return {item_to_filter_single(obj)};
} catch (const pybind11::cast_error &) {
} catch (const std::invalid_argument &) {
}

std::vector<CoordFilter> filters;
for (const auto &filter_case : obj) {
filters.push_back(item_to_filter_single(filter_case));
}
return filters;
}

DiagramHelper stim_pybind::circuit_diagram(
const Circuit &circuit,
const std::string &type,
const pybind11::object &tick,
const pybind11::object &filter_coords_obj) {
std::vector<CoordFilter> filter_coords;
try {
if (filter_coords_obj.is_none()) {
filter_coords.push_back({});
} else {
for (const auto &filter_case : filter_coords_obj) {
CoordFilter filter;
if (pybind11::isinstance<DemTarget>(filter_case)) {
filter.exact_target = pybind11::cast<DemTarget>(filter_case);
filter.use_target = true;
} else {
for (const auto &c : filter_case) {
filter.coordinates.push_back(pybind11::cast<double>(c));
}
}
filter_coords.push_back(std::move(filter));
}
}
filter_coords = item_to_filter_multi(filter_coords_obj);
} catch (const std::exception &_) {
throw std::invalid_argument("filter_coords wasn't a list of list of floats.");
throw std::invalid_argument("filter_coords wasn't an Iterable[stim.DemTarget | Iterable[float]].");
}

uint64_t tick_min;
Expand Down Expand Up @@ -198,21 +237,21 @@ DiagramHelper stim_pybind::circuit_diagram(
std::stringstream out;
out << DiagramTimelineAsciiDrawer::make_diagram(circuit);
return DiagramHelper{DIAGRAM_TYPE_TEXT, out.str()};
} else if (type == "timeline-svg") {
} else if (type == "timeline-svg" || type == "timeline") {
std::stringstream out;
DiagramTimelineSvgDrawer::make_diagram_write_to(
circuit, out, tick_min, num_ticks, SVG_MODE_TIMELINE, filter_coords);
return DiagramHelper{DIAGRAM_TYPE_SVG, out.str()};
} else if (type == "time-slice-svg" || type == "timeslice-svg") {
} else if (type == "time-slice-svg" || type == "timeslice-svg" || type == "timeslice" || type == "time-slice") {
std::stringstream out;
DiagramTimelineSvgDrawer::make_diagram_write_to(
circuit, out, tick_min, num_ticks, SVG_MODE_TIME_SLICE, filter_coords);
return DiagramHelper{DIAGRAM_TYPE_SVG, out.str()};
} else if (type == "detslice-svg" || type == "detector-slice-svg") {
} else if (type == "detslice-svg" || type == "detslice" || type == "detector-slice-svg" || type == "detector-slice") {
std::stringstream out;
DetectorSliceSet::from_circuit_ticks(circuit, tick_min, num_ticks, filter_coords).write_svg_diagram_to(out);
return DiagramHelper{DIAGRAM_TYPE_SVG, out.str()};
} else if (type == "detslice-with-ops-svg" || type == "time+detector-slice-svg") {
} else if (type == "detslice-with-ops" || type == "detslice-with-ops-svg" || type == "time+detector-slice-svg") {
std::stringstream out;
DiagramTimelineSvgDrawer::make_diagram_write_to(
circuit, out, tick_min, num_ticks, SVG_MODE_TIME_DETECTOR_SLICE, filter_coords);
Expand Down
8 changes: 8 additions & 0 deletions src/stim/diagram/detector_slice/detector_slice_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ std::ostream &stim_draw_internal::operator<<(std::ostream &out, const DetectorSl
slice.write_text_diagram_to(out);
return out;
}
std::ostream &stim_draw_internal::operator<<(std::ostream &out, const CoordFilter &filter) {
if (filter.use_target) {
out << filter.exact_target;
} else {
out << comma_sep(filter.coordinates);
}
return out;
}

void DetectorSliceSet::write_text_diagram_to(std::ostream &out) const {
DiagramTimelineAsciiDrawer drawer(num_qubits, false);
Expand Down
1 change: 1 addition & 0 deletions src/stim/diagram/detector_slice/detector_slice_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Coord<2> pick_polygon_center(stim::SpanRef<const Coord<2>> coords);
bool is_colinear(Coord<2> a, Coord<2> b, Coord<2> c, float atol);

std::ostream &operator<<(std::ostream &out, const DetectorSliceSet &slice);
std::ostream &operator<<(std::ostream &out, const CoordFilter &filter);

} // namespace stim_draw_internal

Expand Down