Skip to content

Commit

Permalink
Add new api method to retrieve n newest candles of a cluster. Add tes…
Browse files Browse the repository at this point in the history
…t coverage as well.
  • Loading branch information
sirEven committed Nov 7, 2024
1 parent 581e248 commit 32fc7f3
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 4 deletions.
8 changes: 8 additions & 0 deletions locast/candle_storage/candle_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ async def retrieve_cluster(
resolution: ResolutionDetail,
) -> List[Candle]: ...

async def retrieve_newest_candles(
self,
exchange: Exchange,
market: str,
resolution: ResolutionDetail,
amount: int,
) -> List[Candle]: ...

async def delete_cluster(
self,
exchange: Exchange,
Expand Down
32 changes: 32 additions & 0 deletions locast/candle_storage/sql/sqlite_candle_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,38 @@ async def retrieve_cluster(
else:
return []

async def retrieve_newest_candles(
self,
exchange: Exchange,
market: str,
resolution: ResolutionDetail,
amount: int,
) -> List[Candle]:
with Session(self._engine) as session:
if foreign_keys := self._look_up_foreign_keys(
exchange,
market,
resolution,
session,
):
sqlite_exchange, sqlite_market, sqlite_resolution = foreign_keys

stmnt = (
select(SqliteCandle)
.where(
(SqliteCandle.exchange_id == sqlite_exchange.id)
& (SqliteCandle.market_id == sqlite_market.id)
& (SqliteCandle.resolution_id == sqlite_resolution.id)
)
.order_by(desc(SqliteCandle.started_at))
.limit(amount)
)

results = session.exec(stmnt)
return self._to_candles(list(results.all()))
else:
return []

async def delete_cluster(
self,
exchange: Exchange,
Expand Down
27 changes: 25 additions & 2 deletions locast/store_manager/store_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,36 @@ async def retrieve_cluster(
) -> List[Candle]:
cluster_info = await self.get_cluster_info(exchange, market, resolution)

if not cluster_info.newest_candle:
if cluster_info.size == 0:
raise MissingClusterException(
f"Cluster does not exist for market {market} and resolution {resolution.notation}."
)

return await self._candle_storage.retrieve_cluster(exchange, market, resolution)

async def retrieve_newest_candles(
self,
exchange: Exchange,
market: str,
resolution: ResolutionDetail,
amount: int,
) -> List[Candle]:
cluster_info = await self.get_cluster_info(exchange, market, resolution)

if cluster_info.size == 0:
raise MissingClusterException(
f"Cluster does not exist for market {market} and resolution {resolution.notation}."
)

amount_to_retrieve = min(amount, cluster_info.size)

return await self._candle_storage.retrieve_newest_candles(
exchange,
market,
resolution,
amount_to_retrieve,
)

async def update_cluster(
self,
exchange: Exchange,
Expand Down Expand Up @@ -138,7 +161,7 @@ async def _check_horizon(
market: str,
resolution: ResolutionDetail,
start_date: datetime,
):
) -> datetime:
if not (horizon := self._horizon_cache.get(f"{market}_{resolution.notation}")):
horizon = await self._candle_fetcher.find_horizon(market, resolution)
self._horizon_cache[f"{market}_{resolution.notation}"] = horizon
Expand Down
88 changes: 86 additions & 2 deletions tests/candle_storage/sql/test_sqlite_candle_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def test_store_candles_results_in_correct_storage_state(

@pytest.mark.parametrize("amount", few_amounts)
@pytest.mark.asyncio
async def test_retrieve_candles_results_in_correct_cluster(
async def test_retrieve_cluster_results_in_correct_cluster(
sqlite_candle_storage_memory: SqliteCandleStorage,
amount: int,
) -> None:
Expand All @@ -103,7 +103,7 @@ async def test_retrieve_candles_results_in_correct_cluster(


@pytest.mark.asyncio
async def test_retrieve_candles_results_in_empty_list(
async def test_retrieve_cluster_results_in_empty_list(
sqlite_candle_storage_memory: SqliteCandleStorage,
) -> None:
# given
Expand All @@ -120,6 +120,90 @@ async def test_retrieve_candles_results_in_empty_list(
assert len(retrieved_candles) == 0


@pytest.mark.asyncio
async def test_retrieve_newest_candles_results_in_correct_list(
sqlite_candle_storage_memory: SqliteCandleStorage,
) -> None:
# given
storage = sqlite_candle_storage_memory

exchange = Exchange.DYDX_V4
res = ResolutionDetail(Seconds.ONE_MINUTE, "1MIN")
start_date = string_to_datetime("2022-01-01T00:00:00.000Z")
market = "ETH-USD"
amount_mocked = 100
amount_retreived = 10

candles = mock_dydx_v4_candles(market, res, amount_mocked, start_date)
await storage.store_candles(candles)

# when
retrieved_candles = await storage.retrieve_newest_candles(
exchange,
market,
res,
amount_retreived,
)

# then
assert len(retrieved_candles) == amount_retreived
assert candles[0].started_at == retrieved_candles[0].started_at


@pytest.mark.asyncio
async def test_retrieve_newest_candles_corrects_amount_to_cluster_size(
sqlite_candle_storage_memory: SqliteCandleStorage,
) -> None:
# given
storage = sqlite_candle_storage_memory

exchange = Exchange.DYDX_V4
res = ResolutionDetail(Seconds.ONE_MINUTE, "1MIN")
start_date = string_to_datetime("2022-01-01T00:00:00.000Z")
market = "ETH-USD"
amount_mocked = 50
amount_retreived = 60

candles = mock_dydx_v4_candles(market, res, amount_mocked, start_date)
await storage.store_candles(candles)

# when cluster size is less than requested to retrieve
retrieved_candles = await storage.retrieve_newest_candles(
exchange,
market,
res,
amount_retreived,
)

# then retrieved size equals cluster size
assert len(retrieved_candles) == amount_mocked
assert candles[0].started_at == retrieved_candles[0].started_at


@pytest.mark.asyncio
async def test_retrieve_newest_candles_results_in_empty_list(
sqlite_candle_storage_memory: SqliteCandleStorage,
) -> None:
# given
storage = sqlite_candle_storage_memory

exchange = Exchange.DYDX_V4
res = ResolutionDetail(Seconds.ONE_MINUTE, "1MIN")
market = "ETH-USD"
amount = 10

# when no cluster in storage
retrieved_candles = await storage.retrieve_newest_candles(
exchange,
market,
res,
amount,
)

# then
assert len(retrieved_candles) == 0


@pytest.mark.parametrize("amount", few_amounts)
@pytest.mark.asyncio
async def test_delete_cluster_results_in_correct_state(
Expand Down

0 comments on commit 32fc7f3

Please sign in to comment.