From e9b41bc9f2a58646cf180c2a6fdab74f545a637d Mon Sep 17 00:00:00 2001 From: Andrew Duffy Date: Thu, 15 Aug 2024 10:12:09 -0400 Subject: [PATCH] handle zero bytes in input properly --- rust-toolchain.toml | 2 +- src/lib.rs | 13 ++++++++++--- tests/correctness.rs | 22 +++++++++++----------- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 23591c9..2296533 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,4 +1,4 @@ [toolchain] -channel = "stable" +channel = "nightly-2024-08-14" components = ["rust-src", "rustfmt", "clippy"] profile = "minimal" diff --git a/src/lib.rs b/src/lib.rs index c5896b6..7191b00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,7 +58,14 @@ impl Symbol { // For little-endian platforms, this counts the number of *trailing* zeros let null_bytes = (numeric.leading_zeros() >> 3) as usize; - size_of::() - null_bytes + // Special case handling of a symbol with all-zeros. This is actually + // a 1-byte symbol containing 0x00. + let len = size_of::() - null_bytes; + if len == 0 { + 1 + } else { + len + } } /// Returns true if the symbol does not encode any bytes. @@ -298,9 +305,9 @@ impl SymbolTable { /// /// # Safety /// - /// `in_ptr` and `out_ptr` must never be NULL or otherwise point to invalid memory. + /// `out_ptr` must never be NULL or otherwise point to invalid memory. // NOTE(aduffy): uncomment this line to make the function appear in profiles - // #[inline(never)] + #[inline(never)] pub(crate) unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) { // Speculatively write the first byte of `word` at offset 1. This is necessary if it is an escape, and // if it isn't, it will be overwritten anyway. diff --git a/tests/correctness.rs b/tests/correctness.rs index f4f752f..8773bc7 100644 --- a/tests/correctness.rs +++ b/tests/correctness.rs @@ -29,17 +29,17 @@ fn test_train_on_empty() { ); } -// #[test] -// fn test_zeros() { -// println!("training zeros"); -// let training_data: Vec = vec![0, 1, 2, 3, 4]; -// let trained = fsst_rs::train(&training_data); -// println!("compressing with zeros"); -// let compressed = trained.compress(&[0, 4]); -// println!("decomperssing with zeros"); -// assert_eq!(trained.decompress(&compressed), &[0, 4]); -// println!("done"); -// } +#[test] +fn test_zeros() { + println!("training zeros"); + let training_data: Vec = vec![0, 1, 2, 3, 4]; + let trained = fsst_rs::train(&training_data); + println!("compressing with zeros"); + let compressed = trained.compress(&[0, 4]); + println!("decomperssing with zeros"); + assert_eq!(trained.decompress(&compressed), &[0, 4]); + println!("done"); +} #[test] fn test_large() {