Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify StrongholdAdapterBuilder #940

Merged
merged 7 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sdk/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Security -->

## 1.1.0 - 2023-MM-DD

### Changed

- `StrongholdAdapterBuilder` updated to be slightly more ergonomic;

## 1.0.1 - 2023-07-25

### Fixed
Expand Down
122 changes: 59 additions & 63 deletions sdk/src/client/stronghold/mod.rs
thibault-martinez marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ mod migration;
mod secret;
mod storage;

use alloc::sync::Weak;
use std::{
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};

use derive_builder::Builder;
thibault-martinez marked this conversation as resolved.
Show resolved Hide resolved
use iota_stronghold::{KeyProvider, SnapshotPath, Stronghold};
use iota_stronghold::{ClientError, KeyProvider, SnapshotPath, Stronghold};
use log::{debug, error, warn};
use tokio::{
sync::{Mutex, MutexGuard},
Expand All @@ -74,11 +74,9 @@ use super::{storage::StorageAdapter, utils::Password};
/// A wrapper on [Stronghold].
///
/// See the [module-level documentation](self) for more details.
#[derive(Builder, Debug)]
thibault-martinez marked this conversation as resolved.
Show resolved Hide resolved
#[builder(pattern = "owned", build_fn(skip))]
#[derive(Debug)]
pub struct StrongholdAdapter {
/// A stronghold instance.
#[builder(field(type = "Option<Stronghold>"))]
stronghold: Arc<Mutex<Stronghold>>,

/// A key to open the Stronghold vault.
Expand All @@ -88,8 +86,6 @@ pub struct StrongholdAdapter {
/// derive a key from it.
///
/// [`password()`]: self::StrongholdAdapterBuilder::password()
#[builder(setter(custom))]
#[builder(field(type = "Option<KeyProvider>"))]
key_provider: Arc<Mutex<Option<KeyProvider>>>,

/// An interval of time, after which `key` will be cleared from the memory.
Expand All @@ -98,18 +94,12 @@ pub struct StrongholdAdapter {
/// timer will be spawned in the background to clear ([zeroize]) the key after `timeout`.
///
/// If a [`StrongholdAdapter`] is destroyed (dropped), then the timer will stop too.
#[builder(setter(strip_option))]
timeout: Option<Duration>,

/// A handle to the timeout task.
///
/// Note that this field doesn't actually have a custom setter; `setter(custom)` is only for skipping the setter
/// generation.
#[builder(setter(custom))]
timeout_task: Arc<Mutex<Option<JoinHandle<()>>>>,
timeout_task: Arc<Mutex<Option<JoinHandle<Result<(), ClientError>>>>>,

/// The path to a Stronghold snapshot file.
#[builder(setter(skip))]
pub(crate) snapshot_path: PathBuf,
}

Expand Down Expand Up @@ -146,8 +136,30 @@ fn check_or_create_snapshot(
Ok(())
}

#[derive(Default, Debug)]
pub struct StrongholdAdapterBuilder {
stronghold: Option<Stronghold>,
key_provider: Option<KeyProvider>,
timeout: Option<Duration>,
}

/// Extra / custom builder method implementations.
impl StrongholdAdapterBuilder {
pub fn stronghold(mut self, stronghold: impl Into<Option<Stronghold>>) -> Self {
self.stronghold = stronghold.into();
self
}

pub fn key_provider(mut self, key_provider: impl Into<Option<KeyProvider>>) -> Self {
self.key_provider = key_provider.into();
self
}

pub fn timeout(mut self, timeout: impl Into<Option<Duration>>) -> Self {
self.timeout = timeout.into();
self
}

/// Use an user-input password string to derive a key to use Stronghold.
pub fn password(mut self, password: impl Into<Password>) -> Self {
let password = password.into();
Expand All @@ -172,7 +184,7 @@ impl StrongholdAdapterBuilder {
///
/// [`password()`]: Self::password()
/// [`timeout()`]: Self::timeout()
pub fn build<P: AsRef<Path>>(mut self, snapshot_path: P) -> Result<StrongholdAdapter, Error> {
pub fn build<P: AsRef<Path>>(self, snapshot_path: P) -> Result<StrongholdAdapter, Error> {
// In any case, Stronghold - as a necessary component - needs to be present at this point.
let stronghold = self.stronghold.unwrap_or_default();

Expand All @@ -186,39 +198,25 @@ impl StrongholdAdapterBuilder {
let has_key_provider = self.key_provider.is_some();
let key_provider = Arc::new(Mutex::new(self.key_provider));
let stronghold = Arc::new(Mutex::new(stronghold));
let mut timeout_task = Arc::new(Mutex::new(None));

// If both `key` and `timeout` are set, then we spawn the task and keep its join handle.
if let (true, Some(Some(timeout))) = (has_key_provider, self.timeout) {
let timeout_task = Arc::new(Mutex::new(None));

// The key clearing task, with the data it owns.
let task_self = timeout_task.clone();
let key_provider = key_provider.clone();

// To keep this function synchronous (`fn`), we spawn a task that spawns the key clearing task here. It'll
// however panic when this function is not in a Tokio runtime context (usually in an `async fn`), albeit it
// itself is a `fn`. There is also a small delay from the return of this function to the task actually being
// spawned and set in the `struct`.
let stronghold_clone = stronghold.clone();
tokio::spawn(async move {
*task_self.lock().await = Some(tokio::spawn(task_key_clear(
task_self.clone(), // LHS moves task_self
stronghold_clone,
key_provider,
timeout,
)));
});

// Keep the task handle in the builder; the code below checks this.
self.timeout_task = Some(timeout_task);
if let (true, Some(timeout)) = (has_key_provider, self.timeout) {
let weak = Arc::downgrade(&timeout_task);
*Arc::get_mut(&mut timeout_task).unwrap().get_mut() = Some(tokio::spawn(task_key_clear(
weak,
stronghold.clone(),
key_provider.clone(),
timeout,
)));
}

// Create the adapter as per configuration and return it.
Ok(StrongholdAdapter {
stronghold,
key_provider,
timeout: self.timeout.unwrap_or(None),
timeout_task: self.timeout_task.unwrap_or_else(|| Arc::new(Mutex::new(None))),
timeout: self.timeout,
timeout_task,
snapshot_path: snapshot_path.as_ref().to_path_buf(),
})
}
Expand Down Expand Up @@ -269,12 +267,10 @@ impl StrongholdAdapter {
timeout_task.abort();
}

// The key clearing task, with the data it owns.
let task_self = self.timeout_task.clone();
let key_provider = self.key_provider.clone();

*self.timeout_task.lock().await = Some(tokio::spawn(task_key_clear(
task_self,
Arc::downgrade(&self.timeout_task),
self.stronghold.clone(),
key_provider,
timeout,
Expand Down Expand Up @@ -329,12 +325,10 @@ impl StrongholdAdapter {

// Recover: restart the key clearing task
if let Some(timeout) = self.timeout {
// The key clearing task, with the data it owns.
let task_self = self.timeout_task.clone();
let key_provider = self.key_provider.clone();

*self.timeout_task.lock().await = Some(tokio::spawn(task_key_clear(
task_self,
Arc::downgrade(&self.timeout_task),
self.stronghold.clone(),
key_provider,
timeout,
Expand Down Expand Up @@ -372,12 +366,10 @@ impl StrongholdAdapter {

// Recover: restart key clearing task
if let Some(timeout) = self.timeout {
// The key clearing task, with the data it owns.
let task_self = self.timeout_task.clone();
let key_provider = self.key_provider.clone();

*self.timeout_task.lock().await = Some(tokio::spawn(task_key_clear(
task_self,
Arc::downgrade(&self.timeout_task),
self.stronghold.clone(),
key_provider,
timeout,
Expand All @@ -393,12 +385,10 @@ impl StrongholdAdapter {

// Restart the key clearing task.
if let Some(timeout) = self.timeout {
// The key clearing task, with the data it owns.
let task_self = self.timeout_task.clone();
let key_provider = self.key_provider.clone();

*self.timeout_task.lock().await = Some(tokio::spawn(task_key_clear(
task_self,
Arc::downgrade(&self.timeout_task),
self.stronghold.clone(),
key_provider,
timeout,
Expand Down Expand Up @@ -453,12 +443,10 @@ impl StrongholdAdapter {

// If a new timeout is set and the key is still in the memory, spawn a new task; otherwise we do nothing.
if let (Some(_), Some(timeout)) = (self.key_provider.lock().await.as_ref(), self.timeout) {
// The key clearing task, with the data it owns.
let task_self = self.timeout_task.clone();
let key_provider = self.key_provider.clone();

*self.timeout_task.lock().await = Some(tokio::spawn(task_key_clear(
task_self,
Arc::downgrade(&self.timeout_task),
self.stronghold.clone(),
key_provider,
timeout,
Expand Down Expand Up @@ -542,21 +530,29 @@ impl StrongholdAdapter {

/// The asynchronous key clearing task purging `key` after `timeout` spent in Tokio.
async fn task_key_clear(
task_self: Arc<Mutex<Option<JoinHandle<()>>>>,
task: Weak<Mutex<Option<JoinHandle<Result<(), ClientError>>>>>,
stronghold: Arc<Mutex<Stronghold>>,
key_provider: Arc<Mutex<Option<KeyProvider>>>,
timeout: Duration,
) {
) -> Result<(), ClientError> {
tokio::time::sleep(timeout).await;

debug!("StrongholdAdapter is purging the key");
key_provider.lock().await.take();
// If the weak pointer cannot upgrade, that means the secret manager has been dropped,
// so we can just exit.
if let Some(task) = task.upgrade() {
// Take the join handle, but hold the lock until we're done
let mut lock = task.lock().await;
lock.take();

// TODO handle error
stronghold.lock().await.clear().unwrap();
debug!("StrongholdAdapter is purging the key");
key_provider.lock().await.take();

// Take self, but do nothing (we're exiting anyways).
task_self.lock().await.take();
stronghold.lock().await.clear()?;

drop(lock);
}

Ok(())
}

#[cfg(test)]
Expand Down