Skip to content

Commit

Permalink
Add `stim.CompiledDetectorSampler.sample(..., dets_out=None, obs_out=…
Browse files Browse the repository at this point in the history
…None)` (#782)

- Add `obs_out` and `dets_out` parameters to the sinter-hot-path
detection event sampling method
- Rewrite bit-table-to-numpy code to allow passing in the buffer to
write to
- Avoids doing large allocations for every call to sample
- Avoids some extra copies that were previously present
- Also, release the GIL when doing the actual frame simulation call

Benchmarked by taking 250 instances of 1024 shots from a distance 11
surface code running for 33 rounds:

- Old version: 3.16 seconds
- New version (no buffer): 2.36 seconds
- New version (yes buffer): 2.34 seconds

So... the buffer appears to not be hugely significant, but the copy
reduction was very useful.

```
import numpy as np
import stim
import time

circuit = stim.Circuit.generated(
    "surface_code:rotated_memory_x",
    distance=11,
    rounds=33,
    after_clifford_depolarization=1e-3,
    before_measure_flip_probability=1e-3,
    after_reset_flip_probability=1e-3,
    before_round_data_depolarization=1e-3,
)
sampler = circuit.compile_detector_sampler()

det_buf = np.empty((1024, (circuit.num_detectors + 7) // 8), dtype=np.uint8)
obs_buf = np.empty((1024, (circuit.num_observables + 7) // 8), dtype=np.uint8)
t0 = time.monotonic()

if True:
    for _ in range(250):
        sampler.sample(
            shots=1024,
            bit_packed=True,
            dets_out=det_buf,
            obs_out=obs_buf,
        )
else:
    for _ in range(250):
        sampler.sample(
            shots=1024,
            bit_packed=True,
        )
t1 = time.monotonic()
dt = t1 - t0
print(dt)
print(dt / 1024)
print(dt / 1024 / 1024)
print(dt / 1024 / 1024 / circuit.num_detectors)
```
  • Loading branch information
Strilanc authored Jun 11, 2024
1 parent 985d6a7 commit 320288c
Show file tree
Hide file tree
Showing 16 changed files with 418 additions and 1,811 deletions.
1,695 changes: 74 additions & 1,621 deletions MODULE.bazel.lock

Large diffs are not rendered by default.

27 changes: 8 additions & 19 deletions doc/python_api_reference_vDev.md
Original file line number Diff line number Diff line change
Expand Up @@ -4304,25 +4304,6 @@ def __repr__(
# stim.CompiledDetectorSampler.sample

# (in class stim.CompiledDetectorSampler)
@overload
def sample(
self,
shots: int,
*,
prepend_observables: bool = False,
append_observables: bool = False,
bit_packed: bool = False,
) -> np.ndarray:
pass
@overload
def sample(
self,
shots: int,
*,
separate_observables: Literal[True],
bit_packed: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
pass
def sample(
self,
shots: int,
Expand All @@ -4331,6 +4312,8 @@ def sample(
append_observables: bool = False,
separate_observables: bool = False,
bit_packed: bool = False,
dets_out: Optional[np.ndarray] = None,
obs_out: Optional[np.ndarray] = None,
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""Returns a numpy array containing a batch of detector samples from the circuit.
Expand All @@ -4349,6 +4332,12 @@ def sample(
with the detectors and are placed at the end of the results.
bit_packed: Returns a uint8 numpy array with 8 bits per byte, instead of
a bool_ numpy array with 1 bit per byte. Uses little endian packing.
dets_out: Defaults to None. Specifies a pre-allocated numpy array to write
the detection event data into. This array must have the correct shape
and dtype.
obs_out: Defaults to None. Specifies a pre-allocated numpy array to write
the observable flip data into. This array must have the correct shape
and dtype.
Returns:
A numpy array or tuple of numpy arrays containing the samples.
Expand Down
27 changes: 8 additions & 19 deletions doc/stim.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3242,25 +3242,6 @@ class CompiledDetectorSampler:
) -> str:
"""Returns valid python code evaluating to an equivalent `stim.CompiledDetectorSampler`.
"""
@overload
def sample(
self,
shots: int,
*,
prepend_observables: bool = False,
append_observables: bool = False,
bit_packed: bool = False,
) -> np.ndarray:
pass
@overload
def sample(
self,
shots: int,
*,
separate_observables: Literal[True],
bit_packed: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
pass
def sample(
self,
shots: int,
Expand All @@ -3269,6 +3250,8 @@ class CompiledDetectorSampler:
append_observables: bool = False,
separate_observables: bool = False,
bit_packed: bool = False,
dets_out: Optional[np.ndarray] = None,
obs_out: Optional[np.ndarray] = None,
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""Returns a numpy array containing a batch of detector samples from the circuit.
Expand All @@ -3287,6 +3270,12 @@ class CompiledDetectorSampler:
with the detectors and are placed at the end of the results.
bit_packed: Returns a uint8 numpy array with 8 bits per byte, instead of
a bool_ numpy array with 1 bit per byte. Uses little endian packing.
dets_out: Defaults to None. Specifies a pre-allocated numpy array to write
the detection event data into. This array must have the correct shape
and dtype.
obs_out: Defaults to None. Specifies a pre-allocated numpy array to write
the observable flip data into. This array must have the correct shape
and dtype.
Returns:
A numpy array or tuple of numpy arrays containing the samples.
Expand Down
27 changes: 8 additions & 19 deletions glue/python/src/stim/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3242,25 +3242,6 @@ class CompiledDetectorSampler:
) -> str:
"""Returns valid python code evaluating to an equivalent `stim.CompiledDetectorSampler`.
"""
@overload
def sample(
self,
shots: int,
*,
prepend_observables: bool = False,
append_observables: bool = False,
bit_packed: bool = False,
) -> np.ndarray:
pass
@overload
def sample(
self,
shots: int,
*,
separate_observables: Literal[True],
bit_packed: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
pass
def sample(
self,
shots: int,
Expand All @@ -3269,6 +3250,8 @@ class CompiledDetectorSampler:
append_observables: bool = False,
separate_observables: bool = False,
bit_packed: bool = False,
dets_out: Optional[np.ndarray] = None,
obs_out: Optional[np.ndarray] = None,
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""Returns a numpy array containing a batch of detector samples from the circuit.
Expand All @@ -3287,6 +3270,12 @@ class CompiledDetectorSampler:
with the detectors and are placed at the end of the results.
bit_packed: Returns a uint8 numpy array with 8 bits per byte, instead of
a bool_ numpy array with 1 bit per byte. Uses little endian packing.
dets_out: Defaults to None. Specifies a pre-allocated numpy array to write
the detection event data into. This array must have the correct shape
and dtype.
obs_out: Defaults to None. Specifies a pre-allocated numpy array to write
the observable flip data into. This array must have the correct shape
and dtype.
Returns:
A numpy array or tuple of numpy arrays containing the samples.
Expand Down
70 changes: 45 additions & 25 deletions src/stim/py/compiled_detector_sampler.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,37 +32,49 @@ CompiledDetectorSampler::CompiledDetectorSampler(Circuit init_circuit, std::mt19
}

pybind11::object CompiledDetectorSampler::sample_to_numpy(
size_t num_shots, bool prepend_observables, bool append_observables, bool separate_observables, bool bit_packed) {
size_t num_shots, bool prepend_observables, bool append_observables, bool separate_observables, bool bit_packed, pybind11::object dets_out, pybind11::object obs_out) {
if (separate_observables && (append_observables || prepend_observables)) {
throw std::invalid_argument(
"Can't specify separate_observables=True with append_observables=True or prepend_observables=True");
}

frame_sim.configure_for(circuit_stats, FrameSimulatorMode::STORE_DETECTIONS_TO_MEMORY, num_shots);
frame_sim.reset_all();
frame_sim.do_circuit(circuit);
{
pybind11::gil_scoped_release release;
frame_sim.configure_for(circuit_stats, FrameSimulatorMode::STORE_DETECTIONS_TO_MEMORY, num_shots);
frame_sim.reset_all();
frame_sim.do_circuit(circuit);
}

const auto &det_data = frame_sim.det_record.storage;
const auto &obs_data = frame_sim.obs_record;
uint64_t num_dets = circuit_stats.num_detectors;
uint64_t num_obs = circuit_stats.num_observables;
if (separate_observables) {
pybind11::object py_det_data = transposed_simd_bit_table_to_numpy(det_data, num_dets, num_shots, bit_packed);
pybind11::object py_obs_data = transposed_simd_bit_table_to_numpy(obs_data, num_obs, num_shots, bit_packed);
return pybind11::make_tuple(py_det_data, py_obs_data);

pybind11::object py_obs_data = pybind11::none();
if (separate_observables || !obs_out.is_none()) {
py_obs_data = simd_bit_table_to_numpy(obs_data, circuit_stats.num_observables, num_shots, bit_packed, true, obs_out);
}

size_t num_concat = circuit_stats.num_detectors;
simd_bit_table<MAX_BITWORD_WIDTH> concat_data = det_data;
if (append_observables) {
concat_data = concat_data.concat_major(obs_data, num_concat, circuit_stats.num_observables);
num_concat += circuit_stats.num_observables;
pybind11::object py_det_data = pybind11::none();
if (append_observables || prepend_observables) {
size_t num_concat = circuit_stats.num_detectors;
simd_bit_table<MAX_BITWORD_WIDTH> concat_data = det_data;
if (append_observables) {
concat_data = concat_data.concat_major(obs_data, num_concat, circuit_stats.num_observables);
num_concat += circuit_stats.num_observables;
}
if (prepend_observables) {
concat_data = obs_data.concat_major(concat_data, circuit_stats.num_observables, num_concat);
num_concat += circuit_stats.num_observables;
}
py_det_data = simd_bit_table_to_numpy(concat_data, num_concat, num_shots, bit_packed, true, dets_out);
} else {
py_det_data = simd_bit_table_to_numpy(det_data, circuit_stats.num_detectors, num_shots, bit_packed, true, dets_out);
}
if (prepend_observables) {
concat_data = obs_data.concat_major(concat_data, circuit_stats.num_observables, num_concat);
num_concat += circuit_stats.num_observables;

if (separate_observables) {
return pybind11::make_tuple(py_det_data, py_obs_data);
} else {
return py_det_data;
}
return transposed_simd_bit_table_to_numpy(concat_data, num_concat, num_shots, bit_packed);
}

void CompiledDetectorSampler::sample_write(
Expand Down Expand Up @@ -200,20 +212,22 @@ void stim_pybind::pybind_compiled_detector_sampler_methods(
bool prepend,
bool append,
bool separate_observables,
bool bit_packed) {
return self.sample_to_numpy(shots, prepend, append, separate_observables, bit_packed);
bool bit_packed,
pybind11::object dets_out,
pybind11::object obs_out) {
return self.sample_to_numpy(shots, prepend, append, separate_observables, bit_packed, dets_out, obs_out);
},
pybind11::arg("shots"),
pybind11::kw_only(),
pybind11::arg("prepend_observables") = false,
pybind11::arg("append_observables") = false,
pybind11::arg("separate_observables") = false,
pybind11::arg("bit_packed") = false,
pybind11::arg("dets_out") = pybind11::none(),
pybind11::arg("obs_out") = pybind11::none(),
clean_doc_string(R"DOC(
@signature def sample(self, shots: int, *, prepend_observables: bool = False, append_observables: bool = False, separate_observables: bool = False, bit_packed: bool = False, dets_out: Optional[np.ndarray] = None, obs_out: Optional[np.ndarray] = None) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
Returns a numpy array containing a batch of detector samples from the circuit.
@overload def sample(self, shots: int, *, prepend_observables: bool = False, append_observables: bool = False, bit_packed: bool = False) -> np.ndarray:
@overload def sample(self, shots: int, *, separate_observables: Literal[True], bit_packed: bool = False) -> Tuple[np.ndarray, np.ndarray]:
@signature def sample(self, shots: int, *, prepend_observables: bool = False, append_observables: bool = False, separate_observables: bool = False, bit_packed: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
The circuit must define the detectors using DETECTOR instructions. Observables
defined by OBSERVABLE_INCLUDE instructions can also be included in the results
Expand All @@ -230,6 +244,12 @@ void stim_pybind::pybind_compiled_detector_sampler_methods(
with the detectors and are placed at the end of the results.
bit_packed: Returns a uint8 numpy array with 8 bits per byte, instead of
a bool_ numpy array with 1 bit per byte. Uses little endian packing.
dets_out: Defaults to None. Specifies a pre-allocated numpy array to write
the detection event data into. This array must have the correct shape
and dtype.
obs_out: Defaults to None. Specifies a pre-allocated numpy array to write
the observable flip data into. This array must have the correct shape
and dtype.
Returns:
A numpy array or tuple of numpy arrays containing the samples.
Expand Down Expand Up @@ -283,7 +303,7 @@ void stim_pybind::pybind_compiled_detector_sampler_methods(
c.def(
"sample_bit_packed",
[](CompiledDetectorSampler &self, size_t shots, bool prepend, bool append) {
return self.sample_to_numpy(shots, prepend, append, false, true);
return self.sample_to_numpy(shots, prepend, append, false, true, pybind11::none(), pybind11::none());
},
pybind11::arg("shots"),
pybind11::kw_only(),
Expand Down
4 changes: 3 additions & 1 deletion src/stim/py/compiled_detector_sampler.pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ struct CompiledDetectorSampler {
bool prepend_observables,
bool append_observables,
bool separate_observables,
bool bit_packed);
bool bit_packed,
pybind11::object dets_out,
pybind11::object obs_out);
void sample_write(
size_t num_samples,
pybind11::object filepath_obj,
Expand Down
103 changes: 103 additions & 0 deletions src/stim/py/compiled_detector_sampler_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tempfile

import numpy as np
import pytest
import stim


Expand Down Expand Up @@ -209,3 +210,105 @@ def test_detector_sampler_actually_fills_array():
sampler = circuit.compile_detector_sampler()
detector_data = sampler.sample(shots=10000)
assert np.all(detector_data)


def test_manual_output_buffer():
circuit = stim.Circuit('''
X_ERROR(1) 0
M 0
DETECTOR
DETECTOR
DETECTOR rec[-1]
DETECTOR rec[-1]
DETECTOR rec[-1]
DETECTOR
DETECTOR
DETECTOR rec[-1]
DETECTOR rec[-1]
OBSERVABLE_INCLUDE(0) rec[-1]
''')
sampler = circuit.compile_detector_sampler()

with pytest.raises(ValueError):
sampler.sample(shots=17, dets_out=np.zeros(shape=(17, 8), dtype=np.bool_))
with pytest.raises(ValueError):
sampler.sample(shots=17, dets_out=np.zeros(shape=(17, 10), dtype=np.bool_))
with pytest.raises(ValueError):
sampler.sample(shots=17, dets_out=np.zeros(shape=(18, 9), dtype=np.bool_))
with pytest.raises(ValueError):
sampler.sample(shots=17, dets_out=np.zeros(shape=(16, 9), dtype=np.bool_))
with pytest.raises(ValueError):
sampler.sample(shots=17, dets_out=np.zeros(shape=(17, 9), dtype=np.uint8))

buf = np.zeros(shape=(17, 9), dtype=np.bool_)
ret = sampler.sample(shots=17, dets_out=buf)
assert ret is buf
assert np.array_equal(buf, [[0, 0, 1, 1, 1, 0, 0, 1, 1]] * 17)

buf = np.zeros(shape=(17, 2), dtype=np.uint8)
ret = sampler.sample(
shots=17,
dets_out=buf,
bit_packed=True,
)
assert ret is buf
assert np.array_equal(buf, [[0b10011100, 0b1]] * 17)

buf = np.zeros(shape=(2, 17), dtype=np.uint8).transpose()
ret = sampler.sample(
shots=17,
dets_out=buf,
bit_packed=True,
)
assert ret is buf
assert np.array_equal(buf, [[0b10011100, 0b1]] * 17)

buf = np.zeros(shape=(17, 9), dtype=np.bool_)
buf2 = np.zeros(shape=(17, 1), dtype=np.bool_)
ret = sampler.sample(
shots=17,
dets_out=buf,
obs_out=buf2,
)
assert ret is buf
assert np.array_equal(buf, [[0, 0, 1, 1, 1, 0, 0, 1, 1]] * 17)
assert np.array_equal(buf2, [[1]] * 17)

buf = np.zeros(shape=(17, 9), dtype=np.bool_)
buf2 = np.zeros(shape=(17, 1), dtype=np.bool_)
ret, ret2 = sampler.sample(
shots=17,
dets_out=buf,
obs_out=buf2,
separate_observables=True,
)
assert ret is buf
assert ret2 is buf2
assert np.array_equal(buf, [[0, 0, 1, 1, 1, 0, 0, 1, 1]] * 17)
assert np.array_equal(buf2, [[1]] * 17)

buf = np.zeros(shape=(17, 10), dtype=np.bool_)
buf2 = np.zeros(shape=(17, 1), dtype=np.bool_)
ret = sampler.sample(
shots=17,
dets_out=buf,
obs_out=buf2,
append_observables=True,
)
assert ret is buf
assert np.array_equal(buf, [[0, 0, 1, 1, 1, 0, 0, 1, 1, 1]] * 17)
assert np.array_equal(buf2, [[1]] * 17)

buf = np.zeros(shape=(10, 17), dtype=np.bool_).transpose()
buf2 = np.zeros(shape=(17, 1), dtype=np.bool_)
ret = sampler.sample(
shots=17,
dets_out=buf,
obs_out=buf2,
append_observables=True,
)
assert ret is buf
assert np.array_equal(buf, [[0, 0, 1, 1, 1, 0, 0, 1, 1, 1]] * 17)
assert np.array_equal(buf2, [[1]] * 17)
4 changes: 2 additions & 2 deletions src/stim/py/compiled_measurement_sampler.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ CompiledMeasurementSampler::CompiledMeasurementSampler(
}

pybind11::object CompiledMeasurementSampler::sample_to_numpy(size_t num_shots, bool bit_packed) {
simd_bit_table<MAX_BITWORD_WIDTH> sample = sample_batch_measurements(circuit, ref_sample, num_shots, rng, true);
simd_bit_table<MAX_BITWORD_WIDTH> sample = sample_batch_measurements(circuit, ref_sample, num_shots, rng, false);
size_t bits_per_sample = circuit.count_measurements();
return simd_bit_table_to_numpy(sample, num_shots, bits_per_sample, bit_packed);
return simd_bit_table_to_numpy(sample, bits_per_sample, num_shots, bit_packed, true, pybind11::none());
}

void CompiledMeasurementSampler::sample_write(size_t num_samples, std::string_view filepath, std::string_view format) {
Expand Down
Loading

0 comments on commit 320288c

Please sign in to comment.