From 09d6355837b913a9dc96230fdd9e67744448affa Mon Sep 17 00:00:00 2001 From: Michael Sproul Date: Mon, 6 Nov 2023 17:37:42 +1100 Subject: [PATCH] Add `BitList::is_disjoint` --- Cargo.toml | 5 +++++ benches/bitfield.rs | 30 +++++++++++++++++++++++++ src/bitfield.rs | 54 ++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 benches/bitfield.rs diff --git a/Cargo.toml b/Cargo.toml index 8b311f6..87844f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,8 @@ itertools = "0.10.0" [dev-dependencies] serde_json = "1.0.0" tree_hash_derive = "0.5.0" +criterion = "0.5" + +[[bench]] +name = "bitfield" +harness = false diff --git a/benches/bitfield.rs b/benches/bitfield.rs new file mode 100644 index 0000000..cafb523 --- /dev/null +++ b/benches/bitfield.rs @@ -0,0 +1,30 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use smallvec::smallvec; +use ssz_types::BitList; +use typenum::U2048; + +type BitList2048 = BitList; + +fn is_disjoint(c: &mut Criterion) { + let x = BitList2048::from_raw_bytes(smallvec![0xff; 2048 / 8], 2048).unwrap(); + let y = BitList2048::from_raw_bytes(smallvec![0x00; 2048 / 8], 2048).unwrap(); + + c.bench_with_input( + BenchmarkId::new("bitfield_is_disjoint", 2048), + &(x.clone(), y.clone()), + |b, &(ref x, ref y)| { + b.iter(|| assert!(x.is_disjoint(&y))); + }, + ); + + c.bench_with_input( + BenchmarkId::new("bitfield_is_disjoint_by_intersection", 2048), + &(x.clone(), y.clone()), + |b, &(ref x, ref y)| { + b.iter(|| assert!(x.intersection(&y).is_zero())); + }, + ); +} + +criterion_group!(benches, is_disjoint); +criterion_main!(benches); diff --git a/src/bitfield.rs b/src/bitfield.rs index d090833..e58dab3 100644 --- a/src/bitfield.rs +++ b/src/bitfield.rs @@ -240,6 +240,17 @@ impl Bitfield> { pub fn is_subset(&self, other: &Self) -> bool { self.difference(other).is_zero() } + + /// Returns `true` is `self` is disjoint from `other` and `false` otherwise. + /// + /// This method is a faster alternative to `self.intersection(other).is_zero()` and does not + /// allocate! + pub fn is_disjoint(&self, other: &Self) -> bool { + self.bytes + .iter() + .zip(other.bytes.iter()) + .all(|(self_byte, other_byte)| self_byte & other_byte == 0) + } } impl Bitfield> { @@ -391,7 +402,10 @@ impl Bitfield { /// - `bytes` is not the minimal required bytes to represent a bitfield of `bit_len` bits. /// - `bit_len` is not a multiple of 8 and `bytes` contains set bits that are higher than, or /// equal to `bit_len`. - fn from_raw_bytes(bytes: SmallVec<[u8; SMALLVEC_LEN]>, bit_len: usize) -> Result { + pub fn from_raw_bytes( + bytes: SmallVec<[u8; SMALLVEC_LEN]>, + bit_len: usize, + ) -> Result { if bit_len == 0 { if bytes.len() == 1 && bytes[0] == 0 { // A bitfield with `bit_len` 0 can only be represented by a single zero byte. @@ -1243,6 +1257,44 @@ mod bitlist { assert_eq!(c.intersection(&c), c); } + #[test] + fn is_disjoint() { + let a = BitList1024::from_raw_bytes(smallvec![0b1000, 0b0001], 16).unwrap(); + let b = BitList1024::from_raw_bytes(smallvec![0b0100, 0b0001], 16).unwrap(); + let c = BitList1024::from_raw_bytes(smallvec![0b0111, 0b1110], 16).unwrap(); + let d = BitList1024::from_raw_bytes(smallvec![0b0101, 0b0001, 0b0000], 24).unwrap(); + let e = BitList1024::from_raw_bytes(smallvec![0b0101, 0b0000, 0b1011], 24).unwrap(); + + assert_eq!(a.len(), 16); + assert_eq!(b.len(), 16); + assert_eq!(c.len(), 16); + assert_eq!(d.len(), 24); + assert_eq!(e.len(), 24); + + // a bitfield is never disjoint from itself + assert!(!a.is_disjoint(&a)); + assert!(!b.is_disjoint(&b)); + assert!(!c.is_disjoint(&c)); + assert!(!d.is_disjoint(&d)); + assert!(!e.is_disjoint(&e)); + + // same length, not disjoint + assert!(!a.is_disjoint(&b)); + assert!(!b.is_disjoint(&a)); + + // same length, disjoint + assert!(a.is_disjoint(&c)); + assert!(c.is_disjoint(&a)); + + // different length, not disjoint + assert!(!a.is_disjoint(&d)); + assert!(!d.is_disjoint(&a)); + + // different length, disjoint + assert!(a.is_disjoint(&e)); + assert!(e.is_disjoint(&a)); + } + #[test] fn subset() { let a = BitList1024::from_raw_bytes(smallvec![0b1000, 0b0001], 16).unwrap();