diff --git a/.github/workflows/miri.yml b/.github/workflows/miri.yml new file mode 100644 index 0000000..068177a --- /dev/null +++ b/.github/workflows/miri.yml @@ -0,0 +1,52 @@ +name: Miri + +on: + push: + branches: ["develop"] + pull_request: {} + workflow_dispatch: {} + +permissions: + actions: read + contents: read + +jobs: + miri: + name: "miri" + runs-on: ubuntu-latest + env: + RUST_BACKTRACE: 1 + MIRIFLAGS: -Zmiri-strict-provenance -Zmiri-symbolic-alignment-check -Zmiri-backtrace=full + steps: + - uses: actions/checkout@v4 + + - name: Rust Version + id: rust-version + shell: bash + run: echo "version=$(cat rust-toolchain.toml | grep channel | awk -F'\"' '{print $2}')" >> $GITHUB_OUTPUT + + - name: Rust Toolchain + id: rust-toolchain + uses: dtolnay/rust-toolchain@master + if: steps.rustup-cache.outputs.cache-hit != 'true' + with: + toolchain: "${{ steps.rust-version.outputs.version }}" + components: miri + + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/develop' }} + shared-key: "shared" # To allow reuse across jobs + + - name: Rust Compile Cache + uses: mozilla-actions/sccache-action@v0.0.5 + - name: Rust Compile Cache Config + shell: bash + run: | + echo "SCCACHE_GHA_ENABLED=true" >> $GITHUB_ENV + echo "RUSTC_WRAPPER=sccache" >> $GITHUB_ENV + echo "CARGO_INCREMENTAL=0" >> $GITHUB_ENV + + - name: Run tests with Miri + run: cargo miri test diff --git a/src/builder.rs b/src/builder.rs index c3272ae..42f2fb5 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -9,6 +9,75 @@ use std::collections::BinaryHeap; use crate::{Compressor, Symbol, ESCAPE_CODE, MAX_CODE}; +/// Bitmap that only works for values up to 512 +#[derive(Clone, Copy, Debug, Default)] +struct CodesBitmap { + codes: [u64; 8], +} + +assert_sizeof!(CodesBitmap => 64); + +impl CodesBitmap { + /// Set the indicated bit. Must be between 0 and [`MAX_CODE`][crate::MAX_CODE]. + pub(crate) fn set(&mut self, index: usize) { + debug_assert!(index <= MAX_CODE as usize, "code cannot exceed {MAX_CODE}"); + + let map = index >> 6; + self.codes[map] |= 1 << (index % 64); + } + + /// Get all codes set in this bitmap + pub(crate) fn codes(&self) -> CodesIterator { + CodesIterator { + inner: self, + index: 0, + block: self.codes[0], + reference: 0, + } + } +} + +struct CodesIterator<'a> { + inner: &'a CodesBitmap, + index: usize, + block: u64, + reference: usize, +} + +impl<'a> Iterator for CodesIterator<'a> { + type Item = u16; + + fn next(&mut self) -> Option { + // If current is zero, advance to next non-zero block + while self.block == 0 { + self.index += 1; + if self.index >= 8 { + return None; + } + self.block = self.inner.codes[self.index]; + self.reference = self.index * 64; + } + + if self.block == 0 { + return None; + } + + // Find the next set bit in the current block. + let position = self.block.trailing_zeros() as usize; + let code = self.reference + position; + + // The next iteration will calculate with reference to the returned code + 1 + self.reference = code + 1; + self.block = if position == 63 { + 0 + } else { + self.block >> (1 + position) + }; + + Some(code as u16) + } +} + #[derive(Debug, Clone)] struct Counter { /// Frequency count for each code. @@ -16,6 +85,12 @@ struct Counter { /// Frequency count for each code-pair. counts2: Vec, + + /// Bitmap index of pairs that have been set. + /// + /// `pair_index[code1].codes()` yields an iterator that can + /// be used to find all possible codes that follow `codes1`. + pair_index: Vec, } const COUNTS1_SIZE: usize = MAX_CODE as usize; @@ -28,16 +103,7 @@ impl Counter { Self { counts1: vec![0; COUNTS1_SIZE], counts2: vec![0; COUNTS2_SIZE], - } - } - - /// reset - pub fn reset(&mut self) { - for idx in 0..COUNTS1_SIZE { - self.counts1[idx] = 0; - } - for idx in 0..COUNTS2_SIZE { - self.counts2[idx] = 0; + pair_index: vec![CodesBitmap::default(); COUNTS1_SIZE], } } @@ -50,6 +116,7 @@ impl Counter { fn record_count2(&mut self, code1: u16, code2: u16) { let idx = (code1 as usize) * 511 + (code2 as usize); self.counts2[idx] += 1; + self.pair_index[code1 as usize].set(code2 as usize); } #[inline] @@ -62,12 +129,24 @@ impl Counter { let idx = (code1 as usize) * 511 + (code2 as usize); self.counts2[idx] } + + /// Returns an iterator over the codes that have been observed + /// to follow `code1`. + /// + /// This is the set of all values `code2` where there was + /// previously a call to `self.record_count2(code1, code2)`. + fn second_codes(&self, code1: u16) -> CodesIterator { + self.pair_index[code1 as usize].codes() + } } /// The number of generations used for training. This is taken from the [FSST paper]. /// /// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf +#[cfg(not(miri))] const MAX_GENERATIONS: usize = 5; +#[cfg(miri)] +const MAX_GENERATIONS: usize = 2; impl Compressor { /// Build and train a `Compressor` from a sample corpus of text. @@ -87,14 +166,13 @@ impl Compressor { return compressor; } - let mut counter = Counter::new(); - for _generation in 0..(MAX_GENERATIONS - 1) { + let mut counter = Counter::new(); compressor.compress_count(sample, &mut counter); compressor = compressor.optimize(&counter, true); - counter.reset(); } + let mut counter = Counter::new(); compressor.compress_count(sample, &mut counter); compressor.optimize(&counter, true) } @@ -142,9 +220,16 @@ impl Compressor { fn optimize(&self, counters: &Counter, include_ascii: bool) -> Self { let mut res = Compressor::default(); let mut pqueue = BinaryHeap::with_capacity(65_536); + for code1 in 0u16..(256u16 + self.n_symbols as u16) { let symbol1 = self.symbols[code1 as usize]; - let mut gain = counters.count1(code1) * symbol1.len(); + let count = counters.count1(code1); + // If count is zero, we can skip the whole inner loop. + if count == 0 { + continue; + } + + let mut gain = count * symbol1.len(); // NOTE: use heuristic from C++ implementation to boost the gain of single-byte symbols. // This helps to reduce exception counts. if code1 < 256 { @@ -157,10 +242,10 @@ impl Compressor { }); } - for code2 in 0u16..(256u16 + self.n_symbols as u16) { + for code2 in counters.second_codes(code1) { let symbol2 = &self.symbols[code2 as usize]; - // If either symbol is zero-length, or if merging would yield a symbol of - // length greater than 8, skip. + + // If merging would yield a symbol of length greater than 8, skip. if symbol1.len() + symbol2.len() > 8 { continue; } @@ -247,8 +332,7 @@ impl Ord for Candidate { #[cfg(test)] mod test { - - use crate::{Compressor, ESCAPE_CODE}; + use crate::{builder::CodesBitmap, Compressor, ESCAPE_CODE}; #[test] fn test_builder() { @@ -282,4 +366,44 @@ mod test { ] ); } + + #[test] + fn test_bitmap() { + let mut map = CodesBitmap::default(); + map.set(10); + map.set(100); + map.set(500); + + let codes: Vec = map.codes().collect(); + assert_eq!(codes, vec![10u16, 100, 500]); + + // empty case + let map = CodesBitmap::default(); + assert_eq!(map.codes().collect::>(), vec![]); + + // edge case: first bit in each block is set + let mut map = CodesBitmap::default(); + (0..8).for_each(|i| map.set(64 * i)); + assert_eq!( + map.codes().collect::>(), + (0u16..8).map(|i| 64 * i).collect::>(), + ); + + // Full bitmap case. There are only 512 values, so test them all + let mut map = CodesBitmap::default(); + for i in 0..512 { + map.set(i); + } + assert_eq!( + map.codes().collect::>(), + (0u16..512u16).collect::>() + ); + } + + #[test] + #[should_panic(expected = "code cannot exceed")] + fn test_bitmap_invalid() { + let mut map = CodesBitmap::default(); + map.set(512); + } } diff --git a/tests/correctness.rs b/tests/correctness.rs index 5a68cb1..64f3ba7 100644 --- a/tests/correctness.rs +++ b/tests/correctness.rs @@ -47,16 +47,13 @@ fn test_one_byte() { #[test] fn test_zeros() { - println!("training zeros"); let training_data: Vec = vec![0, 1, 2, 3, 4, 0]; let trained = Compressor::train(&training_data); - println!("compressing with zeros"); let compressed = trained.compress(&[4, 0]); - println!("decomperssing with zeros"); assert_eq!(trained.decompressor().decompress(&compressed), &[4, 0]); - println!("done"); } +#[cfg_attr(miri, ignore)] #[test] fn test_large() { let corpus: Vec = DECLARATION.bytes().cycle().take(10_240).collect();