diff --git a/token/client/src/token.rs b/token/client/src/token.rs index 8233c57bce0..ecefd01dcf6 100644 --- a/token/client/src/token.rs +++ b/token/client/src/token.rs @@ -3680,7 +3680,8 @@ where let account = self.get_account(self.pubkey).await?; let account_lamports = account.lamports; let mint_state = self.unpack_mint_info(account)?; - let new_account_len = mint_state.try_get_new_account_len(token_metadata)?; + let new_account_len = + mint_state.try_get_new_account_len_for_variable_len_extension(token_metadata)?; let new_rent_exempt_minimum = self .client .get_minimum_balance_for_rent_exemption(new_account_len) @@ -3762,7 +3763,8 @@ where let mint_state = self.unpack_mint_info(account)?; let mut token_metadata = mint_state.get_variable_len_extension::()?; token_metadata.update(field, value); - let new_account_len = mint_state.try_get_new_account_len(&token_metadata)?; + let new_account_len = + mint_state.try_get_new_account_len_for_variable_len_extension(&token_metadata)?; let new_rent_exempt_minimum = self .client .get_minimum_balance_for_rent_exemption(new_account_len) diff --git a/token/program-2022/src/extension/mod.rs b/token/program-2022/src/extension/mod.rs index 46a43a4562f..cd420e85dc6 100644 --- a/token/program-2022/src/extension/mod.rs +++ b/token/program-2022/src/extension/mod.rs @@ -404,12 +404,12 @@ pub trait BaseStateWithExtensions { /// /// Provides the correct answer regardless if the extension is already present /// in the TLV data. - fn try_get_new_account_len( + fn try_get_new_account_len_for_variable_len_extension_from_new_extension_len( &self, - new_extension: &V, + new_extension_len: usize, ) -> Result { // get the new length used by the extension - let new_extension_len = add_type_and_length_to_len(new_extension.get_packed_len()?); + let new_extension_tlv_len = add_type_and_length_to_len(new_extension_len); let tlv_info = get_tlv_data_info(self.get_tlv_data())?; // If we're adding an extension, then we must have at least BASE_ACCOUNT_LENGTH // and account type @@ -417,7 +417,7 @@ pub trait BaseStateWithExtensions { .used_len .saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH); let new_len = if tlv_info.extension_types.is_empty() { - current_len.saturating_add(new_extension_len) + current_len.saturating_add(new_extension_tlv_len) } else { // get the current length used by the extension let current_extension_len = self @@ -426,10 +426,29 @@ pub trait BaseStateWithExtensions { .unwrap_or(0); current_len .saturating_sub(current_extension_len) - .saturating_add(new_extension_len) + .saturating_add(new_extension_tlv_len) }; Ok(adjust_len_for_multisig(new_len)) } + + /// Calculate the new expected size if the state allocates the required + /// number of bytes for the given extension type. + fn try_get_new_account_len(&self) -> Result { + self.try_get_new_account_len_for_variable_len_extension_from_new_extension_len::( + pod_get_packed_len::(), + ) + } + + /// Calculate the new expected size if the state allocates the required + /// number of bytes for the given variable-length extension type. + fn try_get_new_account_len_for_variable_len_extension( + &self, + new_extension: &V, + ) -> Result { + self.try_get_new_account_len_for_variable_len_extension_from_new_extension_len::( + new_extension.get_packed_len()?, + ) + } } /// Encapsulates owned immutable base state data (mint or account) with possible extensions @@ -1178,6 +1197,58 @@ impl Extension for AccountPaddingTest { const TYPE: ExtensionType = ExtensionType::AccountPaddingTest; } +/// Packs a fixed-length extension into a TLV space +/// +/// This function reallocates the account as needed to accommodate for the +/// change in space. +/// +/// If the extension already exists, it will overwrite the existing extension +/// if `overwrite` is `true`, otherwise it will return an error. +/// +/// If the extension does not exist, it will reallocate the account and write +/// the extension into the TLV buffer. +/// +/// NOTE: Since this function deals with fixed-size extensions, it does not +/// handle _decreasing_ the size of an account's data buffer, like the function +/// `alloc_and_serialize_variable_len_extension` does. +pub fn alloc_and_serialize( + account_info: &AccountInfo, + new_extension: &V, + overwrite: bool, +) -> Result<(), ProgramError> { + let previous_account_len = account_info.try_data_len()?; + let (new_account_len, extension_already_exists) = { + let data = account_info.try_borrow_data()?; + let state = StateWithExtensions::::unpack(&data)?; + let new_account_len = state.try_get_new_account_len::()?; + let extension_already_exists = state.get_extension_bytes::().is_ok(); + (new_account_len, extension_already_exists) + }; + + if extension_already_exists { + if !overwrite { + return Err(TokenError::ExtensionAlreadyInitialized.into()); + } else { + // Overwrite the extension + let mut buffer = account_info.try_borrow_mut_data()?; + let mut state = StateWithExtensionsMut::::unpack(&mut buffer)?; + let extension = state.get_extension_mut::()?; + *extension = *new_extension; + } + } else { + // Realloc the account, then write the new extension + account_info.realloc(new_account_len, false)?; + let mut buffer = account_info.try_borrow_mut_data()?; + if previous_account_len <= BASE_ACCOUNT_LENGTH { + set_account_type::(*buffer)?; + } + let mut state = StateWithExtensionsMut::::unpack(&mut buffer)?; + let extension = state.init_extension::(false)?; + *extension = *new_extension; + } + Ok(()) +} + /// Packs a variable-length extension into a TLV space /// /// This function reallocates the account as needed to accommodate for the @@ -1186,7 +1257,7 @@ impl Extension for AccountPaddingTest { /// /// NOTE: Unlike the `reallocate` instruction, this function will reduce the /// size of an account if it has too many bytes allocated for the given value. -pub fn alloc_and_serialize( +pub fn alloc_and_serialize_variable_len_extension( account_info: &AccountInfo, new_extension: &V, overwrite: bool, @@ -1195,7 +1266,8 @@ pub fn alloc_and_serialize( let (new_account_len, extension_already_exists) = { let data = account_info.try_borrow_data()?; let state = StateWithExtensions::::unpack(&data)?; - let new_account_len = state.try_get_new_account_len(new_extension)?; + let new_account_len = + state.try_get_new_account_len_for_variable_len_extension(new_extension)?; let extension_already_exists = state.get_extension_bytes::().is_ok(); (new_account_len, extension_already_exists) }; @@ -2282,7 +2354,9 @@ mod test { let current_len = state.try_get_account_len().unwrap(); assert_eq!(current_len, Mint::LEN); let new_len = state - .try_get_new_account_len::(&variable_len) + .try_get_new_account_len_for_variable_len_extension::( + &variable_len, + ) .unwrap(); assert_eq!( new_len, @@ -2297,19 +2371,25 @@ mod test { // Reduce the extension size let new_len = state - .try_get_new_account_len::(&small_variable_len) + .try_get_new_account_len_for_variable_len_extension::( + &small_variable_len, + ) .unwrap(); assert_eq!(current_len.checked_sub(new_len).unwrap(), 1); // Increase the extension size let new_len = state - .try_get_new_account_len::(&big_variable_len) + .try_get_new_account_len_for_variable_len_extension::( + &big_variable_len, + ) .unwrap(); assert_eq!(new_len.checked_sub(current_len).unwrap(), 1); // Maintain the extension size let new_len = state - .try_get_new_account_len::(&variable_len) + .try_get_new_account_len_for_variable_len_extension::( + &variable_len, + ) .unwrap(); assert_eq!(new_len, current_len); } @@ -2382,7 +2462,8 @@ mod test { let key = Pubkey::new_unique(); let account_info = (&key, &mut data).into_account_info(); - alloc_and_serialize::(&account_info, &variable_len, false).unwrap(); + alloc_and_serialize_variable_len_extension::(&account_info, &variable_len, false) + .unwrap(); let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len); assert_eq!(data.len(), new_account_len); let state = StateWithExtensions::::unpack(data.data()).unwrap(); @@ -2395,12 +2476,18 @@ mod test { // alloc again succeeds with "overwrite" let account_info = (&key, &mut data).into_account_info(); - alloc_and_serialize::(&account_info, &variable_len, true).unwrap(); + alloc_and_serialize_variable_len_extension::(&account_info, &variable_len, true) + .unwrap(); // alloc again fails without "overwrite" let account_info = (&key, &mut data).into_account_info(); assert_eq!( - alloc_and_serialize::(&account_info, &variable_len, false).unwrap_err(), + alloc_and_serialize_variable_len_extension::( + &account_info, + &variable_len, + false, + ) + .unwrap_err(), TokenError::ExtensionAlreadyInitialized.into() ); } @@ -2429,7 +2516,8 @@ mod test { let key = Pubkey::new_unique(); let account_info = (&key, &mut data).into_account_info(); - alloc_and_serialize::(&account_info, &variable_len, false).unwrap(); + alloc_and_serialize_variable_len_extension::(&account_info, &variable_len, false) + .unwrap(); let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len) + add_type_and_length_to_len(size_of::()); @@ -2447,12 +2535,18 @@ mod test { // alloc again succeeds with "overwrite" let account_info = (&key, &mut data).into_account_info(); - alloc_and_serialize::(&account_info, &variable_len, true).unwrap(); + alloc_and_serialize_variable_len_extension::(&account_info, &variable_len, true) + .unwrap(); // alloc again fails without "overwrite" let account_info = (&key, &mut data).into_account_info(); assert_eq!( - alloc_and_serialize::(&account_info, &variable_len, false).unwrap_err(), + alloc_and_serialize_variable_len_extension::( + &account_info, + &variable_len, + false, + ) + .unwrap_err(), TokenError::ExtensionAlreadyInitialized.into() ); } @@ -2488,7 +2582,8 @@ mod test { let key = Pubkey::new_unique(); let account_info = (&key, &mut data).into_account_info(); let variable_len = VariableLenMintTest { data: vec![1, 2] }; - alloc_and_serialize::(&account_info, &variable_len, true).unwrap(); + alloc_and_serialize_variable_len_extension::(&account_info, &variable_len, true) + .unwrap(); let state = StateWithExtensions::::unpack(data.data()).unwrap(); let extension = state.get_extension::().unwrap(); @@ -2505,7 +2600,8 @@ mod test { let variable_len = VariableLenMintTest { data: vec![1, 2, 3, 4, 5, 6, 7], }; - alloc_and_serialize::(&account_info, &variable_len, true).unwrap(); + alloc_and_serialize_variable_len_extension::(&account_info, &variable_len, true) + .unwrap(); let state = StateWithExtensions::::unpack(data.data()).unwrap(); let extension = state.get_extension::().unwrap(); @@ -2522,7 +2618,8 @@ mod test { let variable_len = VariableLenMintTest { data: vec![7, 6, 5, 4, 3, 2, 1], }; - alloc_and_serialize::(&account_info, &variable_len, true).unwrap(); + alloc_and_serialize_variable_len_extension::(&account_info, &variable_len, true) + .unwrap(); let state = StateWithExtensions::::unpack(data.data()).unwrap(); let extension = state.get_extension::().unwrap(); diff --git a/token/program-2022/src/extension/token_metadata/processor.rs b/token/program-2022/src/extension/token_metadata/processor.rs index 88f58d8a9ec..a6cce12bc98 100644 --- a/token/program-2022/src/extension/token_metadata/processor.rs +++ b/token/program-2022/src/extension/token_metadata/processor.rs @@ -5,8 +5,8 @@ use { check_program_account, error::TokenError, extension::{ - alloc_and_serialize, metadata_pointer::MetadataPointer, BaseStateWithExtensions, - StateWithExtensions, + alloc_and_serialize_variable_len_extension, metadata_pointer::MetadataPointer, + BaseStateWithExtensions, StateWithExtensions, }, state::Mint, }, @@ -98,7 +98,7 @@ pub fn process_initialize( // allocate a TLV entry for the space and write it in, assumes that there's // enough SOL for the new rent-exemption - alloc_and_serialize::(metadata_info, &token_metadata, false)?; + alloc_and_serialize_variable_len_extension::(metadata_info, &token_metadata, false)?; Ok(()) } @@ -127,7 +127,7 @@ pub fn process_update_field( token_metadata.update(data.field, data.value); // Update / realloc the account - alloc_and_serialize::(metadata_info, &token_metadata, true)?; + alloc_and_serialize_variable_len_extension::(metadata_info, &token_metadata, true)?; Ok(()) } @@ -154,7 +154,7 @@ pub fn process_remove_key( if !token_metadata.remove_key(&data.key) && !data.idempotent { return Err(TokenMetadataError::KeyNotFound.into()); } - alloc_and_serialize::(metadata_info, &token_metadata, true)?; + alloc_and_serialize_variable_len_extension::(metadata_info, &token_metadata, true)?; Ok(()) } @@ -179,7 +179,7 @@ pub fn process_update_authority( check_update_authority(update_authority_info, &token_metadata.update_authority)?; token_metadata.update_authority = data.new_authority; // Update the account, no realloc needed! - alloc_and_serialize::(metadata_info, &token_metadata, true)?; + alloc_and_serialize_variable_len_extension::(metadata_info, &token_metadata, true)?; Ok(()) }