Skip to content

Commit

Permalink
(feat) Add OFAC list check
Browse files Browse the repository at this point in the history
  • Loading branch information
shibaeff committed Sep 9, 2024
1 parent b7a68cc commit d5f6982
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 2 deletions.
58 changes: 57 additions & 1 deletion client/chain/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ type chainClient struct {

sessionEnabled bool

ofacChecker *OfacChecker

authQueryClient authtypes.QueryClient
authzQueryClient authztypes.QueryClient
bankQueryClient banktypes.QueryClient
Expand All @@ -342,6 +344,29 @@ type chainClient struct {
canSign bool
}

//func (cc *chainClient) loadOfacList() error {
// response, err := http.Get(defaultOfacListURL)
// if err != nil {
// return err
// }
// defer response.Body.Close()
//
// if response.StatusCode != http.StatusOK {
// return fmt.Errorf("request to the OFAC upstream failed with code: %s", response.Status)
// }
//
// body, err := io.ReadAll(response.Body)
// if err != nil {
// return err
// }
//
// var ofacList []string
// if err := json.Unmarshal(body, &ofacList); err != nil {
// return err
// }
// return nil
//}

func NewChainClient(
ctx client.Context,
network common.Network,
Expand Down Expand Up @@ -440,6 +465,11 @@ func NewChainClient(
subaccountToNonce: make(map[ethcommon.Hash]uint32),
}

cc.ofacChecker, err = NewOfacChecker()
if err != nil {
return nil, errors.Wrap(err, "Error creating OFAC checker")
}

if cc.canSign {
var err error

Expand All @@ -453,6 +483,11 @@ func NewChainClient(
go cc.syncTimeoutHeight()
}

if err != nil {
err = errors.Wrap(err, "failed to load the OFAC list")
return nil, err
}

return cc, nil
}

Expand Down Expand Up @@ -774,6 +809,21 @@ func (c *chainClient) BuildSignedTx(clientCtx client.Context, accNum, accSeq, in
}

func (c *chainClient) buildSignedTx(clientCtx client.Context, txf tx.Factory, msgs ...sdk.Msg) ([]byte, error) {
k, err := txf.Keybase().Key(clientCtx.FromName)
if err != nil {
err = errors.Wrap(err, "error parsing signer account address")
return nil, err
}
signerAddressPubKey, err := k.GetPubKey()
if err != nil {
err = errors.Wrap(err, "error getting signer public key")
return nil, err
}
if c.ofacChecker.IsBlacklisted(sdk.AccAddress(signerAddressPubKey.Address()).String()) {
err = errors.Errorf("Address is in the OFAC list")
return nil, err
}

ctx := context.Background()
if clientCtx.Simulate {
simTxBytes, err := txf.BuildSimTx(msgs...)
Expand All @@ -796,7 +846,7 @@ func (c *chainClient) buildSignedTx(clientCtx client.Context, txf tx.Factory, ms
c.gasWanted = adjustedGas
}

txf, err := PrepareFactory(clientCtx, txf)
txf, err = PrepareFactory(clientCtx, txf)
if err != nil {
return nil, errors.Wrap(err, "failed to prepareFactory")
}
Expand Down Expand Up @@ -1153,6 +1203,9 @@ func (c *chainClient) GetAuthzGrants(ctx context.Context, req authztypes.QueryGr
}

func (c *chainClient) BuildGenericAuthz(granter, grantee, msgtype string, expireIn time.Time) *authztypes.MsgGrant {
if c.ofacChecker.IsBlacklisted(granter) {
panic("Address is in the OFAC list") // panics should generally be avoided, but otherwise function signature should be changed
}
authz := authztypes.NewGenericAuthorization(msgtype)
authzAny := codectypes.UnsafePackAny(authz)
return &authztypes.MsgGrant{
Expand Down Expand Up @@ -1184,6 +1237,9 @@ var (
)

func (c *chainClient) BuildExchangeAuthz(granter, grantee string, authzType ExchangeAuthz, subaccountId string, markets []string, expireIn time.Time) *authztypes.MsgGrant {
if c.ofacChecker.IsBlacklisted(granter) {
panic("Address is in the OFAC list") // panics should generally be avoided, but otherwise function signature should be changed
}
var typedAuthzAny codectypes.Any
var typedAuthzBytes []byte
switch authzType {
Expand Down
69 changes: 69 additions & 0 deletions client/chain/chain_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
package chain

import (
"context"
"os"
"testing"

exchangetypes "github.com/InjectiveLabs/sdk-go/chain/exchange/types"
exchangeclient "github.com/InjectiveLabs/sdk-go/client/exchange"
spotExchangePB "github.com/InjectiveLabs/sdk-go/exchange/spot_exchange_rpc/pb"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"

"github.com/InjectiveLabs/sdk-go/client"
"github.com/InjectiveLabs/sdk-go/client/common"
rpchttp "github.com/cometbft/cometbft/rpc/client/http"
Expand Down Expand Up @@ -51,6 +58,68 @@ func createClient(senderAddress cosmtypes.AccAddress, cosmosKeyring keyring.Keyr
return chainClient, err
}

func TestOfacList(t *testing.T) {
network := common.LoadNetwork("testnet", "lb")
tmClient, err := rpchttp.New(network.TmEndpoint, "/websocket")
assert.NoError(t, err)

senderAddress, cosmosKeyring, err := accountForTests()
assert.NoError(t, err)

clientCtx, err := NewClientContext(
network.ChainId,
senderAddress.String(),
cosmosKeyring,
)
assert.NoError(t, err)

clientCtx = clientCtx.WithNodeURI(network.TmEndpoint).WithClient(tmClient)

exchangeClient, err := exchangeclient.NewExchangeClient(network)
assert.NoError(t, err)
ctx := context.Background()
req := spotExchangePB.MarketsRequest{
MarketStatus: "active",
}
res, err := exchangeClient.GetSpotMarkets(ctx, &req)
assert.NoError(t, err)

marketsAssistant, err := NewMarketsAssistantInitializedFromChain(ctx, exchangeClient)
assert.NoError(t, err)

cc, err := createClient(senderAddress, cosmosKeyring, network)
assert.NoError(t, err)

defaultSubaccountID := cc.DefaultSubaccount(senderAddress)
marketId := res.Markets[0].MarketId
amount := decimal.NewFromFloat(2)
price := decimal.NewFromFloat(1.02)

order := cc.CreateSpotOrder(
defaultSubaccountID,
&SpotOrderData{
OrderType: exchangetypes.OrderType_BUY, //BUY SELL BUY_PO SELL_PO
Quantity: amount,
Price: price,
FeeRecipient: senderAddress.String(),
MarketId: marketId,
},
marketsAssistant,
)

msg := new(exchangetypes.MsgCreateSpotLimitOrder)
msg.Sender = senderAddress.String()
msg.Order = exchangetypes.SpotOrder(*order)

accNum, accSeq := cc.GetAccNonce()

cc.(*chainClient).ofacTxList = []string{

Check failure on line 116 in client/chain/chain_test.go

View workflow job for this annotation

GitHub Actions / run-tests

cc.(*chainClient).ofacTxList undefined (type *chainClient has no field or method ofacTxList)
senderAddress.String(),
}
_, err = cc.BuildSignedTx(clientCtx, accNum, accSeq, 20000, msg)
assert.Error(t, err)
}

func TestDefaultSubaccount(t *testing.T) {
network := common.LoadNetwork("devnet", "lb")
senderAddress, cosmosKeyring, err := accountForTests()
Expand Down
85 changes: 85 additions & 0 deletions client/chain/ofac.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package chain

import (
"encoding/json"
"errors"
"io"
"net/http"
"os"
"path/filepath"
"strings"
)

const (
defaultOfacListURL = "https://raw.githubusercontent.com/InjectiveLabs/injective-lists/master/wallets/ofac.json"
defaultofacListFilename = "ofac.json"
)

type OfacChecker struct {
ofacListPath string
ofacList map[string]interface{}
}

func NewOfacChecker() (*OfacChecker, error) {
checker := &OfacChecker{
ofacListPath: getOfacListPath(),
}
if _, err := os.Stat(checker.ofacListPath); os.IsNotExist(err) {
if err := checker.downloadOfacList(); err != nil {
return nil, err
}
}
if err := checker.loadOfacList(); err != nil {
return nil, err
}
return checker, nil
}

func getOfacListPath() string {
currentDirectory, _ := os.Getwd()
for !strings.HasSuffix(currentDirectory, "sdk-go") {
currentDirectory = filepath.Dir(currentDirectory)
}
return filepath.Join(currentDirectory, defaultofacListFilename)
}

func (oc *OfacChecker) downloadOfacList() error {
resp, err := http.Get(defaultOfacListURL)
if err != nil {
return err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return errors.New("failed to download OFAC list")
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}

if err := os.WriteFile(oc.ofacListPath, body, 0644); err != nil {
return err
}

return nil
}

func (oc *OfacChecker) loadOfacList() error {
file, err := os.ReadFile(oc.ofacListPath)
if err != nil {
return err
}

err = json.Unmarshal(file, &oc.ofacList)
if err != nil {
return err
}
return nil
}

func (oc *OfacChecker) IsBlacklisted(address string) bool {
_, exists := oc.ofacList[address]
return exists
}
2 changes: 1 addition & 1 deletion examples/chain/8_OfflineSigning/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func main() {
}

defaultSubaccountID := chainClient.DefaultSubaccount(senderAddress)
marketId := "0xa508cb32923323679f29a032c70342c147c17d0145625922b0ef22e955c844c0"
marketId := "0x01edfab47f124748dc89998eb33144af734484ba07099014594321729a0ca16b"
amount := decimal.NewFromFloat(2)
price := decimal.NewFromFloat(1.02)

Expand Down

0 comments on commit d5f6982

Please sign in to comment.