Skip to content

Commit

Permalink
Update tests to close connection when done
Browse files Browse the repository at this point in the history
  • Loading branch information
SinaKhalili committed Oct 28, 2024
1 parent ea2c15d commit 58f97a6
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 39 deletions.
81 changes: 59 additions & 22 deletions tests/ci/devnet.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,37 @@
import os
import pytest
import asyncio

from pytest import mark

from solana.rpc.async_api import AsyncClient
import os

from anchorpy import Wallet

from driftpy.drift_client import DriftClient
from driftpy.account_subscription_config import AccountSubscriptionConfig
from driftpy.constants.perp_markets import devnet_perp_market_configs
from driftpy.constants.spot_markets import devnet_spot_market_configs
from driftpy.drift_client import DriftClient
import pytest
from pytest import mark
from solana.rpc.async_api import AsyncClient


@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for the session."""
try:
loop = asyncio.get_event_loop_policy().new_event_loop()
asyncio.set_event_loop(loop)
yield loop
finally:
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()

# Allow tasks to respond to cancellation
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()


@pytest.fixture(scope="session")
def rpc_url():
return os.environ.get("DEVNET_RPC_ENDPOINT")
return "https://api.devnet.solana.com"


@mark.asyncio
Expand All @@ -31,9 +46,7 @@ async def test_devnet_constants(rpc_url: str):
)

print("Subscribing to Drift Client")

asyncio.wait_for(await drift_client.subscribe(), 15)

await drift_client.subscribe()
print("Subscribed to Drift Client")

expected_perp_markets = sorted(
Expand All @@ -43,16 +56,29 @@ async def test_devnet_constants(rpc_url: str):
drift_client.get_perp_market_accounts(), key=lambda market: market.market_index
)

print("==> Received Perp Markets:")
for market in received_perp_markets:
print(
market.market_index,
market.amm.oracle,
bytes(market.name).decode("utf-8").strip(),
market.amm.oracle_source,
)

for expected, received in zip(expected_perp_markets, received_perp_markets):
market_info = f"Market: {received.pubkey} Market Index: {received.market_index}"

assert (
expected.market_index == received.market_index
), f"Devnet Perp: Expected market index {expected.market_index}, got {received.market_index} Market: {received.pubkey}"
), f"Devnet Perp: Expected market index {expected.market_index}, got {received.market_index} {market_info} for {expected.symbol}"

assert str(expected.oracle) == str(
received.amm.oracle
), f"Devnet Perp: Expected oracle {expected.oracle}, got {received.amm.oracle} Market: {received.pubkey} Market Index: {received.market_index}"
), f"Devnet Perp: Expected oracle {expected.oracle}, got {received.amm.oracle} {market_info} for {expected.symbol}"

assert str(expected.oracle_source) == str(
received.amm.oracle_source
), f"Devnet Perp: Expected oracle source {expected.oracle_source}, got {received.amm.oracle_source} Market: {received.pubkey} Market Index: {received.market_index}"
), f"Devnet Perp: Expected oracle source {expected.oracle_source}, got {received.amm.oracle_source} {market_info} for {expected.symbol}"

expected_spot_markets = sorted(
devnet_spot_market_configs, key=lambda market: market.market_index
Expand All @@ -61,16 +87,29 @@ async def test_devnet_constants(rpc_url: str):
drift_client.get_spot_market_accounts(), key=lambda market: market.market_index
)

print("\n==> Received Spot Markets:")
for market in received_spot_markets:
print(
market.market_index,
market.oracle,
bytes(market.name).decode("utf-8").strip(),
market.oracle_source,
)

for expected, received in zip(expected_spot_markets, received_spot_markets):
market_info = f"Market: {received.pubkey} Market Index: {received.market_index}"

assert (
expected.market_index == received.market_index
), f"Devnet Spot: Expected market index {expected.market_index}, got {received.market_index} Market: {received.pubkey}"
), f"Devnet Spot: Expected market index {expected.market_index}, got {received.market_index} {market_info} for {expected.symbol}"

assert str(expected.oracle) == str(
received.oracle
), f"Devnet Spot: Expected oracle {expected.oracle}, got {received.oracle} Market: {received.pubkey} Market Index: {received.market_index}"
), f"Devnet Spot: Expected oracle {expected.oracle}, got {received.oracle} {market_info} for {expected.symbol}"

assert str(expected.oracle_source) == str(
received.oracle_source
), f"Devnet Spot: Expected oracle source {expected.oracle_source}, got {received.oracle_source} Market: {received.pubkey} Market Index: {received.market_index}"
), f"Devnet Spot: Expected oracle source {expected.oracle_source}, got {received.oracle_source} {market_info} for {expected.symbol}"


@mark.asyncio
Expand All @@ -85,7 +124,7 @@ async def test_devnet_cached(rpc_url: str):

print("Subscribing to Drift Client")

asyncio.wait_for(await drift_client.subscribe(), 15)
await drift_client.subscribe()

print("Subscribed to Drift Client")

Expand Down Expand Up @@ -133,9 +172,7 @@ async def test_devnet_ws(rpc_url: str):
)

print("Subscribing to Drift Client")

asyncio.wait_for(await drift_client.subscribe(), 15)

await drift_client.subscribe()
print("Subscribed to Drift Client")

perp_markets = drift_client.get_perp_market_accounts()
Expand Down
41 changes: 24 additions & 17 deletions tests/ci/mainnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,31 @@
from solana.rpc.async_api import AsyncClient


@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for the session."""
try:
loop = asyncio.get_event_loop_policy().new_event_loop()
asyncio.set_event_loop(loop)
yield loop
finally:
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()

