Skip to content

Commit

Permalink
Use a cache eviction logic that matches our BlsCache usage pattern. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
AmineKhaldi authored Nov 15, 2024
1 parent 2d6f511 commit d988747
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 28 deletions.
8 changes: 7 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ hex = "0.4.3"
thiserror = "1.0.69"
pyo3 = "0.22.6"
arbitrary = "1.4.1"
lru = "0.12.5"
rand = "0.8.5"
criterion = "0.5.1"
rstest = "0.22.0"
Expand Down
2 changes: 1 addition & 1 deletion crates/chia-bls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ hex = { workspace = true }
thiserror = { workspace = true }
pyo3 = { workspace = true, features = ["multiple-pymethods"], optional = true }
arbitrary = { workspace = true, optional = true }
lru = { workspace = true }
linked-hash-map = "0.5.6"

[dev-dependencies]
rand = { workspace = true }
Expand Down
39 changes: 38 additions & 1 deletion crates/chia-bls/benches/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use chia_bls::aggregate_verify;
use chia_bls::{sign, BlsCache, SecretKey, Signature};
use criterion::{criterion_group, criterion_main, Criterion};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rand::{seq::SliceRandom, Rng, SeedableRng};

fn cache_benchmark(c: &mut Criterion) {
let mut rng = StdRng::seed_from_u64(1337);
Expand Down Expand Up @@ -76,6 +76,43 @@ fn cache_benchmark(c: &mut Criterion) {
));
});
});

// Add more pairs to the cache so we can evict a relatively larger number
for i in 1_000..20_000 {
let derived = sk.derive_hardened(i);
let pk = derived.public_key();
let sig = sign(&derived, msg);
agg_sig.aggregate(&sig);
pks.push(pk);
}
bls_cache.aggregate_verify(
pks[1_000..20_000].iter().zip([&msg].iter().cycle()),
&agg_sig,
);

c.bench_function("bls_cache.evict 5% of the items", |b| {
let mut cache = bls_cache.clone();
let mut pks_shuffled = pks.clone();
pks_shuffled.shuffle(&mut rng);
b.iter(|| {
if cache.is_empty() {
return;
}
cache.evict(pks_shuffled.iter().take(1_000).zip([&msg].iter().cycle()));
});
});

c.bench_function("bls_cache.evict 100% of the items", |b| {
let mut cache = bls_cache.clone();
let mut pks_shuffled = pks.clone();
pks_shuffled.shuffle(&mut rng);
b.iter(|| {
if cache.is_empty() {
return;
}
cache.evict(pks_shuffled.iter().zip([&msg].iter().cycle()));
});
});
}

criterion_group!(cache, cache_benchmark);
Expand Down
150 changes: 126 additions & 24 deletions crates/chia-bls/src/bls_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::borrow::Borrow;
use std::num::NonZeroUsize;

use chia_sha2::Sha256;
use lru::LruCache;
use linked_hash_map::LinkedHashMap;
use std::sync::Mutex;

use crate::{aggregate_verify_gt, hash_to_g2};
Expand All @@ -17,16 +17,35 @@ use crate::{GTElement, PublicKey, Signature};
/// However, validating a signature where we have no cached GT elements, the
/// aggregate_verify() primitive is faster. When long-syncing, that's
/// preferable.
#[derive(Debug, Clone)]
struct BlsCacheData {
// sha256(pubkey + message) -> GTElement
items: LinkedHashMap<[u8; 32], GTElement>,
capacity: NonZeroUsize,
}

impl BlsCacheData {
pub fn put(&mut self, hash: [u8; 32], pairing: GTElement) {
// If the cache is full, remove the oldest item.
if self.items.len() == self.capacity.get() {
if let Some((oldest_key, _)) = self.items.pop_front() {
self.items.remove(&oldest_key);
}
}
self.items.insert(hash, pairing);
}
}

#[cfg_attr(feature = "py-bindings", pyo3::pyclass(name = "BLSCache"))]
#[derive(Debug)]
pub struct BlsCache {
// sha256(pubkey + message) -> GTElement
cache: Mutex<LruCache<[u8; 32], GTElement>>,
cache: Mutex<BlsCacheData>,
}

impl Default for BlsCache {
fn default() -> Self {
Self::new(NonZeroUsize::new(50000).unwrap())
Self::new(NonZeroUsize::new(50_000).unwrap())
}
}

Expand All @@ -39,18 +58,21 @@ impl Clone for BlsCache {
}

