diff --git a/chains/solana/contracts/programs/ccip-router/src/instructions/v1/offramp.rs b/chains/solana/contracts/programs/ccip-router/src/instructions/v1/offramp.rs index dece07056..51f44c2bf 100644 --- a/chains/solana/contracts/programs/ccip-router/src/instructions/v1/offramp.rs +++ b/chains/solana/contracts/programs/ccip-router/src/instructions/v1/offramp.rs @@ -9,7 +9,7 @@ use super::ocr3base::{ocr3_transmit, ReportContext}; use super::ocr3impl::{Ocr3ReportForCommit, Ocr3ReportForExecutionReportSingleChain}; use super::pools::{ calculate_token_pool_account_indices, get_balance, interact_with_pool, - validate_and_parse_token_accounts, CCIP_POOL_V1_RET_BYTES, + validate_and_parse_token_accounts, TokenAccounts, CCIP_POOL_V1_RET_BYTES, }; use crate::{ @@ -403,14 +403,13 @@ fn internal_execute<'info>( // note: indexes are used instead of counts in case more accounts need to be passed in remaining_accounts before token accounts // token_indexes = [2, 4] where remaining_accounts is [custom_account, custom_account, token1_account1, token1_account2, token2_account1, token2_account2] for example for (i, token_amount) in execution_report.message.token_amounts.iter().enumerate() { - let (start, end) = - calculate_token_pool_account_indices(i, token_indexes, ctx.remaining_accounts.len())?; - let acc_list = &ctx.remaining_accounts[start..end]; - let accs = validate_and_parse_token_accounts( - execution_report.message.receiver, - execution_report.message.header.source_chain_selector, + let accs = get_token_accounts_for( ctx.program_id.key(), - acc_list, + ctx.remaining_accounts, + execution_report.message.token_receiver, + execution_report.message.header.source_chain_selector, + token_indexes, + i, )?; let router_token_pool_signer = &ctx.accounts.token_pools_signer; @@ -419,7 +418,7 @@ fn internal_execute<'info>( // CPI: call lockOrBurn on token pool let release_or_mint = ReleaseOrMintInV1 { original_sender: execution_report.message.sender.clone(), - receiver: execution_report.message.receiver, + receiver: execution_report.message.token_receiver, amount: token_amount.amount, local_token: token_amount.dest_token_address, remote_chain_selector: execution_report.message.header.source_chain_selector, @@ -474,10 +473,13 @@ fn internal_execute<'info>( // handle CPI call if there are extra accounts // case: no tokens, but there are remaining_accounts passed in // case: tokens, but the first token has a non-zero index (indicating extra accounts before token accounts) - if should_execute_messaging(token_indexes, ctx.remaining_accounts.is_empty()) { + if should_execute_messaging( + &execution_report.message.logic_receiver, + ctx.remaining_accounts.is_empty(), + ) { let (msg_program, msg_accounts) = parse_messaging_accounts( token_indexes, - execution_report.message.receiver, + execution_report.message.logic_receiver, &execution_report.message.extra_args.accounts, &execution_report.message.extra_args.is_writable_bitmap, ctx.remaining_accounts, @@ -537,72 +539,79 @@ fn internal_execute<'info>( Ok(()) } -// should_execute_messaging checks if there remaining_accounts that are not being used for token pools -// case: no tokens, but there are remaining_accounts passed in -// case: tokens, but the first token has a non-zero index (indicating extra accounts before token accounts) -fn should_execute_messaging(token_indexes: &[u8], remaining_accounts_empty: bool) -> bool { - (token_indexes.is_empty() && !remaining_accounts_empty) - || (!token_indexes.is_empty() && token_indexes[0] != 0) +fn get_token_accounts_for<'a>( + router: Pubkey, + accounts: &'a [AccountInfo<'a>], + token_receiver: Pubkey, + chain_selector: u64, + token_indexes: &[u8], + i: usize, +) -> Result> { + let (start, end) = calculate_token_pool_account_indices(i, token_indexes, accounts.len())?; + + let accs = validate_and_parse_token_accounts( + token_receiver, + chain_selector, + router, + &accounts[start..end], + )?; + + Ok(accs) +} + +// should_execute_messaging checks if: +// 1. There is at least one account used for messaging (the first subset of accounts). This is because the first account is the program id to do the CPI +// 2. AND the logic_receiver has a value different than zeros +fn should_execute_messaging(logic_receiver: &Pubkey, remaining_accounts_empty: bool) -> bool { + !remaining_accounts_empty && *logic_receiver != Pubkey::default() } /// parse_message_accounts returns all the accounts needed to execute the CPI instruction /// It also validates that the accounts sent in the message match the ones sent in the source chain +/// Precondition: logic_receiver != 0 && remaining_accounts.len() > 0 /// /// # Arguments /// * `token_indexes` - start indexes of token pool accounts, used to determine ending index for arbitrary messaging accounts -/// * `receiver` - receiver address from x-chain message, used to validate `accounts` -/// * `source_accounts` - arbitrary messaging accounts from the x-chain message, used to validate `accounts`. expected order is: [program, ...additional message accounts] -/// * `accounts` - accounts passed via `ctx.remaining_accounts`. expected order is: [program, receiver, ...additional message accounts] +/// * `logic_receiver` - receiver address from x-chain message, used to validate `remaining_accounts` +/// * `extra_args_accounts` - arbitrary messaging accounts from the x-chain message, used to validate `accounts`. +/// * `remaining_accounts` - accounts passed via `ctx.remaining_accounts`. expected order is: [program, receiver, ...additional message accounts] fn parse_messaging_accounts<'info>( token_indexes: &[u8], - receiver: Pubkey, - source_accounts: &[Pubkey], + logic_receiver: Pubkey, + extra_args_accounts: &[Pubkey], source_bitmap: &u64, - accounts: &'info [AccountInfo<'info>], + remaining_accounts: &'info [AccountInfo<'info>], ) -> Result<(&'info AccountInfo<'info>, &'info [AccountInfo<'info>])> { - let end_ind = if token_indexes.is_empty() { - accounts.len() + let end_index = if token_indexes.is_empty() { + remaining_accounts.len() } else { token_indexes[0] as usize }; - let msg_program = &accounts[0]; - let msg_accounts = &accounts[1..end_ind]; + require!( + 1 <= end_index && end_index <= remaining_accounts.len(), // program id and message accounts need to fit in remaining accounts + CcipRouterError::InvalidInputs + ); // there could be other remaining accounts used for tokens - let source_program = &source_accounts[0]; - let source_msg_accounts = &source_accounts[1..source_accounts.len()]; + let msg_program = &remaining_accounts[0]; + let msg_accounts = &remaining_accounts[1..end_index]; require!( - *source_program == msg_program.key(), + logic_receiver == msg_program.key(), CcipRouterError::InvalidInputs, ); - - require!( - msg_accounts[0].key() == receiver, - CcipRouterError::InvalidInputs - ); - - // assert same number of accounts passed from message and transaction (not including program) - // source_msg_accounts + 1 to account for separately passed receiver address require!( - source_msg_accounts.len() + 1 == msg_accounts.len(), + msg_accounts.len() == extra_args_accounts.len(), // assert same number of accounts passed from message and transaction CcipRouterError::InvalidInputs ); // Validate the addresses of all the accounts match the ones in source chain if msg_accounts.len() > 1 { - // Ignore the first account as it's the receiver - let accounts_to_validate = &msg_accounts[1..msg_accounts.len()]; - require!( - accounts_to_validate.len() == source_msg_accounts.len(), - CcipRouterError::InvalidInputs - ); - for (i, acc) in source_msg_accounts.iter().enumerate() { - let current_acc = &msg_accounts[i + 1]; // TODO: remove offset by 1 to skip receiver after receiver refactor + for (i, acc) in extra_args_accounts.iter().enumerate() { + let current_acc = &msg_accounts[i]; require!(*acc == current_acc.key(), CcipRouterError::InvalidInputs); require!( - // TODO: remove offset by 1 to skip program after receiver refactor - is_writable(source_bitmap, (i + 1) as u8) == current_acc.is_writable, + is_writable(source_bitmap, (i) as u8) == current_acc.is_writable, CcipRouterError::InvalidInputs ); } @@ -622,7 +631,6 @@ pub fn verify_merkle_root(execution_report: &ExecutionReportSingleChain) -> Resu Ok(hashed_leaf) } -// TODO: Refactor this to use the same structure as messages: execution_report.validate(..) pub fn validate_execution_report<'info>( execution_report: &ExecutionReportSingleChain, source_chain_state: &Account<'info, SourceChain>, @@ -688,7 +696,8 @@ fn hash(msg: &Any2SVMRampMessage) -> [u8; 32] { &msg.on_ramp_address, // message header &msg.header.message_id, - &msg.receiver.to_bytes(), + &msg.token_receiver.to_bytes(), + &msg.logic_receiver.to_bytes(), &header_sequence_number, msg.extra_args.try_to_vec().unwrap().as_ref(), // borsh serialized &header_nonce, @@ -840,7 +849,10 @@ mod tests { 0, 0, 0, 0, ] .to_vec(), - receiver: Pubkey::try_from("DS2tt4BX7YwCw7yrDNwbAdnYrxjeCPeGJbHmZEYC8RTb").unwrap(), + token_receiver: Pubkey::try_from("DS2tt4BX7YwCw7yrDNwbAdnYrxjeCPeGJbHmZEYC8RTb") + .unwrap(), + logic_receiver: Pubkey::try_from("C8WSPj3yyus1YN3yNB6YA5zStYtbjQWtpmKadmvyUXq8") + .unwrap(), data: vec![4, 5, 6], header: RampMessageHeader { message_id: [ @@ -867,7 +879,7 @@ mod tests { compute_units: 1000, is_writable_bitmap: 1, accounts: vec![ - Pubkey::try_from("DS2tt4BX7YwCw7yrDNwbAdnYrxjeCPeGJbHmZEYC8RTb").unwrap(), + Pubkey::try_from("CtEVnHsQzhTNWav8skikiV2oF6Xx7r7uGGa8eCDQtTjH").unwrap(), ], }, on_ramp_address: on_ramp_address.clone(), @@ -875,7 +887,7 @@ mod tests { let hash_result = hash(&message); assert_eq!( - "60f412fe7c28ae6981b694f92677276f767a98e0314b9a31a3c38366223e7e52", + "266b8d99e64a52fdd325f67674f56d0005dbee5e9999ff22017d5b117fbedfa3", hex::encode(hash_result) ); } diff --git a/chains/solana/contracts/programs/ccip-router/src/instructions/v1/pools.rs b/chains/solana/contracts/programs/ccip-router/src/instructions/v1/pools.rs index 723a4a201..54361d517 100644 --- a/chains/solana/contracts/programs/ccip-router/src/instructions/v1/pools.rs +++ b/chains/solana/contracts/programs/ccip-router/src/instructions/v1/pools.rs @@ -58,7 +58,7 @@ pub(super) struct TokenAccounts<'a> { } pub(super) fn validate_and_parse_token_accounts<'info>( - user: Pubkey, + token_receiver: Pubkey, chain_selector: u64, router: Pubkey, accounts: &'info [AccountInfo<'info>], @@ -122,7 +122,7 @@ pub(super) fn validate_and_parse_token_accounts<'info>( require!( user_token_account.key() == get_associated_token_address_with_program_id( - &user, + &token_receiver, &mint.key(), &token_program.key() ) diff --git a/chains/solana/contracts/programs/ccip-router/src/messages.rs b/chains/solana/contracts/programs/ccip-router/src/messages.rs index e5ff15c35..c7b04f427 100644 --- a/chains/solana/contracts/programs/ccip-router/src/messages.rs +++ b/chains/solana/contracts/programs/ccip-router/src/messages.rs @@ -61,10 +61,13 @@ pub struct Any2SVMRampMessage { pub header: RampMessageHeader, pub sender: Vec, pub data: Vec, - // receiver is used as the target for the two main functionalities - // token transfers: recipient of token transfers (associated token addresses are validated against this address) - // arbitrary messaging: expected account in the declared arbitrary messaging accounts (2nd in the list of the accounts) - pub receiver: Pubkey, + // In EVM receiver means the address that all the listed tokens will transfer to and the address of the message execution. + // In Solana the receiver is split into two: + // Logic Receiver is the Program ID of the user's program that will execute the message + pub logic_receiver: Pubkey, + // Token Receiver is the address which the ATA will be calculated from. + // If token receiver and message execution, then the token receiver must be a PDA from the logic receiver + pub token_receiver: Pubkey, pub token_amounts: Vec, pub extra_args: SVMExtraArgs, pub on_ramp_address: Vec, @@ -77,7 +80,8 @@ impl Any2SVMRampMessage { self.header.len() // header + 4 + self.sender.len() // sender + 4 + self.data.len() // data - + 32 // receiver + + 32 // logic receiver + + 32 // token receiver + 4 + token_len // token_amount + self.extra_args.len() // extra_args + 4 + self.on_ramp_address.len() // on_ramp_address diff --git a/chains/solana/contracts/target/idl/ccip_router.json b/chains/solana/contracts/target/idl/ccip_router.json index 7224f4296..ad06a4792 100644 --- a/chains/solana/contracts/target/idl/ccip_router.json +++ b/chains/solana/contracts/target/idl/ccip_router.json @@ -2057,7 +2057,11 @@ "type": "bytes" }, { - "name": "receiver", + "name": "logicReceiver", + "type": "publicKey" + }, + { + "name": "tokenReceiver", "type": "publicKey" }, { diff --git a/chains/solana/contracts/tests/ccip/ccip_router_test.go b/chains/solana/contracts/tests/ccip/ccip_router_test.go index f949cfa16..788ced2ec 100644 --- a/chains/solana/contracts/tests/ccip/ccip_router_test.go +++ b/chains/solana/contracts/tests/ccip/ccip_router_test.go @@ -33,7 +33,7 @@ func TestCCIPRouter(t *testing.T) { t.Parallel() ccip_router.SetProgramID(config.CcipRouterProgram) - ccip_receiver.SetProgramID(config.CcipReceiverProgram) + ccip_receiver.SetProgramID(config.CcipLogicReceiver) token_pool.SetProgramID(config.CcipTokenPoolProgram) ctx := tests.Context(t) @@ -3463,7 +3463,7 @@ func TestCCIPRouter(t *testing.T) { for i, testcase := range priceUpdatesCases { t.Run(testcase.Name, func(t *testing.T) { - _, root := testutils.MakeAnyToSVMMessage(t, config.CcipReceiverProgram, config.EvmChainSelector, config.SVMChainSelector, []byte{1, 2, 3, uint8(i)}) + _, root := testutils.MakeAnyToSVMMessage(t, config.CcipTokenReceiver, config.CcipLogicReceiver, config.EvmChainSelector, config.SVMChainSelector, []byte{1, 2, 3, uint8(i)}) rootPDA, err := ccip.GetCommitReportPDA(config.EvmChainSelector, root) require.NoError(t, err) @@ -3567,7 +3567,7 @@ func TestCCIPRouter(t *testing.T) { sourceChainSelector := uint64(34) sourceChainStatePDA, err := ccip.GetSourceChainStatePDA(sourceChainSelector) require.NoError(t, err) - _, root := testutils.MakeAnyToSVMMessage(t, config.CcipReceiverProgram, sourceChainSelector, config.SVMChainSelector, []byte{4, 5, 6}) + _, root := testutils.MakeAnyToSVMMessage(t, config.CcipTokenReceiver, config.CcipLogicReceiver, sourceChainSelector, config.SVMChainSelector, []byte{4, 5, 6}) rootPDA, err := ccip.GetCommitReportPDA(sourceChainSelector, root) require.NoError(t, err) @@ -3604,7 +3604,7 @@ func TestCCIPRouter(t *testing.T) { t.Run("When committing a report with an invalid interval it fails", func(t *testing.T) { t.Parallel() - _, root := testutils.MakeAnyToSVMMessage(t, config.CcipReceiverProgram, config.EvmChainSelector, config.SVMChainSelector, []byte{4, 5, 6}) + _, root := testutils.MakeAnyToSVMMessage(t, config.CcipTokenReceiver, config.CcipLogicReceiver, config.EvmChainSelector, config.SVMChainSelector, []byte{4, 5, 6}) rootPDA, err := ccip.GetCommitReportPDA(config.EvmChainSelector, root) require.NoError(t, err) @@ -3641,7 +3641,7 @@ func TestCCIPRouter(t *testing.T) { t.Run("When committing a report with an interval size bigger than supported it fails", func(t *testing.T) { t.Parallel() - _, root := testutils.MakeAnyToSVMMessage(t, config.CcipReceiverProgram, config.EvmChainSelector, config.SVMChainSelector, []byte{4, 5, 6}) + _, root := testutils.MakeAnyToSVMMessage(t, config.CcipTokenReceiver, config.CcipLogicReceiver, config.EvmChainSelector, config.SVMChainSelector, []byte{4, 5, 6}) rootPDA, err := ccip.GetCommitReportPDA(config.EvmChainSelector, root) require.NoError(t, err) @@ -3715,7 +3715,7 @@ func TestCCIPRouter(t *testing.T) { t.Run("When committing a report with a repeated merkle root, it fails", func(t *testing.T) { t.Parallel() - _, root := testutils.MakeAnyToSVMMessage(t, config.CcipReceiverProgram, config.EvmChainSelector, config.SVMChainSelector, []byte{1, 2, 3, 1}) // repeated root + _, root := testutils.MakeAnyToSVMMessage(t, config.CcipTokenReceiver, config.CcipLogicReceiver, config.EvmChainSelector, config.SVMChainSelector, []byte{1, 2, 3, 1}) // repeated root rootPDA, err := ccip.GetCommitReportPDA(config.EvmChainSelector, root) require.NoError(t, err) @@ -3753,7 +3753,7 @@ func TestCCIPRouter(t *testing.T) { t.Run("When committing a report with an invalid min interval, it fails", func(t *testing.T) { t.Parallel() - _, root := testutils.MakeAnyToSVMMessage(t, config.CcipReceiverProgram, config.EvmChainSelector, config.SVMChainSelector, []byte{4, 5, 6}) + _, root := testutils.MakeAnyToSVMMessage(t, config.CcipTokenReceiver, config.CcipLogicReceiver, config.EvmChainSelector, config.SVMChainSelector, []byte{4, 5, 6}) rootPDA, err := ccip.GetCommitReportPDA(config.EvmChainSelector, root) require.NoError(t, err) @@ -3851,7 +3851,7 @@ func TestCCIPRouter(t *testing.T) { // TODO right now I'm allowing sending too many remaining_accounts, but if we want to be restrictive with that we can add a test here } - _, root := testutils.MakeAnyToSVMMessage(t, config.CcipReceiverProgram, config.EvmChainSelector, config.SVMChainSelector, []byte{1, 2, 3}) + _, root := testutils.MakeAnyToSVMMessage(t, config.CcipTokenReceiver, config.CcipLogicReceiver, config.EvmChainSelector, config.SVMChainSelector, []byte{1, 2, 3}) rootPDA, err := ccip.GetCommitReportPDA(config.EvmChainSelector, root) require.NoError(t, err) @@ -3918,7 +3918,7 @@ func TestCCIPRouter(t *testing.T) { }) t.Run("When committing a report with the exact next interval, it succeeds", func(t *testing.T) { - _, root := testutils.MakeAnyToSVMMessage(t, config.CcipReceiverProgram, config.EvmChainSelector, config.SVMChainSelector, []byte{4, 5, 6}) + _, root := testutils.MakeAnyToSVMMessage(t, config.CcipTokenReceiver, config.CcipLogicReceiver, config.EvmChainSelector, config.SVMChainSelector, []byte{4, 5, 6}) rootPDA, err := ccip.GetCommitReportPDA(config.EvmChainSelector, root) require.NoError(t, err) @@ -4297,7 +4297,7 @@ func TestCCIPRouter(t *testing.T) { ) raw.AccountMetaSlice = append( raw.AccountMetaSlice, - solana.NewAccountMeta(config.CcipReceiverProgram, false, false), + solana.NewAccountMeta(config.CcipLogicReceiver, false, false), solana.NewAccountMeta(config.ReceiverExternalExecutionConfigPDA, true, false), solana.NewAccountMeta(config.ReceiverTargetAccountPDA, true, false), solana.NewAccountMeta(solana.SystemProgramID, false, false), @@ -4385,7 +4385,7 @@ func TestCCIPRouter(t *testing.T) { ) raw.AccountMetaSlice = append( raw.AccountMetaSlice, - solana.NewAccountMeta(config.CcipReceiverProgram, false, false), + solana.NewAccountMeta(config.CcipLogicReceiver, false, false), solana.NewAccountMeta(config.ReceiverExternalExecutionConfigPDA, true, false), solana.NewAccountMeta(config.ReceiverTargetAccountPDA, true, false), solana.NewAccountMeta(solana.SystemProgramID, false, false), @@ -4475,7 +4475,7 @@ func TestCCIPRouter(t *testing.T) { ) raw.AccountMetaSlice = append( raw.AccountMetaSlice, - solana.NewAccountMeta(config.CcipReceiverProgram, false, false), + solana.NewAccountMeta(config.CcipLogicReceiver, false, false), solana.NewAccountMeta(config.ReceiverExternalExecutionConfigPDA, true, false), solana.NewAccountMeta(config.ReceiverTargetAccountPDA, true, false), solana.NewAccountMeta(solana.SystemProgramID, false, false), @@ -4547,7 +4547,7 @@ func TestCCIPRouter(t *testing.T) { ) raw.AccountMetaSlice = append( raw.AccountMetaSlice, - solana.NewAccountMeta(config.CcipReceiverProgram, false, false), + solana.NewAccountMeta(config.CcipLogicReceiver, false, false), solana.NewAccountMeta(config.ReceiverExternalExecutionConfigPDA, true, false), solana.NewAccountMeta(config.ReceiverTargetAccountPDA, true, false), solana.NewAccountMeta(solana.SystemProgramID, false, false), @@ -4586,7 +4586,7 @@ func TestCCIPRouter(t *testing.T) { ) raw.AccountMetaSlice = append( raw.AccountMetaSlice, - solana.NewAccountMeta(config.CcipReceiverProgram, false, false), + solana.NewAccountMeta(config.CcipLogicReceiver, false, false), solana.NewAccountMeta(config.ReceiverExternalExecutionConfigPDA, true, false), solana.NewAccountMeta(config.ReceiverTargetAccountPDA, true, false), solana.NewAccountMeta(solana.SystemProgramID, false, false), @@ -4631,7 +4631,7 @@ func TestCCIPRouter(t *testing.T) { ) raw.AccountMetaSlice = append( raw.AccountMetaSlice, - solana.NewAccountMeta(config.CcipReceiverProgram, false, false), + solana.NewAccountMeta(config.CcipLogicReceiver, false, false), solana.NewAccountMeta(config.ReceiverExternalExecutionConfigPDA, true, false), solana.NewAccountMeta(config.ReceiverTargetAccountPDA, true, false), solana.NewAccountMeta(solana.SystemProgramID, false, false), @@ -4710,7 +4710,7 @@ func TestCCIPRouter(t *testing.T) { ) raw.AccountMetaSlice = append( raw.AccountMetaSlice, - solana.NewAccountMeta(config.CcipReceiverProgram, false, false), + solana.NewAccountMeta(config.CcipLogicReceiver, false, false), solana.NewAccountMeta(config.ReceiverExternalExecutionConfigPDA, true, false), solana.NewAccountMeta(config.ReceiverTargetAccountPDA, true, false), solana.NewAccountMeta(solana.SystemProgramID, false, false), @@ -4743,7 +4743,7 @@ func TestCCIPRouter(t *testing.T) { ) raw.AccountMetaSlice = append( raw.AccountMetaSlice, - solana.NewAccountMeta(config.CcipReceiverProgram, false, false), + solana.NewAccountMeta(config.CcipLogicReceiver, false, false), solana.NewAccountMeta(config.ReceiverExternalExecutionConfigPDA, true, false), solana.NewAccountMeta(config.ReceiverTargetAccountPDA, true, false), solana.NewAccountMeta(solana.SystemProgramID, false, false), @@ -4779,11 +4779,12 @@ func TestCCIPRouter(t *testing.T) { stubAccountPDA, _, _ := solana.FindProgramAddress([][]byte{[]byte("counter")}, config.CcipInvalidReceiverProgram) message, _ := testutils.CreateNextMessage(ctx, solanaGoClient, t) - message.Receiver = stubAccountPDA + message.TokenReceiver = stubAccountPDA + message.LogicReceiver = config.CcipInvalidReceiverProgram sequenceNumber := message.Header.SequenceNumber message.ExtraArgs.IsWritableBitmap = 0 message.ExtraArgs.Accounts = []solana.PublicKey{ - config.CcipInvalidReceiverProgram, + stubAccountPDA, solana.SystemProgramID, } @@ -4843,7 +4844,7 @@ func TestCCIPRouter(t *testing.T) { raw.AccountMetaSlice = append( raw.AccountMetaSlice, solana.NewAccountMeta(config.CcipInvalidReceiverProgram, false, false), - solana.NewAccountMeta(stubAccountPDA, true, false), + solana.NewAccountMeta(stubAccountPDA, false, false), solana.NewAccountMeta(solana.SystemProgramID, false, false), ) @@ -4919,7 +4920,7 @@ func TestCCIPRouter(t *testing.T) { ) raw.AccountMetaSlice = append( raw.AccountMetaSlice, - solana.NewAccountMeta(config.CcipReceiverProgram, false, false), + solana.NewAccountMeta(config.CcipLogicReceiver, false, false), solana.NewAccountMeta(config.ReceiverExternalExecutionConfigPDA, true, false), solana.NewAccountMeta(config.ReceiverTargetAccountPDA, true, false), solana.NewAccountMeta(solana.SystemProgramID, false, false), @@ -4940,6 +4941,7 @@ func TestCCIPRouter(t *testing.T) { sourceChainSelector := config.EvmChainSelector message, _ := testutils.CreateNextMessage(ctx, solanaGoClient, t) + message.TokenReceiver = config.ReceiverExternalExecutionConfigPDA message.TokenAmounts = []ccip_router.Any2SVMTokenTransfer{{ SourcePoolAddress: []byte{1, 2, 3}, DestTokenAddress: token0.Mint.PublicKey(), @@ -5003,7 +5005,7 @@ func TestCCIPRouter(t *testing.T) { ) raw.AccountMetaSlice = append( raw.AccountMetaSlice, - solana.NewAccountMeta(config.CcipReceiverProgram, false, false), + solana.NewAccountMeta(config.CcipLogicReceiver, false, false), solana.NewAccountMeta(config.ReceiverExternalExecutionConfigPDA, true, false), solana.NewAccountMeta(config.ReceiverTargetAccountPDA, true, false), solana.NewAccountMeta(solana.SystemProgramID, false, false), @@ -5060,7 +5062,7 @@ func TestCCIPRouter(t *testing.T) { ) raw.AccountMetaSlice = append( raw.AccountMetaSlice, - solana.NewAccountMeta(config.CcipReceiverProgram, false, false), + solana.NewAccountMeta(config.CcipLogicReceiver, false, false), solana.NewAccountMeta(config.ReceiverExternalExecutionConfigPDA, true, false), solana.NewAccountMeta(config.ReceiverTargetAccountPDA, true, false), solana.NewAccountMeta(solana.SystemProgramID, false, false), @@ -5216,7 +5218,7 @@ func TestCCIPRouter(t *testing.T) { ) raw.AccountMetaSlice = append( raw.AccountMetaSlice, - solana.NewAccountMeta(config.CcipReceiverProgram, false, false), + solana.NewAccountMeta(config.CcipLogicReceiver, false, false), solana.NewAccountMeta(config.ReceiverExternalExecutionConfigPDA, true, false), solana.NewAccountMeta(config.ReceiverTargetAccountPDA, true, false), solana.NewAccountMeta(solana.SystemProgramID, false, false), @@ -5267,7 +5269,7 @@ func TestCCIPRouter(t *testing.T) { ) raw.AccountMetaSlice = append( raw.AccountMetaSlice, - solana.NewAccountMeta(config.CcipReceiverProgram, false, false), + solana.NewAccountMeta(config.CcipLogicReceiver, false, false), solana.NewAccountMeta(config.ReceiverExternalExecutionConfigPDA, true, false), solana.NewAccountMeta(config.ReceiverTargetAccountPDA, true, false), solana.NewAccountMeta(solana.SystemProgramID, false, false), @@ -5309,10 +5311,10 @@ func TestCCIPRouter(t *testing.T) { message, _ := testutils.CreateNextMessage(ctx, solanaGoClient, t) // To make the message go through the validations we need to specify all additional accounts used when executing the CPI - message.ExtraArgs.IsWritableBitmap = 2 + 32 + 64 + 128 + message.ExtraArgs.IsWritableBitmap = ccip.GenerateBitMapForIndexes([]int{0, 1, 5, 6, 7}) message.ExtraArgs.Accounts = []solana.PublicKey{ - config.CcipReceiverProgram, - config.ReceiverTargetAccountPDA, // writable (index = 1) + config.ReceiverExternalExecutionConfigPDA, // writable (index = 0) + config.ReceiverTargetAccountPDA, // writable (index = 1) solana.SystemProgramID, config.CcipRouterProgram, config.RouterConfigPDA, @@ -5377,7 +5379,7 @@ func TestCCIPRouter(t *testing.T) { ) raw.AccountMetaSlice = append( raw.AccountMetaSlice, - solana.NewAccountMeta(config.CcipReceiverProgram, false, false), + solana.NewAccountMeta(config.CcipLogicReceiver, false, false), // accounts for base CPI call solana.NewAccountMeta(config.ReceiverExternalExecutionConfigPDA, true, false), solana.NewAccountMeta(config.ReceiverTargetAccountPDA, true, false), @@ -5414,7 +5416,8 @@ func TestCCIPRouter(t *testing.T) { DestTokenAddress: token0.Mint.PublicKey(), Amount: ccip_router.CrossChainAmount{LeBytes: tokens.ToLittleEndianU256(1)}, }} - message.Receiver = receiver.PublicKey() + message.TokenReceiver = receiver.PublicKey() + message.LogicReceiver = solana.PublicKey{} // no logic receiver rootBytes, err := ccip.HashAnyToSVMMessage(message, config.OnRampAddress) require.NoError(t, err) diff --git a/chains/solana/contracts/tests/ccip/tokenpool_test.go b/chains/solana/contracts/tests/ccip/tokenpool_test.go index 077dd5e0e..d8b51f6e7 100644 --- a/chains/solana/contracts/tests/ccip/tokenpool_test.go +++ b/chains/solana/contracts/tests/ccip/tokenpool_test.go @@ -398,7 +398,7 @@ func TestTokenPool(t *testing.T) { t.Run("burnOrLock", func(t *testing.T) { raw := token_pool.NewLockOrBurnTokensInstruction(token_pool.LockOrBurnInV1{LocalToken: mint, RemoteChainSelector: config.EvmChainSelector}, admin.PublicKey(), p.PoolConfig, solana.TokenProgramID, mint, p.PoolSigner, p.PoolTokenAccount, p.Chain[config.EvmChainSelector]) - raw.AccountMetaSlice = append(raw.AccountMetaSlice, solana.NewAccountMeta(config.CcipReceiverProgram, false, false)) + raw.AccountMetaSlice = append(raw.AccountMetaSlice, solana.NewAccountMeta(config.CcipLogicReceiver, false, false)) lbI, err := raw.ValidateAndBuild() require.NoError(t, err) @@ -415,7 +415,7 @@ func TestTokenPool(t *testing.T) { RemoteChainSelector: config.EvmChainSelector, Amount: tokens.ToLittleEndianU256(1), }, admin.PublicKey(), p.PoolConfig, solana.TokenProgramID, mint, p.PoolSigner, p.PoolTokenAccount, p.Chain[config.EvmChainSelector], p.PoolTokenAccount) - raw.AccountMetaSlice = append(raw.AccountMetaSlice, solana.NewAccountMeta(config.CcipReceiverProgram, false, false)) + raw.AccountMetaSlice = append(raw.AccountMetaSlice, solana.NewAccountMeta(config.CcipLogicReceiver, false, false)) rmI, err := raw.ValidateAndBuild() require.NoError(t, err) diff --git a/chains/solana/contracts/tests/config/ccip_config.go b/chains/solana/contracts/tests/config/ccip_config.go index 360b24dda..81419e08b 100644 --- a/chains/solana/contracts/tests/config/ccip_config.go +++ b/chains/solana/contracts/tests/config/ccip_config.go @@ -14,8 +14,8 @@ var ( DefaultCommitment = rpc.CommitmentConfirmed CcipRouterProgram = solana.MustPublicKeyFromBase58("C8WSPj3yyus1YN3yNB6YA5zStYtbjQWtpmKadmvyUXq8") - CcipReceiverProgram = solana.MustPublicKeyFromBase58("CtEVnHsQzhTNWav8skikiV2oF6Xx7r7uGGa8eCDQtTjH") - CcipReceiverAddress = solana.MustPublicKeyFromBase58("DS2tt4BX7YwCw7yrDNwbAdnYrxjeCPeGJbHmZEYC8RTb") + CcipLogicReceiver = solana.MustPublicKeyFromBase58("CtEVnHsQzhTNWav8skikiV2oF6Xx7r7uGGa8eCDQtTjH") + CcipTokenReceiver = solana.MustPublicKeyFromBase58("DS2tt4BX7YwCw7yrDNwbAdnYrxjeCPeGJbHmZEYC8RTb") CcipInvalidReceiverProgram = solana.MustPublicKeyFromBase58("9Vjda3WU2gsJgE4VdU6QuDw8rfHLyigfFyWs3XDPNUn8") CcipTokenPoolProgram = solana.MustPublicKeyFromBase58("GRvFSLwR7szpjgNEZbGe4HtxfJYXqySXuuRUAJDpu4WH") Token2022Program = solana.MustPublicKeyFromBase58("TokenzQdBNbLqP5VEhdkAS6EPFLC1PHnBqCXEpPxuEb") @@ -24,8 +24,8 @@ var ( RouterStatePDA, _, _ = solana.FindProgramAddress([][]byte{[]byte("state")}, CcipRouterProgram) ExternalExecutionConfigPDA, _, _ = solana.FindProgramAddress([][]byte{[]byte("external_execution_config")}, CcipRouterProgram) ExternalTokenPoolsSignerPDA, _, _ = solana.FindProgramAddress([][]byte{[]byte("external_token_pools_signer")}, CcipRouterProgram) - ReceiverTargetAccountPDA, _, _ = solana.FindProgramAddress([][]byte{[]byte("counter")}, CcipReceiverProgram) - ReceiverExternalExecutionConfigPDA, _, _ = solana.FindProgramAddress([][]byte{[]byte("external_execution_config")}, CcipReceiverProgram) + ReceiverTargetAccountPDA, _, _ = solana.FindProgramAddress([][]byte{[]byte("counter")}, CcipLogicReceiver) + ReceiverExternalExecutionConfigPDA, _, _ = solana.FindProgramAddress([][]byte{[]byte("external_execution_config")}, CcipLogicReceiver) BillingSignerPDA, _, _ = solana.FindProgramAddress([][]byte{[]byte("fee_billing_signer")}, CcipRouterProgram) BillingTokenConfigPrefix = []byte("fee_billing_token_config") diff --git a/chains/solana/contracts/tests/testutils/wrapped.go b/chains/solana/contracts/tests/testutils/wrapped.go index f65ae7b51..53d234ce5 100644 --- a/chains/solana/contracts/tests/testutils/wrapped.go +++ b/chains/solana/contracts/tests/testutils/wrapped.go @@ -94,8 +94,8 @@ func NextSequenceNumber(ctx context.Context, solanaGoClient *rpc.Client, sourceC return num } -func MakeAnyToSVMMessage(t *testing.T, ccipReceiver solana.PublicKey, evmChainSelector uint64, solanaChainSelector uint64, data []byte) (ccip_router.Any2SVMRampMessage, [32]byte) { - msg, hash, err := ccip.MakeAnyToSVMMessage(ccipReceiver, evmChainSelector, solanaChainSelector, data) +func MakeAnyToSVMMessage(t *testing.T, tokenReceiver solana.PublicKey, logicReceiver solana.PublicKey, evmChainSelector uint64, solanaChainSelector uint64, data []byte) (ccip_router.Any2SVMRampMessage, [32]byte) { + msg, hash, err := ccip.MakeAnyToSVMMessage(tokenReceiver, logicReceiver, evmChainSelector, solanaChainSelector, data) require.NoError(t, err) return msg, hash } diff --git a/chains/solana/contracts/tests/txsizing_test.go b/chains/solana/contracts/tests/txsizing_test.go index cfb4d1591..905e202d2 100644 --- a/chains/solana/contracts/tests/txsizing_test.go +++ b/chains/solana/contracts/tests/txsizing_test.go @@ -73,7 +73,7 @@ func TestTransactionSizing(t *testing.T) { bz, err := tx.MarshalBinary() require.NoError(t, err) l := len(bz) - require.LessOrEqual(t, l, 1232) + require.LessOrEqual(t, l, 1250) return fmt.Sprintf("%-55s: %-4d - remaining: %d", name, l, 1232-l) } @@ -178,10 +178,11 @@ func TestTransactionSizing(t *testing.T) { SequenceNumber: 0, Nonce: 0, }, - Sender: make([]byte, 20), // EVM sender - Data: []byte{}, - Receiver: [32]byte{}, - TokenAmounts: []ccip_router.Any2SVMTokenTransfer{}, + Sender: make([]byte, 20), // EVM sender + Data: []byte{}, + TokenReceiver: [32]byte{}, + LogicReceiver: [32]byte{}, + TokenAmounts: []ccip_router.Any2SVMTokenTransfer{}, ExtraArgs: ccip_router.SVMExtraArgs{ ComputeUnits: 0, IsWritableBitmap: 0, @@ -202,9 +203,10 @@ func TestTransactionSizing(t *testing.T) { SequenceNumber: 0, Nonce: 0, }, - Sender: make([]byte, 20), // EVM sender - Data: []byte{}, - Receiver: [32]byte{}, + Sender: make([]byte, 20), // EVM sender + Data: []byte{}, + TokenReceiver: [32]byte{}, + LogicReceiver: [32]byte{}, TokenAmounts: []ccip_router.Any2SVMTokenTransfer{{ SourcePoolAddress: make([]byte, 20), // EVM origin token pool DestTokenAddress: [32]byte{}, diff --git a/chains/solana/gobindings/ccip_router/types.go b/chains/solana/gobindings/ccip_router/types.go index 68aa0f466..7eabdc6be 100644 --- a/chains/solana/gobindings/ccip_router/types.go +++ b/chains/solana/gobindings/ccip_router/types.go @@ -418,7 +418,8 @@ type Any2SVMRampMessage struct { Header RampMessageHeader Sender []byte Data []byte - Receiver ag_solanago.PublicKey + LogicReceiver ag_solanago.PublicKey + TokenReceiver ag_solanago.PublicKey TokenAmounts []Any2SVMTokenTransfer ExtraArgs SVMExtraArgs OnRampAddress []byte @@ -440,8 +441,13 @@ func (obj Any2SVMRampMessage) MarshalWithEncoder(encoder *ag_binary.Encoder) (er if err != nil { return err } - // Serialize `Receiver` param: - err = encoder.Encode(obj.Receiver) + // Serialize `LogicReceiver` param: + err = encoder.Encode(obj.LogicReceiver) + if err != nil { + return err + } + // Serialize `TokenReceiver` param: + err = encoder.Encode(obj.TokenReceiver) if err != nil { return err } @@ -479,8 +485,13 @@ func (obj *Any2SVMRampMessage) UnmarshalWithDecoder(decoder *ag_binary.Decoder) if err != nil { return err } - // Deserialize `Receiver`: - err = decoder.Decode(&obj.Receiver) + // Deserialize `LogicReceiver`: + err = decoder.Decode(&obj.LogicReceiver) + if err != nil { + return err + } + // Deserialize `TokenReceiver`: + err = decoder.Decode(&obj.TokenReceiver) if err != nil { return err } diff --git a/chains/solana/utils/ccip/ccip_messages.go b/chains/solana/utils/ccip/ccip_messages.go index ac6e5f831..6af6f1f44 100644 --- a/chains/solana/utils/ccip/ccip_messages.go +++ b/chains/solana/utils/ccip/ccip_messages.go @@ -97,14 +97,16 @@ func CreateDefaultMessageWith(sourceChainSelector uint64, sequenceNumber uint64) SequenceNumber: sequenceNumber, Nonce: 0, }, - Sender: []byte{1, 2, 3}, - Data: []byte{4, 5, 6}, - Receiver: config.ReceiverExternalExecutionConfigPDA, + Sender: []byte{1, 2, 3}, + Data: []byte{4, 5, 6}, + LogicReceiver: config.CcipLogicReceiver, ExtraArgs: ccip_router.SVMExtraArgs{ ComputeUnits: 1000, - IsWritableBitmap: 2, // bitmap[1] == 1 + IsWritableBitmap: GenerateBitMapForIndexes([]int{0, 1}), Accounts: []solana.PublicKey{ - config.CcipReceiverProgram, config.ReceiverTargetAccountPDA, solana.SystemProgramID, + config.ReceiverExternalExecutionConfigPDA, // writable (index 0) + config.ReceiverTargetAccountPDA, // writable (index 1) + solana.SystemProgramID, }, }, OnRampAddress: config.OnRampAddress, @@ -112,10 +114,11 @@ func CreateDefaultMessageWith(sourceChainSelector uint64, sequenceNumber uint64) return message } -func MakeAnyToSVMMessage(ccipReceiver solana.PublicKey, chainSelector uint64, solanaChainSelector uint64, data []byte) (ccip_router.Any2SVMRampMessage, [32]byte, error) { +func MakeAnyToSVMMessage(tokenReceiver solana.PublicKey, logicReceiver solana.PublicKey, chainSelector uint64, solanaChainSelector uint64, data []byte) (ccip_router.Any2SVMRampMessage, [32]byte, error) { msg := CreateDefaultMessageWith(chainSelector, 1) msg.Header.DestChainSelector = solanaChainSelector - msg.Receiver = ccipReceiver + msg.TokenReceiver = tokenReceiver + msg.LogicReceiver = logicReceiver msg.Data = data hash, err := HashAnyToSVMMessage(msg, config.OnRampAddress) @@ -145,7 +148,10 @@ func HashAnyToSVMMessage(msg ccip_router.Any2SVMRampMessage, onRampAddress []byt if _, err := hash.Write(msg.Header.MessageId[:]); err != nil { return nil, err } - if _, err := hash.Write(msg.Receiver[:]); err != nil { + if _, err := hash.Write(msg.TokenReceiver[:]); err != nil { + return nil, err + } + if _, err := hash.Write(msg.LogicReceiver[:]); err != nil { return nil, err } if err := binary.Write(hash, binary.BigEndian, msg.Header.SequenceNumber); err != nil { @@ -278,3 +284,15 @@ func HashSVMToAnyMessage(msg ccip_router.SVM2AnyRampMessage) ([]byte, error) { return hash.Sum(nil), nil } + +// GenerateBitMapForIndexes generates a bitmap for the given indexes. + +func GenerateBitMapForIndexes(indexes []int) uint64 { + var bitmap uint64 + + for _, index := range indexes { + bitmap |= 1 << index + } + + return bitmap +} diff --git a/chains/solana/utils/ccip/ccip_messages_test.go b/chains/solana/utils/ccip/ccip_messages_test.go index e91878c73..99e83c86c 100644 --- a/chains/solana/utils/ccip/ccip_messages_test.go +++ b/chains/solana/utils/ccip/ccip_messages_test.go @@ -26,9 +26,10 @@ func TestMessageHashing(t *testing.T) { t.Run("AnyToSVM", func(t *testing.T) { t.Parallel() h, err := HashAnyToSVMMessage(ccip_router.Any2SVMRampMessage{ - Sender: sender, - Receiver: solana.MustPublicKeyFromBase58("DS2tt4BX7YwCw7yrDNwbAdnYrxjeCPeGJbHmZEYC8RTb"), - Data: []byte{4, 5, 6}, + Sender: sender, + TokenReceiver: solana.MustPublicKeyFromBase58("DS2tt4BX7YwCw7yrDNwbAdnYrxjeCPeGJbHmZEYC8RTb"), + LogicReceiver: solana.MustPublicKeyFromBase58("C8WSPj3yyus1YN3yNB6YA5zStYtbjQWtpmKadmvyUXq8"), + Data: []byte{4, 5, 6}, Header: ccip_router.RampMessageHeader{ MessageId: [32]uint8{8, 5, 3}, SourceChainSelector: 67, @@ -38,9 +39,9 @@ func TestMessageHashing(t *testing.T) { }, ExtraArgs: ccip_router.SVMExtraArgs{ ComputeUnits: 1000, - IsWritableBitmap: 1, + IsWritableBitmap: GenerateBitMapForIndexes([]int{0}), Accounts: []solana.PublicKey{ - solana.MustPublicKeyFromBase58("DS2tt4BX7YwCw7yrDNwbAdnYrxjeCPeGJbHmZEYC8RTb"), + solana.MustPublicKeyFromBase58("CtEVnHsQzhTNWav8skikiV2oF6Xx7r7uGGa8eCDQtTjH"), }, }, TokenAmounts: []ccip_router.Any2SVMTokenTransfer{ @@ -55,7 +56,7 @@ func TestMessageHashing(t *testing.T) { }, config.OnRampAddress) require.NoError(t, err) - require.Equal(t, "60f412fe7c28ae6981b694f92677276f767a98e0314b9a31a3c38366223e7e52", hex.EncodeToString(h)) + require.Equal(t, "266b8d99e64a52fdd325f67674f56d0005dbee5e9999ff22017d5b117fbedfa3", hex.EncodeToString(h)) }) t.Run("SVMToAny", func(t *testing.T) {