Skip to content

Commit

Permalink
use NonZeroUSize for capturing cache size
Browse files Browse the repository at this point in the history
Signed-off-by: Eric Devolder <[email protected]>
  • Loading branch information
keldonin committed Sep 9, 2024
1 parent 0c33eff commit fd10cc6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 24 deletions.
29 changes: 15 additions & 14 deletions cryptoki/src/session/object_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::session::Session;
use cryptoki_sys::*;
use std::collections::HashMap;
use std::convert::TryInto;
use std::num::NonZeroUsize;

// Search 10 elements at a time
const MAX_OBJECT_COUNT: usize = 10;
Expand Down Expand Up @@ -86,12 +87,8 @@ impl<'a> ObjectHandleIterator<'a> {
fn new(
session: &'a Session,
mut template: Vec<CK_ATTRIBUTE>,
cache_size: usize,
cache_size: NonZeroUsize,
) -> Result<Self> {
if cache_size == 0 {
return Err(Error::InvalidValue);
}

unsafe {
Rv::from(get_pkcs11!(session.client(), C_FindObjectsInit)(
session.handle(),
Expand All @@ -101,11 +98,11 @@ impl<'a> ObjectHandleIterator<'a> {
.into_result(Function::FindObjectsInit)?;
}

let cache: Vec<CK_OBJECT_HANDLE> = vec![0; cache_size];
let cache: Vec<CK_OBJECT_HANDLE> = vec![0; cache_size.get()];
Ok(ObjectHandleIterator {
session,
object_count: cache_size,
index: cache_size,
object_count: cache_size.get(),
index: cache_size.get(),
cache,
})
}
Expand Down Expand Up @@ -187,6 +184,7 @@ impl Session {
/// Iterate over session objects matching a template.
///
/// # Arguments
///
/// * `template` - The template to match objects against
///
/// # Returns
Expand All @@ -195,31 +193,34 @@ impl Session {
/// matching the template. Note that the cache size is managed internally and set to a default value (10)
///
/// # See also
///
/// * [`ObjectHandleIterator`] for more information on how to use the iterator
/// * [`Session::iter_objects_with_cache_size`] for a way to specify the cache size
#[inline(always)]
pub fn iter_objects(&self, template: &[Attribute]) -> Result<ObjectHandleIterator> {
self.iter_objects_with_cache_size(template, MAX_OBJECT_COUNT)
self.iter_objects_with_cache_size(template, NonZeroUsize::new(MAX_OBJECT_COUNT).unwrap())
}

/// Iterate over session objects matching a template, with cache size
///
/// # Arguments
///
/// * `template` - The template to match objects against
/// * `cache_size` - The number of objects to cache. Note that 0 is an invalid value and will return an error.
/// * `cache_size` - The number of objects to cache (type is [`NonZeroUsize`])
///
/// # Returns
///
/// This function will return a [`Result<ObjectHandleIterator>`] that can be used to iterate over the objects
/// matching the template. The cache size corresponds to the size of the array provided to `C_FindObjects()`.
///
/// # See also
///
/// * [`ObjectHandleIterator`] for more information on how to use the iterator
/// * [`Session::iter_objects`] for a simpler way to iterate over objects
pub fn iter_objects_with_cache_size(
&self,
template: &[Attribute],
cache_size: usize,
cache_size: NonZeroUsize,
) -> Result<ObjectHandleIterator> {
let template: Vec<CK_ATTRIBUTE> = template.iter().map(Into::into).collect();
ObjectHandleIterator::new(self, template, cache_size)
Expand All @@ -229,12 +230,12 @@ impl Session {
///
/// # Arguments
///
/// * `template` - A [Attribute] of search parameters that will be used
/// to find objects.
/// * `template` - A reference to [Attribute] of search parameters that will be used
/// to find objects.
///
/// # Returns
///
/// Upon success, a vector of [ObjectHandle] wrapped in a Result.
/// Upon success, a vector of [`ObjectHandle`] wrapped in a Result.
/// Upon failure, the first error encountered.
///
/// # Note
Expand Down
22 changes: 12 additions & 10 deletions cryptoki/tests/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use cryptoki::session::{SessionState, UserType};
use cryptoki::types::AuthPin;
use serial_test::serial;
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::thread;

use cryptoki::mechanism::ekdf::AesCbcDeriveParams;
Expand Down Expand Up @@ -399,27 +400,26 @@ fn session_objecthandle_iterator() -> testresult::TestResult {

// test iter_objects_with_cache_size()
// count keys with cache size of 20
let found_keys = session.iter_objects_with_cache_size(&key_search_template, 20)?;
let found_keys = session
.iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(20).unwrap())?;
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 11);

// count keys with cache size of 0 => should result in an error
let found_keys = session.iter_objects_with_cache_size(&key_search_template, 0);
assert!(found_keys.is_err());

// count keys with cache size of 1
let found_keys = session.iter_objects_with_cache_size(&key_search_template, 1)?;
let found_keys = session
.iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(1).unwrap())?;
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 11);

// count keys with cache size of 10
let found_keys = session.iter_objects_with_cache_size(&key_search_template, 10)?;
let found_keys = session
.iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(10).unwrap())?;
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 11);

// fetch keys into a vector
let found_keys: Vec<ObjectHandle> = session
.iter_objects_with_cache_size(&key_search_template, 10)?
.iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(10).unwrap())?
.map_while(|key| key.ok())
.collect();
assert_eq!(found_keys.len(), 11);
Expand All @@ -428,13 +428,15 @@ fn session_objecthandle_iterator() -> testresult::TestResult {
let key1 = found_keys[1];

session.destroy_object(key0).unwrap();
let found_keys = session.iter_objects_with_cache_size(&key_search_template, 10)?;
let found_keys = session
.iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(10).unwrap())?;
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 10);

// destroy another key
session.destroy_object(key1).unwrap();
let found_keys = session.iter_objects_with_cache_size(&key_search_template, 10)?;
let found_keys = session
.iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(10).unwrap())?;
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 9);

Expand Down

0 comments on commit fd10cc6

Please sign in to comment.