Skip to content

Commit

Permalink
Cleanup and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
batconjurer committed Apr 30, 2024
1 parent 8f8c238 commit 08a3b7f
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 23 deletions.
116 changes: 113 additions & 3 deletions crates/sdk/src/masp/shielded_ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1722,7 +1722,6 @@ impl<U: ShieldedUtils + Send + Sync> ShieldedContext<U> {
sks: &[ExtendedSpendingKey],
fvks: &[ViewingKey],
) -> Result<(), Error> {
// add new viewing keys
// Reload the state from file to get the last confirmed state and
// discard any speculative data, we cannot fetch on top of a
// speculative state
Expand All @@ -1737,6 +1736,7 @@ impl<U: ShieldedUtils + Send + Sync> ShieldedContext<U> {
};
}

// add new viewing keys
for esk in sks {
let vk = to_viewing_key(esk).vk;
self.vk_heights.entry(vk).or_default();
Expand All @@ -1748,11 +1748,12 @@ impl<U: ShieldedUtils + Send + Sync> ShieldedContext<U> {
let _ = self.save().await;

let native_token = query_native_token(client).await?;
// the latest block height which has been added to the witness Merkle
// tree
// the height of the key that is least synced
let Some(least_idx) = self.vk_heights.values().min().cloned() else {
return Ok(());
};
// the latest block height which has been added to the witness Merkle
// tree
let last_witnessed_tx = self.tx_note_map.keys().max().cloned();
// get the bounds on the block heights to fetch
let start_height =
Expand All @@ -1766,6 +1767,8 @@ impl<U: ShieldedUtils + Send + Sync> ShieldedContext<U> {
let last_query_height = last_query_height.unwrap_or(last_block_height);
let last_query_height =
std::cmp::min(last_query_height, last_block_height);

// Update the commitment tree and witnesses
self.update_witness_map::<_, _, M>(
client,
progress.io(),
Expand Down Expand Up @@ -2485,4 +2488,111 @@ mod shielded_ctx_tests {
}];
assert_eq!(keys, expected);
}

/// Test that we cache and persist trial-decryptions
/// when the scanning process does not complete successfully.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_decrypted_cache() {
let temp_dir = tempdir().unwrap();
let mut shielded_ctx =
FsShieldedUtils::new(temp_dir.path().to_path_buf());
let (client, masp_tx_sender) = test_client(100.into());
let io = StdIo;
let progress = DefaultTracker::new(&io);
let vk = ExtendedFullViewingKey::from(
ExtendedViewingKey::from_str(AA_VIEWING_KEY).expect("Test failed"),
)
.fvk
.vk;

// Fetch a large number of MASP notes
let (masp_tx, changed_keys) = arbitrary_masp_tx();
for h in 1..20 {
masp_tx_sender
.send(Some((
IndexedTx {
height: h.into(),
index: TxIndex(1),
is_wrapper: false,
},
(Default::default(), changed_keys.clone(), masp_tx.clone()),
)))
.expect("Test failed");
}
masp_tx_sender.send(None).expect("Test failed");

// we expect this to fail.
let result = shielded_ctx
.fetch::<_, _, _, TestingMaspClient>(
&client,
&progress,
RetryStrategy::Times(1),
None,
None,
0,
&[],
&[vk],
)
.await
.unwrap_err();
match result {
Error::Other(msg) => assert_eq!(
msg.as_str(),
"After retrying, could not fetch all MASP txs."
),
other => panic!("{:?} does not match Error::Other(_)", other),
}

// reload the shielded context
shielded_ctx.load_confirmed().await.expect("Test failed");

// maliciously remove an entry from the shielded context
// so that one of the last fetched notes will fail to scan.
shielded_ctx.vk_heights.clear();
shielded_ctx.tx_note_map.remove(&IndexedTx {
height: 18.into(),
index: TxIndex(1),
is_wrapper: false,
});
shielded_ctx.save().await.expect("Test failed");

// refetch the same MASP notes
for h in 1..20 {
masp_tx_sender
.send(Some((
IndexedTx {
height: h.into(),
index: TxIndex(1),
is_wrapper: false,
},
(Default::default(), changed_keys.clone(), masp_tx.clone()),
)))
.expect("Test failed");
}
masp_tx_sender.send(None).expect("Test failed");

// we expect this to fail.
shielded_ctx
.fetch::<_, _, _, TestingMaspClient>(
&client,
&progress,
RetryStrategy::Times(1),
None,
None,
0,
&[],
&[vk],
)
.await
.unwrap_err();

// because of an error in scanning, there should be elements
// in the decrypted cache.
shielded_ctx.load_confirmed().await.expect("Test failed");
let result: HashMap<(IndexedTx, ViewingKey), DecryptedData> =
shielded_ctx.decrypted_note_cache.drain().collect();
// unfortunately we cannot easily assert what will be in this
// cache as scanning is done in parallel, introducing non-determinism
assert!(!result.is_empty());
}
}
18 changes: 12 additions & 6 deletions crates/sdk/src/masp/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ impl Client for TestingClient {
self.inner.perform(request).await
}
}

