From 518865a8754af4c1c1ed6afe490b30253325f2ce Mon Sep 17 00:00:00 2001 From: Eric Devolder Date: Sat, 7 Sep 2024 03:09:54 +0200 Subject: [PATCH] use NonZeroUSize for capturing cache size Signed-off-by: Eric Devolder --- cryptoki/src/session/object_management.rs | 28 +++++++++++++---------- cryptoki/tests/basic.rs | 22 ++++++++++-------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/cryptoki/src/session/object_management.rs b/cryptoki/src/session/object_management.rs index 934d9fac..a0262e8e 100644 --- a/cryptoki/src/session/object_management.rs +++ b/cryptoki/src/session/object_management.rs @@ -9,6 +9,7 @@ use crate::session::Session; use cryptoki_sys::*; use std::collections::HashMap; use std::convert::TryInto; +use std::num::NonZeroUsize; // Search 10 elements at a time const MAX_OBJECT_COUNT: usize = 10; @@ -86,12 +87,8 @@ impl<'a> ObjectHandleIterator<'a> { fn new( session: &'a Session, mut template: Vec, - cache_size: usize, + cache_size: NonZeroUsize, ) -> Result { - if cache_size == 0 { - return Err(Error::InvalidValue); - } - unsafe { Rv::from(get_pkcs11!(session.client(), C_FindObjectsInit)( session.handle(), @@ -101,11 +98,11 @@ impl<'a> ObjectHandleIterator<'a> { .into_result(Function::FindObjectsInit)?; } - let cache: Vec = vec![0; cache_size]; + let cache: Vec = vec![0; cache_size.get()]; Ok(ObjectHandleIterator { session, - object_count: cache_size, - index: cache_size, + object_count: cache_size.get(), + index: cache_size.get(), cache, }) } @@ -187,6 +184,7 @@ impl Session { /// Iterate over session objects matching a template. /// /// # Arguments + /// /// * `template` - The template to match objects against /// /// # Returns @@ -195,18 +193,20 @@ impl Session { /// matching the template. Note that the cache size is managed internally and set to a default value (10) /// /// # See also + /// /// * [`ObjectHandleIterator`] for more information on how to use the iterator /// * [`Session::iter_objects_with_cache_size`] for a way to specify the cache size #[inline(always)] pub fn iter_objects(&self, template: &[Attribute]) -> Result { - self.iter_objects_with_cache_size(template, MAX_OBJECT_COUNT) + self.iter_objects_with_cache_size(template, NonZeroUsize::new(MAX_OBJECT_COUNT).unwrap()) } /// Iterate over session objects matching a template, with cache size /// /// # Arguments + /// /// * `template` - The template to match objects against - /// * `cache_size` - The number of objects to cache. Note that 0 is an invalid value and will return an error. + /// * `cache_size` - The number of objects to cache (type is [`NonZeroUsize`]) /// /// # Returns /// @@ -214,29 +214,33 @@ impl Session { /// matching the template. The cache size corresponds to the size of the array provided to `C_FindObjects()`. /// /// # See also + /// /// * [`ObjectHandleIterator`] for more information on how to use the iterator /// * [`Session::iter_objects`] for a simpler way to iterate over objects pub fn iter_objects_with_cache_size( &self, template: &[Attribute], - cache_size: usize, + cache_size: NonZeroUsize, ) -> Result { let template: Vec = template.iter().map(Into::into).collect(); ObjectHandleIterator::new(self, template, cache_size) } /// Search for session objects matching a template + /// /// # Arguments + /// /// * `template` - The template to match objects against /// /// # Returns /// - /// Upon success, a vector of [ObjectHandle] wrapped in a Result. + /// Upon success, a vector of [`ObjectHandle`] wrapped in a Result. /// Upon failure, the first error encountered. /// /// It is a convenience function that will call [`Session::iter_objects`] and collect the results. /// /// # See also + /// /// * [`Session::iter_objects`] for a way to specify the cache size #[inline(always)] pub fn find_objects(&self, template: &[Attribute]) -> Result> { diff --git a/cryptoki/tests/basic.rs b/cryptoki/tests/basic.rs index 0ed6cb01..bc9d5059 100644 --- a/cryptoki/tests/basic.rs +++ b/cryptoki/tests/basic.rs @@ -16,6 +16,7 @@ use cryptoki::session::{SessionState, UserType}; use cryptoki::types::AuthPin; use serial_test::serial; use std::collections::HashMap; +use std::num::NonZeroUsize; use std::thread; use cryptoki::mechanism::ekdf::AesCbcDeriveParams; @@ -399,27 +400,26 @@ fn session_objecthandle_iterator() -> testresult::TestResult { // test iter_objects_with_cache_size() // count keys with cache size of 20 - let found_keys = session.iter_objects_with_cache_size(&key_search_template, 20)?; + let found_keys = session + .iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(20).unwrap())?; let found_keys = found_keys.map_while(|key| key.ok()).count(); assert_eq!(found_keys, 11); - // count keys with cache size of 0 => should result in an error - let found_keys = session.iter_objects_with_cache_size(&key_search_template, 0); - assert!(found_keys.is_err()); - // count keys with cache size of 1 - let found_keys = session.iter_objects_with_cache_size(&key_search_template, 1)?; + let found_keys = session + .iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(1).unwrap())?; let found_keys = found_keys.map_while(|key| key.ok()).count(); assert_eq!(found_keys, 11); // count keys with cache size of 10 - let found_keys = session.iter_objects_with_cache_size(&key_search_template, 10)?; + let found_keys = session + .iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(10).unwrap())?; let found_keys = found_keys.map_while(|key| key.ok()).count(); assert_eq!(found_keys, 11); // fetch keys into a vector let found_keys: Vec = session - .iter_objects_with_cache_size(&key_search_template, 10)? + .iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(10).unwrap())? .map_while(|key| key.ok()) .collect(); assert_eq!(found_keys.len(), 11); @@ -428,13 +428,15 @@ fn session_objecthandle_iterator() -> testresult::TestResult { let key1 = found_keys[1]; session.destroy_object(key0).unwrap(); - let found_keys = session.iter_objects_with_cache_size(&key_search_template, 10)?; + let found_keys = session + .iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(10).unwrap())?; let found_keys = found_keys.map_while(|key| key.ok()).count(); assert_eq!(found_keys, 10); // destroy another key session.destroy_object(key1).unwrap(); - let found_keys = session.iter_objects_with_cache_size(&key_search_template, 10)?; + let found_keys = session + .iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(10).unwrap())?; let found_keys = found_keys.map_while(|key| key.ok()).count(); assert_eq!(found_keys, 9);