Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better upgrade #63

Merged
merged 3 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/contracts/gift_factory.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ mod GiftFactory {
#[abi(embed_v0)]
impl OwnableImpl = OwnableComponent::OwnableImpl<ContractState>;
impl InternalImpl = OwnableComponent::InternalImpl<ContractState>;
impl TimelockUpgradeInternalImpl = TimelockUpgradeComponent::TimelockUpgradeInternalImpl<ContractState>;

// Pausable
component!(path: PausableComponent, storage: pausable, event: PausableEvent);
Expand Down Expand Up @@ -227,11 +228,13 @@ mod GiftFactory {

#[abi(embed_v0)]
impl TimelockUpgradeCallbackImpl of ITimelockUpgradeCallback<ContractState> {
fn perform_upgrade(ref self: ContractState, new_implementation: ClassHash, data: Span<felt252>) {
// This should do some sanity checks
// We should check that the new implementation is a valid implementation
// Execute the upgrade using replace_class_syscall(...)
panic_with_felt252('gift-fac/downgrade-not-allowed');
fn perform_upgrade(ref self: ContractState, new_implementation: ClassHash, data: Array<felt252>) {
self.timelock_upgrade.assert_and_reset_lock();
// This should do some sanity checks and ensure that the new implementation is a valid implementation,
// then it can call replace_class_syscall and emit the UpgradeExecuted event
panic_with_felt252(
'gift-fac/downgrade-not-allowed'
); // since this is the first version nobody should be calling this method
}
}

Expand Down
31 changes: 28 additions & 3 deletions src/contracts/timelock_upgrade.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub trait ITimelockUpgradeCallback<TContractState> {
/// @dev Currently empty as the upgrade logic will be handled in the contract we upgrade to
/// @param new_implementation The class hash of the new implementation
/// @param data The data to be used for the upgrade
fn perform_upgrade(ref self: TContractState, new_implementation: ClassHash, data: Span<felt252>);
fn perform_upgrade(ref self: TContractState, new_implementation: ClassHash, data: Array<felt252>);
}

#[starknet::component]
Expand All @@ -62,7 +62,9 @@ pub mod TimelockUpgradeComponent {

#[storage]
pub struct Storage {
pending_upgrade: PendingUpgrade
pending_upgrade: PendingUpgrade,
/// true only during the upgrade call
upgrade_lock: bool,
}

#[event]
Expand Down Expand Up @@ -141,14 +143,36 @@ pub mod TimelockUpgradeComponent {
assert(current_timestamp < ready_at + VALID_WINDOW_PERIOD, 'upgrade/upgrade-too-late');

self.pending_upgrade.write(Default::default());

self.upgrade_lock.write(true);
ITimelockUpgradeCallbackLibraryDispatcher { class_hash: implementation }
.perform_upgrade(implementation, calldata.span());
.perform_upgrade(implementation, calldata);
self.upgrade_lock.write(false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the upgrade should call assert_and_reset_lock(...) shouldn't we just assert!(!self.upgrade_lock.read(),'some message')?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

}


fn get_pending_upgrade(self: @ComponentState<TContractState>) -> PendingUpgrade {
self.pending_upgrade.read()
}
}

#[generate_trait]
pub impl TimelockUpgradeInternalImpl<
TContractState, +HasComponent<TContractState>
> of ITimelockUpgradeInternal<TContractState> {
/// @notice Should be called by the `perform_upgrade` method to make sure this method can only by called when upgrading
fn assert_and_reset_lock(ref self: ComponentState<TContractState>) {
assert(self.upgrade_lock.read(), 'upgrade/only-during-upgrade');
self.upgrade_lock.write(false);
}
fn emit_upgrade_executed(
ref self: ComponentState<TContractState>, new_implementation: ClassHash, calldata: Array<felt252>
) {
self.emit(UpgradeExecuted { new_implementation, calldata });
}
}


#[generate_trait]
impl PrivateImpl<
TContractState, impl Ownable: OwnableComponent::HasComponent<TContractState>, +HasComponent<TContractState>
Expand All @@ -159,6 +183,7 @@ pub mod TimelockUpgradeComponent {
}
}


impl DefaultClassHash of Default<ClassHash> {
fn default() -> ClassHash {
Zero::zero()
Expand Down
6 changes: 5 additions & 1 deletion src/mocks/future_factory.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod FutureFactory {
component!(path: TimelockUpgradeComponent, storage: timelock_upgrade, event: TimelockUpgradeEvent);
#[abi(embed_v0)]
impl TimelockUpgradeImpl = TimelockUpgradeComponent::TimelockUpgradeImpl<ContractState>;
impl TimelockUpgradeInternalImpl = TimelockUpgradeComponent::TimelockUpgradeInternalImpl<ContractState>;

#[storage]
struct Storage {
Expand Down Expand Up @@ -48,8 +49,11 @@ mod FutureFactory {

#[abi(embed_v0)]
impl TimelockUpgradeCallbackImpl of ITimelockUpgradeCallback<ContractState> {
fn perform_upgrade(ref self: ContractState, new_implementation: ClassHash, data: Span<felt252>) {
fn perform_upgrade(ref self: ContractState, new_implementation: ClassHash, data: Array<felt252>) {
self.timelock_upgrade.assert_and_reset_lock();
starknet::syscalls::replace_class_syscall(new_implementation).unwrap();
self.timelock_upgrade.emit_upgrade_executed(new_implementation, data);
}
}
}

13 changes: 13 additions & 0 deletions tests-integration/upgrade.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ describe("Test Factory Upgrade", function () {
newFactory.connect(deployer);
await newFactory.get_num().should.eventually.equal(1n);

// we can't call the perform_upgrade method directly
await expectRevertWithErrorMessage("upgrade/only-during-upgrade", () =>
factory.perform_upgrade(newFactoryClassHash, []),
);

// clear deployment cache
delete protocolCache["GiftFactory"];
});
Expand Down Expand Up @@ -79,6 +84,14 @@ describe("Test Factory Upgrade", function () {
await expectRevertWithErrorMessage("Caller is not the owner", () => factory.upgrade([]));
});

it("no calls to perform_upgrade", async function () {
const { factory } = await setupGiftProtocol();
const newFactoryClassHash = "0x1";
await expectRevertWithErrorMessage("upgrade/only-during-upgrade", () =>
factory.perform_upgrade(newFactoryClassHash, []),
);
});

it("Invalid Calldata", async function () {
const { factory } = await setupGiftProtocol();
const newFactoryClassHash = "0x1";
Expand Down
Loading