/// Creat a test client for unit testing as well
/// as a channel for communicating with it.
pub fn test_client(
last_height: BlockHeight,
) -> (TestingClient, flume::Sender<Option<IndexedNoteEntry>>) {
Expand All @@ -92,6 +95,9 @@ pub fn test_client(
)
}

/// A client for unit tests. It "fetches" a new note
/// when a channel controlled by the unit test sends
/// it one.
#[derive(Clone)]
pub struct TestingMaspClient<'a> {
client: &'a TestingClient,
Expand All @@ -109,16 +115,18 @@ impl<'a> MaspClient<'a, TestingClient> for TestingMaspClient<'a> {
&self,
_: &ShieldedContext<U>,
_: &IO,
_: IndexedTx,
last_witness_tx: IndexedTx,
_: BlockHeight,
) -> Result<CommitmentTreeUpdates, Error> {
let mut note_map_delta: BTreeMap<IndexedTx, usize> = Default::default();
let mut channel_temp = vec![];
let mut note_pos = 0;
for msg in self.client.next_masp_txs.drain() {
if let Some((ix, _)) = msg.as_ref() {
note_map_delta.insert(*ix, note_pos);
note_pos += 1;
if *ix >= last_witness_tx {
note_map_delta.insert(*ix, note_pos);
note_pos += 1;
}
}
channel_temp.push(msg);
}
Expand Down Expand Up @@ -163,9 +171,7 @@ impl<'a> MaspClient<'a, TestingClient> for TestingMaspClient<'a> {
}
}

/// An iterator that yields its first element
/// but runs forever on the second
/// `next` call.
/// An iterator that yields its first element only
struct YieldOnceIterator {
first: Option<IndexedNoteEntry>,
}
Expand Down
31 changes: 28 additions & 3 deletions crates/sdk/src/masp/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ pub struct MaspTokenRewardData {
pub locked_amount_target: Uint,
}

