From 386b8604b224d62f498f3246837771fcd9d73556 Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Fri, 8 Dec 2023 11:30:42 -0700 Subject: [PATCH] zcash_client_backend: Generalize `ScanningKey` This change allows the `ScanningKey` type to represent Orchard keys as well as Sapling keys. --- zcash_client_backend/src/data_api/chain.rs | 63 +++---- zcash_client_backend/src/scanning.rs | 182 +++++++++++++-------- 2 files changed, 135 insertions(+), 110 deletions(-) diff --git a/zcash_client_backend/src/data_api/chain.rs b/zcash_client_backend/src/data_api/chain.rs index 1555785f85..f35b15b449 100644 --- a/zcash_client_backend/src/data_api/chain.rs +++ b/zcash_client_backend/src/data_api/chain.rs @@ -146,10 +146,7 @@ use std::ops::Range; use sapling::note_encryption::PreparedIncomingViewingKey; -use zcash_primitives::{ - consensus::{self, BlockHeight}, - zip32::Scope, -}; +use zcash_primitives::consensus::{self, BlockHeight}; use crate::{ data_api::{NullifierQuery, WalletWrite}, @@ -284,57 +281,43 @@ where .map_err(Error::Wallet)?; // TODO: Change `scan_block` to also scan Orchard. // https://github.com/zcash/librustzcash/issues/403 - let dfvks: Vec<_> = ufvks + let ivks: Vec<_> = ufvks .iter() .filter_map(|(account, ufvk)| ufvk.sapling().map(move |k| (account, k))) - .collect(); - // Precompute the IVKs instead of doing so per block. - let ivks = dfvks - .iter() - .flat_map(|(account, dfvk)| { - dfvk.to_sapling_keys() - .into_iter() - .map(|key| (*account, key)) - }) + .flat_map(|(account, dfvk)| dfvk.to_ivks().into_iter().map(move |key| (account, key))) .collect::>(); - // Get the nullifiers for the unspent notes we are tracking - let mut sapling_nullifiers = data_db - .get_sapling_nullifiers(NullifierQuery::Unspent) - .map_err(Error::Wallet)?; - - let mut batch_runner = BatchRunner::<_, _, _, _, ()>::new( + let mut sapling_runner = BatchRunner::<_, _, _, _, ()>::new( 100, - dfvks - .iter() - .flat_map(|(account, dfvk)| { - [ - ((**account, Scope::External), dfvk.to_ivk(Scope::External)), - ((**account, Scope::Internal), dfvk.to_ivk(Scope::Internal)), - ] - }) - .map(|(tag, ivk)| (tag, PreparedIncomingViewingKey::new(&ivk))), + ivks.iter().map(|(account, (scope, ivk, _))| { + ((**account, *scope), PreparedIncomingViewingKey::new(ivk)) + }), ); - let mut prior_block_metadata = if from_height > BlockHeight::from(0) { - data_db - .block_metadata(from_height - 1) - .map_err(Error::Wallet)? - } else { - None - }; - block_source.with_blocks::<_, DbT::Error>( Some(from_height), Some(limit), |block: CompactBlock| { - add_block_to_runner(params, block, &mut batch_runner); + add_block_to_runner(params, block, &mut sapling_runner); Ok(()) }, )?; - batch_runner.flush(); + sapling_runner.flush(); + + let mut prior_block_metadata = if from_height > BlockHeight::from(0) { + data_db + .block_metadata(from_height - 1) + .map_err(Error::Wallet)? + } else { + None + }; + + // Get the nullifiers for the unspent notes we are tracking + let mut sapling_nullifiers = data_db + .get_sapling_nullifiers(NullifierQuery::Unspent) + .map_err(Error::Wallet)?; let mut scanned_blocks = vec![]; let mut scan_end_height = from_height; @@ -351,7 +334,7 @@ where &ivks, &sapling_nullifiers, prior_block_metadata.as_ref(), - Some(&mut batch_runner), + Some(&mut sapling_runner), ) .map_err(Error::Scan)?; diff --git a/zcash_client_backend/src/scanning.rs b/zcash_client_backend/src/scanning.rs index 38d03881fb..70b6fbf023 100644 --- a/zcash_client_backend/src/scanning.rs +++ b/zcash_client_backend/src/scanning.rs @@ -12,11 +12,8 @@ use sapling::{ }; use subtle::{ConditionallySelectable, ConstantTimeEq, CtOption}; use zcash_note_encryption::batch; -use zcash_primitives::consensus::{BlockHeight, NetworkUpgrade}; -use zcash_primitives::{ - consensus, - zip32::{AccountId, Scope}, -}; +use zcash_primitives::consensus::{self, BlockHeight, NetworkUpgrade}; +use zip32::{AccountId, Scope}; use crate::data_api::{BlockMetadata, ScannedBlock, ScannedBundles}; use crate::{ @@ -43,50 +40,73 @@ pub trait ScanningKey { /// The type representing the scope of the scanning key. type Scope: Clone + Eq + std::hash::Hash + Send + 'static; - /// The type of key that is used to decrypt Sapling outputs; - type SaplingNk: Clone; + /// The type of key that is used to decrypt outputs belonging to the wallet. + type IncomingViewingKey: Clone; - type SaplingKeys: IntoIterator; + /// The type of key that is used to derive nullifiers. + type NullifierDerivingKey: Clone; - /// The type of nullifier extracted when a note is successfully - /// obtained by trial decryption. + /// The type of nullifier extracted when a note is successfully obtained by trial decryption. type Nf; - /// Obtain the underlying Sapling incoming viewing key(s) for this scanning key. - fn to_sapling_keys(&self) -> Self::SaplingKeys; + /// The type of notes obtained by trial decryption. + type Note; + + /// Obtain the underlying incoming viewing key(s) for this scanning key. + fn to_ivks( + &self, + ) -> Vec<( + Self::Scope, + Self::IncomingViewingKey, + Self::NullifierDerivingKey, + )>; /// Produces the nullifier for the specified note and witness, if possible. /// /// IVK-based implementations of this trait cannot successfully derive /// nullifiers, in which case `Self::Nf` should be set to the unit type /// and this function is a no-op. - fn sapling_nf(key: &Self::SaplingNk, note: &sapling::Note, note_position: Position) + fn nf(key: &Self::NullifierDerivingKey, note: &Self::Note, note_position: Position) -> Self::Nf; } impl ScanningKey for &K { type Scope = K::Scope; - type SaplingNk = K::SaplingNk; - type SaplingKeys = K::SaplingKeys; + type IncomingViewingKey = K::IncomingViewingKey; + type NullifierDerivingKey = K::NullifierDerivingKey; type Nf = K::Nf; - - fn to_sapling_keys(&self) -> Self::SaplingKeys { - (*self).to_sapling_keys() + type Note = K::Note; + + fn to_ivks( + &self, + ) -> Vec<( + Self::Scope, + Self::IncomingViewingKey, + Self::NullifierDerivingKey, + )> { + (*self).to_ivks() } - fn sapling_nf(key: &Self::SaplingNk, note: &sapling::Note, position: Position) -> Self::Nf { - K::sapling_nf(key, note, position) + fn nf(key: &Self::NullifierDerivingKey, note: &Self::Note, position: Position) -> Self::Nf { + K::nf(key, note, position) } } impl ScanningKey for DiversifiableFullViewingKey { type Scope = Scope; - type SaplingNk = sapling::NullifierDerivingKey; - type SaplingKeys = [(Self::Scope, SaplingIvk, Self::SaplingNk); 2]; + type IncomingViewingKey = SaplingIvk; + type NullifierDerivingKey = sapling::NullifierDerivingKey; type Nf = sapling::Nullifier; - - fn to_sapling_keys(&self) -> Self::SaplingKeys { - [ + type Note = sapling::Note; + + fn to_ivks( + &self, + ) -> Vec<( + Self::Scope, + Self::IncomingViewingKey, + Self::NullifierDerivingKey, + )> { + vec![ ( Scope::External, self.to_ivk(Scope::External), @@ -100,22 +120,29 @@ impl ScanningKey for DiversifiableFullViewingKey { ] } - fn sapling_nf(key: &Self::SaplingNk, note: &sapling::Note, position: Position) -> Self::Nf { + fn nf(key: &Self::NullifierDerivingKey, note: &Self::Note, position: Position) -> Self::Nf { note.nf(key, position.into()) } } impl ScanningKey for (Scope, SaplingIvk, sapling::NullifierDerivingKey) { type Scope = Scope; - type SaplingNk = sapling::NullifierDerivingKey; - type SaplingKeys = [(Self::Scope, SaplingIvk, Self::SaplingNk); 1]; + type IncomingViewingKey = SaplingIvk; + type NullifierDerivingKey = sapling::NullifierDerivingKey; type Nf = sapling::Nullifier; - - fn to_sapling_keys(&self) -> Self::SaplingKeys { - [self.clone()] + type Note = sapling::Note; + + fn to_ivks( + &self, + ) -> Vec<( + Self::Scope, + Self::IncomingViewingKey, + Self::NullifierDerivingKey, + )> { + vec![self.clone()] } - fn sapling_nf(key: &Self::SaplingNk, note: &sapling::Note, position: Position) -> Self::Nf { + fn nf(key: &Self::NullifierDerivingKey, note: &Self::Note, position: Position) -> Self::Nf { note.nf(key, position.into()) } } @@ -126,15 +153,22 @@ impl ScanningKey for (Scope, SaplingIvk, sapling::NullifierDerivingKey) { /// [`SaplingIvk`]: sapling::SaplingIvk impl ScanningKey for SaplingIvk { type Scope = (); - type SaplingNk = (); - type SaplingKeys = [(Self::Scope, SaplingIvk, Self::SaplingNk); 1]; + type IncomingViewingKey = SaplingIvk; + type NullifierDerivingKey = (); type Nf = (); - - fn to_sapling_keys(&self) -> Self::SaplingKeys { - [((), self.clone(), ())] + type Note = sapling::Note; + + fn to_ivks( + &self, + ) -> Vec<( + Self::Scope, + Self::IncomingViewingKey, + Self::NullifierDerivingKey, + )> { + vec![((), self.clone(), ())] } - fn sapling_nf(_key: &Self::SaplingNk, _note: &sapling::Note, _position: Position) {} + fn nf(_key: &Self::NullifierDerivingKey, _note: &Self::Note, _position: Position) -> Self::Nf {} } /// Errors that may occur in chain scanning @@ -251,17 +285,20 @@ impl fmt::Display for ScanError { /// [`IncrementalWitness`]: sapling::IncrementalWitness /// [`WalletSaplingOutput`]: crate::wallet::WalletSaplingOutput /// [`WalletTx`]: crate::wallet::WalletTx -pub fn scan_block( +pub fn scan_block< + P: consensus::Parameters + Send + 'static, + SK: ScanningKey, +>( params: &P, block: CompactBlock, - vks: &[(&AccountId, &K)], + sapling_keys: &[(&AccountId, &SK)], sapling_nullifiers: &[(AccountId, sapling::Nullifier)], prior_block_metadata: Option<&BlockMetadata>, -) -> Result, ScanError> { +) -> Result, ScanError> { scan_block_with_runner::<_, _, ()>( params, block, - vks, + sapling_keys, sapling_nullifiers, prior_block_metadata, None, @@ -332,16 +369,16 @@ fn check_hash_continuity( #[tracing::instrument(skip_all, fields(height = block.height))] pub(crate) fn scan_block_with_runner< P: consensus::Parameters + Send + 'static, - K: ScanningKey, - T: Tasks> + Sync, + SK: ScanningKey, + T: Tasks> + Sync, >( params: &P, block: CompactBlock, - vks: &[(&AccountId, K)], - nullifiers: &[(AccountId, sapling::Nullifier)], + sapling_keys: &[(&AccountId, SK)], + sapling_nullifiers: &[(AccountId, sapling::Nullifier)], prior_block_metadata: Option<&BlockMetadata>, - mut batch_runner: Option<&mut TaggedBatchRunner>, -) -> Result, ScanError> { + mut batch_runner: Option<&mut TaggedBatchRunner>, +) -> Result, ScanError> { if let Some(scan_error) = check_hash_continuity(&block, prior_block_metadata) { return Err(scan_error); } @@ -444,7 +481,7 @@ pub(crate) fn scan_block_with_runner< )?; let compact_block_tx_count = block.vtx.len(); - let mut wtxs: Vec> = vec![]; + let mut wtxs: Vec> = vec![]; let mut sapling_nullifier_map = Vec::with_capacity(block.vtx.len()); let mut sapling_note_commitments: Vec<(sapling::Node, Retention)> = vec![]; for (tx_idx, tx) in block.vtx.into_iter().enumerate() { @@ -465,7 +502,7 @@ pub(crate) fn scan_block_with_runner< // Find the first tracked nullifier that matches this spend, and produce // a WalletShieldedSpend if there is a match, in constant time. - let spend = nullifiers + let spend = sapling_nullifiers .iter() .map(|&(account, nf)| CtOption::new(account, nf.ct_eq(&spend_nf))) .fold(CtOption::new(AccountId::ZERO, 0.into()), |first, next| { @@ -498,7 +535,7 @@ pub(crate) fn scan_block_with_runner< u32::try_from(tx.actions.len()).expect("Orchard action count cannot exceed a u32"); // Check for incoming notes while incrementing tree and witnesses - let mut shielded_outputs: Vec> = vec![]; + let mut shielded_outputs: Vec> = vec![]; { let decoded = &tx .outputs @@ -513,10 +550,10 @@ pub(crate) fn scan_block_with_runner< .collect::>(); let decrypted: Vec<_> = if let Some(runner) = batch_runner.as_mut() { - let vks = vks + let sapling_keys = sapling_keys .iter() .flat_map(|(a, k)| { - k.to_sapling_keys() + k.to_ivks() .into_iter() .map(move |(scope, _, nk)| ((**a, scope), nk)) }) @@ -525,37 +562,36 @@ pub(crate) fn scan_block_with_runner< let mut decrypted = runner.collect_results(cur_hash, txid); (0..decoded.len()) .map(|i| { - decrypted.remove(&(txid, i)).map(|d_out| { - let a = d_out.ivk_tag.0; - let nk = vks.get(&d_out.ivk_tag).expect( + decrypted.remove(&(txid, i)).map(|d_note| { + let a = d_note.ivk_tag.0; + let nk = sapling_keys.get(&d_note.ivk_tag).expect( "The batch runner and scan_block must use the same set of IVKs.", ); - (d_out.note, a, d_out.ivk_tag.1, (*nk).clone()) + (d_note.note, a, d_note.ivk_tag.1, (*nk).clone()) }) }) .collect() } else { - let vks = vks + let sapling_keys = sapling_keys .iter() .flat_map(|(a, k)| { - k.to_sapling_keys() + k.to_ivks() .into_iter() .map(move |(scope, ivk, nk)| (**a, scope, ivk, nk)) }) .collect::>(); - let ivks = vks + let ivks = sapling_keys .iter() - .map(|(_, _, ivk, _)| ivk) - .map(PreparedIncomingViewingKey::new) + .map(|(_, _, ivk, _)| PreparedIncomingViewingKey::new(ivk)) .collect::>(); batch::try_compact_note_decryption(&ivks, &decoded[..]) .into_iter() .map(|v| { v.map(|((note, _), ivk_idx)| { - let (account, scope, _, nk) = &vks[ivk_idx]; + let (account, scope, _, nk) = &sapling_keys[ivk_idx]; (note, *account, scope.clone(), (*nk).clone()) }) }) @@ -588,7 +624,7 @@ pub(crate) fn scan_block_with_runner< let note_commitment_tree_position = Position::from(u64::from( sapling_commitment_tree_size + u32::try_from(output_idx).unwrap(), )); - let nf = K::sapling_nf(&nk, ¬e, note_commitment_tree_position); + let nf = SK::nf(&nk, ¬e, note_commitment_tree_position); shielded_outputs.push(WalletSaplingOutput::from_parts( output_idx, @@ -837,7 +873,7 @@ mod tests { let mut batch_runner = if scan_multithreaded { let mut runner = BatchRunner::<_, _, _, _, ()>::new( 10, - dfvk.to_sapling_keys() + dfvk.to_ivks() .iter() .map(|(scope, ivk, _)| ((account, *scope), ivk)) .map(|(tag, ivk)| (tag, PreparedIncomingViewingKey::new(ivk))), @@ -924,7 +960,7 @@ mod tests { let mut batch_runner = if scan_multithreaded { let mut runner = BatchRunner::<_, _, _, _, ()>::new( 10, - dfvk.to_sapling_keys() + dfvk.to_ivks() .iter() .map(|(scope, ivk, _)| ((account, *scope), ivk)) .map(|(tag, ivk)| (tag, PreparedIncomingViewingKey::new(ivk))), @@ -997,10 +1033,16 @@ mod tests { Some((0, 0)), ); assert_eq!(cb.vtx.len(), 2); - let vks: Vec<(&AccountId, &SaplingIvk)> = vec![]; - - let scanned_block = - scan_block(&Network::TestNetwork, cb, &vks[..], &[(account, nf)], None).unwrap(); + let sapling_keys: Vec<(&AccountId, &SaplingIvk)> = vec![]; + + let scanned_block = scan_block( + &Network::TestNetwork, + cb, + &sapling_keys[..], + &[(account, nf)], + None, + ) + .unwrap(); let txs = scanned_block.transactions(); assert_eq!(txs.len(), 1);