diff --git a/lightning/src/ln/functional_test_utils.rs b/lightning/src/ln/functional_test_utils.rs index 7e011e65ef..d74d88e0d1 100644 --- a/lightning/src/ln/functional_test_utils.rs +++ b/lightning/src/ln/functional_test_utils.rs @@ -1083,19 +1083,29 @@ macro_rules! check_added_monitors { } } -/// Checks whether the claimed HTLC for the specified path has the correct channel information. -/// -/// This will panic if the path is empty, if the HTLC's channel ID is not actually a channel that -/// connects the final two nodes in the path, or if the `user_channel_id` is incorrect. -pub fn check_claimed_htlc_channel<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, path: &[&Node<'a, 'b, 'c>], htlc: &ClaimedHTLC) { +fn claimed_htlc_matches_path<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, path: &[&Node<'a, 'b, 'c>], htlc: &ClaimedHTLC) -> bool { let mut nodes = path.iter().rev(); let dest = nodes.next().expect("path should have a destination").node; let prev = nodes.next().unwrap_or(&origin_node).node; let dest_channels = dest.list_channels(); let ch = dest_channels.iter().find(|ch| ch.channel_id == htlc.channel_id) .expect("HTLC's channel should be one of destination node's channels"); - assert_eq!(htlc.user_channel_id, ch.user_channel_id); - assert_eq!(ch.counterparty.node_id, prev.get_our_node_id()); + htlc.user_channel_id == ch.user_channel_id && + ch.counterparty.node_id == prev.get_our_node_id() +} + +fn check_claimed_htlcs_match_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, route: &[&[&Node<'a, 'b, 'c>]], htlcs: &[ClaimedHTLC]) { + assert_eq!(route.len(), htlcs.len()); + for path in route { + let mut found_matching_htlc = false; + for htlc in htlcs { + if claimed_htlc_matches_path(origin_node, path, htlc) { + found_matching_htlc = true; + break; + } + } + assert!(found_matching_htlc); + } } pub fn _reload_node<'a, 'b, 'c>(node: &'a Node<'a, 'b, 'c>, default_config: UserConfig, chanman_encoded: &[u8], monitors_encoded: &[&[u8]]) -> TestChannelManager<'b, 'c> { @@ -2832,7 +2842,7 @@ pub fn pass_claimed_payment_along_route(args: ClaimAlongRouteArgs) -> u64 { assert_eq!(htlcs.len(), expected_paths.len()); // One per path. assert_eq!(htlcs.iter().map(|h| h.value_msat).sum::(), amount_msat); assert_eq!(onion_fields.as_ref().unwrap().custom_tlvs, custom_tlvs); - expected_paths.iter().zip(htlcs).for_each(|(path, htlc)| check_claimed_htlc_channel(origin_node, path, htlc)); + check_claimed_htlcs_match_route(origin_node, expected_paths, htlcs); fwd_amt_msat = amount_msat; }, Event::PaymentClaimed { @@ -2849,7 +2859,7 @@ pub fn pass_claimed_payment_along_route(args: ClaimAlongRouteArgs) -> u64 { assert_eq!(htlcs.len(), expected_paths.len()); // One per path. assert_eq!(htlcs.iter().map(|h| h.value_msat).sum::(), amount_msat); assert_eq!(onion_fields.as_ref().unwrap().custom_tlvs, custom_tlvs); - expected_paths.iter().zip(htlcs).for_each(|(path, htlc)| check_claimed_htlc_channel(origin_node, path, htlc)); + check_claimed_htlcs_match_route(origin_node, expected_paths, htlcs); fwd_amt_msat = amount_msat; } _ => panic!(),