diff --git a/pyinjective/core/broadcaster.py b/pyinjective/core/broadcaster.py index f279baea..9623c529 100644 --- a/pyinjective/core/broadcaster.py +++ b/pyinjective/core/broadcaster.py @@ -3,6 +3,7 @@ from decimal import Decimal from typing import List, Optional +import requests from google.protobuf import any_pb2 from grpc import RpcError @@ -12,6 +13,9 @@ from pyinjective.constant import GAS_PRICE from pyinjective.core.gas_limit_estimator import GasLimitEstimator from pyinjective.core.network import Network +from pyinjective.exceptions import BannedAddressError, OfacListFetchError + +OFAC_LIST_URL = "https://raw.githubusercontent.com/InjectiveLabs/injective-lists/master/wallets/ofac.json" class BroadcasterAccountConfig(ABC): @@ -62,6 +66,17 @@ def __init__( self._client = client self._composer = composer self._fee_calculator = fee_calculator + self._ofac_list = self.load_ofac_list() + + @classmethod + def load_ofac_list(cls) -> List[str]: + try: + response = requests.get(OFAC_LIST_URL) + response.raise_for_status() + ofac_list = response.json() + return ofac_list + except requests.exceptions.RequestException as e: + raise OfacListFetchError(f"Error fetching OFAC list: {e}") @classmethod def new_using_simulation( @@ -157,6 +172,10 @@ async def broadcast(self, messages: List[any_pb2.Any]): messages_for_transaction = self._account_config.messages_prepared_for_transaction(messages=messages) + # before constructing the transaction, check if sender address is in the OFAC list + if self._account_config.trading_injective_address in self._ofac_list: + raise BannedAddressError(f"Address {self._account_config.trading_injective_address} is in the OFAC list") + transaction = Transaction() transaction.with_messages(*messages_for_transaction) transaction.with_sequence(self._client.get_sequence()) diff --git a/pyinjective/exceptions.py b/pyinjective/exceptions.py index f7be2962..dd3f9327 100644 --- a/pyinjective/exceptions.py +++ b/pyinjective/exceptions.py @@ -28,3 +28,11 @@ class ConvertError(PyInjectiveError): class SchemaError(PyInjectiveError): pass + + +class BannedAddressError(PyInjectiveError): + pass + + +class OfacListFetchError(PyInjectiveError): + pass diff --git a/tests/core/test_broadcaster.py b/tests/core/test_broadcaster.py new file mode 100644 index 00000000..97fe48c4 --- /dev/null +++ b/tests/core/test_broadcaster.py @@ -0,0 +1,72 @@ +from unittest.mock import AsyncMock, patch + +import pytest + +from pyinjective import PrivateKey +from pyinjective.async_client import AsyncClient +from pyinjective.composer import Composer +from pyinjective.core.broadcaster import BannedAddressError, MsgBroadcasterWithPk, StandardAccountBroadcasterConfig +from pyinjective.core.network import Network +from pyinjective.exceptions import EmptyMsgError + + +@pytest.mark.asyncio +async def test_broadcast_address_in_ofac_list(): + private_key_banned = PrivateKey.from_mnemonic("test mnemonic never use other places") + public_key_banned = private_key_banned.to_public_key() + address_banned = public_key_banned.to_address() + + ofac_list = [address_banned.to_acc_bech32()] + with patch("pyinjective.core.broadcaster.requests.get") as mock_get: + mock_get.return_value.json.return_value = ofac_list + mock_get.return_value.raise_for_status = lambda: None + + network = Network.local() + client = AsyncClient( + network=Network.local(), + ) + composer = AsyncMock(spec=Composer) + + account_config = StandardAccountBroadcasterConfig(private_key=private_key_banned.to_hex()) + broadcaster = MsgBroadcasterWithPk( + network=network, + account_config=account_config, + client=client, + composer=composer, + fee_calculator=AsyncMock(), + ) + broadcaster._ofac_list = ofac_list + with pytest.raises(BannedAddressError): + await broadcaster.broadcast(messages=[]) + + +@pytest.mark.asyncio +async def test_broadcast_address_not_in_ofac_list(): + private_key_allowed = PrivateKey.from_mnemonic("another test mnemonic never use other places") + + ofac_list = ["banned_address"] + with patch("pyinjective.core.broadcaster.requests.get") as mock_get: + mock_get.return_value.json.return_value = ofac_list + mock_get.return_value.raise_for_status = lambda: None + + network = Network.local() + client = AsyncClient(network=Network.local()) + composer = AsyncMock(spec=Composer) + + account_config = StandardAccountBroadcasterConfig(private_key=private_key_allowed.to_hex()) + broadcaster = MsgBroadcasterWithPk( + network=network, + account_config=account_config, + client=client, + composer=composer, + fee_calculator=AsyncMock(), + ) + broadcaster._ofac_list = ofac_list + + try: + await broadcaster.broadcast(messages=[]) + except BannedAddressError: + pytest.fail("BannedAddressError was raised unexpectedly") + except EmptyMsgError: + # expected failure Exception caught + pass