From 50a513999c553837c846751c36f39d3949cc9243 Mon Sep 17 00:00:00 2001 From: Thibault Martinez Date: Mon, 30 Oct 2023 11:51:39 +0100 Subject: [PATCH] ISA: handle restricted addresses (#1526) * ISA: handle restricted addresses * Avoid clone * Nit * Add 2 more restricted tests --------- Co-authored-by: /alex/ --- .../api/block_builder/input_selection/mod.rs | 7 +- .../input_selection/requirement/sender.rs | 5 + .../client/input_selection/basic_outputs.rs | 180 +++++++++++++++++- 3 files changed, 190 insertions(+), 2 deletions(-) diff --git a/sdk/src/client/api/block_builder/input_selection/mod.rs b/sdk/src/client/api/block_builder/input_selection/mod.rs index 445a7cc3cc..113513f103 100644 --- a/sdk/src/client/api/block_builder/input_selection/mod.rs +++ b/sdk/src/client/api/block_builder/input_selection/mod.rs @@ -68,6 +68,7 @@ impl InputSelection { Address::Account(account_address) => Ok(Some(Requirement::Account(*account_address.account_id()))), Address::Nft(nft_address) => Ok(Some(Requirement::Nft(*nft_address.nft_id()))), Address::Anchor(_) => Err(Error::UnsupportedAddressType(AnchorAddress::KIND)), + Address::Restricted(_) => Ok(None), _ => todo!("What do we do here?"), } } @@ -234,7 +235,11 @@ impl InputSelection { .unwrap() .0; - self.addresses.contains(&required_address) + if let Address::Restricted(restricted_address) = required_address { + self.addresses.contains(restricted_address.address()) + } else { + self.addresses.contains(&required_address) + } }) } diff --git a/sdk/src/client/api/block_builder/input_selection/requirement/sender.rs b/sdk/src/client/api/block_builder/input_selection/requirement/sender.rs index 0dfc1fc516..ecebdc5826 100644 --- a/sdk/src/client/api/block_builder/input_selection/requirement/sender.rs +++ b/sdk/src/client/api/block_builder/input_selection/requirement/sender.rs @@ -42,6 +42,11 @@ impl InputSelection { Err(e) => Err(e), } } + Address::Restricted(restricted_address) => { + log::debug!("Forwarding {address:?} sender requirement to inner address"); + + self.fulfill_sender_requirement(restricted_address.address()) + } _ => Err(Error::UnsupportedAddressType(address.kind())), } } diff --git a/sdk/tests/client/input_selection/basic_outputs.rs b/sdk/tests/client/input_selection/basic_outputs.rs index 201d6c0686..8eedd759ac 100644 --- a/sdk/tests/client/input_selection/basic_outputs.rs +++ b/sdk/tests/client/input_selection/basic_outputs.rs @@ -6,7 +6,7 @@ use std::str::FromStr; use iota_sdk::{ client::api::input_selection::{Error, InputSelection, Requirement}, types::block::{ - address::{AccountAddress, Address, Bech32Address, NftAddress}, + address::{AccountAddress, Address, Bech32Address, NftAddress, RestrictedAddress, ToBech32Ext}, output::{AccountId, NftId}, protocol::protocol_parameters, }, @@ -1389,3 +1389,181 @@ fn too_many_outputs_with_remainder() { iota_sdk::client::api::input_selection::Error::InvalidOutputCount(129) ) } + +#[test] +fn restricted_ed25519() { + let protocol_parameters = protocol_parameters(); + let address = Address::try_from_bech32(BECH32_ADDRESS_ED25519_1).unwrap(); + let restricted = RestrictedAddress::new(address.clone()).unwrap(); + let restricted_bech32 = restricted.to_bech32_unchecked("rms").to_string(); + + let inputs = build_inputs([ + Basic(1_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None), + Basic(1_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None), + Basic(1_000_000, &restricted_bech32, None, None, None, None, None, None), + Basic(1_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None), + Basic(1_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None), + ]); + let outputs = build_outputs([Basic( + 1_000_000, + BECH32_ADDRESS_ED25519_0, + None, + None, + None, + None, + None, + None, + )]); + + let selected = InputSelection::new( + inputs.clone(), + outputs.clone(), + addresses([BECH32_ADDRESS_ED25519_1]), + protocol_parameters, + ) + .select() + .unwrap(); + + assert_eq!(selected.inputs.len(), 1); + assert_eq!(selected.inputs, [inputs[2].clone()]); + assert!(unsorted_eq(&selected.outputs, &outputs)); +} + +#[test] +fn restricted_nft() { + let protocol_parameters = protocol_parameters(); + let nft_id_1 = NftId::from_str(NFT_ID_1).unwrap(); + let nft_address = Address::from(nft_id_1); + let restricted = RestrictedAddress::new(nft_address.clone()).unwrap(); + let restricted_bech32 = restricted.to_bech32_unchecked("rms").to_string(); + + let inputs = build_inputs([ + Basic(2_000_000, &restricted_bech32, None, None, None, None, None, None), + Nft( + 2_000_000, + nft_id_1, + BECH32_ADDRESS_ED25519_0, + None, + None, + None, + None, + None, + None, + ), + ]); + let outputs = build_outputs([Basic( + 3_000_000, + BECH32_ADDRESS_ED25519_0, + None, + None, + None, + None, + None, + None, + )]); + + let selected = InputSelection::new( + inputs.clone(), + outputs.clone(), + addresses([BECH32_ADDRESS_ED25519_0]), + protocol_parameters, + ) + .select() + .unwrap(); + + assert!(unsorted_eq(&selected.inputs, &inputs)); + assert_eq!(selected.outputs.len(), 2); + assert!(selected.outputs.contains(&outputs[0])); +} + +#[test] +fn restricted_account() { + let protocol_parameters = protocol_parameters(); + let account_id_1 = AccountId::from_str(ACCOUNT_ID_1).unwrap(); + let account_address = Address::from(account_id_1); + let restricted = RestrictedAddress::new(account_address.clone()).unwrap(); + let restricted_bech32 = restricted.to_bech32_unchecked("rms").to_string(); + + let inputs = build_inputs([ + Basic(2_000_000, &restricted_bech32, None, None, None, None, None, None), + Account( + 2_000_000, + account_id_1, + BECH32_ADDRESS_ED25519_0, + None, + None, + None, + None, + ), + ]); + + let outputs = build_outputs([Basic( + 3_000_000, + BECH32_ADDRESS_ED25519_0, + None, + None, + None, + None, + None, + None, + )]); + + let selected = InputSelection::new( + inputs.clone(), + outputs.clone(), + addresses([BECH32_ADDRESS_ED25519_0]), + protocol_parameters, + ) + .select() + .unwrap(); + + assert!(unsorted_eq(&selected.inputs, &inputs)); + assert_eq!(selected.outputs.len(), 2); + assert!(selected.outputs.contains(&outputs[0])); +} + +#[test] +fn restricted_ed25519_sender() { + let protocol_parameters = protocol_parameters(); + let sender = Address::try_from_bech32(BECH32_ADDRESS_ED25519_1).unwrap(); + let restricted_sender = RestrictedAddress::new(sender.clone()).unwrap(); + let restricted_sender_bech32 = restricted_sender.to_bech32_unchecked("rms").to_string(); + + let inputs = build_inputs([ + Basic(2_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None), + Basic(2_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None), + Basic(1_000_000, BECH32_ADDRESS_ED25519_1, None, None, None, None, None, None), + Basic(2_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None), + Basic(2_000_000, BECH32_ADDRESS_ED25519_0, None, None, None, None, None, None), + ]); + let outputs = build_outputs([Basic( + 2_000_000, + BECH32_ADDRESS_ED25519_0, + None, + Some(&restricted_sender_bech32), + None, + None, + None, + None, + )]); + + let selected = InputSelection::new( + inputs.clone(), + outputs.clone(), + addresses([BECH32_ADDRESS_ED25519_0, BECH32_ADDRESS_ED25519_1]), + protocol_parameters, + ) + .select() + .unwrap(); + + // Sender + another for amount + assert_eq!(selected.inputs.len(), 2); + assert!( + selected + .inputs + .iter() + .any(|input| *input.output.as_basic().address() == sender) + ); + // Provided output + remainder + assert_eq!(selected.outputs.len(), 2); +}