diff --git a/account/account.go b/account/account.go index 1607e34b..16067a47 100644 --- a/account/account.go +++ b/account/account.go @@ -8,6 +8,7 @@ import ( "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/starknet.go/contracts" + "github.com/NethermindEth/starknet.go/curve" "github.com/NethermindEth/starknet.go/hash" "github.com/NethermindEth/starknet.go/rpc" "github.com/NethermindEth/starknet.go/utils" @@ -180,10 +181,7 @@ func (account *Account) TransactionHashDeployAccount(tx rpc.DeployAccountType, c case rpc.DeployAccountTxn: calldata := []*felt.Felt{txn.ClassHash, txn.ContractAddressSalt} calldata = append(calldata, txn.ConstructorCalldata...) - calldataHash, err := hash.ComputeHashOnElementsFelt(calldata) - if err != nil { - return nil, err - } + calldataHash := curve.ComputeHashOnElementsFelt(calldata) versionFelt, err := new(felt.Felt).SetString(string(txn.Version)) if err != nil { @@ -200,7 +198,7 @@ func (account *Account) TransactionHashDeployAccount(tx rpc.DeployAccountType, c txn.MaxFee, account.ChainId, []*felt.Felt{txn.Nonce}, - ) + ), nil case rpc.DeployAccountTxnV3: if txn.Version == "" || txn.ResourceBounds == (rpc.ResourceBoundsMapping{}) || txn.Nonce == nil || txn.PayMasterData == nil { return nil, ErrNotAllParametersSet @@ -265,11 +263,7 @@ func (account *Account) TransactionHashInvoke(tx rpc.InvokeTxnType) (*felt.Felt, return nil, ErrNotAllParametersSet } - calldataHash, err := hash.ComputeHashOnElementsFelt(txn.Calldata) - if err != nil { - return nil, err - } - + calldataHash := curve.ComputeHashOnElementsFelt(txn.Calldata) txnVersionFelt, err := new(felt.Felt).SetString(string(txn.Version)) if err != nil { return nil, err @@ -283,17 +277,14 @@ func (account *Account) TransactionHashInvoke(tx rpc.InvokeTxnType) (*felt.Felt, txn.MaxFee, account.ChainId, []*felt.Felt{}, - ) + ), nil case rpc.InvokeTxnV1: if txn.Version == "" || len(txn.Calldata) == 0 || txn.Nonce == nil || txn.MaxFee == nil || txn.SenderAddress == nil { return nil, ErrNotAllParametersSet } - calldataHash, err := hash.ComputeHashOnElementsFelt(txn.Calldata) - if err != nil { - return nil, err - } + calldataHash := curve.ComputeHashOnElementsFelt(txn.Calldata) txnVersionFelt, err := new(felt.Felt).SetString(string(txn.Version)) if err != nil { return nil, err @@ -307,7 +298,7 @@ func (account *Account) TransactionHashInvoke(tx rpc.InvokeTxnType) (*felt.Felt, txn.MaxFee, account.ChainId, []*felt.Felt{txn.Nonce}, - ) + ), nil case rpc.InvokeTxnV3: // https://github.com/starknet-io/SNIPs/blob/main/SNIPS/snip-8.md#protocol-changes if txn.Version == "" || txn.ResourceBounds == (rpc.ResourceBoundsMapping{}) || len(txn.Calldata) == 0 || txn.Nonce == nil || txn.SenderAddress == nil || txn.PayMasterData == nil || txn.AccountDeploymentData == nil { @@ -398,10 +389,7 @@ func (account *Account) TransactionHashDeclare(tx rpc.DeclareTxnType) (*felt.Fel return nil, ErrNotAllParametersSet } - calldataHash, err := hash.ComputeHashOnElementsFelt([]*felt.Felt{txn.ClassHash}) - if err != nil { - return nil, err - } + calldataHash := curve.ComputeHashOnElementsFelt([]*felt.Felt{txn.ClassHash}) txnVersionFelt, err := new(felt.Felt).SetString(string(txn.Version)) if err != nil { @@ -416,16 +404,13 @@ func (account *Account) TransactionHashDeclare(tx rpc.DeclareTxnType) (*felt.Fel txn.MaxFee, account.ChainId, []*felt.Felt{txn.Nonce}, - ) + ), nil case rpc.DeclareTxnV2: if txn.CompiledClassHash == nil || txn.SenderAddress == nil || txn.Version == "" || txn.ClassHash == nil || txn.MaxFee == nil || txn.Nonce == nil { return nil, ErrNotAllParametersSet } - calldataHash, err := hash.ComputeHashOnElementsFelt([]*felt.Felt{txn.ClassHash}) - if err != nil { - return nil, err - } + calldataHash := curve.ComputeHashOnElementsFelt([]*felt.Felt{txn.ClassHash}) txnVersionFelt, err := new(felt.Felt).SetString(string(txn.Version)) if err != nil { @@ -440,7 +425,7 @@ func (account *Account) TransactionHashDeclare(tx rpc.DeclareTxnType) (*felt.Fel txn.MaxFee, account.ChainId, []*felt.Felt{txn.Nonce, txn.CompiledClassHash}, - ) + ), nil case rpc.DeclareTxnV3: // https://github.com/starknet-io/SNIPs/blob/main/SNIPS/snip-8.md#protocol-changes if txn.Version == "" || txn.ResourceBounds == (rpc.ResourceBoundsMapping{}) || txn.Nonce == nil || txn.SenderAddress == nil || txn.PayMasterData == nil || txn.AccountDeploymentData == nil || @@ -495,10 +480,7 @@ func (account *Account) TransactionHashDeclare(tx rpc.DeclareTxnType) (*felt.Fel // - *felt.Felt: the precomputed address as a *felt.Felt // - error: an error if any func (account *Account) PrecomputeAccountAddress(salt *felt.Felt, classHash *felt.Felt, constructorCalldata []*felt.Felt) (*felt.Felt, error) { - result, err := contracts.PrecomputeAddress(&felt.Zero, salt, classHash, constructorCalldata) - if err != nil { - return nil, err - } + result := contracts.PrecomputeAddress(&felt.Zero, salt, classHash, constructorCalldata) return result, nil } diff --git a/contracts/contracts.go b/contracts/contracts.go index 79916108..1150f7aa 100644 --- a/contracts/contracts.go +++ b/contracts/contracts.go @@ -6,7 +6,6 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/starknet.go/curve" - "github.com/NethermindEth/starknet.go/utils" ) var PREFIX_CONTRACT_ADDRESS = new(felt.Felt).SetBytes([]byte("STARKNET_CONTRACT_ADDRESS")) @@ -61,23 +60,17 @@ func UnmarshalCasmClass(filePath string) (*CasmClass, error) { // - constructorCalldata: the constructor calldata // Returns: // - *felt.Felt: the precomputed address as a *felt.Felt -// - error: an error if any -func PrecomputeAddress(deployerAddress *felt.Felt, salt *felt.Felt, classHash *felt.Felt, constructorCalldata []*felt.Felt) (*felt.Felt, error) { +func PrecomputeAddress(deployerAddress *felt.Felt, salt *felt.Felt, classHash *felt.Felt, constructorCalldata []*felt.Felt) *felt.Felt { - bigIntArr := utils.FeltArrToBigIntArr([]*felt.Felt{ + feltArr := []*felt.Felt{ PREFIX_CONTRACT_ADDRESS, deployerAddress, salt, classHash, - }) + } - constructorCalldataBigIntArr := utils.FeltArrToBigIntArr(constructorCalldata) - constructorCallDataHashInt, _ := curve.Curve.ComputeHashOnElements(constructorCalldataBigIntArr) - bigIntArr = append(bigIntArr, constructorCallDataHashInt) + constructorCallDataHash := curve.ComputeHashOnElementsFelt(constructorCalldata) + feltArr = append(feltArr, constructorCallDataHash) - preBigInt, err := curve.Curve.ComputeHashOnElements(bigIntArr) - if err != nil { - return nil, err - } - return utils.BigIntToFelt(preBigInt), nil + return curve.ComputeHashOnElementsFelt(feltArr) } diff --git a/contracts/contracts_test.go b/contracts/contracts_test.go index 8b36e114..946d5a5f 100644 --- a/contracts/contracts_test.go +++ b/contracts/contracts_test.go @@ -1,4 +1,4 @@ -package contracts_test +package contracts import ( "encoding/json" @@ -6,7 +6,6 @@ import ( "testing" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/starknet.go/contracts" "github.com/NethermindEth/starknet.go/rpc" "github.com/NethermindEth/starknet.go/utils" "github.com/stretchr/testify/assert" @@ -46,7 +45,7 @@ func TestUnmarshalContractClass(t *testing.T) { // // none func TestUnmarshalCasmClass(t *testing.T) { - casmClass, err := contracts.UnmarshalCasmClass("./tests/hello_starknet_compiled.casm.json") + casmClass, err := UnmarshalCasmClass("./tests/hello_starknet_compiled.casm.json") require.NoError(t, err) assert.Equal(t, casmClass.Prime, "0x800000000000011000000000000000000000000000000000000000000000001") assert.Equal(t, casmClass.Version, "2.1.0") @@ -106,13 +105,12 @@ func TestPrecomputeAddress(t *testing.T) { } for _, test := range testSet { - precomputedAddress, err := contracts.PrecomputeAddress( + precomputedAddress := PrecomputeAddress( utils.TestHexToFelt(t, test.DeployerAddress), utils.TestHexToFelt(t, test.Salt), utils.TestHexToFelt(t, test.ClassHash), test.ConstructorCalldata, ) - require.NoError(t, err) require.Equal(t, test.ExpectedPrecomputedAddress, precomputedAddress.String()) } } diff --git a/curve/curve.go b/curve/curve.go index c7199524..d03131f6 100644 --- a/curve/curve.go +++ b/curve/curve.go @@ -18,6 +18,7 @@ import ( junoCrypto "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/starknet.go/utils" ) var Curve StarkCurve @@ -520,101 +521,81 @@ func (sc StarkCurve) SignFelt(msgHash, privKey *felt.Felt) (*felt.Felt, *felt.Fe return xFelt, yFelt, nil } -// HashElements calculates the hash of a list of elements using the StarkCurve struct and a golang Pedersen Hash. -// (ref: https://github.com/seanjameshan/starknet.js/blob/main/src/utils/ellipticCurve.ts) -// +// HashElements calculates the hash of a list of elements using a golang Pedersen Hash. // Parameters: // - elems: slice of big.Int pointers to be hashed // Returns: // - hash: The hash of the list of elements -// - err: An error if any -func (sc StarkCurve) HashElements(elems []*big.Int) (hash *big.Int, err error) { +func HashElements(elems []*big.Int) (hash *big.Int) { + feltArr := utils.BigIntArrToFeltArr(elems) if len(elems) == 0 { - elems = append(elems, big.NewInt(0)) + feltArr = append(feltArr, new(felt.Felt)) } - hash = big.NewInt(0) - for _, h := range elems { - hash, err = sc.PedersenHash([]*big.Int{hash, h}) - if err != nil { - return hash, err - } + feltHash := new(felt.Felt) + for _, felt := range feltArr { + feltHash = Pedersen(feltHash, felt) } - return hash, err + + hash = utils.FeltToBigInt(feltHash) + return } // ComputeHashOnElements computes the hash on the given elements using a golang Pedersen Hash implementation. -// (ref: https://github.com/starkware-libs/cairo-lang/blob/13cef109cd811474de114925ee61fd5ac84a25eb/src/starkware/cairo/common/hash_state.py#L6) // -// The function appends the length of `elems` to the slice and then calls the `HashElements` method of the -// `Curve` struct, passing in `elems` as an argument. The resulting hash and -// any error that occurred during computation are returned. +// The function appends the length of `elems` to the slice and then calls the `HashElements` method +// passing in `elems` as an argument. The resulting hash is returned. // // Parameters: // - elems: slice of big.Int pointers to be hashed // Returns: // - hash: The hash of the list of elements -// - err: An error if any -func (sc StarkCurve) ComputeHashOnElements(elems []*big.Int) (hash *big.Int, err error) { +func ComputeHashOnElements(elems []*big.Int) (hash *big.Int) { elems = append(elems, big.NewInt(int64(len(elems)))) - return Curve.HashElements((elems)) + return HashElements(elems) } -// PedersenHash calculates the Pedersen hash of the given elements. -// NOTE: This function assumes the curve has been initialized with constant points -// (ref: https://github.com/seanjameshan/starknet.js/blob/main/src/utils/ellipticCurve.ts) -// -// The function requires that the precomputed constant points have been initiated. -// If the length of `sc.ConstantPoints` is zero, an error is returned. -// The function iterates over the elements in `elems` and performs the Pedersen hash calculation. -// For each element, it checks if the value is within the valid range. -// If the value is invalid, an error is returned. -// For each bit in the element, the function performs an addition operation on `ptx` and `pty` -// using the corresponding constant point from the precomputed constant points. -// If the constant point is a duplicate of `ptx`, an error is returned. -// The function returns the resulting hash and a nil error if the calculation is successful. -// Otherwise, it returns `ptx` and an error describing the issue encountered. +// ComputeHashOnElementsFelt computes the hash on elements of a Felt array. +// Does the same as ComputeHashOnElements, but receives and returns felt types. // // Parameters: -// - elems: An array of big integers representing the elements to hash. +// - feltArr: A pointer to an array of Felt objects. // Returns: -// - hash: The resulting Pedersen hash as a big integer. -// - err: An error, if any, encountered during the calculation. -func (sc StarkCurve) PedersenHash(elems []*big.Int) (hash *big.Int, err error) { - if len(sc.ConstantPoints) == 0 { - return hash, fmt.Errorf("must initiate precomputed constant points") - } - - ptx := new(big.Int).Set(sc.Gx) - pty := new(big.Int).Set(sc.Gy) - for i, elem := range elems { - x := new(big.Int).Set(elem) - - if x.Cmp(big.NewInt(0)) == -1 || x.Cmp(sc.P) >= 0 { - return ptx, fmt.Errorf("invalid x: %v", x) - } +// - *felt.Felt: a pointer to a Felt object +func ComputeHashOnElementsFelt(feltArr []*felt.Felt) *felt.Felt { + return PedersenArray(feltArr...) +} - for j := 0; j < 252; j++ { - idx := 2 + (i * 252) + j - xin := new(big.Int).Set(sc.ConstantPoints[idx][0]) - yin := new(big.Int).Set(sc.ConstantPoints[idx][1]) - if xin.Cmp(ptx) == 0 { - return hash, fmt.Errorf("constant point duplication: %v %v", ptx, xin) - } - if x.Bit(0) == 1 { - ptx, pty = sc.Add(ptx, pty, xin, yin) - } - x = x.Rsh(x, 1) - } - } +// Pedersen is a function that implements the Pedersen hash. +// NOTE: This function just wraps the Juno implementation +// (ref: https://github.com/NethermindEth/juno/blob/32fd743c774ec11a1bb2ce3dceecb57515f4873e/core/crypto/pedersen_hash.go#L20) +// +// Parameters: +// - a: a pointers to felt.Felt to be hashed. +// - b: a pointers to felt.Felt to be hashed. +// Returns: +// - *felt.Felt: a pointer to a felt.Felt storing the resulting hash. +func Pedersen(a, b *felt.Felt) *felt.Felt { + return junoCrypto.Pedersen(a, b) +} - return ptx, nil +// PedersenArray is a function that takes a variadic number of felt.Felt pointers as parameters and +// calls the PedersenArray function from the junoCrypto package with the provided parameters. +// NOTE: This function just wraps the Juno implementation +// (ref: https://github.com/NethermindEth/juno/blob/32fd743c774ec11a1bb2ce3dceecb57515f4873e/core/crypto/pedersen_hash.go#L12) +// +// Parameters: +// - felts: A variadic number of pointers to felt.Felt +// Returns: +// - *felt.Felt: pointer to a felt.Felt +func PedersenArray(felts ...*felt.Felt) *felt.Felt { + return junoCrypto.PedersenArray(felts...) } // PoseidonArray is a function that takes a variadic number of felt.Felt pointers as parameters and +// calls the PoseidonArray function from the junoCrypto package with the provided parameters. // NOTE: This function just wraps the Juno implementation // (ref: https://github.com/NethermindEth/juno/blob/main/core/crypto/poseidon_hash.go#L74) -// calls the PoseidonArray function from the junoCrypto package with the provided parameters. // // Parameters: // - felts: A variadic number of pointers to felt.Felt diff --git a/curve/curve_test.go b/curve/curve_test.go index fbbfa5e6..db7a313d 100644 --- a/curve/curve_test.go +++ b/curve/curve_test.go @@ -3,13 +3,18 @@ package curve import ( "crypto/elliptic" "fmt" - "log" "math/big" "testing" + "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/starknet.go/utils" + "github.com/stretchr/testify/require" ) +// package level variable to be used by the benchmarking code +// to prevent the compiler from optimizing the code away +var result any + // BenchmarkPedersenHash benchmarks the performance of the PedersenHash function. // // The function takes a 2D slice of big.Int values as input and measures the time @@ -21,21 +26,21 @@ import ( // // none func BenchmarkPedersenHash(b *testing.B) { - suite := [][]*big.Int{ - {utils.HexToBN("0x12773"), utils.HexToBN("0x872362")}, - {utils.HexToBN("0x1277312773"), utils.HexToBN("0x872362872362")}, - {utils.HexToBN("0x1277312773"), utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826")}, - {utils.HexToBN("0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB"), utils.HexToBN("0x872362872362")}, - {utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826"), utils.HexToBN("0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB")}, - {utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), utils.HexToBN("0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9")}, - {utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbdde")}, + suite := [][]string{ + {"0x12773", "0x872362"}, + {"0x1277312773", "0x872362872362"}, + {"0x1277312773", "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826"}, + {"0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB", "0x872362872362"}, + {"0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB"}, + {"0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd", "0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9"}, + {"0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd", "0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbdde"}, } for _, test := range suite { - b.Run(fmt.Sprintf("input_size_%d_%d", test[0].BitLen(), test[1].BitLen()), func(b *testing.B) { - if _, err := Curve.PedersenHash(test); err != nil { - log.Fatal(err) - } + b.Run(fmt.Sprintf("input_size_%d_%d", len(test[0]), len(test[1])), func(b *testing.B) { + hexArr, err := utils.HexArrToFelt(test) + require.NoError(b, err) + result = Pedersen(hexArr[0], hexArr[1]) }) } } @@ -66,9 +71,9 @@ func BenchmarkCurveSign(b *testing.B) { }) for _, test := range dataSet { - if _, _, err := Curve.Sign(test.MessageHash, test.PrivateKey, test.Seed); err != nil { - log.Fatal(err) - } + result, _, err := Curve.Sign(test.MessageHash, test.PrivateKey, test.Seed) + require.NoError(b, err) + require.NotEmpty(b, result) } } } @@ -89,24 +94,28 @@ func BenchmarkCurveSign(b *testing.B) { // // none func BenchmarkSignatureVerify(b *testing.B) { - private, _ := Curve.GetRandomPrivateKey() - x, y, _ := Curve.PrivateToPoint(private) - - hash, _ := Curve.PedersenHash( - []*big.Int{ - utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), - utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbdde"), - }) - - r, s, _ := Curve.Sign(hash, private) - - b.Run(fmt.Sprintf("sign_input_size_%d", hash.BitLen()), func(b *testing.B) { - if _, _, err := Curve.Sign(hash, private); err != nil { - log.Fatal(err) - } + private, err := Curve.GetRandomPrivateKey() + require.NoError(b, err) + x, y, err := Curve.PrivateToPoint(private) + require.NoError(b, err) + + hash := Pedersen( + utils.TestHexToFelt(b, "0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), + utils.TestHexToFelt(b, "0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbdde"), + ) + hashBigInt := utils.FeltToBigInt(hash) + + r, s, err := Curve.Sign(hashBigInt, private) + require.NoError(b, err) + + b.Run(fmt.Sprintf("sign_input_size_%d", hashBigInt.BitLen()), func(b *testing.B) { + result, _, err = Curve.Sign(hashBigInt, private) + require.NoError(b, err) + require.NotEmpty(b, result) }) - b.Run(fmt.Sprintf("verify_input_size_%d", hash.BitLen()), func(b *testing.B) { - Curve.Verify(hash, r, s, x, y) + b.Run(fmt.Sprintf("verify_input_size_%d", hashBigInt.BitLen()), func(b *testing.B) { + result = Curve.Verify(hashBigInt, r, s, x, y) + require.NotEmpty(b, result) }) } @@ -119,13 +128,10 @@ func BenchmarkSignatureVerify(b *testing.B) { // none func TestGeneral_PrivateToPoint(t *testing.T) { x, _, err := Curve.PrivateToPoint(big.NewInt(2)) - if err != nil { - t.Errorf("PrivateToPoint err %v", err) - } - expectedX, _ := new(big.Int).SetString("3324833730090626974525872402899302150520188025637965566623476530814354734325", 10) - if x.Cmp(expectedX) != 0 { - t.Errorf("Actual public key %v different from expected %v", x, expectedX) - } + require.NoError(t, err) + expectedX, ok := new(big.Int).SetString("3324833730090626974525872402899302150520188025637965566623476530814354734325", 10) + require.True(t, ok) + require.Equal(t, expectedX, x) } // TestGeneral_PedersenHash is a test function for the PedersenHash method in the General struct. @@ -140,31 +146,30 @@ func TestGeneral_PrivateToPoint(t *testing.T) { // none func TestGeneral_PedersenHash(t *testing.T) { testPedersen := []struct { - elements []*big.Int - expected *big.Int + elements []string + expected string }{ { - elements: []*big.Int{utils.HexToBN("0x12773"), utils.HexToBN("0x872362")}, - expected: utils.HexToBN("0x5ed2703dfdb505c587700ce2ebfcab5b3515cd7e6114817e6026ec9d4b364ca"), + elements: []string{"0x12773", "0x872362"}, + expected: "0x5ed2703dfdb505c587700ce2ebfcab5b3515cd7e6114817e6026ec9d4b364ca", }, { - elements: []*big.Int{utils.HexToBN("0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9"), utils.HexToBN("0x537461726b4e6574204d61696c")}, - expected: utils.HexToBN("0x180c0a3d13c1adfaa5cbc251f4fc93cc0e26cec30ca4c247305a7ce50ac807c"), + elements: []string{"0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9", "0x537461726b4e6574204d61696c"}, + expected: "0x180c0a3d13c1adfaa5cbc251f4fc93cc0e26cec30ca4c247305a7ce50ac807c", }, { - elements: []*big.Int{big.NewInt(100), big.NewInt(1000)}, - expected: utils.HexToBN("0x45a62091df6da02dce4250cb67597444d1f465319908486b836f48d0f8bf6e7"), + elements: []string{"100", "1000"}, + expected: "0x45a62091df6da02dce4250cb67597444d1f465319908486b836f48d0f8bf6e7", }, } - for _, tt := range testPedersen { - hash, err := Curve.PedersenHash(tt.elements) - if err != nil { - t.Errorf("Hashing err: %v\n", err) - } - if hash.Cmp(tt.expected) != 0 { - t.Errorf("incorrect hash: got %v expected %v\n", hash, tt.expected) - } + for _, test := range testPedersen { + elementsFelt, err := utils.HexArrToFelt(test.elements) + require.NoError(t, err) + expected := utils.TestHexToFelt(t, test.expected) + + result := Pedersen(elementsFelt[0], elementsFelt[1]) + require.Equal(t, expected, result) } } @@ -202,9 +207,7 @@ func TestGeneral_DivMod(t *testing.T) { for _, tt := range testDivmod { divR := DivMod(tt.x, tt.y, Curve.P) - if divR.Cmp(tt.expected) != 0 { - t.Errorf("DivMod Res %v does not == expected %v\n", divR, tt.expected) - } + require.Equal(t, tt.expected, divR) } } @@ -213,7 +216,7 @@ func TestGeneral_DivMod(t *testing.T) { // It tests the addition of two big integers and compares the result with the expected values. // The function takes a slice of test cases, each containing two big integers and their expected sum. // It iterates over the test cases, computes the sum using the Add function, and checks if it matches the expected sum. -// If the computed sum does not match the expected sum, an error is reported using the t.Errorf function. +// If the computed sum does not match the expected sum, an error is reported using the require.Equal function. // // Parameters: // - t: a *testing.T value representing the testing context @@ -243,13 +246,8 @@ func TestGeneral_Add(t *testing.T) { for _, tt := range testAdd { resX, resY := Curve.Add(Curve.Gx, Curve.Gy, tt.x, tt.y) - if resX.Cmp(tt.expectedX) != 0 { - t.Errorf("ResX %v does not == expected %v\n", resX, tt.expectedX) - - } - if resY.Cmp(tt.expectedY) != 0 { - t.Errorf("ResY %v does not == expected %v\n", resY, tt.expectedY) - } + require.Equal(t, tt.expectedX, resX) + require.Equal(t, tt.expectedY, resY) } } @@ -286,25 +284,16 @@ func TestGeneral_MultAir(t *testing.T) { for _, tt := range testMult { x, y, err := Curve.MimicEcMultAir(tt.r, tt.x, tt.y, Curve.Gx, Curve.Gy) - if err != nil { - t.Errorf("MultAirERR %v\n", err) - } - - if x.Cmp(tt.expectedX) != 0 { - t.Errorf("ResX %v does not == expected %v\n", x, tt.expectedX) - - } - if y.Cmp(tt.expectedY) != 0 { - t.Errorf("ResY %v does not == expected %v\n", y, tt.expectedY) - } + require.NoError(t, err) + require.Equal(t, tt.expectedX, x) + require.Equal(t, tt.expectedY, y) } } -// TestGeneral_ComputeHashOnElements is a test function that verifies the correctness of the ComputeHashOnElements function in the General package. +// TestGeneral_ComputeHashOnElements is a test function that verifies the correctness of the ComputeHashOnElements and ComputeHashOnElementsFelt functions in the General package. // -// This function tests the ComputeHashOnElements function by passing in different arrays of big.Int elements and comparing the computed hash with the expected hash. -// It checks the behavior of the ComputeHashOnElements function when an empty array is passed as input, as well as when an array with multiple elements is passed. -// The expected hashes are precalculated using the utils.HexToBN function. +// This function tests both functions by passing in different arrays of big.Int elements and comparing the computed hash with the expected hash. +// It checks the behavior of the functions when an empty array is passed as input, as well as when an array with multiple elements is passed. // // Parameters: // - t: a *testing.T value representing the testing context @@ -312,28 +301,25 @@ func TestGeneral_MultAir(t *testing.T) { // // none func TestGeneral_ComputeHashOnElements(t *testing.T) { - hashEmptyArray, err := Curve.ComputeHashOnElements([]*big.Int{}) + hashEmptyArray := ComputeHashOnElements([]*big.Int{}) + hashEmptyArrayFelt := ComputeHashOnElementsFelt([]*felt.Felt{}) + expectedHashEmmptyArray := utils.HexToBN("0x49ee3eba8c1600700ee1b87eb599f16716b0b1022947733551fde4050ca6804") - if err != nil { - t.Errorf("Could no hash an empty array %v\n", err) - } - if hashEmptyArray.Cmp(expectedHashEmmptyArray) != 0 { - t.Errorf("Hash empty array wrong value. Expected %v got %v\n", expectedHashEmmptyArray, hashEmptyArray) - } + require.Equal(t, hashEmptyArray, expectedHashEmmptyArray, "Hash empty array wrong value.") + require.Equal(t, utils.FeltToBigInt(hashEmptyArrayFelt), expectedHashEmmptyArray, "Hash empty array wrong value.") - hashFilledArray, err := Curve.ComputeHashOnElements([]*big.Int{ + filledArray := []*big.Int{ big.NewInt(123782376), big.NewInt(213984), big.NewInt(128763521321), - }) - expectedHashFilledArray := utils.HexToBN("0x7b422405da6571242dfc245a43de3b0fe695e7021c148b918cd9cdb462cac59") - - if err != nil { - t.Errorf("Could no hash an array with values %v\n", err) - } - if hashFilledArray.Cmp(expectedHashFilledArray) != 0 { - t.Errorf("Hash filled array wrong value. Expected %v got %v\n", expectedHashFilledArray, hashFilledArray) } + + hashFilledArray := ComputeHashOnElements(filledArray) + hashFilledArrayFelt := ComputeHashOnElementsFelt(utils.BigIntArrToFeltArr(filledArray)) + + expectedHashFilledArray := utils.HexToBN("0x7b422405da6571242dfc245a43de3b0fe695e7021c148b918cd9cdb462cac59") + require.Equal(t, hashFilledArray, expectedHashFilledArray, "Hash filled array wrong value.") + require.Equal(t, utils.FeltToBigInt(hashFilledArrayFelt), expectedHashFilledArray, "Hash filled array wrong value.") } // TestGeneral_HashAndSign is a test function that verifies the hashing and signing process. @@ -344,29 +330,21 @@ func TestGeneral_ComputeHashOnElements(t *testing.T) { // // none func TestGeneral_HashAndSign(t *testing.T) { - hashy, err := Curve.HashElements([]*big.Int{ + hashy := HashElements([]*big.Int{ big.NewInt(1953658213), big.NewInt(126947999705460), big.NewInt(1953658213), }) - if err != nil { - t.Errorf("Hasing elements: %v\n", err) - } - priv, _ := Curve.GetRandomPrivateKey() + priv, err := Curve.GetRandomPrivateKey() + require.NoError(t, err) x, y, err := Curve.PrivateToPoint(priv) - if err != nil { - t.Errorf("Could not convert random private key to point: %v\n", err) - } + require.NoError(t, err) r, s, err := Curve.Sign(hashy, priv) - if err != nil { - t.Errorf("Could not convert gen signature: %v\n", err) - } + require.NoError(t, err) - if !Curve.Verify(hashy, r, s, x, y) { - t.Errorf("Verified bad signature %v %v\n", r, s) - } + require.True(t, Curve.Verify(hashy, r, s, x, y)) } // TestGeneral_ComputeFact tests the ComputeFact function. @@ -401,9 +379,7 @@ func TestGeneral_ComputeFact(t *testing.T) { for _, tt := range testFacts { hash := utils.ComputeFact(tt.programHash, tt.programOutput) - if hash.Cmp(tt.expected) != 0 { - t.Errorf("Fact does not equal ex %v %v\n", hash, tt.expected) - } + require.Equal(t, tt.expected, hash) } } @@ -415,36 +391,25 @@ func TestGeneral_ComputeFact(t *testing.T) { // // none func TestGeneral_BadSignature(t *testing.T) { - hash, err := Curve.PedersenHash([]*big.Int{utils.HexToBN("0x12773"), utils.HexToBN("0x872362")}) - if err != nil { - t.Errorf("Hashing err: %v\n", err) - } + hash := Pedersen(utils.TestHexToFelt(t, "0x12773"), utils.TestHexToFelt(t, "0x872362")) + hashBigInt := utils.FeltToBigInt(hash) - priv, _ := Curve.GetRandomPrivateKey() + priv, err := Curve.GetRandomPrivateKey() + require.NoError(t, err) x, y, err := Curve.PrivateToPoint(priv) - if err != nil { - t.Errorf("Could not convert random private key to point: %v\n", err) - } + require.NoError(t, err) - r, s, err := Curve.Sign(hash, priv) - if err != nil { - t.Errorf("Could not convert gen signature: %v\n", err) - } + r, s, err := Curve.Sign(hashBigInt, priv) + require.NoError(t, err) badR := new(big.Int).Add(r, big.NewInt(1)) - if Curve.Verify(hash, badR, s, x, y) { - t.Errorf("Verified bad signature %v %v\n", r, s) - } + require.False(t, Curve.Verify(hashBigInt, badR, s, x, y)) badS := new(big.Int).Add(s, big.NewInt(1)) - if Curve.Verify(hash, r, badS, x, y) { - t.Errorf("Verified bad signature %v %v\n", r, s) - } + require.False(t, Curve.Verify(hashBigInt, r, badS, x, y)) - badHash := new(big.Int).Add(hash, big.NewInt(1)) - if Curve.Verify(badHash, r, s, x, y) { - t.Errorf("Verified bad signature %v %v\n", r, s) - } + badHash := new(big.Int).Add(hashBigInt, big.NewInt(1)) + require.False(t, Curve.Verify(badHash, r, s, x, y)) } // TestGeneral_Signature tests the Signature function. @@ -494,28 +459,24 @@ func TestGeneral_Signature(t *testing.T) { var err error for _, tt := range testSignature { + require := require.New(t) if tt.raw != "" { - h, _ := utils.HexToBytes(tt.raw) + h, err := utils.HexToBytes(tt.raw) + require.NoError(err) tt.publicX, tt.publicY = elliptic.Unmarshal(Curve, h) //nolint:all } else if tt.private != nil { tt.publicX, tt.publicY, err = Curve.PrivateToPoint(tt.private) - if err != nil { - t.Errorf("Could not convert random private key to point: %v\n", err) - } + require.NoError(err) } else if tt.publicX != nil { tt.publicY = Curve.GetYCoordinate(tt.publicX) } if tt.rIn == nil && tt.private != nil { tt.rIn, tt.sIn, err = Curve.Sign(tt.hash, tt.private) - if err != nil { - t.Errorf("Could not sign good hash: %v\n", err) - } + require.NoError(err) } - if !Curve.Verify(tt.hash, tt.rIn, tt.sIn, tt.publicX, tt.publicY) { - t.Errorf("successful signature did not verify\n") - } + require.True(Curve.Verify(tt.hash, tt.rIn, tt.sIn, tt.publicX, tt.publicY)) } } @@ -536,11 +497,7 @@ func TestGeneral_SplitFactStr(t *testing.T) { } for _, d := range data { l, h := utils.SplitFactStr(d["input"]) // 0x3 - if l != d["l"] { - t.Errorf("expected %s, got %s", d["l"], l) - } - if h != d["h"] { - t.Errorf("expected %s, got %s", d["h"], h) - } + require.Equal(t, d["l"], l) + require.Equal(t, d["h"], h) } } diff --git a/hash/hash.go b/hash/hash.go index b669753b..9ce5e7b9 100644 --- a/hash/hash.go +++ b/hash/hash.go @@ -5,26 +5,8 @@ import ( "github.com/NethermindEth/starknet.go/contracts" "github.com/NethermindEth/starknet.go/curve" "github.com/NethermindEth/starknet.go/rpc" - "github.com/NethermindEth/starknet.go/utils" ) -// ComputeHashOnElementsFelt computes the hash on elements of a Felt array. -// -// Parameters: -// - feltArr: A pointer to an array of Felt objects. -// Returns: -// - *felt.Felt: a pointer to a Felt object -// - error: an error if any -func ComputeHashOnElementsFelt(feltArr []*felt.Felt) (*felt.Felt, error) { - bigIntArr := utils.FeltArrToBigIntArr(feltArr) - - hash, err := curve.Curve.ComputeHashOnElements(bigIntArr) - if err != nil { - return nil, err - } - return utils.BigIntToFelt(hash), nil -} - // CalculateTransactionHashCommon calculates the transaction hash common to be used in the StarkNet network - a unique identifier of the transaction. // [specification]: https://github.com/starkware-libs/cairo-lang/blob/master/src/starkware/starknet/core/os/transaction_hash/transaction_hash.py#L27C5-L27C38 // @@ -39,7 +21,6 @@ func ComputeHashOnElementsFelt(feltArr []*felt.Felt) (*felt.Felt, error) { // - additionalData: Additional data to be included in the hash // Returns: // - *felt.Felt: the calculated transaction hash -// - error: an error if any func CalculateTransactionHashCommon( txHashPrefix *felt.Felt, version *felt.Felt, @@ -48,7 +29,7 @@ func CalculateTransactionHashCommon( calldata *felt.Felt, maxFee *felt.Felt, chainId *felt.Felt, - additionalData []*felt.Felt) (*felt.Felt, error) { + additionalData []*felt.Felt) *felt.Felt { dataToHash := []*felt.Felt{ txHashPrefix, @@ -60,7 +41,7 @@ func CalculateTransactionHashCommon( chainId, } dataToHash = append(dataToHash, additionalData...) - return ComputeHashOnElementsFelt(dataToHash) + return curve.ComputeHashOnElementsFelt(dataToHash) } // ClassHash calculates the hash of a contract class. diff --git a/merkle/merkle.go b/merkle/merkle.go index b6813bf0..c1c32940 100644 --- a/merkle/merkle.go +++ b/merkle/merkle.go @@ -17,24 +17,19 @@ type FixedSizeMerkleTree struct { // // It takes a variable number of *big.Int leaves as input and returns a pointer to a FixedSizeMerkleTree and an error. // The function builds the Merkle tree using the given leaves and sets the tree's root. -// If there is an error during the tree building process, the function returns nil and the error. // // Parameters: // - leaves: a slice of *big.Int representing the leaves of the tree. // Returns: // - *FixedSizeMerkleTree: a pointer to a FixedSizeMerkleTree -// - error: an error if any -func NewFixedSizeMerkleTree(leaves ...*big.Int) (*FixedSizeMerkleTree, error) { +func NewFixedSizeMerkleTree(leaves ...*big.Int) *FixedSizeMerkleTree { mt := &FixedSizeMerkleTree{ Leaves: leaves, Branches: [][]*big.Int{}, } - root, err := mt.build(leaves) - if err != nil { - return nil, err - } + root := mt.build(leaves) mt.Root = root - return mt, err + return mt } // MerkleHash calculates the Merkle hash of two big integers. @@ -44,12 +39,11 @@ func NewFixedSizeMerkleTree(leaves ...*big.Int) (*FixedSizeMerkleTree, error) { // - y: the second big integer // Returns: // - *big.Int: the Merkle hash of the two big integers -// - error: an error if the calculation fails -func MerkleHash(x, y *big.Int) (*big.Int, error) { +func MerkleHash(x, y *big.Int) *big.Int { if x.Cmp(y) <= 0 { - return curve.Curve.HashElements([]*big.Int{x, y}) + return curve.HashElements([]*big.Int{x, y}) } - return curve.Curve.HashElements([]*big.Int{y, x}) + return curve.HashElements([]*big.Int{y, x}) } // build recursively constructs a Merkle tree from the given leaves. @@ -58,26 +52,19 @@ func MerkleHash(x, y *big.Int) (*big.Int, error) { // - leaves: a slice of *big.Int representing the leaves of the tree // Return type(s): // - *big.Int: the root hash of the Merkle tree -// - error: any error that occurred during the construction of the tree -func (mt *FixedSizeMerkleTree) build(leaves []*big.Int) (*big.Int, error) { +func (mt *FixedSizeMerkleTree) build(leaves []*big.Int) *big.Int { if len(leaves) == 1 { - return leaves[0], nil + return leaves[0] } mt.Branches = append(mt.Branches, leaves) newLeaves := []*big.Int{} for i := 0; i < len(leaves); i += 2 { if i+1 == len(leaves) { - hash, err := MerkleHash(leaves[i], big.NewInt(0)) - if err != nil { - return nil, err - } + hash := MerkleHash(leaves[i], big.NewInt(0)) newLeaves = append(newLeaves, hash) break } - hash, err := MerkleHash(leaves[i], leaves[i+1]) - if err != nil { - return nil, err - } + hash := MerkleHash(leaves[i], leaves[i+1]) newLeaves = append(newLeaves, hash) } return mt.build(newLeaves) @@ -125,10 +112,7 @@ func (mt *FixedSizeMerkleTree) recursiveProof(leaf *big.Int, branchIndex int, ha if index%2 != 0 { nextProof = branch[index-1] } - newLeaf, err := MerkleHash(leaf, nextProof) - if err != nil { - return nil, fmt.Errorf("nextproof error: %v", err) - } + newLeaf := MerkleHash(leaf, nextProof) newHashPath := append(hashPath, nextProof) return mt.recursiveProof(newLeaf, branchIndex+1, newHashPath) } @@ -150,9 +134,7 @@ func ProofMerklePath(root *big.Int, leaf *big.Int, path []*big.Int) bool { if len(path) == 0 { return root.Cmp(leaf) == 0 } - nexLeaf, err := MerkleHash(leaf, path[0]) - if err != nil { - return false - } + nexLeaf := MerkleHash(leaf, path[0]) + return ProofMerklePath(root, nexLeaf, path[1:]) } diff --git a/merkle/merkle_test.go b/merkle/merkle_test.go index f2dd6094..d6074a62 100644 --- a/merkle/merkle_test.go +++ b/merkle/merkle_test.go @@ -11,7 +11,8 @@ import ( // - t: a pointer to the testing.T object // - proofs: a slice of pointers to big.Int objects representing the proofs // Returns: -// none +// +// none func debugProof(t *testing.T, proofs []*big.Int) { t.Log("...proof") for k, v := range proofs { @@ -26,23 +27,21 @@ func debugProof(t *testing.T, proofs []*big.Int) { // Parameters: // - t: A testing.T object used for reporting test failures and logging. // Returns: -// none +// +// none func TestGeneral_FixedSizeMerkleTree_Check1(t *testing.T) { leaves := []*big.Int{big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4), big.NewInt(5), big.NewInt(6), big.NewInt(7)} - merkleTree, err := NewFixedSizeMerkleTree(leaves...) - proof_7_0, _ := MerkleHash(big.NewInt(7), big.NewInt(0)) - proof_1_2, _ := MerkleHash(big.NewInt(1), big.NewInt(2)) - proof_3_4, _ := MerkleHash(big.NewInt(3), big.NewInt(4)) - proof_1_2_3_4, _ := MerkleHash(proof_1_2, proof_3_4) + merkleTree := NewFixedSizeMerkleTree(leaves...) + proof_7_0 := MerkleHash(big.NewInt(7), big.NewInt(0)) + proof_1_2 := MerkleHash(big.NewInt(1), big.NewInt(2)) + proof_3_4 := MerkleHash(big.NewInt(3), big.NewInt(4)) + proof_1_2_3_4 := MerkleHash(proof_1_2, proof_3_4) manualProof := []*big.Int{ big.NewInt(6), proof_7_0, proof_1_2_3_4, } leaf := big.NewInt(5) - if err != nil { - t.Fatal("should generate merkle tree, error", err) - } proof, err := merkleTree.Proof(leaf) if err != nil { t.Fatal("should generate merkle proof, error", err) diff --git a/typed/typed.go b/typed/typed.go index 67373bcf..54771223 100644 --- a/typed/typed.go +++ b/typed/typed.go @@ -70,7 +70,7 @@ func (dm Domain) FmtDefinitionEncoding(field string) (fmtEnc []*big.Int) { // Returns: // - *felt.Felt: a *felt.Felt with the value of str func strToFelt(str string) *felt.Felt { - var f = &felt.Zero + var f = new(felt.Felt) asciiRegexp := regexp.MustCompile(`^([[:graph:]]|[[:space:]]){1,31}$`) if b, ok := new(big.Int).SetString(str, 0); ok { @@ -128,40 +128,32 @@ func NewTypedData(types map[string]TypeDef, pType string, dom Domain) (td TypedD // Parameters: // - account: A pointer to a big.Int representing the account. // - msg: A TypedMessage object representing the message. -// - sc: A StarkCurve object representing the curve. // Returns: // - hash: A pointer to a big.Int representing the calculated hash. -// - err: An error object indicating any error that occurred during the calculation. -func (td TypedData) GetMessageHash(account *big.Int, msg TypedMessage, sc curve.StarkCurve) (hash *big.Int, err error) { +func (td TypedData) GetMessageHash(account *big.Int, msg TypedMessage) (hash *big.Int) { elements := []*big.Int{utils.UTF8StrToBig("StarkNet Message")} - domEnc, err := td.GetTypedMessageHash("StarkNetDomain", td.Domain, sc) - if err != nil { - return hash, fmt.Errorf("could not hash domain: %w", err) - } + domEnc := td.GetTypedMessageHash("StarkNetDomain", td.Domain) + elements = append(elements, domEnc) elements = append(elements, account) - msgEnc, err := td.GetTypedMessageHash(td.PrimaryType, msg, sc) - if err != nil { - return hash, fmt.Errorf("could not hash message: %w", err) - } + msgEnc := td.GetTypedMessageHash(td.PrimaryType, msg) elements = append(elements, msgEnc) - hash, err = sc.ComputeHashOnElements(elements) - return hash, err + hash = curve.ComputeHashOnElements(elements) + return hash } // GetTypedMessageHash calculates the hash of a typed message using the provided StarkCurve. // // Parameters: -// - inType: the type of the message -// - msg: the typed message -// - sc: the StarkCurve used for hashing +// - inType: the type of the message +// - msg: the typed message +// // Returns: -// - hash: the calculated hash -// - err: any error if any -func (td TypedData) GetTypedMessageHash(inType string, msg TypedMessage, sc curve.StarkCurve) (hash *big.Int, err error) { +// - hash: the calculated hash +func (td TypedData) GetTypedMessageHash(inType string, msg TypedMessage) (hash *big.Int) { prim := td.Types[inType] elements := []*big.Int{prim.Encoding} @@ -179,15 +171,12 @@ func (td TypedData) GetTypedMessageHash(inType string, msg TypedMessage, sc curv innerElements = append(innerElements, fmtDefinitions...) innerElements = append(innerElements, big.NewInt(int64(len(innerElements)))) - innerHash, err := sc.HashElements(innerElements) - if err != nil { - return hash, fmt.Errorf("error hashing internal elements: %v %w", innerElements, err) - } + innerHash := curve.HashElements(innerElements) elements = append(elements, innerHash) } - hash, err = sc.ComputeHashOnElements(elements) - return hash, err + hash = curve.ComputeHashOnElements(elements) + return hash } // GetTypeHash returns the hash of the given type. diff --git a/typed/typed_test.go b/typed/typed_test.go index 19b6b1e1..f58d87bc 100644 --- a/typed/typed_test.go +++ b/typed/typed_test.go @@ -2,12 +2,11 @@ package typed import ( "fmt" - "log" "math/big" "testing" - "github.com/NethermindEth/starknet.go/curve" "github.com/NethermindEth/starknet.go/utils" + "github.com/stretchr/testify/require" ) type Mail struct { @@ -50,7 +49,7 @@ func (mail Mail) FmtDefinitionEncoding(field string) (fmtEnc []*big.Int) { // // Returns: // - ttd: the generated TypedData object -func MockTypedData() (ttd TypedData) { +func MockTypedData() (ttd TypedData, err error) { exampleTypes := make(map[string]TypeDef) domDefs := []Definition{{"name", "felt"}, {"version", "felt"}, {"chainId", "felt"}} exampleTypes["StarkNetDomain"] = TypeDef{Definitions: domDefs} @@ -65,8 +64,11 @@ func MockTypedData() (ttd TypedData) { ChainId: "1", } - ttd, _ = NewTypedData(exampleTypes, "Mail", dm) - return ttd + ttd, err = NewTypedData(exampleTypes, "Mail", dm) + if err != nil { + return TypedData{}, err + } + return ttd, err } // TestGeneral_GetMessageHash tests the GetMessageHash function. @@ -83,7 +85,8 @@ func MockTypedData() (ttd TypedData) { // Returns: // - None func TestGeneral_GetMessageHash(t *testing.T) { - ttd := MockTypedData() + ttd, err := MockTypedData() + require.NoError(t, err) mail := Mail{ From: Person{ @@ -97,15 +100,10 @@ func TestGeneral_GetMessageHash(t *testing.T) { Contents: "Hello, Bob!", } - hash, err := ttd.GetMessageHash(utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826"), mail, curve.Curve) - if err != nil { - t.Errorf("Could not hash message: %v\n", err) - } + hash := ttd.GetMessageHash(utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826"), mail) exp := "0x6fcff244f63e38b9d88b9e3378d44757710d1b244282b435cb472053c8d78d0" - if utils.BigToHex(hash) != exp { - t.Errorf("type hash: %v does not match expected %v\n", utils.BigToHex(hash), exp) - } + require.Equal(t, exp, utils.BigToHex(hash)) } // BenchmarkGetMessageHash is a benchmark function for testing the GetMessageHash function. @@ -120,7 +118,8 @@ func TestGeneral_GetMessageHash(t *testing.T) { // // none func BenchmarkGetMessageHash(b *testing.B) { - ttd := MockTypedData() + ttd, err := MockTypedData() + require.NoError(b, err) mail := Mail{ From: Person{ @@ -135,9 +134,8 @@ func BenchmarkGetMessageHash(b *testing.B) { } addr := utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826") b.Run(fmt.Sprintf("input_size_%d", addr.BitLen()), func(b *testing.B) { - if _, err := ttd.GetMessageHash(addr, mail, curve.Curve); err != nil { - log.Fatal(err) - } + result := ttd.GetMessageHash(addr, mail) + require.NotEmpty(b, result) }) } @@ -152,17 +150,13 @@ func BenchmarkGetMessageHash(b *testing.B) { // // none func TestGeneral_GetDomainHash(t *testing.T) { - ttd := MockTypedData() + ttd, err := MockTypedData() + require.NoError(t, err) - hash, err := ttd.GetTypedMessageHash("StarkNetDomain", ttd.Domain, curve.Curve) - if err != nil { - t.Errorf("Could not hash message: %v\n", err) - } + hash := ttd.GetTypedMessageHash("StarkNetDomain", ttd.Domain) exp := "0x54833b121883a3e3aebff48ec08a962f5742e5f7b973469c1f8f4f55d470b07" - if utils.BigToHex(hash) != exp { - t.Errorf("type hash: %v does not match expected %v\n", utils.BigToHex(hash), exp) - } + require.Equal(t, exp, utils.BigToHex(hash)) } // TestGeneral_GetTypedMessageHash is a unit test for the GetTypedMessageHash function @@ -178,7 +172,8 @@ func TestGeneral_GetDomainHash(t *testing.T) { // // none func TestGeneral_GetTypedMessageHash(t *testing.T) { - ttd := MockTypedData() + ttd, err := MockTypedData() + require.NoError(t, err) mail := Mail{ From: Person{ @@ -192,15 +187,10 @@ func TestGeneral_GetTypedMessageHash(t *testing.T) { Contents: "Hello, Bob!", } - hash, err := ttd.GetTypedMessageHash("Mail", mail, curve.Curve) - if err != nil { - t.Errorf("Could get typed message hash: %v\n", err) - } + hash := ttd.GetTypedMessageHash("Mail", mail) exp := "0x4758f1ed5e7503120c228cbcaba626f61514559e9ef5ed653b0b885e0f38aec" - if utils.BigToHex(hash) != exp { - t.Errorf("type hash: %v does not match expected %v\n", utils.BigToHex(hash), exp) - } + require.Equal(t, exp, utils.BigToHex(hash)) } // TestGeneral_GetTypeHash tests the GetTypeHash function. @@ -215,37 +205,28 @@ func TestGeneral_GetTypedMessageHash(t *testing.T) { // // none func TestGeneral_GetTypeHash(t *testing.T) { - tdd := MockTypedData() + require := require.New(t) - hash, err := tdd.GetTypeHash("StarkNetDomain") - if err != nil { - t.Errorf("error enccoding type %v\n", err) - } + ttd, err := MockTypedData() + require.NoError(err) + + hash, err := ttd.GetTypeHash("StarkNetDomain") + require.NoError(err) exp := "0x1bfc207425a47a5dfa1a50a4f5241203f50624ca5fdf5e18755765416b8e288" - if utils.BigToHex(hash) != exp { - t.Errorf("type hash: %v does not match expected %v\n", utils.BigToHex(hash), exp) - } + require.Equal(exp, utils.BigToHex(hash)) - enc := tdd.Types["StarkNetDomain"] - if utils.BigToHex(enc.Encoding) != exp { - t.Errorf("type hash: %v does not match expected %v\n", utils.BigToHex(hash), exp) - } + enc := ttd.Types["StarkNetDomain"] + require.Equal(exp, utils.BigToHex(enc.Encoding)) - pHash, err := tdd.GetTypeHash("Person") - if err != nil { - t.Errorf("error enccoding type %v\n", err) - } + pHash, err := ttd.GetTypeHash("Person") + require.NoError(err) exp = "0x2896dbe4b96a67110f454c01e5336edc5bbc3635537efd690f122f4809cc855" - if utils.BigToHex(pHash) != exp { - t.Errorf("type hash: %v does not match expected %v\n", utils.BigToHex(pHash), exp) - } + require.Equal(exp, utils.BigToHex(pHash)) - enc = tdd.Types["Person"] - if utils.BigToHex(enc.Encoding) != exp { - t.Errorf("type hash: %v does not match expected %v\n", utils.BigToHex(hash), exp) - } + enc = ttd.Types["Person"] + require.Equal(exp, utils.BigToHex(enc.Encoding)) } // TestGeneral_GetSelectorFromName tests the GetSelectorFromName function. @@ -288,15 +269,12 @@ func TestGeneral_GetSelectorFromName(t *testing.T) { // // none func TestGeneral_EncodeType(t *testing.T) { - tdd := MockTypedData() + ttd, err := MockTypedData() + require.NoError(t, err) - enc, err := tdd.EncodeType("Mail") - if err != nil { - t.Errorf("error enccoding type %v\n", err) - } + enc, err := ttd.EncodeType("Mail") + require.NoError(t, err) exp := "Mail(from:Person,to:Person,contents:felt)Person(name:felt,wallet:felt)" - if enc != exp { - t.Errorf("type encoding: %v does not match expected %v\n", enc, exp) - } + require.Equal(t, exp, enc) } diff --git a/utils/Felt.go b/utils/Felt.go index 50e0357f..002b7cba 100644 --- a/utils/Felt.go +++ b/utils/Felt.go @@ -202,3 +202,17 @@ func feltToString(f *felt.Felt) (string, error) { } return string(b), nil } + +// BigIntArrToFeltArr converts an array of big.Int objects to an array of Felt objects. +// +// Parameters: +// - bigArr: the array of big.Int objects to convert +// Returns: +// - []*felt.Felt: the array of Felt objects +func BigIntArrToFeltArr(bigArr []*big.Int) []*felt.Felt { + var feltArr []*felt.Felt + for _, big := range bigArr { + feltArr = append(feltArr, BigIntToFelt(big)) + } + return feltArr +} diff --git a/utils/keccak.go b/utils/keccak.go index 171487e8..4426a1bd 100644 --- a/utils/keccak.go +++ b/utils/keccak.go @@ -73,6 +73,22 @@ func HexToBN(hexString string) *big.Int { return n } +// HexArrToBNArr converts a hexadecimal string array to a *big.Int array. +// Trim "0x" prefix(if exists) +// +// Parameters: +// - hexArr: the hexadecimal string array to be converted +// Returns: +// - *big.Int: the converted array +func HexArrToBNArr(hexArr []string) []*big.Int { + bigNumArr := make([]*big.Int, len(hexArr)) + for i, e := range hexArr { + bigNum := HexToBN(e) + bigNumArr[i] = bigNum + } + return bigNumArr +} + // HexToBytes converts a hexadecimal string to a byte slice. // trim "0x" prefix(if exists) //