From d653c519637345a5591faf602748e94ba6eef36d Mon Sep 17 00:00:00 2001 From: Duncan Dean Date: Fri, 12 Apr 2024 15:37:17 +0200 Subject: [PATCH] Split contributions into local/remote maps --- lightning/src/ln/interactivetxs.rs | 261 ++++++++++++++++++++--------- 1 file changed, 184 insertions(+), 77 deletions(-) diff --git a/lightning/src/ln/interactivetxs.rs b/lightning/src/ln/interactivetxs.rs index 1dbbb41750c..fd42060bc0e 100644 --- a/lightning/src/ln/interactivetxs.rs +++ b/lightning/src/ln/interactivetxs.rs @@ -78,7 +78,7 @@ impl SerialIdExt for SerialId { } #[derive(Debug, Clone, PartialEq)] -pub enum AbortReason { +pub(crate) enum AbortReason { InvalidStateTransition, UnexpectedCounterpartyMessage, ReceivedTooManyTxAddInputs, @@ -98,20 +98,121 @@ pub enum AbortReason { InvalidTx, } -#[derive(Debug)] -pub struct TxInputWithPrevOutput { +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct InteractiveTxInput { + serial_id: SerialId, input: TxIn, prev_output: TxOut, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct InteractiveTxOutput { + serial_id: SerialId, + tx_out: TxOut, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ConstructedTransaction { + local_inputs: Vec, + local_outputs: Vec, + + remote_inputs: Vec, + remote_outputs: Vec, + + local_inputs_value_satoshis: u64, + local_outputs_value_satoshis: u64, + + remote_inputs_value_satoshis: u64, + remote_outputs_value_satoshis: u64, + + remote_fees: u64, + local_fees: u64, + total_fees: u64, + + lock_time: AbsoluteLockTime, +} + +impl ConstructedTransaction { + pub fn new( + local_inputs: Vec, local_outputs: Vec, + remote_inputs: Vec, remote_outputs: Vec, + lock_time: AbsoluteLockTime, + ) -> Self { + let local_inputs_value_satoshis = + local_inputs.iter().fold(0u64, |acc, InteractiveTxInput { ref prev_output, .. }| { + acc.saturating_add(prev_output.value) + }); + let local_outputs_value_satoshis = + local_outputs.iter().fold(0u64, |acc, InteractiveTxOutput { ref tx_out, .. }| { + acc.saturating_add(tx_out.value) + }); + + let remote_inputs_value_satoshis = + remote_inputs.iter().fold(0u64, |acc, InteractiveTxInput { ref prev_output, .. }| { + acc.saturating_add(prev_output.value) + }); + let remote_outputs_value_satoshis = + remote_outputs.iter().fold(0u64, |acc, InteractiveTxOutput { ref tx_out, .. }| { + acc.saturating_add(tx_out.value) + }); + + let local_fees = local_inputs_value_satoshis.saturating_sub(local_outputs_value_satoshis); + let remote_fees = + remote_inputs_value_satoshis.saturating_sub(remote_outputs_value_satoshis); + + Self { + local_inputs, + local_outputs, + + remote_inputs, + remote_outputs, + + local_inputs_value_satoshis, + local_outputs_value_satoshis, + + remote_inputs_value_satoshis, + remote_outputs_value_satoshis, + + local_fees, + remote_fees, + total_fees: local_fees.saturating_add(remote_fees), + + lock_time, + } + } + + pub fn build_unsigned_tx(&self) -> Result { + // Inputs and outputs must be sorted by serial_id + let ConstructedTransaction { local_inputs: inputs, local_outputs: outputs, .. } = self; + + let mut inputs: Vec = + inputs.iter().chain(self.remote_inputs.iter()).map(|input| input.clone()).collect(); + let mut outputs: Vec = + outputs.iter().chain(self.remote_outputs.iter()).map(|input| input.clone()).collect(); + inputs.sort_unstable_by_key(|InteractiveTxInput { serial_id, .. }| *serial_id); + outputs.sort_unstable_by_key(|InteractiveTxOutput { serial_id, .. }| *serial_id); + + let input: Vec = + inputs.clone().into_iter().map(|InteractiveTxInput { input, .. }| input).collect(); + let output: Vec = + outputs.clone().into_iter().map(|InteractiveTxOutput { tx_out, .. }| tx_out).collect(); + + let unsigned_tx = Transaction { version: 2, lock_time: self.lock_time, input, output }; + + Ok(unsigned_tx) + } +} + #[derive(Debug)] struct NegotiationContext { holder_is_initiator: bool, received_tx_add_input_count: u16, received_tx_add_output_count: u16, - inputs: HashMap, + local_inputs: HashMap, + remote_inputs: HashMap, prevtx_outpoints: HashSet, - outputs: HashMap, + local_outputs: HashMap, + remote_outputs: HashMap, tx_locktime: AbsoluteLockTime, feerate_sat_per_kw: u32, to_remote_value_satoshis: u64, @@ -128,24 +229,16 @@ impl NegotiationContext { self.holder_is_initiator == serial_id.is_for_non_initiator() } - fn total_input_and_output_count(&self) -> usize { - self.inputs.len().saturating_add(self.outputs.len()) + fn total_input_count(&self) -> usize { + self.local_inputs.len().saturating_add(self.remote_inputs.len()) } - fn counterparty_inputs_contributed( - &self, - ) -> impl Iterator + Clone { - self.inputs - .iter() - .filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id)) - .map(|(_, input_with_prevout)| input_with_prevout) + fn total_output_count(&self) -> usize { + self.local_outputs.len().saturating_add(self.remote_outputs.len()) } - fn counterparty_outputs_contributed(&self) -> impl Iterator + Clone { - self.outputs - .iter() - .filter(move |(serial_id, _)| self.is_serial_id_valid_for_counterparty(serial_id)) - .map(|(_, output)| output) + fn total_input_and_output_count(&self) -> usize { + self.total_input_count().saturating_add(self.total_output_count()) } fn received_tx_add_input(&mut self, msg: &msgs::TxAddInput) -> Result<(), AbortReason> { @@ -207,7 +300,7 @@ impl NegotiationContext { return Err(AbortReason::PrevTxOutInvalid); }; let prev_outpoint = OutPoint { txid, vout: msg.prevtx_out }; - match self.inputs.entry(msg.serial_id) { + match self.remote_inputs.entry(msg.serial_id) { hash_map::Entry::Occupied(_) => { // The receiving node: // - MUST fail the negotiation if: @@ -215,7 +308,8 @@ impl NegotiationContext { return Err(AbortReason::DuplicateSerialId); }, hash_map::Entry::Vacant(entry) => { - entry.insert(TxInputWithPrevOutput { + entry.insert(InteractiveTxInput { + serial_id: msg.serial_id, input: TxIn { previous_output: prev_outpoint.clone(), sequence: Sequence(msg.sequence), @@ -234,7 +328,7 @@ impl NegotiationContext { return Err(AbortReason::IncorrectSerialIdParity); } - self.inputs + self.remote_inputs .remove(&msg.serial_id) // The receiving node: // - MUST fail the negotiation if: @@ -270,8 +364,8 @@ impl NegotiationContext { // Check that adding this output would not cause the total output value to exceed the total // bitcoin supply. let mut outputs_value: u64 = 0; - for output in self.outputs.iter() { - outputs_value = outputs_value.saturating_add(output.1.value); + for output in self.remote_outputs.iter() { + outputs_value = outputs_value.saturating_add(output.1.tx_out.value); } if outputs_value.saturating_add(msg.sats) > TOTAL_BITCOIN_SUPPLY_SATOSHIS { // The receiving node: @@ -300,8 +394,11 @@ impl NegotiationContext { return Err(AbortReason::InvalidOutputScript); } - let output = TxOut { value: msg.sats, script_pubkey: msg.script.clone() }; - match self.outputs.entry(msg.serial_id) { + let output = InteractiveTxOutput { + serial_id: msg.serial_id, + tx_out: TxOut { value: msg.sats, script_pubkey: msg.script.clone() }, + }; + match self.remote_outputs.entry(msg.serial_id) { hash_map::Entry::Occupied(_) => { // The receiving node: // - MUST fail the negotiation if: @@ -319,7 +416,7 @@ impl NegotiationContext { if !self.is_serial_id_valid_for_counterparty(&msg.serial_id) { return Err(AbortReason::IncorrectSerialIdParity); } - if let Some(_) = self.outputs.remove(&msg.serial_id) { + if let Some(_) = self.remote_outputs.remove(&msg.serial_id) { Ok(()) } else { // The receiving node: @@ -343,37 +440,46 @@ impl NegotiationContext { // We have added an input that already exists return Err(AbortReason::PrevTxOutInvalid); } - self.inputs.insert(msg.serial_id, TxInputWithPrevOutput { input, prev_output }); + self.local_inputs.insert( + msg.serial_id, + InteractiveTxInput { serial_id: msg.serial_id, input, prev_output }, + ); Ok(()) } fn sent_tx_add_output(&mut self, msg: &msgs::TxAddOutput) -> Result<(), AbortReason> { - self.outputs - .insert(msg.serial_id, TxOut { value: msg.sats, script_pubkey: msg.script.clone() }); + self.local_outputs.insert( + msg.serial_id, + InteractiveTxOutput { + serial_id: msg.serial_id, + tx_out: TxOut { value: msg.sats, script_pubkey: msg.script.clone() }, + }, + ); Ok(()) } fn sent_tx_remove_input(&mut self, msg: &msgs::TxRemoveInput) -> Result<(), AbortReason> { - self.inputs.remove(&msg.serial_id); + self.local_inputs.remove(&msg.serial_id); Ok(()) } fn sent_tx_remove_output(&mut self, msg: &msgs::TxRemoveOutput) -> Result<(), AbortReason> { - self.outputs.remove(&msg.serial_id); + self.local_outputs.remove(&msg.serial_id); Ok(()) } fn check_counterparty_fees( - &self, counterparty_inputs_value: u64, counterparty_outputs_value: u64, + &self, counterparty_fees_contributed: u64, ) -> Result<(), AbortReason> { let mut counterparty_weight_contributed: u64 = self - .counterparty_outputs_contributed() - .map(|output| get_output_weight(&output.script_pubkey)) + .remote_outputs + .values() + .map(|output| get_output_weight(&output.tx_out.script_pubkey)) .sum(); // We don't know the counterparty's witnesses ahead of time obviously, so we use the lower bounds // specified in BOLT 3. let mut total_inputs_weight: u64 = 0; - for TxInputWithPrevOutput { prev_output, .. } in self.counterparty_inputs_contributed() { + for InteractiveTxInput { prev_output, .. } in self.remote_inputs.values() { total_inputs_weight = total_inputs_weight.saturating_add(if prev_output.script_pubkey.is_v0_p2wpkh() { P2WPKH_INPUT_WEIGHT_LOWER_BOUND @@ -387,8 +493,6 @@ impl NegotiationContext { } counterparty_weight_contributed = counterparty_weight_contributed.saturating_add(total_inputs_weight); - let counterparty_fees_contributed = - counterparty_inputs_value.saturating_sub(counterparty_outputs_value); let mut required_counterparty_contribution_fee = fee_for_weight(self.feerate_sat_per_kw, counterparty_weight_contributed); if !self.holder_is_initiator { @@ -405,56 +509,54 @@ impl NegotiationContext { Ok(()) } - fn build_transaction(self) -> Result { + fn validate_tx(self) -> Result { // The receiving node: // MUST fail the negotiation if: // - the peer's total input satoshis is less than their outputs - let mut counterparty_inputs_value: u64 = 0; - let mut counterparty_outputs_value: u64 = 0; - for input in self.counterparty_inputs_contributed() { - counterparty_inputs_value = - counterparty_inputs_value.saturating_add(input.prev_output.value); - } - for output in self.counterparty_outputs_contributed() { - counterparty_outputs_value = counterparty_outputs_value.saturating_add(output.value); - } + let remote_inputs_value = self + .remote_inputs + .values() + .fold(0u64, |acc, InteractiveTxInput { prev_output, .. }| { + acc.saturating_add(prev_output.value) + }); + let remote_outputs_value = self + .remote_outputs + .values() + .fold(0u64, |acc, InteractiveTxOutput { tx_out, .. }| acc.saturating_add(tx_out.value)); + // ...actually the counterparty might be splicing out, so that their balance also contributes // to the total input value. - if counterparty_inputs_value.saturating_add(self.to_remote_value_satoshis) - < counterparty_outputs_value + if remote_inputs_value.saturating_add(self.to_remote_value_satoshis) < remote_outputs_value { return Err(AbortReason::OutputsValueExceedsInputsValue); } // - there are more than 252 inputs // - there are more than 252 outputs - if self.inputs.len() > MAX_INPUTS_OUTPUTS_COUNT - || self.outputs.len() > MAX_INPUTS_OUTPUTS_COUNT + if self.total_input_count() > MAX_INPUTS_OUTPUTS_COUNT + || self.total_output_count() > MAX_INPUTS_OUTPUTS_COUNT { return Err(AbortReason::ExceededNumberOfInputsOrOutputs); } // - the peer's paid feerate does not meet or exceed the agreed feerate (based on the minimum fee). - self.check_counterparty_fees(counterparty_inputs_value, counterparty_outputs_value)?; - - // Inputs and outputs must be sorted by serial_id - let mut inputs = self.inputs.into_iter().collect::>(); - let mut outputs = self.outputs.into_iter().collect::>(); - inputs.sort_unstable_by_key(|(serial_id, _)| *serial_id); - outputs.sort_unstable_by_key(|(serial_id, _)| *serial_id); + self.check_counterparty_fees(remote_inputs_value.saturating_sub(remote_outputs_value))?; + + let constructed_tx = ConstructedTransaction::new( + self.local_inputs.into_values().collect(), + self.local_outputs.into_values().collect(), + self.remote_inputs.into_values().collect(), + self.remote_outputs.into_values().collect(), + self.tx_locktime, + ); - let tx_to_validate = Transaction { - version: 2, - lock_time: self.tx_locktime, - input: inputs.into_iter().map(|(_, input)| input.input).collect(), - output: outputs.into_iter().map(|(_, output)| output).collect(), - }; - if tx_to_validate.weight().to_wu() > MAX_STANDARD_TX_WEIGHT as u64 { + let unsigned_tx = constructed_tx.build_unsigned_tx()?; + if unsigned_tx.weight().to_wu() > MAX_STANDARD_TX_WEIGHT as u64 { return Err(AbortReason::TransactionTooLarge); } - Ok(tx_to_validate) + Ok(constructed_tx) } } @@ -542,7 +644,7 @@ define_state!( ReceivedTxComplete, "We have received a `tx_complete` message and the counterparty is awaiting ours." ); -define_state!(NegotiationComplete, Transaction, "We have exchanged consecutive `tx_complete` messages with the counterparty and the transaction negotiation is complete."); +define_state!(NegotiationComplete, ConstructedTransaction, "We have exchanged consecutive `tx_complete` messages with the counterparty and the transaction negotiation is complete."); define_state!( NegotiationAborted, AbortReason, @@ -584,7 +686,7 @@ macro_rules! define_state_transitions { impl StateTransition for $tx_complete_state { fn transition(self, _data: &msgs::TxComplete) -> StateTransitionResult { let context = self.into_negotiation_context(); - let tx = context.build_transaction()?; + let tx = context.validate_tx()?; Ok(NegotiationComplete(tx)) } } @@ -662,9 +764,11 @@ impl StateMachine { holder_is_initiator: is_initiator, received_tx_add_input_count: 0, received_tx_add_output_count: 0, - inputs: new_hash_map(), + local_inputs: new_hash_map(), + remote_inputs: new_hash_map(), prevtx_outpoints: new_hash_set(), - outputs: new_hash_map(), + local_outputs: new_hash_map(), + remote_outputs: new_hash_map(), feerate_sat_per_kw, to_remote_value_satoshis, }; @@ -726,14 +830,14 @@ impl StateMachine { ]); } -pub struct InteractiveTxConstructor { +pub(crate) struct InteractiveTxConstructor { state_machine: StateMachine, channel_id: ChannelId, inputs_to_contribute: Vec<(SerialId, TxIn, TransactionU16LenLimited)>, outputs_to_contribute: Vec<(SerialId, TxOut)>, } -pub enum InteractiveTxMessageSend { +pub(crate) enum InteractiveTxMessageSend { TxAddInput(msgs::TxAddInput), TxAddOutput(msgs::TxAddOutput), TxComplete(msgs::TxComplete), @@ -765,10 +869,10 @@ where serial_id } -pub enum HandleTxCompleteValue { +pub(crate) enum HandleTxCompleteValue { SendTxMessage(InteractiveTxMessageSend), - SendTxComplete(InteractiveTxMessageSend, Transaction), - NegotiationComplete(Transaction), + SendTxComplete(InteractiveTxMessageSend, ConstructedTransaction), + NegotiationComplete(ConstructedTransaction), } impl InteractiveTxConstructor { @@ -1127,7 +1231,10 @@ mod tests { } assert!(message_send_a.is_none()); assert!(message_send_b.is_none()); - assert_eq!(final_tx_a, final_tx_b); + assert_eq!( + final_tx_a.unwrap().build_unsigned_tx().unwrap(), + final_tx_b.unwrap().build_unsigned_tx().unwrap() + ); assert!(session.expect_error.is_none(), "Test: {}", session.description); }