diff --git a/.circleci/config.yml b/.circleci/config.yml index e6fa5ee08..73d86c6d0 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -64,6 +64,13 @@ jobs: environment: HARDHAT_CONFIG_FILE: ../test/hardhat.config.venv.ts TEST_FILE: test/sample-test.ts + - run: + name: Test account+fee interaction (through plugin with venv) + command: ./test/test_plugin.sh + no_output_timeout: 1m + environment: + HARDHAT_CONFIG_FILE: ../test/hardhat.config.venv.ts + TEST_FILE: test/oz-account-test.ts package_build_and_publish: docker: - image: cimg/python:3.7 diff --git a/starknet_devnet/server.py b/starknet_devnet/server.py index d7bae3e5c..04201288c 100644 --- a/starknet_devnet/server.py +++ b/starknet_devnet/server.py @@ -43,7 +43,10 @@ async def add_transaction(): contract_address, transaction_hash = await starknet_wrapper.deploy(transaction) result_dict = {} elif tx_type == TransactionType.INVOKE_FUNCTION.name: - contract_address, transaction_hash, result_dict = await starknet_wrapper.invoke(transaction) + try: + contract_address, transaction_hash, result_dict = await starknet_wrapper.invoke(transaction) + except StarkException as stark_exception: + abort(Response(stark_exception.message, 500)) else: abort(Response(f"Invalid tx_type: {tx_type}.", 400)) @@ -228,10 +231,16 @@ def get_state_update(): return jsonify(state_update) @app.route("/feeder_gateway/estimate_fee", methods=["POST"]) -def estimate_fee(): +async def estimate_fee(): """Currently a dummy implementation, always returning 0.""" + transaction = validate_transaction(request.data, InvokeFunction) + try: + actual_fee = await starknet_wrapper.calculate_actual_fee(transaction) + except StarkException as stark_exception: + abort(Response(stark_exception.message, 500)) + return jsonify({ - "amount": 0, + "amount": actual_fee, "unit": "wei" }) @@ -302,7 +311,7 @@ def main(): # Uncomment this once fork support is added # origin = Origin(args.fork) if args.fork else NullOrigin() - # starknet_wrapper.set_origin(origin) + # starknet_wrapper.origin = origin if args.load_path: try: diff --git a/starknet_devnet/starknet_wrapper.py b/starknet_devnet/starknet_wrapper.py index d08137210..50ceb6436 100644 --- a/starknet_devnet/starknet_wrapper.py +++ b/starknet_devnet/starknet_wrapper.py @@ -12,6 +12,7 @@ import dill as pickle from starkware.starknet.business_logic.internal_transaction import InternalInvokeFunction from starkware.starknet.business_logic.state import CarriedState +from starkware.starknet.business_logic.transaction_fee import calculate_tx_fee_by_cairo_usage from starkware.starknet.definitions.transaction_type import TransactionType from starkware.starknet.services.api.gateway.contract_address import calculate_contract_address from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Deploy, Transaction @@ -21,7 +22,11 @@ from starkware.starknet.services.api.feeder_gateway.block_hash import calculate_block_hash from .origin import NullOrigin, Origin -from .util import Choice, StarknetDevnetException, TxStatus, fixed_length_hex, DummyExecutionInfo, enable_pickling, generate_state_update +from .util import ( + DEFAULT_GENERAL_CONFIG, + Choice, StarknetDevnetException, TxStatus, DummyExecutionInfo, + fixed_length_hex, enable_pickling, generate_state_update +) from .contract_wrapper import ContractWrapper from .transaction_wrapper import TransactionWrapper, DeployTransactionWrapper, InvokeTransactionWrapper from .postman_wrapper import LocalPostmanWrapper @@ -37,7 +42,7 @@ class StarknetWrapper: """ def __init__(self): - self.__origin: Origin = NullOrigin() + self.origin: Origin = NullOrigin() """Origin chain that this devnet was forked from.""" self.__address2contract_wrapper: Dict[int, ContractWrapper] = {} @@ -81,7 +86,7 @@ async def get_starknet(self): Returns the underlying Starknet instance, creating it first if necessary. """ if not self.__starknet: - self.__starknet = await Starknet.empty() + self.__starknet = await Starknet.empty(general_config=DEFAULT_GENERAL_CONFIG) await self.__preserve_current_state(self.__starknet.state.state) return self.__starknet @@ -121,10 +126,6 @@ def __get_contract_wrapper(self, address: int) -> ContractWrapper: return self.__address2contract_wrapper[address] - def set_origin(self, origin: Origin): - """Set the origin chain.""" - self.__origin = origin - async def deploy(self, deploy_transaction: Deploy): """ Deploys the contract specified with `transaction`. @@ -178,6 +179,13 @@ async def invoke(self, transaction: InvokeFunction): invoke_transaction: InternalInvokeFunction = InternalInvokeFunction.from_external(transaction, state.general_config) try: + # This check might not be needed in future versions which will interact with the token contract + if invoke_transaction.max_fee: # handle only if non-zero + actual_fee = await self.calculate_actual_fee(transaction) + if actual_fee > invoke_transaction.max_fee: + message = f"Actual fee exceeded max fee.\n{actual_fee} > {invoke_transaction.max_fee}" + raise StarknetDevnetException(message=message) + contract_wrapper = self.__get_contract_wrapper(invoke_transaction.contract_address) adapted_result, execution_info = await contract_wrapper.call_or_invoke( Choice.INVOKE, @@ -244,7 +252,7 @@ def get_transaction_status(self, transaction_hash: str): return ret - return self.__origin.get_transaction_status(transaction_hash) + return self.origin.get_transaction_status(transaction_hash) def get_transaction(self, transaction_hash: str): """Returns the transaction identified by `transaction_hash`.""" @@ -253,7 +261,7 @@ def get_transaction(self, transaction_hash: str): if tx_hash_int in self.__transaction_wrappers: return self.__transaction_wrappers[tx_hash_int].transaction - return self.__origin.get_transaction(transaction_hash) + return self.origin.get_transaction(transaction_hash) def get_transaction_receipt(self, transaction_hash: str): """Returns the transaction receipt of the transaction identified by `transaction_hash`.""" @@ -262,7 +270,7 @@ def get_transaction_receipt(self, transaction_hash: str): if tx_hash_int in self.__transaction_wrappers: return self.__transaction_wrappers[tx_hash_int].receipt - return self.__origin.get_transaction_receipt(transaction_hash) + return self.origin.get_transaction_receipt(transaction_hash) def get_transaction_trace(self, transaction_hash:str): """Returns the transaction trace of the tranasction indetified by `transaction_hash`""" @@ -279,11 +287,11 @@ def get_transaction_trace(self, transaction_hash:str): return transaction_wrapper.trace - return self.__origin.get_transaction_trace(transaction_hash) + return self.origin.get_transaction_trace(transaction_hash) def get_number_of_blocks(self) -> int: """Returns the number of blocks stored so far.""" - return len(self.__num2block) + self.__origin.get_number_of_blocks() + return len(self.__num2block) + self.origin.get_number_of_blocks() async def __generate_block(self, tx_wrapper: TransactionWrapper): """ @@ -342,14 +350,14 @@ def get_block_by_hash(self, block_hash: str): block_hash_int = int(block_hash, 16) if block_hash_int in self.__hash2block: return self.__hash2block[block_hash_int] - return self.__origin.get_block_by_hash(block_hash=block_hash) + return self.origin.get_block_by_hash(block_hash=block_hash) def get_block_by_number(self, block_number: int): """Returns the block whose block_number is provided""" if block_number is None: if self.__num2block: return self.__get_last_block() - return self.__origin.get_block_by_number(block_number) + return self.origin.get_block_by_number(block_number) if block_number < 0: message = f"Block number must be a non-negative integer; got: {block_number}." @@ -362,7 +370,7 @@ def get_block_by_number(self, block_number: int): if block_number in self.__num2block: return self.__num2block[block_number] - return self.__origin.get_block_by_number(block_number) + return self.origin.get_block_by_number(block_number) # pylint: disable=too-many-arguments async def __store_transaction(self, transaction: Transaction, contract_address: int, tx_hash: int, status: TxStatus, @@ -397,7 +405,7 @@ def get_code(self, contract_address: int) -> dict: if self.__is_contract_deployed(contract_address): contract_wrapper = self.__get_contract_wrapper(contract_address) return contract_wrapper.code - return self.__origin.get_code(contract_address) + return self.origin.get_code(contract_address) def get_full_contract(self, contract_address: int) -> dict: """Returns a `dict` contract definition of the contract at `contract_address`.""" @@ -415,7 +423,7 @@ async def get_storage_at(self, contract_address: int, key: int) -> str: contract_state = contract_states[contract_address] if key in contract_state.storage_updates: return hex(contract_state.storage_updates[key].value) - return self.__origin.get_storage_at(self, contract_address, key) + return self.origin.get_storage_at(contract_address, key) async def load_messaging_contract_in_l1(self, network_url: str, contract_address: str, network_id: str) -> dict: """Creates a Postman Wrapper instance and loads an already deployed Messaging contract in the L1 network""" @@ -496,7 +504,7 @@ def get_state_update(self, block_hash=None, block_number=None): if numeric_hash in self.__hash2block: return self.__hash2state_update[numeric_hash] - return self.__origin.get_state_update(block_hash=block_hash) + return self.origin.get_state_update(block_hash=block_hash) if block_number is not None: if block_number in self.__num2block: @@ -505,6 +513,22 @@ def get_state_update(self, block_hash=None, block_number=None): return self.__hash2state_update[numeric_hash] - return self.__origin.get_state_update(block_number=block_number) + return self.origin.get_state_update(block_number=block_number) + + return self.__last_state_update or self.origin.get_state_update() + + async def calculate_actual_fee(self, transaction: InvokeFunction): + """Calculates actual fee""" + state = await self.__get_state() + internal_tx = InternalInvokeFunction.from_external(transaction, state.general_config) - return self.__last_state_update or self.__origin.get_state_update() + state_copy = state.state._copy() # pylint: disable=protected-access + execution_info = await internal_tx.apply_state_updates(state_copy, state.general_config) + + cairo_resource_usage = execution_info.call_info.execution_resources.to_dict() + + return calculate_tx_fee_by_cairo_usage( + general_config=state.general_config, + cairo_resource_usage=cairo_resource_usage, + l1_gas_usage=0 + ) diff --git a/starknet_devnet/util.py b/starknet_devnet/util.py index 9e2375504..10c5a8569 100644 --- a/starknet_devnet/util.py +++ b/starknet_devnet/util.py @@ -10,6 +10,7 @@ from starkware.starkware_utils.error_handling import StarkException from starkware.starknet.testing.contract import StarknetContract from starkware.starknet.business_logic.state import CarriedState +from starkware.starknet.definitions.general_config import StarknetGeneralConfig from . import __version__ @@ -226,3 +227,26 @@ def generate_state_update(previous_state: CarriedState, current_state: CarriedSt "storage_diffs": storage_diffs } } + +DEFAULT_GENERAL_CONFIG = StarknetGeneralConfig.load({ + "event_commitment_tree_height": 64, + "global_state_commitment_tree_height": 251, + 'gas_price': 100000000000, + 'starknet_os_config': { + 'chain_id': 'TESTNET', + 'fee_token_address': '0x20abcf49dad3e9813d65bf1b8d54c5a0c9e6049a3027bd8c2ab315475c0a5c1' + }, + 'contract_storage_commitment_tree_height': 251, + 'cairo_resource_fee_weights': { + 'n_steps': 0.05, + 'pedersen_builtin': 0.4, + 'range_check_builtin': 0.4, + 'ecdsa_builtin': 25.6, + 'bitwise_builtin': 12.8, + 'output_builtin': 0.0, + 'ec_op_builtin': 0.0 + }, 'invoke_tx_max_n_steps': 1000000, + 'sequencer_address': '0x37b2cd6baaa515f520383bee7b7094f892f4c770695fc329a8973e841a971ae', + 'tx_version': 0, + 'tx_commitment_tree_height': 64 +}) diff --git a/test/test_endpoints.py b/test/test_endpoints.py index dd80ad47e..110e6b518 100644 --- a/test/test_endpoints.py +++ b/test/test_endpoints.py @@ -30,7 +30,7 @@ def send_call(req_dict: dict): ) def assert_deploy_resp(resp: bytes): - """Asserts the validity of invoke response body.""" + """Asserts the validity of deploy response body.""" resp_dict = json.loads(resp.data.decode("utf-8")) assert set(resp_dict.keys()) == set(["address", "code", "transaction_hash"]) assert resp_dict["code"] == "TRANSACTION_RECEIVED" diff --git a/test/test_plugin.sh b/test/test_plugin.sh index 98cead04a..000187e49 100755 --- a/test/test_plugin.sh +++ b/test/test_plugin.sh @@ -31,7 +31,7 @@ function result_assertion() { } cd starknet-hardhat-example -mv "$HARDHAT_CONFIG_FILE" hardhat.config.ts +cp "$HARDHAT_CONFIG_FILE" hardhat.config.ts # npx hardhat starknet-compile <- Already executed in setup_example.sh # devnet already defined in config as localhost:5000 npx hardhat starknet-deploy \