/// The MASP transaction(s) found in a Namada tx.
/// These transactions can appear in the fee payment
/// and / or the main payload.
#[derive(Debug, Clone)]
pub(super) struct ExtractedMaspTx {
pub(crate) fee_unshielding:
Expand Down Expand Up @@ -131,7 +134,7 @@ pub enum ContextSyncStatus {
Speculative,
}

/// a masp change
/// A MASP specific amount delta.
#[derive(BorshSerialize, BorshDeserialize, BorshDeserializer, Debug, Clone)]
pub struct MaspChange {
/// the token address
Expand All @@ -142,6 +145,10 @@ pub struct MaspChange {

#[derive(Debug, Default)]
/// Data returned by successfully scanning a tx
///
/// This is append-only data that will be sent
/// to a [`TaskManager`] to be applied to the
/// shielded context.
pub(super) struct ScannedData {
pub div_map: HashMap<usize, Diversifier>,
pub memo_map: HashMap<usize, MemoBytes>,
Expand All @@ -153,6 +160,7 @@ pub(super) struct ScannedData {
}

impl ScannedData {
/// Append `self` to a [`ShieldedContext`]
pub(super) fn apply_to<U: ShieldedUtils>(
mut self,
ctx: &mut ShieldedContext<U>,
Expand Down Expand Up @@ -181,6 +189,7 @@ impl ScannedData {
ctx.decrypted_note_cache.merge(self.decrypted_note_cache);
}

/// Merge to different instances of `Self`.
pub(super) fn merge(&mut self, mut other: Self) {
for (k, v) in other.note_map.drain(..) {
self.note_map.insert(k, v);
Expand Down Expand Up @@ -211,6 +220,11 @@ impl ScannedData {

#[derive(Debug, Clone, BorshSerialize, BorshDeserialize)]
/// Data extracted from a successfully decrypted MASP note
///
/// These will be cached until the trial-decryption phase
/// of shielded-sync has finished. Then they will be
/// re-scanned as part of nullifying spent notes (which
/// is not parallelizable).
pub struct DecryptedData {
pub tx: Transaction,
pub keys: BTreeSet<namada_core::storage::Key>,
Expand All @@ -227,6 +241,7 @@ pub struct DecryptedDataCache {
}

impl DecryptedDataCache {
/// Add an entry to the cache
pub fn insert(
&mut self,
key: (IndexedTx, ViewingKey),
Expand All @@ -235,19 +250,23 @@ impl DecryptedDataCache {
self.inner.insert(key, value);
}

/// Merge another cache into `self`.
pub fn merge(&mut self, mut other: Self) {
for (k, v) in other.inner.drain(..) {
self.insert(k, v);
}
}

/// Check if the cache already contains an entry for a given IndexedTx and
/// viewing key.
pub fn contains(&self, ix: &IndexedTx, vk: &ViewingKey) -> bool {
self.inner
.keys()
.find_map(|(i, v)| (i == ix && v == vk).then_some(()))
.is_some()
}

/// Return an iterator over the cache that consumes it.
pub fn drain(
&mut self,
) -> impl Iterator<Item = ((IndexedTx, ViewingKey), DecryptedData)> + '_
Expand All @@ -258,8 +277,9 @@ impl DecryptedDataCache {

/// A cache of fetched indexed transactions.
///
/// The cache is designed so that it either contains
/// all transactions from a given height, or none.
/// An invariant that shielded-sync maintains is that
/// this cache either contains all transactions from
/// a given height, or none.
#[derive(Debug, Default, Clone)]
pub struct Unscanned {
pub(super) txs: Arc<Mutex<IndexedNoteData>>,
Expand All @@ -283,6 +303,7 @@ impl BorshDeserialize for Unscanned {
}

impl Unscanned {
/// Append elements to the cache from an iterator.
pub fn extend<I>(&self, items: I)
where
I: IntoIterator<Item = IndexedNoteEntry>,
Expand All @@ -291,11 +312,14 @@ impl Unscanned {
locked.extend(items);
}

/// Add a single entry to the cache.
pub fn insert(&self, (k, v): IndexedNoteEntry) {
let mut locked = self.txs.lock().unwrap();
locked.insert(k, v);
}

/// Check if this cache has already been populated for a given
/// block height.
pub fn contains_height(&self, height: u64) -> bool {
let locked = self.txs.lock().unwrap();
locked.keys().any(|k| k.height.0 == height)
Expand Down Expand Up @@ -327,6 +351,7 @@ impl Unscanned {
.unwrap_or_default()
}

/// Remove the first entry from the cache and return it.
pub fn pop_first(&self) -> Option<IndexedNoteEntry> {
let mut locked = self.txs.lock().unwrap();
locked.pop_first()
Expand Down
Loading

0 comments on commit 08a3b7f

Please sign in to comment.