From 3100ac25234dbea22400d89dcfc381ddecefce6e Mon Sep 17 00:00:00 2001 From: Harsh Bajpai Date: Wed, 18 Oct 2023 09:58:36 +0530 Subject: [PATCH] feat: add upgradable component (#431) * feat: add upgradable component * dev: change updradable to updradeable + change updatable to upgradeable * dev: change upgratable to upgradeable * dev: minor changes on casing * dev: rename function parameter to `new_class_hash` from `class_hash` for `upgrade_contract` * dev: remove unused code + fix casing * dev: fix casing * dev: fix casing --------- Co-authored-by: Harsh Bajpai --- .../src/components/upgradeable.cairo | 38 +++++++- .../contracts/src/kakarot_core/kakarot.cairo | 14 ++- .../src/tests/test_kakarot_core.cairo | 19 ++++ .../src/tests/test_upgradeable.cairo | 92 ++++++++++++++++++- 4 files changed, 157 insertions(+), 6 deletions(-) diff --git a/crates/contracts/src/components/upgradeable.cairo b/crates/contracts/src/components/upgradeable.cairo index 60088299d..373f57c35 100644 --- a/crates/contracts/src/components/upgradeable.cairo +++ b/crates/contracts/src/components/upgradeable.cairo @@ -1,3 +1,39 @@ -// TODO +use starknet::{replace_class_syscall, ClassHash}; +#[starknet::interface] +trait IUpgradeable { + fn upgrade_contract(ref self: TContractState, new_class_hash: ClassHash); +} + +#[starknet::component] +mod upgradeable_component { + use starknet::ClassHash; + use starknet::info::get_caller_address; + + #[storage] + struct Storage {} + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + ContractUpgraded: ContractUpgraded + } + + #[derive(Drop, starknet::Event)] + struct ContractUpgraded { + new_class_hash: ClassHash + } + + #[embeddable_as(Upgradeable)] + impl UpgradeableImpl< + TContractState, +HasComponent + > of super::IUpgradeable> { + fn upgrade_contract( + ref self: ComponentState, new_class_hash: starknet::ClassHash + ) { + starknet::replace_class_syscall(new_class_hash); + self.emit(ContractUpgraded { new_class_hash: new_class_hash }); + } + } +} diff --git a/crates/contracts/src/kakarot_core/kakarot.cairo b/crates/contracts/src/kakarot_core/kakarot.cairo index a68d16d42..4a9581769 100644 --- a/crates/contracts/src/kakarot_core/kakarot.cairo +++ b/crates/contracts/src/kakarot_core/kakarot.cairo @@ -19,6 +19,7 @@ struct ContractAccountStorage { mod KakarotCore { use contracts::components::ownable::ownable_component::InternalTrait; use contracts::components::ownable::{ownable_component}; + use contracts::components::upgradeable::{IUpgradeable, upgradeable_component}; use contracts::kakarot_core::interface::IKakarotCore; use contracts::kakarot_core::interface; use core::hash::{HashStateExTrait, HashStateTrait}; @@ -35,12 +36,15 @@ mod KakarotCore { use utils::traits::U256TryIntoContractAddress; component!(path: ownable_component, storage: ownable, event: OwnableEvent); + component!(path: upgradeable_component, storage: upgradeable, event: UpgradeableEvent); #[abi(embed_v0)] impl OwnableImpl = ownable_component::Ownable; impl OwnableInternalImpl = ownable_component::InternalImpl; + impl UpgradeableImpl = upgradeable_component::Upgradeable; + #[storage] struct Storage { /// Kakarot storage for accounts: Externally Owned Accounts (EOA) and Contract Accounts (CA) @@ -64,12 +68,15 @@ mod KakarotCore { // Components #[substorage(v0)] ownable: ownable_component::Storage, + #[substorage(v0)] + upgradeable: upgradeable_component::Storage, } #[event] #[derive(Drop, starknet::Event)] enum Event { OwnableEvent: ownable_component::Event, + UpgradeableEvent: upgradeable_component::Event, EOADeployed: EOADeployed, } @@ -248,9 +255,9 @@ mod KakarotCore { /// Upgrade the KakarotCore smart contract /// Using replace_class_syscall - fn upgrade( - ref self: ContractState, new_class_hash: ClassHash - ) { //TODO: implement upgrade logic + fn upgrade(ref self: ContractState, new_class_hash: ClassHash) { + self.ownable.assert_only_owner(); + self.upgradeable.upgrade_contract(new_class_hash); } } @@ -269,4 +276,3 @@ mod KakarotCore { } } } - diff --git a/crates/contracts/src/tests/test_kakarot_core.cairo b/crates/contracts/src/tests/test_kakarot_core.cairo index d7f992d5e..bd70aa56c 100644 --- a/crates/contracts/src/tests/test_kakarot_core.cairo +++ b/crates/contracts/src/tests/test_kakarot_core.cairo @@ -1,5 +1,9 @@ use contracts::components::ownable::ownable_component; use contracts::kakarot_core::{interface::IExtendedKakarotCoreDispatcherImpl, KakarotCore}; +use contracts::tests::test_upgradeable::{ + MockContractUpgradeableV1, IMockContractUpgradeableDispatcher, + IMockContractUpgradeableDispatcherTrait +}; use contracts::tests::utils; use debug::PrintTrait; use eoa::externally_owned_account::ExternallyOwnedAccount; @@ -97,3 +101,18 @@ fn test_kakarot_core_compute_starknet_address() { assert(eoa_starknet_address == expected_starknet_address, 'wrong starknet address'); } +#[test] +#[available_gas(20000000)] +fn test_kakarot_core_upgrade_contract() { + let kakarot_core = utils::deploy_kakarot_core(test_utils::native_token()); + let class_hash: ClassHash = MockContractUpgradeableV1::TEST_CLASS_HASH.try_into().unwrap(); + + testing::set_contract_address(utils::other_starknet_address()); + kakarot_core.upgrade(class_hash); + + let version = IMockContractUpgradeableDispatcher { + contract_address: kakarot_core.contract_address + } + .version(); + assert(version == 1, 'version is not 1'); +} diff --git a/crates/contracts/src/tests/test_upgradeable.cairo b/crates/contracts/src/tests/test_upgradeable.cairo index 60088299d..8dc06d977 100644 --- a/crates/contracts/src/tests/test_upgradeable.cairo +++ b/crates/contracts/src/tests/test_upgradeable.cairo @@ -1,3 +1,93 @@ -// TODO +use MockContractUpgradeableV0::HasComponentImpl_upgradeable_component; +use contracts::components::upgradeable::{IUpgradeableDispatcher, IUpgradeableDispatcherTrait}; +use contracts::components::upgradeable::{upgradeable_component}; +use contracts::tests::utils; +use debug::PrintTrait; +use serde::Serde; +use starknet::{deploy_syscall, ClassHash, ContractAddress, testing}; +use upgradeable_component::{UpgradeableImpl}; +#[starknet::interface] +trait IMockContractUpgradeable { + fn version(self: @TContractState) -> felt252; +} + +#[starknet::contract] +mod MockContractUpgradeableV0 { + use contracts::components::upgradeable::{upgradeable_component}; + use super::IMockContractUpgradeable; + component!(path: upgradeable_component, storage: upgradeable, event: UpgradeableEvent); + + #[abi(embed_v0)] + impl UpgradeableImpl = upgradeable_component::Upgradeable; + + #[storage] + struct Storage { + #[substorage(v0)] + upgradeable: upgradeable_component::Storage + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + UpgradeableEvent: upgradeable_component::Event + } + + #[external(v0)] + impl MockContractUpgradeableImpl of IMockContractUpgradeable { + fn version(self: @ContractState) -> felt252 { + 0 + } + } +} + +#[starknet::contract] +mod MockContractUpgradeableV1 { + use contracts::components::upgradeable::{upgradeable_component}; + use super::IMockContractUpgradeable; + component!(path: upgradeable_component, storage: upgradeable, event: upgradeableEvent); + + #[storage] + struct Storage { + #[substorage(v0)] + upgradeable: upgradeable_component::Storage + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + upgradeableEvent: upgradeable_component::Event + } + + #[external(v0)] + impl MockContractUpgradeableImpl of IMockContractUpgradeable { + fn version(self: @ContractState) -> felt252 { + 1 + } + } +} + +#[test] +#[available_gas(500000)] +fn test_upgradeable_update_contract() { + let (contract_address, _) = deploy_syscall( + MockContractUpgradeableV0::TEST_CLASS_HASH.try_into().unwrap(), 0, array![].span(), false + ) + .unwrap(); + + let version = IMockContractUpgradeableDispatcher { contract_address: contract_address } + .version(); + + assert(version == 0, 'version is not 0'); + + let mut call_data: Array = array![]; + + let new_class_hash: ClassHash = MockContractUpgradeableV1::TEST_CLASS_HASH.try_into().unwrap(); + + IUpgradeableDispatcher { contract_address: contract_address }.upgrade_contract(new_class_hash); + + let version = IMockContractUpgradeableDispatcher { contract_address: contract_address } + .version(); + assert(version == 1, 'version is not 1'); +}