Skip to content

Commit

Permalink
Feat/split receiver (#448)
Browse files Browse the repository at this point in the history
* Split receiver into logic and token

* WIP

* Fix some tests

* improve comment

* Fix tests

* Fix tests

* Use method to calculate bitmap

* fix hash

* Fix tx size test

* Fix tx sizing test

* Refactor token accounts method
  • Loading branch information
agusaldasoro authored Jan 21, 2025
1 parent 55385e8 commit e42d6d0
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::ocr3base::{ocr3_transmit, ReportContext};
use super::ocr3impl::{Ocr3ReportForCommit, Ocr3ReportForExecutionReportSingleChain};
use super::pools::{
calculate_token_pool_account_indices, get_balance, interact_with_pool,
validate_and_parse_token_accounts, CCIP_POOL_V1_RET_BYTES,
validate_and_parse_token_accounts, TokenAccounts, CCIP_POOL_V1_RET_BYTES,
};

use crate::{
Expand Down Expand Up @@ -403,14 +403,13 @@ fn internal_execute<'info>(
// note: indexes are used instead of counts in case more accounts need to be passed in remaining_accounts before token accounts
// token_indexes = [2, 4] where remaining_accounts is [custom_account, custom_account, token1_account1, token1_account2, token2_account1, token2_account2] for example
for (i, token_amount) in execution_report.message.token_amounts.iter().enumerate() {
let (start, end) =
calculate_token_pool_account_indices(i, token_indexes, ctx.remaining_accounts.len())?;
let acc_list = &ctx.remaining_accounts[start..end];
let accs = validate_and_parse_token_accounts(
execution_report.message.receiver,
execution_report.message.header.source_chain_selector,
let accs = get_token_accounts_for(
ctx.program_id.key(),
acc_list,
ctx.remaining_accounts,
execution_report.message.token_receiver,
execution_report.message.header.source_chain_selector,
token_indexes,
i,
)?;
let router_token_pool_signer = &ctx.accounts.token_pools_signer;

Expand All @@ -419,7 +418,7 @@ fn internal_execute<'info>(
// CPI: call lockOrBurn on token pool
let release_or_mint = ReleaseOrMintInV1 {
original_sender: execution_report.message.sender.clone(),
receiver: execution_report.message.receiver,
receiver: execution_report.message.token_receiver,
amount: token_amount.amount,
local_token: token_amount.dest_token_address,
remote_chain_selector: execution_report.message.header.source_chain_selector,
Expand Down Expand Up @@ -474,10 +473,13 @@ fn internal_execute<'info>(
// handle CPI call if there are extra accounts
// case: no tokens, but there are remaining_accounts passed in
// case: tokens, but the first token has a non-zero index (indicating extra accounts before token accounts)
if should_execute_messaging(token_indexes, ctx.remaining_accounts.is_empty()) {
if should_execute_messaging(
&execution_report.message.logic_receiver,
ctx.remaining_accounts.is_empty(),
) {
let (msg_program, msg_accounts) = parse_messaging_accounts(
token_indexes,
execution_report.message.receiver,
execution_report.message.logic_receiver,
&execution_report.message.extra_args.accounts,
&execution_report.message.extra_args.is_writable_bitmap,
ctx.remaining_accounts,
Expand Down Expand Up @@ -537,72 +539,79 @@ fn internal_execute<'info>(
Ok(())
}

// should_execute_messaging checks if there remaining_accounts that are not being used for token pools
// case: no tokens, but there are remaining_accounts passed in
// case: tokens, but the first token has a non-zero index (indicating extra accounts before token accounts)
fn should_execute_messaging(token_indexes: &[u8], remaining_accounts_empty: bool) -> bool {
(token_indexes.is_empty() && !remaining_accounts_empty)
|| (!token_indexes.is_empty() && token_indexes[0] != 0)
fn get_token_accounts_for<'a>(
router: Pubkey,
accounts: &'a [AccountInfo<'a>],
token_receiver: Pubkey,
chain_selector: u64,
token_indexes: &[u8],
i: usize,
) -> Result<TokenAccounts<'a>> {
let (start, end) = calculate_token_pool_account_indices(i, token_indexes, accounts.len())?;

let accs = validate_and_parse_token_accounts(
token_receiver,
chain_selector,
router,
&accounts[start..end],
)?;

Ok(accs)
}

// should_execute_messaging checks if:
// 1. There is at least one account used for messaging (the first subset of accounts). This is because the first account is the program id to do the CPI
// 2. AND the logic_receiver has a value different than zeros
fn should_execute_messaging(logic_receiver: &Pubkey, remaining_accounts_empty: bool) -> bool {
!remaining_accounts_empty && *logic_receiver != Pubkey::default()
}

/// parse_message_accounts returns all the accounts needed to execute the CPI instruction
/// It also validates that the accounts sent in the message match the ones sent in the source chain
/// Precondition: logic_receiver != 0 && remaining_accounts.len() > 0
///
/// # Arguments
/// * `token_indexes` - start indexes of token pool accounts, used to determine ending index for arbitrary messaging accounts
/// * `receiver` - receiver address from x-chain message, used to validate `accounts`
/// * `source_accounts` - arbitrary messaging accounts from the x-chain message, used to validate `accounts`. expected order is: [program, ...additional message accounts]
/// * `accounts` - accounts passed via `ctx.remaining_accounts`. expected order is: [program, receiver, ...additional message accounts]
/// * `logic_receiver` - receiver address from x-chain message, used to validate `remaining_accounts`
/// * `extra_args_accounts` - arbitrary messaging accounts from the x-chain message, used to validate `accounts`.
/// * `remaining_accounts` - accounts passed via `ctx.remaining_accounts`. expected order is: [program, receiver, ...additional message accounts]
fn parse_messaging_accounts<'info>(
token_indexes: &[u8],
receiver: Pubkey,
source_accounts: &[Pubkey],
logic_receiver: Pubkey,
extra_args_accounts: &[Pubkey],
source_bitmap: &u64,
accounts: &'info [AccountInfo<'info>],
remaining_accounts: &'info [AccountInfo<'info>],
) -> Result<(&'info AccountInfo<'info>, &'info [AccountInfo<'info>])> {
let end_ind = if token_indexes.is_empty() {
accounts.len()
let end_index = if token_indexes.is_empty() {
remaining_accounts.len()
} else {
token_indexes[0] as usize
};

let msg_program = &accounts[0];
let msg_accounts = &accounts[1..end_ind];
require!(
1 <= end_index && end_index <= remaining_accounts.len(), // program id and message accounts need to fit in remaining accounts
CcipRouterError::InvalidInputs
); // there could be other remaining accounts used for tokens

let source_program = &source_accounts[0];
let source_msg_accounts = &source_accounts[1..source_accounts.len()];
let msg_program = &remaining_accounts[0];
let msg_accounts = &remaining_accounts[1..end_index];

require!(
*source_program == msg_program.key(),
logic_receiver == msg_program.key(),
CcipRouterError::InvalidInputs,
);

require!(
msg_accounts[0].key() == receiver,
CcipRouterError::InvalidInputs
);

// assert same number of accounts passed from message and transaction (not including program)
// source_msg_accounts + 1 to account for separately passed receiver address
require!(
source_msg_accounts.len() + 1 == msg_accounts.len(),
msg_accounts.len() == extra_args_accounts.len(), // assert same number of accounts passed from message and transaction
CcipRouterError::InvalidInputs
);

// Validate the addresses of all the accounts match the ones in source chain
if msg_accounts.len() > 1 {
// Ignore the first account as it's the receiver
let accounts_to_validate = &msg_accounts[1..msg_accounts.len()];
require!(
accounts_to_validate.len() == source_msg_accounts.len(),
CcipRouterError::InvalidInputs
);
for (i, acc) in source_msg_accounts.iter().enumerate() {
let current_acc = &msg_accounts[i + 1]; // TODO: remove offset by 1 to skip receiver after receiver refactor
for (i, acc) in extra_args_accounts.iter().enumerate() {
let current_acc = &msg_accounts[i];
require!(*acc == current_acc.key(), CcipRouterError::InvalidInputs);
require!(
// TODO: remove offset by 1 to skip program after receiver refactor
is_writable(source_bitmap, (i + 1) as u8) == current_acc.is_writable,
is_writable(source_bitmap, (i) as u8) == current_acc.is_writable,
CcipRouterError::InvalidInputs
);
}
Expand All @@ -622,7 +631,6 @@ pub fn verify_merkle_root(execution_report: &ExecutionReportSingleChain) -> Resu
Ok(hashed_leaf)
}

// TODO: Refactor this to use the same structure as messages: execution_report.validate(..)
pub fn validate_execution_report<'info>(
execution_report: &ExecutionReportSingleChain,
source_chain_state: &Account<'info, SourceChain>,
Expand Down Expand Up @@ -688,7 +696,8 @@ fn hash(msg: &Any2SVMRampMessage) -> [u8; 32] {
&msg.on_ramp_address,
// message header
&msg.header.message_id,
&msg.receiver.to_bytes(),
&msg.token_receiver.to_bytes(),
&msg.logic_receiver.to_bytes(),
&header_sequence_number,
msg.extra_args.try_to_vec().unwrap().as_ref(), // borsh serialized
&header_nonce,
Expand Down Expand Up @@ -840,7 +849,10 @@ mod tests {
0, 0, 0, 0,
]
.to_vec(),
receiver: Pubkey::try_from("DS2tt4BX7YwCw7yrDNwbAdnYrxjeCPeGJbHmZEYC8RTb").unwrap(),
token_receiver: Pubkey::try_from("DS2tt4BX7YwCw7yrDNwbAdnYrxjeCPeGJbHmZEYC8RTb")
.unwrap(),
logic_receiver: Pubkey::try_from("C8WSPj3yyus1YN3yNB6YA5zStYtbjQWtpmKadmvyUXq8")
.unwrap(),
data: vec![4, 5, 6],
header: RampMessageHeader {
message_id: [
Expand All @@ -867,15 +879,15 @@ mod tests {
compute_units: 1000,
is_writable_bitmap: 1,
accounts: vec![
Pubkey::try_from("DS2tt4BX7YwCw7yrDNwbAdnYrxjeCPeGJbHmZEYC8RTb").unwrap(),
Pubkey::try_from("CtEVnHsQzhTNWav8skikiV2oF6Xx7r7uGGa8eCDQtTjH").unwrap(),
],
},
on_ramp_address: on_ramp_address.clone(),
};
let hash_result = hash(&message);

assert_eq!(
"60f412fe7c28ae6981b694f92677276f767a98e0314b9a31a3c38366223e7e52",
"266b8d99e64a52fdd325f67674f56d0005dbee5e9999ff22017d5b117fbedfa3",
hex::encode(hash_result)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub(super) struct TokenAccounts<'a> {
}

pub(super) fn validate_and_parse_token_accounts<'info>(
user: Pubkey,
token_receiver: Pubkey,
chain_selector: u64,
router: Pubkey,
accounts: &'info [AccountInfo<'info>],
Expand Down Expand Up @@ -122,7 +122,7 @@ pub(super) fn validate_and_parse_token_accounts<'info>(
require!(
user_token_account.key()
== get_associated_token_address_with_program_id(
&user,
&token_receiver,
&mint.key(),
&token_program.key()
)
Expand Down
14 changes: 9 additions & 5 deletions chains/solana/contracts/programs/ccip-router/src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,13 @@ pub struct Any2SVMRampMessage {
pub header: RampMessageHeader,
pub sender: Vec<u8>,
pub data: Vec<u8>,
// receiver is used as the target for the two main functionalities
// token transfers: recipient of token transfers (associated token addresses are validated against this address)
// arbitrary messaging: expected account in the declared arbitrary messaging accounts (2nd in the list of the accounts)
pub receiver: Pubkey,
// In EVM receiver means the address that all the listed tokens will transfer to and the address of the message execution.
// In Solana the receiver is split into two:
// Logic Receiver is the Program ID of the user's program that will execute the message
pub logic_receiver: Pubkey,
// Token Receiver is the address which the ATA will be calculated from.
// If token receiver and message execution, then the token receiver must be a PDA from the logic receiver
pub token_receiver: Pubkey,
pub token_amounts: Vec<Any2SVMTokenTransfer>,
pub extra_args: SVMExtraArgs,
pub on_ramp_address: Vec<u8>,
Expand All @@ -77,7 +80,8 @@ impl Any2SVMRampMessage {
self.header.len() // header
+ 4 + self.sender.len() // sender
+ 4 + self.data.len() // data
+ 32 // receiver
+ 32 // logic receiver
+ 32 // token receiver
+ 4 + token_len // token_amount
+ self.extra_args.len() // extra_args
+ 4 + self.on_ramp_address.len() // on_ramp_address
Expand Down
6 changes: 5 additions & 1 deletion chains/solana/contracts/target/idl/ccip_router.json
Original file line number Diff line number Diff line change
Expand Up @@ -2057,7 +2057,11 @@
"type": "bytes"
},
{
"name": "receiver",
"name": "logicReceiver",
"type": "publicKey"
},
{
"name": "tokenReceiver",
"type": "publicKey"
},
{
Expand Down
Loading

0 comments on commit e42d6d0

Please sign in to comment.