Skip to content

Commit

Permalink
implements session object handle iterator, with caching
Browse files Browse the repository at this point in the history
Signed-off-by: Eric Devolder <[email protected]>
  • Loading branch information
keldonin committed Sep 4, 2024
1 parent e0a9e17 commit 168b57e
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 2 deletions.
8 changes: 8 additions & 0 deletions cryptoki/src/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ macro_rules! get_pkcs11 {
};
}

/// Same as get_pkcs11! but does not attempt to apply '?' syntactic sugar.
/// Suitable only if the caller can't return a Result.
macro_rules! get_pkcs11_func {
($pkcs11:expr, $func_name:ident) => {
($pkcs11.impl_.function_list.$func_name)
};
}

mod general_purpose;
mod info;
mod locking;
Expand Down
4 changes: 4 additions & 0 deletions cryptoki/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod session_management;
mod signing_macing;
mod slot_token_management;

pub use object_management::ObjectHandleIterator;
pub use session_info::{SessionInfo, SessionState};

/// Type that identifies a session
Expand All @@ -31,6 +32,8 @@ pub use session_info::{SessionInfo, SessionState};
pub struct Session {
handle: CK_SESSION_HANDLE,
client: Pkcs11,
#[allow(dead_code)]
search_active: bool,
// This is not used but to prevent Session to automatically implement Send and Sync
_guard: PhantomData<*mut u32>,
}
Expand Down Expand Up @@ -62,6 +65,7 @@ impl Session {
Session {
handle,
client,
search_active: false,
_guard: PhantomData,
}
}
Expand Down
216 changes: 215 additions & 1 deletion cryptoki/src/session/object_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! Object management functions

use crate::context::Function;
use crate::error::{Result, Rv, RvError};
use crate::error::{Error, Result, Rv, RvError};
use crate::object::{Attribute, AttributeInfo, AttributeType, ObjectHandle};
use crate::session::Session;
use cryptoki_sys::*;
Expand All @@ -13,7 +13,221 @@ use std::convert::TryInto;
// Search 10 elements at a time
const MAX_OBJECT_COUNT: usize = 10;

/// Iterator over object handles, in an active session.
///
/// Used to iterate over the object handles returned by underlying calls to `C_FindObjects`.
/// The iterator is created by calling the `iter_objects` and `iter_objects_with_cache_size` methods on a `Session` object.
///
/// # Example
///
/// ```no_run
/// use cryptoki::context::CInitializeArgs;
/// use cryptoki::context::Pkcs11;
/// use cryptoki::error::Error;
/// use cryptoki::object::Attribute;
/// use cryptoki::object::AttributeType;
/// use cryptoki::session::UserType;
/// use cryptoki::types::AuthPin;
/// use std::env;
///
/// fn test() -> Result<(), Error> {
/// let pkcs11 = Pkcs11::new(
/// env::var("PKCS11_SOFTHSM2_MODULE")
/// .unwrap_or_else(|_| "/usr/local/lib/libsofthsm2.so".to_string()),
/// )?;
///
/// pkcs11.initialize(CInitializeArgs::OsThreads)?;
/// let slot = pkcs11.get_slots_with_token()?.remove(0);
///
/// let session = pkcs11.open_ro_session(slot).unwrap();
/// session.login(UserType::User, Some(&AuthPin::new("fedcba".into())))?;
///
/// let token_object = vec![Attribute::Token(true)];
/// let wanted_attr = vec![AttributeType::Label];
///
/// for (idx, obj) in session.iter_objects(&token_object)?.enumerate() {
/// let obj = obj?; // handle potential error condition
///
/// let attributes = session.get_attributes(obj, &wanted_attr)?;
///
/// match attributes.get(0) {
/// Some(Attribute::Label(l)) => {
/// println!(
/// "token object #{}: handle {}, label {}",
/// idx,
/// obj,
/// String::from_utf8(l.to_vec())
/// .unwrap_or_else(|_| "*** not valid utf8 ***".to_string())
/// );
/// }
/// _ => {
/// println!("token object #{}: handle {}, label not found", idx, obj);
/// }
/// }
/// }
/// Ok(())
/// }
///
/// pub fn main() {
/// test().unwrap();
/// }
/// ```
#[derive(Debug)]
pub struct ObjectHandleIterator<'a> {
session: &'a Session,
object_count: usize,
index: usize,
cache: Vec<CK_OBJECT_HANDLE>,
}

