Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify and fix AtomicCounter #3302

Merged
merged 4 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lightning/src/ln/functional_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7668,8 +7668,8 @@ fn test_bump_penalty_txn_on_revoked_htlcs() {
assert_ne!(node_txn[0].input[0].previous_output, node_txn[2].input[0].previous_output);
assert_ne!(node_txn[1].input[0].previous_output, node_txn[2].input[0].previous_output);

assert_eq!(node_txn[0].input[0].previous_output, revoked_htlc_txn[1].input[0].previous_output);
assert_eq!(node_txn[1].input[0].previous_output, revoked_htlc_txn[0].input[0].previous_output);
assert_eq!(node_txn[1].input[0].previous_output, revoked_htlc_txn[1].input[0].previous_output);
assert_eq!(node_txn[0].input[0].previous_output, revoked_htlc_txn[0].input[0].previous_output);

// node_txn[3] spends the revoked outputs from the revoked_htlc_txn (which only have one
// output, checked above).
Expand Down
2 changes: 1 addition & 1 deletion lightning/src/ln/interactivetxs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,7 @@ mod tests {
impl EntropySource for TestEntropySource {
fn get_secure_random_bytes(&self) -> [u8; 32] {
let mut res = [0u8; 32];
let increment = self.0.get_increment();
let increment = self.0.next();
for (i, byte) in res.iter_mut().enumerate() {
// Rotate the increment value by 'i' bits to the right, to avoid clashes
// when `generate_local_serial_id` does a parity flip on consecutive calls for the
Expand Down
7 changes: 7 additions & 0 deletions lightning/src/ln/max_payment_path_len_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,13 @@ fn blinded_path_with_custom_tlv() {
create_announced_chan_between_nodes(&nodes, 1, 2);
let chan_upd_2_3 = create_announced_chan_between_nodes_with_value(&nodes, 2, 3, 1_000_000, 0).0.contents;

// Ensure all nodes are at the same height
let node_max_height = nodes.iter().map(|node| node.blocks.lock().unwrap().len()).max().unwrap() as u32;
connect_blocks(&nodes[0], node_max_height - nodes[0].best_block_info().1);
connect_blocks(&nodes[1], node_max_height - nodes[1].best_block_info().1);
connect_blocks(&nodes[2], node_max_height - nodes[2].best_block_info().1);
connect_blocks(&nodes[3], node_max_height - nodes[3].best_block_info().1);

// Construct the route parameters for sending to nodes[3]'s blinded path.
let amt_msat = 100_000;
let (payment_preimage, payment_hash, payment_secret) = get_payment_preimage_hash(&nodes[3], Some(amt_msat), None);
Expand Down
4 changes: 4 additions & 0 deletions lightning/src/ln/monitor_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2251,6 +2251,10 @@ fn do_test_restored_packages_retry(check_old_monitor_retries_after_upgrade: bool

let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs);

// Reset our RNG counters to mirror the RNG output from when this test was written.
tnull marked this conversation as resolved.
Show resolved Hide resolved
nodes[0].keys_manager.backing.inner.entropy_source.set_counter(0x1_0000_0004);
nodes[1].keys_manager.backing.inner.entropy_source.set_counter(0x1_0000_0004);

// Open a channel, lock in an HTLC, and immediately broadcast the commitment transaction. This
// ensures that the HTLC timeout package is held until we reach its expiration height.
let (_, _, chan_id, funding_tx) = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 100_000, 50_000_000);
Expand Down
2 changes: 1 addition & 1 deletion lightning/src/ln/peer_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM

fn get_ephemeral_key(&self) -> SecretKey {
let mut ephemeral_hash = self.ephemeral_key_midstate.clone();
let counter = self.peer_counter.get_increment();
let counter = self.peer_counter.next();
ephemeral_hash.input(&counter.to_le_bytes());
SecretKey::from_slice(&Sha256::from_engine(ephemeral_hash).to_byte_array()).expect("You broke SHA-256!")
}
Expand Down
17 changes: 14 additions & 3 deletions lightning/src/sign/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,6 @@ pub trait ChangeDestinationSource {
///
/// This implementation performs no policy checks and is insufficient by itself as
/// a secure external signer.
#[derive(Debug)]
pub struct InMemorySigner {
/// Holder secret key in the 2-of-2 multisig script of a channel. This key also backs the
/// holder's anchor output in a commitment transaction, if one is present.
Expand Down Expand Up @@ -1854,6 +1853,9 @@ pub struct KeysManager {
channel_master_key: Xpriv,
channel_child_index: AtomicUsize,

#[cfg(test)]
pub(crate) entropy_source: RandomBytes,
#[cfg(not(test))]
entropy_source: RandomBytes,

seed: [u8; 32],
Expand Down Expand Up @@ -2310,6 +2312,9 @@ impl SignerProvider for KeysManager {
/// Switching between this struct and [`KeysManager`] will invalidate any previously issued
/// invoices and attempts to pay previous invoices will fail.
pub struct PhantomKeysManager {
#[cfg(test)]
pub(crate) inner: KeysManager,
#[cfg(not(test))]
inner: KeysManager,
inbound_payment_key: KeyMaterial,
phantom_secret: SecretKey,
Expand Down Expand Up @@ -2475,7 +2480,6 @@ impl PhantomKeysManager {
}

/// An implementation of [`EntropySource`] using ChaCha20.
#[derive(Debug)]
pub struct RandomBytes {
/// Seed from which all randomness produced is derived from.
seed: [u8; 32],
Expand All @@ -2489,11 +2493,18 @@ impl RandomBytes {
pub fn new(seed: [u8; 32]) -> Self {
Self { seed, index: AtomicCounter::new() }
}

#[cfg(test)]
/// Force the counter to a value to produce the same output again. Mostly useful in tests where
/// we need to maintain behavior with a previous version which didn't use as much RNG output.
pub(crate) fn set_counter(&self, count: u64) {
self.index.set_counter(count);
}
}

impl EntropySource for RandomBytes {
fn get_secure_random_bytes(&self) -> [u8; 32] {
let index = self.index.get_increment();
let index = self.index.next();
let mut nonce = [0u8; 16];
nonce[..8].copy_from_slice(&index.to_be_bytes());
ChaCha20::get_single_block(&self.seed, &nonce)
Expand Down
52 changes: 32 additions & 20 deletions lightning/src/util/atomic_counter.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,44 @@
//! A simple atomic counter that uses AtomicUsize to give a u64 counter.
//! A simple atomic counter that uses mutexes if the platform doesn't support atomic u64s.

#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))]
compile_error!("We need at least 32-bit pointers for atomic counter (and to have enough memory to run LDK)");
#[cfg(target_has_atomic = "64")]
use core::sync::atomic::{AtomicU64, Ordering};
#[cfg(not(target_has_atomic = "64"))]
use crate::sync::Mutex;

use core::sync::atomic::{AtomicUsize, Ordering};

#[derive(Debug)]
pub(crate) struct AtomicCounter {
// Usize needs to be at least 32 bits to avoid overflowing both low and high. If usize is 64
// bits we will never realistically count into high:
counter_low: AtomicUsize,
counter_high: AtomicUsize,
#[cfg(target_has_atomic = "64")]
counter: AtomicU64,
#[cfg(not(target_has_atomic = "64"))]
counter: Mutex<u64>,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the intention is to produce some unique values that maybe don't need to be sequential you can still have a sensible lock-free implementation. Roughly like this:

let mut low = self.low.load(Relaxed);
let mut high = self.high.load(Relaxed);
loop {
    let new_low = if low == u32::MAX {
        // don't use fetch_add to avoid incrementing high by more than 1
        if let Err(new) = self.high.compare_exchange(high, high + 1, Relaxed, Relaxed) {
            high = new;
        }
        0
    } else {
        low + 1
    }
     // FTR this cannot be weak
     match self.low.compare_exchange(low, new_low, Relaxed, Relaxed) {
        Ok(_) => break,
        Err(new) => low = new,
     }
}
u64::from(high) << 32 | u64::from(low)

This assumes that a thread doesn't get scheduled-out after incrementing high for so long that other thread(s) manage to increment the counter by 2^32, which I think is a reasonable assumption. There's still a chance that high gets bumped by more than one though if a thread managed to bump it and before it updates low another thread reads both of them. This is quite unfrequent and could be dealt with by sacrificing one bit of high which gets set first and then reset at the end.


impl AtomicCounter {
pub(crate) fn new() -> Self {
Self {
counter_low: AtomicUsize::new(0),
counter_high: AtomicUsize::new(0),
#[cfg(target_has_atomic = "64")]
counter: AtomicU64::new(0),
#[cfg(not(target_has_atomic = "64"))]
counter: Mutex::new(0),
}
}
pub(crate) fn next(&self) -> u64 {
#[cfg(target_has_atomic = "64")] {
self.counter.fetch_add(1, Ordering::AcqRel)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have AcqRel however to my understanding this is not a synchronization primitive so Relaxed is appropriate.

#[cfg(not(target_has_atomic = "64"))] {
let mut mtx = self.counter.lock().unwrap();
*mtx += 1;
*mtx - 1
}
}
pub(crate) fn get_increment(&self) -> u64 {
let low = self.counter_low.fetch_add(1, Ordering::AcqRel) as u64;
let high = if low == 0 {
self.counter_high.fetch_add(1, Ordering::AcqRel) as u64
} else {
self.counter_high.load(Ordering::Acquire) as u64
};
(high << 32) | low
#[cfg(test)]
pub(crate) fn set_counter(&self, count: u64) {
#[cfg(target_has_atomic = "64")] {
self.counter.store(count, Ordering::Release);
}
#[cfg(not(target_has_atomic = "64"))] {
let mut mtx = self.counter.lock().unwrap();
*mtx = count;
}
}
}
Loading