impl BlsCache {
pub fn new(cache_size: NonZeroUsize) -> Self {
pub fn new(capacity: NonZeroUsize) -> Self {
Self {
cache: Mutex::new(LruCache::new(cache_size)),
cache: Mutex::new(BlsCacheData {
items: LinkedHashMap::new(),
capacity,
}),
}
}

pub fn len(&self) -> usize {
self.cache.lock().expect("cache").len()
self.cache.lock().expect("cache").items.len()
}

pub fn is_empty(&self) -> bool {
self.cache.lock().expect("cache").is_empty()
self.cache.lock().expect("cache").items.is_empty()
}

pub fn aggregate_verify<Pk: Borrow<PublicKey>, Msg: AsRef<[u8]>>(
Expand All @@ -67,7 +89,7 @@ impl BlsCache {
let hash: [u8; 32] = hasher.finalize();

// If the pairing is in the cache, we don't need to recalculate it.
if let Some(pairing) = self.cache.lock().expect("cache").get(&hash).cloned() {
if let Some(pairing) = self.cache.lock().expect("cache").items.get(&hash).cloned() {
return pairing;
}

Expand All @@ -88,6 +110,22 @@ impl BlsCache {
let hash: [u8; 32] = hasher.finalize();
self.cache.lock().expect("cache").put(hash, gt);
}

pub fn evict<Pk, Msg>(&self, pks_msgs: impl IntoIterator<Item = (Pk, Msg)>)
where
Pk: Borrow<PublicKey>,
Msg: AsRef<[u8]>,
{
let mut c = self.cache.lock().expect("cache");
for (pk, msg) in pks_msgs {
let mut hasher = Sha256::new();
let mut aug_msg = pk.borrow().to_bytes().to_vec();
aug_msg.extend_from_slice(msg.as_ref());
hasher.update(&aug_msg);
let hash: [u8; 32] = hasher.finalize();
c.items.remove(&hash);
}
}
}

#[cfg(feature = "py-bindings")]
Expand Down Expand Up @@ -148,7 +186,7 @@ impl BlsCache {
use pyo3::types::PyBytes;
let ret = PyList::empty_bound(py);
let c = self.cache.lock().expect("cache");
for (key, value) in &*c {
for (key, value) in &c.items {
ret.append((PyBytes::new_bound(py, key), value.clone().into_py(py)))?;
}
Ok(ret.into())
Expand All @@ -167,6 +205,20 @@ impl BlsCache {
}
Ok(())
}

#[pyo3(name = "evict")]
pub fn py_evict(&self, pks: &Bound<'_, PyList>, msgs: &Bound<'_, PyList>) -> PyResult<()> {
let pks = pks
.iter()?
.map(|item| item?.extract())
.collect::<PyResult<Vec<PublicKey>>>()?;
let msgs = msgs
.iter()?
.map(|item| item?.extract())
.collect::<PyResult<Vec<PyBackedBytes>>>()?;
self.evict(pks.into_iter().zip(msgs));
Ok(())
}
}

#[cfg(test)]
Expand Down Expand Up @@ -261,21 +313,24 @@ pub mod tests {
}

// The cache should be full now.
assert_eq!(bls_cache.cache.lock().expect("cache").len(), 3);

// Recreate first key.
let sk = SecretKey::from_seed(&[1; 32]);
let pk = sk.public_key();
let msg = [106; 32];

let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();

let mut hasher = Sha256::new();
hasher.update(aug_msg);
let hash: [u8; 32] = hasher.finalize();
assert_eq!(bls_cache.len(), 3);

// The first key should have been removed, since it's the oldest that's been accessed.
assert!(!bls_cache.cache.lock().expect("cache").contains(&hash));
// Recreate first two keys and make sure they got removed.
for i in 1..=2 {
let sk = SecretKey::from_seed(&[i; 32]);
let pk = sk.public_key();
let msg = [106; 32];
let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();
let mut hasher = Sha256::new();
hasher.update(aug_msg);
let hash: [u8; 32] = hasher.finalize();
assert!(!bls_cache
.cache
.lock()
.expect("cache")
.items
.contains_key(&hash));
}
}

#[test]
Expand All @@ -286,4 +341,51 @@ pub mod tests {

assert!(bls_cache.aggregate_verify(pks_msgs, &Signature::default()));
}

#[test]
fn test_evict() {
let mut bls_cache = BlsCache::new(NonZeroUsize::new(5).unwrap());
// Create 5 pk msg pairs and add them to the cache.
let mut pks_msgs = Vec::new();
for i in 1..=5 {
let sk = SecretKey::from_seed(&[i; 32]);
let pk = sk.public_key();
let msg = [42; 32];
let sig = sign(&sk, msg);
pks_msgs.push((pk, msg));
assert!(bls_cache.aggregate_verify([(pk, msg)], &sig));
}
assert_eq!(bls_cache.len(), 5);
// Evict the first and third entries.
let pks_msgs_to_evict = vec![pks_msgs[0], pks_msgs[2]];
bls_cache.evict(pks_msgs_to_evict.iter().copied());
// The cache should have 3 items now.
assert_eq!(bls_cache.len(), 3);
// Check that the evicted entries are no longer in the cache.
for (pk, msg) in &pks_msgs_to_evict {
let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();
let mut hasher = Sha256::new();
hasher.update(aug_msg);
let hash: [u8; 32] = hasher.finalize();
assert!(!bls_cache
.cache
.lock()
.expect("cache")
.items
.contains_key(&hash));
}
// Check that the remaining entries are still in the cache.
for (pk, msg) in &[pks_msgs[1], pks_msgs[3], pks_msgs[4]] {
let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat();
let mut hasher = Sha256::new();
hasher.update(aug_msg);
let hash: [u8; 32] = hasher.finalize();
assert!(bls_cache
.cache
.lock()
.expect("cache")
.items
.contains_key(&hash));
}
}
}
1 change: 1 addition & 0 deletions wheel/generate_type_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ def len(self) -> int: ...
def aggregate_verify(self, pks: list[G1Element], msgs: list[bytes], sig: G2Element) -> bool: ...
def items(self) -> list[tuple[bytes, GTElement]]: ...
def update(self, other: Sequence[tuple[bytes, GTElement]]) -> None: ...
def evict(self, pks: list[G1Element], msgs: list[bytes]) -> None: ...
@final
class AugSchemeMPL:
Expand Down
1 change: 1 addition & 0 deletions wheel/python/chia_rs/chia_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class BLSCache:
def aggregate_verify(self, pks: list[G1Element], msgs: list[bytes], sig: G2Element) -> bool: ...
def items(self) -> list[tuple[bytes, GTElement]]: ...
def update(self, other: Sequence[tuple[bytes, GTElement]]) -> None: ...
def evict(self, pks: list[G1Element], msgs: list[bytes]) -> None: ...

@final
class AugSchemeMPL:
Expand Down

0 comments on commit d988747

Please sign in to comment.