diff --git a/cryptoki/src/session/mod.rs b/cryptoki/src/session/mod.rs index 245c461e..18b69bb0 100644 --- a/cryptoki/src/session/mod.rs +++ b/cryptoki/src/session/mod.rs @@ -23,6 +23,7 @@ mod session_management; mod signing_macing; mod slot_token_management; +pub use object_management::FindObjects; pub use session_info::{SessionInfo, SessionState}; /// Type that identifies a session @@ -126,11 +127,57 @@ impl Session { session_management::get_session_info(self) } - /// Search for session objects matching a template - pub fn find_objects(&self, template: &[Attribute]) -> Result> { + /// Search for token and session objects matching a template + pub fn find_objects(&mut self, template: &[Attribute]) -> Result> { object_management::find_objects(self, template) } + /// Initiate a search for token and session objects matching a template + /// + /// # Arguments + /// + /// * `template` - The list of attributes to match + /// + /// # Returns + /// + /// This function returns a [FindObjects], which represents an ongoing search. The + /// lifetime of this search is tied to a mutable borrow of the session, so that there + /// may only be one search per session at once. When the [FindObjects] is dropped, + /// the search is ended. + /// + /// # Examples + /// + /// ```no_run + /// use cryptoki::error::Result; + /// use cryptoki::object::{Attribute, AttributeType}; + /// use cryptoki::session::Session; + /// + /// const BATCH_SIZE: usize = 10; + /// + /// fn print_object_labels(session: &mut Session, template: &[Attribute]) -> Result<()> { + /// // Initiate the search. + /// let mut search = session.find_objects_init(template)?; + /// + /// // Iterate over batches of results, while find_next returns a non-empty batch + /// while let ref objects @ [_, ..] = search.find_next(BATCH_SIZE)?[..] { + /// // Iterate over objects in the batch. + /// for &object in objects { + /// // Look up the label for the object. We can't use `session` directly here, + /// // since it's mutably borrowed by search. Instead, use `search.session()`. + /// let attrs = search.session().get_attributes(object, &[AttributeType::Label])?; + /// if let Some(Attribute::Label(label)) = attrs.get(0) { + /// println!("Found object: {}", String::from_utf8_lossy(&label)); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn find_objects_init<'a>(&'a mut self, template: &[Attribute]) -> Result> { + object_management::find_objects_init(self, template) + } + /// Create a new object pub fn create_object(&self, template: &[Attribute]) -> Result { object_management::create_object(self, template) @@ -175,7 +222,7 @@ impl Session { /// pkcs11.initialize(CInitializeArgs::OsThreads).unwrap(); /// let slot = pkcs11.get_slots_with_token().unwrap().remove(0); /// - /// let session = pkcs11.open_ro_session(slot).unwrap(); + /// let mut session = pkcs11.open_ro_session(slot).unwrap(); /// session.login(UserType::User, Some("fedcba")); /// /// let empty_attrib= vec![]; diff --git a/cryptoki/src/session/object_management.rs b/cryptoki/src/session/object_management.rs index db811ed4..64fe77e2 100644 --- a/cryptoki/src/session/object_management.rs +++ b/cryptoki/src/session/object_management.rs @@ -6,62 +6,107 @@ use crate::error::{Result, Rv, RvError}; use crate::object::{Attribute, AttributeInfo, AttributeType, ObjectHandle}; use crate::session::Session; use cryptoki_sys::*; +use log::error; use std::collections::HashMap; use std::convert::TryInto; -// Search 10 elements at a time -const MAX_OBJECT_COUNT: usize = 10; - -// See public docs on stub in parent mod.rs -#[inline(always)] -pub(super) fn find_objects(session: &Session, template: &[Attribute]) -> Result> { - let mut template: Vec = template.iter().map(|attr| attr.into()).collect(); - - unsafe { - Rv::from(get_pkcs11!(session.client(), C_FindObjectsInit)( - session.handle(), - template.as_mut_ptr(), - template.len().try_into()?, - )) - .into_result()?; - } - - let mut object_handles = [0; MAX_OBJECT_COUNT]; - let mut object_count = 0; - let mut objects = Vec::new(); +/// Represents an ongoing object search +/// +/// See the documentation for [Session::find_objects_init]. +#[derive(Debug)] +pub struct FindObjects<'a> { + session: &'a mut Session, +} - unsafe { - Rv::from(get_pkcs11!(session.client(), C_FindObjects)( - session.handle(), - object_handles.as_mut_ptr() as CK_OBJECT_HANDLE_PTR, - MAX_OBJECT_COUNT.try_into()?, - &mut object_count, - )) - .into_result()?; - } +impl<'a> FindObjects<'a> { + /// Continue an ongoing object search + /// + /// # Arguments + /// + /// * `max_objects` - The maximum number of objects to return + /// + /// # Returns + /// + /// This function returns up to `max_objects` objects. If there are no remaining + /// objects, or `max_objects` is 0, then it returns an empty vector. + pub fn find_next(&mut self, max_objects: usize) -> Result> { + if max_objects == 0 { + return Ok(vec![]); + } - while object_count > 0 { - objects.extend_from_slice(&object_handles[..object_count.try_into()?]); + let mut object_handles = Vec::with_capacity(max_objects); + let mut object_count = 0; unsafe { - Rv::from(get_pkcs11!(session.client(), C_FindObjects)( - session.handle(), - object_handles.as_mut_ptr() as CK_OBJECT_HANDLE_PTR, - MAX_OBJECT_COUNT.try_into()?, + Rv::from(get_pkcs11!(self.session.client(), C_FindObjects)( + self.session.handle(), + object_handles.as_mut_ptr(), + max_objects.try_into()?, &mut object_count, )) .into_result()?; + object_handles.set_len(object_count.try_into()?) } + + Ok(object_handles.into_iter().map(ObjectHandle::new).collect()) } + /// Get the session associated to the search + pub fn session(&self) -> &Session { + self.session + } +} + +impl<'a> Drop for FindObjects<'a> { + fn drop(&mut self) { + if let Err(e) = find_objects_final_private(self.session) { + error!("Failed to terminate object search: {}", e); + } + } +} + +fn find_objects_final_private(session: &Session) -> Result<()> { unsafe { Rv::from(get_pkcs11!(session.client(), C_FindObjectsFinal)( session.handle(), )) + .into_result() + } +} + +// See public docs on stub in parent mod.rs +#[inline(always)] +pub(super) fn find_objects_init<'a>( + session: &'a mut Session, + template: &[Attribute], +) -> Result> { + let mut template: Vec = template.iter().map(|attr| attr.into()).collect(); + unsafe { + Rv::from(get_pkcs11!(session.client(), C_FindObjectsInit)( + session.handle(), + template.as_mut_ptr(), + template.len().try_into()?, + )) .into_result()?; } + Ok(FindObjects { session }) +} - let objects = objects.into_iter().map(ObjectHandle::new).collect(); +// Search 10 elements at a time +const MAX_OBJECT_COUNT: usize = 10; + +// See public docs on stub in parent mod.rs +#[inline(always)] +pub(super) fn find_objects( + session: &mut Session, + template: &[Attribute], +) -> Result> { + let mut search = session.find_objects_init(template)?; + let mut objects = Vec::new(); + + while let ref new_objects @ [_, ..] = search.find_next(MAX_OBJECT_COUNT)?[..] { + objects.extend_from_slice(new_objects) + } Ok(objects) } diff --git a/cryptoki/tests/basic.rs b/cryptoki/tests/basic.rs index b798c5db..e7fbcdb1 100644 --- a/cryptoki/tests/basic.rs +++ b/cryptoki/tests/basic.rs @@ -6,8 +6,10 @@ use crate::common::{SO_PIN, USER_PIN}; use common::init_pins; use cryptoki::error::{Error, RvError}; use cryptoki::mechanism::Mechanism; -use cryptoki::object::{Attribute, AttributeInfo, AttributeType, KeyType, ObjectClass}; -use cryptoki::session::{SessionState, UserType}; +use cryptoki::object::{ + Attribute, AttributeInfo, AttributeType, KeyType, ObjectClass, ObjectHandle, +}; +use cryptoki::session::{Session, SessionState, UserType}; use serial_test::serial; use std::collections::HashMap; use std::thread; @@ -220,7 +222,7 @@ fn import_export() -> Result<()> { let (pkcs11, slot) = init_pins(); // open a session - let session = pkcs11.open_rw_session(slot)?; + let mut session = pkcs11.open_rw_session(slot)?; // log in the session session.login(UserType::User, Some(USER_PIN))?; @@ -781,3 +783,270 @@ fn ro_rw_session_test() -> Result<()> { Ok(()) } + +// Generate some AES keys with the given labels +fn generate_sample_objects(session: &Session, labels: I) -> Result<()> +where + I: IntoIterator, + I::Item: AsRef<[u8]>, +{ + for label in labels { + session.generate_key( + &Mechanism::AesKeyGen, + &[ + Attribute::ValueLen(16.into()), + Attribute::Label(label.as_ref().to_owned()), + ], + )?; + } + Ok(()) +} + +// Fetch the labels for the given objects and sort them +fn object_labels_sorted(session: &Session, objects: I) -> Result> +where + I: IntoIterator, +{ + let mut labels = objects + .into_iter() + .map(|obj| { + let mut attrs = session.get_attributes(obj, &[AttributeType::Label])?; + if let Some(Attribute::Label(label)) = attrs.pop() { + Ok(String::from_utf8(label)?) + } else { + panic!("Expected label attribute"); + } + }) + .collect::>>()?; + labels.sort(); + Ok(labels) +} + +// Find all objects (empty template). +#[test] +#[serial] +fn find_objects_all() -> Result<()> { + let (pkcs11, slot) = init_pins(); + let mut session = pkcs11.open_rw_session(slot)?; + session.login(UserType::User, Some(USER_PIN))?; + + // Generate some sample objects + let expected_labels = vec!["bar", "baz", "foo"]; + generate_sample_objects(&session, &expected_labels)?; + + // Find all objects + let objects = session.find_objects(&[])?; + + // Check that we get the same objects back + let labels = object_labels_sorted(&session, objects)?; + assert_eq!(expected_labels, labels); + + Ok(()) +} + +// Find objects matching a template when none match. +#[test] +#[serial] +fn find_objects_none() -> Result<()> { + let (pkcs11, slot) = init_pins(); + let mut session = pkcs11.open_rw_session(slot)?; + session.login(UserType::User, Some(USER_PIN))?; + + // Generate a sample object labeled "foo" + generate_sample_objects(&session, ["foo"])?; + + // Search for objects labeled "bar" + let objects = session.find_objects(&[Attribute::Label(b"bar".to_vec())])?; + assert_eq!(&objects, &[]); + Ok(()) +} + +#[test] +#[serial] +// Find objects matching a template when a few (<10) match. +fn find_objects_few() -> Result<()> { + let (pkcs11, slot) = init_pins(); + let mut session = pkcs11.open_rw_session(slot)?; + session.login(UserType::User, Some(USER_PIN))?; + + // generate some sample AES keys to match + let expected_labels = vec!["bar", "baz", "foo"]; + generate_sample_objects(&session, &expected_labels)?; + + // generate a key that shouldn't match (DES3 vs AES) + session.generate_key( + &Mechanism::Des3KeyGen, + &[Attribute::Label(b"quux".to_vec())], + )?; + + // search for all AES keys + let objects = session.find_objects(&[Attribute::KeyType(KeyType::AES)])?; + let labels = object_labels_sorted(&session, objects)?; + assert_eq!(expected_labels, labels); + Ok(()) +} + +// Find objects matching a template when many (>10) match. +#[test] +#[serial] +fn find_objects_many() -> Result<()> { + let (pkcs11, slot) = init_pins(); + let mut session = pkcs11.open_rw_session(slot)?; + session.login(UserType::User, Some(USER_PIN))?; + + // generate some sample AES keys to match + let expected_labels = (1..=20) + .map(|x| format!("key {:02}", x)) + .collect::>(); + generate_sample_objects(&session, &expected_labels)?; + + // generate a key that shouldn't match (DES3 vs AES) + session.generate_key( + &Mechanism::Des3KeyGen, + &[Attribute::Label(b"quux".to_vec())], + )?; + + // search for all AES keys + let objects = session.find_objects(&[Attribute::KeyType(KeyType::AES)])?; + let labels = object_labels_sorted(&session, objects)?; + assert_eq!(expected_labels, labels); + Ok(()) +} + +// Find all objects incrementally (i.e. empty template) +#[test] +#[serial] +fn find_objects_incrementally_all() -> Result<()> { + let (pkcs11, slot) = init_pins(); + let mut session = pkcs11.open_rw_session(slot)?; + session.login(UserType::User, Some(USER_PIN))?; + + // Generate some sample objects + let expected_labels = vec!["bar", "baz", "foo"]; + generate_sample_objects(&session, &expected_labels)?; + + // Find all objects + let mut search = session.find_objects_init(&[])?; + let objects = search.find_next(10)?; + assert!(search.find_next(10)?.is_empty()); + drop(search); + + // Check that we get the same objects back + let labels = object_labels_sorted(&session, objects)?; + assert_eq!(expected_labels, labels); + + Ok(()) +} + +// Find objects incrementally when none match. +#[test] +#[serial] +fn find_objects_incrementally_none() -> Result<()> { + let (pkcs11, slot) = init_pins(); + let mut session = pkcs11.open_rw_session(slot)?; + session.login(UserType::User, Some(USER_PIN))?; + + // Generate a sample object labeled "foo" + generate_sample_objects(&session, ["foo"])?; + + // Search for objects labeled "bar" + let mut search = session.find_objects_init(&[Attribute::Label(b"bar".to_vec())])?; + assert!(search.find_next(10)?.is_empty()); + Ok(()) +} + +// Find objects incrementally in a single batch. +#[test] +#[serial] +fn find_objects_incrementally_single_batch() -> Result<()> { + let (pkcs11, slot) = init_pins(); + let mut session = pkcs11.open_rw_session(slot)?; + session.login(UserType::User, Some(USER_PIN))?; + + // generate some sample AES keys to match + let expected_labels = vec!["bar", "baz", "foo"]; + generate_sample_objects(&session, &expected_labels)?; + + // generate a key that shouldn't match (DES3 vs AES) + session.generate_key( + &Mechanism::Des3KeyGen, + &[Attribute::Label(b"quux".to_vec())], + )?; + + // search for all AES keys + let mut search = session.find_objects_init(&[Attribute::KeyType(KeyType::AES)])?; + let objects = search.find_next(10)?; + assert!(search.find_next(10)?.is_empty()); + drop(search); + + let labels = object_labels_sorted(&session, objects)?; + assert_eq!(expected_labels, labels); + Ok(()) +} + +// Find objects incrementally in multiple batches. +#[test] +#[serial] +fn find_objects_incrementally_multiple_batches() -> Result<()> { + let (pkcs11, slot) = init_pins(); + let mut session = pkcs11.open_rw_session(slot)?; + session.login(UserType::User, Some(USER_PIN))?; + + // generate some sample AES keys to match + let expected_labels = vec!["bar", "baz", "foo"]; + generate_sample_objects(&session, &expected_labels)?; + + // generate a key that shouldn't match (DES3 vs AES) + session.generate_key( + &Mechanism::Des3KeyGen, + &[Attribute::Label(b"quux".to_vec())], + )?; + + // search for all AES keys + let mut search = session.find_objects_init(&[Attribute::KeyType(KeyType::AES)])?; + let objects1 = search.find_next(2)?; + assert_eq!(2, objects1.len()); + let objects2 = search.find_next(2)?; + assert_eq!(1, objects2.len()); + assert!(search.find_next(2)?.is_empty()); + drop(search); + + let objects = objects1.into_iter().chain(objects2); + let labels = object_labels_sorted(&session, objects)?; + assert_eq!(expected_labels, labels); + Ok(()) +} + +// Find objects incrementally with a zero-sized batch. +#[test] +#[serial] +fn find_objects_incrementally_zero_batch() -> Result<()> { + let (pkcs11, slot) = init_pins(); + let mut session = pkcs11.open_rw_session(slot)?; + session.login(UserType::User, Some(USER_PIN))?; + + // generate some sample AES keys to match + let expected_labels = vec!["bar", "baz", "foo"]; + generate_sample_objects(&session, &expected_labels)?; + + // generate a key that shouldn't match (DES3 vs AES) + session.generate_key( + &Mechanism::Des3KeyGen, + &[Attribute::Label(b"quux".to_vec())], + )?; + + // search for all AES keys + let mut search = session.find_objects_init(&[Attribute::KeyType(KeyType::AES)])?; + + // get a batch of size 0 + assert!(search.find_next(0)?.is_empty()); + + // continue the search + let objects = search.find_next(10)?; + assert!(search.find_next(10)?.is_empty()); + drop(search); + + let labels = object_labels_sorted(&session, objects)?; + assert_eq!(expected_labels, labels); + Ok(()) +}