Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rustfmt: Run on util/* (2/2) #3323

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions lightning/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ pub(crate) mod fuzz_wrappers;
#[macro_use]
pub mod ser_macros;

#[cfg(fuzzing)]
pub mod base32;
#[cfg(not(fuzzing))]
pub(crate) mod base32;
pub mod errors;
pub mod ser;
pub mod message_signing;
pub mod persist;
pub mod scid_utils;
pub mod ser;
pub mod sweep;
pub mod wakers;
#[cfg(fuzzing)]
pub mod base32;
#[cfg(not(fuzzing))]
pub(crate) mod base32;

pub(crate) mod atomic_counter;
pub(crate) mod async_poll;
pub(crate) mod atomic_counter;
pub(crate) mod byte_utils;
pub(crate) mod transaction_utils;
pub mod hash_tables;
pub(crate) mod transaction_utils;

#[cfg(feature = "std")]
pub(crate) mod time;
Expand All @@ -43,8 +43,8 @@ pub mod indexed_map;
pub(crate) mod macro_logger;

// These have to come after macro_logger to build
pub mod logger;
pub mod config;
pub mod logger;

#[cfg(any(test, feature = "_test_utils"))]
pub mod test_utils;
Expand Down
108 changes: 78 additions & 30 deletions lightning/src/util/scid_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ pub fn vout_from_scid(short_channel_id: u64) -> u16 {

/// Constructs a `short_channel_id` using the components pieces. Results in an error
/// if the block height, tx index, or vout index overflow the maximum sizes.
pub fn scid_from_parts(block: u64, tx_index: u64, vout_index: u64) -> Result<u64, ShortChannelIdError> {
pub fn scid_from_parts(
block: u64, tx_index: u64, vout_index: u64,
) -> Result<u64, ShortChannelIdError> {
if block > MAX_SCID_BLOCK {
return Err(ShortChannelIdError::BlockOverflow);
}
Expand All @@ -71,12 +73,12 @@ pub fn scid_from_parts(block: u64, tx_index: u64, vout_index: u64) -> Result<u64
/// 3) payments intended to be intercepted will route using a fake scid (this is typically used so
/// the forwarding node can open a JIT channel to the next hop)
pub(crate) mod fake_scid {
use bitcoin::constants::ChainHash;
use bitcoin::Network;
use crate::sign::EntropySource;
use crate::crypto::chacha20::ChaCha20;
use crate::util::scid_utils;
use crate::prelude::*;
use crate::sign::EntropySource;
use crate::util::scid_utils;
use bitcoin::constants::ChainHash;
use bitcoin::Network;

use core::ops::Deref;

Expand All @@ -89,7 +91,6 @@ pub(crate) mod fake_scid {
const BLOCKS_PER_MONTH: u32 = 144 /* blocks per day */ * 30 /* days per month */;
pub(crate) const MAX_SCID_BLOCKS_FROM_NOW: u32 = BLOCKS_PER_MONTH;


/// Fake scids are divided into namespaces, with each namespace having its own identifier between
/// [0..7]. This allows us to identify what namespace a fake scid corresponds to upon HTLC
/// receipt, and handle the HTLC accordingly. The namespace identifier is encrypted when encoded
Expand All @@ -101,44 +102,59 @@ pub(crate) mod fake_scid {
/// SCID aliases for outbound private channels
OutboundAlias,
/// Payment interception namespace
Intercept
Intercept,
}

