Skip to content

Commit

Permalink
Add stim.FlipSimulator (#612)
Browse files Browse the repository at this point in the history
Python API changes:

- Add `stim.FlipSimulator`
- Add `stim.FlipSimulator.__init__`
- Add `stim.FlipSimulator.batch_size`
- Add `stim.FlipSimulator.do`
- Add `stim.FlipSimulator.get_detector_flips`
- Add `stim.FlipSimulator.get_observable_flips`
- Add `stim.FlipSimulator.get_measurement_flips`
- Add `stim.FlipSimulator.num_detectors`
- Add `stim.FlipSimulator.num_measurements`
- Add `stim.FlipSimulator.num_observables`
- Add `stim.FlipSimulator.num_qubits`
- Add `stim.FlipSimulator.peek_pauli_flips`
- Add `stim.FlipSimulator.set_pauli_flip`

C++ changes:

- Add `frame_simulator.pybind.{h,cc}`
- Add `stim::simd_bit_table<W>::read_across_majors_at_minor_index`
- Add `stim::simd_bit_table<W>::copy_into_different_size_table`
- Add `stim::simd_bit_table<W>::resize`
- Add `stim.FlipSimulator`
- Split `stim::FrameSimulator::reset_all_and_run` into two methods
- Add `FrameSimulatorMode::STORE_EVERYTHING_TO_MEMORY`
- Add `FrameSimulator::ensure_safe_to_do_circuit_with_stats`
- Add `FrameSimulator::safe_do_instruction`
- Add `FrameSimulator::safe_do_circuit`
- Move `stim::CircuitStats` from circuit file to circuit instruction file
- Add `stim::CircuitInstruction::compute_stats`

Fixes #306
  • Loading branch information
Strilanc authored Aug 20, 2023
1 parent b56c1d7 commit 836a564
Show file tree
Hide file tree
Showing 21 changed files with 3,068 additions and 84 deletions.
624 changes: 624 additions & 0 deletions doc/python_api_reference_vDev.md

Large diffs are not rendered by default.

520 changes: 520 additions & 0 deletions doc/stim.pyi

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions file_lists/python_api_files
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ src/stim/py/march.pybind.cc
src/stim/py/numpy.pybind.cc
src/stim/py/stim.pybind.cc
src/stim/simulators/dem_sampler.pybind.cc
src/stim/simulators/frame_simulator.pybind.cc
src/stim/simulators/matched_error.pybind.cc
src/stim/simulators/measurements_to_detection_events.pybind.cc
src/stim/simulators/tableau_simulator.pybind.cc
Expand Down
520 changes: 520 additions & 0 deletions glue/python/src/stim/__init__.pyi

Large diffs are not rendered by default.

52 changes: 1 addition & 51 deletions src/stim/circuit/circuit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -706,59 +706,9 @@ size_t Circuit::count_sweep_bits() const {

CircuitStats Circuit::compute_stats() const {
CircuitStats total;

for (const auto &op : operations) {
if (op.gate_type == REPEAT) {
// Recurse into blocks.
auto sub = op.repeat_block_body(*this).compute_stats();
auto reps = op.repeat_block_rep_count();
total.num_observables = std::max(total.num_observables, sub.num_observables);
total.num_qubits = std::max(total.num_qubits, sub.num_qubits);
total.max_lookback = std::max(total.max_lookback, sub.max_lookback);
total.num_sweep_bits = std::max(total.num_sweep_bits, sub.num_sweep_bits);
total.num_detectors = add_saturate(total.num_detectors, mul_saturate(sub.num_detectors, reps));
total.num_measurements = add_saturate(total.num_measurements, mul_saturate(sub.num_measurements, reps));
total.num_ticks = add_saturate(total.num_ticks, mul_saturate(sub.num_ticks, reps));
continue;
}

for (auto t : op.targets) {
auto v = t.data & TARGET_VALUE_MASK;
// Qubit counting.
if (!(t.data & (TARGET_RECORD_BIT | TARGET_SWEEP_BIT))) {
total.num_qubits = std::max(total.num_qubits, v + 1);
}
// Lookback counting.
if (t.data & TARGET_RECORD_BIT) {
total.max_lookback = std::max(total.max_lookback, v);
}
// Sweep bit counting.
if (t.data & TARGET_SWEEP_BIT) {
total.num_sweep_bits = std::max(total.num_sweep_bits, v + 1);
}
}

// Measurement counting.
total.num_measurements += op.count_measurement_results();

switch (op.gate_type) {
case GateType::DETECTOR:
// Detector counting.
total.num_detectors += total.num_detectors < UINT64_MAX;
break;
case GateType::OBSERVABLE_INCLUDE:
// Observable counting.
total.num_observables = std::max(total.num_observables, (uint64_t)op.args[0] + 1);
break;
case GateType::TICK:
// Tick counting.
total.num_ticks++;
break;
default:
break;
}
op.add_stats_to(total, this);
}

return total;
}

Expand Down
11 changes: 0 additions & 11 deletions src/stim/circuit/circuit.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,6 @@ namespace stim {
uint64_t add_saturate(uint64_t a, uint64_t b);
uint64_t mul_saturate(uint64_t a, uint64_t b);

/// Stores a variety of circuit quantities relevant for sizing memory.
struct CircuitStats {
uint64_t num_detectors = 0;
uint64_t num_observables = 0;
uint64_t num_measurements = 0;
uint32_t num_qubits = 0;
uint32_t num_ticks = 0;
uint32_t max_lookback = 0;
uint32_t num_sweep_bits = 0;
};

/// A description of a quantum computation.
struct Circuit {
/// Backing data stores for variable-sized target data referenced by operations.
Expand Down
61 changes: 61 additions & 0 deletions src/stim/circuit/circuit_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,67 @@ Circuit &CircuitInstruction::repeat_block_body(Circuit &host) const {
return host.blocks[b];
}

CircuitStats CircuitInstruction::compute_stats(const Circuit *host) const {
CircuitStats out;
add_stats_to(out, host);
return out;
}

void CircuitInstruction::add_stats_to(CircuitStats &out, const Circuit *host) const {
if (gate_type == REPEAT) {
if (host == nullptr) {
throw std::invalid_argument("gate_type == REPEAT && host == nullptr");
}
// Recurse into blocks.
auto sub = repeat_block_body(*host).compute_stats();
auto reps = repeat_block_rep_count();
out.num_observables = std::max(out.num_observables, sub.num_observables);
out.num_qubits = std::max(out.num_qubits, sub.num_qubits);
out.max_lookback = std::max(out.max_lookback, sub.max_lookback);
out.num_sweep_bits = std::max(out.num_sweep_bits, sub.num_sweep_bits);
out.num_detectors = add_saturate(out.num_detectors, mul_saturate(sub.num_detectors, reps));
out.num_measurements = add_saturate(out.num_measurements, mul_saturate(sub.num_measurements, reps));
out.num_ticks = add_saturate(out.num_ticks, mul_saturate(sub.num_ticks, reps));
return;
}

for (auto t : targets) {
auto v = t.data & TARGET_VALUE_MASK;
// Qubit counting.
if (!(t.data & (TARGET_RECORD_BIT | TARGET_SWEEP_BIT))) {
out.num_qubits = std::max(out.num_qubits, v + 1);
}
// Lookback counting.
if (t.data & TARGET_RECORD_BIT) {
out.max_lookback = std::max(out.max_lookback, v);
}
// Sweep bit counting.
if (t.data & TARGET_SWEEP_BIT) {
out.num_sweep_bits = std::max(out.num_sweep_bits, v + 1);
}
}

// Measurement counting.
out.num_measurements += count_measurement_results();

switch (gate_type) {
case GateType::DETECTOR:
// Detector counting.
out.num_detectors += out.num_detectors < UINT64_MAX;
break;
case GateType::OBSERVABLE_INCLUDE:
// Observable counting.
out.num_observables = std::max(out.num_observables, (uint64_t)args[0] + 1);
break;
case GateType::TICK:
// Tick counting.
out.num_ticks++;
break;
default:
break;
}
}

const Circuit &CircuitInstruction::repeat_block_body(const Circuit &host) const {
assert(targets.size() == 3);
auto b = targets[0].data;
Expand Down
28 changes: 28 additions & 0 deletions src/stim/circuit/circuit_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,29 @@ namespace stim {

struct Circuit;

/// Stores a variety of circuit quantities relevant for sizing memory.
struct CircuitStats {
uint64_t num_detectors = 0;
uint64_t num_observables = 0;
uint64_t num_measurements = 0;
uint32_t num_qubits = 0;
uint64_t num_ticks = 0;
uint32_t max_lookback = 0;
uint32_t num_sweep_bits = 0;

inline CircuitStats repeated(uint64_t repetitions) const {
return CircuitStats{
num_detectors * repetitions,
num_observables,
num_measurements * repetitions,
num_qubits,
(uint32_t)(num_ticks * repetitions),
max_lookback,
num_sweep_bits,
};
}
};

/// The data that describes how a gate is being applied to qubits (or other targets).
///
/// A gate applied to targets.
Expand All @@ -49,6 +72,11 @@ struct CircuitInstruction {
CircuitInstruction() = delete;
CircuitInstruction(GateType gate_type, SpanRef<const double> args, SpanRef<const GateTarget> targets);

/// Computes number of qubits, number of measurements, etc.
CircuitStats compute_stats(const Circuit *host) const;
/// Computes number of qubits, number of measurements, etc and adds them into a target.
void add_stats_to(CircuitStats &out, const Circuit *host) const;

/// Determines if two operations can be combined into one operation (with combined targeting data).
///
/// For example, `H 1` then `H 2 1` is equivalent to `H 1 2 1` so those instructions are fusable.
Expand Down
10 changes: 10 additions & 0 deletions src/stim/mem/simd_bit_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ struct simd_bit_table {
/// Resizes the table. Doesn't clear to zero. Does nothing if already the target size.
void destructive_resize(size_t new_min_bits_major, size_t new_min_bits_minor);

/// Copies the table into another table.
///
/// It's safe for the other table to have a different size.
/// When the other table has a different size, only the data at locations common to both
/// tables are copied over.
void copy_into_different_size_table(simd_bit_table<W> &other) const;

/// Resizes the table, keeping any data common to the old and new size and otherwise zeroing data.
void resize(size_t new_min_bits_major, size_t new_min_bits_minor);

/// Equality.
bool operator==(const simd_bit_table &other) const;
/// Inequality.
Expand Down
42 changes: 42 additions & 0 deletions src/stim/mem/simd_bit_table.inl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,36 @@ void simd_bit_table<W>::destructive_resize(size_t new_min_bits_major, size_t new
data.destructive_resize(num_simd_words_minor * num_simd_words_major * W * W);
}

template <size_t W>
void simd_bit_table<W>::copy_into_different_size_table(simd_bit_table<W> &other) const {
size_t ni = num_simd_words_minor;
size_t na = num_simd_words_major;
size_t mi = other.num_simd_words_minor;
size_t ma = other.num_simd_words_major;
size_t num_min_bytes = std::min(ni, mi) * (W / 8);
size_t num_maj = std::min(na, ma) * W;

if (ni == mi) {
memcpy(other.data.ptr_simd, data.ptr_simd, num_min_bytes * num_maj);
} else {
for (size_t maj = 0; maj < num_maj; maj++) {
memcpy(other[maj].ptr_simd, (*this)[maj].ptr_simd, num_min_bytes);
}
}
}

template <size_t W>
void simd_bit_table<W>::resize(size_t new_min_bits_major, size_t new_min_bits_minor) {
auto new_num_simd_words_minor = min_bits_to_num_simd_words<W>(new_min_bits_minor);
auto new_num_simd_words_major = min_bits_to_num_simd_words<W>(new_min_bits_major);
if (new_num_simd_words_major == num_simd_words_major && new_num_simd_words_minor == num_simd_words_minor) {
return;
}
auto new_table = simd_bit_table<W>(new_min_bits_major, new_min_bits_minor);
copy_into_different_size_table(new_table);
*this = std::move(new_table);
}

template <size_t W>
void simd_bit_table<W>::do_square_transpose() {
assert(num_simd_words_minor == num_simd_words_major);
Expand Down Expand Up @@ -138,6 +168,18 @@ simd_bit_table<W> simd_bit_table<W>::transposed() const {
return result;
}

template <size_t W>
simd_bits<W> simd_bit_table<W>::read_across_majors_at_minor_index(size_t major_start, size_t major_stop, size_t minor_index) const {
assert(major_stop >= major_start);
assert(major_stop <= num_major_bits_padded());
assert(minor_index < num_minor_bits_padded());
simd_bits<W> result(major_stop - major_start);
for (size_t maj = major_start; maj < major_stop; maj++) {
result[maj - major_start] = (*this)[maj][minor_index];
}
return result;
}

template <size_t W>
simd_bit_table<W> simd_bit_table<W>::slice_maj(size_t maj_start_bit, size_t maj_stop_bit) const {
simd_bit_table<W> result(maj_stop_bit - maj_start_bit, num_minor_bits_padded());
Expand Down
106 changes: 105 additions & 1 deletion src/stim/mem/simd_bit_table.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ TEST(simd_bit_table, lg) {

TEST_EACH_WORD_SIZE_W(simd_bit_table, destructive_resize, {
auto rng = INDEPENDENT_TEST_RNG();
simd_bit_table<W> table = table.random(5, 7, rng);
simd_bit_table<W> table = simd_bit_table<W>::random(5, 7, rng);
const uint8_t *prev_pointer = table.data.u8;
table.destructive_resize(5, 7);
ASSERT_EQ(table.data.u8, prev_pointer);
Expand All @@ -302,3 +302,107 @@ TEST_EACH_WORD_SIZE_W(simd_bit_table, destructive_resize, {
ASSERT_GE(table.num_major_bits_padded(), 1025);
ASSERT_GE(table.num_minor_bits_padded(), 7);
})

TEST_EACH_WORD_SIZE_W(simd_bit_table, read_across_majors_at_minor_index, {
auto rng = INDEPENDENT_TEST_RNG();
simd_bit_table<W> table = simd_bit_table<W>::random(5, 7, rng);
simd_bits<W> slice = table.read_across_majors_at_minor_index(2, 5, 1);
ASSERT_GE(slice.num_bits_padded(), 4);
ASSERT_EQ(slice[0], table[2][1]);
ASSERT_EQ(slice[1], table[3][1]);
ASSERT_EQ(slice[2], table[4][1]);
ASSERT_EQ(slice[3], false);
})

template <size_t W>
bool is_table_overlap_identical(const simd_bit_table<W> &a, const simd_bit_table<W> &b) {
size_t w_min = std::min(a.num_simd_words_minor, b.num_simd_words_minor);
size_t n_maj = std::min(a.num_major_bits_padded(), b.num_major_bits_padded());
for (size_t k_maj = 0; k_maj < n_maj; k_maj++) {
if (a[k_maj].word_range_ref(0, w_min) != b[k_maj].word_range_ref(0, w_min)) {
return false;
}
}
return true;
}

template <size_t W>
bool is_table_zero_outside(const simd_bit_table<W> &a, size_t num_major_bits, size_t num_minor_bits) {
size_t num_major_words = min_bits_to_num_simd_words<W>(num_major_bits);
size_t num_minor_words = min_bits_to_num_simd_words<W>(num_minor_bits);
if (a.num_simd_words_minor > num_minor_words) {
for (size_t k = 0; k < a.num_simd_words_major; k++) {
if (a[k].word_range_ref(num_minor_words, a.num_simd_words_minor - num_minor_words).not_zero()) {
return false;
}
}
}
for (size_t k = a.num_simd_words_major; k < num_major_words; k++) {
if (a[k].not_zero()) {
return false;
}
}
return true;
}

TEST_EACH_WORD_SIZE_W(simd_bit_table, copy_into_different_size_table, {
auto rng = INDEPENDENT_TEST_RNG();

auto check_size = [&](size_t w1, size_t h1, size_t w2, size_t h2) {
simd_bit_table<W> src = simd_bit_table<W>::random(w1, h1, rng);
simd_bit_table<W> dst = simd_bit_table<W>::random(w1, h1, rng);
src.copy_into_different_size_table(dst);
return is_table_overlap_identical(src, dst);
};

EXPECT_TRUE(check_size(0, 0, 0, 0));

EXPECT_TRUE(check_size(64, 0, 0, 0));
EXPECT_TRUE(check_size(0, 64, 0, 0));
EXPECT_TRUE(check_size(0, 0, 64, 0));
EXPECT_TRUE(check_size(0, 0, 0, 64));

EXPECT_TRUE(check_size(64, 64, 64, 64));
EXPECT_TRUE(check_size(512, 64, 64, 64));
EXPECT_TRUE(check_size(64, 512, 64, 64));
EXPECT_TRUE(check_size(64, 64, 512, 64));
EXPECT_TRUE(check_size(64, 64, 64, 512));

EXPECT_TRUE(check_size(512, 512, 64, 64));
EXPECT_TRUE(check_size(512, 64, 512, 64));
EXPECT_TRUE(check_size(512, 64, 64, 512));
EXPECT_TRUE(check_size(64, 512, 512, 64));
EXPECT_TRUE(check_size(64, 512, 64, 512));
EXPECT_TRUE(check_size(64, 64, 512, 512));
})

TEST_EACH_WORD_SIZE_W(simd_bit_table, resize, {
auto rng = INDEPENDENT_TEST_RNG();

auto check_size = [&](size_t w1, size_t h1, size_t w2, size_t h2) {
simd_bit_table<W> src = simd_bit_table<W>::random(w1, h1, rng);
simd_bit_table<W> dst = src;
dst.resize(w2, h2);
return is_table_overlap_identical(src, dst) && is_table_zero_outside(dst, std::min(w1, w2), std::min(h1, h2));
};

EXPECT_TRUE(check_size(0, 0, 0, 0));

EXPECT_TRUE(check_size(64, 0, 0, 0));
EXPECT_TRUE(check_size(0, 64, 0, 0));
EXPECT_TRUE(check_size(0, 0, 64, 0));
EXPECT_TRUE(check_size(0, 0, 0, 64));

EXPECT_TRUE(check_size(64, 64, 64, 64));
EXPECT_TRUE(check_size(512, 64, 64, 64));
EXPECT_TRUE(check_size(64, 512, 64, 64));
EXPECT_TRUE(check_size(64, 64, 512, 64));
EXPECT_TRUE(check_size(64, 64, 64, 512));

EXPECT_TRUE(check_size(512, 512, 64, 64));
EXPECT_TRUE(check_size(512, 64, 512, 64));
EXPECT_TRUE(check_size(512, 64, 64, 512));
EXPECT_TRUE(check_size(64, 512, 512, 64));
EXPECT_TRUE(check_size(64, 512, 64, 512));
EXPECT_TRUE(check_size(64, 64, 512, 512));
})
3 changes: 2 additions & 1 deletion src/stim/py/compiled_detector_sampler.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ pybind11::object CompiledDetectorSampler::sample_to_numpy(
}

frame_sim.configure_for(circuit_stats, FrameSimulatorMode::STORE_DETECTIONS_TO_MEMORY, num_shots);
frame_sim.reset_all_and_run(circuit);
frame_sim.reset_all();
frame_sim.do_circuit(circuit);

const auto &det_data = frame_sim.det_record.storage;
const auto &obs_data = frame_sim.obs_record;
Expand Down
Loading

0 comments on commit 836a564

Please sign in to comment.