From b374b5d219338a56dd0d8b8876713f5dca6b020b Mon Sep 17 00:00:00 2001 From: Andrew Duffy Date: Tue, 20 Aug 2024 16:44:23 -0400 Subject: [PATCH 1/6] feat: make Compressor::train 2x faster with bitmap index The slowest part of Compressor::train is the double-nested loops over codes. Now compress_count when it records code pairs will also populate a bitmap index, where `pairs_index[code1].set(code2)` will indicate that code2 followed code1 in compressed output. In the `optimize` loop, we can eliminate tight loop iterations by accessing `pairse_index[code1].second_codes()` which yields the value code2 values. This results in a speedup from ~1ms -> 500micros. --- src/builder.rs | 135 +++++++++++++++++++++++++++++++++++++------ tests/correctness.rs | 5 +- 2 files changed, 117 insertions(+), 23 deletions(-) diff --git a/src/builder.rs b/src/builder.rs index c3272ae..0da0f68 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -9,6 +9,80 @@ use std::collections::BinaryHeap; use crate::{Compressor, Symbol, ESCAPE_CODE, MAX_CODE}; +/// Bitmap that only works for values up to 512 +#[derive(Clone, Copy, Debug, Default)] +#[allow(dead_code)] +struct CodesBitmap { + codes: [u64; 8], +} + +assert_sizeof!(CodesBitmap => 64); + +#[allow(dead_code)] +impl CodesBitmap { + /// Set the indicated bit. Must be between 0 and [`MAX_CODE`][crate::MAX_CODE]. + pub(crate) fn set(&mut self, index: usize) { + debug_assert!( + index <= MAX_CODE as usize, + "CodesBitmap only works on codes <= {MAX_CODE}" + ); + + let map = index >> 6; + self.codes[map] |= 1 << (index % 64); + } + + /// Get all codes set in this bitmap + pub(crate) fn codes(&self) -> CodesIterator { + CodesIterator { + inner: self, + index: 0, + block: self.codes[0], + reference: 0, + } + } +} + +struct CodesIterator<'a> { + inner: &'a CodesBitmap, + index: usize, + block: u64, + reference: usize, +} + +impl<'a> Iterator for CodesIterator<'a> { + type Item = u16; + + fn next(&mut self) -> Option { + // If current is zero, advance to next non-zero block + while self.block == 0 { + self.index += 1; + if self.index >= 8 { + return None; + } + self.block = self.inner.codes[self.index]; + self.reference = self.index * 64; + } + + if self.block == 0 { + return None; + } + + // Find the next set bit in the current block. + let position = self.block.trailing_zeros() as usize; + let code = self.reference + position; + + // Reference is advanced by the set position in the bit iterator. + self.reference += position; + self.block = if position == 63 { + 0 + } else { + self.block >> (1 + position) + }; + + Some(code as u16) + } +} + #[derive(Debug, Clone)] struct Counter { /// Frequency count for each code. @@ -16,6 +90,12 @@ struct Counter { /// Frequency count for each code-pair. counts2: Vec, + + /// Bitmap index of pairs that have been set. + /// + /// `pair_index[code1].codes()` yields an iterator that can + /// be used to find the values of codes in the outside iterator. + pair_index: Vec, } const COUNTS1_SIZE: usize = MAX_CODE as usize; @@ -28,16 +108,7 @@ impl Counter { Self { counts1: vec![0; COUNTS1_SIZE], counts2: vec![0; COUNTS2_SIZE], - } - } - - /// reset - pub fn reset(&mut self) { - for idx in 0..COUNTS1_SIZE { - self.counts1[idx] = 0; - } - for idx in 0..COUNTS2_SIZE { - self.counts2[idx] = 0; + pair_index: vec![CodesBitmap::default(); COUNTS1_SIZE], } } @@ -50,6 +121,7 @@ impl Counter { fn record_count2(&mut self, code1: u16, code2: u16) { let idx = (code1 as usize) * 511 + (code2 as usize); self.counts2[idx] += 1; + self.pair_index[code1 as usize].set(code2 as usize); } #[inline] @@ -62,6 +134,11 @@ impl Counter { let idx = (code1 as usize) * 511 + (code2 as usize); self.counts2[idx] } + + /// Access to the second-code in a code pair following `code1`. + fn second_codes(&self, code1: u16) -> CodesIterator { + self.pair_index[code1 as usize].codes() + } } /// The number of generations used for training. This is taken from the [FSST paper]. @@ -87,14 +164,13 @@ impl Compressor { return compressor; } - let mut counter = Counter::new(); - for _generation in 0..(MAX_GENERATIONS - 1) { + let mut counter = Counter::new(); compressor.compress_count(sample, &mut counter); compressor = compressor.optimize(&counter, true); - counter.reset(); } + let mut counter = Counter::new(); compressor.compress_count(sample, &mut counter); compressor.optimize(&counter, true) } @@ -142,9 +218,16 @@ impl Compressor { fn optimize(&self, counters: &Counter, include_ascii: bool) -> Self { let mut res = Compressor::default(); let mut pqueue = BinaryHeap::with_capacity(65_536); + for code1 in 0u16..(256u16 + self.n_symbols as u16) { let symbol1 = self.symbols[code1 as usize]; - let mut gain = counters.count1(code1) * symbol1.len(); + let count = counters.count1(code1); + // If count is zero, we can skip the whole inner loop. + if count == 0 { + continue; + } + + let mut gain = count * symbol1.len(); // NOTE: use heuristic from C++ implementation to boost the gain of single-byte symbols. // This helps to reduce exception counts. if code1 < 256 { @@ -157,10 +240,10 @@ impl Compressor { }); } - for code2 in 0u16..(256u16 + self.n_symbols as u16) { + for code2 in counters.second_codes(code1) { let symbol2 = &self.symbols[code2 as usize]; - // If either symbol is zero-length, or if merging would yield a symbol of - // length greater than 8, skip. + + // If merging would yield a symbol of length greater than 8, skip. if symbol1.len() + symbol2.len() > 8 { continue; } @@ -247,8 +330,7 @@ impl Ord for Candidate { #[cfg(test)] mod test { - - use crate::{Compressor, ESCAPE_CODE}; + use crate::{builder::CodesBitmap, Compressor, ESCAPE_CODE}; #[test] fn test_builder() { @@ -282,4 +364,19 @@ mod test { ] ); } + + #[test] + fn test_bitmap() { + let mut map = CodesBitmap::default(); + map.set(10); + map.set(100); + map.set(500); + + let codes: Vec = map.codes().collect(); + assert_eq!(codes, vec![10u16, 100, 500]); + + // empty case + let map = CodesBitmap::default(); + assert_eq!(map.codes().collect::>(), vec![]); + } } diff --git a/tests/correctness.rs b/tests/correctness.rs index 5a68cb1..64f3ba7 100644 --- a/tests/correctness.rs +++ b/tests/correctness.rs @@ -47,16 +47,13 @@ fn test_one_byte() { #[test] fn test_zeros() { - println!("training zeros"); let training_data: Vec = vec![0, 1, 2, 3, 4, 0]; let trained = Compressor::train(&training_data); - println!("compressing with zeros"); let compressed = trained.compress(&[4, 0]); - println!("decomperssing with zeros"); assert_eq!(trained.decompressor().decompress(&compressed), &[4, 0]); - println!("done"); } +#[cfg_attr(miri, ignore)] #[test] fn test_large() { let corpus: Vec = DECLARATION.bytes().cycle().take(10_240).collect(); From 720d5064d07d63ae99f73eff4c509946ff20d242 Mon Sep 17 00:00:00 2001 From: Andrew Duffy Date: Tue, 20 Aug 2024 16:49:16 -0400 Subject: [PATCH 2/6] add miri action --- .github/workflows/miri.yml | 52 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 .github/workflows/miri.yml diff --git a/.github/workflows/miri.yml b/.github/workflows/miri.yml new file mode 100644 index 0000000..068177a --- /dev/null +++ b/.github/workflows/miri.yml @@ -0,0 +1,52 @@ +name: Miri + +on: + push: + branches: ["develop"] + pull_request: {} + workflow_dispatch: {} + +permissions: + actions: read + contents: read + +jobs: + miri: + name: "miri" + runs-on: ubuntu-latest + env: + RUST_BACKTRACE: 1 + MIRIFLAGS: -Zmiri-strict-provenance -Zmiri-symbolic-alignment-check -Zmiri-backtrace=full + steps: + - uses: actions/checkout@v4 + + - name: Rust Version + id: rust-version + shell: bash + run: echo "version=$(cat rust-toolchain.toml | grep channel | awk -F'\"' '{print $2}')" >> $GITHUB_OUTPUT + + - name: Rust Toolchain + id: rust-toolchain + uses: dtolnay/rust-toolchain@master + if: steps.rustup-cache.outputs.cache-hit != 'true' + with: + toolchain: "${{ steps.rust-version.outputs.version }}" + components: miri + + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/develop' }} + shared-key: "shared" # To allow reuse across jobs + + - name: Rust Compile Cache + uses: mozilla-actions/sccache-action@v0.0.5 + - name: Rust Compile Cache Config + shell: bash + run: | + echo "SCCACHE_GHA_ENABLED=true" >> $GITHUB_ENV + echo "RUSTC_WRAPPER=sccache" >> $GITHUB_ENV + echo "CARGO_INCREMENTAL=0" >> $GITHUB_ENV + + - name: Run tests with Miri + run: cargo miri test From fecde16fe51e859a76ca2533eaa03c39c6573e33 Mon Sep 17 00:00:00 2001 From: Andrew Duffy Date: Tue, 20 Aug 2024 16:54:09 -0400 Subject: [PATCH 3/6] final cleanups --- src/builder.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/builder.rs b/src/builder.rs index 0da0f68..3a78131 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -11,14 +11,12 @@ use crate::{Compressor, Symbol, ESCAPE_CODE, MAX_CODE}; /// Bitmap that only works for values up to 512 #[derive(Clone, Copy, Debug, Default)] -#[allow(dead_code)] struct CodesBitmap { codes: [u64; 8], } assert_sizeof!(CodesBitmap => 64); -#[allow(dead_code)] impl CodesBitmap { /// Set the indicated bit. Must be between 0 and [`MAX_CODE`][crate::MAX_CODE]. pub(crate) fn set(&mut self, index: usize) { @@ -94,7 +92,7 @@ struct Counter { /// Bitmap index of pairs that have been set. /// /// `pair_index[code1].codes()` yields an iterator that can - /// be used to find the values of codes in the outside iterator. + /// be used to find all possible codes that follow `codes1`. pair_index: Vec, } @@ -135,7 +133,11 @@ impl Counter { self.counts2[idx] } - /// Access to the second-code in a code pair following `code1`. + /// Returns an iterator over the codes that have been observed + /// to follow `code1`. + /// + /// This is the set of all values `code2` where there was + /// previously a call to `self.record_count2(code1, code2)`. fn second_codes(&self, code1: u16) -> CodesIterator { self.pair_index[code1 as usize].codes() } From 17ac1be3e0befe5d934f5c6932a68286df6d5339 Mon Sep 17 00:00:00 2001 From: Andrew Duffy Date: Tue, 20 Aug 2024 16:58:06 -0400 Subject: [PATCH 4/6] only run miri on develop i don't want to lose my 30s CI checks --- .github/workflows/miri.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/miri.yml b/.github/workflows/miri.yml index 068177a..8de5d55 100644 --- a/.github/workflows/miri.yml +++ b/.github/workflows/miri.yml @@ -3,7 +3,6 @@ name: Miri on: push: branches: ["develop"] - pull_request: {} workflow_dispatch: {} permissions: From aea4ae38b8627f3f3de70bb3d4733419e46641b8 Mon Sep 17 00:00:00 2001 From: Andrew Duffy Date: Tue, 20 Aug 2024 17:41:23 -0400 Subject: [PATCH 5/6] turn miri back on for CI --- .github/workflows/miri.yml | 1 + src/builder.rs | 3 +++ 2 files changed, 4 insertions(+) diff --git a/.github/workflows/miri.yml b/.github/workflows/miri.yml index 8de5d55..068177a 100644 --- a/.github/workflows/miri.yml +++ b/.github/workflows/miri.yml @@ -3,6 +3,7 @@ name: Miri on: push: branches: ["develop"] + pull_request: {} workflow_dispatch: {} permissions: diff --git a/src/builder.rs b/src/builder.rs index 3a78131..f47a7dc 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -146,7 +146,10 @@ impl Counter { /// The number of generations used for training. This is taken from the [FSST paper]. /// /// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf +#[cfg(not(miri))] const MAX_GENERATIONS: usize = 5; +#[cfg(miri)] +const MAX_GENERATIONS: usize = 2; impl Compressor { /// Build and train a `Compressor` from a sample corpus of text. From f74d185903df725e09940fae349988a71e331088 Mon Sep 17 00:00:00 2001 From: Andrew Duffy Date: Tue, 20 Aug 2024 17:57:02 -0400 Subject: [PATCH 6/6] fix small bug in iterator, more tests --- src/builder.rs | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/src/builder.rs b/src/builder.rs index f47a7dc..42f2fb5 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -20,10 +20,7 @@ assert_sizeof!(CodesBitmap => 64); impl CodesBitmap { /// Set the indicated bit. Must be between 0 and [`MAX_CODE`][crate::MAX_CODE]. pub(crate) fn set(&mut self, index: usize) { - debug_assert!( - index <= MAX_CODE as usize, - "CodesBitmap only works on codes <= {MAX_CODE}" - ); + debug_assert!(index <= MAX_CODE as usize, "code cannot exceed {MAX_CODE}"); let map = index >> 6; self.codes[map] |= 1 << (index % 64); @@ -69,8 +66,8 @@ impl<'a> Iterator for CodesIterator<'a> { let position = self.block.trailing_zeros() as usize; let code = self.reference + position; - // Reference is advanced by the set position in the bit iterator. - self.reference += position; + // The next iteration will calculate with reference to the returned code + 1 + self.reference = code + 1; self.block = if position == 63 { 0 } else { @@ -383,5 +380,30 @@ mod test { // empty case let map = CodesBitmap::default(); assert_eq!(map.codes().collect::>(), vec![]); + + // edge case: first bit in each block is set + let mut map = CodesBitmap::default(); + (0..8).for_each(|i| map.set(64 * i)); + assert_eq!( + map.codes().collect::>(), + (0u16..8).map(|i| 64 * i).collect::>(), + ); + + // Full bitmap case. There are only 512 values, so test them all + let mut map = CodesBitmap::default(); + for i in 0..512 { + map.set(i); + } + assert_eq!( + map.codes().collect::>(), + (0u16..512u16).collect::>() + ); + } + + #[test] + #[should_panic(expected = "code cannot exceed")] + fn test_bitmap_invalid() { + let mut map = CodesBitmap::default(); + map.set(512); } }