Skip to content

Commit

Permalink
(feat) Implemented OFAC list address check
Browse files Browse the repository at this point in the history
  • Loading branch information
shibaeff committed Sep 6, 2024
1 parent 5541211 commit 4ff81ed
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 0 deletions.
19 changes: 19 additions & 0 deletions pyinjective/core/broadcaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from decimal import Decimal
from typing import List, Optional

import requests
from google.protobuf import any_pb2
from grpc import RpcError

Expand All @@ -12,6 +13,9 @@
from pyinjective.constant import GAS_PRICE
from pyinjective.core.gas_limit_estimator import GasLimitEstimator
from pyinjective.core.network import Network
from pyinjective.exceptions import BannedAddressError, OfacListFetchError

OFAC_LIST_URL = "https://raw.githubusercontent.com/InjectiveLabs/injective-lists/master/wallets/ofac.json"


class BroadcasterAccountConfig(ABC):
Expand Down Expand Up @@ -62,6 +66,17 @@ def __init__(
self._client = client
self._composer = composer
self._fee_calculator = fee_calculator
self._ofac_list = self.load_ofac_list()

@classmethod
def load_ofac_list(cls) -> List[str]:
try:
response = requests.get(OFAC_LIST_URL)
response.raise_for_status()
ofac_list = response.json()
return ofac_list
except requests.exceptions.RequestException as e:
raise OfacListFetchError(f"Error fetching OFAC list: {e}")

@classmethod
def new_using_simulation(
Expand Down Expand Up @@ -157,6 +172,10 @@ async def broadcast(self, messages: List[any_pb2.Any]):

messages_for_transaction = self._account_config.messages_prepared_for_transaction(messages=messages)

# before constructing the transaction, check if sender address is in the OFAC list
if self._account_config.trading_injective_address in self._ofac_list:
raise BannedAddressError(f"Address {self._account_config.trading_injective_address} is in the OFAC list")

transaction = Transaction()
transaction.with_messages(*messages_for_transaction)
transaction.with_sequence(self._client.get_sequence())
Expand Down
8 changes: 8 additions & 0 deletions pyinjective/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,11 @@ class ConvertError(PyInjectiveError):

class SchemaError(PyInjectiveError):
pass


class BannedAddressError(PyInjectiveError):
pass


class OfacListFetchError(PyInjectiveError):
pass
72 changes: 72 additions & 0 deletions tests/core/test_broadcaster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from unittest.mock import AsyncMock, patch

import pytest

from pyinjective import PrivateKey
from pyinjective.async_client import AsyncClient
from pyinjective.composer import Composer
from pyinjective.core.broadcaster import BannedAddressError, MsgBroadcasterWithPk, StandardAccountBroadcasterConfig
from pyinjective.core.network import Network
from pyinjective.exceptions import EmptyMsgError


@pytest.mark.asyncio
async def test_broadcast_address_in_ofac_list():
private_key_banned = PrivateKey.from_mnemonic("test mnemonic never use other places")
public_key_banned = private_key_banned.to_public_key()
address_banned = public_key_banned.to_address()

ofac_list = [address_banned.to_acc_bech32()]
with patch("pyinjective.core.broadcaster.requests.get") as mock_get:
mock_get.return_value.json.return_value = ofac_list
mock_get.return_value.raise_for_status = lambda: None

network = Network.local()
client = AsyncClient(
network=Network.local(),
)
composer = AsyncMock(spec=Composer)

account_config = StandardAccountBroadcasterConfig(private_key=private_key_banned.to_hex())
broadcaster = MsgBroadcasterWithPk(
network=network,
account_config=account_config,
client=client,
composer=composer,
fee_calculator=AsyncMock(),
)
broadcaster._ofac_list = ofac_list
with pytest.raises(BannedAddressError):
await broadcaster.broadcast(messages=[])


@pytest.mark.asyncio
async def test_broadcast_address_not_in_ofac_list():
private_key_allowed = PrivateKey.from_mnemonic("another test mnemonic never use other places")

ofac_list = ["banned_address"]
with patch("pyinjective.core.broadcaster.requests.get") as mock_get:
mock_get.return_value.json.return_value = ofac_list
mock_get.return_value.raise_for_status = lambda: None

network = Network.local()
client = AsyncClient(network=Network.local())
composer = AsyncMock(spec=Composer)

account_config = StandardAccountBroadcasterConfig(private_key=private_key_allowed.to_hex())
broadcaster = MsgBroadcasterWithPk(
network=network,
account_config=account_config,
client=client,
composer=composer,
fee_calculator=AsyncMock(),
)
broadcaster._ofac_list = ofac_list

try:
await broadcaster.broadcast(messages=[])
except BannedAddressError:
pytest.fail("BannedAddressError was raised unexpectedly")
except EmptyMsgError:
# expected failure Exception caught
pass

0 comments on commit 4ff81ed

Please sign in to comment.