impl<'a> ObjectHandleIterator<'a> {
fn new(
session: &'a Session,
mut template: Vec<CK_ATTRIBUTE>,
cache_size: usize,
) -> Result<Self> {
if cache_size == 0 {
return Err(Error::InvalidValue);
}

unsafe {
Rv::from(get_pkcs11!(session.client(), C_FindObjectsInit)(
session.handle(),
template.as_mut_ptr(),
template.len().try_into()?,
))
.into_result(Function::FindObjectsInit)?;
}

let cache: Vec<CK_OBJECT_HANDLE> = vec![0; cache_size];
Ok(ObjectHandleIterator {
session,
object_count: cache_size,
index: cache_size,
cache,
})
}
}

// In this implementation, we use object_count to keep track of the number of objects
// returned by the last C_FindObjects call; the index is used to keep track of
// the next object in the cache to be returned. The size of cache is never changed.
// In order to enter the loop for the first time, we set object_count to cache_size
// and index to cache_size. That allows to jump directly to the C_FindObjects call
// and start filling the cache.

impl<'a> Iterator for ObjectHandleIterator<'a> {
type Item = Result<ObjectHandle>;

fn next(&mut self) -> Option<Self::Item> {
// since the iterator is initialized with object_count and index both equal and > 0,
// we are guaranteed to enter the loop at least once
while self.object_count > 0 {
// if index<object_count, we have items in the cache to return
if self.index < self.object_count {
self.index += 1;
return Some(Ok(ObjectHandle::new(self.cache[self.index - 1])));
} else {
// reset counters and proceed to the next section
self.index = 0;

if self.object_count < self.cache.len() {
// if self.object_count is less than the cache size,
// it means our last call to C_FindObjects returned less than the cache size
// At this point, we have exhausted all objects in the cache
// and we can safely break the loop and return None
self.object_count = 0;
break;
} else {
// reset the counter - C_FindObjects will adjust that value.
self.object_count = 0;
}
}

let p11rv;

match get_pkcs11_func!(self.session.client(), C_FindObjects) {
Some(f) => {
p11rv = unsafe {
f(
self.session.handle(),
self.cache.as_mut_ptr(),
self.cache.len() as CK_ULONG,
&mut self.object_count as *mut usize as CK_ULONG_PTR,
)
};
}
None => {
// C_FindObjects() is not implemented on this implementation
// sort of unexpected. TODO: Consider panic!() instead?
log::error!("C_FindObjects() is not implemented on this library");
return Some(Err(Error::NullFunctionPointer) as Result<ObjectHandle>);
}
}

if let Rv::Error(error) = Rv::from(p11rv) {
return Some(
Err(Error::Pkcs11(error, Function::FindObjects)) as Result<ObjectHandle>
);
}
}
None
}
}

impl Drop for ObjectHandleIterator<'_> {
fn drop(&mut self) {
// silently pass if C_FindObjectsFinal() is not implemented on this implementation
// this is unexpected. TODO: Consider panic!() instead?
if let Some(f) = get_pkcs11_func!(self.session.client(), C_FindObjectsFinal) {
// swallow the return value, as we can't do anything about it
let _ = unsafe { f(self.session.handle()) };
}
}
}

