Skip to content

Commit

Permalink
WIP: Using chunking in predicate add
Browse files Browse the repository at this point in the history
  • Loading branch information
deven96 committed Sep 23, 2024
1 parent f31e05d commit b4b9be8
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 92 deletions.
68 changes: 23 additions & 45 deletions ahnlich/ai/src/engine/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ use ahnlich_types::keyval::StoreValue;
use ahnlich_types::metadata::MetadataValue;
use fallible_collections::FallibleVec;
use flurry::HashMap as ConcurrentHashMap;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashSet as StdHashSet;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use utils::parallel;
use utils::persistence::AhnlichPersistenceUtils;

/// Contains all the stores that have been created in memory
Expand Down Expand Up @@ -155,61 +157,39 @@ impl AIStoreHandler {
}

/// Validates storeinputs against a store and checks storevalue for reservedkey.
#[tracing::instrument(skip(self, inputs), fields(input_length=inputs.len(), pool_size, chunk_size))]
pub(crate) async fn validate_and_prepare_store_data(
#[tracing::instrument(skip(self, inputs), fields(input_length=inputs.len(), num_threads = rayon::current_num_threads()))]
pub(crate) fn validate_and_prepare_store_data(
&self,
store_name: &StoreName,
inputs: Vec<(StoreInput, StoreValue)>,
) -> Result<StoreValidateResponse, AIProxyError> {
let store = self.get(store_name)?;
let mut output: Vec<_> = FallibleVec::try_with_capacity(inputs.len())?;
let mut delete_hashset = StdHashSet::with_capacity(inputs.len());
let pool_size: usize = 64;
let chunk_size = (inputs.len() + std::cmp::min(inputs.len(), pool_size) - 1)
/ std::cmp::min(inputs.len(), pool_size);

tracing::Span::current().record("pool_size", pool_size);
tracing::Span::current().record("chunk_size", chunk_size);

let mut handles: Vec<_> = FallibleVec::try_with_capacity(pool_size)?;
let chunked_inputs = inputs.chunks(chunk_size);

for chunk in chunked_inputs.into_iter() {
let index_model = store.index_model;
let owned_chunk = chunk.to_vec();
let task =
tokio::spawn(
async move { Self::process_store_inputs(index_model, owned_chunk).await },
);
handles.try_push(task)?;
}

for task in handles {
let response = task
.await
.map_err(|err| AIProxyError::StandardError(err.to_string()))
.and_then(|inner| inner);
match response {
Ok((sub_output, sub_delete_hashset)) => {
output.extend(sub_output);
delete_hashset.extend(sub_delete_hashset);
}
Err(err) => return Err(err),
}
}

Ok((output, delete_hashset))
let index_model = store.index_model;
let chunk_size = parallel::chunk_size(inputs.len());
inputs
.into_par_iter()
.chunks(chunk_size)
.map(|input| Self::preprocess_store_input(index_model, input))
.try_reduce(
|| (Vec::new(), StdHashSet::new()),
|(mut acc_vec, mut acc_set), chunk_res| {
let (chunk_vec, chunk_set) = chunk_res;
acc_vec.extend(chunk_vec);
acc_set.extend(chunk_set);
Ok((acc_vec, acc_set))
},
)
}

#[tracing::instrument(skip(inputs))]
pub(crate) async fn process_store_inputs(
pub(crate) fn preprocess_store_input(
index_model: AIModel,
inputs: Vec<(StoreInput, StoreValue)>,
) -> Result<StoreValidateResponse, AIProxyError> {
let mut output: Vec<_> = FallibleVec::try_with_capacity(inputs.len())?;
let mut delete_hashset = StdHashSet::new();
let metadata_key = &*AHNLICH_AI_RESERVED_META_KEY;
for (store_input, mut store_value) in inputs.into_iter() {
for (store_input, mut store_value) in inputs {
if store_value.contains_key(metadata_key) {
return Err(AIProxyError::ReservedError(metadata_key.to_string()));
}
Expand All @@ -226,7 +206,6 @@ impl AIStoreHandler {
output.try_push((store_input, store_value))?;
delete_hashset.insert(metadata_value);
}

Ok((output, delete_hashset))
}

Expand All @@ -240,9 +219,8 @@ impl AIStoreHandler {
preprocess_action: PreprocessAction,
) -> Result<StoreSetResponse, AIProxyError> {
let store = self.get(store_name)?;
let (validated_data, delete_hashset) = self
.validate_and_prepare_store_data(store_name, inputs)
.await?;
let (validated_data, delete_hashset) =
self.validate_and_prepare_store_data(store_name, inputs)?;

let (store_inputs, store_values): (Vec<_>, Vec<_>) = validated_data.into_iter().unzip();
let store_keys = model_manager
Expand Down
6 changes: 6 additions & 0 deletions ahnlich/db/src/algorithm/non_linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ pub struct NonLinearAlgorithmIndices {
}

impl NonLinearAlgorithmIndices {
#[tracing::instrument]
pub fn is_empty(&self) -> bool {
let pinned = self.algorithm_to_index.pin();
pinned.is_empty()
}

#[tracing::instrument]
pub fn create(input: HashSet<NonLinearAlgorithm>, dimension: NonZeroUsize) -> Self {
let algorithm_to_index = ConcurrentHashMap::new();
Expand Down
66 changes: 44 additions & 22 deletions ahnlich/db/src/engine/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@ use ahnlich_types::predicate::PredicateCondition;
use flurry::HashMap as ConcurrentHashMap;
use flurry::HashSet as ConcurrentHashSet;
use itertools::Itertools;
use rayon::iter::IndexedParallelIterator;
use rayon::iter::IntoParallelIterator;
use rayon::iter::ParallelIterator;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::collections::HashSet as StdHashSet;
use std::mem::size_of_val;
use utils::parallel;

type InnerPredicateIndexVal = ConcurrentHashSet<StoreKeyId>;
type InnerPredicateIndex = ConcurrentHashMap<MetadataValue, InnerPredicateIndexVal>;
Expand Down Expand Up @@ -146,26 +151,35 @@ impl PredicateIndices {
/// Adds predicates if the key is within allowed_predicates
#[tracing::instrument(skip(self))]
pub(super) fn add(&self, new: Vec<(StoreKeyId, StoreValue)>) {
let predicate_values = self.inner.pin();
let iter = new
.into_iter()
.into_par_iter()
.flat_map(|(store_key_id, store_value)| {
store_value.into_iter().map(move |(key, val)| {
store_value.into_par_iter().map(move |(key, val)| {
let allowed_keys = self.allowed_predicates.pin();
allowed_keys
.contains(&key)
.then_some((store_key_id.clone(), key, val))
})
})
.flatten()
.map(|(store_key_id, key, val)| (key, (val, store_key_id)))
.into_group_map();
.map(|(store_key_id, key, val)| (key, (val.to_owned(), store_key_id)))
.fold(HashMap::new, |mut acc: HashMap<_, Vec<_>>, (k, v)| {
acc.entry(k).or_default().push(v);
acc
})
.reduce(HashMap::new, |mut acc, map| {
for (key, mut values) in map {
acc.entry(key).or_default().append(&mut values);
}
acc
});

let predicate_values = self.inner.pin();
for (key, val) in iter {
// If there exists a predicate index as we want to update it, just add to that
// predicate index instead
let pred = PredicateIndex::init(val.clone());
if let Err(existing_predicate) = predicate_values.try_insert(key.clone(), pred) {
if let Err(existing_predicate) = predicate_values.try_insert(key, pred) {
existing_predicate.current.add(val);
};
}
Expand Down Expand Up @@ -250,7 +264,7 @@ impl PredicateIndex {
.sum::<usize>()
}

#[tracing::instrument]
#[tracing::instrument(skip(init), fields(input_length = init.len()))]
fn init(init: Vec<(MetadataValue, StoreKeyId)>) -> Self {
let new = Self(InnerPredicateIndex::new());
new.add(init);
Expand All @@ -277,23 +291,31 @@ impl PredicateIndex {
if update.is_empty() {
return;
}
let pinned = self.0.pin();
for (predicate_value, store_key_id) in update {
if let Some((_, value)) = pinned.get_key_value(&predicate_value) {
value.insert(store_key_id, &value.guard());
} else {
// Use try_insert as it is very possible that the hashmap itself now has that key that
// was not previously there as it has been inserted on a different thread
let new_hashset = ConcurrentHashSet::new();
new_hashset.insert(store_key_id.clone(), &new_hashset.guard());
if let Err(error_current) = pinned.try_insert(predicate_value, new_hashset) {
error_current
.current
.insert(store_key_id, &error_current.current.guard());
let chunk_size = parallel::chunk_size(update.len());
update
.into_par_iter()
.chunks(chunk_size)
.for_each(|values| {
let pinned = self.0.pin();
for (predicate_value, store_key_id) in values {
if let Some((_, value)) = pinned.get_key_value(&predicate_value) {
value.insert(store_key_id, &value.guard());
} else {
// Use try_insert as it is very possible that the hashmap itself now has that key that
// was not previously there as it has been inserted on a different thread
let new_hashset = ConcurrentHashSet::new();
new_hashset.insert(store_key_id.clone(), &new_hashset.guard());
if let Err(error_current) = pinned.try_insert(predicate_value, new_hashset)
{
error_current
.current
.insert(store_key_id, &error_current.current.guard());
}
}
}
}
}
});
}

/// checks the predicate index for a predicate op and value. The return type is a StdHashSet<_>
/// because we do not modify it at any point so we do not need concurrency protection
#[tracing::instrument(skip(self))]
Expand Down
56 changes: 32 additions & 24 deletions ahnlich/db/src/engine/store.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::errors::ServerError;
use rayon::prelude::*;

use super::super::algorithm::non_linear::NonLinearAlgorithmIndices;
use super::super::algorithm::{AlgorithmByType, FindSimilarN};
Expand All @@ -14,16 +15,15 @@ use ahnlich_types::predicate::PredicateCondition;
use ahnlich_types::similarity::Algorithm;
use ahnlich_types::similarity::NonLinearAlgorithm;
use ahnlich_types::similarity::Similarity;
use fallible_collections::FallibleVec;
use flurry::HashMap as ConcurrentHashMap;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap as StdHashMap;
use std::collections::HashSet as StdHashSet;
use std::mem::size_of_val;
use std::num::NonZeroUsize;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::atomic::{AtomicBool, AtomicUsize};
use std::sync::Arc;
use utils::persistence::AhnlichPersistenceUtils;
/// A hash of Store key, this is more preferable when passing around references as arrays can be
Expand Down Expand Up @@ -396,7 +396,7 @@ pub struct Store {
/// Making use of a concurrent hashmap, we should be able to create an engine that manages stores
id_to_value: ConcurrentHashMap<StoreKeyId, (StoreKey, StoreValue)>,
/// Indices to filter for the store
predicate_indices: PredicateIndices,
predicate_indices: Arc<PredicateIndices>,
/// Non linear Indices
non_linear_indices: NonLinearAlgorithmIndices,
}
Expand All @@ -411,7 +411,7 @@ impl Store {
Self {
dimension,
id_to_value: ConcurrentHashMap::new(),
predicate_indices: PredicateIndices::init(predicates),
predicate_indices: Arc::new(PredicateIndices::init(predicates)),
non_linear_indices: NonLinearAlgorithmIndices::create(non_linear_indices, dimension),
}
}
Expand Down Expand Up @@ -571,36 +571,44 @@ impl Store {
});
}
let store_dimension: usize = self.dimension.into();
let check_bounds = |(store_key, store_val): (StoreKey, StoreValue)| -> Result<(StoreKeyId, (StoreKey, StoreValue)), ServerError> {
let check_bounds = |(store_key, store_val): &(StoreKey, StoreValue)| -> Result<(StoreKeyId, (StoreKey, StoreValue)), ServerError> {
let input_dimension = store_key.0.len();
if input_dimension != store_dimension {
Err(ServerError::StoreDimensionMismatch { store_dimension, input_dimension })
} else {
Ok(((&store_key).into(), (store_key, store_val)))
Ok(((store_key).into(), (store_key.to_owned(), store_val.to_owned())))
}
};
let res: Vec<(StoreKeyId, (StoreKey, StoreValue))> = new
.into_iter()
.map(check_bounds)
.collect::<Result<_, _>>()?;
let res: Vec<(StoreKeyId, (StoreKey, StoreValue))> =
new.par_iter().map(check_bounds).collect::<Result<_, _>>()?;
let predicate_insert = res
.iter()
.par_iter()
.map(|(k, (_, v))| (k.clone(), v.clone()))
.collect();
let pinned = self.id_to_value.pin();
let (mut inserted, mut updated) = (0, 0);
let mut inserted_keys: Vec<_> = FallibleVec::try_with_capacity(res.len())?;
for (key, val) in res {
if pinned.insert(key, val.clone()).is_some() {
updated += 1;
} else {
inserted += 1;
inserted_keys.push(val.0 .0);
}
let inserted = AtomicUsize::new(0);
let updated = AtomicUsize::new(0);
let inserted_keys = res
.into_par_iter()
.flat_map_iter(|(k, v)| {
let pinned = self.id_to_value.pin();
if pinned.insert(k, v.clone()).is_some() {
updated.fetch_add(1, Ordering::SeqCst);
} else {
inserted.fetch_add(1, Ordering::SeqCst);
return Some(v.0 .0);
}
None
})
.collect();
let predicate_indices = self.predicate_indices.clone();
predicate_indices.add(predicate_insert);
if !self.non_linear_indices.is_empty() {
self.non_linear_indices.insert(inserted_keys);
}
self.predicate_indices.add(predicate_insert);
self.non_linear_indices.insert(inserted_keys);
Ok(StoreUpsert { inserted, updated })
Ok(StoreUpsert {
inserted: inserted.into_inner(),
updated: updated.into_inner(),
})
}

#[tracing::instrument(skip(self))]
Expand Down
1 change: 1 addition & 0 deletions ahnlich/utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ log.workspace = true
cap = "0.1.2"
tokio-util.workspace = true
fallible_collections.workspace = true
rayon.workspace = true
1 change: 1 addition & 0 deletions ahnlich/utils/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod allocator;
pub mod client;
pub mod parallel;
pub mod persistence;
pub mod protocol;
pub mod server;
22 changes: 22 additions & 0 deletions ahnlich/utils/src/parallel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use rayon::ThreadPoolBuilder;
use std::sync::Once;

static INIT_THREADPOOL_ONCE: Once = Once::new();

// Initialize global rayon threadpool
pub(crate) fn init_threadpool(num_threads: usize) {
INIT_THREADPOOL_ONCE.call_once(|| {
ThreadPoolBuilder::new()
.num_threads(num_threads)
.build_global()
.expect("Cannot build server threadpool");
});
}

// Calculates chunk size to use for an iterable input in order for it to be able to fit into all
// possible rayon threads
pub fn chunk_size(input_length: usize) -> usize {
let num_threads = rayon::current_num_threads();
let minimum_factor = std::cmp::min(input_length, num_threads);
(input_length + minimum_factor - 1) / minimum_factor
}
Loading

0 comments on commit b4b9be8

Please sign in to comment.