Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

token-2022: Add Pod-compatible versions of StateWithExtensions #6336

Merged
merged 6 commits into from
Mar 6, 2024
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 130 additions & 27 deletions token/program-2022/src/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,24 +515,42 @@ impl<'data, S: BaseState + Pack> StateWithExtensions<'data, S> {
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at(S::SIZE_OF);
let base = S::unpack(base_data)?;
if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
// type_and_tlv_indices() checks that returned indexes are within range
let account_type = AccountType::try_from(rest[account_type_index])
.map_err(|_| ProgramError::InvalidAccountData)?;
check_account_type::<S>(account_type)?;
Ok(Self {
base,
tlv_data: &rest[tlv_start_index..],
})
let tlv_data = unpack_tlv_data::<S>(rest)?;
Ok(Self { base, tlv_data })
}
}
impl<'a, S: BaseState + Pack> BaseStateWithExtensions<S> for StateWithExtensions<'a, S> {
fn get_tlv_data(&self) -> &[u8] {
self.tlv_data
}
}

/// Encapsulates immutable base state data (mint or account) with possible
/// extensions, where the base state is Pod for zero-copy serde.
#[derive(Debug, PartialEq)]
pub struct PodStateWithExtensions<'data, S: BaseState + Pod> {
/// Unpacked base data
pub base: &'data S,
/// Slice of data containing all TLV data, deserialized on demand
tlv_data: &'data [u8],
}
impl<'data, S: BaseState + Pod> PodStateWithExtensions<'data, S> {
/// Unpack base state, leaving the extension data as a slice
///
/// Fails if the base state is not initialized.
pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at(S::SIZE_OF);
let base = pod_from_bytes::<S>(base_data)?;
if !base.is_initialized() {
Err(ProgramError::UninitializedAccount)
} else {
Ok(Self {
base,
tlv_data: &[],
})
let tlv_data = unpack_tlv_data::<S>(rest)?;
Ok(Self { base, tlv_data })
}
}
}
impl<'a, S: BaseState + Pack> BaseStateWithExtensions<S> for StateWithExtensions<'a, S> {
impl<'a, S: BaseState + Pod> BaseStateWithExtensions<S> for PodStateWithExtensions<'a, S> {
fn get_tlv_data(&self) -> &[u8] {
self.tlv_data
}
Expand Down Expand Up @@ -769,6 +787,18 @@ pub trait BaseStateWithExtensionsMut<S: BaseState>: BaseStateWithExtensions<S> {
}
Ok(())
}

/// Check that the account type on the account (if initialized) matches the
/// account type for any extensions initialized on the TLV data
fn check_account_type_matches_extension_type(&self) -> Result<(), ProgramError> {
if let Some(extension_type) = self.get_first_extension_type()? {
let account_type = extension_type.get_account_type();
if account_type != S::ACCOUNT_TYPE {
return Err(TokenError::ExtensionBaseMismatch.into());
}
}
Ok(())
}
}

/// Encapsulates mutable base state data (mint or account) with possible
Expand All @@ -792,7 +822,7 @@ impl<'data, S: BaseState + Pack> StateWithExtensionsMut<'data, S> {
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
let base = S::unpack(base_data)?;
let (account_type, tlv_data) = unpack_type_and_tlv_data::<S>(rest)?;
let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
Ok(Self {
base,
base_data,
Expand All @@ -812,19 +842,14 @@ impl<'data, S: BaseState + Pack> StateWithExtensionsMut<'data, S> {
if base.is_initialized() {
return Err(TokenError::AlreadyInUse.into());
}
let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data::<S>(rest)?;
let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
let state = Self {
base,
base_data,
account_type,
tlv_data,
};
if let Some(extension_type) = state.get_first_extension_type()? {
let account_type = extension_type.get_account_type();
if account_type != S::ACCOUNT_TYPE {
return Err(TokenError::ExtensionBaseMismatch.into());
}
}
state.check_account_type_matches_extension_type()?;
Ok(state)
}

Expand All @@ -847,7 +872,85 @@ impl<'a, S: BaseState> BaseStateWithExtensionsMut<S> for StateWithExtensionsMut<
}
}

fn unpack_type_and_tlv_data_with_check<
/// Encapsulates mutable base state data (mint or account) with possible
/// extensions, where the base state is Pod for zero-copy serde.
#[derive(Debug, PartialEq)]
pub struct PodStateWithExtensionsMut<'data, S: BaseState> {
/// Unpacked base data
pub base: &'data mut S,
/// Writable account type
account_type: &'data mut [u8],
/// Slice of data containing all TLV data, deserialized on demand
tlv_data: &'data mut [u8],
}
impl<'data, S: BaseState + Pod> PodStateWithExtensionsMut<'data, S> {
/// Unpack base state, leaving the extension data as a mutable slice
///
/// Fails if the base state is not initialized.
pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
let base = pod_from_bytes_mut::<S>(base_data)?;
if !base.is_initialized() {
Err(ProgramError::UninitializedAccount)
} else {
let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
Ok(Self {
base,
account_type,
tlv_data,
})
}
}

/// Unpack an uninitialized base state, leaving the extension data as a
/// mutable slice
///
/// Fails if the base state has already been initialized.
pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
let base = pod_from_bytes_mut::<S>(base_data)?;
if base.is_initialized() {
return Err(TokenError::AlreadyInUse.into());
}
let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
let state = Self {
base,
account_type,
tlv_data,
};
state.check_account_type_matches_extension_type()?;
Ok(state)
}
}

impl<'a, S: BaseState> BaseStateWithExtensions<S> for PodStateWithExtensionsMut<'a, S> {
fn get_tlv_data(&self) -> &[u8] {
self.tlv_data
}
}
impl<'a, S: BaseState> BaseStateWithExtensionsMut<S> for PodStateWithExtensionsMut<'a, S> {
fn get_tlv_data_mut(&mut self) -> &mut [u8] {
self.tlv_data
}
fn get_account_type_mut(&mut self) -> &mut [u8] {
self.account_type
}
}

fn unpack_tlv_data<S: BaseState>(rest: &[u8]) -> Result<&[u8], ProgramError> {
if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
let account_type = AccountType::try_from(rest[account_type_index])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you restore the comment here? I do think it was helpful so that we can feel confident indexing into rest.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, will do!

.map_err(|_| ProgramError::InvalidAccountData)?;
check_account_type::<S>(account_type)?;
Ok(&rest[tlv_start_index..])
} else {
Ok(&[])
}
}

fn unpack_type_and_tlv_data_with_check_mut<
S: BaseState,
F: Fn(AccountType) -> Result<(), ProgramError>,
>(
Expand All @@ -869,16 +972,16 @@ fn unpack_type_and_tlv_data_with_check<
}
}

fn unpack_type_and_tlv_data<S: BaseState>(
fn unpack_type_and_tlv_data_mut<S: BaseState>(
rest: &mut [u8],
) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
unpack_type_and_tlv_data_with_check::<S, _>(rest, check_account_type::<S>)
unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, check_account_type::<S>)
}

fn unpack_uninitialized_type_and_tlv_data<S: BaseState>(
fn unpack_uninitialized_type_and_tlv_data_mut<S: BaseState>(
rest: &mut [u8],
) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
unpack_type_and_tlv_data_with_check::<S, _>(rest, |account_type| {
unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, |account_type| {
if account_type != AccountType::Uninitialized {
Err(ProgramError::InvalidAccountData)
} else {
Expand Down
Loading