Skip to content

Commit

Permalink
feat/refactor_cookies_management (#156)
Browse files Browse the repository at this point in the history
* (feat) Refactored all cookies related logic into the CookieAssistant structs. Changed chainClient, exchangeClient and explorerClient to receive the network and use the cookie through the CookieAssistant in the network. Removed the HTTP websocket from chainClient

* (fix) Added port number to the endpoints that did not have it

---------

Co-authored-by: abel <[email protected]>
  • Loading branch information
aarmoa and abel committed Sep 21, 2023
1 parent 0d2b3a6 commit 468f46b
Show file tree
Hide file tree
Showing 111 changed files with 887 additions and 815 deletions.
192 changes: 68 additions & 124 deletions client/chain/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"math"
"math/big"
"net/http"
"os"
"strconv"
"strings"
Expand All @@ -19,7 +18,6 @@ import (
sdkmath "cosmossdk.io/math"
wasmtypes "github.com/CosmWasm/wasmd/x/wasm/types"
log "github.com/InjectiveLabs/suplog"
rpcclient "github.com/cometbft/cometbft/rpc/client"
rpchttp "github.com/cometbft/cometbft/rpc/client/http"
"github.com/cosmos/cosmos-sdk/client"
"github.com/cosmos/cosmos-sdk/client/tx"
Expand Down Expand Up @@ -132,7 +130,9 @@ type ChainClient interface {
GetGasFee() (string, error)

StreamEventOrderFail(sender string, failEventCh chan map[string]uint)
StreamEventOrderFailWithWebsocket(sender string, websocket *rpchttp.HTTP, failEventCh chan map[string]uint)
StreamOrderbookUpdateEvents(orderbookType OrderbookType, marketIds []string, orderbookCh chan exchangetypes.Orderbook)
StreamOrderbookUpdateEventsWithWebsocket(orderbookType OrderbookType, marketIds []string, websocket *rpchttp.HTTP, orderbookCh chan exchangetypes.Orderbook)

// get tx from chain node
GetTx(ctx context.Context, txHash string) (*txtypes.GetTxResponse, error)
Expand All @@ -141,6 +141,7 @@ type ChainClient interface {

type chainClient struct {
ctx client.Context
network common.Network
opts *common.ClientOptions
logger log.Logger
conn *grpc.ClientConn
Expand All @@ -162,7 +163,6 @@ type chainClient struct {
sessionCookie string
sessionEnabled bool

cometbftClient rpcclient.Client
txClient txtypes.ServiceClient
authQueryClient authtypes.QueryClient
exchangeQueryClient exchangetypes.QueryClient
Expand All @@ -175,15 +175,19 @@ type chainClient struct {
canSign bool
}

// NewCosmosClient creates a new gRPC client that communicates with gRPC server at protoAddr.
// NewChainClient creates a new gRPC client that communicates with gRPC server at protoAddr.
// protoAddr must be in form "tcp://127.0.0.1:8080" or "unix:///tmp/test.sock", protocol is required.
func NewChainClient(
ctx client.Context,
protoAddr string,
network common.Network,
options ...common.ClientOption,
) (ChainClient, error) {
// process options
opts := common.DefaultClientOptions()

if network.ChainTlsCert != nil {
options = append(options, common.OptionTLSCert(network.ChainTlsCert))
}
for _, opt := range options {
if err := opt(opts); err != nil {
err = errors.Wrap(err, "error in client option")
Expand All @@ -207,37 +211,22 @@ func NewChainClient(
var err error
stickySessionEnabled := true
if opts.TLSCert != nil {
conn, err = grpc.Dial(protoAddr, grpc.WithTransportCredentials(opts.TLSCert), grpc.WithContextDialer(common.DialerFunc))
conn, err = grpc.Dial(network.ChainGrpcEndpoint, grpc.WithTransportCredentials(opts.TLSCert), grpc.WithContextDialer(common.DialerFunc))
} else {
conn, err = grpc.Dial(protoAddr, grpc.WithInsecure(), grpc.WithContextDialer(common.DialerFunc))
conn, err = grpc.Dial(network.ChainGrpcEndpoint, grpc.WithInsecure(), grpc.WithContextDialer(common.DialerFunc))
stickySessionEnabled = false
}
if err != nil {
err = errors.Wrapf(err, "failed to connect to the gRPC: %s", protoAddr)
err = errors.Wrapf(err, "failed to connect to the gRPC: %s", network.ChainGrpcEndpoint)
return nil, err
}

// init tm websocket
var cometbftClient *rpchttp.HTTP
if ctx.NodeURI != "" {
cometbftClient, err = rpchttp.New(ctx.NodeURI, "/websocket")
if err != nil {
panic(err)
}

if !cometbftClient.IsRunning() {
err = cometbftClient.Start()
if err != nil {
return nil, err
}
}
}

cancelCtx, cancelFn := context.WithCancel(context.Background())
// build client
cc := &chainClient{
ctx: ctx,
opts: opts,
ctx: ctx,
network: network,
opts: opts,

logger: log.WithFields(log.Fields{
"module": "sdk-go",
Expand All @@ -255,7 +244,6 @@ func NewChainClient(

sessionEnabled: stickySessionEnabled,

cometbftClient: cometbftClient,
txClient: txtypes.NewServiceClient(conn),
authQueryClient: authtypes.NewQueryClient(conn),
exchangeQueryClient: exchangetypes.NewQueryClient(conn),
Expand Down Expand Up @@ -371,92 +359,16 @@ func (c *chainClient) getAccSeq() uint64 {
return c.accSeq
}

func (c *chainClient) setCookie(metadata metadata.MD) {
if !c.sessionEnabled {
return
}
md := metadata.Get("set-cookie")
if len(md) > 0 {
// write to client instance
c.sessionCookie = md[0]
// write to disk
err := os.WriteFile(defaultChainCookieName, []byte(md[0]), 0644)
if err != nil {
c.logger.Errorln(err)
return
}
c.logger.Infoln("chain session cookie saved to disk")
}
}

func (c *chainClient) fetchCookie(ctx context.Context) context.Context {
func (c *chainClient) requestCookie() metadata.MD {
var header metadata.MD
c.txClient.GetTx(context.Background(), &txtypes.GetTxRequest{}, grpc.Header(&header))
c.setCookie(header)
time.Sleep(defaultBlockTime)

return metadata.NewOutgoingContext(ctx, metadata.Pairs("cookie", c.sessionCookie))
}

func cookieByName(cookies []*http.Cookie, key string) *http.Cookie {
for _, c := range cookies {
if c.Name == key {
return c
}
}
return nil
}

func (c *chainClient) getCookieExpirationTime(cookies []*http.Cookie) (time.Time, error) {
var expiresAt string
if cookieByName(cookies, "GCLB") != nil {
// parse global load balance cookie timestamp
cookie := cookieByName(cookies, "expires")
expiresAt = strings.Replace(cookie.Value, "-", " ", -1)
} else {
cookie := cookieByName(cookies, "Expires")
if cookie == nil {
return time.Time{}, nil
}

expiresAt = strings.Replace(cookie.Value, "-", " ", -1)
yyyy := fmt.Sprintf("20%s", expiresAt[12:14])
expiresAt = expiresAt[:12] + yyyy + expiresAt[14:]
}

return time.Parse(time.RFC1123, expiresAt)
return header
}

func (c *chainClient) getCookie(ctx context.Context) context.Context {
md := metadata.Pairs("cookie", c.sessionCookie)
if !c.sessionEnabled {
return metadata.NewOutgoingContext(ctx, md)
}

// borrow http request to parse cookie
header := http.Header{}
header.Add("Cookie", c.sessionCookie)
request := http.Request{Header: header}
cookies := request.Cookies()

if len(cookies) > 0 {
// parse expire field into unix timestamp
expiresTimestamp, err := c.getCookieExpirationTime(cookies)
if err != nil {
panic(err)
}

if !expiresTimestamp.IsZero() {
// renew session if timestamp diff < offset
timestampDiff := expiresTimestamp.Unix() - time.Now().Unix()
if timestampDiff < defaultSessionRenewalOffset {
return c.fetchCookie(ctx)
}
}
} else {
return c.fetchCookie(ctx)
}

provider := common.NewMetadataProvider(c.requestCookie)
cookie, _ := c.network.ChainMetadata(provider)
md := metadata.Pairs("cookie", cookie)
return metadata.NewOutgoingContext(ctx, md)
}

Expand Down Expand Up @@ -499,10 +411,6 @@ func (c *chainClient) Close() {
if c.conn != nil {
c.conn.Close()
}

if c.cometbftClient != nil {
c.cometbftClient.Stop()
}
}

func (c *chainClient) GetBankBalances(ctx context.Context, address string) (*banktypes.QueryAllBalancesResponse, error) {
Expand Down Expand Up @@ -580,8 +488,7 @@ func (c *chainClient) SimulateMsg(clientCtx client.Context, msgs ...sdk.Msg) (*t

ctx := context.Background()
ctx = c.getCookie(ctx)
var header metadata.MD
simRes, err := c.txClient.Simulate(ctx, &txtypes.SimulateRequest{TxBytes: simTxBytes}, grpc.Header(&header))
simRes, err := c.txClient.Simulate(ctx, &txtypes.SimulateRequest{TxBytes: simTxBytes})
if err != nil {
err = errors.Wrap(err, "failed to CalculateGas")
return nil, err
Expand Down Expand Up @@ -671,9 +578,8 @@ func (c *chainClient) SyncBroadcastSignedTx(txBytes []byte) (*txtypes.BroadcastT
}

ctx := context.Background()
var header metadata.MD
ctx = c.getCookie(ctx)
res, err := c.txClient.BroadcastTx(ctx, &req, grpc.Header(&header))
res, err := c.txClient.BroadcastTx(ctx, &req)
if err != nil {
return res, err
}
Expand Down Expand Up @@ -720,9 +626,8 @@ func (c *chainClient) AsyncBroadcastSignedTx(txBytes []byte) (*txtypes.Broadcast

ctx := context.Background()
// use our own client to broadcast tx
var header metadata.MD
ctx = c.getCookie(ctx)
res, err := c.txClient.BroadcastTx(ctx, &req, grpc.Header(&header))
res, err := c.txClient.BroadcastTx(ctx, &req)
if err != nil {
return nil, err
}
Expand All @@ -749,8 +654,7 @@ func (c *chainClient) broadcastTx(
return nil, err
}
ctx := c.getCookie(ctx)
var header metadata.MD
simRes, err := c.txClient.Simulate(ctx, &txtypes.SimulateRequest{TxBytes: simTxBytes}, grpc.Header(&header))
simRes, err := c.txClient.Simulate(ctx, &txtypes.SimulateRequest{TxBytes: simTxBytes})
if err != nil {
err = errors.Wrap(err, "failed to CalculateGas")
return nil, err
Expand Down Expand Up @@ -787,9 +691,8 @@ func (c *chainClient) broadcastTx(
Mode: txtypes.BroadcastMode_BROADCAST_MODE_SYNC,
}
// use our own client to broadcast tx
var header metadata.MD
ctx = c.getCookie(ctx)
res, err := c.txClient.BroadcastTx(ctx, &req, grpc.Header(&header))
res, err := c.txClient.BroadcastTx(ctx, &req)
if !await || err != nil {
return res, err
}
Expand Down Expand Up @@ -1242,8 +1145,28 @@ func (c *chainClient) BuildExchangeBatchUpdateOrdersAuthz(
}

func (c *chainClient) StreamEventOrderFail(sender string, failEventCh chan map[string]uint) {
var cometbftClient *rpchttp.HTTP
var err error

cometbftClient, err = rpchttp.New(c.network.TmEndpoint, "/websocket")
if err != nil {
panic(err)
}

if !cometbftClient.IsRunning() {
err = cometbftClient.Start()
if err != nil {
panic(err)
}
}
defer cometbftClient.Stop()

c.StreamEventOrderFailWithWebsocket(sender, cometbftClient, failEventCh)
}

func (c *chainClient) StreamEventOrderFailWithWebsocket(sender string, websocket *rpchttp.HTTP, failEventCh chan map[string]uint) {
filter := fmt.Sprintf("tm.event='Tx' AND message.sender='%s' AND message.action='/injective.exchange.v1beta1.MsgBatchUpdateOrders' AND injective.exchange.v1beta1.EventOrderFail.flags EXISTS", sender)
eventCh, err := c.cometbftClient.Subscribe(context.Background(), "OrderFail", filter, 10000)
eventCh, err := websocket.Subscribe(context.Background(), "OrderFail", filter, 10000)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -1276,8 +1199,29 @@ func (c *chainClient) StreamEventOrderFail(sender string, failEventCh chan map[s
}

func (c *chainClient) StreamOrderbookUpdateEvents(orderbookType OrderbookType, marketIds []string, orderbookCh chan exchangetypes.Orderbook) {
var cometbftClient *rpchttp.HTTP
var err error

cometbftClient, err = rpchttp.New(c.network.TmEndpoint, "/websocket")
if err != nil {
panic(err)
}

if !cometbftClient.IsRunning() {
err = cometbftClient.Start()
if err != nil {
panic(err)
}
}
defer cometbftClient.Stop()

c.StreamOrderbookUpdateEventsWithWebsocket(orderbookType, marketIds, cometbftClient, orderbookCh)

}

func (c *chainClient) StreamOrderbookUpdateEventsWithWebsocket(orderbookType OrderbookType, marketIds []string, websocket *rpchttp.HTTP, orderbookCh chan exchangetypes.Orderbook) {
filter := fmt.Sprintf("tm.event='NewBlock' AND %s EXISTS", orderbookType)
eventCh, err := c.cometbftClient.Subscribe(context.Background(), "OrderbookUpdate", filter, 10000)
eventCh, err := websocket.Subscribe(context.Background(), "OrderbookUpdate", filter, 10000)
if err != nil {
panic(err)
}
Expand Down
9 changes: 4 additions & 5 deletions client/chain/chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,15 @@ func createClient(senderAddress cosmtypes.AccAddress, cosmosKeyring keyring.Keyr

chainClient, err := NewChainClient(
clientCtx,
network.ChainGrpcEndpoint,
common.OptionTLSCert(network.ChainTlsCert),
network,
common.OptionGasPrices("500000000inj"),
)

return chainClient, err
}

func TestDefaultSubaccount(t *testing.T) {
network := common.LoadNetwork("testnet", "k8s")
network := common.LoadNetwork("testnet", "lb")
senderAddress, cosmosKeyring, err := accountForTests()

if err != nil {
Expand All @@ -71,8 +70,8 @@ func TestDefaultSubaccount(t *testing.T) {
}
}

func TestGetSubaccountWithIndes(t *testing.T) {
network := common.LoadNetwork("testnet", "k8s")
func TestGetSubaccountWithIndex(t *testing.T) {
network := common.LoadNetwork("testnet", "lb")
senderAddress, cosmosKeyring, err := accountForTests()

if err != nil {
Expand Down
Loading

0 comments on commit 468f46b

Please sign in to comment.