Skip to content

Commit

Permalink
token 2022: add support for _reading_ repeating fixed-length extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
buffalojoec committed Nov 14, 2023
1 parent 3cf8ea8 commit 5114e94
Showing 1 changed file with 176 additions and 0 deletions.
176 changes: 176 additions & 0 deletions token/program-2022/src/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,35 @@ fn get_extension_indices<V: Extension>(
Err(ProgramError::InvalidAccountData)
}

fn get_all_extension_indices<V: Extension>(
tlv_data: &[u8],
) -> Result<Vec<TlvIndices>, ProgramError> {
let mut indices_list = vec![];
let mut start_index = 0;
let v_account_type = V::TYPE.get_account_type();
while start_index < tlv_data.len() {
let tlv_indices = get_tlv_indices(start_index);
if tlv_data.len() < tlv_indices.value_start {
return Err(ProgramError::InvalidAccountData);
}
let extension_type =
ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
let account_type = extension_type.get_account_type();

let length =
pod_from_bytes::<Length>(&tlv_data[tlv_indices.length_start..tlv_indices.value_start])?;
let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
start_index = value_end_index;

if extension_type == V::TYPE {
indices_list.push(tlv_indices);
} else if v_account_type != account_type && extension_type != ExtensionType::Uninitialized {
return Err(TokenError::ExtensionTypeMismatch.into());
}
}
Ok(indices_list)
}

/// Basic information about the TLV buffer, collected from iterating through all
/// entries
#[derive(Debug, PartialEq)]
Expand Down Expand Up @@ -339,6 +368,30 @@ fn get_extension_bytes<S: BaseState, V: Extension>(tlv_data: &[u8]) -> Result<&[
Ok(&tlv_data[value_start..value_end])
}

fn get_all_extension_bytes<S: BaseState, V: Extension>(
tlv_data: &[u8],
) -> Result<Vec<&[u8]>, ProgramError> {
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
return Err(ProgramError::InvalidAccountData);
}
let all_extension_indices = get_all_extension_indices::<V>(tlv_data)?;
let mut all_extension_bytes = vec![];
for TlvIndices {
type_start: _,
length_start,
value_start,
} in all_extension_indices.iter()
{
let length = pod_from_bytes::<Length>(&tlv_data[*length_start..*value_start])?;
let value_end = value_start.saturating_add(usize::from(*length));
if tlv_data.len() < value_end {
return Err(ProgramError::InvalidAccountData);
}
all_extension_bytes.push(&tlv_data[*value_start..value_end]);
}
Ok(all_extension_bytes)
}

fn get_extension_bytes_mut<S: BaseState, V: Extension>(
tlv_data: &mut [u8],
) -> Result<&mut [u8], ProgramError> {
Expand Down Expand Up @@ -397,11 +450,46 @@ pub trait BaseStateWithExtensions<S: BaseState> {
get_extension_bytes::<S, V>(self.get_tlv_data())
}

/// Fetch the bytes for a TLV entry, where repetitions are allowed
fn get_repeating_extension_bytes<V: Extension>(
&self,
repetition: usize,
) -> Result<&[u8], ProgramError> {
get_all_extension_bytes::<S, V>(self.get_tlv_data()).map(|x| {
*x.get(repetition.saturating_sub(1))
.ok_or::<ProgramError>(ProgramError::InvalidAccountData)
.unwrap()
})
}

/// Fetch all bytes for each entry of a particular TLV entry, where
/// repetitions are allowed
fn get_all_extension_bytes<V: Extension>(&self) -> Result<Vec<&[u8]>, ProgramError> {
get_all_extension_bytes::<S, V>(self.get_tlv_data())
}

/// Unpack a portion of the TLV data as the desired type
fn get_extension<V: Extension + Pod>(&self) -> Result<&V, ProgramError> {
pod_from_bytes::<V>(self.get_extension_bytes::<V>()?)
}

/// Unpack a portion of the TLV data as the desired type, where repetitions
/// are allowed
fn get_repeating_extension<V: Extension + Pod>(
&self,
repetition: usize,
) -> Result<&V, ProgramError> {
pod_from_bytes::<V>(self.get_repeating_extension_bytes::<V>(repetition)?)
}

/// Unpack all extensions for the desired type
fn get_all_extensions<V: Extension + Pod>(&self) -> Result<Vec<&V>, ProgramError> {
self.get_all_extension_bytes::<V>()?
.iter()
.map(|bytes| pod_from_bytes::<V>(bytes))
.collect()
}

/// Unpacks a portion of the TLV data as the desired variable-length type
fn get_variable_len_extension<V: Extension + VariableLenPack>(
&self,
Expand Down Expand Up @@ -1410,6 +1498,54 @@ mod test {
1, 1, // data
];

const MINT_WITH_DUPLICATED_EXTENSION: &[u8] = &[
1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // base mint
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // padding
1, // account type
3, 0, // extension type
32, 0, // length
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, // data
3, 0, // extension type
32, 0, // length
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, // data
3, 0, // extension type
32, 0, // length
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, // data
];

const MINT_WITH_DUPLICATED_EXTENSION_AND_ONE_EXTRA: &[u8] = &[
1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // base mint
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // padding
1, // account type
3, 0, // extension type
32, 0, // length
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, // data
3, 0, // extension type
32, 0, // length
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, // data
3, 0, // extension type
32, 0, // length
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, // data
14, 0, // extension type
64, 0, // length
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
];

#[test]
fn unpack_opaque_buffer() {
let state = StateWithExtensions::<Mint>::unpack(MINT_WITH_EXTENSION).unwrap();
Expand All @@ -1435,6 +1571,46 @@ mod test {
assert_eq!(state.base, TEST_MINT);
}

#[test]
fn unpack_opaque_buffer_with_duplicates() {
let state = StateWithExtensions::<Mint>::unpack(MINT_WITH_DUPLICATED_EXTENSION).unwrap();
assert_eq!(state.base, TEST_MINT);
let all_extensions = state.get_all_extensions::<MintCloseAuthority>().unwrap();
assert_eq!(all_extensions.len(), 3);
assert_eq!(
state.get_extension::<TransferFeeConfig>(),
Err(ProgramError::InvalidAccountData)
);
assert_eq!(
StateWithExtensions::<Account>::unpack(MINT_WITH_DUPLICATED_EXTENSION),
Err(ProgramError::InvalidAccountData)
);

// test we can get a single entry
let state = StateWithExtensions::<Mint>::unpack(MINT_WITH_DUPLICATED_EXTENSION).unwrap();
let extension = state
.get_repeating_extension::<MintCloseAuthority>(1)
.unwrap();
let close_authority =
OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
assert_eq!(extension.close_authority, close_authority);

let state =
StateWithExtensions::<Mint>::unpack(MINT_WITH_DUPLICATED_EXTENSION_AND_ONE_EXTRA)
.unwrap();
assert_eq!(state.base, TEST_MINT);
let all_extensions = state.get_all_extensions::<MintCloseAuthority>().unwrap();
assert_eq!(all_extensions.len(), 3);
assert_eq!(
state.get_extension::<TransferFeeConfig>(),
Err(ProgramError::InvalidAccountData)
);
assert_eq!(
StateWithExtensions::<Account>::unpack(MINT_WITH_DUPLICATED_EXTENSION_AND_ONE_EXTRA),
Err(ProgramError::InvalidAccountData)
);
}

#[test]
fn fail_unpack_opaque_buffer() {
// input buffer too small
Expand Down

0 comments on commit 5114e94

Please sign in to comment.