Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make Compressor::train 2x faster with bitmap index #16

Merged
merged 6 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions .github/workflows/miri.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: Miri

on:
push:
branches: ["develop"]
pull_request: {}
workflow_dispatch: {}

permissions:
actions: read
contents: read

jobs:
miri:
name: "miri"
runs-on: ubuntu-latest
env:
RUST_BACKTRACE: 1
MIRIFLAGS: -Zmiri-strict-provenance -Zmiri-symbolic-alignment-check -Zmiri-backtrace=full
steps:
- uses: actions/checkout@v4

- name: Rust Version
id: rust-version
shell: bash
run: echo "version=$(cat rust-toolchain.toml | grep channel | awk -F'\"' '{print $2}')" >> $GITHUB_OUTPUT

- name: Rust Toolchain
id: rust-toolchain
uses: dtolnay/rust-toolchain@master
if: steps.rustup-cache.outputs.cache-hit != 'true'
with:
toolchain: "${{ steps.rust-version.outputs.version }}"
components: miri

- name: Rust Dependency Cache
uses: Swatinem/rust-cache@v2
with:
save-if: ${{ github.ref == 'refs/heads/develop' }}
shared-key: "shared" # To allow reuse across jobs

- name: Rust Compile Cache
uses: mozilla-actions/[email protected]
- name: Rust Compile Cache Config
shell: bash
run: |
echo "SCCACHE_GHA_ENABLED=true" >> $GITHUB_ENV
echo "RUSTC_WRAPPER=sccache" >> $GITHUB_ENV
echo "CARGO_INCREMENTAL=0" >> $GITHUB_ENV

- name: Run tests with Miri
run: cargo miri test
162 changes: 143 additions & 19 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,88 @@ use std::collections::BinaryHeap;

use crate::{Compressor, Symbol, ESCAPE_CODE, MAX_CODE};

/// Bitmap that only works for values up to 512
#[derive(Clone, Copy, Debug, Default)]
struct CodesBitmap {
codes: [u64; 8],
}

assert_sizeof!(CodesBitmap => 64);

impl CodesBitmap {
/// Set the indicated bit. Must be between 0 and [`MAX_CODE`][crate::MAX_CODE].
pub(crate) fn set(&mut self, index: usize) {
debug_assert!(index <= MAX_CODE as usize, "code cannot exceed {MAX_CODE}");

let map = index >> 6;
self.codes[map] |= 1 << (index % 64);
}

/// Get all codes set in this bitmap
pub(crate) fn codes(&self) -> CodesIterator {
CodesIterator {
inner: self,
index: 0,
block: self.codes[0],
reference: 0,
}
}
}

struct CodesIterator<'a> {
inner: &'a CodesBitmap,
index: usize,
block: u64,
reference: usize,
}

impl<'a> Iterator for CodesIterator<'a> {
type Item = u16;

fn next(&mut self) -> Option<Self::Item> {
// If current is zero, advance to next non-zero block
while self.block == 0 {
self.index += 1;
if self.index >= 8 {
return None;
}
self.block = self.inner.codes[self.index];
self.reference = self.index * 64;
}

if self.block == 0 {
return None;
}
Comment on lines +61 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be possible to skip this check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! #18


// Find the next set bit in the current block.
let position = self.block.trailing_zeros() as usize;
let code = self.reference + position;

// The next iteration will calculate with reference to the returned code + 1
self.reference = code + 1;
self.block = if position == 63 {
0
} else {
self.block >> (1 + position)
};

Some(code as u16)
}
}

#[derive(Debug, Clone)]
struct Counter {
/// Frequency count for each code.
counts1: Vec<usize>,

/// Frequency count for each code-pair.
counts2: Vec<usize>,

/// Bitmap index of pairs that have been set.
///
/// `pair_index[code1].codes()` yields an iterator that can
/// be used to find all possible codes that follow `codes1`.
pair_index: Vec<CodesBitmap>,
}

const COUNTS1_SIZE: usize = MAX_CODE as usize;
Expand All @@ -28,16 +103,7 @@ impl Counter {
Self {
counts1: vec![0; COUNTS1_SIZE],
counts2: vec![0; COUNTS2_SIZE],
}
}

/// reset
pub fn reset(&mut self) {
for idx in 0..COUNTS1_SIZE {
self.counts1[idx] = 0;
}
for idx in 0..COUNTS2_SIZE {
self.counts2[idx] = 0;
Comment on lines -35 to -40
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was slower than just building a new Counter b/c of the vec![0] change made in the previous PR

pair_index: vec![CodesBitmap::default(); COUNTS1_SIZE],
}
}

