diff --git a/packages/macros/src/attribute/with_components/diagnostics.rs b/packages/macros/src/attribute/with_components/diagnostics.rs index 7fefd0e5a..e2ac22a2a 100644 --- a/packages/macros/src/attribute/with_components/diagnostics.rs +++ b/packages/macros/src/attribute/with_components/diagnostics.rs @@ -13,6 +13,17 @@ pub mod errors { pub fn NO_CONTRACT_ATTRIBUTE(contract_attribute: &str) -> String { format!("Contract module must have the `#[{contract_attribute}]` attribute.\n") } + /// Error when there are duplicate components in the attribute. + pub fn DUPLICATE_COMPONENTS(components: &[&str]) -> String { + if components.len() == 1 { + format!("Component {} is specified multiple times. Each component must only be listed once.\n", components[0]) + } else { + let mut sorted_components = components.to_vec(); + sorted_components.sort(); + let components_str = sorted_components.join(", "); + format!("Components [{}] are specified multiple times. Each component must only be listed once.\n", components_str) + } + } } #[allow(non_snake_case)] diff --git a/packages/macros/src/attribute/with_components/parser.rs b/packages/macros/src/attribute/with_components/parser.rs index cd79cfac1..1227d3e86 100644 --- a/packages/macros/src/attribute/with_components/parser.rs +++ b/packages/macros/src/attribute/with_components/parser.rs @@ -97,22 +97,38 @@ fn validate_contract_module( ) -> (Vec, Vec) { let mut warnings = vec![]; + // 1. Check for duplicate components (error) + let mut component_counts = std::collections::HashMap::new(); + for component_info in components_info.iter() { + let component_name = component_info.short_name(); + *component_counts.entry(component_name).or_insert(0) += 1; + } + let duplicates: Vec<&str> = component_counts + .iter() + .filter(|(_, &count)| count > 1) + .map(|(&name, _)| name) + .collect(); + if !duplicates.is_empty() { + let error = Diagnostic::error(errors::DUPLICATE_COMPONENTS(&duplicates)); + return (vec![error], vec![]); + } + if let RewriteNode::Copied(copied) = node { let item = ast::ItemModule::from_syntax_node(db, *copied); - // 1. Check that the module has a body (error) + // 2. Check that the module has a body (error) let MaybeModuleBody::Some(body) = item.body(db) else { let error = Diagnostic::error(errors::NO_BODY); return (vec![error], vec![]); }; - // 2. Check that the module has the `#[starknet::contract]` attribute (error) + // 3. Check that the module has the `#[starknet::contract]` attribute (error) if !item.has_attr(db, CONTRACT_ATTRIBUTE) { let error = Diagnostic::error(errors::NO_CONTRACT_ATTRIBUTE(CONTRACT_ATTRIBUTE)); return (vec![error], vec![]); } - // 3. Check that the module has the corresponding initializers (warning) + // 4. Check that the module has the corresponding initializers (warning) let components_with_initializer = components_info .iter() .filter(|c| c.has_initializer) @@ -151,7 +167,7 @@ fn validate_contract_module( } } - // 4. Check that the contract has the corresponding immutable configs + // 5. Check that the contract has the corresponding immutable configs (warning) for component in components_info.iter().filter(|c| c.has_immutable_config) { // Get the body code (maybe we can do this without the builder) let body_ast = body.as_syntax_node(); diff --git a/packages/macros/src/tests/snapshots/openzeppelin_macros__tests__test_with_components__with_component_used_three_times.snap b/packages/macros/src/tests/snapshots/openzeppelin_macros__tests__test_with_components__with_component_used_three_times.snap new file mode 100644 index 000000000..59ca55476 --- /dev/null +++ b/packages/macros/src/tests/snapshots/openzeppelin_macros__tests__test_with_components__with_component_used_three_times.snap @@ -0,0 +1,18 @@ +--- +source: src/tests/test_with_components.rs +assertion_line: 1692 +expression: result +--- +TokenStream: + +None + +Diagnostics: + +==== +Error: Component ERC20 is specified multiple times. Each component must only be listed once. +==== + +AuxData: + +None diff --git a/packages/macros/src/tests/snapshots/openzeppelin_macros__tests__test_with_components__with_component_used_twice.snap b/packages/macros/src/tests/snapshots/openzeppelin_macros__tests__test_with_components__with_component_used_twice.snap new file mode 100644 index 000000000..01f0cc17b --- /dev/null +++ b/packages/macros/src/tests/snapshots/openzeppelin_macros__tests__test_with_components__with_component_used_twice.snap @@ -0,0 +1,18 @@ +--- +source: src/tests/test_with_components.rs +assertion_line: 1668 +expression: result +--- +TokenStream: + +None + +Diagnostics: + +==== +Error: Component Ownable is specified multiple times. Each component must only be listed once. +==== + +AuxData: + +None diff --git a/packages/macros/src/tests/snapshots/openzeppelin_macros__tests__test_with_components__with_multiple_duplicate_components.snap b/packages/macros/src/tests/snapshots/openzeppelin_macros__tests__test_with_components__with_multiple_duplicate_components.snap new file mode 100644 index 000000000..811282ba9 --- /dev/null +++ b/packages/macros/src/tests/snapshots/openzeppelin_macros__tests__test_with_components__with_multiple_duplicate_components.snap @@ -0,0 +1,18 @@ +--- +source: src/tests/test_with_components.rs +assertion_line: 1717 +expression: result +--- +TokenStream: + +None + +Diagnostics: + +==== +Error: Components [ERC20, Ownable] are specified multiple times. Each component must only be listed once. +==== + +AuxData: + +None diff --git a/packages/macros/src/tests/test_with_components.rs b/packages/macros/src/tests/test_with_components.rs index eb3796fb3..daa5a448c 100644 --- a/packages/macros/src/tests/test_with_components.rs +++ b/packages/macros/src/tests/test_with_components.rs @@ -1645,6 +1645,78 @@ fn test_with_governor_integration() { assert_snapshot!(result); } +#[test] +fn test_with_component_used_twice() { + let attribute = "(Ownable, Ownable)"; + let item = indoc!( + " + #[starknet::contract] + pub mod MyContract { + use starknet::ContractAddress; + + #[storage] + pub struct Storage {} + + #[constructor] + fn constructor(ref self: ContractState, owner: ContractAddress) { + self.ownable.initializer(owner); + } + } + " + ); + let result = get_string_result(attribute, item); + assert_snapshot!(result); +} + +#[test] +fn test_with_component_used_three_times() { + let attribute = "(ERC20, ERC20, ERC20)"; + let item = indoc!( + " + #[starknet::contract] + pub mod MyContract { + use openzeppelin_token::erc20::{ERC20HooksEmptyImpl, DefaultConfig}; + use starknet::ContractAddress; + + #[storage] + pub struct Storage {} + + #[constructor] + fn constructor(ref self: ContractState, owner: ContractAddress) { + self.erc20.initializer(\"MyToken\", \"MTK\"); + } + } + " + ); + let result = get_string_result(attribute, item); + assert_snapshot!(result); +} + +#[test] +fn test_with_multiple_duplicate_components() { + let attribute = "(Ownable, ERC20, Ownable, ERC20, SRC5)"; + let item = indoc!( + " + #[starknet::contract] + pub mod MyContract { + use openzeppelin_token::erc20::ERC20HooksEmptyImpl; + use starknet::ContractAddress; + + #[storage] + pub struct Storage {} + + #[constructor] + fn constructor(ref self: ContractState, owner: ContractAddress) { + self.ownable.initializer(owner); + self.erc20.initializer(\"MyToken\", \"MTK\"); + } + } + " + ); + let result = get_string_result(attribute, item); + assert_snapshot!(result); +} + // // Helpers //