Skip to content

Commit

Permalink
ERC721Enumerable helper function (#1196)
Browse files Browse the repository at this point in the history
* Implement helper function for ERC721Enumerable

* Change function name

* Run linter

* Add changelog entry

* Add documentation

* Remove unnecessary Enumerable mocks

* Add ERC721EnumerableExtended impl to ERC721Enumerable mock

* Remove IERC721EnumerableExtended interface and make all_tokens_of_owner an internal function

* Update doc

* Update changelog

* Add more test cases

* Try to fix coverage issue

* Revert test coverage fix changes
  • Loading branch information
immrsd authored Nov 6, 2024
1 parent 9a37244 commit 02c9ce6
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `is_valid_p256_signature` utility function to `openzeppelin_account::utils::signature` (#1189)
- `Secp256r1KeyPair` type and helpers to `openzeppelin_testing::signing` (#1189)
- `all_tokens_of_owner` function to `ERC721EnumerableComponent` fetching all owner's tokens in a single call (#1196)
- Embeddable impls for ERC2981 component (#1173)
- `ERC2981Info` with read functions for discovering the component's state
- `ERC2981AdminOwnable` providing admin functions for a token that implements Ownable component
Expand Down
13 changes: 13 additions & 0 deletions docs/modules/ROOT/pages/api/erc721.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,7 @@ mod ERC721EnumerableContract {
.InternalImpl
* xref:#ERC721EnumerableComponent-initializer[`++initializer(self)++`]
* xref:#ERC721EnumerableComponent-before_update[`++before_update(self, to, token_id)++`]
* xref:#ERC721EnumerableComponent-all_tokens_of_owner[`++all_tokens_of_owner(self, owner)++`]
* xref:#ERC721EnumerableComponent-_add_token_to_owner_enumeration[`++_add_token_to_owner_enumeration(self, to, token_id)++`]
* xref:#ERC721EnumerableComponent-_add_token_to_all_tokens_enumeration[`++_add_token_to_all_tokens_enumeration(self, token_id)++`]
* xref:#ERC721EnumerableComponent-_remove_token_from_owner_enumeration[`++_remove_token_from_owner_enumeration(self, from, token_id)++`]
Expand Down Expand Up @@ -943,6 +944,18 @@ When a token is transferred, minted, or burned, the ownership-tracking data stru

This must be added to the implementing contract's xref:ERC721Component-before_update[ERC721HooksTrait::before_update] hook.

[.contract-item]
[[ERC721EnumerableComponent-all_tokens_of_owner]]
==== `[.contract-item-name]#++all_tokens_of_owner++#++(self: @ContractState, owner: ContractAddress) → Span<u256>++` [.item-kind]#internal#

Returns a list of all token ids owned by the specified `owner`.
This function provides a more efficient alternative to calling `ERC721::balance_of`
and iterating through tokens with `ERC721Enumerable::token_of_owner_by_index`.

Requirements:

- `owner` is not the zero address.

[.contract-item]
[[ERC721EnumerableComponent-_add_token_to_owner_enumeration]]
==== `[.contract-item-name]#++_add_token_to_owner_enumeration++#++(ref self: ContractState, to: ContractAddress, token_id: u256)++` [.item-kind]#internal#
Expand Down
10 changes: 10 additions & 0 deletions packages/test_common/src/mocks/erc721.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ pub mod DualCaseERC721ReceiverMock {
pub mod ERC721EnumerableMock {
use openzeppelin_introspection::src5::SRC5Component;
use openzeppelin_token::erc721::ERC721Component;
use openzeppelin_token::erc721::extensions::ERC721EnumerableComponent::InternalTrait;
use openzeppelin_token::erc721::extensions::ERC721EnumerableComponent;
use starknet::ContractAddress;

Expand Down Expand Up @@ -238,6 +239,15 @@ pub mod ERC721EnumerableMock {
}
}

#[generate_trait]
#[abi(per_item)]
impl ExternalImpl of ExternalTrait {
#[external(v0)]
fn all_tokens_of_owner(self: @ContractState, owner: ContractAddress) -> Span<u256> {
self.erc721_enumerable.all_tokens_of_owner(owner)
}
}

#[constructor]
fn constructor(
ref self: ContractState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ pub mod ERC721EnumerableComponent {
use crate::erc721::ERC721Component::ERC721Impl;
use crate::erc721::ERC721Component::InternalImpl as ERC721InternalImpl;
use crate::erc721::ERC721Component;
use crate::erc721::extensions::erc721_enumerable::interface::IERC721Enumerable;
use crate::erc721::extensions::erc721_enumerable::interface;
use openzeppelin_introspection::src5::SRC5Component::InternalTrait as SRC5InternalTrait;
use openzeppelin_introspection::src5::SRC5Component;
Expand Down Expand Up @@ -51,7 +50,7 @@ pub mod ERC721EnumerableComponent {
+ERC721Component::ERC721HooksTrait<TContractState>,
+SRC5Component::HasComponent<TContractState>,
+Drop<TContractState>
> of IERC721Enumerable<ComponentState<TContractState>> {
> of interface::IERC721Enumerable<ComponentState<TContractState>> {
/// Returns the total amount of tokens stored by the contract.
fn total_supply(self: @ComponentState<TContractState>) -> u256 {
self.ERC721Enumerable_all_tokens_len.read()
Expand Down Expand Up @@ -133,6 +132,25 @@ pub mod ERC721EnumerableComponent {
}
}

/// Returns a list of all token ids owned by the specified `owner`.
/// This function provides a more efficient alternative to calling `ERC721::balance_of`
/// and iterating through tokens with `ERC721Enumerable::token_of_owner_by_index`.
///
/// Requirements:
///
/// - `owner` is not the zero address.
fn all_tokens_of_owner(
self: @ComponentState<TContractState>, owner: ContractAddress
) -> Span<u256> {
let mut result = array![];
let balance = get_dep_component!(self, ERC721).balance_of(owner);
for index in 0
..balance {
result.append(self.ERC721Enumerable_owned_tokens.read((owner, index)));
};
result.span()
}

/// Adds token to this extension's ownership-tracking data structures.
fn _add_token_to_owner_enumeration(
ref self: ComponentState<TContractState>, to: ContractAddress, token_id: u256
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,11 @@ pub trait IERC721Enumerable<TState> {
fn token_by_index(self: @TState, index: u256) -> u256;
fn token_of_owner_by_index(self: @TState, owner: ContractAddress, index: u256) -> u256;
}

#[starknet::interface]
pub trait ERC721EnumerableABI<TState> {
fn total_supply(self: @TState) -> u256;
fn token_by_index(self: @TState, index: u256) -> u256;
fn token_of_owner_by_index(self: @TState, owner: ContractAddress, index: u256) -> u256;
fn all_tokens_of_owner(self: @TState, owner: ContractAddress) -> Span<u256>;
}
81 changes: 68 additions & 13 deletions packages/token/src/tests/erc721/test_erc721_enumerable.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::erc721::extensions::erc721_enumerable::ERC721EnumerableComponent::{
ERC721EnumerableImpl, InternalImpl
};
use crate::erc721::extensions::erc721_enumerable::ERC721EnumerableComponent;
use crate::erc721::extensions::erc721_enumerable::interface;
use crate::erc721::extensions::erc721_enumerable::interface::IERC721ENUMERABLE_ID;
use openzeppelin_introspection::interface::ISRC5_ID;
use openzeppelin_introspection::src5::SRC5Component::SRC5Impl;
use openzeppelin_test_common::mocks::erc721::ERC721EnumerableMock;
Expand Down Expand Up @@ -57,7 +57,7 @@ fn test_initializer() {

state.initializer();

let supports_ierc721_enum = mock_state.supports_interface(interface::IERC721ENUMERABLE_ID);
let supports_ierc721_enum = mock_state.supports_interface(IERC721ENUMERABLE_ID);
assert!(supports_ierc721_enum);

let supports_isrc5 = mock_state.supports_interface(ISRC5_ID);
Expand Down Expand Up @@ -102,7 +102,7 @@ fn test_token_by_index() {
}

#[test]
#[should_panic(expected: ('ERC721Enum: out of bounds index',))]
#[should_panic(expected: 'ERC721Enum: out of bounds index')]
fn test_token_by_index_equal_to_supply() {
let (state, token_list) = setup();
let supply = token_list.len().into();
Expand All @@ -111,7 +111,7 @@ fn test_token_by_index_equal_to_supply() {
}

#[test]
#[should_panic(expected: ('ERC721Enum: out of bounds index',))]
#[should_panic(expected: 'ERC721Enum: out of bounds index')]
fn test_token_by_index_greater_than_supply() {
let (state, token_list) = setup();
let supply_plus_one = token_list.len().into() + 1;
Expand All @@ -121,7 +121,7 @@ fn test_token_by_index_greater_than_supply() {

#[test]
fn test_token_by_index_burn_last_token() {
let (_, _) = setup();
let _ = setup();
let mut contract_state = CONTRACT_STATE();
let last_token = TOKEN_3;

Expand All @@ -133,7 +133,7 @@ fn test_token_by_index_burn_last_token() {

#[test]
fn test_token_by_index_burn_first_token() {
let (_, _) = setup();
let _ = setup();
let mut contract_state = CONTRACT_STATE();
let first_token = TOKEN_1;

Expand Down Expand Up @@ -177,7 +177,7 @@ fn test_token_of_owner_by_index() {
}

#[test]
#[should_panic(expected: ('ERC721Enum: out of bounds index',))]
#[should_panic(expected: 'ERC721Enum: out of bounds index')]
fn test_token_of_owner_by_index_when_index_equals_owned_tokens() {
let (state, tokens_list) = setup();
let owned_token_len = tokens_list.len().into();
Expand All @@ -186,7 +186,7 @@ fn test_token_of_owner_by_index_when_index_equals_owned_tokens() {
}

#[test]
#[should_panic(expected: ('ERC721Enum: out of bounds index',))]
#[should_panic(expected: 'ERC721Enum: out of bounds index')]
fn test_token_of_owner_by_index_when_index_exceeds_owned_tokens() {
let (state, tokens_list) = setup();
let owned_tokens_len_plus_one = tokens_list.len().into() + 1;
Expand All @@ -195,15 +195,15 @@ fn test_token_of_owner_by_index_when_index_exceeds_owned_tokens() {
}

#[test]
#[should_panic(expected: ('ERC721Enum: out of bounds index',))]
#[should_panic(expected: 'ERC721Enum: out of bounds index')]
fn test_token_of_owner_by_index_when_target_has_no_tokens() {
let (state, _) = setup();

state.token_of_owner_by_index(OTHER(), 0);
}

#[test]
#[should_panic(expected: ('ERC721: invalid account',))]
#[should_panic(expected: 'ERC721: invalid account')]
fn test_token_of_owner_by_index_when_owner_is_zero() {
let (state, _) = setup();

Expand Down Expand Up @@ -503,6 +503,58 @@ fn test__remove_token_from_all_tokens_enumeration_with_first_token() {
assert_eq!(initial_supply - 1, new_supply);
}

//
// all_tokens_of_owner
//

#[test]
fn test_all_tokens_of_owner() {
let (_, tokens_list) = setup();
assert_all_tokens_of_owner(OWNER(), tokens_list);
}

#[test]
fn test_all_tokens_of_owner_after_transfer_first_token() {
let _ = setup();
let mut contract_state = CONTRACT_STATE();

contract_state.erc721.transfer(OWNER(), RECIPIENT(), TOKEN_1);

assert_all_tokens_of_owner(OWNER(), array![TOKEN_3, TOKEN_2].span());
assert_all_tokens_of_owner(RECIPIENT(), array![TOKEN_1].span());
}

#[test]
fn test_all_tokens_of_owner_after_transfer_last_token() {
let _ = setup();
let mut contract_state = CONTRACT_STATE();

contract_state.erc721.transfer(OWNER(), RECIPIENT(), TOKEN_3);

assert_all_tokens_of_owner(OWNER(), array![TOKEN_1, TOKEN_2].span());
assert_all_tokens_of_owner(RECIPIENT(), array![TOKEN_3].span());
}

#[test]
fn test_all_tokens_of_owner_after_burn_first_token() {
let _ = setup();
let mut contract_state = CONTRACT_STATE();

contract_state.erc721.burn(TOKEN_1);

assert_all_tokens_of_owner(OWNER(), array![TOKEN_3, TOKEN_2].span());
}

#[test]
fn test_all_tokens_of_owner_after_burn_last_token() {
let _ = setup();
let mut contract_state = CONTRACT_STATE();

contract_state.erc721.burn(TOKEN_3);

assert_all_tokens_of_owner(OWNER(), array![TOKEN_1, TOKEN_2].span());
}

//
// Helpers
//
Expand Down Expand Up @@ -568,21 +620,24 @@ fn assert_all_tokens_index_to_id(index: u256, exp_token_id: u256) {

fn assert_all_tokens_id_to_index(token_id: u256, exp_index: u256) {
let state = @COMPONENT_STATE();

let id_to_index = state.ERC721Enumerable_all_tokens_index.read(token_id);
assert_eq!(id_to_index, exp_index);
}

fn assert_owner_tokens_index_to_id(owner: ContractAddress, index: u256, exp_token_id: u256) {
let state = @COMPONENT_STATE();

let index_to_id = state.ERC721Enumerable_owned_tokens.read((owner, index));
assert_eq!(index_to_id, exp_token_id);
}

fn assert_owner_tokens_id_to_index(token_id: u256, exp_index: u256) {
let state = @COMPONENT_STATE();

let id_to_index = state.ERC721Enumerable_owned_tokens_index.read(token_id);
assert_eq!(id_to_index, exp_index);
}

fn assert_all_tokens_of_owner(owner: ContractAddress, exp_tokens: Span<u256>) {
let state = @COMPONENT_STATE();
let tokens = state.all_tokens_of_owner(owner);
assert_eq!(tokens, exp_tokens);
}

0 comments on commit 02c9ce6

Please sign in to comment.