diff --git a/src/base/constants/types.cairo b/src/base/constants/types.cairo index 9a19640..1f3ee9d 100644 --- a/src/base/constants/types.cairo +++ b/src/base/constants/types.cairo @@ -207,7 +207,7 @@ pub struct JoltParams { #[derive(Drop, Serde, starknet::Store)] pub struct RenewalData { - pub renewal_duration: u256, + pub renewal_iterations: u256, pub renewal_amount: u256, pub erc20_contract_address: ContractAddress } diff --git a/src/interfaces.cairo b/src/interfaces.cairo index e918ff6..d30ad84 100644 --- a/src/interfaces.cairo +++ b/src/interfaces.cairo @@ -10,3 +10,4 @@ pub mod IHandleRegistry; pub mod IHub; pub mod IJolt; pub mod ICollectNFT; +pub mod IUpgradeable; diff --git a/src/interfaces/IJolt.cairo b/src/interfaces/IJolt.cairo index 9b76dfd..1c8cc43 100644 --- a/src/interfaces/IJolt.cairo +++ b/src/interfaces/IJolt.cairo @@ -1,5 +1,5 @@ use starknet::ContractAddress; -use karst::base::constants::types::{JoltParams, JoltData}; +use karst::base::constants::types::{JoltParams, JoltData, RenewalData}; #[starknet::interface] pub trait IJolt { @@ -7,12 +7,13 @@ pub trait IJolt { // EXTERNALS // ************************************************************************* fn jolt(ref self: TState, jolt_params: JoltParams) -> u256; - fn set_fee_address(ref self: TState, _fee_address: ContractAddress); - fn auto_renew(ref self: TState, profile: ContractAddress, renewal_id: u256) -> bool; + fn auto_renew(ref self: TState, profile: ContractAddress, jolt_id: u256) -> bool; fn fulfill_request(ref self: TState, jolt_id: u256) -> bool; + fn set_fee_address(ref self: TState, _fee_address: ContractAddress); // ************************************************************************* // GETTERS // ************************************************************************* fn get_jolt(self: @TState, jolt_id: u256) -> JoltData; + fn get_renewal_data(self: @TState, profile: ContractAddress, jolt_id: u256) -> RenewalData; fn get_fee_address(self: @TState) -> ContractAddress; } diff --git a/src/interfaces/IUpgradeable.cairo b/src/interfaces/IUpgradeable.cairo new file mode 100644 index 0000000..e96214b --- /dev/null +++ b/src/interfaces/IUpgradeable.cairo @@ -0,0 +1,9 @@ +// ************************************************************************* +// UPGRADEABLE INTERFACE +// ************************************************************************* +use starknet::ClassHash; + +#[starknet::interface] +pub trait IUpgradeable { + fn upgrade(ref self: TContractState, new_class_hash: ClassHash); +} diff --git a/src/jolt/jolt.cairo b/src/jolt/jolt.cairo index ea93216..91d8cb5 100644 --- a/src/jolt/jolt.cairo +++ b/src/jolt/jolt.cairo @@ -111,6 +111,9 @@ pub mod Jolt { // ************************************************************************* // EXTERNALS // ************************************************************************* + + /// @notice multi-faceted transfer logic + /// @param jolt_params required jolting parameters fn jolt(ref self: ContractState, jolt_params: JoltParams) -> u256 { let sender = get_caller_address(); let tx_info = get_tx_info().unbox(); @@ -198,11 +201,8 @@ pub mod Jolt { return jolt_id; } - fn set_fee_address(ref self: ContractState, _fee_address: ContractAddress) { - self.ownable.assert_only_owner(); - self.fee_address.write(_fee_address); - } - + /// @notice fulfills a pending jolt request + /// @param jolt_id id of jolt request to be fulfilled fn fulfill_request(ref self: ContractState, jolt_id: u256) -> bool { // get jolt details let mut jolt_details = self.jolt.read(jolt_id); @@ -224,17 +224,43 @@ pub mod Jolt { self._fulfill_request(jolt_id, sender, jolt_details) } - fn auto_renew(ref self: ContractState, profile: ContractAddress, renewal_id: u256) -> bool { - self._auto_renew(profile, renewal_id) + /// @notice contains logic for auto renewal of subscriptions + /// @dev can be automated using cron jobs in a backend service + /// @param jolt_id id of jolt subscription to auto-renew + fn auto_renew(ref self: ContractState, profile: ContractAddress, jolt_id: u256) -> bool { + self._auto_renew(profile, jolt_id) + } + + /// @notice sets the fee address which receives subscription payments and maybe actual fees + /// in the future? + /// @param _fee_address address to be set + fn set_fee_address(ref self: ContractState, _fee_address: ContractAddress) { + self.ownable.assert_only_owner(); + self.fee_address.write(_fee_address); } // ************************************************************************* // GETTERS // ************************************************************************* + + /// @notice gets the associated data for a jolt id + /// @param jolt_id id of jolt who's data is to be gotten + /// @returns JoltData struct containing jolt details fn get_jolt(self: @ContractState, jolt_id: u256) -> JoltData { self.jolt.read(jolt_id) } + /// @notice gets the renewal data for a particular jolt id + /// @param jolt_id id of jolt who's renewal data is to be gotten + /// @returns RenewalData struct containing jolt renewal details + fn get_renewal_data( + self: @ContractState, profile: ContractAddress, jolt_id: u256 + ) -> RenewalData { + self.renewals.read((profile, jolt_id)) + } + + /// @notice gets the fee address + /// @returns the fee address for contract fn get_fee_address(self: @ContractState) -> ContractAddress { self.fee_address.read() } @@ -245,6 +271,8 @@ pub mod Jolt { // ************************************************************************* #[abi(embed_v0)] impl UpgradeableImpl of IUpgradeable { + /// @notice upgrades the contract + /// @param new_class_hash the class hash to upgrade to fn upgrade(ref self: ContractState, new_class_hash: ClassHash) { self.ownable.assert_only_owner(); self.upgradeable.upgrade(new_class_hash); @@ -256,6 +284,13 @@ pub mod Jolt { // ************************************************************************* #[generate_trait] impl Private of PrivateTrait { + /// @notice contains the tipping logic + /// @param jolt_id id of txn + /// @param sender the profile performing the tipping + /// @param recipient the profile being tipped + /// @param amount the amount to be tipped + /// @param erc20_contract_address the address of token used in tipping + /// @returns JoltStatus status of the txn fn _tip( ref self: ContractState, jolt_id: u256, @@ -287,6 +322,13 @@ pub mod Jolt { JoltStatus::SUCCESSFUL } + /// @notice contains the transfer logic + /// @param jolt_id id of txn + /// @param sender the profile performing the transfer + /// @param recipient the profile being transferred to + /// @param amount the amount to be transferred + /// @param erc20_contract_address the address of token used + /// @returns JoltStatus status of the txn fn _transfer( ref self: ContractState, jolt_id: u256, @@ -318,6 +360,13 @@ pub mod Jolt { JoltStatus::SUCCESSFUL } + /// @notice contains the subscription logic + /// @param jolt_id id of txn + /// @param sender the profile performing the subscription + /// @param amount the amount to pay + /// @param auto_renewal a tuple containing renewal status and renewal_iterations + /// @param erc20_contract_address the address of token used + /// @returns JoltStatus status of the txn fn _subscribe( ref self: ContractState, jolt_id: u256, @@ -326,34 +375,22 @@ pub mod Jolt { auto_renewal: (bool, u256), erc20_contract_address: ContractAddress ) -> JoltStatus { - let (renewal_status, renewal_duration) = auto_renewal; + let (renewal_status, renewal_iterations) = auto_renewal; let dispatcher = IERC20Dispatcher { contract_address: erc20_contract_address }; let this_contract = get_contract_address(); - let tx_info = get_tx_info().unbox(); if (renewal_status == true) { - // check allowances match auto-renew duration + // check allowances match auto-renew iterations let allowance = dispatcher.allowance(sender, this_contract); - assert(allowance >= renewal_duration * amount, Errors::INSUFFICIENT_ALLOWANCE); - - // generate renewal ID - let renewal_hash = PedersenTrait::new(0) - .update(sender.into()) - .update(jolt_id.low.into()) - .update(jolt_id.high.into()) - .update(tx_info.nonce) - .update(4) - .finalize(); - - let renewal_id: u256 = renewal_hash.try_into().unwrap(); + assert(allowance >= renewal_iterations * amount, Errors::INSUFFICIENT_ALLOWANCE); // write renewal details to storage let renewal_data = RenewalData { - renewal_duration: renewal_duration, + renewal_iterations: renewal_iterations, renewal_amount: amount, erc20_contract_address }; - self.renewals.write((sender, renewal_id), renewal_data); + self.renewals.write((sender, jolt_id), renewal_data); } // send subscription amount to fee address @@ -376,6 +413,14 @@ pub mod Jolt { JoltStatus::SUCCESSFUL } + /// @notice contains the request logic + /// @param jolt_id id of txn + /// @param sender the profile performing the request + /// @param recipient the profile being requested of + /// @param amount the amount to be tipped + /// @param expiration_timestamp timestamp of when the request will expire + /// @param erc20_contract_address the address of token used + /// @returns JoltStatus status of the txn fn _request( ref self: ContractState, jolt_id: u256, @@ -408,6 +453,10 @@ pub mod Jolt { JoltStatus::PENDING } + /// @notice internal logic to fulfill a request + /// @param sender the profile fulfilling the request + /// @param jolt_details details of the jolt to be fulfilled + /// @returns bool status of the txn fn _fulfill_request( ref self: ContractState, jolt_id: u256, sender: ContractAddress, jolt_details: JoltData ) -> bool { @@ -440,17 +489,21 @@ pub mod Jolt { return true; } + /// @notice internal logic to auto renew a subscription + /// @param sender the profile renewing a subscription + /// @param renewal_id id jolt to be renewed + /// @returns bool status of the txn fn _auto_renew(ref self: ContractState, sender: ContractAddress, renewal_id: u256) -> bool { let tx_info = get_tx_info().unbox(); let amount = self.renewals.read((sender, renewal_id)).renewal_amount; - let duration = self.renewals.read((sender, renewal_id)).renewal_duration; + let iteration = self.renewals.read((sender, renewal_id)).renewal_iterations; let erc20_contract_address = self .renewals .read((sender, renewal_id)) .erc20_contract_address; - // check duration is greater than 0 else shouldn't auto renew - assert(duration > 0, Errors::AUTO_RENEW_DURATION_ENDED); + // check iteration is greater than 0 else shouldn't auto renew + assert(iteration > 0, Errors::AUTO_RENEW_DURATION_ENDED); // send subscription amount to fee address let fee_address = self.fee_address.read(); @@ -467,9 +520,9 @@ pub mod Jolt { let jolt_id: u256 = jolt_hash.try_into().unwrap(); - // reduce duration by one month + // reduce iteration by one month let renewal_data = RenewalData { - renewal_duration: duration - 1, renewal_amount: amount, erc20_contract_address + renewal_iterations: iteration - 1, renewal_amount: amount, erc20_contract_address }; self.renewals.write((sender, renewal_id), renewal_data); @@ -506,6 +559,11 @@ pub mod Jolt { return true; } + /// @notice internal logic to perform an ERC20 transfer + /// @param erc20_contract_address address of the token to be transferred + /// @param sender profile sending the token + /// @param recipient profile receiving the token + /// @param amount amount to be transferred fn _transfer_helper( ref self: ContractState, erc20_contract_address: ContractAddress, diff --git a/src/mocks.cairo b/src/mocks.cairo index 231805c..e637541 100644 --- a/src/mocks.cairo +++ b/src/mocks.cairo @@ -1,3 +1,4 @@ pub mod registry; pub mod interfaces; pub mod ERC20; +pub mod jolt_upgrade; diff --git a/src/mocks/interfaces.cairo b/src/mocks/interfaces.cairo index 8468de9..b0cdc97 100644 --- a/src/mocks/interfaces.cairo +++ b/src/mocks/interfaces.cairo @@ -1 +1,2 @@ pub mod IComposable; +pub mod IJoltUpgrade; diff --git a/src/mocks/interfaces/IJoltUpgrade.cairo b/src/mocks/interfaces/IJoltUpgrade.cairo new file mode 100644 index 0000000..85087a5 --- /dev/null +++ b/src/mocks/interfaces/IJoltUpgrade.cairo @@ -0,0 +1,13 @@ +use karst::base::constants::types::{JoltParams}; + +#[starknet::interface] +pub trait IJoltUpgrade { + // ************************************************************************* + // EXTERNALS + // ************************************************************************* + fn jolt(ref self: TState, jolt_params: JoltParams) -> u256; + // ************************************************************************* + // GETTERS + // ************************************************************************* + fn version(self: @TState) -> u256; +} diff --git a/src/mocks/jolt_upgrade.cairo b/src/mocks/jolt_upgrade.cairo new file mode 100644 index 0000000..1d65ddf --- /dev/null +++ b/src/mocks/jolt_upgrade.cairo @@ -0,0 +1,52 @@ +#[starknet::contract] +pub mod JoltUpgrade { + // ************************************************************************* + // IMPORTS + // ************************************************************************* + use core::hash::HashStateTrait; + use core::pedersen::PedersenTrait; + use starknet::get_tx_info; + use karst::base::{constants::types::{JoltParams}}; + use karst::mocks::interfaces::IJoltUpgrade::IJoltUpgrade; + + // ************************************************************************* + // STORAGE + // ************************************************************************* + #[storage] + struct Storage {} + + // ************************************************************************* + // EVENTS + // ************************************************************************* + #[event] + #[derive(Drop, starknet::Event)] + pub enum Event {} + + + // ************************************************************************* + // EXTERNALS + // ************************************************************************* + #[abi(embed_v0)] + impl JoltImpl of IJoltUpgrade { + fn jolt(ref self: ContractState, jolt_params: JoltParams) -> u256 { + let tx_info = get_tx_info().unbox(); + + // generate jolt_id + let jolt_hash = PedersenTrait::new(0) + .update(jolt_params.recipient.into()) + .update(jolt_params.amount.low.into()) + .update(jolt_params.amount.high.into()) + .update(tx_info.nonce) + .update(4) + .finalize(); + + let jolt_id: u256 = jolt_hash.try_into().unwrap(); + + return jolt_id; + } + + fn version(self: @ContractState) -> u256 { + 2 + } + } +} diff --git a/tests/test_jolt.cairo b/tests/test_jolt.cairo index 2b59287..591ca52 100644 --- a/tests/test_jolt.cairo +++ b/tests/test_jolt.cairo @@ -2,6 +2,7 @@ use core::traits::TryInto; use core::hash::HashStateTrait; use core::pedersen::PedersenTrait; use starknet::{ContractAddress, contract_address_const}; + use snforge_std::{ declare, DeclareResultTrait, ContractClassTrait, start_cheat_caller_address, stop_cheat_caller_address, start_cheat_nonce, stop_cheat_nonce, start_cheat_block_timestamp, @@ -9,14 +10,20 @@ use snforge_std::{ }; use karst::interfaces::IJolt::{IJoltDispatcher, IJoltDispatcherTrait}; use karst::interfaces::IERC20::{IERC20Dispatcher, IERC20DispatcherTrait}; -use karst::jolt::jolt::Jolt::{Event as JoltEvent, Jolted}; -use karst::jolt::jolt::Jolt::{Event as JoltRequestEvent, JoltRequested}; -use karst::jolt::jolt::Jolt::{Event as JoltRequestFulfillEvent, JoltRequestFullfilled}; +use karst::interfaces::IUpgradeable::{IUpgradeableDispatcher, IUpgradeableDispatcherTrait}; + +use karst::jolt::jolt::Jolt::{ + {Event as JoltEvent, Jolted}, {Event as JoltRequestEvent, JoltRequested}, + {Event as JoltRequestFulfillEvent, JoltRequestFullfilled}, +}; + use karst::base::{constants::types::{JoltParams, JoltType, JoltStatus}}; +use karst::mocks::interfaces::IJoltUpgrade::{IJoltUpgradeDispatcher, IJoltUpgradeDispatcherTrait}; const ADMIN: felt252 = 5382942; const ADDRESS1: felt252 = 254290; const ADDRESS2: felt252 = 525616; +const FEE_ADDRESS: felt252 = 250322; // ************************************************************************* // SETUP @@ -849,3 +856,380 @@ fn test_jolt_event_is_emitted_on_request_fulfillment() { stop_cheat_block_timestamp(jolt_contract_address); stop_cheat_caller_address(jolt_contract_address); } + +// ************************************************************************* +// TEST - SUBSCRIPTION +// ************************************************************************* +#[test] +fn test_jolt_subscription() { + let (jolt_contract_address, erc20_contract_address) = __setup__(); + let dispatcher = IJoltDispatcher { contract_address: jolt_contract_address }; + let erc20_dispatcher = IERC20Dispatcher { contract_address: erc20_contract_address }; + + let jolt_params = JoltParams { + jolt_type: JoltType::Subscription, + recipient: contract_address_const::<0>(), + memo: "hey first subscription ever!", + amount: 2000000000000000000, + expiration_stamp: 0, + auto_renewal: (true, 1), + erc20_contract_address: erc20_contract_address + }; + + // set fee address + start_cheat_caller_address(jolt_contract_address, ADMIN.try_into().unwrap()); + dispatcher.set_fee_address(FEE_ADDRESS.try_into().unwrap()); + stop_cheat_caller_address(jolt_contract_address); + + // approve contract to spend amount + start_cheat_caller_address(erc20_contract_address, ADDRESS1.try_into().unwrap()); + erc20_dispatcher.approve(jolt_contract_address, 2000000000000000000); + stop_cheat_caller_address(erc20_contract_address); + + // jolt + start_cheat_caller_address(jolt_contract_address, ADDRESS1.try_into().unwrap()); + start_cheat_block_timestamp(jolt_contract_address, 36000); + start_cheat_nonce(jolt_contract_address, 23); + + let jolt_id = dispatcher.jolt(jolt_params); + + // check jolt data + let jolt_data = dispatcher.get_jolt(jolt_id); + assert(jolt_data.jolt_type == JoltType::Subscription, 'invalid jolt type'); + assert(jolt_data.sender == ADDRESS1.try_into().unwrap(), 'invalid sender'); + assert(jolt_data.memo == "hey first subscription ever!", 'invalid memo'); + assert(jolt_data.amount == 2000000000000000000, 'invalid amount'); + assert(jolt_data.status == JoltStatus::SUCCESSFUL, 'invalid status'); + assert(jolt_data.block_timestamp == 36000, 'invalid block stamp'); + assert(jolt_data.erc20_contract_address == erc20_contract_address, 'invalid address'); + + // check that fee_address received sub amount + let balance = erc20_dispatcher.balance_of(FEE_ADDRESS.try_into().unwrap()); + assert(balance == 2000000000000000000, 'incorrect balance'); + + // check that renewal data was updated + let renewal_data = dispatcher.get_renewal_data(ADDRESS1.try_into().unwrap(), jolt_id); + assert(renewal_data.renewal_iterations == 1, 'invalid iteration count'); + assert(renewal_data.renewal_amount == 2000000000000000000, 'invalid renewal amount'); + assert(renewal_data.erc20_contract_address == erc20_contract_address, 'invalid erc20'); + + stop_cheat_nonce(jolt_contract_address); + stop_cheat_block_timestamp(jolt_contract_address); + stop_cheat_caller_address(jolt_contract_address); +} + +#[test] +#[should_panic(expected: ('Karst: insufficient allowance!',))] +fn test_jolt_subscription_fails_if_insufficient_allowance() { + let (jolt_contract_address, erc20_contract_address) = __setup__(); + let dispatcher = IJoltDispatcher { contract_address: jolt_contract_address }; + let erc20_dispatcher = IERC20Dispatcher { contract_address: erc20_contract_address }; + + let jolt_params = JoltParams { + jolt_type: JoltType::Subscription, + recipient: contract_address_const::<0>(), + memo: "hey first subscription ever!", + amount: 2000000000000000000, + expiration_stamp: 0, + auto_renewal: (true, 5), + erc20_contract_address: erc20_contract_address + }; + + // set fee address + start_cheat_caller_address(jolt_contract_address, ADMIN.try_into().unwrap()); + dispatcher.set_fee_address(FEE_ADDRESS.try_into().unwrap()); + stop_cheat_caller_address(jolt_contract_address); + + // approve contract to spend amount + start_cheat_caller_address(erc20_contract_address, ADDRESS1.try_into().unwrap()); + erc20_dispatcher.approve(jolt_contract_address, 4000000000000000000); + stop_cheat_caller_address(erc20_contract_address); + + // jolt + start_cheat_caller_address(jolt_contract_address, ADDRESS1.try_into().unwrap()); + start_cheat_block_timestamp(jolt_contract_address, 36000); + start_cheat_nonce(jolt_contract_address, 23); + + dispatcher.jolt(jolt_params); + + stop_cheat_nonce(jolt_contract_address); + stop_cheat_block_timestamp(jolt_contract_address); + stop_cheat_caller_address(jolt_contract_address); +} + +#[test] +fn test_jolt_event_is_emitted_on_subscription() { + let (jolt_contract_address, erc20_contract_address) = __setup__(); + let dispatcher = IJoltDispatcher { contract_address: jolt_contract_address }; + let erc20_dispatcher = IERC20Dispatcher { contract_address: erc20_contract_address }; + + let jolt_params = JoltParams { + jolt_type: JoltType::Subscription, + recipient: contract_address_const::<0>(), + memo: "hey first subscription ever!", + amount: 2000000000000000000, + expiration_stamp: 0, + auto_renewal: (true, 5), + erc20_contract_address: erc20_contract_address + }; + + // set fee address + start_cheat_caller_address(jolt_contract_address, ADMIN.try_into().unwrap()); + dispatcher.set_fee_address(FEE_ADDRESS.try_into().unwrap()); + stop_cheat_caller_address(jolt_contract_address); + + // approve contract to spend amount + start_cheat_caller_address(erc20_contract_address, ADDRESS1.try_into().unwrap()); + erc20_dispatcher.approve(jolt_contract_address, 10000000000000000000); + stop_cheat_caller_address(erc20_contract_address); + + // jolt + start_cheat_caller_address(jolt_contract_address, ADDRESS1.try_into().unwrap()); + start_cheat_block_timestamp(jolt_contract_address, 36000); + + let mut spy = spy_events(); + let jolt_id = dispatcher.jolt(jolt_params); + + // check for events + let expected_event = JoltEvent::Jolted( + Jolted { + jolt_id: jolt_id, + jolt_type: 'SUBSCRIPTION', + sender: ADDRESS1.try_into().unwrap(), + recipient: FEE_ADDRESS.try_into().unwrap(), + block_timestamp: 36000, + } + ); + spy.assert_emitted(@array![(jolt_contract_address, expected_event)]); + + stop_cheat_block_timestamp(jolt_contract_address); + stop_cheat_caller_address(jolt_contract_address); +} + +#[test] +fn test_auto_renewal() { + let (jolt_contract_address, erc20_contract_address) = __setup__(); + let dispatcher = IJoltDispatcher { contract_address: jolt_contract_address }; + let erc20_dispatcher = IERC20Dispatcher { contract_address: erc20_contract_address }; + + // user first need to subscribe + let jolt_params = JoltParams { + jolt_type: JoltType::Subscription, + recipient: contract_address_const::<0>(), + memo: "hey first subscription ever!", + amount: 2000000000000000000, + expiration_stamp: 0, + auto_renewal: (true, 5), + erc20_contract_address: erc20_contract_address + }; + + // set fee address + start_cheat_caller_address(jolt_contract_address, ADMIN.try_into().unwrap()); + dispatcher.set_fee_address(FEE_ADDRESS.try_into().unwrap()); + stop_cheat_caller_address(jolt_contract_address); + + // approve contract to spend amount + start_cheat_caller_address(erc20_contract_address, ADDRESS1.try_into().unwrap()); + erc20_dispatcher.approve(jolt_contract_address, 10000000000000000000); + stop_cheat_caller_address(erc20_contract_address); + + // jolt + start_cheat_caller_address(jolt_contract_address, ADDRESS1.try_into().unwrap()); + start_cheat_block_timestamp(jolt_contract_address, 36000); + start_cheat_nonce(jolt_contract_address, 23); + + let jolt_id = dispatcher.jolt(jolt_params); + + // try to auto renew thrice + dispatcher.auto_renew(ADDRESS1.try_into().unwrap(), jolt_id); + dispatcher.auto_renew(ADDRESS1.try_into().unwrap(), jolt_id); + dispatcher.auto_renew(ADDRESS1.try_into().unwrap(), jolt_id); + + // check if auto renewal worked + let renewal_data = dispatcher.get_renewal_data(ADDRESS1.try_into().unwrap(), jolt_id); + assert(renewal_data.renewal_iterations == 2, 'invalid iteration count'); + assert(renewal_data.renewal_amount == 2000000000000000000, 'invalid renewal amount'); + assert(renewal_data.erc20_contract_address == erc20_contract_address, 'invalid erc20'); + + // check that fee_address received sub amount plus renewal amounts + let balance = erc20_dispatcher.balance_of(FEE_ADDRESS.try_into().unwrap()); + assert(balance == 8000000000000000000, 'incorrect balance'); + + stop_cheat_block_timestamp(jolt_contract_address); + stop_cheat_caller_address(jolt_contract_address); +} + +#[test] +#[should_panic(expected: ('Karst: auto renew ended!',))] +fn test_auto_renewal_fails_once_iteration_count_is_zero() { + let (jolt_contract_address, erc20_contract_address) = __setup__(); + let dispatcher = IJoltDispatcher { contract_address: jolt_contract_address }; + let erc20_dispatcher = IERC20Dispatcher { contract_address: erc20_contract_address }; + + // user first need to subscribe + let jolt_params = JoltParams { + jolt_type: JoltType::Subscription, + recipient: contract_address_const::<0>(), + memo: "hey first subscription ever!", + amount: 2000000000000000000, + expiration_stamp: 0, + auto_renewal: (true, 2), + erc20_contract_address: erc20_contract_address + }; + + // set fee address + start_cheat_caller_address(jolt_contract_address, ADMIN.try_into().unwrap()); + dispatcher.set_fee_address(FEE_ADDRESS.try_into().unwrap()); + stop_cheat_caller_address(jolt_contract_address); + + // approve contract to spend amount + start_cheat_caller_address(erc20_contract_address, ADDRESS1.try_into().unwrap()); + erc20_dispatcher.approve(jolt_contract_address, 10000000000000000000); + stop_cheat_caller_address(erc20_contract_address); + + // jolt + start_cheat_caller_address(jolt_contract_address, ADDRESS1.try_into().unwrap()); + start_cheat_block_timestamp(jolt_contract_address, 36000); + start_cheat_nonce(jolt_contract_address, 23); + + let jolt_id = dispatcher.jolt(jolt_params); + + // try to auto renew thrice - should fail on third try + dispatcher.auto_renew(ADDRESS1.try_into().unwrap(), jolt_id); + dispatcher.auto_renew(ADDRESS1.try_into().unwrap(), jolt_id); + dispatcher.auto_renew(ADDRESS1.try_into().unwrap(), jolt_id); + + stop_cheat_block_timestamp(jolt_contract_address); + stop_cheat_caller_address(jolt_contract_address); +} + +#[test] +fn test_auto_renewal_emits_susbcription_event() { + let (jolt_contract_address, erc20_contract_address) = __setup__(); + let dispatcher = IJoltDispatcher { contract_address: jolt_contract_address }; + let erc20_dispatcher = IERC20Dispatcher { contract_address: erc20_contract_address }; + + // user first need to subscribe + let jolt_params = JoltParams { + jolt_type: JoltType::Subscription, + recipient: contract_address_const::<0>(), + memo: "hey first subscription ever!", + amount: 2000000000000000000, + expiration_stamp: 0, + auto_renewal: (true, 2), + erc20_contract_address: erc20_contract_address + }; + + // set fee address + start_cheat_caller_address(jolt_contract_address, ADMIN.try_into().unwrap()); + dispatcher.set_fee_address(FEE_ADDRESS.try_into().unwrap()); + stop_cheat_caller_address(jolt_contract_address); + + // approve contract to spend amount + start_cheat_caller_address(erc20_contract_address, ADDRESS1.try_into().unwrap()); + erc20_dispatcher.approve(jolt_contract_address, 10000000000000000000); + stop_cheat_caller_address(erc20_contract_address); + + // jolt + start_cheat_caller_address(jolt_contract_address, ADDRESS1.try_into().unwrap()); + start_cheat_block_timestamp(jolt_contract_address, 36000); + start_cheat_nonce(jolt_contract_address, 23); + + let jolt_id = dispatcher.jolt(jolt_params); + + // try to auto renew + let mut spy = spy_events(); + dispatcher.auto_renew(ADDRESS1.try_into().unwrap(), jolt_id); + + // generate expected renewal jolt_id + let renewal_jolt_hash = PedersenTrait::new(0) + .update(FEE_ADDRESS.try_into().unwrap()) + .update(2000000000000000000) + .update(0) + .update(23) + .update(4) + .finalize(); + + let renewal_jolt_id: u256 = renewal_jolt_hash.try_into().unwrap(); + + // check for events + let expected_event = JoltEvent::Jolted( + Jolted { + jolt_id: renewal_jolt_id, + jolt_type: 'SUBSCRIPTION', + sender: ADDRESS1.try_into().unwrap(), + recipient: FEE_ADDRESS.try_into().unwrap(), + block_timestamp: 36000, + } + ); + spy.assert_emitted(@array![(jolt_contract_address, expected_event)]); + + stop_cheat_nonce(jolt_contract_address); + stop_cheat_block_timestamp(jolt_contract_address); + stop_cheat_caller_address(jolt_contract_address); +} + +// ************************************************************************* +// TEST - FEE ADDRESS +// ************************************************************************* +#[test] +fn test_set_fee_address() { + let (jolt_contract_address, _) = __setup__(); + let dispatcher = IJoltDispatcher { contract_address: jolt_contract_address }; + + // set fee address + start_cheat_caller_address(jolt_contract_address, ADMIN.try_into().unwrap()); + dispatcher.set_fee_address(FEE_ADDRESS.try_into().unwrap()); + stop_cheat_caller_address(jolt_contract_address); + + // check fee address + let fee_address = dispatcher.get_fee_address(); + assert(fee_address == FEE_ADDRESS.try_into().unwrap(), 'invalid fee address'); +} + +#[test] +#[should_panic(expected: ('Caller is not the owner',))] +fn test_only_admin_can_set_fee_address() { + let (jolt_contract_address, _) = __setup__(); + let dispatcher = IJoltDispatcher { contract_address: jolt_contract_address }; + + // set fee address + start_cheat_caller_address(jolt_contract_address, ADDRESS1.try_into().unwrap()); + dispatcher.set_fee_address(FEE_ADDRESS.try_into().unwrap()); + stop_cheat_caller_address(jolt_contract_address); +} + +// ************************************************************************* +// TEST - UPGRADE +// ************************************************************************* +#[test] +fn test_upgrade() { + let (jolt_contract_address, _) = __setup__(); + let dispatcher = IJoltUpgradeDispatcher { contract_address: jolt_contract_address }; + let upgrade_dispatcher = IUpgradeableDispatcher { contract_address: jolt_contract_address }; + let upgraded_class = declare("JoltUpgrade").unwrap().contract_class(); + let new_class_hash = *upgraded_class.class_hash; + + // set fee address + start_cheat_caller_address(jolt_contract_address, ADMIN.try_into().unwrap()); + upgrade_dispatcher.upgrade(new_class_hash); + stop_cheat_caller_address(jolt_contract_address); + + // check if upgrade worked by calling version which didn't previously exist + let version = dispatcher.version(); + assert(version == 2, 'failed to upgrade'); +} + +#[test] +#[should_panic(expected: ('Caller is not the owner',))] +fn test_upgrade_fails_if_not_admin() { + let (jolt_contract_address, _) = __setup__(); + let upgrade_dispatcher = IUpgradeableDispatcher { contract_address: jolt_contract_address }; + let upgraded_class = declare("JoltUpgrade").unwrap().contract_class(); + let new_class_hash = *upgraded_class.class_hash; + + // set fee address + start_cheat_caller_address(jolt_contract_address, ADDRESS1.try_into().unwrap()); + upgrade_dispatcher.upgrade(new_class_hash); + stop_cheat_caller_address(jolt_contract_address); +}