diff --git a/starknet_devnet/server.py b/starknet_devnet/server.py index 423c58b66..0b3863694 100644 --- a/starknet_devnet/server.py +++ b/starknet_devnet/server.py @@ -9,9 +9,15 @@ from .util import TxStatus, parse_args app = Flask(__name__) + address2contract = {} +"""Maps contract address to contract instance.""" + address2types = {} +"""Maps contract address to a dict of types (structs) used in that contract.""" + transactions = [] +"""A chronological list of transactions.""" class StarknetWrapper: def __init__(self): @@ -31,20 +37,29 @@ def is_alive(): return "Alive!!!" async def deploy(contract_definition: ContractDefinition): + """ + Deploys the contract whose definition is provided and returns deployment address in hex form. + The other returned object is present to conform to a past version of call_or_invoke, but will be removed in future versions. + """ + starknet = await starknet_wrapper.get_starknet() contract = await starknet.deploy(contract_def=contract_definition) hex_address = hex(contract.contract_address) address2contract[hex_address] = contract return hex_address, {} -def attempt_hex(x): - try: - return hex(x) - except: - pass - return x +def generate_complex(calldata, calldata_i: int, input_type: str, types): + """ + Converts members of `calldata` to a more complex type specified by `input_type`: + - puts members of a struct into a tuple + - puts members of a tuple into a tuple + + The `calldata_i` is incremented according to how many `calldata` members were consumed. + `types` is a dict that maps a type's name to its specification. + + Returns the `calldata` converted to the type specified by `input_type` (tuple if struct or tuple, number). Also returns the incremented `calldata_i`. + """ -def generate_complex(calldata, calldata_i, input_type, types): if input_type == "felt": return calldata[calldata_i], calldata_i + 1 @@ -64,6 +79,16 @@ def generate_complex(calldata, calldata_i, input_type, types): return tuple(arr), calldata_i def adapt_calldata(calldata, expected_inputs, types): + """ + Simulatenously iterates over `calldata` and `expected_inputs`. + + The `calldata` is converted to types specified by `expected_inputs`. + + `types` is a dict that maps a type's name to its specification. + + Returns a list representing adapted `calldata`. + """ + last_name = None last_value = None calldata_i = 0 @@ -105,6 +130,18 @@ def adapt_calldata(calldata, expected_inputs, types): return adapted_calldata def adapt_output(received, ret): + """ + Adapts the `received` object to format expected by client (list of hex strings). + If `received` is an instance of `list`, it is understood that it corresponds to a felt*, so first its length is appended. + If `received` is iterable, it is either a struct, a tuple or a felt*. + Otherwise it is a `felt`. + `ret` is recursively populated (and should probably be empty on first call). + + Example: + >>> L = []; adapt_output((1, [5, 10]), L); print(L) + ['0x1', '0x2', '0x5', '0xa'] + """ + if isinstance(received, list): ret.append(hex(len(received))) try: @@ -148,6 +185,11 @@ def is_transaction_hash_legal(transaction_hash: int) -> bool: return 0 <= transaction_hash < len(transactions) def store_types(contract_address: str, abi): + """ + Stores the types (structs) used in a contract. + The types are read from `abi`, and stored to a global map under the key `contract_address` which is expected to be a hex string. + """ + structs = [x for x in abi if x["type"] == "struct"] type_dict = { struct["name"]: struct for struct in structs } address2types[contract_address] = type_dict @@ -171,6 +213,10 @@ def store_transaction(contract_address: str, tx_type: str) -> str: @app.route("/gateway/add_transaction", methods=["POST"]) async def add_transaction(): + """ + Endpoint for accepting DEPLOY and INVOKE_FUNCTION transactions. + """ + raw_data = request.get_data() transaction = Transaction.loads(raw_data) # TODO transaction.calculate_hash() @@ -207,6 +253,10 @@ def get_contract_addresses(): @app.route("/feeder_gateway/call_contract", methods=["POST"]) async def call_contract(): + """ + Endpoint for receiving calls (not invokes) of contract functions. + """ + raw_data = request.get_data() call_specifications = InvokeFunction.loads(raw_data) result_dict = await call_or_invoke("call", @@ -257,6 +307,10 @@ def get_transaction_status(): @app.route("/feeder_gateway/get_transaction", methods=["GET"]) def get_transaction(): + """ + Returns the transaction identified by the transactionHash argument in the GET request. + """ + transaction_hash = request.args.get("transactionHash", type=lambda x: int(x, 16)) if is_transaction_hash_legal(transaction_hash): return jsonify(transactions[transaction_hash]) @@ -271,4 +325,4 @@ def main(): app.run(**vars(args)) if __name__ == "__main__": - main() \ No newline at end of file + main()