From ccf02f96062a9d8190af29606abf570e695b3e82 Mon Sep 17 00:00:00 2001 From: Andrew Duffy Date: Fri, 16 Aug 2024 14:00:28 -0400 Subject: [PATCH] separate Compressor and Decompressor --- Cargo.toml | 3 + benches/compress.rs | 15 +-- examples/file_compressor.rs | 4 +- examples/round_trip.rs | 6 +- fuzz/fuzz_targets/fuzz_compress.rs | 7 +- fuzz/fuzz_targets/fuzz_train.rs | 2 +- src/builder.rs | 58 +++++------ src/find_longest/naive.rs | 4 +- src/lib.rs | 148 ++++++++++++++++++----------- tests/correctness.rs | 31 +++--- 10 files changed, 168 insertions(+), 110 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d49faf3..7108617 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,9 @@ license = "Apache-2.0" repository = "https://github.com/spiraldb/fsst" edition = "2021" +[lib] +name = "fsst" + [lints.rust] warnings = "deny" missing_docs = "deny" diff --git a/benches/compress.rs b/benches/compress.rs index 92725d9..97f8c76 100644 --- a/benches/compress.rs +++ b/benches/compress.rs @@ -8,7 +8,7 @@ use core::str; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use fsst_rs::{train, ESCAPE_CODE}; +use fsst::{Compressor, ESCAPE_CODE}; const CORPUS: &str = include_str!("dracula.txt"); const TEST: &str = "I found my smattering of German very useful here"; @@ -17,13 +17,13 @@ fn bench_fsst(c: &mut Criterion) { let mut group = c.benchmark_group("fsst"); group.bench_function("train", |b| { let corpus = CORPUS.as_bytes(); - b.iter(|| black_box(train(black_box(corpus)))); + b.iter(|| black_box(Compressor::train(black_box(corpus)))); }); - let table = train(CORPUS); + let compressor = Compressor::train(CORPUS); let plaintext = TEST.as_bytes(); - let compressed = table.compress(plaintext); + let compressed = compressor.compress(plaintext); let escape_count = compressed.iter().filter(|b| **b == ESCAPE_CODE).count(); let ratio = (plaintext.len() as f64) / (compressed.len() as f64); println!( @@ -31,17 +31,18 @@ fn bench_fsst(c: &mut Criterion) { compressed.len() ); - let decompressed = table.decompress(&compressed); + let decompressor = compressor.decompressor(); + let decompressed = decompressor.decompress(&compressed); let decompressed = str::from_utf8(&decompressed).unwrap(); println!("DECODED: {}", decompressed); assert_eq!(decompressed, TEST); group.bench_function("compress-single", |b| { - b.iter(|| black_box(table.compress(black_box(plaintext)))); + b.iter(|| black_box(compressor.compress(black_box(plaintext)))); }); group.bench_function("decompress-single", |b| { - b.iter(|| black_box(table.decompress(black_box(&compressed)))); + b.iter(|| black_box(decompressor.decompress(black_box(&compressed)))); }); } diff --git a/examples/file_compressor.rs b/examples/file_compressor.rs index 3dab660..3314c92 100644 --- a/examples/file_compressor.rs +++ b/examples/file_compressor.rs @@ -19,6 +19,8 @@ use std::{ path::Path, }; +use fsst::Compressor; + fn main() { let args: Vec<_> = std::env::args().skip(1).collect(); assert!(args.len() >= 2, "args TRAINING and FILE must be provided"); @@ -33,7 +35,7 @@ fn main() { } println!("building the compressor from {train_path:?}..."); - let compressor = fsst_rs::train(&train_bytes); + let compressor = Compressor::train(&train_bytes); println!("compressing blocks of {input_path:?} with compressor..."); diff --git a/examples/round_trip.rs b/examples/round_trip.rs index 0f3fab7..038b932 100644 --- a/examples/round_trip.rs +++ b/examples/round_trip.rs @@ -2,14 +2,16 @@ use core::str; +use fsst::Compressor; + fn main() { // Train on a sample. let sample = "the quick brown fox jumped over the lazy dog"; - let trained = fsst_rs::train(sample.as_bytes()); + let trained = Compressor::train(sample.as_bytes()); let compressed = trained.compress(sample.as_bytes()); println!("compressed: {} => {}", sample.len(), compressed.len()); // decompress now - let decode = trained.decompress(&compressed); + let decode = trained.decompressor().decompress(&compressed); let output = str::from_utf8(&decode).unwrap(); println!( "decoded to the original: len={} text='{}'", diff --git a/fuzz/fuzz_targets/fuzz_compress.rs b/fuzz/fuzz_targets/fuzz_compress.rs index 23f73bb..a871293 100644 --- a/fuzz/fuzz_targets/fuzz_compress.rs +++ b/fuzz/fuzz_targets/fuzz_compress.rs @@ -3,8 +3,9 @@ use libfuzzer_sys::fuzz_target; fuzz_target!(|data: &[u8]| { - let table = fsst_rs::train("the quick brown fox jumped over the lazy dog".as_bytes()); - let compress = table.compress(data); - let decompress = table.decompress(&compress); + let compressor = + fsst::Compressor::train("the quick brown fox jumped over the lazy dog".as_bytes()); + let compress = compressor.compress(data); + let decompress = compressor.decompressor().decompress(&compress); assert_eq!(&decompress, data); }); diff --git a/fuzz/fuzz_targets/fuzz_train.rs b/fuzz/fuzz_targets/fuzz_train.rs index fbb6618..5d3dada 100644 --- a/fuzz/fuzz_targets/fuzz_train.rs +++ b/fuzz/fuzz_targets/fuzz_train.rs @@ -3,5 +3,5 @@ use libfuzzer_sys::fuzz_target; fuzz_target!(|data: &[u8]| { - let _ = fsst_rs::train(data); + let _ = fsst::Compressor::train(data); }); diff --git a/src/builder.rs b/src/builder.rs index 43a6fd4..84ed370 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -1,4 +1,4 @@ -//! Functions and types used for building a [`SymbolTable`] from a corpus of text. +//! Functions and types used for building a [`Compressor`] from a corpus of text. //! //! This module implements the logic from Algorithm 3 of the [FSST Paper]. //! @@ -8,7 +8,7 @@ use std::cmp::Ordering; use std::collections::BinaryHeap; use crate::find_longest::FindLongestSymbol; -use crate::{Symbol, SymbolTable, MAX_CODE}; +use crate::{Compressor, Symbol, MAX_CODE}; #[derive(Debug, Clone)] struct Counter { @@ -53,31 +53,33 @@ impl Counter { /// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf pub const MAX_GENERATIONS: usize = 5; -/// Build and train a `SymbolTable` from a sample corpus of text. -/// -/// This function implements the generational algorithm described in the [FSST paper] Section -/// 4.3. Starting with an empty symbol table, it iteratively compresses the corpus, then attempts -/// to merge symbols when doing so would yield better compression than leaving them unmerged. The -/// resulting table will have at most 255 symbols (the 256th symbol is reserved for the escape -/// code). -/// -/// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf -pub fn train(corpus: impl AsRef<[u8]>) -> SymbolTable { - let mut table = SymbolTable::default(); - // TODO(aduffy): handle truncating/sampling if corpus > requires sample size. - let sample = corpus.as_ref(); - if sample.is_empty() { - return table; - } - for _generation in 0..MAX_GENERATIONS { - let counter = table.compress_count(sample); - table = table.optimize(counter); - } +impl Compressor { + /// Build and train a `Compressor` from a sample corpus of text. + /// + /// This function implements the generational algorithm described in the [FSST paper] Section + /// 4.3. Starting with an empty symbol table, it iteratively compresses the corpus, then attempts + /// to merge symbols when doing so would yield better compression than leaving them unmerged. The + /// resulting table will have at most 255 symbols (the 256th symbol is reserved for the escape + /// code). + /// + /// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf + pub fn train(corpus: impl AsRef<[u8]>) -> Self { + let mut compressor = Self::default(); + // TODO(aduffy): handle truncating/sampling if corpus > requires sample size. + let sample = corpus.as_ref(); + if sample.is_empty() { + return compressor; + } + for _generation in 0..MAX_GENERATIONS { + let counter = compressor.compress_count(sample); + compressor = compressor.optimize(counter); + } - table + compressor + } } -impl SymbolTable { +impl Compressor { /// Compress the text using the current symbol table. Count the code occurrences /// and code-pair occurrences to allow us to calculate apparent gain. fn compress_count(&self, sample: &[u8]) -> Counter { @@ -101,7 +103,7 @@ impl SymbolTable { /// Using a set of counters and the existing set of symbols, build a new /// set of symbols/codes that optimizes the gain over the distribution in `counter`. fn optimize(&self, counters: Counter) -> Self { - let mut res = SymbolTable::default(); + let mut res = Compressor::default(); let mut pqueue = BinaryHeap::new(); for code1 in 0u16..(256u16 + self.n_symbols as u16) { let symbol1 = self.symbols[code1 as usize]; @@ -186,13 +188,13 @@ impl Ord for Candidate { #[cfg(test)] mod test { - use crate::{train, ESCAPE_CODE}; + use crate::{Compressor, ESCAPE_CODE}; #[test] fn test_builder() { - // Train a SymbolTable on the toy string + // Train a Compressor on the toy string let text = "hello world"; - let table = train(text.as_bytes()); + let table = Compressor::train(text.as_bytes()); // Use the table to compress a string, see the values let compressed = table.compress(text.as_bytes()); diff --git a/src/find_longest/naive.rs b/src/find_longest/naive.rs index c9add2d..e6519c1 100644 --- a/src/find_longest/naive.rs +++ b/src/find_longest/naive.rs @@ -1,11 +1,11 @@ use crate::find_longest::FindLongestSymbol; -use crate::SymbolTable; +use crate::Compressor; // Find the code that maps to a symbol with longest-match to a piece of text. // // This is the naive algorithm that just scans the whole table and is very slow. -impl FindLongestSymbol for SymbolTable { +impl FindLongestSymbol for Compressor { // NOTE(aduffy): if you don't disable inlining, this function won't show up in profiles. #[inline(never)] fn find_longest_symbol(&self, text: &[u8]) -> u16 { diff --git a/src/lib.rs b/src/lib.rs index ff24827..86619ed 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,7 +17,7 @@ mod builder; mod find_longest; mod lossy_pht; -/// `Symbol`s are small (up to 8-byte) segments of strings, stored in a [`SymbolTable`] and +/// `Symbol`s are small (up to 8-byte) segments of strings, stored in a [`Compressor`][`crate::Compressor`] and /// identified by an 8-bit code. #[derive(Copy, Clone)] pub union Symbol { @@ -206,29 +206,102 @@ impl Debug for CodeMeta { } } -/// The static symbol table used for compression and decompression. +/// Decompressor uses a symbol table to take a stream of 8-bit codes into a string. +#[derive(Clone)] +pub struct Decompressor<'a> { + /// Table mapping codes to symbols. + /// + /// The first 256 slots are escapes. The following slots (up to 254) + /// are for symbols with actual codes. + /// + /// This physical layout is important so that we can do straight-line execution in the decompress method. + pub(crate) symbols: &'a [Symbol], +} + +impl<'a> Decompressor<'a> { + /// Returns a new decompressor that uses the provided symbol table. + /// + /// # Panics + /// + /// If the provided symbol table has length greater than [`MAX_CODE`]. + pub fn new(symbols: &'a [Symbol]) -> Self { + assert!( + symbols.len() <= MAX_CODE as usize, + "symbol table cannot have size exceeding MAX_CODE" + ); + + Self { symbols } + } + + /// Decompress a byte slice that was previously returned by a compressor using + /// the same symbol table. + pub fn decompress(&self, compressed: &[u8]) -> Vec { + let mut decoded: Vec = Vec::with_capacity(size_of::() * (compressed.len() + 1)); + let ptr = decoded.as_mut_ptr(); + + let mut in_pos = 0; + let mut out_pos = 0; + + while in_pos < compressed.len() && out_pos < (decoded.capacity() - size_of::()) { + let code = compressed[in_pos]; + if code == ESCAPE_CODE { + // Advance by one, do raw write. + in_pos += 1; + // SAFETY: out_pos is always 8 bytes or more from the end of decoded buffer + unsafe { + let write_addr = ptr.byte_offset(out_pos as isize); + write_addr.write(compressed[in_pos]); + } + out_pos += 1; + in_pos += 1; + } else { + let symbol = self.symbols[256 + code as usize]; + // SAFETY: out_pos is always 8 bytes or more from the end of decoded buffer + unsafe { + let write_addr = ptr.byte_offset(out_pos as isize) as *mut u64; + // Perform 8 byte unaligned write. + write_addr.write_unaligned(symbol.num); + } + in_pos += 1; + out_pos += symbol.len(); + } + } + + assert!( + in_pos >= compressed.len(), + "decompression should exhaust input before output" + ); + + // SAFETY: we enforce in the loop condition that out_pos <= decoded.capacity() + unsafe { decoded.set_len(out_pos) }; + + decoded + } +} + +/// A compressor that uses a symbol table to greedily compress strings. /// -/// The `SymbolTable` is the central component of FSST. You can create a SymbolTable either by -/// default, or by [training][`crate::train`] it on an input corpus of text. +/// The `Compressor` is the central component of FSST. You can create a compressor either by +/// default (i.e. an empty compressor), or by [training][`Self::train`] it on an input corpus of text. /// /// Example usage: /// /// ``` -/// use fsst_rs::{Symbol, SymbolTable}; -/// let mut table = SymbolTable::default(); +/// use fsst::{Symbol, Compressor}; +/// let mut compressor = Compressor::default(); /// /// // Insert a new symbol -/// assert!(table.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', 0, 0, 0]))); +/// assert!(compressor.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', 0, 0, 0]))); /// -/// let compressed = table.compress("hello".as_bytes()); +/// let compressed = compressor.compress("hello".as_bytes()); /// assert_eq!(compressed, vec![0u8]); /// ``` #[derive(Clone)] -pub struct SymbolTable { +pub struct Compressor { /// Table mapping codes to symbols. pub(crate) symbols: [Symbol; 511], - /// Indicates the number of entries in the symbol table that have been populated, not counting + /// The number of entries in the symbol table that have been populated, not counting /// the escape values. pub(crate) n_symbols: u8, @@ -242,7 +315,7 @@ pub struct SymbolTable { lossy_pht: LossyPHT, } -impl Default for SymbolTable { +impl Default for Compressor { fn default() -> Self { let mut table = Self { symbols: [Symbol::ZERO; 511], @@ -264,7 +337,7 @@ impl Default for SymbolTable { /// /// The symbol table is trained on a corpus of data in the form of a single byte array, building up /// a mapping of 1-byte "codes" to sequences of up to `N` plaintext bytse, or "symbols". -impl SymbolTable { +impl Compressor { /// Attempt to insert a new symbol at the end of the table. /// /// # Panics @@ -434,48 +507,17 @@ impl SymbolTable { values } - /// Decompress a byte slice that was previously returned by [compression][Self::compress]. - pub fn decompress(&self, compressed: &[u8]) -> Vec { - let mut decoded: Vec = Vec::with_capacity(size_of::() * (compressed.len() + 1)); - let ptr = decoded.as_mut_ptr(); - - let mut in_pos = 0; - let mut out_pos = 0; - - while in_pos < compressed.len() && out_pos < (decoded.capacity() - size_of::()) { - let code = compressed[in_pos]; - if code == ESCAPE_CODE { - // Advance by one, do raw write. - in_pos += 1; - // SAFETY: out_pos is always 8 bytes or more from the end of decoded buffer - unsafe { - let write_addr = ptr.byte_offset(out_pos as isize); - write_addr.write(compressed[in_pos]); - } - out_pos += 1; - in_pos += 1; - } else { - let symbol = self.symbols[256 + code as usize]; - // SAFETY: out_pos is always 8 bytes or more from the end of decoded buffer - unsafe { - let write_addr = ptr.byte_offset(out_pos as isize) as *mut u64; - // Perform 8 byte unaligned write. - write_addr.write_unaligned(symbol.num); - } - in_pos += 1; - out_pos += symbol.len(); - } - } - - assert!( - in_pos >= compressed.len(), - "decompression should exhaust input before output" - ); - - // SAFETY: we enforce in the loop condition that out_pos <= decoded.capacity() - unsafe { decoded.set_len(out_pos) }; + /// Access the decompressor that can be used to decompress strings emitted from this + /// `Compressor` instance. + pub fn decompressor(&self) -> Decompressor { + Decompressor::new(self.symbol_table()) + } - decoded + /// Returns a readonly slice of the current symbol table. + /// + /// The returned slice will have length of `256 + n_symbols`. + pub fn symbol_table(&self) -> &[Symbol] { + &self.symbols[0..(256 + self.n_symbols as usize)] } } diff --git a/tests/correctness.rs b/tests/correctness.rs index a7c7599..d557f06 100644 --- a/tests/correctness.rs +++ b/tests/correctness.rs @@ -1,6 +1,6 @@ #![cfg(test)] -use fsst_rs::Symbol; +use fsst::{Compressor, Symbol}; static PREAMBLE: &str = r#" When in the Course of human events, it becomes necessary for one people to dissolve @@ -16,44 +16,44 @@ static ART_OF_WAR: &str = include_str!("./fixtures/art_of_war.txt"); #[test] fn test_basic() { // Roundtrip the declaration - let trained = fsst_rs::train(PREAMBLE); + let trained = Compressor::train(PREAMBLE); let compressed = trained.compress(PREAMBLE.as_bytes()); - let decompressed = trained.decompress(&compressed); + let decompressed = trained.decompressor().decompress(&compressed); assert_eq!(decompressed, PREAMBLE.as_bytes()); } #[test] fn test_train_on_empty() { - let trained = fsst_rs::train(""); + let trained = Compressor::train(""); // We can still compress with it, but the symbols are going to be empty. let compressed = trained.compress("the quick brown fox jumped over the lazy dog".as_bytes()); assert_eq!( - trained.decompress(&compressed), + trained.decompressor().decompress(&compressed), "the quick brown fox jumped over the lazy dog".as_bytes() ); } #[test] fn test_one_byte() { - let mut empty = fsst_rs::SymbolTable::default(); + let mut empty = Compressor::default(); // Assign code 0 to map to the symbol containing byte 0x01 empty.insert(Symbol::from_u8(0x01)); let compressed = empty.compress(&[0x01]); assert_eq!(compressed, vec![0u8]); - assert_eq!(empty.decompress(&compressed), vec![0x01]); + assert_eq!(empty.decompressor().decompress(&compressed), vec![0x01]); } #[test] fn test_zeros() { println!("training zeros"); let training_data: Vec = vec![0, 1, 2, 3, 4, 0]; - let trained = fsst_rs::train(&training_data); + let trained = Compressor::train(&training_data); println!("compressing with zeros"); let compressed = trained.compress(&[4, 0]); println!("decomperssing with zeros"); - assert_eq!(trained.decompress(&compressed), &[4, 0]); + assert_eq!(trained.decompressor().decompress(&compressed), &[4, 0]); println!("done"); } @@ -65,20 +65,25 @@ fn test_large() { corpus.push_str(DECLARATION); } - let trained = fsst_rs::train(&corpus); + let trained = Compressor::train(&corpus); let mut massive = String::new(); while massive.len() < 16 * 1_024 * 1_024 { massive.push_str(DECLARATION); } let compressed = trained.compress(massive.as_bytes()); - assert_eq!(trained.decompress(&compressed), massive.as_bytes()); + assert_eq!( + trained.decompressor().decompress(&compressed), + massive.as_bytes() + ); } #[test] fn test_chinese() { - let trained = fsst_rs::train(ART_OF_WAR.as_bytes()); + let trained = Compressor::train(ART_OF_WAR.as_bytes()); assert_eq!( ART_OF_WAR.as_bytes(), - trained.decompress(&trained.compress(ART_OF_WAR.as_bytes())) + trained + .decompressor() + .decompress(&trained.compress(ART_OF_WAR.as_bytes())) ); }