# Allow tasks to respond to cancellation
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()


@pytest.fixture(scope="session")
def rpc_url():
return os.environ.get("MAINNET_RPC_ENDPOINT")


@mark.asyncio
async def test_mainnet_constants(rpc_url: str):
print()
print("Checking mainnet constants")
drift_client = DriftClient(
AsyncClient(rpc_url),
Expand All @@ -28,9 +45,7 @@ async def test_mainnet_constants(rpc_url: str):
)

print("Subscribing to Drift Client")

asyncio.wait_for(await drift_client.subscribe(), 15)

await drift_client.subscribe()
print("Subscribed to Drift Client")

expected_perp_markets = sorted(
Expand Down Expand Up @@ -72,7 +87,6 @@ async def test_mainnet_constants(rpc_url: str):

@mark.asyncio
async def test_mainnet_cached(rpc_url: str):
print()
print("Checking mainnet cached subscription")
drift_client = DriftClient(
AsyncClient(rpc_url),
Expand All @@ -82,9 +96,7 @@ async def test_mainnet_cached(rpc_url: str):
)

print("Subscribing to Drift Client")

asyncio.wait_for(await drift_client.subscribe(), 15)

await drift_client.subscribe()
print("Subscribed to Drift Client")

perp_markets = drift_client.get_perp_market_accounts()
Expand Down Expand Up @@ -114,27 +126,23 @@ async def test_mainnet_cached(rpc_url: str):
), f"Expected {len(mainnet_spot_market_configs)} spot markets, got {len(spot_markets)}"

print("Unsubscribing from Drift Client")

await drift_client.unsubscribe()

print("Unsubscribed from Drift Client")


@mark.asyncio
async def test_mainnet_ws(rpc_url: str):
print()
print("Checking mainnet websocket subscription")
connection = AsyncClient(rpc_url)
drift_client = DriftClient(
AsyncClient(rpc_url),
connection,
Wallet.dummy(),
env="mainnet",
account_subscription=AccountSubscriptionConfig("websocket"),
)

print("Subscribing to Drift Client")

asyncio.wait_for(await drift_client.subscribe(), 15)

await drift_client.subscribe()
print("Subscribed to Drift Client")

perp_markets = drift_client.get_perp_market_accounts()
Expand Down Expand Up @@ -165,7 +173,6 @@ async def test_mainnet_ws(rpc_url: str):
), f"Expected {len(mainnet_spot_market_configs)} spot markets, got {len(spot_markets)}"

print("Unsubscribing from Drift Client")

await drift_client.unsubscribe()

await connection.close()
print("Unsubscribed from Drift Client")

0 comments on commit 58f97a6

Please sign in to comment.