From e11edb1f9b52ee1e0016970c7b6dbc98d52277c6 Mon Sep 17 00:00:00 2001 From: Andrew Duffy Date: Thu, 15 Aug 2024 13:30:43 -0400 Subject: [PATCH] fix docs + clippy --- src/lib.rs | 24 ++++++++---------------- tests/correctness.rs | 2 +- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index bdd2667..1098223 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -388,18 +388,9 @@ impl SymbolTable { ); let remaining_bytes = remaining_bytes as usize; - // Shift off the remaining bytes - // Read the remaining bytes - // Unroll and test multiple values being written here. - // let mut last_word = [0u8; 8]; - // for i in 0..remaining_bytes { - // last_word[i as usize] = unsafe { in_ptr.byte_add(i as usize).read() }; - // } - - // Shift on the words from the remaining bytes. - // let mut last_word = unsafe { (in_ptr as *const u64).read_unaligned() }; - // last_word = mask_prefix(last_word, remaining_bytes as usize); - // let mut last_word = u64::from_le_bytes(last_word); + // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above, + // but shift data out of this word rather than advancing an input pointer and potentially reading + // unowned memory. let mut last_word = unsafe { match remaining_bytes { 0 => 0, @@ -513,6 +504,9 @@ fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool { (left & mask) == right } +/// This is a function that will get monomorphized based on the value of `N` to do +/// a load of `N` values from the pointer in a minimum number of instructions into +/// an output `u64`. unsafe fn extract_u64(ptr: *const u8) -> u64 { match N { 1 => ptr.read() as u64, @@ -522,9 +516,7 @@ unsafe fn extract_u64(ptr: *const u8) -> u64 { let high = (ptr.byte_add(1) as *const u16).read_unaligned() as u64; high << 8 | low } - 4 => { - return (ptr as *const u32).read_unaligned() as u64; - } + 4 => (ptr as *const u32).read_unaligned() as u64, 5 => { let low = (ptr as *const u32).read_unaligned() as u64; let high = ptr.byte_add(4).read() as u64; @@ -541,7 +533,7 @@ unsafe fn extract_u64(ptr: *const u8) -> u64 { let high = ptr.byte_add(6).read() as u64; (high << 48) | (mid << 32) | low } - 8 => (ptr as *const u64).read_unaligned() as u64, + 8 => (ptr as *const u64).read_unaligned(), _ => unreachable!("N must be <= 8"), } } diff --git a/tests/correctness.rs b/tests/correctness.rs index e168dd6..2f2abf7 100644 --- a/tests/correctness.rs +++ b/tests/correctness.rs @@ -48,7 +48,7 @@ fn test_one_byte() { #[test] fn test_zeros() { println!("training zeros"); - let training_data: Vec = vec![0, 1, 2, 3, 4]; + let training_data: Vec = vec![0, 1, 2, 3, 4, 0]; let trained = fsst_rs::train(&training_data); println!("compressing with zeros"); let compressed = trained.compress(&[0, 4]);