-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: make Compressor::train 2x faster with bitmap index #16
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
b374b5d
feat: make Compressor::train 2x faster with bitmap index
a10y 720d506
add miri action
a10y fecde16
final cleanups
a10y 17ac1be
only run miri on develop
a10y aea4ae3
turn miri back on for CI
a10y f74d185
fix small bug in iterator, more tests
a10y File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
name: Miri | ||
|
||
on: | ||
push: | ||
branches: ["develop"] | ||
pull_request: {} | ||
workflow_dispatch: {} | ||
|
||
permissions: | ||
actions: read | ||
contents: read | ||
|
||
jobs: | ||
miri: | ||
name: "miri" | ||
runs-on: ubuntu-latest | ||
env: | ||
RUST_BACKTRACE: 1 | ||
MIRIFLAGS: -Zmiri-strict-provenance -Zmiri-symbolic-alignment-check -Zmiri-backtrace=full | ||
steps: | ||
- uses: actions/checkout@v4 | ||
|
||
- name: Rust Version | ||
id: rust-version | ||
shell: bash | ||
run: echo "version=$(cat rust-toolchain.toml | grep channel | awk -F'\"' '{print $2}')" >> $GITHUB_OUTPUT | ||
|
||
- name: Rust Toolchain | ||
id: rust-toolchain | ||
uses: dtolnay/rust-toolchain@master | ||
if: steps.rustup-cache.outputs.cache-hit != 'true' | ||
with: | ||
toolchain: "${{ steps.rust-version.outputs.version }}" | ||
components: miri | ||
|
||
- name: Rust Dependency Cache | ||
uses: Swatinem/rust-cache@v2 | ||
with: | ||
save-if: ${{ github.ref == 'refs/heads/develop' }} | ||
shared-key: "shared" # To allow reuse across jobs | ||
|
||
- name: Rust Compile Cache | ||
uses: mozilla-actions/[email protected] | ||
- name: Rust Compile Cache Config | ||
shell: bash | ||
run: | | ||
echo "SCCACHE_GHA_ENABLED=true" >> $GITHUB_ENV | ||
echo "RUSTC_WRAPPER=sccache" >> $GITHUB_ENV | ||
echo "CARGO_INCREMENTAL=0" >> $GITHUB_ENV | ||
|
||
- name: Run tests with Miri | ||
run: cargo miri test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,13 +9,88 @@ use std::collections::BinaryHeap; | |
|
||
use crate::{Compressor, Symbol, ESCAPE_CODE, MAX_CODE}; | ||
|
||
/// Bitmap that only works for values up to 512 | ||
#[derive(Clone, Copy, Debug, Default)] | ||
struct CodesBitmap { | ||
codes: [u64; 8], | ||
} | ||
|
||
assert_sizeof!(CodesBitmap => 64); | ||
|
||
impl CodesBitmap { | ||
/// Set the indicated bit. Must be between 0 and [`MAX_CODE`][crate::MAX_CODE]. | ||
pub(crate) fn set(&mut self, index: usize) { | ||
debug_assert!(index <= MAX_CODE as usize, "code cannot exceed {MAX_CODE}"); | ||
|
||
let map = index >> 6; | ||
self.codes[map] |= 1 << (index % 64); | ||
} | ||
|
||
/// Get all codes set in this bitmap | ||
pub(crate) fn codes(&self) -> CodesIterator { | ||
CodesIterator { | ||
inner: self, | ||
index: 0, | ||
block: self.codes[0], | ||
reference: 0, | ||
} | ||
} | ||
} | ||
|
||
struct CodesIterator<'a> { | ||
inner: &'a CodesBitmap, | ||
index: usize, | ||
block: u64, | ||
reference: usize, | ||
} | ||
|
||
impl<'a> Iterator for CodesIterator<'a> { | ||
type Item = u16; | ||
|
||
fn next(&mut self) -> Option<Self::Item> { | ||
// If current is zero, advance to next non-zero block | ||
while self.block == 0 { | ||
self.index += 1; | ||
if self.index >= 8 { | ||
return None; | ||
} | ||
self.block = self.inner.codes[self.index]; | ||
self.reference = self.index * 64; | ||
} | ||
|
||
if self.block == 0 { | ||
return None; | ||
} | ||
|
||
// Find the next set bit in the current block. | ||
let position = self.block.trailing_zeros() as usize; | ||
let code = self.reference + position; | ||
|
||
// The next iteration will calculate with reference to the returned code + 1 | ||
self.reference = code + 1; | ||
self.block = if position == 63 { | ||
0 | ||
} else { | ||
self.block >> (1 + position) | ||
}; | ||
|
||
Some(code as u16) | ||
} | ||
} | ||
|
||
#[derive(Debug, Clone)] | ||
struct Counter { | ||
/// Frequency count for each code. | ||
counts1: Vec<usize>, | ||
|
||
/// Frequency count for each code-pair. | ||
counts2: Vec<usize>, | ||
|
||
/// Bitmap index of pairs that have been set. | ||
/// | ||
/// `pair_index[code1].codes()` yields an iterator that can | ||
/// be used to find all possible codes that follow `codes1`. | ||
pair_index: Vec<CodesBitmap>, | ||
} | ||
|
||
const COUNTS1_SIZE: usize = MAX_CODE as usize; | ||
|
@@ -28,16 +103,7 @@ impl Counter { | |
Self { | ||
counts1: vec![0; COUNTS1_SIZE], | ||
counts2: vec![0; COUNTS2_SIZE], | ||
} | ||
} | ||
|
||
/// reset | ||
pub fn reset(&mut self) { | ||
for idx in 0..COUNTS1_SIZE { | ||
self.counts1[idx] = 0; | ||
} | ||
for idx in 0..COUNTS2_SIZE { | ||
self.counts2[idx] = 0; | ||
Comment on lines
-35
to
-40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was slower than just building a new |
||
pair_index: vec![CodesBitmap::default(); COUNTS1_SIZE], | ||
} | ||
} | ||
|
||
|
@@ -50,6 +116,7 @@ impl Counter { | |
fn record_count2(&mut self, code1: u16, code2: u16) { | ||
let idx = (code1 as usize) * 511 + (code2 as usize); | ||
self.counts2[idx] += 1; | ||
self.pair_index[code1 as usize].set(code2 as usize); | ||
} | ||
|
||
#[inline] | ||
|
@@ -62,12 +129,24 @@ impl Counter { | |
let idx = (code1 as usize) * 511 + (code2 as usize); | ||
self.counts2[idx] | ||
} | ||
|
||
/// Returns an iterator over the codes that have been observed | ||
/// to follow `code1`. | ||
/// | ||
/// This is the set of all values `code2` where there was | ||
/// previously a call to `self.record_count2(code1, code2)`. | ||
fn second_codes(&self, code1: u16) -> CodesIterator { | ||
self.pair_index[code1 as usize].codes() | ||
} | ||
} | ||
|
||
/// The number of generations used for training. This is taken from the [FSST paper]. | ||
/// | ||
/// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf | ||
#[cfg(not(miri))] | ||
const MAX_GENERATIONS: usize = 5; | ||
#[cfg(miri)] | ||
const MAX_GENERATIONS: usize = 2; | ||
|
||
impl Compressor { | ||
/// Build and train a `Compressor` from a sample corpus of text. | ||
|
@@ -87,14 +166,13 @@ impl Compressor { | |
return compressor; | ||
} | ||
|
||
let mut counter = Counter::new(); | ||
|
||
for _generation in 0..(MAX_GENERATIONS - 1) { | ||
let mut counter = Counter::new(); | ||
compressor.compress_count(sample, &mut counter); | ||
compressor = compressor.optimize(&counter, true); | ||
counter.reset(); | ||
} | ||
|
||
let mut counter = Counter::new(); | ||
compressor.compress_count(sample, &mut counter); | ||
compressor.optimize(&counter, true) | ||
} | ||
|
@@ -142,9 +220,16 @@ impl Compressor { | |
fn optimize(&self, counters: &Counter, include_ascii: bool) -> Self { | ||
let mut res = Compressor::default(); | ||
let mut pqueue = BinaryHeap::with_capacity(65_536); | ||
|
||
for code1 in 0u16..(256u16 + self.n_symbols as u16) { | ||
let symbol1 = self.symbols[code1 as usize]; | ||
let mut gain = counters.count1(code1) * symbol1.len(); | ||
let count = counters.count1(code1); | ||
// If count is zero, we can skip the whole inner loop. | ||
if count == 0 { | ||
continue; | ||
} | ||
|
||
let mut gain = count * symbol1.len(); | ||
// NOTE: use heuristic from C++ implementation to boost the gain of single-byte symbols. | ||
// This helps to reduce exception counts. | ||
if code1 < 256 { | ||
|
@@ -157,10 +242,10 @@ impl Compressor { | |
}); | ||
} | ||
|
||
for code2 in 0u16..(256u16 + self.n_symbols as u16) { | ||
for code2 in counters.second_codes(code1) { | ||
let symbol2 = &self.symbols[code2 as usize]; | ||
// If either symbol is zero-length, or if merging would yield a symbol of | ||
// length greater than 8, skip. | ||
|
||
// If merging would yield a symbol of length greater than 8, skip. | ||
if symbol1.len() + symbol2.len() > 8 { | ||
continue; | ||
} | ||
|
@@ -247,8 +332,7 @@ impl Ord for Candidate { | |
|
||
#[cfg(test)] | ||
mod test { | ||
|
||
use crate::{Compressor, ESCAPE_CODE}; | ||
use crate::{builder::CodesBitmap, Compressor, ESCAPE_CODE}; | ||
|
||
#[test] | ||
fn test_builder() { | ||
|
@@ -282,4 +366,44 @@ mod test { | |
] | ||
); | ||
} | ||
|
||
#[test] | ||
fn test_bitmap() { | ||
let mut map = CodesBitmap::default(); | ||
map.set(10); | ||
map.set(100); | ||
map.set(500); | ||
|
||
let codes: Vec<u16> = map.codes().collect(); | ||
assert_eq!(codes, vec![10u16, 100, 500]); | ||
|
||
// empty case | ||
let map = CodesBitmap::default(); | ||
assert_eq!(map.codes().collect::<Vec<_>>(), vec![]); | ||
|
||
// edge case: first bit in each block is set | ||
let mut map = CodesBitmap::default(); | ||
(0..8).for_each(|i| map.set(64 * i)); | ||
assert_eq!( | ||
map.codes().collect::<Vec<_>>(), | ||
(0u16..8).map(|i| 64 * i).collect::<Vec<_>>(), | ||
); | ||
|
||
// Full bitmap case. There are only 512 values, so test them all | ||
let mut map = CodesBitmap::default(); | ||
for i in 0..512 { | ||
map.set(i); | ||
} | ||
assert_eq!( | ||
map.codes().collect::<Vec<_>>(), | ||
(0u16..512u16).collect::<Vec<_>>() | ||
); | ||
} | ||
|
||
#[test] | ||
#[should_panic(expected = "code cannot exceed")] | ||
fn test_bitmap_invalid() { | ||
let mut map = CodesBitmap::default(); | ||
map.set(512); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be possible to skip this check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! #18