Expand All @@ -50,6 +116,7 @@ impl Counter {
fn record_count2(&mut self, code1: u16, code2: u16) {
let idx = (code1 as usize) * 511 + (code2 as usize);
self.counts2[idx] += 1;
self.pair_index[code1 as usize].set(code2 as usize);
}

#[inline]
Expand All @@ -62,12 +129,24 @@ impl Counter {
let idx = (code1 as usize) * 511 + (code2 as usize);
self.counts2[idx]
}

/// Returns an iterator over the codes that have been observed
/// to follow `code1`.
///
/// This is the set of all values `code2` where there was
/// previously a call to `self.record_count2(code1, code2)`.
fn second_codes(&self, code1: u16) -> CodesIterator {
self.pair_index[code1 as usize].codes()
}
}

/// The number of generations used for training. This is taken from the [FSST paper].
///
/// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf
#[cfg(not(miri))]
const MAX_GENERATIONS: usize = 5;
#[cfg(miri)]
const MAX_GENERATIONS: usize = 2;

impl Compressor {
/// Build and train a `Compressor` from a sample corpus of text.
Expand All @@ -87,14 +166,13 @@ impl Compressor {
return compressor;
}

let mut counter = Counter::new();

for _generation in 0..(MAX_GENERATIONS - 1) {
let mut counter = Counter::new();
compressor.compress_count(sample, &mut counter);
compressor = compressor.optimize(&counter, true);
counter.reset();
}

let mut counter = Counter::new();
compressor.compress_count(sample, &mut counter);
compressor.optimize(&counter, true)
}
Expand Down Expand Up @@ -142,9 +220,16 @@ impl Compressor {
fn optimize(&self, counters: &Counter, include_ascii: bool) -> Self {
let mut res = Compressor::default();
let mut pqueue = BinaryHeap::with_capacity(65_536);

for code1 in 0u16..(256u16 + self.n_symbols as u16) {
let symbol1 = self.symbols[code1 as usize];
let mut gain = counters.count1(code1) * symbol1.len();
let count = counters.count1(code1);
// If count is zero, we can skip the whole inner loop.
if count == 0 {
continue;
}

let mut gain = count * symbol1.len();
// NOTE: use heuristic from C++ implementation to boost the gain of single-byte symbols.
// This helps to reduce exception counts.
if code1 < 256 {
Expand All @@ -157,10 +242,10 @@ impl Compressor {
});
}

for code2 in 0u16..(256u16 + self.n_symbols as u16) {
for code2 in counters.second_codes(code1) {
let symbol2 = &self.symbols[code2 as usize];
// If either symbol is zero-length, or if merging would yield a symbol of
// length greater than 8, skip.

// If merging would yield a symbol of length greater than 8, skip.
if symbol1.len() + symbol2.len() > 8 {
continue;
}
Expand Down Expand Up @@ -247,8 +332,7 @@ impl Ord for Candidate {

#[cfg(test)]
mod test {

use crate::{Compressor, ESCAPE_CODE};
use crate::{builder::CodesBitmap, Compressor, ESCAPE_CODE};

#[test]
fn test_builder() {
Expand Down Expand Up @@ -282,4 +366,44 @@ mod test {
]
);
}

#[test]
fn test_bitmap() {
let mut map = CodesBitmap::default();
map.set(10);
map.set(100);
map.set(500);

let codes: Vec<u16> = map.codes().collect();
assert_eq!(codes, vec![10u16, 100, 500]);

// empty case
let map = CodesBitmap::default();
assert_eq!(map.codes().collect::<Vec<_>>(), vec![]);

// edge case: first bit in each block is set
let mut map = CodesBitmap::default();
(0..8).for_each(|i| map.set(64 * i));
assert_eq!(
map.codes().collect::<Vec<_>>(),
(0u16..8).map(|i| 64 * i).collect::<Vec<_>>(),
);

// Full bitmap case. There are only 512 values, so test them all
let mut map = CodesBitmap::default();
for i in 0..512 {
map.set(i);
}
assert_eq!(
map.codes().collect::<Vec<_>>(),
(0u16..512u16).collect::<Vec<_>>()
);
}

#[test]
#[should_panic(expected = "code cannot exceed")]
fn test_bitmap_invalid() {
let mut map = CodesBitmap::default();
map.set(512);
}
}
5 changes: 1 addition & 4 deletions tests/correctness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,13 @@ fn test_one_byte() {

#[test]
fn test_zeros() {
println!("training zeros");
let training_data: Vec<u8> = vec![0, 1, 2, 3, 4, 0];
let trained = Compressor::train(&training_data);
println!("compressing with zeros");
let compressed = trained.compress(&[4, 0]);
println!("decomperssing with zeros");
assert_eq!(trained.decompressor().decompress(&compressed), &[4, 0]);
println!("done");
}

#[cfg_attr(miri, ignore)]
#[test]
fn test_large() {
let corpus: Vec<u8> = DECLARATION.bytes().cycle().take(10_240).collect();
Expand Down
Loading