From 62a93f65f45cece140779fa0539761e634f5febd Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Mon, 2 Oct 2023 12:16:47 -0700 Subject: [PATCH] MOD: Support file paths in Live.add_stream --- CHANGELOG.md | 2 + databento/live/client.py | 24 +++++++++++- tests/test_live_client.py | 78 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2832b00..2d3de61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,11 +4,13 @@ #### Enhancements - Added `map_symbols` support for DBN data generated by the `Live` client +- Added support for file paths in `Live.add_stream` #### Bug fixes - Fixed an issue where `DBNStore.from_bytes` did not rewind seekable buffers - Fixed an issue where the `DBNStore` would not map symbols with input symbology of `SType.INSTRUMENT_ID` - Fixed an issue with `DBNStore.request_symbology` when the DBN metadata's start date and end date were the same +- Fixed an issue where closed streams were not removed from a `Live` client on shutdown. ## 0.20.0 - 2023-09-21 diff --git a/databento/live/client.py b/databento/live/client.py index 8008478..2b09c71 100644 --- a/databento/live/client.py +++ b/databento/live/client.py @@ -3,11 +3,13 @@ import asyncio import logging import os +import pathlib import queue import threading from collections.abc import Iterable from concurrent import futures from numbers import Number +from os import PathLike from typing import IO import databento_dbn @@ -307,7 +309,7 @@ def add_callback( def add_stream( self, - stream: IO[bytes], + stream: IO[bytes] | PathLike[str] | str, exception_callback: ExceptionCallback | None = None, ) -> None: """ @@ -315,7 +317,7 @@ def add_stream( Parameters ---------- - stream : IO[bytes] + stream : IO[bytes] or PathLike[str] or str The IO stream to write to when handling live records as they arrive. exception_callback : Callable[[Exception], None], optional An error handling callback to process exceptions that are raised @@ -325,12 +327,17 @@ def add_stream( ------ ValueError If `stream` is not a writable byte stream. + OSError + If `stream` is not a path to a writeable file. See Also -------- Live.add_callback """ + if isinstance(stream, (str, PathLike)): + stream = pathlib.Path(stream).open("wb") + if not hasattr(stream, "write"): raise ValueError(f"{type(stream).__name__} does not support write()") @@ -589,6 +596,19 @@ async def _shutdown(self) -> None: if self._session is None: return await self._session.wait_for_close() + + to_remove = [] + for stream in self._user_streams: + stream_name = getattr(stream, "name", str(stream)) + if stream.closed: + logger.info("removing closed user stream %s", stream_name) + to_remove.append(stream) + else: + stream.flush() + + for key in to_remove: + self._user_streams.pop(key) + self._symbology_map.clear() def _map_symbol(self, record: DBNRecord) -> None: diff --git a/tests/test_live_client.py b/tests/test_live_client.py index 17e61c0..628b528 100644 --- a/tests/test_live_client.py +++ b/tests/test_live_client.py @@ -309,6 +309,7 @@ def test_live_start_twice( with pytest.raises(ValueError): live_client.start() + def test_live_start_before_subscribe( live_client: client.Live, ) -> None: @@ -318,6 +319,7 @@ def test_live_start_before_subscribe( with pytest.raises(ValueError): live_client.start() + @pytest.mark.parametrize( "schema", [pytest.param(schema, id=str(schema)) for schema in Schema.variants()], @@ -428,6 +430,34 @@ def test_live_stop( live_client.block_for_close() +@pytest.mark.usefixtures("mock_live_server") +def test_live_shutdown_remove_closed_stream( + tmp_path: pathlib.Path, + live_client: client.Live, +) -> None: + """ + Test that closed streams are removed upon disconnection. + """ + live_client.subscribe( + dataset=Dataset.GLBX_MDP3, + schema=Schema.MBO, + ) + + output = tmp_path / "output.dbn" + + with output.open("wb") as out: + live_client.add_stream(out) + + assert live_client.is_connected() is True + + live_client.start() + + live_client.stop() + live_client.block_for_close() + + assert live_client._user_streams == {} + + def test_live_stop_twice( live_client: client.Live, ) -> None: @@ -575,6 +605,15 @@ def test_live_add_stream_invalid( with pytest.raises(ValueError): live_client.add_stream(readable_file.open(mode="rb")) +def test_live_add_stream_path_directory( + tmp_path: pathlib.Path, + live_client: client.Live, +) -> None: + """ + Test that passing a path to a directory raises an OSError. + """ + with pytest.raises(OSError): + live_client.add_stream(tmp_path) @pytest.mark.skipif(platform.system() == "Windows", reason="flaky on windows runner") async def test_live_async_iteration( @@ -730,6 +769,7 @@ def test_live_sync_iteration( assert isinstance(records[2], databento_dbn.MBOMsg) assert isinstance(records[3], databento_dbn.MBOMsg) + async def test_live_callback( live_client: client.Live, ) -> None: @@ -800,6 +840,44 @@ async def test_live_stream_to_dbn( assert output.read_bytes() == expected_data.read() +@pytest.mark.parametrize( + "schema", + (pytest.param(schema, id=str(schema)) for schema in Schema.variants()), +) +async def test_live_stream_to_dbn_from_path( + tmp_path: pathlib.Path, + test_data_path: Callable[[Schema], pathlib.Path], + live_client: client.Live, + schema: Schema, +) -> None: + """ + Test that DBN data streamed by the MockLiveServer is properly re- + constructed client side when specifying a file as a path. + """ + output = tmp_path / "output.dbn" + + live_client.subscribe( + dataset=Dataset.GLBX_MDP3, + schema=schema, + stype_in=SType.RAW_SYMBOL, + symbols="TEST", + ) + live_client.add_stream(output) + + live_client.start() + + await live_client.wait_for_close() + + expected_data = BytesIO( + zstandard.ZstdDecompressor() + .stream_reader(test_data_path(schema).open("rb")) + .read(), + ) + expected_data.seek(0) # rewind + + assert output.read_bytes() == expected_data.read() + + @pytest.mark.parametrize( "schema", (pytest.param(schema, id=str(schema)) for schema in Schema.variants()),