Skip to content

Commit

Permalink
Do not use cache map in ChainIDProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
MDobak committed Feb 15, 2024
1 parent d3ae97d commit 88d7a38
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions txmodifier/chainid.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package txmodifier
import (
"context"
"fmt"
"sync"

"github.com/defiweb/go-eth/rpc"
"github.com/defiweb/go-eth/types"
Expand All @@ -14,21 +15,27 @@ import (
// To use this modifier, add it using the WithTXModifiers option when creating
// a new rpc.Client.
type ChainIDProvider struct {
chainID map[rpc.RPC]uint64
mu sync.Mutex
chainID *uint64
replace bool
cache bool
}

// ChainIDProviderOptions is the options for NewChainIDProvider.
type ChainIDProviderOptions struct {
Replace bool // Replace is true if the chain ID should be replaced even if it is already set.
Cache bool // Cache is true if the chain ID will be cached instead of being queried for each transaction.
// Replace is true if the transaction chain ID should be replaced even if
// it is already set.
Replace bool

// Cache is true if the chain ID will be cached instead of being queried
// for each transaction. Cached chain ID will be used for all RPC clients
// that use the same ChainIDProvider instance.
Cache bool
}

// NewChainIDProvider returns a new ChainIDProvider.
func NewChainIDProvider(opts ChainIDProviderOptions) *ChainIDProvider {
return &ChainIDProvider{
chainID: make(map[rpc.RPC]uint64),
replace: opts.Replace,
cache: opts.Cache,
}
Expand All @@ -39,17 +46,27 @@ func (p *ChainIDProvider) Modify(ctx context.Context, client rpc.RPC, tx *types.
if !p.replace && tx.ChainID != nil {
return nil
}
if chainID, ok := p.chainID[client]; ok {
if !p.cache {
chainID, err := client.ChainID(ctx)
if err != nil {
return fmt.Errorf("chain ID provider: %w", err)
}
tx.ChainID = &chainID
return nil
}
chainID, err := client.ChainID(ctx)
if err != nil {
return fmt.Errorf("chain ID provider: %w", err)
}
if p.cache {
p.chainID[client] = chainID
p.mu.Lock()
defer p.mu.Unlock()
var cid uint64
if p.chainID != nil {
cid = *p.chainID
} else {
chainID, err := client.ChainID(ctx)
if err != nil {
return fmt.Errorf("chain ID provider: %w", err)
}
p.chainID = &chainID
cid = chainID
}
tx.ChainID = &chainID
tx.ChainID = &cid
return nil
}

0 comments on commit 88d7a38

Please sign in to comment.