diff --git a/account/account.go b/account/account.go index dd70f689..ea3643b9 100644 --- a/account/account.go +++ b/account/account.go @@ -6,7 +6,8 @@ import ( "time" "github.com/NethermindEth/juno/core/felt" - starknetgo "github.com/NethermindEth/starknet.go" + "github.com/NethermindEth/starknet.go/curve" + "github.com/NethermindEth/starknet.go/hash" "github.com/NethermindEth/starknet.go/rpc" "github.com/NethermindEth/starknet.go/utils" ) @@ -47,10 +48,10 @@ type Account struct { ChainId *felt.Felt AccountAddress *felt.Felt publicKey string - ks starknetgo.Keystore + ks Keystore } -func NewAccount(provider rpc.RpcProvider, accountAddress *felt.Felt, publicKey string, keystore starknetgo.Keystore) (*Account, error) { +func NewAccount(provider rpc.RpcProvider, accountAddress *felt.Felt, publicKey string, keystore Keystore) (*Account, error) { account := &Account{ provider: provider, AccountAddress: accountAddress, @@ -134,7 +135,7 @@ func (account *Account) TransactionHashDeployAccount(tx rpc.DeployAccountTxn, co } calldata := []*felt.Felt{tx.ClassHash, tx.ContractAddressSalt} calldata = append(calldata, tx.ConstructorCalldata...) - calldataHash, err := computeHashOnElementsFelt(calldata) + calldataHash, err := hash.ComputeHashOnElementsFelt(calldata) if err != nil { return nil, err } @@ -145,7 +146,7 @@ func (account *Account) TransactionHashDeployAccount(tx rpc.DeployAccountTxn, co } // https://docs.starknet.io/documentation/architecture_and_concepts/Network_Architecture/transactions/#deploy_account_hash_calculation - return calculateTransactionHashCommon( + return hash.CalculateTransactionHashCommon( PREFIX_DEPLOY_ACCOUNT, versionFelt, contractAddress, @@ -166,7 +167,7 @@ func (account *Account) TransactionHashInvoke(tx rpc.InvokeTxnType) (*felt.Felt, return nil, ErrNotAllParametersSet } - calldataHash, err := computeHashOnElementsFelt(txn.Calldata) + calldataHash, err := hash.ComputeHashOnElementsFelt(txn.Calldata) if err != nil { return nil, err } @@ -175,7 +176,7 @@ func (account *Account) TransactionHashInvoke(tx rpc.InvokeTxnType) (*felt.Felt, if err != nil { return nil, err } - return calculateTransactionHashCommon( + return hash.CalculateTransactionHashCommon( PREFIX_TRANSACTION, txnVersionFelt, txn.ContractAddress, @@ -191,7 +192,7 @@ func (account *Account) TransactionHashInvoke(tx rpc.InvokeTxnType) (*felt.Felt, return nil, ErrNotAllParametersSet } - calldataHash, err := computeHashOnElementsFelt(txn.Calldata) + calldataHash, err := hash.ComputeHashOnElementsFelt(txn.Calldata) if err != nil { return nil, err } @@ -199,7 +200,7 @@ func (account *Account) TransactionHashInvoke(tx rpc.InvokeTxnType) (*felt.Felt, if err != nil { return nil, err } - return calculateTransactionHashCommon( + return hash.CalculateTransactionHashCommon( PREFIX_TRANSACTION, txnVersionFelt, txn.SenderAddress, @@ -224,7 +225,7 @@ func (account *Account) TransactionHashDeclare(tx rpc.DeclareTxnType) (*felt.Fel return nil, ErrNotAllParametersSet } - calldataHash, err := computeHashOnElementsFelt([]*felt.Felt{txn.ClassHash}) + calldataHash, err := hash.ComputeHashOnElementsFelt([]*felt.Felt{txn.ClassHash}) if err != nil { return nil, err } @@ -233,7 +234,7 @@ func (account *Account) TransactionHashDeclare(tx rpc.DeclareTxnType) (*felt.Fel if err != nil { return nil, err } - return calculateTransactionHashCommon( + return hash.CalculateTransactionHashCommon( PREFIX_DECLARE, txnVersionFelt, txn.SenderAddress, @@ -248,7 +249,7 @@ func (account *Account) TransactionHashDeclare(tx rpc.DeclareTxnType) (*felt.Fel return nil, ErrNotAllParametersSet } - calldataHash, err := computeHashOnElementsFelt([]*felt.Felt{txn.ClassHash}) + calldataHash, err := hash.ComputeHashOnElementsFelt([]*felt.Felt{txn.ClassHash}) if err != nil { return nil, err } @@ -257,7 +258,7 @@ func (account *Account) TransactionHashDeclare(tx rpc.DeclareTxnType) (*felt.Fel if err != nil { return nil, err } - return calculateTransactionHashCommon( + return hash.CalculateTransactionHashCommon( PREFIX_DECLARE, txnVersionFelt, txn.SenderAddress, @@ -284,10 +285,10 @@ func (account *Account) PrecomputeAddress(deployerAddress *felt.Felt, salt *felt }) constructorCalldataBigIntArr := utils.FeltArrToBigIntArr(constructorCalldata) - constructorCallDataHashInt, _ := starknetgo.Curve.ComputeHashOnElements(constructorCalldataBigIntArr) + constructorCallDataHashInt, _ := curve.Curve.ComputeHashOnElements(constructorCalldataBigIntArr) bigIntArr = append(bigIntArr, constructorCallDataHashInt) - preBigInt, err := starknetgo.Curve.ComputeHashOnElements(bigIntArr) + preBigInt, err := curve.Curve.ComputeHashOnElements(bigIntArr) if err != nil { return nil, err } @@ -418,3 +419,27 @@ func (account *Account) TransactionByBlockIdAndIndex(ctx context.Context, blockI func (account *Account) TransactionByHash(ctx context.Context, hash *felt.Felt) (rpc.Transaction, error) { return account.provider.TransactionByHash(ctx, hash) } + +/* +Formats the multicall transactions in a format which can be signed and verified by the network and OpenZeppelin account contracts +*/ +func FmtCalldata(fnCalls []rpc.FunctionCall) []*felt.Felt { + callArray := []*felt.Felt{} + callData := []*felt.Felt{new(felt.Felt).SetUint64(uint64(len(fnCalls)))} + + for _, tx := range fnCalls { + callData = append(callData, tx.ContractAddress, tx.EntryPointSelector) + + if len(tx.Calldata) == 0 { + callData = append(callData, &felt.Zero, &felt.Zero) + continue + } + + callData = append(callData, new(felt.Felt).SetUint64(uint64(len(callArray))), new(felt.Felt).SetUint64(uint64(len(tx.Calldata))+1)) + callArray = append(callArray, tx.Calldata...) + } + callData = append(callData, new(felt.Felt).SetUint64(uint64(len(callArray)+1))) + callData = append(callData, callArray...) + callData = append(callData, new(felt.Felt).SetUint64(0)) + return callData +} diff --git a/account/account_test.go b/account/account_test.go index a931244c..863099fe 100644 --- a/account/account_test.go +++ b/account/account_test.go @@ -11,10 +11,9 @@ import ( "testing" "time" - "github.com/NethermindEth/juno/core/felt" - starknetgo "github.com/NethermindEth/starknet.go" + "github.com/golang/mock/gomock" "github.com/joho/godotenv" - + "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/starknet.go/account" "github.com/NethermindEth/starknet.go/contracts" "github.com/NethermindEth/starknet.go/devnet" @@ -22,7 +21,6 @@ import ( "github.com/NethermindEth/starknet.go/mocks" "github.com/NethermindEth/starknet.go/rpc" "github.com/NethermindEth/starknet.go/utils" - "github.com/golang/mock/gomock" "github.com/test-go/testify/require" ) @@ -134,7 +132,7 @@ func TestTransactionHashInvoke(t *testing.T) { for _, test := range testSet { t.Run("Transaction hash", func(t *testing.T) { - ks := starknetgo.NewMemKeystore() + ks := account.NewMemKeystore() if test.SetKS { privKeyBI, ok := new(big.Int).SetString(test.PrivKey.String(), 0) require.True(t, ok) @@ -225,7 +223,7 @@ func TestChainIdMOCK(t *testing.T) { for _, test := range testSet { mockRpcProvider.EXPECT().ChainID(context.Background()).Return(test.ChainID, nil) - account, err := account.NewAccount(mockRpcProvider, &felt.Zero, "pubkey", starknetgo.NewMemKeystore()) + account, err := account.NewAccount(mockRpcProvider, &felt.Zero, "pubkey", account.NewMemKeystore()) require.NoError(t, err) require.Equal(t, account.ChainId.String(), test.ExpectedID) } @@ -256,7 +254,7 @@ func TestChainId(t *testing.T) { require.NoError(t, err, "Error in rpc.NewClient") provider := rpc.NewProvider(client) - account, err := account.NewAccount(provider, &felt.Zero, "pubkey", starknetgo.NewMemKeystore()) + account, err := account.NewAccount(provider, &felt.Zero, "pubkey", account.NewMemKeystore()) require.NoError(t, err) require.Equal(t, account.ChainId.String(), test.ExpectedID) } @@ -297,7 +295,7 @@ func TestSignMOCK(t *testing.T) { for _, test := range testSet { privKeyBI, ok := new(big.Int).SetString(test.PrivKey.String(), 0) require.True(t, ok) - ks := starknetgo.NewMemKeystore() + ks := account.NewMemKeystore() ks.Put(test.Address.String(), privKeyBI) mockRpcProvider.EXPECT().ChainID(context.Background()).Return(test.ChainId, nil) @@ -388,7 +386,7 @@ func TestAddInvoke(t *testing.T) { provider := rpc.NewProvider(client) // Set up ks - ks := starknetgo.NewMemKeystore() + ks := account.NewMemKeystore() if test.SetKS { fakePrivKeyBI, ok := new(big.Int).SetString(test.PrivKey.String(), 0) require.True(t, ok) @@ -425,7 +423,7 @@ func TestAddDeployAccountDevnet(t *testing.T) { fakeUserPub := utils.TestHexToFelt(t, fakeUser.PublicKey) // Set up ks - ks := starknetgo.NewMemKeystore() + ks := account.NewMemKeystore() fakePrivKeyBI, ok := new(big.Int).SetString(fakeUser.PrivateKey, 0) require.True(t, ok) ks.Put(fakeUser.PublicKey, fakePrivKeyBI) @@ -471,7 +469,7 @@ func TestTransactionHashDeployAccountTestnet(t *testing.T) { ExpectedHash := utils.TestHexToFelt(t, "0x5b6b5927cd70ad7a80efdbe898244525871875c76540b239f6730118598b9cb") ExpectedPrecomputeAddr := utils.TestHexToFelt(t, "0x88d0038623a89bf853c70ea68b1062ccf32b094d1d7e5f924cda8404dc73e1") - ks := starknetgo.NewMemKeystore() + ks := account.NewMemKeystore() fakePrivKeyBI, ok := new(big.Int).SetString(PrivKey.String(), 0) require.True(t, ok) ks.Put(PubKey.String(), fakePrivKeyBI) @@ -515,7 +513,7 @@ func TestTransactionHashDeclare(t *testing.T) { require.NoError(t, err, "Error in rpc.NewClient") provider := rpc.NewProvider(client) - acnt, err := account.NewAccount(provider, &felt.Zero, "", starknetgo.NewMemKeystore()) + acnt, err := account.NewAccount(provider, &felt.Zero, "", account.NewMemKeystore()) require.NoError(t, err) tx := rpc.DeclareTxnV2{ diff --git a/account/hash.go b/account/hash.go deleted file mode 100644 index 77d02f18..00000000 --- a/account/hash.go +++ /dev/null @@ -1,67 +0,0 @@ -package account - -import ( - "github.com/NethermindEth/juno/core/felt" - starknetgo "github.com/NethermindEth/starknet.go" - "github.com/NethermindEth/starknet.go/rpc" - "github.com/NethermindEth/starknet.go/utils" -) - -// computeHashOnElementsFelt hashes the array of felts provided as input -func computeHashOnElementsFelt(feltArr []*felt.Felt) (*felt.Felt, error) { - bigIntArr := utils.FeltArrToBigIntArr(feltArr) - hash, err := starknetgo.Curve.ComputeHashOnElements(bigIntArr) - if err != nil { - return nil, err - } - return utils.BigIntToFelt(hash), nil -} - -// calculateTransactionHashCommon [specification] calculates the transaction hash in the StarkNet network - a unique identifier of the transaction. -// [specification]: https://github.com/starkware-libs/cairo-lang/blob/master/src/starkware/starknet/core/os/transaction_hash/transaction_hash.py#L27C5-L27C38 -func calculateTransactionHashCommon( - txHashPrefix *felt.Felt, - version *felt.Felt, - contractAddress *felt.Felt, - entryPointSelector *felt.Felt, - calldata *felt.Felt, - maxFee *felt.Felt, - chainId *felt.Felt, - additionalData []*felt.Felt) (*felt.Felt, error) { - - dataToHash := []*felt.Felt{ - txHashPrefix, - version, - contractAddress, - entryPointSelector, - calldata, - maxFee, - chainId, - } - dataToHash = append(dataToHash, additionalData...) - return computeHashOnElementsFelt(dataToHash) -} - -/* -Formats the multicall transactions in a format which can be signed and verified by the network and OpenZeppelin account contracts -*/ -func FmtCalldata(fnCalls []rpc.FunctionCall) []*felt.Felt { - callArray := []*felt.Felt{} - callData := []*felt.Felt{new(felt.Felt).SetUint64(uint64(len(fnCalls)))} - - for _, tx := range fnCalls { - callData = append(callData, tx.ContractAddress, tx.EntryPointSelector) - - if len(tx.Calldata) == 0 { - callData = append(callData, &felt.Zero, &felt.Zero) - continue - } - - callData = append(callData, new(felt.Felt).SetUint64(uint64(len(callArray))), new(felt.Felt).SetUint64(uint64(len(tx.Calldata))+1)) - callArray = append(callArray, tx.Calldata...) - } - callData = append(callData, new(felt.Felt).SetUint64(uint64(len(callArray)+1))) - callData = append(callData, callArray...) - callData = append(callData, new(felt.Felt).SetUint64(0)) - return callData -} diff --git a/keystore.go b/account/keystore.go similarity index 93% rename from keystore.go rename to account/keystore.go index b5f0543f..536f6c22 100644 --- a/keystore.go +++ b/account/keystore.go @@ -1,4 +1,4 @@ -package starknetgo +package account import ( "context" @@ -6,6 +6,8 @@ import ( "fmt" "math/big" "sync" + + "github.com/NethermindEth/starknet.go/curve" ) type Keystore interface { @@ -68,7 +70,7 @@ func sign(ctx context.Context, msgHash *big.Int, key *big.Int) (x *big.Int, y *b err = ctx.Err() default: - x, y, err = Curve.Sign(msgHash, key) + x, y, err = curve.Curve.Sign(msgHash, key) } return x, y, err } diff --git a/curve.go b/curve/curve.go similarity index 52% rename from curve.go rename to curve/curve.go index 06dce1f8..7df3f8f7 100644 --- a/curve.go +++ b/curve/curve.go @@ -1,4 +1,4 @@ -package starknetgo +package curve /* Although the library adheres to the 'elliptic/curve' interface. @@ -6,12 +6,18 @@ package starknetgo It is recommended to use in the same way(i.e. `curve.Sign` and not `ecdsa.Sign`). */ import ( + "bytes" "crypto/elliptic" + "crypto/rand" + "crypto/sha256" _ "embed" "encoding/json" "fmt" "log" "math/big" + + junoCrypto "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" ) var Curve StarkCurve @@ -271,16 +277,309 @@ func (sc StarkCurve) EcMult(m, x1, y1 *big.Int) (x, y *big.Int) { return x, y } -// Finds a nonnegative integer 0 <= x < p such that (m * x) % p == n -// -// (ref: https://github.com/starkware-libs/cairo-lang/blob/master/src/starkware/crypto/starkware/crypto/signature/math_utils.py) -func DivMod(n, m, p *big.Int) *big.Int { - q := new(big.Int) - gx := new(big.Int) - gy := new(big.Int) - q.GCD(gx, gy, m, p) - - r := new(big.Int).Mul(n, gx) - r = r.Mod(r, p) - return r +/* +Verifies the validity of the stark curve signature +given the message hash, and public key (x, y) coordinates +used to sign the message. + +(ref: https://github.com/starkware-libs/cairo-lang/blob/master/src/starkware/crypto/starkware/crypto/signature/signature.py) +*/ +func (sc StarkCurve) Verify(msgHash, r, s, pubX, pubY *big.Int) bool { + w := sc.InvModCurveSize(s) + + if s.Cmp(big.NewInt(0)) != 1 || s.Cmp(sc.N) != -1 { + return false + } + if r.Cmp(big.NewInt(0)) != 1 || r.Cmp(sc.Max) != -1 { + return false + } + if w.Cmp(big.NewInt(0)) != 1 || w.Cmp(sc.Max) != -1 { + return false + } + if msgHash.Cmp(big.NewInt(0)) != 1 || msgHash.Cmp(sc.Max) != -1 { + return false + } + if !sc.IsOnCurve(pubX, pubY) { + return false + } + + zGx, zGy, err := sc.MimicEcMultAir(msgHash, sc.EcGenX, sc.EcGenY, sc.MinusShiftPointX, sc.MinusShiftPointY) + if err != nil { + return false + } + + rQx, rQy, err := sc.MimicEcMultAir(r, pubX, pubY, sc.Gx, sc.Gy) + if err != nil { + return false + } + inX, inY := sc.Add(zGx, zGy, rQx, rQy) + wBx, wBy, err := sc.MimicEcMultAir(w, inX, inY, sc.Gx, sc.Gy) + if err != nil { + return false + } + + outX, _ := sc.Add(wBx, wBy, sc.MinusShiftPointX, sc.MinusShiftPointY) + if r.Cmp(outX) == 0 { + return true + } else { + altY := new(big.Int).Neg(pubY) + + zGx, zGy, err = sc.MimicEcMultAir(msgHash, sc.EcGenX, sc.EcGenY, sc.MinusShiftPointX, sc.MinusShiftPointY) + if err != nil { + return false + } + + rQx, rQy, err = sc.MimicEcMultAir(r, pubX, new(big.Int).Set(altY), sc.Gx, sc.Gy) + if err != nil { + return false + } + inX, inY = sc.Add(zGx, zGy, rQx, rQy) + wBx, wBy, err = sc.MimicEcMultAir(w, inX, inY, sc.Gx, sc.Gy) + if err != nil { + return false + } + + outX, _ = sc.Add(wBx, wBy, sc.MinusShiftPointX, sc.MinusShiftPointY) + if r.Cmp(outX) == 0 { + return true + } + } + return false +} + +/* +Signs the hash value of contents with the provided private key. +Secret is generated using a golang implementation of RFC 6979. +Implementation does not yet include "extra entropy" or "retry gen". + +(ref: https://datatracker.ietf.org/doc/html/rfc6979) +*/ +func (sc StarkCurve) Sign(msgHash, privKey *big.Int, seed ...*big.Int) (x, y *big.Int, err error) { + if msgHash == nil { + return x, y, fmt.Errorf("nil msgHash") + } + if privKey == nil { + return x, y, fmt.Errorf("nil privKey") + } + if msgHash.Cmp(big.NewInt(0)) != 1 || msgHash.Cmp(sc.Max) != -1 { + return x, y, fmt.Errorf("invalid bit length") + } + + inSeed := big.NewInt(0) + if len(seed) == 1 && inSeed != nil { + inSeed = seed[0] + } + for { + k := sc.GenerateSecret(big.NewInt(0).Set(msgHash), big.NewInt(0).Set(privKey), big.NewInt(0).Set(inSeed)) + // In case r is rejected k shall be generated with new seed + inSeed = inSeed.Add(inSeed, big.NewInt(1)) + + r, _ := sc.EcMult(k, sc.EcGenX, sc.EcGenY) + + // DIFF: in classic ECDSA, we take int(x) % n. + if r.Cmp(big.NewInt(0)) != 1 || r.Cmp(sc.Max) != -1 { + // Bad value. This fails with negligible probability. + continue + } + + agg := new(big.Int).Mul(r, privKey) + agg = agg.Add(agg, msgHash) + + if new(big.Int).Mod(agg, sc.N).Cmp(big.NewInt(0)) == 0 { + // Bad value. This fails with negligible probability. + continue + } + + w := DivMod(k, agg, sc.N) + if w.Cmp(big.NewInt(0)) != 1 || w.Cmp(sc.Max) != -1 { + // Bad value. This fails with negligible probability. + continue + } + + s := sc.InvModCurveSize(w) + return r, s, nil + } + + return x, y, nil +} + +/* +See Sign. SignFelt just wraps Sign. +*/ +func (sc StarkCurve) SignFelt(msgHash, privKey *felt.Felt) (*felt.Felt, *felt.Felt, error) { + msgHashInt := msgHash.BigInt(new(big.Int)) + privKeyInt := privKey.BigInt(new(big.Int)) + x, y, err := sc.Sign(msgHashInt, privKeyInt) + if err != nil { + return nil, nil, err + } + xFelt := felt.NewFelt(new(felt.Felt).Impl().SetBigInt(x)) + yFelt := felt.NewFelt(new(felt.Felt).Impl().SetBigInt(y)) + return xFelt, yFelt, nil + +} + +/* +Hashes the contents of a given array using a golang Pedersen Hash implementation. + +(ref: https://github.com/seanjameshan/starknet.js/blob/main/src/utils/ellipticCurve.ts) +*/ +func (sc StarkCurve) HashElements(elems []*big.Int) (hash *big.Int, err error) { + if len(elems) == 0 { + elems = append(elems, big.NewInt(0)) + } + + hash = big.NewInt(0) + for _, h := range elems { + hash, err = sc.PedersenHash([]*big.Int{hash, h}) + if err != nil { + return hash, err + } + } + return hash, err +} + +/* +Hashes the contents of a given array with its size using a golang Pedersen Hash implementation. + +(ref: https://github.com/starkware-libs/cairo-lang/blob/13cef109cd811474de114925ee61fd5ac84a25eb/src/starkware/cairo/common/hash_state.py#L6) +*/ +func (sc StarkCurve) ComputeHashOnElements(elems []*big.Int) (hash *big.Int, err error) { + elems = append(elems, big.NewInt(int64(len(elems)))) + return Curve.HashElements((elems)) +} + +/* +Provides the pedersen hash of given array of big integers. +NOTE: This function assumes the curve has been initialized with contant points + +(ref: https://github.com/seanjameshan/starknet.js/blob/main/src/utils/ellipticCurve.ts) +*/ +func (sc StarkCurve) PedersenHash(elems []*big.Int) (hash *big.Int, err error) { + if len(sc.ConstantPoints) == 0 { + return hash, fmt.Errorf("must initiate precomputed constant points") + } + + ptx := new(big.Int).Set(sc.Gx) + pty := new(big.Int).Set(sc.Gy) + for i, elem := range elems { + x := new(big.Int).Set(elem) + + if x.Cmp(big.NewInt(0)) != -1 && x.Cmp(sc.P) != -1 { + return ptx, fmt.Errorf("invalid x: %v", x) + } + + for j := 0; j < 252; j++ { + idx := 2 + (i * 252) + j + xin := new(big.Int).Set(sc.ConstantPoints[idx][0]) + yin := new(big.Int).Set(sc.ConstantPoints[idx][1]) + if xin.Cmp(ptx) == 0 { + return hash, fmt.Errorf("constant point duplication: %v %v", ptx, xin) + } + if x.Bit(0) == 1 { + ptx, pty = sc.Add(ptx, pty, xin, yin) + } + x = x.Rsh(x, 1) + } + } + + return ptx, nil +} + +/* +Provides the pedersen hash of given array of felts. +NOTE: This function just wraps the Juno implementation + +(ref: https://github.com/NethermindEth/juno/blob/main/core/crypto/poseidon_hash.go#L74) +*/ +func (sc StarkCurve) PoseidonArray(felts ...*felt.Felt) *felt.Felt { + return junoCrypto.PoseidonArray(felts...) +} + +/* +Provides the starknet keccak hash . +NOTE: This function just wraps the Juno implementation + +(ref: https://github.com/NethermindEth/juno/blob/main/core/crypto/keccak.go#L11) +*/ +func (sc StarkCurve) StarknetKeccak(b []byte) (*felt.Felt, error) { + return junoCrypto.StarknetKeccak(b) +} + +// implementation based on https://github.com/codahale/rfc6979/blob/master/rfc6979.go +// for the specification, see https://tools.ietf.org/html/rfc6979#section-3.2 +func (sc StarkCurve) GenerateSecret(msgHash, privKey, seed *big.Int) (secret *big.Int) { + alg := sha256.New + holen := alg().Size() + rolen := (sc.BitSize + 7) >> 3 + + if msgHash.BitLen()%8 <= 4 && msgHash.BitLen() >= 248 { + msgHash = msgHash.Mul(msgHash, big.NewInt(16)) + } + + by := append(int2octets(privKey, rolen), bits2octets(msgHash, sc.N, sc.BitSize, rolen)...) + + if seed.Cmp(big.NewInt(0)) == 1 { + by = append(by, seed.Bytes()...) + } + + v := bytes.Repeat([]byte{0x01}, holen) + + k := bytes.Repeat([]byte{0x00}, holen) + + k = mac(alg, k, append(append(v, 0x00), by...), k) + + v = mac(alg, k, v, v) + + k = mac(alg, k, append(append(v, 0x01), by...), k) + + v = mac(alg, k, v, v) + + for { + var t []byte + + for len(t) < rolen { + v = mac(alg, k, v, v) + t = append(t, v...) + } + + secret = bits2int(new(big.Int).SetBytes(t), sc.BitSize) + // TODO: implement seed here, final gating function + if secret.Cmp(big.NewInt(0)) == 1 && secret.Cmp(sc.N) == -1 { + return secret + } + k = mac(alg, k, append(v, 0x00), k) + v = mac(alg, k, v, v) + } +} + +// obtain random primary key on stark curve +// NOTE: to be used for testing purposes +func (sc StarkCurve) GetRandomPrivateKey() (priv *big.Int, err error) { + max := new(big.Int).Sub(sc.Max, big.NewInt(1)) + + priv, err = rand.Int(rand.Reader, max) + if err != nil { + return priv, err + } + + x, y, err := sc.PrivateToPoint(priv) + if err != nil { + return priv, err + } + + if !sc.IsOnCurve(x, y) { + return priv, fmt.Errorf("key gen is not on stark cruve") + } + + return priv, nil +} + +// obtain public key coordinates from stark curve given the private key +func (sc StarkCurve) PrivateToPoint(privKey *big.Int) (x, y *big.Int, err error) { + if privKey.Cmp(big.NewInt(0)) != 1 || privKey.Cmp(sc.N) != -1 { + return x, y, fmt.Errorf("private key not in curve range") + } + x, y = sc.EcMult(privKey, sc.EcGenX, sc.EcGenY) + return x, y, nil } diff --git a/starknetgo_test.go b/curve/curve_test.go similarity index 50% rename from starknetgo_test.go rename to curve/curve_test.go index bcb1546b..8d283933 100644 --- a/starknetgo_test.go +++ b/curve/curve_test.go @@ -1,4 +1,4 @@ -package starknetgo +package curve import ( "crypto/elliptic" @@ -9,6 +9,48 @@ import ( "github.com/NethermindEth/starknet.go/utils" ) +func BenchmarkPedersenHash(b *testing.B) { + suite := [][]*big.Int{ + {utils.HexToBN("0x12773"), utils.HexToBN("0x872362")}, + {utils.HexToBN("0x1277312773"), utils.HexToBN("0x872362872362")}, + {utils.HexToBN("0x1277312773"), utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826")}, + {utils.HexToBN("0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB"), utils.HexToBN("0x872362872362")}, + {utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826"), utils.HexToBN("0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB")}, + {utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), utils.HexToBN("0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9")}, + {utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbdde")}, + } + + for _, test := range suite { + b.Run(fmt.Sprintf("input_size_%d_%d", test[0].BitLen(), test[1].BitLen()), func(b *testing.B) { + Curve.PedersenHash(test) + }) + } +} + +func BenchmarkCurveSign(b *testing.B) { + type data struct { + MessageHash *big.Int + PrivateKey *big.Int + Seed *big.Int + } + + dataSet := []data{} + MessageHash := big.NewInt(0).Exp(big.NewInt(2), big.NewInt(250), nil) + PrivateKey := big.NewInt(0).Add(MessageHash, big.NewInt(1)) + Seed := big.NewInt(0) + for i := int64(0); i < 20; i++ { + dataSet = append(dataSet, data{ + MessageHash: big.NewInt(0).Add(MessageHash, big.NewInt(i)), + PrivateKey: big.NewInt(0).Add(PrivateKey, big.NewInt(i)), + Seed: big.NewInt(0).Add(Seed, big.NewInt(i)), + }) + + for _, test := range dataSet { + Curve.Sign(test.MessageHash, test.PrivateKey, test.Seed) + } + } +} + func BenchmarkSignatureVerify(b *testing.B) { private, _ := Curve.GetRandomPrivateKey() x, y, _ := Curve.PrivateToPoint(private) @@ -29,6 +71,140 @@ func BenchmarkSignatureVerify(b *testing.B) { }) } +func TestGeneral_PrivateToPoint(t *testing.T) { + x, _, err := Curve.PrivateToPoint(big.NewInt(2)) + if err != nil { + t.Errorf("PrivateToPoint err %v", err) + } + expectedX, _ := new(big.Int).SetString("3324833730090626974525872402899302150520188025637965566623476530814354734325", 10) + if x.Cmp(expectedX) != 0 { + t.Errorf("Actual public key %v different from expected %v", x, expectedX) + } +} + +func TestGeneral_PedersenHash(t *testing.T) { + testPedersen := []struct { + elements []*big.Int + expected *big.Int + }{ + { + elements: []*big.Int{utils.HexToBN("0x12773"), utils.HexToBN("0x872362")}, + expected: utils.HexToBN("0x5ed2703dfdb505c587700ce2ebfcab5b3515cd7e6114817e6026ec9d4b364ca"), + }, + { + elements: []*big.Int{utils.HexToBN("0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9"), utils.HexToBN("0x537461726b4e6574204d61696c")}, + expected: utils.HexToBN("0x180c0a3d13c1adfaa5cbc251f4fc93cc0e26cec30ca4c247305a7ce50ac807c"), + }, + { + elements: []*big.Int{big.NewInt(100), big.NewInt(1000)}, + expected: utils.HexToBN("0x45a62091df6da02dce4250cb67597444d1f465319908486b836f48d0f8bf6e7"), + }, + } + + for _, tt := range testPedersen { + hash, err := Curve.PedersenHash(tt.elements) + if err != nil { + t.Errorf("Hashing err: %v\n", err) + } + if hash.Cmp(tt.expected) != 0 { + t.Errorf("incorrect hash: got %v expected %v\n", hash, tt.expected) + } + } +} + +func TestGeneral_DivMod(t *testing.T) { + testDivmod := []struct { + x *big.Int + y *big.Int + expected *big.Int + }{ + { + x: utils.StrToBig("311379432064974854430469844112069886938521247361583891764940938105250923060"), + y: utils.StrToBig("621253665351494585790174448601059271924288186997865022894315848222045687999"), + expected: utils.StrToBig("2577265149861519081806762825827825639379641276854712526969977081060187505740"), + }, + { + x: big.NewInt(1), + y: big.NewInt(2), + expected: utils.HexToBN("0x0400000000000008800000000000000000000000000000000000000000000001"), + }, + } + + for _, tt := range testDivmod { + divR := DivMod(tt.x, tt.y, Curve.P) + + if divR.Cmp(tt.expected) != 0 { + t.Errorf("DivMod Res %v does not == expected %v\n", divR, tt.expected) + } + } +} + +func TestGeneral_Add(t *testing.T) { + testAdd := []struct { + x *big.Int + y *big.Int + expectedX *big.Int + expectedY *big.Int + }{ + { + x: utils.StrToBig("1468732614996758835380505372879805860898778283940581072611506469031548393285"), + y: utils.StrToBig("1402551897475685522592936265087340527872184619899218186422141407423956771926"), + expectedX: utils.StrToBig("2573054162739002771275146649287762003525422629677678278801887452213127777391"), + expectedY: utils.StrToBig("3086444303034188041185211625370405120551769541291810669307042006593736192813"), + }, + { + x: big.NewInt(1), + y: big.NewInt(2), + expectedX: utils.StrToBig("225199957243206662471193729647752088571005624230831233470296838210993906468"), + expectedY: utils.StrToBig("190092378222341939862849656213289777723812734888226565973306202593691957981"), + }, + } + + for _, tt := range testAdd { + resX, resY := Curve.Add(Curve.Gx, Curve.Gy, tt.x, tt.y) + if resX.Cmp(tt.expectedX) != 0 { + t.Errorf("ResX %v does not == expected %v\n", resX, tt.expectedX) + + } + if resY.Cmp(tt.expectedY) != 0 { + t.Errorf("ResY %v does not == expected %v\n", resY, tt.expectedY) + } + } +} + +func TestGeneral_MultAir(t *testing.T) { + testMult := []struct { + r *big.Int + x *big.Int + y *big.Int + expectedX *big.Int + expectedY *big.Int + }{ + { + r: utils.StrToBig("2458502865976494910213617956670505342647705497324144349552978333078363662855"), + x: utils.StrToBig("1468732614996758835380505372879805860898778283940581072611506469031548393285"), + y: utils.StrToBig("1402551897475685522592936265087340527872184619899218186422141407423956771926"), + expectedX: utils.StrToBig("182543067952221301675635959482860590467161609552169396182763685292434699999"), + expectedY: utils.StrToBig("3154881600662997558972388646773898448430820936643060392452233533274798056266"), + }, + } + + for _, tt := range testMult { + x, y, err := Curve.MimicEcMultAir(tt.r, tt.x, tt.y, Curve.Gx, Curve.Gy) + if err != nil { + t.Errorf("MultAirERR %v\n", err) + } + + if x.Cmp(tt.expectedX) != 0 { + t.Errorf("ResX %v does not == expected %v\n", x, tt.expectedX) + + } + if y.Cmp(tt.expectedY) != 0 { + t.Errorf("ResY %v does not == expected %v\n", y, tt.expectedY) + } + } +} + func TestGeneral_ComputeHashOnElements(t *testing.T) { hashEmptyArray, err := Curve.ComputeHashOnElements([]*big.Int{}) expectedHashEmmptyArray := utils.HexToBN("0x49ee3eba8c1600700ee1b87eb599f16716b0b1022947733551fde4050ca6804") @@ -197,3 +373,19 @@ func TestGeneral_Signature(t *testing.T) { } } } + +func TestGeneral_SplitFactStr(t *testing.T) { + data := []map[string]string{ + {"input": "0x3", "h": "0x0", "l": "0x3"}, + {"input": "0x300000000000000000000000000000000", "h": "0x3", "l": "0x0"}, + } + for _, d := range data { + l, h := utils.SplitFactStr(d["input"]) // 0x3 + if l != d["l"] { + t.Errorf("expected %s, got %s", d["l"], l) + } + if h != d["h"] { + t.Errorf("expected %s, got %s", d["h"], h) + } + } +} diff --git a/opts.go b/curve/opts.go similarity index 97% rename from opts.go rename to curve/opts.go index c3ec447d..afeb1f31 100644 --- a/opts.go +++ b/curve/opts.go @@ -1,4 +1,4 @@ -package starknetgo +package curve type curveOptions struct { initConstants bool diff --git a/pedersen_params.json b/curve/pedersen_params.json similarity index 100% rename from pedersen_params.json rename to curve/pedersen_params.json diff --git a/curve/utils.go b/curve/utils.go new file mode 100644 index 00000000..b7e48bcf --- /dev/null +++ b/curve/utils.go @@ -0,0 +1,103 @@ +package curve + +import ( + "crypto/hmac" + "hash" + "math/big" +) + +// Finds a nonnegative integer 0 <= x < p such that (m * x) % p == n +// +// (ref: https://github.com/starkware-libs/cairo-lang/blob/master/src/starkware/crypto/starkware/crypto/signature/math_utils.py) +func DivMod(n, m, p *big.Int) *big.Int { + q := new(big.Int) + gx := new(big.Int) + gy := new(big.Int) + q.GCD(gx, gy, m, p) + + r := new(big.Int).Mul(n, gx) + r = r.Mod(r, p) + return r +} + +// https://tools.ietf.org/html/rfc6979#section-2.3.3 +func int2octets(v *big.Int, rolen int) []byte { + out := v.Bytes() + + // pad with zeros if it's too short + if len(out) < rolen { + out2 := make([]byte, rolen) + copy(out2[rolen-len(out):], out) + return out2 + } + + // drop most significant bytes if it's too long + if len(out) > rolen { + out2 := make([]byte, rolen) + copy(out2, out[len(out)-rolen:]) + return out2 + } + + return out +} + +// https://tools.ietf.org/html/rfc6979#section-2.3.4 +func bits2octets(in, q *big.Int, qlen, rolen int) []byte { + z1 := bits2int(in, qlen) + z2 := new(big.Int).Sub(z1, q) + if z2.Sign() < 0 { + return int2octets(z1, rolen) + } + return int2octets(z2, rolen) +} + +// https://tools.ietf.org/html/rfc6979#section-2.3.2 +func bits2int(in *big.Int, qlen int) *big.Int { + blen := len(in.Bytes()) * 8 + + if blen > qlen { + + return new(big.Int).Rsh(in, uint(blen-qlen)) + } + return in +} + +// mac returns an HMAC of the given key and message. +func mac(alg func() hash.Hash, k, m, buf []byte) []byte { + h := hmac.New(alg, k) + h.Write(m) + return h.Sum(buf[:0]) +} + +// mask excess bits +func MaskBits(mask, wordSize int, slice []byte) (ret []byte) { + excess := len(slice)*wordSize - mask + for _, by := range slice { + if excess > 0 { + if excess > wordSize { + excess = excess - wordSize + continue + } + by <<= excess + by >>= excess + excess = 0 + } + ret = append(ret, by) + } + return ret +} + +// format the bytes in Keccak hash +func FmtKecBytes(in *big.Int, rolen int) (buf []byte) { + buf = append(buf, in.Bytes()...) + + // pad with zeros if too short + if len(buf) < rolen { + padded := make([]byte, rolen) + copy(padded[rolen-len(buf):], buf) + + return padded + } + + return buf +} diff --git a/curve_test.go b/curve_test.go deleted file mode 100644 index a18129cd..00000000 --- a/curve_test.go +++ /dev/null @@ -1,185 +0,0 @@ -package starknetgo - -import ( - "fmt" - "math/big" - "testing" - - "github.com/NethermindEth/starknet.go/utils" -) - -func BenchmarkPedersenHash(b *testing.B) { - suite := [][]*big.Int{ - {utils.HexToBN("0x12773"), utils.HexToBN("0x872362")}, - {utils.HexToBN("0x1277312773"), utils.HexToBN("0x872362872362")}, - {utils.HexToBN("0x1277312773"), utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826")}, - {utils.HexToBN("0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB"), utils.HexToBN("0x872362872362")}, - {utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826"), utils.HexToBN("0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB")}, - {utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), utils.HexToBN("0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9")}, - {utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbddd"), utils.HexToBN("0x7f15c38ea577a26f4f553282fcfe4f1feeb8ecfaad8f221ae41abf8224cbdde")}, - } - - for _, test := range suite { - b.Run(fmt.Sprintf("input_size_%d_%d", test[0].BitLen(), test[1].BitLen()), func(b *testing.B) { - Curve.PedersenHash(test) - }) - } -} - -func BenchmarkCurveSign(b *testing.B) { - type data struct { - MessageHash *big.Int - PrivateKey *big.Int - Seed *big.Int - } - - dataSet := []data{} - MessageHash := big.NewInt(0).Exp(big.NewInt(2), big.NewInt(250), nil) - PrivateKey := big.NewInt(0).Add(MessageHash, big.NewInt(1)) - Seed := big.NewInt(0) - for i := int64(0); i < 20; i++ { - dataSet = append(dataSet, data{ - MessageHash: big.NewInt(0).Add(MessageHash, big.NewInt(i)), - PrivateKey: big.NewInt(0).Add(PrivateKey, big.NewInt(i)), - Seed: big.NewInt(0).Add(Seed, big.NewInt(i)), - }) - - for _, test := range dataSet { - Curve.Sign(test.MessageHash, test.PrivateKey, test.Seed) - } - } -} - -func TestGeneral_PrivateToPoint(t *testing.T) { - x, _, err := Curve.PrivateToPoint(big.NewInt(2)) - if err != nil { - t.Errorf("PrivateToPoint err %v", err) - } - expectedX, _ := new(big.Int).SetString("3324833730090626974525872402899302150520188025637965566623476530814354734325", 10) - if x.Cmp(expectedX) != 0 { - t.Errorf("Actual public key %v different from expected %v", x, expectedX) - } -} - -func TestGeneral_PedersenHash(t *testing.T) { - testPedersen := []struct { - elements []*big.Int - expected *big.Int - }{ - { - elements: []*big.Int{utils.HexToBN("0x12773"), utils.HexToBN("0x872362")}, - expected: utils.HexToBN("0x5ed2703dfdb505c587700ce2ebfcab5b3515cd7e6114817e6026ec9d4b364ca"), - }, - { - elements: []*big.Int{utils.HexToBN("0x13d41f388b8ea4db56c5aa6562f13359fab192b3db57651af916790f9debee9"), utils.HexToBN("0x537461726b4e6574204d61696c")}, - expected: utils.HexToBN("0x180c0a3d13c1adfaa5cbc251f4fc93cc0e26cec30ca4c247305a7ce50ac807c"), - }, - { - elements: []*big.Int{big.NewInt(100), big.NewInt(1000)}, - expected: utils.HexToBN("0x45a62091df6da02dce4250cb67597444d1f465319908486b836f48d0f8bf6e7"), - }, - } - - for _, tt := range testPedersen { - hash, err := Curve.PedersenHash(tt.elements) - if err != nil { - t.Errorf("Hashing err: %v\n", err) - } - if hash.Cmp(tt.expected) != 0 { - t.Errorf("incorrect hash: got %v expected %v\n", hash, tt.expected) - } - } -} - -func TestGeneral_DivMod(t *testing.T) { - testDivmod := []struct { - x *big.Int - y *big.Int - expected *big.Int - }{ - { - x: utils.StrToBig("311379432064974854430469844112069886938521247361583891764940938105250923060"), - y: utils.StrToBig("621253665351494585790174448601059271924288186997865022894315848222045687999"), - expected: utils.StrToBig("2577265149861519081806762825827825639379641276854712526969977081060187505740"), - }, - { - x: big.NewInt(1), - y: big.NewInt(2), - expected: utils.HexToBN("0x0400000000000008800000000000000000000000000000000000000000000001"), - }, - } - - for _, tt := range testDivmod { - divR := DivMod(tt.x, tt.y, Curve.P) - - if divR.Cmp(tt.expected) != 0 { - t.Errorf("DivMod Res %v does not == expected %v\n", divR, tt.expected) - } - } -} - -func TestGeneral_Add(t *testing.T) { - testAdd := []struct { - x *big.Int - y *big.Int - expectedX *big.Int - expectedY *big.Int - }{ - { - x: utils.StrToBig("1468732614996758835380505372879805860898778283940581072611506469031548393285"), - y: utils.StrToBig("1402551897475685522592936265087340527872184619899218186422141407423956771926"), - expectedX: utils.StrToBig("2573054162739002771275146649287762003525422629677678278801887452213127777391"), - expectedY: utils.StrToBig("3086444303034188041185211625370405120551769541291810669307042006593736192813"), - }, - { - x: big.NewInt(1), - y: big.NewInt(2), - expectedX: utils.StrToBig("225199957243206662471193729647752088571005624230831233470296838210993906468"), - expectedY: utils.StrToBig("190092378222341939862849656213289777723812734888226565973306202593691957981"), - }, - } - - for _, tt := range testAdd { - resX, resY := Curve.Add(Curve.Gx, Curve.Gy, tt.x, tt.y) - if resX.Cmp(tt.expectedX) != 0 { - t.Errorf("ResX %v does not == expected %v\n", resX, tt.expectedX) - - } - if resY.Cmp(tt.expectedY) != 0 { - t.Errorf("ResY %v does not == expected %v\n", resY, tt.expectedY) - } - } -} - -func TestGeneral_MultAir(t *testing.T) { - testMult := []struct { - r *big.Int - x *big.Int - y *big.Int - expectedX *big.Int - expectedY *big.Int - }{ - { - r: utils.StrToBig("2458502865976494910213617956670505342647705497324144349552978333078363662855"), - x: utils.StrToBig("1468732614996758835380505372879805860898778283940581072611506469031548393285"), - y: utils.StrToBig("1402551897475685522592936265087340527872184619899218186422141407423956771926"), - expectedX: utils.StrToBig("182543067952221301675635959482860590467161609552169396182763685292434699999"), - expectedY: utils.StrToBig("3154881600662997558972388646773898448430820936643060392452233533274798056266"), - }, - } - - for _, tt := range testMult { - x, y, err := Curve.MimicEcMultAir(tt.r, tt.x, tt.y, Curve.Gx, Curve.Gy) - if err != nil { - t.Errorf("MultAirERR %v\n", err) - } - - if x.Cmp(tt.expectedX) != 0 { - t.Errorf("ResX %v does not == expected %v\n", x, tt.expectedX) - - } - if y.Cmp(tt.expectedY) != 0 { - t.Errorf("ResY %v does not == expected %v\n", y, tt.expectedY) - } - } -} diff --git a/examples/curve/go.mod b/examples/curve/go.mod deleted file mode 100644 index 07dbe2ee..00000000 --- a/examples/curve/go.mod +++ /dev/null @@ -1,12 +0,0 @@ -module github.com/NethermindEth/starknet.go/examples/curve - -go 1.18 - -replace github.com/NethermindEth/starknet.go => ../../ - -require github.com/NethermindEth/starknet.go v0.3.1 - -require ( - golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e // indirect - golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c // indirect -) diff --git a/examples/curve/go.work b/examples/curve/go.work deleted file mode 100644 index b32bc53a..00000000 --- a/examples/curve/go.work +++ /dev/null @@ -1,6 +0,0 @@ -go 1.18 - -use ( - . - ../.. -) diff --git a/examples/curve/main.go b/examples/curve/main.go deleted file mode 100644 index a130f28c..00000000 --- a/examples/curve/main.go +++ /dev/null @@ -1,40 +0,0 @@ -package main - -import ( - "fmt" - "math/big" - - starknetgo "github.com/NethermindEth/starknet.go" - "github.com/NethermindEth/starknet.go/utils" -) - -func main() { - /* - Although the library adheres to the 'elliptic/curve' interface. - All testing has been done against library function explicity. - It is recommended to use in the same way(i.e. `curve.Sign` and not `ecdsa.Sign`). - NOTE: when not given local file path this pulls the curve data from Starkware github repo - */ - hash, err := starknetgo.Curve.PedersenHash([]*big.Int{utils.HexToBN("0x12773"), utils.HexToBN("0x872362")}) - if err != nil { - panic(err.Error()) - } - - priv, _ := starknetgo.Curve.GetRandomPrivateKey() - - x, y, err := starknetgo.Curve.PrivateToPoint(priv) - if err != nil { - panic(err.Error()) - } - - r, s, err := starknetgo.Curve.Sign(hash, priv) - if err != nil { - panic(err.Error()) - } - - if starknetgo.Curve.Verify(hash, r, s, x, y) { - fmt.Println("signature is valid") - } else { - fmt.Println("signature is invalid") - } -} diff --git a/examples/deployAccount/main.go b/examples/deployAccount/main.go index a9e6d98c..edfec071 100644 --- a/examples/deployAccount/main.go +++ b/examples/deployAccount/main.go @@ -6,7 +6,7 @@ import ( "os" "github.com/NethermindEth/juno/core/felt" - starknetgo "github.com/NethermindEth/starknet.go" + "github.com/NethermindEth/starknet.go/curve" "github.com/NethermindEth/starknet.go/rpc" "github.com/NethermindEth/starknet.go/utils" ethrpc "github.com/ethereum/go-ethereum/rpc" @@ -70,7 +70,7 @@ func main() { panic(err) } fmt.Println("Transaction hash:", hash) - x, y, err := starknetgo.Curve.SignFelt(hash, priv) + x, y, err := curve.Curve.SignFelt(hash, priv) if err != nil { panic(err) } @@ -87,24 +87,19 @@ func main() { } func getRandomKeys() (*felt.Felt, *felt.Felt) { - privateKey, err := starknetgo.Curve.GetRandomPrivateKey() + privateKey, err := curve.Curve.GetRandomPrivateKey() if err != nil { fmt.Println("can't get random private key:", err) os.Exit(1) } - pubX, _, err := starknetgo.Curve.PrivateToPoint(privateKey) + pubX, _, err := curve.Curve.PrivateToPoint(privateKey) if err != nil { fmt.Println("can't generate public key:", err) os.Exit(1) } - privFelt, err := utils.BigIntToFelt(privateKey) - if err != nil { - panic(err) - } - pubFelt, err := utils.BigIntToFelt(pubX) - if err != nil { - panic(err) - } + privFelt := utils.BigIntToFelt(privateKey) + pubFelt := utils.BigIntToFelt(pubX) + return pubFelt, privFelt } @@ -114,38 +109,31 @@ func getRandomKeys() (*felt.Felt, *felt.Felt) { func precomputeAddress(deployerAddress *felt.Felt, salt *felt.Felt, classHash *felt.Felt, constructorCalldata []*felt.Felt) (*felt.Felt, error) { CONTRACT_ADDRESS_PREFIX := new(felt.Felt).SetBytes([]byte("STARKNET_CONTRACT_ADDRESS")) - bigIntArr, err := utils.FeltArrToBigIntArr([]*felt.Felt{ + bigIntArr := utils.FeltArrToBigIntArr([]*felt.Felt{ CONTRACT_ADDRESS_PREFIX, deployerAddress, salt, classHash, }) - if err != nil { - return nil, err - } - - constructorCalldataBigIntArr, err := utils.FeltArrToBigIntArr(constructorCalldata) - constructorCallDataHashInt, _ := starknetgo.Curve.ComputeHashOnElements(*constructorCalldataBigIntArr) - *bigIntArr = append(*bigIntArr, constructorCallDataHashInt) + constructorCalldataBigIntArr := utils.FeltArrToBigIntArr(constructorCalldata) + constructorCallDataHashInt, _ := curve.Curve.ComputeHashOnElements(constructorCalldataBigIntArr) + bigIntArr = append(bigIntArr, constructorCallDataHashInt) - preBigInt, err := starknetgo.Curve.ComputeHashOnElements(*bigIntArr) + preBigInt, err := curve.Curve.ComputeHashOnElements(bigIntArr) if err != nil { return nil, err } - return utils.BigIntToFelt(preBigInt) + return utils.BigIntToFelt(preBigInt), nil } func computeHashOnElementsFelt(feltArr []*felt.Felt) (*felt.Felt, error) { - bigIntArr, err := utils.FeltArrToBigIntArr(feltArr) - if err != nil { - return nil, err - } - hash, err := starknetgo.Curve.ComputeHashOnElements(*bigIntArr) + bigIntArr := utils.FeltArrToBigIntArr(feltArr) + hash, err := curve.Curve.ComputeHashOnElements(bigIntArr) if err != nil { return nil, err } - return utils.BigIntToFelt(hash) + return utils.BigIntToFelt(hash), nil } // calculateDeployAccountTransactionHash computes the transaction hash for deployAccount transactions diff --git a/hash/hash.go b/hash/hash.go index 16c23e93..4ab72bdf 100644 --- a/hash/hash.go +++ b/hash/hash.go @@ -2,8 +2,8 @@ package hash import ( "github.com/NethermindEth/juno/core/felt" - starknetgo "github.com/NethermindEth/starknet.go" "github.com/NethermindEth/starknet.go/contracts" + "github.com/NethermindEth/starknet.go/curve" "github.com/NethermindEth/starknet.go/rpc" "github.com/NethermindEth/starknet.go/utils" ) @@ -12,7 +12,7 @@ import ( func ComputeHashOnElementsFelt(feltArr []*felt.Felt) (*felt.Felt, error) { bigIntArr := utils.FeltArrToBigIntArr(feltArr) - hash, err := starknetgo.Curve.ComputeHashOnElements(bigIntArr) + hash, err := curve.Curve.ComputeHashOnElements(bigIntArr) if err != nil { return nil, err } @@ -52,14 +52,14 @@ func ClassHash(contract rpc.ContractClass) (*felt.Felt, error) { ConstructorHash := hashEntryPointByType(contract.EntryPointsByType.Constructor) ExternalHash := hashEntryPointByType(contract.EntryPointsByType.External) L1HandleHash := hashEntryPointByType(contract.EntryPointsByType.L1Handler) - SierraProgamHash := starknetgo.Curve.PoseidonArray(contract.SierraProgram...) - ABIHash, err := starknetgo.Curve.StarknetKeccak([]byte(contract.ABI)) + SierraProgamHash := curve.Curve.PoseidonArray(contract.SierraProgram...) + ABIHash, err := curve.Curve.StarknetKeccak([]byte(contract.ABI)) if err != nil { return nil, err } // https://docs.starknet.io/documentation/architecture_and_concepts/Network_Architecture/transactions/#deploy_account_hash_calculation - return starknetgo.Curve.PoseidonArray(ContractClassVersionHash, ExternalHash, L1HandleHash, ConstructorHash, ABIHash, SierraProgamHash), nil + return curve.Curve.PoseidonArray(ContractClassVersionHash, ExternalHash, L1HandleHash, ConstructorHash, ABIHash, SierraProgamHash), nil } func hashEntryPointByType(entryPoint []rpc.SierraEntryPoint) *felt.Felt { @@ -67,7 +67,7 @@ func hashEntryPointByType(entryPoint []rpc.SierraEntryPoint) *felt.Felt { for _, elt := range entryPoint { flattened = append(flattened, elt.Selector, new(felt.Felt).SetUint64(uint64(elt.FunctionIdx))) } - return starknetgo.Curve.PoseidonArray(flattened...) + return curve.Curve.PoseidonArray(flattened...) } func CompiledClassHash(casmClass contracts.CasmClass) *felt.Felt { @@ -75,10 +75,10 @@ func CompiledClassHash(casmClass contracts.CasmClass) *felt.Felt { ExternalHash := hashCasmClassEntryPointByType(casmClass.EntryPointByType.External) L1HandleHash := hashCasmClassEntryPointByType(casmClass.EntryPointByType.L1Handler) ConstructorHash := hashCasmClassEntryPointByType(casmClass.EntryPointByType.Constructor) - ByteCodeHasH := starknetgo.Curve.PoseidonArray(casmClass.ByteCode...) + ByteCodeHasH := curve.Curve.PoseidonArray(casmClass.ByteCode...) // https://github.com/software-mansion/starknet.py/blob/development/starknet_py/hash/casm_class_hash.py#L10 - return starknetgo.Curve.PoseidonArray(ContractClassVersionHash, ExternalHash, L1HandleHash, ConstructorHash, ByteCodeHasH) + return curve.Curve.PoseidonArray(ContractClassVersionHash, ExternalHash, L1HandleHash, ConstructorHash, ByteCodeHasH) } func hashCasmClassEntryPointByType(entryPoint []contracts.CasmClassEntryPoint) *felt.Felt { @@ -88,8 +88,8 @@ func hashCasmClassEntryPointByType(entryPoint []contracts.CasmClassEntryPoint) * for _, builtIn := range elt.Builtins { builtInFlat = append(builtInFlat, new(felt.Felt).SetBytes([]byte(builtIn))) } - builtInHash := starknetgo.Curve.PoseidonArray(builtInFlat...) + builtInHash := curve.Curve.PoseidonArray(builtInFlat...) flattened = append(flattened, elt.Selector, new(felt.Felt).SetUint64(uint64(elt.Offset)), builtInHash) } - return starknetgo.Curve.PoseidonArray(flattened...) + return curve.Curve.PoseidonArray(flattened...) } diff --git a/merkle.go b/merkle/merkle.go similarity index 93% rename from merkle.go rename to merkle/merkle.go index 2a6f3e85..007512cd 100644 --- a/merkle.go +++ b/merkle/merkle.go @@ -1,8 +1,11 @@ -package starknetgo +package merkle import ( "fmt" "math/big" + + + "github.com/NethermindEth/starknet.go/curve" ) type FixedSizeMerkleTree struct { @@ -26,9 +29,9 @@ func NewFixedSizeMerkleTree(leaves ...*big.Int) (*FixedSizeMerkleTree, error) { func MerkleHash(x, y *big.Int) (*big.Int, error) { if x.Cmp(y) <= 0 { - return Curve.HashElements([]*big.Int{x, y}) + return curve.Curve.HashElements([]*big.Int{x, y}) } - return Curve.HashElements([]*big.Int{y, x}) + return curve.Curve.HashElements([]*big.Int{y, x}) } func (mt *FixedSizeMerkleTree) build(leaves []*big.Int) (*big.Int, error) { diff --git a/merkle_test.go b/merkle/merkle_test.go similarity index 98% rename from merkle_test.go rename to merkle/merkle_test.go index 4e3939c7..84d97eaf 100644 --- a/merkle_test.go +++ b/merkle/merkle_test.go @@ -1,4 +1,4 @@ -package starknetgo +package merkle import ( "math/big" diff --git a/providers_test.go b/providers_test.go deleted file mode 100644 index 48af8916..00000000 --- a/providers_test.go +++ /dev/null @@ -1,155 +0,0 @@ -package starknetgo - -import ( - "context" - "flag" - "fmt" - "math/big" - "os" - "strings" - "testing" - - "github.com/NethermindEth/starknet.go/rpc" - ethrpc "github.com/ethereum/go-ethereum/rpc" - "github.com/joho/godotenv" -) - -const ( - TestPublicKey = "0x783318b2cc1067e5c06d374d2bb9a0382c39aabd009b165d7a268b882971d6" - DevNetETHAddress = "0x62230ea046a9a5fbc261ac77d03c8d41e5d442db2284587570ab46455fd2488" - TestNetETHAddress = "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7" - DevNetAccount032Address = "0x0536244bba4dc9bb219d964b477af6d18f7096635a96284bb0e008bf137650ec" - TestNetAccount032Address = "0x6ca4fdd437dffde5253ba7021ef7265c88b07789aa642eafda37791626edf00" - DevNetAccount040Address = "0x058079067104f58fd9f1ef949cd2d2b482d7bca39b793983f077edaf51d979e9" - TestNetAccount040Address = "0x6cbfa37f409610fee26eeb427ed854b3a4b24580d9b9ef6c3e38db7b3f7322c" - TestnetCounterAddress = "0x51e94d515df16ecae5be4a377666121494eb54193d854fcf5baba2b0da679c6" -) - -var ( - // set the environment for the test, default: mock - testEnv = "mock" - - // testConfigurations are predefined test configurations - testRPCConfigurations = map[string]testRPCConfiguration{ - // Requires a Mainnet Starknet JSON-RPC compliant node (e.g. pathfinder) - // (ref: https://github.com/eqlabs/pathfinder) - "mainnet": { - base: "http://localhost:9545", - }, - // Requires a Testnet Starknet JSON-RPC compliant node (e.g. pathfinder) - // (ref: https://github.com/eqlabs/pathfinder) - "testnet": { - base: "http://localhost:9545", - }, - // Requires a Devnet configuration running locally - // (ref: https://github.com/Shard-Labs/starknet-devnet) - "devnet": { - base: "http://localhost:5050/rpc", - }, - // Used with a mock as a standard configuration, see `mock_test.go`` - "mock": {}, - } -) - -// testConfiguration is a type that is used to configure tests -type testRPCConfiguration struct { - providerv02 *rpc.Provider - base string -} - -// TestMain is used to trigger the tests and, in that case, check for the environment to use. -func TestMain(m *testing.M) { - baseURL := "" - flag.StringVar(&testEnv, "env", "mock", "set the test environment") - flag.StringVar(&baseURL, "base-url", "", "change the baseUrl") - flag.Parse() - godotenv.Load(fmt.Sprintf(".env.%s", testEnv), ".env") - if baseURL != "" { - rpcLocalConfig := testRPCConfigurations[testEnv] - rpcLocalConfig.base = baseURL - testRPCConfigurations[testEnv] = rpcLocalConfig - } - os.Exit(m.Run()) -} - -// beforeEach checks the configuration and initializes it before running the script -func beforeRPCEach(t *testing.T) *testRPCConfiguration { - t.Helper() - godotenv.Load(fmt.Sprintf(".env.%s", testEnv), ".env") - testConfig, ok := testRPCConfigurations[testEnv] - if !ok { - t.Fatal("env supports mock, testnet, mainnet or devnet") - } - testConfig.base = "https://starknet-goerli.cartridge.gg" - base := os.Getenv("INTEGRATION_BASE") - if base != "" { - testConfig.base = base - } - c, err := ethrpc.DialContext(context.Background(), testConfig.base) - if err != nil { - t.Fatal("connect should succeed, instead:", err) - } - clientv02 := rpc.NewProvider(c) - testConfig.providerv02 = clientv02 - return &testConfig -} - -// TestChainID checks the chainId matches the one for the environment -func TestGeneral_ChainID(t *testing.T) { - testConfig := beforeRPCEach(t) - - type testSetType struct { - ChainID string - } - testSet := map[string][]testSetType{ - "devnet": {{ChainID: "SN_GOERLI"}}, - "mainnet": {{ChainID: "SN_MAIN"}}, - "mock": {{ChainID: "MOCK"}}, - "testnet": {{ChainID: "SN_GOERLI"}}, - }[testEnv] - - fmt.Printf("----------------------------\n") - fmt.Printf("Env: %s\n", testEnv) - fmt.Printf("Url: %s\n", testConfig.base) - fmt.Printf("----------------------------\n") - - for _, test := range testSet { - chain, err := testConfig.providerv02.ChainID(context.Background()) - if err != nil { - t.Fatal(err) - } - if chain != test.ChainID { - t.Fatalf("expecting %s, instead: %s", test.ChainID, chain) - } - } -} - -// TestSyncing checks the values returned are consistent -func TestGeneral_Syncing(t *testing.T) { - testConfig := beforeRPCEach(t) - - type testSetType struct { - ChainID string - } - - testSet := map[string][]testSetType{ - "devnet": {}, - "mainnet": {{ChainID: "SN_MAIN"}}, - "mock": {{ChainID: "MOCK"}}, - "testnet": {{ChainID: "SN_GOERLI"}}, - }[testEnv] - - for range testSet { - syncv02, err := testConfig.providerv02.Syncing(context.Background()) - if err != nil { - t.Fatal("BlockWithTxHashes match the expected error:", err) - } - i, ok := big.NewInt(0).SetString(string(syncv02.CurrentBlockNum), 0) - if !ok || i.Cmp(big.NewInt(0)) <= 0 { - t.Fatal("CurrentBlockNum should be positive number, instead: ", syncv02.CurrentBlockNum) - } - if !strings.HasPrefix(syncv02.CurrentBlockHash.String(), "0x") { - t.Fatal("current block hash should return a string starting with 0x") - } - } -} diff --git a/starknetgo.go b/starknetgo.go deleted file mode 100644 index 184a8543..00000000 --- a/starknetgo.go +++ /dev/null @@ -1,371 +0,0 @@ -package starknetgo - -import ( - "bytes" - "crypto/hmac" - "crypto/sha256" - "fmt" - "hash" - "math/big" - - junoCrypto "github.com/NethermindEth/juno/core/crypto" - "github.com/NethermindEth/juno/core/felt" -) - -/* -Verifies the validity of the stark curve signature -given the message hash, and public key (x, y) coordinates -used to sign the message. - -(ref: https://github.com/starkware-libs/cairo-lang/blob/master/src/starkware/crypto/starkware/crypto/signature/signature.py) -*/ -func (sc StarkCurve) Verify(msgHash, r, s, pubX, pubY *big.Int) bool { - w := sc.InvModCurveSize(s) - - if s.Cmp(big.NewInt(0)) != 1 || s.Cmp(sc.N) != -1 { - return false - } - if r.Cmp(big.NewInt(0)) != 1 || r.Cmp(sc.Max) != -1 { - return false - } - if w.Cmp(big.NewInt(0)) != 1 || w.Cmp(sc.Max) != -1 { - return false - } - if msgHash.Cmp(big.NewInt(0)) != 1 || msgHash.Cmp(sc.Max) != -1 { - return false - } - if !sc.IsOnCurve(pubX, pubY) { - return false - } - - zGx, zGy, err := sc.MimicEcMultAir(msgHash, sc.EcGenX, sc.EcGenY, sc.MinusShiftPointX, sc.MinusShiftPointY) - if err != nil { - return false - } - - rQx, rQy, err := sc.MimicEcMultAir(r, pubX, pubY, sc.Gx, sc.Gy) - if err != nil { - return false - } - inX, inY := sc.Add(zGx, zGy, rQx, rQy) - wBx, wBy, err := sc.MimicEcMultAir(w, inX, inY, sc.Gx, sc.Gy) - if err != nil { - return false - } - - outX, _ := sc.Add(wBx, wBy, sc.MinusShiftPointX, sc.MinusShiftPointY) - if r.Cmp(outX) == 0 { - return true - } else { - altY := new(big.Int).Neg(pubY) - - zGx, zGy, err = sc.MimicEcMultAir(msgHash, sc.EcGenX, sc.EcGenY, sc.MinusShiftPointX, sc.MinusShiftPointY) - if err != nil { - return false - } - - rQx, rQy, err = sc.MimicEcMultAir(r, pubX, new(big.Int).Set(altY), sc.Gx, sc.Gy) - if err != nil { - return false - } - inX, inY = sc.Add(zGx, zGy, rQx, rQy) - wBx, wBy, err = sc.MimicEcMultAir(w, inX, inY, sc.Gx, sc.Gy) - if err != nil { - return false - } - - outX, _ = sc.Add(wBx, wBy, sc.MinusShiftPointX, sc.MinusShiftPointY) - if r.Cmp(outX) == 0 { - return true - } - } - return false -} - -/* -Signs the hash value of contents with the provided private key. -Secret is generated using a golang implementation of RFC 6979. -Implementation does not yet include "extra entropy" or "retry gen". - -(ref: https://datatracker.ietf.org/doc/html/rfc6979) -*/ -func (sc StarkCurve) Sign(msgHash, privKey *big.Int, seed ...*big.Int) (x, y *big.Int, err error) { - if msgHash == nil { - return x, y, fmt.Errorf("nil msgHash") - } - if privKey == nil { - return x, y, fmt.Errorf("nil privKey") - } - if msgHash.Cmp(big.NewInt(0)) != 1 || msgHash.Cmp(sc.Max) != -1 { - return x, y, fmt.Errorf("invalid bit length") - } - - inSeed := big.NewInt(0) - if len(seed) == 1 && inSeed != nil { - inSeed = seed[0] - } - for { - k := sc.GenerateSecret(big.NewInt(0).Set(msgHash), big.NewInt(0).Set(privKey), big.NewInt(0).Set(inSeed)) - // In case r is rejected k shall be generated with new seed - inSeed = inSeed.Add(inSeed, big.NewInt(1)) - - r, _ := sc.EcMult(k, sc.EcGenX, sc.EcGenY) - - // DIFF: in classic ECDSA, we take int(x) % n. - if r.Cmp(big.NewInt(0)) != 1 || r.Cmp(sc.Max) != -1 { - // Bad value. This fails with negligible probability. - continue - } - - agg := new(big.Int).Mul(r, privKey) - agg = agg.Add(agg, msgHash) - - if new(big.Int).Mod(agg, sc.N).Cmp(big.NewInt(0)) == 0 { - // Bad value. This fails with negligible probability. - continue - } - - w := DivMod(k, agg, sc.N) - if w.Cmp(big.NewInt(0)) != 1 || w.Cmp(sc.Max) != -1 { - // Bad value. This fails with negligible probability. - continue - } - - s := sc.InvModCurveSize(w) - return r, s, nil - } - - return x, y, nil -} - -/* -See Sign. SignFelt just wraps Sign. -*/ -func (sc StarkCurve) SignFelt(msgHash, privKey *felt.Felt) (*felt.Felt, *felt.Felt, error) { - msgHashInt := msgHash.BigInt(new(big.Int)) - privKeyInt := privKey.BigInt(new(big.Int)) - x, y, err := sc.Sign(msgHashInt, privKeyInt) - if err != nil { - return nil, nil, err - } - xFelt := felt.NewFelt(new(felt.Felt).Impl().SetBigInt(x)) - yFelt := felt.NewFelt(new(felt.Felt).Impl().SetBigInt(y)) - return xFelt, yFelt, nil - -} - -/* -Hashes the contents of a given array using a golang Pedersen Hash implementation. - -(ref: https://github.com/seanjameshan/starknet.js/blob/main/src/utils/ellipticCurve.ts) -*/ -func (sc StarkCurve) HashElements(elems []*big.Int) (hash *big.Int, err error) { - if len(elems) == 0 { - elems = append(elems, big.NewInt(0)) - } - - hash = big.NewInt(0) - for _, h := range elems { - hash, err = sc.PedersenHash([]*big.Int{hash, h}) - if err != nil { - return hash, err - } - } - return hash, err -} - -/* -Hashes the contents of a given array with its size using a golang Pedersen Hash implementation. - -(ref: https://github.com/starkware-libs/cairo-lang/blob/13cef109cd811474de114925ee61fd5ac84a25eb/src/starkware/cairo/common/hash_state.py#L6) -*/ -func (sc StarkCurve) ComputeHashOnElements(elems []*big.Int) (hash *big.Int, err error) { - elems = append(elems, big.NewInt(int64(len(elems)))) - return Curve.HashElements((elems)) -} - -/* -Provides the pedersen hash of given array of big integers. -NOTE: This function assumes the curve has been initialized with contant points - -(ref: https://github.com/seanjameshan/starknet.js/blob/main/src/utils/ellipticCurve.ts) -*/ -func (sc StarkCurve) PedersenHash(elems []*big.Int) (hash *big.Int, err error) { - if len(sc.ConstantPoints) == 0 { - return hash, fmt.Errorf("must initiate precomputed constant points") - } - - ptx := new(big.Int).Set(sc.Gx) - pty := new(big.Int).Set(sc.Gy) - for i, elem := range elems { - x := new(big.Int).Set(elem) - - if x.Cmp(big.NewInt(0)) != -1 && x.Cmp(sc.P) != -1 { - return ptx, fmt.Errorf("invalid x: %v", x) - } - - for j := 0; j < 252; j++ { - idx := 2 + (i * 252) + j - xin := new(big.Int).Set(sc.ConstantPoints[idx][0]) - yin := new(big.Int).Set(sc.ConstantPoints[idx][1]) - if xin.Cmp(ptx) == 0 { - return hash, fmt.Errorf("constant point duplication: %v %v", ptx, xin) - } - if x.Bit(0) == 1 { - ptx, pty = sc.Add(ptx, pty, xin, yin) - } - x = x.Rsh(x, 1) - } - } - - return ptx, nil -} - -/* -Provides the pedersen hash of given array of felts. -NOTE: This function just wraps the Juno implementation - -(ref: https://github.com/NethermindEth/juno/blob/main/core/crypto/poseidon_hash.go#L74) -*/ -func (sc StarkCurve) PoseidonArray(felts ...*felt.Felt) *felt.Felt { - return junoCrypto.PoseidonArray(felts...) -} - -/* -Provides the starknet keccak hash . -NOTE: This function just wraps the Juno implementation - -(ref: https://github.com/NethermindEth/juno/blob/main/core/crypto/keccak.go#L11) -*/ -func (sc StarkCurve) StarknetKeccak(b []byte) (*felt.Felt, error) { - return junoCrypto.StarknetKeccak(b) -} - -// implementation based on https://github.com/codahale/rfc6979/blob/master/rfc6979.go -// for the specification, see https://tools.ietf.org/html/rfc6979#section-3.2 -func (sc StarkCurve) GenerateSecret(msgHash, privKey, seed *big.Int) (secret *big.Int) { - alg := sha256.New - holen := alg().Size() - rolen := (sc.BitSize + 7) >> 3 - - if msgHash.BitLen()%8 <= 4 && msgHash.BitLen() >= 248 { - msgHash = msgHash.Mul(msgHash, big.NewInt(16)) - } - - by := append(int2octets(privKey, rolen), bits2octets(msgHash, sc.N, sc.BitSize, rolen)...) - - if seed.Cmp(big.NewInt(0)) == 1 { - by = append(by, seed.Bytes()...) - } - - v := bytes.Repeat([]byte{0x01}, holen) - - k := bytes.Repeat([]byte{0x00}, holen) - - k = mac(alg, k, append(append(v, 0x00), by...), k) - - v = mac(alg, k, v, v) - - k = mac(alg, k, append(append(v, 0x01), by...), k) - - v = mac(alg, k, v, v) - - for { - var t []byte - - for len(t) < rolen { - v = mac(alg, k, v, v) - t = append(t, v...) - } - - secret = bits2int(new(big.Int).SetBytes(t), sc.BitSize) - // TODO: implement seed here, final gating function - if secret.Cmp(big.NewInt(0)) == 1 && secret.Cmp(sc.N) == -1 { - return secret - } - k = mac(alg, k, append(v, 0x00), k) - v = mac(alg, k, v, v) - } -} - -// https://tools.ietf.org/html/rfc6979#section-2.3.3 -func int2octets(v *big.Int, rolen int) []byte { - out := v.Bytes() - - // pad with zeros if it's too short - if len(out) < rolen { - out2 := make([]byte, rolen) - copy(out2[rolen-len(out):], out) - return out2 - } - - // drop most significant bytes if it's too long - if len(out) > rolen { - out2 := make([]byte, rolen) - copy(out2, out[len(out)-rolen:]) - return out2 - } - - return out -} - -// https://tools.ietf.org/html/rfc6979#section-2.3.4 -func bits2octets(in, q *big.Int, qlen, rolen int) []byte { - z1 := bits2int(in, qlen) - z2 := new(big.Int).Sub(z1, q) - if z2.Sign() < 0 { - return int2octets(z1, rolen) - } - return int2octets(z2, rolen) -} - -// https://tools.ietf.org/html/rfc6979#section-2.3.2 -func bits2int(in *big.Int, qlen int) *big.Int { - blen := len(in.Bytes()) * 8 - - if blen > qlen { - - return new(big.Int).Rsh(in, uint(blen-qlen)) - } - return in -} - -// mac returns an HMAC of the given key and message. -func mac(alg func() hash.Hash, k, m, buf []byte) []byte { - h := hmac.New(alg, k) - h.Write(m) - return h.Sum(buf[:0]) -} - -// mask excess bits -func MaskBits(mask, wordSize int, slice []byte) (ret []byte) { - excess := len(slice)*wordSize - mask - for _, by := range slice { - if excess > 0 { - if excess > wordSize { - excess = excess - wordSize - continue - } - by <<= excess - by >>= excess - excess = 0 - } - ret = append(ret, by) - } - return ret -} - -// format the bytes in Keccak hash -func FmtKecBytes(in *big.Int, rolen int) (buf []byte) { - buf = append(buf, in.Bytes()...) - - // pad with zeros if too short - if len(buf) < rolen { - padded := make([]byte, rolen) - copy(padded[rolen-len(buf):], buf) - - return padded - } - - return buf -} diff --git a/typed.go b/typed/typed.go similarity index 93% rename from typed.go rename to typed/typed.go index 90364847..1bddc958 100644 --- a/typed.go +++ b/typed/typed.go @@ -1,4 +1,4 @@ -package starknetgo +package typed import ( "bytes" @@ -8,6 +8,7 @@ import ( "regexp" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/starknet.go/curve" "github.com/NethermindEth/starknet.go/utils" ) @@ -43,11 +44,9 @@ encoding definition for standard Starknet Domain messages */ func (dm Domain) FmtDefinitionEncoding(field string) (fmtEnc []*big.Int) { processStrToBig := func(fieldVal string) { - felt := strToFelt(fieldVal) - bigInt, ok := feltToBig(felt) - if ok { - fmtEnc = append(fmtEnc, bigInt) - } + feltVal := strToFelt(fieldVal) + bigInt := utils.FeltToBigInt(feltVal) + fmtEnc = append(fmtEnc, bigInt) } switch field { @@ -82,11 +81,6 @@ func strToFelt(str string) *felt.Felt { return f } -func feltToBig(feltNum *felt.Felt) (*big.Int, bool) { - return new(big.Int).SetString(feltNum.String(), 0) - -} - /* 'typedData' interface for interacting and signing typed data in accordance with https://github.com/0xs34n/starknet.js/tree/develop/src/utils/typedData */ @@ -112,7 +106,7 @@ func NewTypedData(types map[string]TypeDef, pType string, dom Domain) (td TypedD } // (ref: https://github.com/0xs34n/starknet.js/blob/767021a203ac0b9cdb282eb6d63b33bfd7614858/src/utils/typedData/index.ts#L166) -func (td TypedData) GetMessageHash(account *big.Int, msg TypedMessage, sc StarkCurve) (hash *big.Int, err error) { +func (td TypedData) GetMessageHash(account *big.Int, msg TypedMessage, sc curve.StarkCurve) (hash *big.Int, err error) { elements := []*big.Int{utils.UTF8StrToBig("Starknet Message")} domEnc, err := td.GetTypedMessageHash("StarknetDomain", td.Domain, sc) @@ -132,7 +126,7 @@ func (td TypedData) GetMessageHash(account *big.Int, msg TypedMessage, sc StarkC return hash, err } -func (td TypedData) GetTypedMessageHash(inType string, msg TypedMessage, sc StarkCurve) (hash *big.Int, err error) { +func (td TypedData) GetTypedMessageHash(inType string, msg TypedMessage, sc curve.StarkCurve) (hash *big.Int, err error) { prim := td.Types[inType] elements := []*big.Int{prim.Encoding} diff --git a/typed_test.go b/typed/typed_test.go similarity index 95% rename from typed_test.go rename to typed/typed_test.go index 63b008db..9065fbbb 100644 --- a/typed_test.go +++ b/typed/typed_test.go @@ -1,10 +1,11 @@ -package starknetgo +package typed import ( "fmt" "math/big" "testing" + "github.com/NethermindEth/starknet.go/curve" "github.com/NethermindEth/starknet.go/utils" ) @@ -66,7 +67,7 @@ func TestGeneral_GetMessageHash(t *testing.T) { Contents: "Hello, Bob!", } - hash, err := ttd.GetMessageHash(utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826"), mail, Curve) + hash, err := ttd.GetMessageHash(utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826"), mail, curve.Curve) if err != nil { t.Errorf("Could not hash message: %v\n", err) } @@ -93,14 +94,14 @@ func BenchmarkGetMessageHash(b *testing.B) { } addr := utils.HexToBN("0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826") b.Run(fmt.Sprintf("input_size_%d", addr.BitLen()), func(b *testing.B) { - ttd.GetMessageHash(addr, mail, Curve) + ttd.GetMessageHash(addr, mail, curve.Curve) }) } func TestGeneral_GetDomainHash(t *testing.T) { ttd := MockTypedData() - hash, err := ttd.GetTypedMessageHash("StarknetDomain", ttd.Domain, Curve) + hash, err := ttd.GetTypedMessageHash("StarknetDomain", ttd.Domain, curve.Curve) if err != nil { t.Errorf("Could not hash message: %v\n", err) } @@ -127,7 +128,7 @@ func TestGeneral_GetTypedMessageHash(t *testing.T) { Contents: "Hello, Bob!", } - hash, err := ttd.GetTypedMessageHash("Mail", mail, Curve) + hash, err := ttd.GetTypedMessageHash("Mail", mail, curve.Curve) if err != nil { t.Errorf("Could get typed message hash: %v\n", err) } diff --git a/utils.go b/utils.go deleted file mode 100644 index 64e3c822..00000000 --- a/utils.go +++ /dev/null @@ -1,38 +0,0 @@ -package starknetgo - -import ( - "crypto/rand" - "fmt" - "math/big" -) - -// obtain random primary key on stark curve -// NOTE: to be used for testing purposes -func (sc StarkCurve) GetRandomPrivateKey() (priv *big.Int, err error) { - max := new(big.Int).Sub(sc.Max, big.NewInt(1)) - - priv, err = rand.Int(rand.Reader, max) - if err != nil { - return priv, err - } - - x, y, err := sc.PrivateToPoint(priv) - if err != nil { - return priv, err - } - - if !sc.IsOnCurve(x, y) { - return priv, fmt.Errorf("key gen is not on stark cruve") - } - - return priv, nil -} - -// obtain public key coordinates from stark curve given the private key -func (sc StarkCurve) PrivateToPoint(privKey *big.Int) (x, y *big.Int, err error) { - if privKey.Cmp(big.NewInt(0)) != 1 || privKey.Cmp(sc.N) != -1 { - return x, y, fmt.Errorf("private key not in curve range") - } - x, y = sc.EcMult(privKey, sc.EcGenX, sc.EcGenY) - return x, y, nil -} diff --git a/utils_test.go b/utils_test.go deleted file mode 100644 index efce2c52..00000000 --- a/utils_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package starknetgo - -import ( - "testing" - - "github.com/NethermindEth/starknet.go/utils" -) - -func TestGeneral_SplitFactStr(t *testing.T) { - data := []map[string]string{ - {"input": "0x3", "h": "0x0", "l": "0x3"}, - {"input": "0x300000000000000000000000000000000", "h": "0x3", "l": "0x0"}, - } - for _, d := range data { - l, h := utils.SplitFactStr(d["input"]) // 0x3 - if l != d["l"] { - t.Errorf("expected %s, got %s", d["l"], l) - } - if h != d["h"] { - t.Errorf("expected %s, got %s", d["h"], h) - } - } -}