From 4bf3c72d51584512f3a72263cabedf696d0e389e Mon Sep 17 00:00:00 2001 From: Maksim Zdobnikau <43750648+DelevoXDG@users.noreply.github.com> Date: Mon, 19 Feb 2024 15:16:07 +0100 Subject: [PATCH] Set chain ID in StarknetAccount initializer (#154) --- Sources/Starknet/Accounts/StarknetAccount.swift | 8 +++----- .../Starknet/Accounts/StarknetAccountProtocol.swift | 2 ++ Tests/StarknetTests/Accounts/AccountTest.swift | 8 +++++--- Tests/StarknetTests/Data/ExecutionTests.swift | 3 ++- Tests/StarknetTests/Providers/ProviderTests.swift | 13 ++++++++----- 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/Sources/Starknet/Accounts/StarknetAccount.swift b/Sources/Starknet/Accounts/StarknetAccount.swift index 26a9713cf..7ce647ecf 100644 --- a/Sources/Starknet/Accounts/StarknetAccount.swift +++ b/Sources/Starknet/Accounts/StarknetAccount.swift @@ -13,14 +13,16 @@ public enum CairoVersion: String, Encodable { public class StarknetAccount: StarknetAccountProtocol { private let cairoVersion: CairoVersion public let address: Felt + public let chainId: StarknetChainId private let signer: StarknetSignerProtocol private let provider: StarknetProviderProtocol - public init(address: Felt, signer: StarknetSignerProtocol, provider: StarknetProviderProtocol, cairoVersion: CairoVersion) { + public init(address: Felt, signer: StarknetSignerProtocol, provider: StarknetProviderProtocol, chainId: StarknetChainId, cairoVersion: CairoVersion) { self.address = address self.signer = signer self.provider = provider + self.chainId = chainId self.cairoVersion = cairoVersion } @@ -45,7 +47,6 @@ public class StarknetAccount: StarknetAccountProtocol { let transaction = makeInvokeTransactionV1(calldata: calldata, signature: [], params: params, forFeeEstimation: forFeeEstimation) - let chainId = try await provider.getChainId() let hash = StarknetTransactionHashCalculator.computeHash(of: transaction, chainId: chainId) let signature = try signer.sign(transactionHash: hash) @@ -58,7 +59,6 @@ public class StarknetAccount: StarknetAccountProtocol { let transaction = makeInvokeTransactionV3(calldata: calldata, signature: [], params: params, forFeeEstimation: forFeeEstimation) - let chainId = try await provider.getChainId() let hash = StarknetTransactionHashCalculator.computeHash(of: transaction, chainId: chainId) let signature = try signer.sign(transactionHash: hash) @@ -69,7 +69,6 @@ public class StarknetAccount: StarknetAccountProtocol { public func signDeployAccountV1(classHash: Felt, calldata: StarknetCalldata, salt: Felt, params: StarknetDeployAccountParamsV1, forFeeEstimation: Bool) async throws -> StarknetDeployAccountTransactionV1 { let transaction = makeDeployAccountTransactionV1(classHash: classHash, salt: salt, calldata: calldata, signature: [], params: params, forFeeEstimation: forFeeEstimation) - let chainId = try await provider.getChainId() let hash = StarknetTransactionHashCalculator.computeHash(of: transaction, chainId: chainId) let signature = try signer.sign(transactionHash: hash) @@ -80,7 +79,6 @@ public class StarknetAccount: StarknetAccountProtocol { public func signDeployAccountV3(classHash: Felt, calldata: StarknetCalldata, salt: Felt, params: StarknetDeployAccountParamsV3, forFeeEstimation: Bool) async throws -> StarknetDeployAccountTransactionV3 { let transaction = makeDeployAccountTransactionV3(classHash: classHash, salt: salt, calldata: calldata, signature: [], params: params, forFeeEstimation: forFeeEstimation) - let chainId = try await provider.getChainId() let hash = StarknetTransactionHashCalculator.computeHash(of: transaction, chainId: chainId) let signature = try signer.sign(transactionHash: hash) diff --git a/Sources/Starknet/Accounts/StarknetAccountProtocol.swift b/Sources/Starknet/Accounts/StarknetAccountProtocol.swift index 51d305da9..222b26783 100644 --- a/Sources/Starknet/Accounts/StarknetAccountProtocol.swift +++ b/Sources/Starknet/Accounts/StarknetAccountProtocol.swift @@ -3,6 +3,8 @@ import Foundation public protocol StarknetAccountProtocol { /// Address of starknet account. var address: Felt { get } + /// Chain id of the Starknet provider. + var chainId: StarknetChainId { get } /// Sign list of calls as invoke transaction v1 /// diff --git a/Tests/StarknetTests/Accounts/AccountTest.swift b/Tests/StarknetTests/Accounts/AccountTest.swift index 0fa5f8fe2..e46f742d5 100644 --- a/Tests/StarknetTests/Accounts/AccountTest.swift +++ b/Tests/StarknetTests/Accounts/AccountTest.swift @@ -6,6 +6,7 @@ final class AccountTests: XCTestCase { static var devnetClient: DevnetClientProtocol! var provider: StarknetProviderProtocol! + var chainId: StarknetChainId! var signer: StarknetSignerProtocol! var account: StarknetAccountProtocol! var accountContractClassHash: Felt! @@ -23,7 +24,8 @@ final class AccountTests: XCTestCase { ethContractAddress = Self.devnetClient.constants.ethErc20ContractAddress let accountDetails = Self.devnetClient.constants.predeployedAccount1 signer = StarkCurveSigner(privateKey: accountDetails.privateKey)! - account = StarknetAccount(address: accountDetails.address, signer: signer, provider: provider, cairoVersion: .zero) + chainId = try await provider.getChainId() + account = StarknetAccount(address: accountDetails.address, signer: signer, provider: provider, chainId: chainId, cairoVersion: .zero) } override class func setUp() { @@ -146,7 +148,7 @@ final class AccountTests: XCTestCase { let newSigner = StarkCurveSigner(privateKey: 1234)! let newPublicKey = newSigner.publicKey let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountContractClassHash, calldata: [newPublicKey], salt: .zero) - let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero) + let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, chainId: chainId, cairoVersion: .zero) try await Self.devnetClient.prefundAccount(address: newAccountAddress) @@ -172,7 +174,7 @@ final class AccountTests: XCTestCase { let newSigner = StarkCurveSigner(privateKey: 4567)! let newPublicKey = newSigner.publicKey let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountContractClassHash, calldata: [newPublicKey], salt: .zero) - let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero) + let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, chainId: chainId, cairoVersion: .zero) try await Self.devnetClient.prefundAccount(address: newAccountAddress, unit: .fri) diff --git a/Tests/StarknetTests/Data/ExecutionTests.swift b/Tests/StarknetTests/Data/ExecutionTests.swift index c55eafc9f..16a37554d 100644 --- a/Tests/StarknetTests/Data/ExecutionTests.swift +++ b/Tests/StarknetTests/Data/ExecutionTests.swift @@ -20,7 +20,8 @@ final class ExecutionTests: XCTestCase { provider = StarknetProvider(url: Self.devnetClient.rpcUrl)! let accountDetails = ExecutionTests.devnetClient.constants.predeployedAccount1 signer = StarkCurveSigner(privateKey: accountDetails.privateKey)! - account = StarknetAccount(address: accountDetails.address, signer: signer, provider: provider, cairoVersion: .one) + let chainId = try await provider.getChainId() + account = StarknetAccount(address: accountDetails.address, signer: signer, provider: provider, chainId: chainId, cairoVersion: .one) balanceContractAddress = try await Self.devnetClient.declareDeployContract(contractName: "Balance").deploy.contractAddress } diff --git a/Tests/StarknetTests/Providers/ProviderTests.swift b/Tests/StarknetTests/Providers/ProviderTests.swift index 240cb1b22..3f855ba52 100644 --- a/Tests/StarknetTests/Providers/ProviderTests.swift +++ b/Tests/StarknetTests/Providers/ProviderTests.swift @@ -6,6 +6,7 @@ final class ProviderTests: XCTestCase { static var devnetClient: DevnetClientProtocol! var provider: StarknetProviderProtocol! + var chainId: StarknetChainId! var signer: StarknetSignerProtocol! var account: StarknetAccountProtocol! var accountContractClassHash: Felt! @@ -33,7 +34,9 @@ final class ProviderTests: XCTestCase { accountContractClassHash = Self.devnetClient.constants.accountContractClassHash let accountDetails = Self.devnetClient.constants.predeployedAccount2 signer = StarkCurveSigner(privateKey: accountDetails.privateKey)! - account = StarknetAccount(address: accountDetails.address, signer: signer, provider: provider, cairoVersion: .zero) + + chainId = try await provider.getChainId() + account = StarknetAccount(address: accountDetails.address, signer: signer, provider: provider, chainId: chainId, cairoVersion: .zero) } func makeStarknetProvider(url: String) -> StarknetProviderProtocol { @@ -246,7 +249,7 @@ final class ProviderTests: XCTestCase { let newSigner = StarkCurveSigner(privateKey: 1111)! let newPublicKey = newSigner.publicKey let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountContractClassHash, calldata: [newPublicKey], salt: .zero) - let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero) + let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, chainId: chainId, cairoVersion: .zero) try await Self.devnetClient.prefundAccount(address: newAccountAddress) @@ -267,7 +270,7 @@ final class ProviderTests: XCTestCase { let newSigner = StarkCurveSigner(privateKey: 3333)! let newPublicKey = newSigner.publicKey let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountContractClassHash, calldata: [newPublicKey], salt: .zero) - let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero) + let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, chainId: chainId, cairoVersion: .zero) try await Self.devnetClient.prefundAccount(address: newAccountAddress) @@ -322,7 +325,7 @@ final class ProviderTests: XCTestCase { let newSigner = StarkCurveSigner(privateKey: 1001)! let newPublicKey = newSigner.publicKey let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountClassHash, calldata: [newPublicKey], salt: .zero) - let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero) + let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, chainId: chainId, cairoVersion: .zero) try await Self.devnetClient.prefundAccount(address: newAccountAddress) @@ -363,7 +366,7 @@ final class ProviderTests: XCTestCase { let newSigner = StarkCurveSigner(privateKey: 3003)! let newPublicKey = newSigner.publicKey let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountClassHash, calldata: [newPublicKey], salt: .zero) - let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero) + let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, chainId: chainId, cairoVersion: .zero) try await Self.devnetClient.prefundAccount(address: newAccountAddress, amount: 5_000_000_000_000_000_000, unit: .fri)