impl Namespace {
/// We generate "realistic-looking" random scids here, meaning the scid's block height is
/// between segwit activation and the current best known height, and the tx index and output
/// index are also selected from a "reasonable" range. We add this logic because it makes it
/// non-obvious at a glance that the scid is fake, e.g. if it appears in invoice route hints.
pub(crate) fn get_fake_scid<ES: Deref>(&self, highest_seen_blockheight: u32, chain_hash: &ChainHash, fake_scid_rand_bytes: &[u8; 32], entropy_source: &ES) -> u64
where ES::Target: EntropySource,
pub(crate) fn get_fake_scid<ES: Deref>(
&self, highest_seen_blockheight: u32, chain_hash: &ChainHash,
fake_scid_rand_bytes: &[u8; 32], entropy_source: &ES,
) -> u64
where
ES::Target: EntropySource,
{
// Ensure we haven't created a namespace that doesn't fit into the 3 bits we've allocated for
// namespaces.
assert!((*self as u8) < MAX_NAMESPACES);
let rand_bytes = entropy_source.get_secure_random_bytes();

let segwit_activation_height = segwit_activation_height(chain_hash);
let mut blocks_since_segwit_activation = highest_seen_blockheight.saturating_sub(segwit_activation_height);
let mut blocks_since_segwit_activation =
highest_seen_blockheight.saturating_sub(segwit_activation_height);

// We want to ensure that this fake channel won't conflict with any transactions we haven't
// seen yet, in case `highest_seen_blockheight` is updated before we get full information
// about transactions confirmed in the given block.
blocks_since_segwit_activation = blocks_since_segwit_activation.saturating_sub(MAX_SCID_BLOCKS_FROM_NOW);
blocks_since_segwit_activation =
blocks_since_segwit_activation.saturating_sub(MAX_SCID_BLOCKS_FROM_NOW);

let rand_for_height = u32::from_be_bytes(rand_bytes[..4].try_into().unwrap());
let fake_scid_height = segwit_activation_height + rand_for_height % (blocks_since_segwit_activation + 1);
let fake_scid_height =
segwit_activation_height + rand_for_height % (blocks_since_segwit_activation + 1);

let rand_for_tx_index = u32::from_be_bytes(rand_bytes[4..8].try_into().unwrap());
let fake_scid_tx_index = rand_for_tx_index % MAX_TX_INDEX;

// Put the scid in the given namespace.
let fake_scid_vout = self.get_encrypted_vout(fake_scid_height, fake_scid_tx_index, fake_scid_rand_bytes);
scid_utils::scid_from_parts(fake_scid_height as u64, fake_scid_tx_index as u64, fake_scid_vout as u64).unwrap()
let fake_scid_vout =
self.get_encrypted_vout(fake_scid_height, fake_scid_tx_index, fake_scid_rand_bytes);
scid_utils::scid_from_parts(
fake_scid_height as u64,
fake_scid_tx_index as u64,
fake_scid_vout as u64,
)
.unwrap()
}

/// We want to ensure that a 3rd party can't identify a payment as belong to a given
/// `Namespace`. Therefore, we encrypt it using a random bytes provided by `ChannelManager`.
fn get_encrypted_vout(&self, block_height: u32, tx_index: u32, fake_scid_rand_bytes: &[u8; 32]) -> u8 {
fn get_encrypted_vout(
&self, block_height: u32, tx_index: u32, fake_scid_rand_bytes: &[u8; 32],
) -> u8 {
let mut salt = [0 as u8; 8];
let block_height_bytes = block_height.to_be_bytes();
salt[0..4].copy_from_slice(&block_height_bytes);
Expand All @@ -161,7 +177,9 @@ pub(crate) mod fake_scid {
}

/// Returns whether the given fake scid falls into the phantom namespace.
pub fn is_valid_phantom(fake_scid_rand_bytes: &[u8; 32], scid: u64, chain_hash: &ChainHash) -> bool {
pub fn is_valid_phantom(
fake_scid_rand_bytes: &[u8; 32], scid: u64, chain_hash: &ChainHash,
) -> bool {
let block_height = scid_utils::block_from_scid(scid);
let tx_index = scid_utils::tx_index_from_scid(scid);
let namespace = Namespace::Phantom;
Expand All @@ -171,7 +189,9 @@ pub(crate) mod fake_scid {
}

/// Returns whether the given fake scid falls into the intercept namespace.
pub fn is_valid_intercept(fake_scid_rand_bytes: &[u8; 32], scid: u64, chain_hash: &ChainHash) -> bool {
pub fn is_valid_intercept(
fake_scid_rand_bytes: &[u8; 32], scid: u64, chain_hash: &ChainHash,
) -> bool {
let block_height = scid_utils::block_from_scid(scid);
let tx_index = scid_utils::tx_index_from_scid(scid);
let namespace = Namespace::Intercept;
Expand All @@ -182,12 +202,16 @@ pub(crate) mod fake_scid {

#[cfg(test)]
mod tests {
use bitcoin::constants::ChainHash;
use bitcoin::network::Network;
use crate::util::scid_utils::fake_scid::{is_valid_intercept, is_valid_phantom, MAINNET_SEGWIT_ACTIVATION_HEIGHT, MAX_TX_INDEX, MAX_NAMESPACES, Namespace, NAMESPACE_ID_BITMASK, segwit_activation_height, TEST_SEGWIT_ACTIVATION_HEIGHT};
use crate::sync::Arc;
use crate::util::scid_utils;
use crate::util::scid_utils::fake_scid::{
is_valid_intercept, is_valid_phantom, segwit_activation_height, Namespace,
MAINNET_SEGWIT_ACTIVATION_HEIGHT, MAX_NAMESPACES, MAX_TX_INDEX, NAMESPACE_ID_BITMASK,
TEST_SEGWIT_ACTIVATION_HEIGHT,
};
use crate::util::test_utils;
use crate::sync::Arc;
use bitcoin::constants::ChainHash;
use bitcoin::network::Network;

#[test]
fn namespace_identifier_is_within_range() {
Expand All @@ -203,7 +227,10 @@ pub(crate) mod fake_scid {
#[test]
fn test_segwit_activation_height() {
let mainnet_genesis = ChainHash::using_genesis_block(Network::Bitcoin);
assert_eq!(segwit_activation_height(&mainnet_genesis), MAINNET_SEGWIT_ACTIVATION_HEIGHT);
assert_eq!(
segwit_activation_height(&mainnet_genesis),
MAINNET_SEGWIT_ACTIVATION_HEIGHT
);

let testnet_genesis = ChainHash::using_genesis_block(Network::Testnet);
assert_eq!(segwit_activation_height(&testnet_genesis), TEST_SEGWIT_ACTIVATION_HEIGHT);
Expand All @@ -221,7 +248,8 @@ pub(crate) mod fake_scid {
let fake_scid_rand_bytes = [0; 32];
let testnet_genesis = ChainHash::using_genesis_block(Network::Testnet);
let valid_encrypted_vout = namespace.get_encrypted_vout(0, 0, &fake_scid_rand_bytes);
let valid_fake_scid = scid_utils::scid_from_parts(1, 0, valid_encrypted_vout as u64).unwrap();
let valid_fake_scid =
scid_utils::scid_from_parts(1, 0, valid_encrypted_vout as u64).unwrap();
assert!(is_valid_phantom(&fake_scid_rand_bytes, valid_fake_scid, &testnet_genesis));
let invalid_fake_scid = scid_utils::scid_from_parts(1, 0, 12).unwrap();
assert!(!is_valid_phantom(&fake_scid_rand_bytes, invalid_fake_scid, &testnet_genesis));
Expand All @@ -233,20 +261,31 @@ pub(crate) mod fake_scid {
let fake_scid_rand_bytes = [0; 32];
let testnet_genesis = ChainHash::using_genesis_block(Network::Testnet);
let valid_encrypted_vout = namespace.get_encrypted_vout(0, 0, &fake_scid_rand_bytes);
let valid_fake_scid = scid_utils::scid_from_parts(1, 0, valid_encrypted_vout as u64).unwrap();
let valid_fake_scid =
scid_utils::scid_from_parts(1, 0, valid_encrypted_vout as u64).unwrap();
assert!(is_valid_intercept(&fake_scid_rand_bytes, valid_fake_scid, &testnet_genesis));
let invalid_fake_scid = scid_utils::scid_from_parts(1, 0, 12).unwrap();
assert!(!is_valid_intercept(&fake_scid_rand_bytes, invalid_fake_scid, &testnet_genesis));
assert!(!is_valid_intercept(
&fake_scid_rand_bytes,
invalid_fake_scid,
&testnet_genesis
));
}

#[test]
fn test_get_fake_scid() {
let mainnet_genesis = ChainHash::using_genesis_block(Network::Bitcoin);
let seed = [0; 32];
let fake_scid_rand_bytes = [1; 32];
let keys_manager = Arc::new(test_utils::TestKeysInterface::new(&seed, Network::Testnet));
let keys_manager =
Arc::new(test_utils::TestKeysInterface::new(&seed, Network::Testnet));
let namespace = Namespace::Phantom;
let fake_scid = namespace.get_fake_scid(500_000, &mainnet_genesis, &fake_scid_rand_bytes, &keys_manager);
let fake_scid = namespace.get_fake_scid(
500_000,
&mainnet_genesis,
&fake_scid_rand_bytes,
&keys_manager,
);

let fake_height = scid_utils::block_from_scid(fake_scid);
assert!(fake_height >= MAINNET_SEGWIT_ACTIVATION_HEIGHT);
Expand Down Expand Up @@ -298,8 +337,17 @@ mod tests {
assert_eq!(scid_from_parts(0x00000001, 0x00000002, 0x0003).unwrap(), 0x000001_000002_0003);
assert_eq!(scid_from_parts(0x00111111, 0x00222222, 0x3333).unwrap(), 0x111111_222222_3333);
assert_eq!(scid_from_parts(0x00ffffff, 0x00ffffff, 0xffff).unwrap(), 0xffffff_ffffff_ffff);
assert_eq!(scid_from_parts(0x01ffffff, 0x00000000, 0x0000).err().unwrap(), ShortChannelIdError::BlockOverflow);
assert_eq!(scid_from_parts(0x00000000, 0x01ffffff, 0x0000).err().unwrap(), ShortChannelIdError::TxIndexOverflow);
assert_eq!(scid_from_parts(0x00000000, 0x00000000, 0x010000).err().unwrap(), ShortChannelIdError::VoutIndexOverflow);
assert_eq!(
scid_from_parts(0x01ffffff, 0x00000000, 0x0000).err().unwrap(),
ShortChannelIdError::BlockOverflow
);
assert_eq!(
scid_from_parts(0x00000000, 0x01ffffff, 0x0000).err().unwrap(),
ShortChannelIdError::TxIndexOverflow
);
assert_eq!(
scid_from_parts(0x00000000, 0x00000000, 0x010000).err().unwrap(),
ShortChannelIdError::VoutIndexOverflow
);
}
}
Loading
Loading