Skip to content

Commit

Permalink
rearrange concurrency in Test_Concurrent_Submission (#69)
Browse files Browse the repository at this point in the history
rearrange concurrency in Test_Concurrent_Submission
close channels when they're done
simplify channel consuming syntax
WaitGroup-s for intra-function goroutines
SequenceNumberTracker to sync/atomic
BuildTransactions, BuildSignAndSubmitTransactions, BuildSignAndSubmitTransactionsWithSignFunction get options
Test_Concurrent_Submission ExpirationSeconds(20)
  • Loading branch information
brianolson authored Jun 25, 2024
1 parent 577b1ee commit 940b3a6
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 122 deletions.
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ func (client *Client) PollForTransactions(txnHashes []string, options ...any) er
}

// WaitForTransaction Do a long-GET for one transaction and wait for it to complete
func (client *Client) WaitForTransaction(txnHash string) (data *api.UserTransaction, err error) {
return client.nodeClient.WaitForTransaction(txnHash)
func (client *Client) WaitForTransaction(txnHash string, options ...any) (data *api.UserTransaction, err error) {
return client.nodeClient.WaitForTransaction(txnHash, options...)
}

// Transactions Get recent transactions.
Expand Down
106 changes: 85 additions & 21 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package aptos

import (
"strings"
"sync"
"testing"
"time"

"github.com/aptos-labs/aptos-go-sdk/api"
"github.com/aptos-labs/aptos-go-sdk/bcs"
Expand Down Expand Up @@ -376,10 +378,44 @@ func Test_AccountResources(t *testing.T) {
assert.Greater(t, len(resourcesBcs), 0)
}

// A worker thread that reads from a chan of transactions that have been submitted and waits on their completion status
func concurrentTxnWaiter(
results chan TransactionSubmissionResponse,
waitResults chan ConcResponse[*api.UserTransaction],
client *Client,
t *testing.T,
wg *sync.WaitGroup,
) {
if wg != nil {
defer wg.Done()
}
responseCount := 0
for response := range results {
responseCount++
assert.NoError(t, response.Err)

waitResponse, err := client.WaitForTransaction(response.Response.Hash, PollTimeout(21*time.Second))
if err != nil {
t.Logf("%s err %s", response.Response.Hash, err)
} else if waitResponse == nil {
t.Logf("%s nil response", response.Response.Hash)
} else if !waitResponse.Success {
t.Logf("%s !Success", response.Response.Hash)
}
waitResults <- ConcResponse[*api.UserTransaction]{Result: waitResponse, Err: err}
}
t.Logf("concurrentTxnWaiter done, %d responses", responseCount)
// signal completion
// (do not close the output as there may be other workers writing to it)
waitResults <- ConcResponse[*api.UserTransaction]{Result: nil, Err: nil}
}

func Test_Concurrent_Submission(t *testing.T) {
const numTxns = uint64(10)
const numTxns = uint64(100)
const numWaiters = 4
netConfig := LocalnetConfig

client, err := NewClient(DevnetConfig)
client, err := NewClient(netConfig)
assert.NoError(t, err)

account1, err := NewEd25519Account()
Expand All @@ -393,9 +429,9 @@ func Test_Concurrent_Submission(t *testing.T) {
assert.NoError(t, err)

// start submission goroutine
payloads := make(chan TransactionSubmissionPayload)
results := make(chan TransactionSubmissionResponse)
go client.nodeClient.BuildSignAndSubmitTransactions(account1, payloads, results)
payloads := make(chan TransactionSubmissionPayload, 50)
results := make(chan TransactionSubmissionResponse, 50)
go client.nodeClient.BuildSignAndSubmitTransactions(account1, payloads, results, ExpirationSeconds(20))

transferAmount, err := bcs.SerializeU64(100)
assert.NoError(t, err)
Expand All @@ -413,37 +449,65 @@ func Test_Concurrent_Submission(t *testing.T) {
}},
}
}
close(payloads)
t.Log("done submitting txns")

// Start waiting on txns
// TODO: These final steps should be concurrent rather than serial like this
waitResults := make(chan ConcResponse[*api.UserTransaction], numTxns)
waitResults := make(chan ConcResponse[*api.UserTransaction], numWaiters*10)

// It's interesting, this had to be wrapped in a goroutine to ensure blocking on results dont' block
go func() {
for response := range results {
assert.NoError(t, response.Err)

go fetch[*api.UserTransaction](func() (*api.UserTransaction, error) {
return client.WaitForTransaction(response.Response.Hash)
}, waitResults)
}
}()
var wg sync.WaitGroup
wg.Add(numWaiters)
for _ = range numWaiters {
go concurrentTxnWaiter(results, waitResults, client, t, &wg)
}

// Wait on all the results, recording the succeeding ones
txnMap := make(map[uint64]bool)

waitersRunning := numWaiters

// We could wait on a close, but I'm going to be a little pickier here
for i := uint64(0); i < numTxns; i++ {
i := uint64(0)
txnGoodEvents := 0
for {
response := <-waitResults
if response.Err == nil && response.Result == nil {
t.Log("txn waiter signaled done")
waitersRunning--
if waitersRunning == 0 {
close(results)
t.Log("last txn waiter done")
break
}
continue
}
assert.NoError(t, response.Err)
assert.True(t, response.Result.Success)
txnMap[response.Result.SequenceNumber] = true
assert.True(t, (response.Result != nil) && response.Result.Success)
if response.Result != nil {
txnMap[response.Result.SequenceNumber] = true
txnGoodEvents++
}
i++
if i >= numTxns {
t.Logf("waited on %d txns, done", i)
break
}
}
t.Log("done waiting for txns, waiting for txn waiter threads")

wg.Wait()

// Check all transactions were successful from [0-numTxns)
t.Logf("got %d(%d) successful txns of %d attempted, error submission indexes:", len(txnMap), txnGoodEvents, numTxns)
allTrue := true
for i := uint64(0); i < numTxns; i++ {
assert.True(t, txnMap[i])
allTrue = allTrue && txnMap[i]
if !txnMap[i] {
t.Logf("%d", i)
}
}
assert.True(t, allTrue, "all txns successful")
assert.Equal(t, len(txnMap), int(numTxns), "num txns successful == num txns sent")
}

func TestClient_BlockByHeight(t *testing.T) {
Expand Down
73 changes: 42 additions & 31 deletions nodeClient.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ func getTransactionPollOptions(defaultPeriod, defaultTimeout time.Duration, opti
return
}

// PollForTransaction waits up to 10 seconds for a transaction to be done, polling at 10Hz
// Accepts options PollPeriod and PollTimeout which should wrap time.Duration values.
// Not just a degenerate case of PollForTransactions, it may return additional information for the single transaction polled.
func (rc *NodeClient) PollForTransaction(hash string, options ...any) (*api.UserTransaction, error) {
period, timeout, err := getTransactionPollOptions(100*time.Millisecond, 10*time.Second, options...)
if err != nil {
Expand All @@ -259,7 +262,7 @@ func (rc *NodeClient) PollForTransaction(hash string, options ...any) (*api.User
deadline := start.Add(timeout)
for {
if time.Now().After(deadline) {
return nil, errors.New("timeout waiting for faucet transactions")
return nil, errors.New("PollForTransaction timeout")
}
time.Sleep(period)
txn, err := rc.TransactionByHash(hash)
Expand Down Expand Up @@ -290,7 +293,7 @@ func (rc *NodeClient) PollForTransactions(txnHashes []string, options ...any) er
deadline := start.Add(timeout)
for len(hashSet) > 0 {
if time.Now().After(deadline) {
return errors.New("timeout waiting for faucet transactions")
return errors.New("PollForTransactions timeout")
}
time.Sleep(period)
for _, hash := range txnHashes {
Expand Down Expand Up @@ -643,8 +646,9 @@ func (rc *NodeClient) buildTransactionInner(
// Fetch requirements concurrently, and then consume them

// Fetch GasUnitPrice which may be cached
gasPriceErrChannel := make(chan error, 1)
var gasPriceErrChannel chan error
if !haveGasUnitPrice {
gasPriceErrChannel = make(chan error, 1)
go func() {
gasPriceEstimation, innerErr := rc.EstimateGasPrice()
if innerErr != nil {
Expand All @@ -655,32 +659,32 @@ func (rc *NodeClient) buildTransactionInner(
}
close(gasPriceErrChannel)
}()
} else {
gasPriceErrChannel <- nil
close(gasPriceErrChannel)
}

// Fetch ChainId which may be cached
chainIdErrChannel := make(chan error, 1)
var chainIdErrChannel chan error
if !haveChainId {
go func() {
chain, innerErr := rc.GetChainId()
if innerErr != nil {
chainIdErrChannel <- innerErr
} else {
chainId = chain
chainIdErrChannel <- nil
}
close(chainIdErrChannel)
}()
} else {
chainIdErrChannel <- nil
close(chainIdErrChannel)
if rc.chainId == 0 {
chainIdErrChannel = make(chan error, 1)
go func() {
chain, innerErr := rc.GetChainId()
if innerErr != nil {
chainIdErrChannel <- innerErr
} else {
chainId = chain
chainIdErrChannel <- nil
}
close(chainIdErrChannel)
}()
} else {
chainId = rc.chainId
}
}

// Fetch sequence number unless provided
accountErrChannel := make(chan error, 1)
var accountErrChannel chan error
if !haveSequenceNumber {
accountErrChannel = make(chan error, 1)
go func() {
account, innerErr := rc.Account(sender)
if innerErr != nil {
Expand All @@ -698,20 +702,27 @@ func (rc *NodeClient) buildTransactionInner(
accountErrChannel <- nil
close(accountErrChannel)
}()
} else {
accountErrChannel <- nil
close(accountErrChannel)
}

// TODO: optionally simulate for max gas
// Wait on the errors
chainIdErr, accountErr, gasPriceErr := <-chainIdErrChannel, <-accountErrChannel, <-gasPriceErrChannel
if chainIdErr != nil {
return nil, chainIdErr
} else if accountErr != nil {
return nil, accountErr
} else if gasPriceErr != nil {
return nil, gasPriceErr
if chainIdErrChannel != nil {
chainIdErr := <-chainIdErrChannel
if chainIdErr != nil {
return nil, chainIdErr
}
}
if accountErrChannel != nil {
accountErr := <-accountErrChannel
if accountErr != nil {
return nil, accountErr
}
}
if gasPriceErrChannel != nil {
gasPriceErr := <-gasPriceErrChannel
if gasPriceErr != nil {
return nil, gasPriceErr
}
}

expirationTimestampSeconds := uint64(time.Now().Unix() + expirationSeconds)
Expand Down
Loading

0 comments on commit 940b3a6

Please sign in to comment.