From 07d04edb4d1cab74d7903eace9f90c29c6283d5c Mon Sep 17 00:00:00 2001 From: Sergio Garcia Date: Wed, 26 Jun 2024 15:04:57 +0100 Subject: [PATCH] tweaks to upgrade --- src/contracts/timelock_upgrade.cairo | 115 +++++++++++++++------------ 1 file changed, 62 insertions(+), 53 deletions(-) diff --git a/src/contracts/timelock_upgrade.cairo b/src/contracts/timelock_upgrade.cairo index 9b332aa..5a9fa7e 100644 --- a/src/contracts/timelock_upgrade.cairo +++ b/src/contracts/timelock_upgrade.cairo @@ -1,4 +1,15 @@ -use starknet::ClassHash; +use core::num::traits::Zero; +use starknet::{ClassHash}; + +#[derive(Serde, Drop, Copy, starknet::Store)] +struct PendingUpgrade { + // Gets the classhash after + implementation: ClassHash, + // Gets the timestamp when the upgrade is ready to be performed, 0 if no upgrade ongoing + ready_at: u64, + // Gets the hash of the calldata used for the upgrade, 0 if no upgrade ongoing + calldata_hash: felt252, +} #[starknet::interface] pub trait ITimelockUpgrade { @@ -19,14 +30,8 @@ pub trait ITimelockUpgrade { /// @param calldata The calldata to be used for the upgrade fn upgrade(ref self: TContractState, calldata: Array); - /// @notice Gets the proposed implementation - fn get_proposed_implementation(self: @TContractState) -> ClassHash; - - /// @notice Gets the timestamp when the upgrade is ready to be performed, 0 if no upgrade ongoing - fn get_upgrade_ready_at(self: @TContractState) -> u64; - - /// @notice Gets the hash of the calldata used for the upgrade, 0 if no upgrade ongoing - fn get_calldata_hash(self: @TContractState) -> felt252; + /// @notice Gets the proposed upgrade + fn get_pending_upgrade(self: @TContractState) -> PendingUpgrade; } #[starknet::interface] @@ -46,7 +51,7 @@ pub mod TimelockUpgradeComponent { use starknet::{get_block_timestamp, ClassHash}; use super::{ ITimelockUpgrade, ITimelockUpgradeCallback, ITimelockUpgradeCallbackLibraryDispatcher, - ITimelockUpgradeCallbackDispatcherTrait + ITimelockUpgradeCallbackDispatcherTrait, PendingUpgrade, PendingUpgradeZero }; /// Time before the upgrade can be performed @@ -54,11 +59,10 @@ pub mod TimelockUpgradeComponent { /// Time window during which the upgrade can be performed const VALID_WINDOW_PERIOD: u64 = consteval_int!(7 * 24 * 60 * 60); // 7 days + #[storage] pub struct Storage { - pending_implementation: ClassHash, - ready_at: u64, - calldata_hash: felt252, + pending_upgrade: PendingUpgrade } #[event] @@ -78,12 +82,12 @@ pub mod TimelockUpgradeComponent { #[derive(Drop, starknet::Event)] struct UpgradeCancelled { - cancelled_implementation: ClassHash + cancelled_upgrade: PendingUpgrade } #[derive(Drop, starknet::Event)] struct Upgraded { - new_implementation: ClassHash + executed_upgrade: PendingUpgrade } #[embeddable_as(TimelockUpgradeImpl)] @@ -99,53 +103,51 @@ pub mod TimelockUpgradeComponent { self.assert_only_owner(); assert(new_implementation.is_non_zero(), 'upgrade/new-implementation-null'); - let pending_implementation = self.pending_implementation.read(); - if pending_implementation.is_non_zero() { - self.emit(UpgradeCancelled { cancelled_implementation: pending_implementation }) + let pending_upgrade = self.pending_upgrade.read(); + if pending_upgrade.is_non_zero() { + self.emit(UpgradeCancelled { cancelled_upgrade: pending_upgrade }) } - self.pending_implementation.write(new_implementation); let ready_at = get_block_timestamp() + MIN_SECURITY_PERIOD; - self.ready_at.write(ready_at); - let calldata_hash = poseidon_hash_span(calldata.span()); - self.calldata_hash.write(calldata_hash); + self + .pending_upgrade + .write( + PendingUpgrade { + implementation: new_implementation, ready_at, calldata_hash: poseidon_hash_span(calldata.span()) + } + ); self.emit(UpgradeProposed { new_implementation, ready_at, calldata }); } fn cancel_upgrade(ref self: ComponentState) { self.assert_only_owner(); - let proposed_implementation = self.pending_implementation.read(); - assert(proposed_implementation.is_non_zero(), 'upgrade/no-new-implementation'); - assert(self.ready_at.read() != 0, 'upgrade/not-ready'); - self.emit(UpgradeCancelled { cancelled_implementation: proposed_implementation }); - self.reset_storage(); + let proposed_implementation = self.pending_upgrade.read(); + assert(proposed_implementation.is_non_zero(), 'upgrade/no-pending-upgrade'); + self.pending_upgrade.write(Zero::zero()); + self.emit(UpgradeCancelled { cancelled_upgrade: proposed_implementation }); } fn upgrade(ref self: ComponentState, calldata: Array) { self.assert_only_owner(); - let new_implementation = self.pending_implementation.read(); - let ready_at = self.ready_at.read(); - let block_timestamp = get_block_timestamp(); - let calldata_hash = poseidon_hash_span(calldata.span()); - assert(calldata_hash == self.calldata_hash.read(), 'upgrade/invalid-calldata'); - assert(new_implementation.is_non_zero(), 'upgrade/no-pending-upgrade'); - assert(block_timestamp >= ready_at, 'upgrade/too-early'); - assert(block_timestamp < ready_at + VALID_WINDOW_PERIOD, 'upgrade/upgrade-too-late'); - self.reset_storage(); - ITimelockUpgradeCallbackLibraryDispatcher { class_hash: new_implementation } - .perform_upgrade(new_implementation, calldata.span()); - } - - fn get_proposed_implementation(self: @ComponentState) -> ClassHash { - self.pending_implementation.read() + let proposed_implementation = self.pending_upgrade.read(); + assert(proposed_implementation.is_non_zero(), 'upgrade/no-pending-upgrade'); + + let current_timestamp = get_block_timestamp(); + assert( + proposed_implementation.calldata_hash == poseidon_hash_span(calldata.span()), 'upgrade/invalid-calldata' + ); + + assert(current_timestamp >= proposed_implementation.ready_at, 'upgrade/too-early'); + assert( + current_timestamp < proposed_implementation.ready_at + VALID_WINDOW_PERIOD, 'upgrade/upgrade-too-late' + ); + self.pending_upgrade.write(Zero::zero()); + ITimelockUpgradeCallbackLibraryDispatcher { class_hash: proposed_implementation.implementation } + .perform_upgrade(proposed_implementation.implementation, calldata.span()); } - fn get_upgrade_ready_at(self: @ComponentState) -> u64 { - self.ready_at.read() - } - - fn get_calldata_hash(self: @ComponentState) -> felt252 { - self.calldata_hash.read() + fn get_pending_upgrade(self: @ComponentState) -> PendingUpgrade { + self.pending_upgrade.read() } } #[generate_trait] @@ -155,11 +157,18 @@ pub mod TimelockUpgradeComponent { fn assert_only_owner(self: @ComponentState) { get_dep_component!(self, Ownable).assert_only_owner(); } + } +} - fn reset_storage(ref self: ComponentState) { - self.pending_implementation.write(Zero::zero()); - self.ready_at.write(0); - self.calldata_hash.write(0); - } + +impl PendingUpgradeZero of core::num::traits::Zero { + fn zero() -> PendingUpgrade { + PendingUpgrade { implementation: Zero::zero(), ready_at: 0, calldata_hash: 0 } + } + fn is_zero(self: @PendingUpgrade) -> bool { + *self.calldata_hash == 0 + } + fn is_non_zero(self: @PendingUpgrade) -> bool { + !self.is_zero() } }