Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests + make coingecko key optional #51

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/pull_request.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,10 @@ jobs:
black --check ./
- name: Type Check (mypy)
run: mypy src
- name: Run Pytest
env:
MORALIS_API_KEY: ${{ secrets.MORALIS_API_KEY }}
CHAIN_SLEEP_TIME: ${{ secrets.CHAIN_SLEEP_TIME }}
NODE_URL: ${{ secrets.NODE_URL }}
CHAIN_NAME: ${{ secrets.CHAIN_NAME }}
run: pytest
32 changes: 23 additions & 9 deletions src/price_providers/coingecko_pricing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
COINGECKO_BUFFER_TIME,
)

coingecko_api_key = os.getenv("COINGECKO_API_KEY")


class CoingeckoPriceProvider(AbstractPriceProvider):
"""
Expand All @@ -24,28 +22,35 @@ class CoingeckoPriceProvider(AbstractPriceProvider):

def __init__(self) -> None:
self.web3 = get_web3_instance()
self.filtered_token_list = self.fetch_coingecko_list()
self.last_reload_time = time.time() # current time in seconds since epoch
self.last_reload_time = time.time()
self.coingecko_api_key = os.getenv("COINGECKO_API_KEY")
try:
self.filtered_token_list = self.fetch_coingecko_list()
self.last_reload_time = time.time() # current time in seconds since epoch
except Exception as e:
logger.warning(f"Failed to fetch initial token list: {e}")

@property
def name(self) -> str:
return "Coingecko"

def fetch_coingecko_list(self) -> list[dict]:
def fetch_coingecko_list(self) -> list[dict] | None:
"""
Fetch and filter the list of tokens (currently filters only Ethereum)
from the Coingecko API.
"""
if not self.coingecko_api_key:
return None

url = (
f"https://pro-api.coingecko.com/api/v3/coins/"
f"list?include_platform=true&status=active"
)
headers = {
"accept": "application/json",
}
if coingecko_api_key:
headers["x-cg-pro-api-key"] = coingecko_api_key

headers["x-cg-pro-api-key"] = self.coingecko_api_key
response = requests.get(url, headers=headers)
tokens_list = json.loads(response.text)
return [
Expand All @@ -71,6 +76,9 @@ def get_token_id_by_address(self, token_address: str) -> str | None:
self.last_reload_time = (
time.time()
) # update the last reload time to current time
if not self.filtered_token_list:
return None

for token in self.filtered_token_list:
if token["platforms"].get("ethereum") == token_address:
return token["id"]
Expand All @@ -82,7 +90,7 @@ def fetch_api_price(
"""
Makes call to Coingecko API to fetch price, between a start and end timestamp.
"""
if not coingecko_api_key:
if not self.coingecko_api_key:
logger.warning("Coingecko API key is not set.")
return None
# price of token is returned in ETH
Expand All @@ -92,7 +100,7 @@ def fetch_api_price(
)
headers = {
"accept": "application/json",
"x-cg-pro-api-key": coingecko_api_key,
"x-cg-pro-api-key": self.coingecko_api_key,
}
try:
response = requests.get(url, headers=headers)
Expand Down Expand Up @@ -122,6 +130,12 @@ def get_price(self, price_params: dict) -> float | None:
Function returns coingecko price for a token address,
closest to and at least as large as the block timestamp for a given tx hash.
"""
if not self.filtered_token_list:
logger.warning(
"Token list is empty, possibly the Coingecko API key isn't set."
)
return None

token_address, block_number = extract_params(price_params, is_block=True)
block_start_timestamp = self.web3.eth.get_block(block_number)["timestamp"]
if self.price_not_retrievable(block_start_timestamp):
Expand Down
2 changes: 1 addition & 1 deletion src/price_providers/moralis_pricing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ def get_price(self, price_params: dict) -> float | None:
self.logger.warning(f"Error: {e}")
except Exception as e:
self.logger.warning(
f"Price retrieval for token: {token_address} returned: {e}"
f"Price retrieval for token: {token_address} returned: {e}. Possibly the Moralis API key is missing."
)
return None
28 changes: 28 additions & 0 deletions tests/test_fees.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
from hexbytes import HexBytes
from src.fees.compute_fees import batch_fee_imbalances


def test_batch_fee_imbalances():
"""
Test the batch_fee_imbalances function with a valid transaction hash.
"""
tx_hash = "0x714bb3b1a804af7a493bcfa991b9859e03c52387b027783f175255885fa97dbd"
protocol_fees, network_fees = batch_fee_imbalances(HexBytes(tx_hash))

# verify that the returned fees are dicts
assert isinstance(protocol_fees, dict), "Protocol fees should be a dict."
assert isinstance(network_fees, dict), "Network fees should be a dict."

# Check that keys and values in the dict have the correct types
for token, fee in protocol_fees.items():
assert isinstance(token, str), "Token address should be string."
assert isinstance(fee, int), "Fee amount should be int."

for token, fee in network_fees.items():
assert isinstance(token, str), "Token address should be string."
assert isinstance(fee, int), "Fee amount should be int."


if __name__ == "__main__":
pytest.main()
4 changes: 2 additions & 2 deletions tests/basic_test.py → tests/test_imbalances.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_imbalances(tx_hash, expected_imbalances):
Asserts imbalances match for main script with test values provided.
"""
chain_name = os.getenv("CHAIN_NAME")
rt = RawTokenImbalances(get_web3_instance(), chain_name)
imbalances = rt.compute_imbalances(tx_hash)
compute = RawTokenImbalances(get_web3_instance(), chain_name)
imbalances = compute.compute_imbalances(tx_hash)
for token_address, expected_imbalance in expected_imbalances.items():
assert imbalances.get(token_address) == expected_imbalance
65 changes: 65 additions & 0 deletions tests/test_pricefeed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest
from src.price_providers.price_feed import PriceFeed


@pytest.fixture
def price_feed():
return PriceFeed()


def test_get_price_real(price_feed):
"""Test with legitimate parameters."""

# Test parameters
tx_hash = "0x94af3d98b0af4ca6bf41e85c05ed42fccd71d5aaa04cbe01fab00d1b2268c4e1"
token_address = "0xd1d2Eb1B1e90B638588728b4130137D262C87cae"
block_number = 20630508
price_params = {
"tx_hash": tx_hash,
"token_address": token_address,
"block_number": block_number,
}

# Get the price
result = price_feed.get_price(price_params)
assert result is not None

price, source = result
# Assert that the price is a positive float
assert isinstance(price, float)
assert price > 0
assert source in ["Coingecko", "Dune", "Moralis", "AuctionPrices"]


def test_get_price_unknown_token(price_feed):
"""Test with an unknown token address."""

tx_hash = "0x94af3d98b0af4ca6bf41e85c05ed42fccd71d5aaa04cbe01fab00d1b2268c4e1"
unknown_token = "0xd1d2Eb1B1e90B638588728b4130137D262C87cad"
price_params = {
"tx_hash": tx_hash,
"token_address": unknown_token,
"block_number": 20630508,
}
result = price_feed.get_price(price_params)

# expect None for an unknown token
assert result is None


def test_get_price_future_block(price_feed):
"""Test with a block number in the future."""
future_block = 99999999
price_params = {
"token_address": "0x6B175474E89094C44Da98b954EedeAC495271d0F",
"block_number": future_block,
}

result = price_feed.get_price(price_params)

# expect None for a future block
assert result is None


if __name__ == "__main__":
pytest.main()
Loading