impl Session {
/// Iterate over session objects matching a template.
///
/// # Arguments
/// * `template` - The template to match objects against
///
/// # Returns
///
/// This function will return a [`Result<ObjectHandleIterator>`] that can be used to iterate over the objects
/// matching the template. Note that the cache size is managed internally and set to a default value (10)
///
/// # See also
/// * [`ObjectHandleIterator`] for more information on how to use the iterator
/// * [`Session::iter_objects_with_cache_size`] for a way to specify the cache size
#[inline(always)]
pub fn iter_objects(&self, template: &[Attribute]) -> Result<ObjectHandleIterator> {
self.iter_objects_with_cache_size(template, MAX_OBJECT_COUNT)
}

/// Iterate over session objects matching a template, with cache size
///
/// # Arguments
/// * `template` - The template to match objects against
/// * `cache_size` - The number of objects to cache. Note that 0 is an invalid value and will return an error.
///
/// # Returns
///
/// This function will return a [`Result<ObjectHandleIterator>`] that can be used to iterate over the objects
/// matching the template. The cache size corresponds to the size of the array provided to `C_FindObjects()`.
///
/// # See also
/// * [`ObjectHandleIterator`] for more information on how to use the iterator
/// * [`Session::iter_objects`] for a simpler way to iterate over objects
pub fn iter_objects_with_cache_size(
&self,
template: &[Attribute],
cache_size: usize,
) -> Result<ObjectHandleIterator> {
let template: Vec<CK_ATTRIBUTE> = template.iter().map(|attr| attr.into()).collect();
ObjectHandleIterator::new(self, template, cache_size)
}

/// Search for session objects matching a template
pub fn find_objects(&self, template: &[Attribute]) -> Result<Vec<ObjectHandle>> {
let mut template: Vec<CK_ATTRIBUTE> = template.iter().map(|attr| attr.into()).collect();
Expand Down
98 changes: 97 additions & 1 deletion cryptoki/tests/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use cryptoki::error::{Error, RvError};
use cryptoki::mechanism::aead::GcmParams;
use cryptoki::mechanism::rsa::{PkcsMgfType, PkcsOaepParams, PkcsOaepSource};
use cryptoki::mechanism::{Mechanism, MechanismType};
use cryptoki::object::{Attribute, AttributeInfo, AttributeType, KeyType, ObjectClass};
use cryptoki::object::{
Attribute, AttributeInfo, AttributeType, KeyType, ObjectClass, ObjectHandle,
};
use cryptoki::session::{SessionState, UserType};
use cryptoki::types::AuthPin;
use serial_test::serial;
Expand Down Expand Up @@ -364,6 +366,100 @@ fn session_find_objects() {
assert_eq!(found_keys.len(), 9);
}

#[test]
#[serial]
fn session_objecthandle_iterator() {
let (pkcs11, slot) = init_pins();
// open a session
let session = pkcs11.open_rw_session(slot).unwrap();

// log in the session
session
.login(UserType::User, Some(&AuthPin::new(USER_PIN.into())))
.unwrap();

// we generate 11 keys with the same CKA_ID

(1..=11).for_each(|i| {
let key_template = vec![
Attribute::Token(true),
Attribute::Encrypt(true),
Attribute::Label(format!("key_{}", i).as_bytes().to_vec()),
Attribute::Id("12345678".as_bytes().to_vec()), // reusing the same CKA_ID
];

// generate a secret key
let _key = session
.generate_key(&Mechanism::Des3KeyGen, &key_template)
.unwrap();
});

// retrieve these keys using this template
let key_search_template = vec![
Attribute::Token(true),
Attribute::Id("12345678".as_bytes().to_vec()),
Attribute::Class(ObjectClass::SECRET_KEY),
Attribute::KeyType(KeyType::DES3),
];

// test iter_objects_with_cache_size()
// count keys with cache size of 20
let found_keys = session
.iter_objects_with_cache_size(&key_search_template, 20)
.unwrap();
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 11);

// count keys with cache size of 0 => should result in an error
let found_keys = session.iter_objects_with_cache_size(&key_search_template, 0);
assert!(found_keys.is_err());

// count keys with cache size of 1
let found_keys = session
.iter_objects_with_cache_size(&key_search_template, 1)
.unwrap();
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 11);

// count keys with cache size of 10
let found_keys = session
.iter_objects_with_cache_size(&key_search_template, 10)
.unwrap();
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 11);

// fetch keys into a vector
let found_keys: Vec<ObjectHandle> = session
.iter_objects_with_cache_size(&key_search_template, 10)
.unwrap()
.map_while(|key| key.ok())
.collect();
assert_eq!(found_keys.len(), 11);

let key0 = found_keys[0];
let key1 = found_keys[1];

session.destroy_object(key0).unwrap();
let found_keys = session
.iter_objects_with_cache_size(&key_search_template, 10)
.unwrap();
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 10);

// destroy another key
session.destroy_object(key1).unwrap();
let found_keys = session
.iter_objects_with_cache_size(&key_search_template, 10)
.unwrap();
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 9);

// test iter_objects()
let found_keys = session.iter_objects(&key_search_template).unwrap();
let found_keys = found_keys.map_while(|key| key.ok()).count();
assert_eq!(found_keys, 9);
}

#[test]
#[serial]
fn wrap_and_unwrap_key() {
Expand Down

0 comments on commit 168b57e

Please sign in to comment.