Skip to content

Commit

Permalink
Set chain ID in StarknetAccount initializer (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
DelevoXDG authored Feb 19, 2024
1 parent 6a78ba4 commit 4bf3c72
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 14 deletions.
8 changes: 3 additions & 5 deletions Sources/Starknet/Accounts/StarknetAccount.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ public enum CairoVersion: String, Encodable {
public class StarknetAccount: StarknetAccountProtocol {
private let cairoVersion: CairoVersion
public let address: Felt
public let chainId: StarknetChainId

private let signer: StarknetSignerProtocol
private let provider: StarknetProviderProtocol

public init(address: Felt, signer: StarknetSignerProtocol, provider: StarknetProviderProtocol, cairoVersion: CairoVersion) {
public init(address: Felt, signer: StarknetSignerProtocol, provider: StarknetProviderProtocol, chainId: StarknetChainId, cairoVersion: CairoVersion) {
self.address = address
self.signer = signer
self.provider = provider
self.chainId = chainId
self.cairoVersion = cairoVersion
}

Expand All @@ -45,7 +47,6 @@ public class StarknetAccount: StarknetAccountProtocol {

let transaction = makeInvokeTransactionV1(calldata: calldata, signature: [], params: params, forFeeEstimation: forFeeEstimation)

let chainId = try await provider.getChainId()
let hash = StarknetTransactionHashCalculator.computeHash(of: transaction, chainId: chainId)

let signature = try signer.sign(transactionHash: hash)
Expand All @@ -58,7 +59,6 @@ public class StarknetAccount: StarknetAccountProtocol {

let transaction = makeInvokeTransactionV3(calldata: calldata, signature: [], params: params, forFeeEstimation: forFeeEstimation)

let chainId = try await provider.getChainId()
let hash = StarknetTransactionHashCalculator.computeHash(of: transaction, chainId: chainId)

let signature = try signer.sign(transactionHash: hash)
Expand All @@ -69,7 +69,6 @@ public class StarknetAccount: StarknetAccountProtocol {
public func signDeployAccountV1(classHash: Felt, calldata: StarknetCalldata, salt: Felt, params: StarknetDeployAccountParamsV1, forFeeEstimation: Bool) async throws -> StarknetDeployAccountTransactionV1 {
let transaction = makeDeployAccountTransactionV1(classHash: classHash, salt: salt, calldata: calldata, signature: [], params: params, forFeeEstimation: forFeeEstimation)

let chainId = try await provider.getChainId()
let hash = StarknetTransactionHashCalculator.computeHash(of: transaction, chainId: chainId)

let signature = try signer.sign(transactionHash: hash)
Expand All @@ -80,7 +79,6 @@ public class StarknetAccount: StarknetAccountProtocol {
public func signDeployAccountV3(classHash: Felt, calldata: StarknetCalldata, salt: Felt, params: StarknetDeployAccountParamsV3, forFeeEstimation: Bool) async throws -> StarknetDeployAccountTransactionV3 {
let transaction = makeDeployAccountTransactionV3(classHash: classHash, salt: salt, calldata: calldata, signature: [], params: params, forFeeEstimation: forFeeEstimation)

let chainId = try await provider.getChainId()
let hash = StarknetTransactionHashCalculator.computeHash(of: transaction, chainId: chainId)

let signature = try signer.sign(transactionHash: hash)
Expand Down
2 changes: 2 additions & 0 deletions Sources/Starknet/Accounts/StarknetAccountProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import Foundation
public protocol StarknetAccountProtocol {
/// Address of starknet account.
var address: Felt { get }
/// Chain id of the Starknet provider.
var chainId: StarknetChainId { get }

/// Sign list of calls as invoke transaction v1
///
Expand Down
8 changes: 5 additions & 3 deletions Tests/StarknetTests/Accounts/AccountTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ final class AccountTests: XCTestCase {
static var devnetClient: DevnetClientProtocol!

var provider: StarknetProviderProtocol!
var chainId: StarknetChainId!
var signer: StarknetSignerProtocol!
var account: StarknetAccountProtocol!
var accountContractClassHash: Felt!
Expand All @@ -23,7 +24,8 @@ final class AccountTests: XCTestCase {
ethContractAddress = Self.devnetClient.constants.ethErc20ContractAddress
let accountDetails = Self.devnetClient.constants.predeployedAccount1
signer = StarkCurveSigner(privateKey: accountDetails.privateKey)!
account = StarknetAccount(address: accountDetails.address, signer: signer, provider: provider, cairoVersion: .zero)
chainId = try await provider.getChainId()
account = StarknetAccount(address: accountDetails.address, signer: signer, provider: provider, chainId: chainId, cairoVersion: .zero)
}

override class func setUp() {
Expand Down Expand Up @@ -146,7 +148,7 @@ final class AccountTests: XCTestCase {
let newSigner = StarkCurveSigner(privateKey: 1234)!
let newPublicKey = newSigner.publicKey
let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountContractClassHash, calldata: [newPublicKey], salt: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, chainId: chainId, cairoVersion: .zero)

try await Self.devnetClient.prefundAccount(address: newAccountAddress)

Expand All @@ -172,7 +174,7 @@ final class AccountTests: XCTestCase {
let newSigner = StarkCurveSigner(privateKey: 4567)!
let newPublicKey = newSigner.publicKey
let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountContractClassHash, calldata: [newPublicKey], salt: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, chainId: chainId, cairoVersion: .zero)

try await Self.devnetClient.prefundAccount(address: newAccountAddress, unit: .fri)

Expand Down
3 changes: 2 additions & 1 deletion Tests/StarknetTests/Data/ExecutionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ final class ExecutionTests: XCTestCase {
provider = StarknetProvider(url: Self.devnetClient.rpcUrl)!
let accountDetails = ExecutionTests.devnetClient.constants.predeployedAccount1
signer = StarkCurveSigner(privateKey: accountDetails.privateKey)!
account = StarknetAccount(address: accountDetails.address, signer: signer, provider: provider, cairoVersion: .one)
let chainId = try await provider.getChainId()
account = StarknetAccount(address: accountDetails.address, signer: signer, provider: provider, chainId: chainId, cairoVersion: .one)
balanceContractAddress = try await Self.devnetClient.declareDeployContract(contractName: "Balance").deploy.contractAddress
}

Expand Down
13 changes: 8 additions & 5 deletions Tests/StarknetTests/Providers/ProviderTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ final class ProviderTests: XCTestCase {
static var devnetClient: DevnetClientProtocol!

var provider: StarknetProviderProtocol!
var chainId: StarknetChainId!
var signer: StarknetSignerProtocol!
var account: StarknetAccountProtocol!
var accountContractClassHash: Felt!
Expand Down Expand Up @@ -33,7 +34,9 @@ final class ProviderTests: XCTestCase {
accountContractClassHash = Self.devnetClient.constants.accountContractClassHash
let accountDetails = Self.devnetClient.constants.predeployedAccount2
signer = StarkCurveSigner(privateKey: accountDetails.privateKey)!
account = StarknetAccount(address: accountDetails.address, signer: signer, provider: provider, cairoVersion: .zero)

chainId = try await provider.getChainId()
account = StarknetAccount(address: accountDetails.address, signer: signer, provider: provider, chainId: chainId, cairoVersion: .zero)
}

func makeStarknetProvider(url: String) -> StarknetProviderProtocol {
Expand Down Expand Up @@ -246,7 +249,7 @@ final class ProviderTests: XCTestCase {
let newSigner = StarkCurveSigner(privateKey: 1111)!
let newPublicKey = newSigner.publicKey
let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountContractClassHash, calldata: [newPublicKey], salt: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, chainId: chainId, cairoVersion: .zero)

try await Self.devnetClient.prefundAccount(address: newAccountAddress)

Expand All @@ -267,7 +270,7 @@ final class ProviderTests: XCTestCase {
let newSigner = StarkCurveSigner(privateKey: 3333)!
let newPublicKey = newSigner.publicKey
let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountContractClassHash, calldata: [newPublicKey], salt: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, chainId: chainId, cairoVersion: .zero)

try await Self.devnetClient.prefundAccount(address: newAccountAddress)

Expand Down Expand Up @@ -322,7 +325,7 @@ final class ProviderTests: XCTestCase {
let newSigner = StarkCurveSigner(privateKey: 1001)!
let newPublicKey = newSigner.publicKey
let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountClassHash, calldata: [newPublicKey], salt: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, chainId: chainId, cairoVersion: .zero)

try await Self.devnetClient.prefundAccount(address: newAccountAddress)

Expand Down Expand Up @@ -363,7 +366,7 @@ final class ProviderTests: XCTestCase {
let newSigner = StarkCurveSigner(privateKey: 3003)!
let newPublicKey = newSigner.publicKey
let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountClassHash, calldata: [newPublicKey], salt: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, cairoVersion: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider, chainId: chainId, cairoVersion: .zero)

try await Self.devnetClient.prefundAccount(address: newAccountAddress, amount: 5_000_000_000_000_000_000, unit: .fri)

Expand Down

0 comments on commit 4bf3c72

Please sign in to comment.