Skip to content
This repository has been archived by the owner on Dec 15, 2023. It is now read-only.

Commit

Permalink
Estimate fee (#62) [skip ci]
Browse files Browse the repository at this point in the history
* Use execution_info.actual_fee for fee estimation

* Use calculate_tx_fee_by_cairo_usage

* Add account+fee testing through plugin

* Add fee check to invoke

* Add error catching to fee estimation

* Fix error handling for invoke and estimate_fee

* Use cp instead of mv for custom hardhat config

* Use fee weights as in alpha

* Make origin public - fix lint error

* Readability changes [skip ci]
  • Loading branch information
FabijanC authored Mar 24, 2022
1 parent 4b675ea commit ab69a92
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 26 deletions.
7 changes: 7 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ jobs:
environment:
HARDHAT_CONFIG_FILE: ../test/hardhat.config.venv.ts
TEST_FILE: test/sample-test.ts
- run:
name: Test account+fee interaction (through plugin with venv)
command: ./test/test_plugin.sh
no_output_timeout: 1m
environment:
HARDHAT_CONFIG_FILE: ../test/hardhat.config.venv.ts
TEST_FILE: test/oz-account-test.ts
package_build_and_publish:
docker:
- image: cimg/python:3.7
Expand Down
17 changes: 13 additions & 4 deletions starknet_devnet/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ async def add_transaction():
contract_address, transaction_hash = await starknet_wrapper.deploy(transaction)
result_dict = {}
elif tx_type == TransactionType.INVOKE_FUNCTION.name:
contract_address, transaction_hash, result_dict = await starknet_wrapper.invoke(transaction)
try:
contract_address, transaction_hash, result_dict = await starknet_wrapper.invoke(transaction)
except StarkException as stark_exception:
abort(Response(stark_exception.message, 500))
else:
abort(Response(f"Invalid tx_type: {tx_type}.", 400))

Expand Down Expand Up @@ -228,10 +231,16 @@ def get_state_update():
return jsonify(state_update)

@app.route("/feeder_gateway/estimate_fee", methods=["POST"])
def estimate_fee():
async def estimate_fee():
"""Currently a dummy implementation, always returning 0."""
transaction = validate_transaction(request.data, InvokeFunction)
try:
actual_fee = await starknet_wrapper.calculate_actual_fee(transaction)
except StarkException as stark_exception:
abort(Response(stark_exception.message, 500))

return jsonify({
"amount": 0,
"amount": actual_fee,
"unit": "wei"
})

Expand Down Expand Up @@ -302,7 +311,7 @@ def main():

# Uncomment this once fork support is added
# origin = Origin(args.fork) if args.fork else NullOrigin()
# starknet_wrapper.set_origin(origin)
# starknet_wrapper.origin = origin

if args.load_path:
try:
Expand Down
64 changes: 44 additions & 20 deletions starknet_devnet/starknet_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dill as pickle
from starkware.starknet.business_logic.internal_transaction import InternalInvokeFunction
from starkware.starknet.business_logic.state import CarriedState
from starkware.starknet.business_logic.transaction_fee import calculate_tx_fee_by_cairo_usage
from starkware.starknet.definitions.transaction_type import TransactionType
from starkware.starknet.services.api.gateway.contract_address import calculate_contract_address
from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Deploy, Transaction
Expand All @@ -21,7 +22,11 @@
from starkware.starknet.services.api.feeder_gateway.block_hash import calculate_block_hash

from .origin import NullOrigin, Origin
from .util import Choice, StarknetDevnetException, TxStatus, fixed_length_hex, DummyExecutionInfo, enable_pickling, generate_state_update
from .util import (
DEFAULT_GENERAL_CONFIG,
Choice, StarknetDevnetException, TxStatus, DummyExecutionInfo,
fixed_length_hex, enable_pickling, generate_state_update
)
from .contract_wrapper import ContractWrapper
from .transaction_wrapper import TransactionWrapper, DeployTransactionWrapper, InvokeTransactionWrapper
from .postman_wrapper import LocalPostmanWrapper
Expand All @@ -37,7 +42,7 @@ class StarknetWrapper:
"""

def __init__(self):
self.__origin: Origin = NullOrigin()
self.origin: Origin = NullOrigin()
"""Origin chain that this devnet was forked from."""

self.__address2contract_wrapper: Dict[int, ContractWrapper] = {}
Expand Down Expand Up @@ -81,7 +86,7 @@ async def get_starknet(self):
Returns the underlying Starknet instance, creating it first if necessary.
"""
if not self.__starknet:
self.__starknet = await Starknet.empty()
self.__starknet = await Starknet.empty(general_config=DEFAULT_GENERAL_CONFIG)
await self.__preserve_current_state(self.__starknet.state.state)
return self.__starknet

Expand Down Expand Up @@ -121,10 +126,6 @@ def __get_contract_wrapper(self, address: int) -> ContractWrapper:

return self.__address2contract_wrapper[address]

def set_origin(self, origin: Origin):
"""Set the origin chain."""
self.__origin = origin

async def deploy(self, deploy_transaction: Deploy):
"""
Deploys the contract specified with `transaction`.
Expand Down Expand Up @@ -178,6 +179,13 @@ async def invoke(self, transaction: InvokeFunction):
invoke_transaction: InternalInvokeFunction = InternalInvokeFunction.from_external(transaction, state.general_config)

try:
# This check might not be needed in future versions which will interact with the token contract
if invoke_transaction.max_fee: # handle only if non-zero
actual_fee = await self.calculate_actual_fee(transaction)
if actual_fee > invoke_transaction.max_fee:
message = f"Actual fee exceeded max fee.\n{actual_fee} > {invoke_transaction.max_fee}"
raise StarknetDevnetException(message=message)

contract_wrapper = self.__get_contract_wrapper(invoke_transaction.contract_address)
adapted_result, execution_info = await contract_wrapper.call_or_invoke(
Choice.INVOKE,
Expand Down Expand Up @@ -244,7 +252,7 @@ def get_transaction_status(self, transaction_hash: str):

return ret

return self.__origin.get_transaction_status(transaction_hash)
return self.origin.get_transaction_status(transaction_hash)

def get_transaction(self, transaction_hash: str):
"""Returns the transaction identified by `transaction_hash`."""
Expand All @@ -253,7 +261,7 @@ def get_transaction(self, transaction_hash: str):
if tx_hash_int in self.__transaction_wrappers:
return self.__transaction_wrappers[tx_hash_int].transaction

return self.__origin.get_transaction(transaction_hash)
return self.origin.get_transaction(transaction_hash)

def get_transaction_receipt(self, transaction_hash: str):
"""Returns the transaction receipt of the transaction identified by `transaction_hash`."""
Expand All @@ -262,7 +270,7 @@ def get_transaction_receipt(self, transaction_hash: str):
if tx_hash_int in self.__transaction_wrappers:
return self.__transaction_wrappers[tx_hash_int].receipt

return self.__origin.get_transaction_receipt(transaction_hash)
return self.origin.get_transaction_receipt(transaction_hash)

def get_transaction_trace(self, transaction_hash:str):
"""Returns the transaction trace of the tranasction indetified by `transaction_hash`"""
Expand All @@ -279,11 +287,11 @@ def get_transaction_trace(self, transaction_hash:str):

return transaction_wrapper.trace

return self.__origin.get_transaction_trace(transaction_hash)
return self.origin.get_transaction_trace(transaction_hash)

def get_number_of_blocks(self) -> int:
"""Returns the number of blocks stored so far."""
return len(self.__num2block) + self.__origin.get_number_of_blocks()
return len(self.__num2block) + self.origin.get_number_of_blocks()

async def __generate_block(self, tx_wrapper: TransactionWrapper):
"""
Expand Down Expand Up @@ -342,14 +350,14 @@ def get_block_by_hash(self, block_hash: str):
block_hash_int = int(block_hash, 16)
if block_hash_int in self.__hash2block:
return self.__hash2block[block_hash_int]
return self.__origin.get_block_by_hash(block_hash=block_hash)
return self.origin.get_block_by_hash(block_hash=block_hash)

def get_block_by_number(self, block_number: int):
"""Returns the block whose block_number is provided"""
if block_number is None:
if self.__num2block:
return self.__get_last_block()
return self.__origin.get_block_by_number(block_number)
return self.origin.get_block_by_number(block_number)

if block_number < 0:
message = f"Block number must be a non-negative integer; got: {block_number}."
Expand All @@ -362,7 +370,7 @@ def get_block_by_number(self, block_number: int):
if block_number in self.__num2block:
return self.__num2block[block_number]

return self.__origin.get_block_by_number(block_number)
return self.origin.get_block_by_number(block_number)

# pylint: disable=too-many-arguments
async def __store_transaction(self, transaction: Transaction, contract_address: int, tx_hash: int, status: TxStatus,
Expand Down Expand Up @@ -397,7 +405,7 @@ def get_code(self, contract_address: int) -> dict:
if self.__is_contract_deployed(contract_address):
contract_wrapper = self.__get_contract_wrapper(contract_address)
return contract_wrapper.code
return self.__origin.get_code(contract_address)
return self.origin.get_code(contract_address)

def get_full_contract(self, contract_address: int) -> dict:
"""Returns a `dict` contract definition of the contract at `contract_address`."""
Expand All @@ -415,7 +423,7 @@ async def get_storage_at(self, contract_address: int, key: int) -> str:
contract_state = contract_states[contract_address]
if key in contract_state.storage_updates:
return hex(contract_state.storage_updates[key].value)
return self.__origin.get_storage_at(self, contract_address, key)
return self.origin.get_storage_at(contract_address, key)

async def load_messaging_contract_in_l1(self, network_url: str, contract_address: str, network_id: str) -> dict:
"""Creates a Postman Wrapper instance and loads an already deployed Messaging contract in the L1 network"""
Expand Down Expand Up @@ -496,7 +504,7 @@ def get_state_update(self, block_hash=None, block_number=None):
if numeric_hash in self.__hash2block:
return self.__hash2state_update[numeric_hash]

return self.__origin.get_state_update(block_hash=block_hash)
return self.origin.get_state_update(block_hash=block_hash)

if block_number is not None:
if block_number in self.__num2block:
Expand All @@ -505,6 +513,22 @@ def get_state_update(self, block_hash=None, block_number=None):

return self.__hash2state_update[numeric_hash]

return self.__origin.get_state_update(block_number=block_number)
return self.origin.get_state_update(block_number=block_number)

return self.__last_state_update or self.origin.get_state_update()

async def calculate_actual_fee(self, transaction: InvokeFunction):
"""Calculates actual fee"""
state = await self.__get_state()
internal_tx = InternalInvokeFunction.from_external(transaction, state.general_config)

return self.__last_state_update or self.__origin.get_state_update()
state_copy = state.state._copy() # pylint: disable=protected-access
execution_info = await internal_tx.apply_state_updates(state_copy, state.general_config)

cairo_resource_usage = execution_info.call_info.execution_resources.to_dict()

return calculate_tx_fee_by_cairo_usage(
general_config=state.general_config,
cairo_resource_usage=cairo_resource_usage,
l1_gas_usage=0
)
24 changes: 24 additions & 0 deletions starknet_devnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from starkware.starkware_utils.error_handling import StarkException
from starkware.starknet.testing.contract import StarknetContract
from starkware.starknet.business_logic.state import CarriedState
from starkware.starknet.definitions.general_config import StarknetGeneralConfig

from . import __version__

Expand Down Expand Up @@ -226,3 +227,26 @@ def generate_state_update(previous_state: CarriedState, current_state: CarriedSt
"storage_diffs": storage_diffs
}
}

DEFAULT_GENERAL_CONFIG = StarknetGeneralConfig.load({
"event_commitment_tree_height": 64,
"global_state_commitment_tree_height": 251,
'gas_price': 100000000000,
'starknet_os_config': {
'chain_id': 'TESTNET',
'fee_token_address': '0x20abcf49dad3e9813d65bf1b8d54c5a0c9e6049a3027bd8c2ab315475c0a5c1'
},
'contract_storage_commitment_tree_height': 251,
'cairo_resource_fee_weights': {
'n_steps': 0.05,
'pedersen_builtin': 0.4,
'range_check_builtin': 0.4,
'ecdsa_builtin': 25.6,
'bitwise_builtin': 12.8,
'output_builtin': 0.0,
'ec_op_builtin': 0.0
}, 'invoke_tx_max_n_steps': 1000000,
'sequencer_address': '0x37b2cd6baaa515f520383bee7b7094f892f4c770695fc329a8973e841a971ae',
'tx_version': 0,
'tx_commitment_tree_height': 64
})
2 changes: 1 addition & 1 deletion test/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def send_call(req_dict: dict):
)

def assert_deploy_resp(resp: bytes):
"""Asserts the validity of invoke response body."""
"""Asserts the validity of deploy response body."""
resp_dict = json.loads(resp.data.decode("utf-8"))
assert set(resp_dict.keys()) == set(["address", "code", "transaction_hash"])
assert resp_dict["code"] == "TRANSACTION_RECEIVED"
Expand Down
2 changes: 1 addition & 1 deletion test/test_plugin.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function result_assertion() {
}

cd starknet-hardhat-example
mv "$HARDHAT_CONFIG_FILE" hardhat.config.ts
cp "$HARDHAT_CONFIG_FILE" hardhat.config.ts
# npx hardhat starknet-compile <- Already executed in setup_example.sh
# devnet already defined in config as localhost:5000
npx hardhat starknet-deploy \
Expand Down

0 comments on commit ab69a92

Please sign in to comment.