From d15cc4e8cd6109a5f141e6db2f5bae6f6cc04590 Mon Sep 17 00:00:00 2001 From: Andrew Duffy Date: Wed, 21 Aug 2024 11:47:00 -0400 Subject: [PATCH] implement second bitmap, ~2x speedup for train (#21) ![image](https://github.com/user-attachments/assets/5a30710f-8025-4708-ad74-eee6dd14187b) ^ this is a sad flamegraph. `bzero` is not a fun place to be spending 60% of your time. ![image](https://github.com/user-attachments/assets/fbe55d00-d7fd-4999-83be-3d9b8e6b48b5) How did we get from one to two? 1. Avoid initializing our `counts1` and `counts1` vectors 2. Implement a second bitmap index that limits our outer-loop iterations in `optimize`. Because `counts1` and `counts2` are not initialized, we check the bitmap before all accesses This also gives us another 2x speedup on the train benchmark, which is nice --- src/builder.rs | 61 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 8 deletions(-) diff --git a/src/builder.rs b/src/builder.rs index 1dda9c9..cc58190 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -26,6 +26,14 @@ impl CodesBitmap { self.codes[map] |= 1 << (index % 64); } + /// Check if `index` is present in the bitmap + pub(crate) fn is_set(&self, index: usize) -> bool { + debug_assert!(index <= MAX_CODE as usize, "code cannot exceed {MAX_CODE}"); + + let map = index >> 6; + self.codes[map] & 1 << (index % 64) != 0 + } + /// Get all codes set in this bitmap pub(crate) fn codes(&self) -> CodesIterator { CodesIterator { @@ -82,6 +90,9 @@ struct Counter { /// Frequency count for each code-pair. counts2: Vec, + /// Bitmap index for codes that appear in counts1 + code1_index: CodesBitmap, + /// Bitmap index of pairs that have been set. /// /// `pair_index[code1].codes()` yields an iterator that can @@ -96,36 +107,70 @@ const COUNTS2_SIZE: usize = COUNTS1_SIZE * COUNTS1_SIZE; impl Counter { fn new() -> Self { + let mut counts1 = Vec::with_capacity(COUNTS1_SIZE); + let mut counts2 = Vec::with_capacity(COUNTS2_SIZE); + // SAFETY: all accesses to the vector go through the bitmap to ensure no uninitialized + // data is ever read from these vectors. + unsafe { + counts1.set_len(COUNTS1_SIZE); + counts2.set_len(COUNTS2_SIZE); + } + Self { - counts1: vec![0; COUNTS1_SIZE], - counts2: vec![0; COUNTS2_SIZE], + counts1, + counts2, + code1_index: CodesBitmap::default(), pair_index: vec![CodesBitmap::default(); COUNTS1_SIZE], } } #[inline] fn record_count1(&mut self, code1: u16) { - self.counts1[code1 as usize] += 1; + if self.code1_index.is_set(code1 as usize) { + self.counts1[code1 as usize] += 1; + } else { + self.counts1[code1 as usize] = 1; + } + self.code1_index.set(code1 as usize); } #[inline] fn record_count2(&mut self, code1: u16, code2: u16) { + debug_assert!(self.code1_index.is_set(code1 as usize)); + debug_assert!(self.code1_index.is_set(code2 as usize)); + let idx = (code1 as usize) * 511 + (code2 as usize); - self.counts2[idx] += 1; + if self.pair_index[code1 as usize].is_set(code2 as usize) { + self.counts2[idx] += 1; + } else { + self.counts2[idx] = 1; + } self.pair_index[code1 as usize].set(code2 as usize); } #[inline] - fn count1(&self, code: u16) -> usize { - self.counts1[code as usize] + fn count1(&self, code1: u16) -> usize { + debug_assert!(self.code1_index.is_set(code1 as usize)); + + self.counts1[code1 as usize] } #[inline] fn count2(&self, code1: u16, code2: u16) -> usize { + debug_assert!(self.code1_index.is_set(code1 as usize)); + debug_assert!(self.code1_index.is_set(code2 as usize)); + debug_assert!(self.pair_index[code1 as usize].is_set(code2 as usize)); + let idx = (code1 as usize) * 511 + (code2 as usize); self.counts2[idx] } + /// Returns an ordered iterator over the codes that were observed + /// in a call to [`Self::count1`]. + fn first_codes(&self) -> CodesIterator { + self.code1_index.codes() + } + /// Returns an iterator over the codes that have been observed /// to follow `code1`. /// @@ -217,7 +262,7 @@ impl Compressor { let mut res = Compressor::default(); let mut pqueue = BinaryHeap::with_capacity(65_536); - for code1 in 0u16..(256u16 + self.n_symbols as u16) { + for code1 in counters.first_codes() { let symbol1 = self.symbols[code1 as usize]; let count = counters.count1(code1); // If count is zero, we can skip the whole inner loop. @@ -375,7 +420,7 @@ mod test { // empty case let map = CodesBitmap::default(); - assert_eq!(map.codes().collect::>(), vec![]); + assert!(map.codes().collect::>().is_empty()); // edge case: first bit in each block is set let mut map = CodesBitmap::default();