diff --git a/README.md b/README.md index 7a51190f..8b8ad73d 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ We recommend starting from [Setting up the environment](https://cairo-lang.org/d # Installation instructions You should be able to download the python package zip file directly from -[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.4.1) +[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.4.2) and install it using ``pip``. See [Setting up the environment](https://cairo-lang.org/docs/quickstart.html). @@ -54,7 +54,7 @@ Once the docker image is built, you can fetch the python package zip file using: ```bash > container_id=$(docker create cairo) -> docker cp ${container_id}:/app/cairo-lang-0.4.1.zip . +> docker cp ${container_id}:/app/cairo-lang-0.4.2.zip . > docker rm -v ${container_id} ``` diff --git a/src/services/everest/business_logic/state.py b/src/services/everest/business_logic/state.py index c3ef5bb3..0d07a49c 100644 --- a/src/services/everest/business_logic/state.py +++ b/src/services/everest/business_logic/state.py @@ -136,6 +136,7 @@ def _apply(self): @contextlib.contextmanager def copy_and_apply(self) -> Iterator[TCarriedState]: copied_state = self._copy() + # The exit logic will not be called in case an exception is raised inside the context. yield copied_state copied_state._apply() # Apply to self. diff --git a/src/starkware/cairo/common/CMakeLists.txt b/src/starkware/cairo/common/CMakeLists.txt index d4426f76..d71f42fd 100644 --- a/src/starkware/cairo/common/CMakeLists.txt +++ b/src/starkware/cairo/common/CMakeLists.txt @@ -4,6 +4,7 @@ python_lib(cairo_common_lib alloc.cairo bitwise.cairo cairo_builtins.cairo + cairo_blake2s/blake2s_utils.py cairo_keccak/keccak_utils.py cairo_sha256/sha256_utils.py default_dict.cairo diff --git a/src/starkware/cairo/common/cairo_blake2s/blake2s_utils.py b/src/starkware/cairo/common/cairo_blake2s/blake2s_utils.py new file mode 100644 index 00000000..8c4b0507 --- /dev/null +++ b/src/starkware/cairo/common/cairo_blake2s/blake2s_utils.py @@ -0,0 +1,86 @@ +from typing import List, Tuple + +IV = [ + 0x6A09E667, + 0xBB67AE85, + 0x3C6EF372, + 0xA54FF53A, + 0x510E527F, + 0x9B05688C, + 0x1F83D9AB, + 0x5BE0CD19, +] + +SIGMA = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3], + [11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4], + [7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8], + [9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13], + [2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9], + [12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11], + [13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10], + [6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5], + [10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0], +] + + +def right_rot(value, n): + return (value >> n) | ((value & (2 ** n - 1)) << (32 - n)) + + +def blake2s_compress( + h: List[int], message: List[int], t0: int, t1: int, f0: int, f1: int +) -> List[int]: + """ + h is a list of 8 32-bit words. + message is a list of 16 32-bit words. + """ + state = h + IV[:4] + [x % 2 ** 32 for x in [IV[4] ^ t0, IV[5] ^ t1, IV[6] ^ f0, IV[7] ^ f1]] + for i in range(10): + state = blake_round(state, message, SIGMA[i]) + return [x ^ v0 ^ v1 for x, v0, v1 in zip(h, state[:8], state[8:])] + + +def blake_round(state: List[int], message: List[int], sigma: List[int]) -> List[int]: + state = list(state) + state[0], state[4], state[8], state[12] = mix( + state[0], state[4], state[8], state[12], message[sigma[0]], message[sigma[1]] + ) + state[1], state[5], state[9], state[13] = mix( + state[1], state[5], state[9], state[13], message[sigma[2]], message[sigma[3]] + ) + state[2], state[6], state[10], state[14] = mix( + state[2], state[6], state[10], state[14], message[sigma[4]], message[sigma[5]] + ) + state[3], state[7], state[11], state[15] = mix( + state[3], state[7], state[11], state[15], message[sigma[6]], message[sigma[7]] + ) + + state[0], state[5], state[10], state[15] = mix( + state[0], state[5], state[10], state[15], message[sigma[8]], message[sigma[9]] + ) + state[1], state[6], state[11], state[12] = mix( + state[1], state[6], state[11], state[12], message[sigma[10]], message[sigma[11]] + ) + state[2], state[7], state[8], state[13] = mix( + state[2], state[7], state[8], state[13], message[sigma[12]], message[sigma[13]] + ) + state[3], state[4], state[9], state[14] = mix( + state[3], state[4], state[9], state[14], message[sigma[14]], message[sigma[15]] + ) + + return state + + +def mix(a: int, b: int, c: int, d: int, m0: int, m1: int) -> Tuple[int, int, int, int]: + a = (a + b + m0) % 2 ** 32 + d = right_rot((d ^ a), 16) + c = (c + d) % 2 ** 32 + b = right_rot((b ^ c), 12) + a = (a + b + m1) % 2 ** 32 + d = right_rot((d ^ a), 8) + c = (c + d) % 2 ** 32 + b = right_rot((b ^ c), 7) + + return a, b, c, d diff --git a/src/starkware/cairo/lang/VERSION b/src/starkware/cairo/lang/VERSION index 267577d4..2b7c5ae0 100644 --- a/src/starkware/cairo/lang/VERSION +++ b/src/starkware/cairo/lang/VERSION @@ -1 +1 @@ -0.4.1 +0.4.2 diff --git a/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions.py b/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions.py index 182ad330..43e78d9f 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions.py @@ -11,8 +11,8 @@ ExprConst, ExprDeref, Expression, - ExprHint, ExprFutureLabel, + ExprHint, ExprIdentifier, ExprNeg, ExprOperator, diff --git a/src/starkware/cairo/lang/ide/vscode-cairo/package.json b/src/starkware/cairo/lang/ide/vscode-cairo/package.json index c4768c26..e0d37b82 100644 --- a/src/starkware/cairo/lang/ide/vscode-cairo/package.json +++ b/src/starkware/cairo/lang/ide/vscode-cairo/package.json @@ -2,7 +2,7 @@ "name": "cairo", "displayName": "Cairo", "description": "Support Cairo syntax", - "version": "0.4.1", + "version": "0.4.2", "engines": { "vscode": "^1.30.0" }, diff --git a/src/starkware/cairo/lang/vm/vm.py b/src/starkware/cairo/lang/vm/vm.py index 0688f98d..d3223293 100644 --- a/src/starkware/cairo/lang/vm/vm.py +++ b/src/starkware/cairo/lang/vm/vm.py @@ -349,8 +349,8 @@ def load_hints(self, program: Program, program_base: MaybeRelocatable): compiled=self.compile_hint( hint.code, f"", hint_index=hint_index ), - # Use hint=hint in the lambda's arguments to capture this value (otherwise, it - # will use the same hint object for all iterations). + # Use hint=hint in the lambda's arguments to capture this value (otherwise, + # it will use the same hint object for all iterations). consts=lambda pc, ap, fp, memory, hint=hint: VmConsts( context=VmConstsContext( identifiers=program.identifiers, diff --git a/src/starkware/starknet/business_logic/CMakeLists.txt b/src/starkware/starknet/business_logic/CMakeLists.txt index 0827d1dd..9157b98d 100644 --- a/src/starkware/starknet/business_logic/CMakeLists.txt +++ b/src/starkware/starknet/business_logic/CMakeLists.txt @@ -63,5 +63,6 @@ python_lib(starknet_internal_transaction_lib starkware_utils_lib pip_marshmallow pip_marshmallow_dataclass + pip_marshmallow_enum pip_marshmallow_oneofschema ) diff --git a/src/starkware/starknet/business_logic/internal_transaction_interface.py b/src/starkware/starknet/business_logic/internal_transaction_interface.py index ce87c023..fa0e8592 100644 --- a/src/starkware/starknet/business_logic/internal_transaction_interface.py +++ b/src/starkware/starknet/business_logic/internal_transaction_interface.py @@ -1,5 +1,6 @@ import asyncio import dataclasses +import logging from abc import abstractmethod from dataclasses import field from typing import Dict, Iterable, List, Optional, Set, Tuple, cast @@ -16,12 +17,16 @@ from starkware.cairo.lang.vm.utils import RunResources from starkware.starknet.business_logic.state import CarriedState, StateSelector from starkware.starknet.definitions import fields +from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.definitions.general_config import StarknetGeneralConfig from starkware.starknet.services.api.gateway.transaction import Transaction from starkware.starkware_utils.config_base import Config +from starkware.starkware_utils.error_handling import StarkException from starkware.starkware_utils.marshmallow_dataclass_fields import SetField from starkware.starkware_utils.validated_dataclass import ValidatedDataclass +logger = logging.getLogger(__name__) + @dataclasses.dataclass(frozen=True) class L2ToL1MessageInfo(ValidatedDataclass): @@ -190,9 +195,20 @@ async def apply_state_updates( assert isinstance(general_config, StarknetGeneralConfig) with state.copy_and_apply() as state_to_update: - execution_info = await self._apply_specific_state_updates( - state=state_to_update, general_config=general_config - ) + try: + execution_info = await self._apply_specific_state_updates( + state=state_to_update, general_config=general_config + ) + except StarkException: + # Raise StarkException-s as-is, so failure information is not lost. + raise + except Exception as exception: + # Wrap all exceptions with StarkException, so the Batcher can continue running + # even after unexpected errors. + logger.error(f"Unexpected failure; exception details: {exception}.", exc_info=True) + raise StarkException( + code=StarknetErrorCode.UNEXPECTED_FAILURE, message=str(exception) + ) return execution_info diff --git a/src/starkware/starknet/cli/CMakeLists.txt b/src/starkware/starknet/cli/CMakeLists.txt index 76c9d903..023dc405 100644 --- a/src/starkware/starknet/cli/CMakeLists.txt +++ b/src/starkware/starknet/cli/CMakeLists.txt @@ -6,6 +6,7 @@ python_lib(starknet_cli_lib LIBS cairo_compile_lib + cairo_tracer_lib cairo_version_lib cairo_vm_utils_lib services_external_api_lib diff --git a/src/starkware/starknet/cli/starknet_cli.py b/src/starkware/starknet/cli/starknet_cli.py index ef9d563b..7310f1dd 100755 --- a/src/starkware/starknet/cli/starknet_cli.py +++ b/src/starkware/starknet/cli/starknet_cli.py @@ -14,6 +14,7 @@ from starkware.cairo.lang.compiler.program import Program from starkware.cairo.lang.compiler.type_system import mark_type_resolved from starkware.cairo.lang.compiler.type_utils import check_felts_only_type +from starkware.cairo.lang.tracer.tracer_data import field_element_repr from starkware.cairo.lang.version import __version__ from starkware.cairo.lang.vm.reconstruct_traceback import reconstruct_traceback from starkware.starknet.compiler.compile import get_selector_from_name @@ -26,6 +27,10 @@ from starkware.starkware_utils.error_handling import StarkErrorCode +def felt_formatter(hex_felt: str) -> str: + return field_element_repr(val=int(hex_felt, 16), prime=fields.FeltField.upper_bound) + + def get_gateway_client(args) -> GatewayClient: gateway_url = os.environ.get("STARKNET_GATEWAY_URL") if args.gateway_url is not None: @@ -184,7 +189,7 @@ async def invoke_or_call(args, command_args, call: bool): if call: feeder_client = get_feeder_gateway_client(args) gateway_response = await feeder_client.call_contract(tx, args.block_id) - print(*gateway_response["result"]) + print(*map(felt_formatter, gateway_response["result"])) else: gateway_client = get_gateway_client(args) gateway_response = await gateway_client.add_transaction(tx=tx) @@ -262,7 +267,7 @@ def handle_network_param(args): network = os.environ.get("STARKNET_NETWORK") if args.network is None else args.network if network is not None: if network != "alpha": - print(f"Unknown network '{network}'.") + print(f"Unknown network '{network}'.", file=sys.stderr) return 1 dns = "alpha2.starknet.io" diff --git a/src/starkware/starknet/compiler/data_encoder.py b/src/starkware/starknet/compiler/data_encoder.py index dc98b331..1a740030 100644 --- a/src/starkware/starknet/compiler/data_encoder.py +++ b/src/starkware/starknet/compiler/data_encoder.py @@ -1,6 +1,6 @@ import dataclasses from enum import Enum, auto -from typing import List, Optional, Sequence, Tuple +from typing import Callable, List, Optional, Sequence, Tuple from starkware.cairo.lang.compiler.ast.cairo_types import ( CairoType, @@ -313,16 +313,36 @@ def decode_data( class DataEncoder(DataEncodingProcessor): + def __init__( + self, + arg_name_func: Callable[[ArgumentInfo], str], + encoding_type: EncodingType, + has_range_check_builtin: bool, + identifiers: IdentifierManager, + ): + """ + Constructs a DataEncoder instance. + + arg_name_func is a function that get ArgumentInfo and returns the name of the reference + containing that argument. + """ + super().__init__( + encoding_type=encoding_type, + has_range_check_builtin=has_range_check_builtin, + identifiers=identifiers, + ) + self.arg_name_func = arg_name_func + def process_felt(self, arg_info: ArgumentInfo): return f"""\ -assert [__{self.var_name}_ptr] = {arg_info.name} +assert [__{self.var_name}_ptr] = {self.arg_name_func(arg_info)} let __{self.var_name}_ptr = __{self.var_name}_ptr + 1 """ def process_felt_ptr(self, arg_info: ArgumentInfo): return f"""\ # Check that the length is non-negative. -assert [range_check_ptr] = {arg_info.name}_len +assert [range_check_ptr] = {self.arg_name_func(arg_info)}_len # Store the updated range_check_ptr as a local variable to keep it available after # the memcpy. local range_check_ptr = range_check_ptr + 1 @@ -330,8 +350,12 @@ def process_felt_ptr(self, arg_info: ArgumentInfo): let __{self.var_name}_ptr_copy = __{self.var_name}_ptr # Store the updated __{self.var_name}_ptr as a local variable to keep it available after # the memcpy. -local __{self.var_name}_ptr : felt* = __{self.var_name}_ptr + {arg_info.name}_len -memcpy(dst=__{self.var_name}_ptr_copy, src={arg_info.name}, len={arg_info.name}_len) +local __{self.var_name}_ptr : felt* = __{self.var_name}_ptr + \ +{self.arg_name_func(arg_info)}_len +memcpy( + dst=__{self.var_name}_ptr_copy, + src={self.arg_name_func(arg_info)}, + len={self.arg_name_func(arg_info)}_len) """ def process_felts_object(self, arg_info: ArgumentInfo, size: int): @@ -340,8 +364,8 @@ def process_felts_object(self, arg_info: ArgumentInfo, size: int): for i in range(size) ) return f"""\ -# Create a reference to {arg_info.name} as felt*. -let __{self.var_name}_tmp : felt* = cast(&{arg_info.name}, felt*) +# Create a reference to {self.arg_name_func(arg_info)} as felt*. +let __{self.var_name}_tmp : felt* = cast(&{self.arg_name_func(arg_info)}, felt*) {body} let __{self.var_name}_ptr = __{self.var_name}_ptr + {size} """ @@ -352,9 +376,11 @@ def encode_data( encoding_type: EncodingType, has_range_check_builtin: bool, identifiers: IdentifierManager, + arg_name_func: Callable[[ArgumentInfo], str] = lambda arg_info: arg_info.name, ) -> List[CommentedCodeElement]: parser = DataEncoder( + arg_name_func=arg_name_func, encoding_type=encoding_type, has_range_check_builtin=has_range_check_builtin, identifiers=identifiers, diff --git a/src/starkware/starknet/compiler/data_encoder_test.py b/src/starkware/starknet/compiler/data_encoder_test.py index 33ce8dd6..e4f3a598 100644 --- a/src/starkware/starknet/compiler/data_encoder_test.py +++ b/src/starkware/starknet/compiler/data_encoder_test.py @@ -182,19 +182,20 @@ def test_encode_data_for_return(): encoding_type=EncodingType.RETURN, has_range_check_builtin=True, identifiers=identifiers, + arg_name_func=lambda arg_info: f"x.{arg_info.name}", ) assert ( "".join(code_element.format(100) + "\n" for code_element in code_elements) == """\ -assert [__return_value_ptr] = a +assert [__return_value_ptr] = x.a let __return_value_ptr = __return_value_ptr + 1 -assert [__return_value_ptr] = b_len +assert [__return_value_ptr] = x.b_len let __return_value_ptr = __return_value_ptr + 1 # Check that the length is non-negative. -assert [range_check_ptr] = b_len +assert [range_check_ptr] = x.b_len # Store the updated range_check_ptr as a local variable to keep it available after # the memcpy. local range_check_ptr = range_check_ptr + 1 @@ -202,11 +203,11 @@ def test_encode_data_for_return(): let __return_value_ptr_copy = __return_value_ptr # Store the updated __return_value_ptr as a local variable to keep it available after # the memcpy. -local __return_value_ptr : felt* = __return_value_ptr + b_len -memcpy(dst=__return_value_ptr_copy, src=b, len=b_len) +local __return_value_ptr : felt* = __return_value_ptr + x.b_len +memcpy(dst=__return_value_ptr_copy, src=x.b, len=x.b_len) -# Create a reference to c as felt*. -let __return_value_tmp : felt* = cast(&c, felt*) +# Create a reference to x.c as felt*. +let __return_value_tmp : felt* = cast(&x.c, felt*) assert [__return_value_ptr + 0] = [__return_value_tmp + 0] assert [__return_value_ptr + 1] = [__return_value_tmp + 1] assert [__return_value_ptr + 2] = [__return_value_tmp + 2] diff --git a/src/starkware/starknet/compiler/starknet_preprocessor.py b/src/starkware/starknet/compiler/starknet_preprocessor.py index 301f16e5..82359e21 100644 --- a/src/starkware/starknet/compiler/starknet_preprocessor.py +++ b/src/starkware/starknet/compiler/starknet_preprocessor.py @@ -9,10 +9,8 @@ ) from starkware.cairo.lang.compiler.ast.code_elements import ( BuiltinsDirective, - CodeElementCompoundAssertEq, CodeElementFuncCall, CodeElementFunction, - CodeElementHint, CodeElementInstruction, LangDirective, ) @@ -23,16 +21,11 @@ ExprConst, ExprDeref, Expression, - ExprHint, ExprIdentifier, ExprOperator, ExprReg, ) -from starkware.cairo.lang.compiler.ast.instructions import ( - AddApInstruction, - InstructionAst, - RetInstruction, -) +from starkware.cairo.lang.compiler.ast.instructions import InstructionAst, RetInstruction from starkware.cairo.lang.compiler.ast.rvalue import RvalueFuncCall from starkware.cairo.lang.compiler.ast.types import TypedIdentifier from starkware.cairo.lang.compiler.error_handling import Location @@ -44,18 +37,22 @@ ) from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition from starkware.cairo.lang.compiler.instruction import Register +from starkware.cairo.lang.compiler.parser import ParserContext from starkware.cairo.lang.compiler.preprocessor.preprocessor import ( PreprocessedProgram, Preprocessor, ) from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError +from starkware.cairo.lang.compiler.preprocessor.preprocessor_utils import autogen_parse_code_block from starkware.cairo.lang.compiler.program import CairoHint from starkware.cairo.lang.compiler.references import create_simple_ref_expr from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.compiler.type_system import is_type_resolved +from starkware.cairo.lang.compiler.type_utils import check_felts_only_type from starkware.starknet.compiler.data_encoder import ( EncodingType, decode_data, + encode_data, struct_to_argument_info_list, ) from starkware.starknet.definitions.constants import STARKNET_LANG_DIRECTIVE @@ -323,8 +320,8 @@ def create_func_wrapper(self, elm: CodeElementFunction, func_alias_name: str): # Add function return values. retdata_size, retdata_ptr = self.process_retdata( - ret_struct_ptr=ExprIdentifier(name="ret_struct"), - ret_struct_type=ret_struct_type, + func_name=elm.identifier.name, + ret_struct_ptr="ret_struct", struct_def=ret_struct_def, location=func_location, ) @@ -393,13 +390,16 @@ def add_abi_entry( for struct_name in abi_type_info.structs: self.add_struct_to_abi(struct_name) for m_name, member in ret_struct_def.members.items(): - assert isinstance(member.cairo_type, TypeFelt) + assert is_type_resolved(member.cairo_type) + abi_type_info = prepare_type_for_abi(member.cairo_type) outputs.append( { "name": m_name, - "type": "felt", + "type": abi_type_info.modified_type.format(), } ) + for struct_name in abi_type_info.structs: + self.add_struct_to_abi(struct_name) res = { "name": name, "type": entry_type, @@ -451,52 +451,29 @@ def get_program(self) -> StarknetPreprocessedProgram: def process_retdata( self, - ret_struct_ptr: Expression, - ret_struct_type: CairoType, + func_name: str, + ret_struct_ptr: str, struct_def: StructDefinition, - location: Optional[Location], + location: Location, ) -> Tuple[Expression, Expression]: """ Processes the return values and return retdata_size and retdata_ptr. """ - # Verify all of the return types are felts. - for _, member_def in struct_def.members.items(): + # Verify all of the return types are felts-only type. See check_felts_only_type(). + for member_def in struct_def.members.values(): cairo_type = member_def.cairo_type - if not isinstance(cairo_type, TypeFelt): + is_felts_only = ( + check_felts_only_type(cairo_type=cairo_type, identifier_manager=self.identifiers) + is not None + ) + if not is_felts_only: raise PreprocessorError( - f"Unsupported argument type {cairo_type.format()}.", + f"Unsupported return value type {cairo_type.format()}.", location=cairo_type.location, ) - self.add_reference( - name=self.current_scope + "retdata_ptr", - value=ExprDeref( - addr=ExprReg(reg=Register.AP), - location=location, - ), - cairo_type=TypePointer(TypeFelt()), - require_future_definition=False, - location=location, - ) - - self.visit( - CodeElementHint( - hint=ExprHint( - hint_code="memory[ap] = segments.add()", - n_prefix_newlines=0, - location=location, - ), - location=location, - ) - ) - - # Skip check of hint whitelist as it fails before the workaround below. - super().visit_CodeElementInstruction( - CodeElementInstruction( - InstructionAst(body=AddApInstruction(ExprConst(1)), inc_ap=False, location=location) - ) - ) + self.prepare_return_struct(func_name=func_name, location=location) # Remove the references from the last instruction's flow tracking as they are # not needed by the hint and they cause the hint whitelist to fail. @@ -505,14 +482,43 @@ def process_retdata( self.instructions[-1].hints[0] = hint, dataclasses.replace( hint_flow_tracking_data, reference_ids={} ) - self.visit( - CodeElementCompoundAssertEq( - ExprDeref(ExprCast(ExprIdentifier("retdata_ptr"), TypePointer(ret_struct_type))), - ret_struct_ptr, - ) + code_elements = encode_data( + arguments=struct_to_argument_info_list(struct_def), + encoding_type=EncodingType.RETURN, + has_range_check_builtin="range_check_ptr" in self.get_os_context(), + identifiers=self.identifiers, + arg_name_func=lambda arg_info: f"{ret_struct_ptr}.{arg_info.name}", + ) + + for code_element in code_elements: + self.visit(code_element.code_elm) + + return (ExprConst(struct_def.size), ExprIdentifier("__return_value_ptr_start")) + + def prepare_return_struct(self, func_name: str, location: Location): + code = """\ +let __return_value_ptr_start = [ap] +let __return_value_ptr = __return_value_ptr_start +%{ memory[ap] = segments.add() %} +ap += 1 +""" + + code_block = autogen_parse_code_block( + path=f"autogen/starknet/external/return/{func_name}", + code=code, + parser_context=ParserContext( + parent_location=(location, "While handling return value of"), + resolved_types=True, + ), ) - return (ExprConst(struct_def.size), ExprIdentifier("retdata_ptr")) + # Call super().visit_CodeElementInstruction instead of self.visit on the last code element + # to skip hint whitelist check. + for code_elm in code_block.code_elements[:-1]: + self.visit(code_elm.code_elm) + last_code_element = code_block.code_elements[-1].code_elm + assert isinstance(last_code_element, CodeElementInstruction) + super().visit_CodeElementInstruction(last_code_element) def validate_l1_handler_signature(self, elm: CodeElementFunction): """ diff --git a/src/starkware/starknet/compiler/starknet_preprocessor_test.py b/src/starkware/starknet/compiler/starknet_preprocessor_test.py index fc8fb4cd..fffbdf4c 100644 --- a/src/starkware/starknet/compiler/starknet_preprocessor_test.py +++ b/src/starkware/starknet/compiler/starknet_preprocessor_test.py @@ -153,12 +153,17 @@ def test_wrapper_with_return_args(): %lang starknet %builtins pedersen range_check ecdsa +struct Point: + member x : felt + member y : felt +end + struct HashBuiltin: end @external -func f{ecdsa_ptr}(a : felt, b : felt) -> (c : felt, d : felt): - return (c=1, d=2) +func f{ecdsa_ptr}(a : felt, b : felt) -> (c : felt, d : Point): + return (c=1, d=Point(2, 3)) end """ ) @@ -170,7 +175,8 @@ def test_wrapper_with_return_args(): # Implementation of f [ap] = [fp + (-5)]; ap++ # Return ecdsa_ptr. [ap] = 1; ap++ # Return c=1 -[ap] = 2; ap++ # Return d=2 +[ap] = 2; ap++ # Return d.x=2 +[ap] = 3; ap++ # Return d.y=3 ret # Implementation of __wrappers__.f @@ -179,17 +185,18 @@ def test_wrapper_with_return_args(): [ap] = [[fp + (-5)] + 4]; ap++ # Pass ecdsa_ptr. [ap] = [[fp + (-3)]]; ap++ # Pass a. [ap] = [[fp + (-3)] + 1]; ap++ # Pass b. -call rel -12 # Call f. +call rel -14 # Call f. %{ memory[ap] = segments.add() %} # Allocate memory for return value ap += 1 -[[ap + (-1)]] = [ap + (-3)] # [retdata_ptr] = c -[[ap + (-1)] + 1] = [ap + (-2)] # [retdata_ptr + 1] = d +[[ap + (-1)]] = [ap + (-4)] # [retdata_ptr] = c +[[ap + (-1)] + 1] = [ap + (-3)] # [retdata_ptr + 1] = d.x +[[ap + (-1)] + 2] = [ap + (-2)] # [retdata_ptr + 2] = d.y [ap] = [[fp + (-5)]]; ap++ # Return syscall_ptr [ap] = [[fp + (-5)] + 1]; ap++ # Return storage_ptr [ap] = [[fp + (-5)] + 2]; ap++ # Return pedersen_ptr. [ap] = [[fp + (-5)] + 3]; ap++ # Return range_check. -[ap] = [ap + (-8)]; ap++ # Return ecdsa. -[ap] = 2; ap++ # Return retdata_size=2 +[ap] = [ap + (-9)]; ap++ # Return ecdsa. +[ap] = 3; ap++ # Return retdata_size=3 [ap] = [ap + (-7)]; ap++ # Return retdata_ptr ret """ @@ -373,6 +380,23 @@ def test_unsupported_args(): ) +def test_unsupported_return_type(): + verify_exception( + """ +%lang starknet +@external +func fc() -> (arg : felt**): + return (cast(0, felt**)) +end +""", + """ +file:?:?: Unsupported return value type felt**. +func fc() -> (arg : felt**): + ^****^ +""", + ) + + def test_invalid_hint(): verify_exception( """ @@ -411,6 +435,9 @@ def test_abi_basic(): struct NonExternalStruct: end +struct ExternalStruct3: + member x: felt +end @external func f(a : felt, arr_len : felt, arr : felt*) -> (b : felt, c : felt): @@ -418,8 +445,8 @@ def test_abi_basic(): end @view -func g() -> (a: felt): - return (0) +func g() -> (a: ExternalStruct3): + return (ExternalStruct3(0)) end @l1_handler @@ -430,6 +457,12 @@ def test_abi_basic(): ) assert program.abi == [ + { + "type": "struct", + "name": "ExternalStruct3", + "members": [{"name": "x", "offset": 0, "type": "felt"}], + "size": 1, + }, { "type": "struct", "name": "ExternalStruct2", @@ -459,7 +492,7 @@ def test_abi_basic(): "inputs": [], "name": "g", "outputs": [ - {"name": "a", "type": "felt"}, + {"name": "a", "type": "ExternalStruct3"}, ], "type": "function", "stateMutability": "view", diff --git a/src/starkware/starknet/core/os/os_utils.py b/src/starkware/starknet/core/os/os_utils.py index e39a5bd1..bc7ebb1a 100644 --- a/src/starkware/starknet/core/os/os_utils.py +++ b/src/starkware/starknet/core/os/os_utils.py @@ -6,7 +6,7 @@ from starkware.starknet.core.os import segment_utils, syscall_utils from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.public.abi import SYSCALL_PTR_OFFSET -from starkware.starkware_utils.error_handling import wrap_with_stark_exception +from starkware.starkware_utils.error_handling import stark_assert, wrap_with_stark_exception def update_builtin_pointers( @@ -90,6 +90,8 @@ def validate_and_process_os_context( ) expected_stop_ptr = syscall_handler.expected_syscall_ptr - assert ( - syscall_stop_ptr == expected_stop_ptr - ), f"Bad syscall_stop_ptr, Expected {expected_stop_ptr}, got {syscall_stop_ptr}." + stark_assert( + syscall_stop_ptr == expected_stop_ptr, + code=StarknetErrorCode.SECURITY_ERROR, + message=f"Bad syscall_stop_ptr, Expected {expected_stop_ptr}, got {syscall_stop_ptr}.", + ) diff --git a/src/starkware/starknet/definitions/fields.py b/src/starkware/starknet/definitions/fields.py index 0d7b973b..087de905 100644 --- a/src/starkware/starknet/definitions/fields.py +++ b/src/starkware/starknet/definitions/fields.py @@ -12,11 +12,7 @@ validate_positive, ) from starkware.starkware_utils.marshmallow_dataclass_fields import BytesAsHex, IntAsStr -from starkware.starkware_utils.validated_fields import ( - RangeValidatedField, - int_as_hex_metadata, - sequential_id_metadata, -) +from starkware.starkware_utils.validated_fields import RangeValidatedField, sequential_id_metadata # Fields data: validation data, dataclass metadata. @@ -33,13 +29,12 @@ upper_bound=constants.FELT_UPPER_BOUND, name_in_error_message="Field element", out_of_range_error_code=StarknetErrorCode.INVALID_FIELD_ELEMENT, + formatter=hex, ) def felt_metadata(name_in_error_message: str) -> Dict[str, Any]: - return int_as_hex_metadata( - validated_field=dataclasses.replace(FeltField, name_in_error_message=name_in_error_message) - ) + return dataclasses.replace(FeltField, name_in_error_message=name_in_error_message).metadata() felt_list_metadata = dict(marshmallow_field=mfields.List(IntAsStr(validate=FeltField.validate))) @@ -51,9 +46,10 @@ def felt_metadata(name_in_error_message: str) -> Dict[str, Any]: upper_bound=constants.CONTRACT_ADDRESS_UPPER_BOUND, name_in_error_message="Contract address", out_of_range_error_code=StarknetErrorCode.OUT_OF_RANGE_CONTRACT_ADDRESS, + formatter=hex, ) -contract_address_metadata = int_as_hex_metadata(validated_field=ContractAddressField) +contract_address_metadata = ContractAddressField.metadata() def bytes_as_hex_dict_keys_metadata( @@ -85,18 +81,20 @@ def bytes_as_hex_dict_keys_metadata( upper_bound=constants.ENTRY_POINT_SELECTOR_UPPER_BOUND, name_in_error_message="Entry point selector", out_of_range_error_code=StarknetErrorCode.OUT_OF_RANGE_ENTRY_POINT_SELECTOR, + formatter=hex, ) -entry_point_selector_metadata = int_as_hex_metadata(validated_field=EntryPointSelectorField) +entry_point_selector_metadata = EntryPointSelectorField.metadata() EntryPointOffsetField = RangeValidatedField( lower_bound=constants.ENTRY_POINT_OFFSET_LOWER_BOUND, upper_bound=constants.ENTRY_POINT_OFFSET_UPPER_BOUND, name_in_error_message="Entry point offset", out_of_range_error_code=StarknetErrorCode.OUT_OF_RANGE_ENTRY_POINT_OFFSET, + formatter=hex, ) -entry_point_offset_metadata = int_as_hex_metadata(validated_field=EntryPointOffsetField) +entry_point_offset_metadata = EntryPointOffsetField.metadata() global_state_commitment_tree_height_metadata = dict( marshmallow_field=mfields.Integer( diff --git a/src/starkware/starknet/security/CMakeLists.txt b/src/starkware/starknet/security/CMakeLists.txt index 4d9a07b2..3a98b201 100644 --- a/src/starkware/starknet/security/CMakeLists.txt +++ b/src/starkware/starknet/security/CMakeLists.txt @@ -39,6 +39,7 @@ python_lib(starknet_hints_whitelist_lib FILES hints_whitelist.py + whitelists/cairo_blake2s.json whitelists/cairo_keccak.json whitelists/cairo_sha256.json whitelists/latest.json diff --git a/src/starkware/starknet/security/whitelists/cairo_blake2s.json b/src/starkware/starknet/security/whitelists/cairo_blake2s.json new file mode 100644 index 00000000..c17a722a --- /dev/null +++ b/src/starkware/starknet/security/whitelists/cairo_blake2s.json @@ -0,0 +1,134 @@ +{ + "allowed_reference_expressions_for_hint": [ + { + "allowed_expressions": [ + { + "expr": "[cast(ap + (-2), felt*)]", + "name": "blake2s.finalize_blake2s.__fp__" + }, + { + "expr": "[cast(fp + (-5), starkware.cairo.common.cairo_builtins.BitwiseBuiltin**)]", + "name": "blake2s.finalize_blake2s.bitwise_ptr" + }, + { + "expr": "[cast(fp + (-3), felt**)]", + "name": "blake2s.finalize_blake2s.blake2s_ptr_end" + }, + { + "expr": "[cast(fp + (-4), felt**)]", + "name": "blake2s.finalize_blake2s.blake2s_ptr_start" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "blake2s.finalize_blake2s.n" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "blake2s.finalize_blake2s.range_check_ptr" + }, + { + "expr": "[cast(ap + (-1), felt**)]", + "name": "blake2s.finalize_blake2s.sigma" + } + ], + "hint_lines": [ + "# Add dummy pairs of input and output.", + "from starkware.cairo.common.cairo_blake2s.blake2s_utils import IV, blake2s_compress", + "", + "_n_packed_instances = int(ids.N_PACKED_INSTANCES)", + "assert 0 <= _n_packed_instances < 20", + "_blake2s_input_chunk_size_felts = int(ids.BLAKE2S_INPUT_CHUNK_SIZE_FELTS)", + "assert 0 <= _blake2s_input_chunk_size_felts < 100", + "", + "message = [0] * _blake2s_input_chunk_size_felts", + "modified_iv = [IV[0] ^ 0x01010020] + IV[1:]", + "output = blake2s_compress(", + " message=message,", + " h=modified_iv,", + " t0=0,", + " t1=0,", + " f0=0xffffffff,", + " f1=0,", + ")", + "padding = (message + modified_iv + [0, 0xffffffff] + output) * (_n_packed_instances - 1)", + "segments.write_arg(ids.blake2s_ptr_end, padding)" + ] + }, + { + "allowed_expressions": [ + { + "expr": "cast([ap + (-10)] + 10, felt*)", + "name": "blake2s.blake2s.blake2s_ptr" + }, + { + "expr": "[cast(fp + (-5), felt**)]", + "name": "blake2s.blake2s.blake2s_start" + }, + { + "expr": "[cast(fp + (-4), felt**)]", + "name": "blake2s.blake2s.input" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "blake2s.blake2s.n_bytes" + }, + { + "expr": "cast([ap + (-10)] + 10, felt*)", + "name": "blake2s.blake2s.output" + }, + { + "expr": "[cast(ap + (-2), felt*)]", + "name": "blake2s.blake2s.range_check_ptr" + } + ], + "hint_lines": [ + "from starkware.cairo.common.cairo_blake2s.blake2s_utils import IV, blake2s_compress", + "", + "_blake2s_input_chunk_size_felts = int(ids.BLAKE2S_INPUT_CHUNK_SIZE_FELTS)", + "assert 0 <= _blake2s_input_chunk_size_felts < 100", + "", + "new_state = blake2s_compress(", + " message=memory.get_range(ids.blake2s_start, _blake2s_input_chunk_size_felts),", + " h=[IV[0] ^ 0x01010020] + IV[1:],", + " t0=ids.n_bytes,", + " t1=0,", + " f0=0xffffffff,", + " f1=0,", + ")", + "", + "segments.write_arg(ids.output, new_state)" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-6), felt**)]", + "name": "blake2s._blake2s_input.blake2s_ptr" + }, + { + "expr": "[cast(fp, felt*)]", + "name": "blake2s._blake2s_input.full_word" + }, + { + "expr": "[cast(fp + (-5), felt**)]", + "name": "blake2s._blake2s_input.input" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "blake2s._blake2s_input.n_bytes" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "blake2s._blake2s_input.n_words" + }, + { + "expr": "[cast(fp + (-7), felt*)]", + "name": "blake2s._blake2s_input.range_check_ptr" + } + ], + "hint_lines": [ + "ids.full_word = int(ids.n_bytes >= 4)" + ] + } + ] +} diff --git a/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py b/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py index afc130f9..65476270 100644 --- a/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py +++ b/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py @@ -16,7 +16,7 @@ async def get_contract_addresses(self) -> Dict[str, str]: async def call_contract( self, invoke_tx: InvokeFunction, block_id: Optional[int] = None - ) -> Dict[str, List[int]]: + ) -> Dict[str, List[str]]: raw_response = await self._send_request( send_method="POST", uri=f"/call_contract?blockId={json.dumps(block_id)}", @@ -30,14 +30,14 @@ async def get_block(self, block_id: Optional[int] = None) -> Dict[str, Any]: ) return json.loads(raw_response) - async def get_code(self, contract_address: int, block_id: Optional[int] = None) -> List[int]: + async def get_code(self, contract_address: int, block_id: Optional[int] = None) -> List[str]: uri = f"/get_code?contractAddress={hex(contract_address)}&blockId={json.dumps(block_id)}" raw_response = await self._send_request(send_method="GET", uri=uri) return json.loads(raw_response) async def get_storage_at( self, contract_address: int, key: int, block_id: Optional[int] = None - ) -> int: + ) -> str: uri = ( f"/get_storage_at?contractAddress={hex(contract_address)}&key={key}&" f"blockId={json.dumps(block_id)}"