Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ERC721Enumerable helper function #1196

Merged
merged 17 commits into from
Nov 6, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `is_valid_p256_signature` utility function to `openzeppelin_account::utils::signature` (#1189)
- `Secp256r1KeyPair` type and helpers to `openzeppelin_testing::signing` (#1189)
- `all_tokens_of_owner` function to `ERC721EnumerableComponent` fetching all owner's tokens in a single call (#1196)
- Embeddable impls for ERC2981 component (#1173)
- `ERC2981Info` with read functions for discovering the component's state
- `ERC2981AdminOwnable` providing admin functions for a token that implements Ownable component
Expand Down
13 changes: 13 additions & 0 deletions docs/modules/ROOT/pages/api/erc721.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,7 @@ mod ERC721EnumerableContract {
.InternalImpl
* xref:#ERC721EnumerableComponent-initializer[`++initializer(self)++`]
* xref:#ERC721EnumerableComponent-before_update[`++before_update(self, to, token_id)++`]
* xref:#ERC721EnumerableComponent-all_tokens_of_owner[`++all_tokens_of_owner(self, owner)++`]
* xref:#ERC721EnumerableComponent-_add_token_to_owner_enumeration[`++_add_token_to_owner_enumeration(self, to, token_id)++`]
* xref:#ERC721EnumerableComponent-_add_token_to_all_tokens_enumeration[`++_add_token_to_all_tokens_enumeration(self, token_id)++`]
* xref:#ERC721EnumerableComponent-_remove_token_from_owner_enumeration[`++_remove_token_from_owner_enumeration(self, from, token_id)++`]
Expand Down Expand Up @@ -943,6 +944,18 @@ When a token is transferred, minted, or burned, the ownership-tracking data stru

This must be added to the implementing contract's xref:ERC721Component-before_update[ERC721HooksTrait::before_update] hook.

[.contract-item]
[[ERC721EnumerableComponent-all_tokens_of_owner]]
==== `[.contract-item-name]#++all_tokens_of_owner++#++(self: @ContractState, owner: ContractAddress) → Span<u256>++` [.item-kind]#internal#

Returns a list of all token ids owned by the specified `owner`.
This function provides a more efficient alternative to calling `ERC721::balance_of`
and iterating through tokens with `ERC721Enumerable::token_of_owner_by_index`.

Requirements:

- `owner` is not the zero address.

[.contract-item]
[[ERC721EnumerableComponent-_add_token_to_owner_enumeration]]
==== `[.contract-item-name]#++_add_token_to_owner_enumeration++#++(ref self: ContractState, to: ContractAddress, token_id: u256)++` [.item-kind]#internal#
Expand Down
10 changes: 10 additions & 0 deletions packages/test_common/src/mocks/erc721.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ pub mod DualCaseERC721ReceiverMock {
pub mod ERC721EnumerableMock {
use openzeppelin_introspection::src5::SRC5Component;
use openzeppelin_token::erc721::ERC721Component;
use openzeppelin_token::erc721::extensions::ERC721EnumerableComponent::InternalTrait;
use openzeppelin_token::erc721::extensions::ERC721EnumerableComponent;
use starknet::ContractAddress;

Expand Down Expand Up @@ -238,6 +239,15 @@ pub mod ERC721EnumerableMock {
}
}

#[generate_trait]
#[abi(per_item)]
impl ExternalImpl of ExternalTrait {
#[external(v0)]
fn all_tokens_of_owner(self: @ContractState, owner: ContractAddress) -> Span<u256> {
self.erc721_enumerable.all_tokens_of_owner(owner)
}
}

#[constructor]
fn constructor(
ref self: ContractState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ pub mod ERC721EnumerableComponent {
use crate::erc721::ERC721Component::ERC721Impl;
use crate::erc721::ERC721Component::InternalImpl as ERC721InternalImpl;
use crate::erc721::ERC721Component;
use crate::erc721::extensions::erc721_enumerable::interface::IERC721Enumerable;
use crate::erc721::extensions::erc721_enumerable::interface;
use openzeppelin_introspection::src5::SRC5Component::InternalTrait as SRC5InternalTrait;
use openzeppelin_introspection::src5::SRC5Component;
Expand Down Expand Up @@ -51,7 +50,7 @@ pub mod ERC721EnumerableComponent {
+ERC721Component::ERC721HooksTrait<TContractState>,
+SRC5Component::HasComponent<TContractState>,
+Drop<TContractState>
> of IERC721Enumerable<ComponentState<TContractState>> {
> of interface::IERC721Enumerable<ComponentState<TContractState>> {
/// Returns the total amount of tokens stored by the contract.
fn total_supply(self: @ComponentState<TContractState>) -> u256 {
self.ERC721Enumerable_all_tokens_len.read()
Expand Down Expand Up @@ -133,6 +132,26 @@ pub mod ERC721EnumerableComponent {
}
}

/// Returns a list of all token ids owned by the specified `owner`.
/// This function provides a more efficient alternative to calling `ERC721::balance_of`
/// and iterating through tokens with `ERC721Enumerable::token_of_owner_by_index`.
///
/// Requirements:
///
/// - `owner` is not the zero address.
fn all_tokens_of_owner(
self: @ComponentState<TContractState>, owner: ContractAddress
) -> Span<u256> {
let mut result = array![];
let balance = get_dep_component!(self, ERC721).balance_of(owner);
for index in 0
..balance {
result.append(self.ERC721Enumerable_owned_tokens.read((owner, index)));
};
let result = result.span();
result
}

/// Adds token to this extension's ownership-tracking data structures.
fn _add_token_to_owner_enumeration(
ref self: ComponentState<TContractState>, to: ContractAddress, token_id: u256
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,11 @@ pub trait IERC721Enumerable<TState> {
fn token_by_index(self: @TState, index: u256) -> u256;
fn token_of_owner_by_index(self: @TState, owner: ContractAddress, index: u256) -> u256;
}

#[starknet::interface]
pub trait ERC721EnumerableABI<TState> {
fn total_supply(self: @TState) -> u256;
fn token_by_index(self: @TState, index: u256) -> u256;
fn token_of_owner_by_index(self: @TState, owner: ContractAddress, index: u256) -> u256;
fn all_tokens_of_owner(self: @TState, owner: ContractAddress) -> Span<u256>;
}
81 changes: 68 additions & 13 deletions packages/token/src/tests/erc721/test_erc721_enumerable.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::erc721::extensions::erc721_enumerable::ERC721EnumerableComponent::{
ERC721EnumerableImpl, InternalImpl
};
use crate::erc721::extensions::erc721_enumerable::ERC721EnumerableComponent;
use crate::erc721::extensions::erc721_enumerable::interface;
use crate::erc721::extensions::erc721_enumerable::interface::IERC721ENUMERABLE_ID;
use openzeppelin_introspection::interface::ISRC5_ID;
use openzeppelin_introspection::src5::SRC5Component::SRC5Impl;
use openzeppelin_test_common::mocks::erc721::ERC721EnumerableMock;
Expand Down Expand Up @@ -57,7 +57,7 @@ fn test_initializer() {

state.initializer();

let supports_ierc721_enum = mock_state.supports_interface(interface::IERC721ENUMERABLE_ID);
let supports_ierc721_enum = mock_state.supports_interface(IERC721ENUMERABLE_ID);
assert!(supports_ierc721_enum);

let supports_isrc5 = mock_state.supports_interface(ISRC5_ID);
Expand Down Expand Up @@ -102,7 +102,7 @@ fn test_token_by_index() {
}

#[test]
#[should_panic(expected: ('ERC721Enum: out of bounds index',))]
#[should_panic(expected: 'ERC721Enum: out of bounds index')]
fn test_token_by_index_equal_to_supply() {
let (state, token_list) = setup();
let supply = token_list.len().into();
Expand All @@ -111,7 +111,7 @@ fn test_token_by_index_equal_to_supply() {
}

#[test]
#[should_panic(expected: ('ERC721Enum: out of bounds index',))]
#[should_panic(expected: 'ERC721Enum: out of bounds index')]
fn test_token_by_index_greater_than_supply() {
let (state, token_list) = setup();
let supply_plus_one = token_list.len().into() + 1;
Expand All @@ -121,7 +121,7 @@ fn test_token_by_index_greater_than_supply() {

#[test]
fn test_token_by_index_burn_last_token() {
let (_, _) = setup();
let _ = setup();
let mut contract_state = CONTRACT_STATE();
let last_token = TOKEN_3;

Expand All @@ -133,7 +133,7 @@ fn test_token_by_index_burn_last_token() {

#[test]
fn test_token_by_index_burn_first_token() {
let (_, _) = setup();
let _ = setup();
let mut contract_state = CONTRACT_STATE();
let first_token = TOKEN_1;

Expand Down Expand Up @@ -177,7 +177,7 @@ fn test_token_of_owner_by_index() {
}

#[test]
#[should_panic(expected: ('ERC721Enum: out of bounds index',))]
#[should_panic(expected: 'ERC721Enum: out of bounds index')]
fn test_token_of_owner_by_index_when_index_equals_owned_tokens() {
let (state, tokens_list) = setup();
let owned_token_len = tokens_list.len().into();
Expand All @@ -186,7 +186,7 @@ fn test_token_of_owner_by_index_when_index_equals_owned_tokens() {
}

#[test]
#[should_panic(expected: ('ERC721Enum: out of bounds index',))]
#[should_panic(expected: 'ERC721Enum: out of bounds index')]
fn test_token_of_owner_by_index_when_index_exceeds_owned_tokens() {
let (state, tokens_list) = setup();
let owned_tokens_len_plus_one = tokens_list.len().into() + 1;
Expand All @@ -195,15 +195,15 @@ fn test_token_of_owner_by_index_when_index_exceeds_owned_tokens() {
}

#[test]
#[should_panic(expected: ('ERC721Enum: out of bounds index',))]
#[should_panic(expected: 'ERC721Enum: out of bounds index')]
fn test_token_of_owner_by_index_when_target_has_no_tokens() {
let (state, _) = setup();

state.token_of_owner_by_index(OTHER(), 0);
}

#[test]
#[should_panic(expected: ('ERC721: invalid account',))]
#[should_panic(expected: 'ERC721: invalid account')]
fn test_token_of_owner_by_index_when_owner_is_zero() {
let (state, _) = setup();

Expand Down Expand Up @@ -503,6 +503,58 @@ fn test__remove_token_from_all_tokens_enumeration_with_first_token() {
assert_eq!(initial_supply - 1, new_supply);
}

//
// all_tokens_of_owner
//

#[test]
fn test_all_tokens_of_owner() {
let (_, tokens_list) = setup();
assert_all_tokens_of_owner(OWNER(), tokens_list);
}

#[test]
fn test_all_tokens_of_owner_after_transfer_first_token() {
let _ = setup();
let mut contract_state = CONTRACT_STATE();

contract_state.erc721.transfer(OWNER(), RECIPIENT(), TOKEN_1);

assert_all_tokens_of_owner(OWNER(), array![TOKEN_3, TOKEN_2].span());
assert_all_tokens_of_owner(RECIPIENT(), array![TOKEN_1].span());
}

#[test]
fn test_all_tokens_of_owner_after_transfer_last_token() {
let _ = setup();
let mut contract_state = CONTRACT_STATE();

contract_state.erc721.transfer(OWNER(), RECIPIENT(), TOKEN_3);

assert_all_tokens_of_owner(OWNER(), array![TOKEN_1, TOKEN_2].span());
assert_all_tokens_of_owner(RECIPIENT(), array![TOKEN_3].span());
}

#[test]
fn test_all_tokens_of_owner_after_burn_first_token() {
let _ = setup();
let mut contract_state = CONTRACT_STATE();

contract_state.erc721.burn(TOKEN_1);

assert_all_tokens_of_owner(OWNER(), array![TOKEN_3, TOKEN_2].span());
}

#[test]
fn test_all_tokens_of_owner_after_burn_last_token() {
let _ = setup();
let mut contract_state = CONTRACT_STATE();

contract_state.erc721.burn(TOKEN_3);

assert_all_tokens_of_owner(OWNER(), array![TOKEN_1, TOKEN_2].span());
}

//
// Helpers
//
Expand Down Expand Up @@ -568,21 +620,24 @@ fn assert_all_tokens_index_to_id(index: u256, exp_token_id: u256) {

fn assert_all_tokens_id_to_index(token_id: u256, exp_index: u256) {
let state = @COMPONENT_STATE();

let id_to_index = state.ERC721Enumerable_all_tokens_index.read(token_id);
assert_eq!(id_to_index, exp_index);
}

fn assert_owner_tokens_index_to_id(owner: ContractAddress, index: u256, exp_token_id: u256) {
let state = @COMPONENT_STATE();

let index_to_id = state.ERC721Enumerable_owned_tokens.read((owner, index));
assert_eq!(index_to_id, exp_token_id);
}

fn assert_owner_tokens_id_to_index(token_id: u256, exp_index: u256) {
let state = @COMPONENT_STATE();

let id_to_index = state.ERC721Enumerable_owned_tokens_index.read(token_id);
assert_eq!(id_to_index, exp_index);
}

fn assert_all_tokens_of_owner(owner: ContractAddress, exp_tokens: Span<u256>) {
let state = @COMPONENT_STATE();
let tokens = state.all_tokens_of_owner(owner);
assert_eq!(tokens, exp_tokens);
}