diff --git a/CHANGELOG.md b/CHANGELOG.md index 05789e2..6c9b256 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # Changelog +## 0.23.1 - 2023-11-10 + +#### Enhancements +- Added new publishers for consolidated DBEQ.BASIC and DBEQ.PLUS + +#### Bug fixes +- Fixed an issue where `Live.block_for_close` and `Live.wait_for_close` would not flush streams if the timeout was reached +- Fixed a performance regression when reading a historical DBN file into a numpy array + ## 0.23.0 - 2023-10-26 #### Enhancements diff --git a/databento/common/dbnstore.py b/databento/common/dbnstore.py index c9c8b63..1702ca0 100644 --- a/databento/common/dbnstore.py +++ b/databento/common/dbnstore.py @@ -10,7 +10,16 @@ from io import BytesIO from os import PathLike from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, BinaryIO, Callable, Literal, overload +from typing import ( + IO, + TYPE_CHECKING, + Any, + BinaryIO, + Callable, + Literal, + Protocol, + overload, +) import databento_dbn import numpy as np @@ -638,7 +647,7 @@ def from_file(cls, path: PathLike[str] | str) -> DBNStore: Raises ------ FileNotFoundError - If a non-existant file is specified. + If a non-existent file is specified. ValueError If an empty file is specified. @@ -1072,20 +1081,43 @@ def to_ndarray( """ schema = validate_maybe_enum(schema, Schema, "schema") - if schema is None: - if self.schema is None: + ndarray_iter: NDArrayIterator + + if self.schema is None: + # If schema is None, we're handling heterogeneous data from the live client. + # This is less performant because the records of a given schema are not contiguous in memory. + if schema is None: raise ValueError("a schema must be specified for mixed DBN data") - schema = self.schema - dtype = SCHEMA_DTYPES_MAP[schema] - ndarray_iter = NDArrayIterator( - filter(lambda r: isinstance(r, SCHEMA_STRUCT_MAP[schema]), self), - dtype, - count, - ) + schema_struct = SCHEMA_STRUCT_MAP[schema] + schema_dtype = SCHEMA_DTYPES_MAP[schema] + schema_filter = filter(lambda r: isinstance(r, schema_struct), self) + + ndarray_iter = NDArrayBytesIterator( + records=map(bytes, schema_filter), + dtype=schema_dtype, + count=count, + ) + else: + # If schema is set, we're handling homogeneous historical data. + schema_dtype = SCHEMA_DTYPES_MAP[self.schema] + + if self._metadata.ts_out: + schema_dtype.append(("ts_out", "u8")) + + if schema is not None and schema != self.schema: + # This is to maintain identical behavior with NDArrayBytesIterator + ndarray_iter = iter([np.empty([0, 1], dtype=schema_dtype)]) + else: + ndarray_iter = NDArrayStreamIterator( + reader=self.reader, + dtype=schema_dtype, + offset=self._metadata_length, + count=count, + ) if count is None: - return next(ndarray_iter, np.empty([0, 1], dtype=dtype)) + return next(ndarray_iter, np.empty([0, 1], dtype=schema_dtype)) return ndarray_iter @@ -1124,10 +1156,66 @@ def _transcode( transcoder.flush() -class NDArrayIterator: +class NDArrayIterator(Protocol): + @abc.abstractmethod + def __iter__(self) -> NDArrayIterator: + ... + + @abc.abstractmethod + def __next__(self) -> np.ndarray[Any, Any]: + ... + + +class NDArrayStreamIterator(NDArrayIterator): + """ + Iterator for homogeneous byte streams of DBN records. + """ + + def __init__( + self, + reader: IO[bytes], + dtype: list[tuple[str, str]], + offset: int = 0, + count: int | None = None, + ) -> None: + self._reader = reader + self._dtype = np.dtype(dtype) + self._offset = offset + self._count = count + + self._reader.seek(offset) + + def __iter__(self) -> NDArrayStreamIterator: + return self + + def __next__(self) -> np.ndarray[Any, Any]: + if self._count is None: + read_size = -1 + else: + read_size = self._dtype.itemsize * max(self._count, 1) + + if buffer := self._reader.read(read_size): + try: + return np.frombuffer( + buffer=buffer, + dtype=self._dtype, + ) + except ValueError: + raise BentoError( + "DBN file is truncated or contains an incomplete record", + ) + + raise StopIteration + + +class NDArrayBytesIterator(NDArrayIterator): + """ + Iterator for heterogeneous streams of DBN records. + """ + def __init__( self, - records: Iterator[DBNRecord], + records: Iterator[bytes], dtype: list[tuple[str, str]], count: int | None, ): @@ -1144,7 +1232,7 @@ def __next__(self) -> np.ndarray[Any, Any]: num_records = 0 for record in itertools.islice(self._records, self._count): num_records += 1 - record_bytes.write(bytes(record)) + record_bytes.write(record) if num_records == 0: if self._first_next: @@ -1152,14 +1240,25 @@ def __next__(self) -> np.ndarray[Any, Any]: raise StopIteration self._first_next = False - return np.frombuffer( - record_bytes.getvalue(), - dtype=self._dtype, - count=num_records, - ) + + try: + return np.frombuffer( + record_bytes.getbuffer(), + dtype=self._dtype, + count=num_records, + ) + except ValueError: + raise BentoError( + "DBN file is truncated or contains an incomplete record", + ) class DataFrameIterator: + """ + Iterator for DataFrames that supports batching and column formatting for + DBN records. + """ + def __init__( self, records: Iterator[np.ndarray[Any, Any]], diff --git a/databento/common/publishers.py b/databento/common/publishers.py index 08f7a01..a8ac90f 100644 --- a/databento/common/publishers.py +++ b/databento/common/publishers.py @@ -95,6 +95,8 @@ class Venue(StringyMixin, str, Enum): ICE Futures Europe (Commodities). NDEX ICE Endex. + DBEQ + Databento Equities - Consolidated. """ @@ -137,6 +139,7 @@ class Venue(StringyMixin, str, Enum): MXOP = "MXOP" IFEU = "IFEU" NDEX = "NDEX" + DBEQ = "DBEQ" @classmethod def from_int(cls, value: int) -> Venue: @@ -221,6 +224,8 @@ def from_int(cls, value: int) -> Venue: return Venue.IFEU if value == 39: return Venue.NDEX + if value == 40: + return Venue.DBEQ raise ValueError(f"Integer value {value} does not correspond with any Venue variant") def to_int(self) -> int: @@ -305,6 +310,8 @@ def to_int(self) -> int: return 38 if self == Venue.NDEX: return 39 + if self == Venue.DBEQ: + return 40 raise ValueError("Invalid Venue") @property @@ -390,6 +397,8 @@ def description(self) -> str: return "ICE Futures Europe (Commodities)" if self == Venue.NDEX: return "ICE Endex" + if self == Venue.DBEQ: + return "Databento Equities - Consolidated" raise ValueError("Unexpected Venue value") @unique @@ -805,6 +814,10 @@ class Publisher(StringyMixin, str, Enum): ICE Futures Europe (Commodities). NDEX_IMPACT_NDEX ICE Endex. + DBEQ_BASIC_DBEQ + DBEQ Basic - Consolidated. + DBEQ_PLUS_DBEQ + DBEQ Plus - Consolidated. """ @@ -866,6 +879,8 @@ class Publisher(StringyMixin, str, Enum): DBEQ_PLUS_FINC = "DBEQ.PLUS.FINC" IFEU_IMPACT_IFEU = "IFEU.IMPACT.IFEU" NDEX_IMPACT_NDEX = "NDEX.IMPACT.NDEX" + DBEQ_BASIC_DBEQ = "DBEQ.BASIC.DBEQ" + DBEQ_PLUS_DBEQ = "DBEQ.PLUS.DBEQ" @classmethod def from_int(cls, value: int) -> Publisher: @@ -988,6 +1003,10 @@ def from_int(cls, value: int) -> Publisher: return Publisher.IFEU_IMPACT_IFEU if value == 58: return Publisher.NDEX_IMPACT_NDEX + if value == 59: + return Publisher.DBEQ_BASIC_DBEQ + if value == 60: + return Publisher.DBEQ_PLUS_DBEQ raise ValueError(f"Integer value {value} does not correspond with any Publisher variant") def to_int(self) -> int: @@ -1110,6 +1129,10 @@ def to_int(self) -> int: return 57 if self == Publisher.NDEX_IMPACT_NDEX: return 58 + if self == Publisher.DBEQ_BASIC_DBEQ: + return 59 + if self == Publisher.DBEQ_PLUS_DBEQ: + return 60 raise ValueError("Invalid Publisher") @property def venue(self) -> Venue: @@ -1232,6 +1255,10 @@ def venue(self) -> Venue: return Venue.IFEU if self == Publisher.NDEX_IMPACT_NDEX: return Venue.NDEX + if self == Publisher.DBEQ_BASIC_DBEQ: + return Venue.DBEQ + if self == Publisher.DBEQ_PLUS_DBEQ: + return Venue.DBEQ raise ValueError("Unexpected Publisher value") @property def dataset(self) -> Dataset: @@ -1354,6 +1381,10 @@ def dataset(self) -> Dataset: return Dataset.IFEU_IMPACT if self == Publisher.NDEX_IMPACT_NDEX: return Dataset.NDEX_IMPACT + if self == Publisher.DBEQ_BASIC_DBEQ: + return Dataset.DBEQ_BASIC + if self == Publisher.DBEQ_PLUS_DBEQ: + return Dataset.DBEQ_PLUS raise ValueError("Unexpected Publisher value") @property @@ -1477,4 +1508,8 @@ def description(self) -> str: return "ICE Futures Europe (Commodities)" if self == Publisher.NDEX_IMPACT_NDEX: return "ICE Endex" + if self == Publisher.DBEQ_BASIC_DBEQ: + return "DBEQ Basic - Consolidated" + if self == Publisher.DBEQ_PLUS_DBEQ: + return "DBEQ Plus - Consolidated" raise ValueError("Unexpected Publisher value") diff --git a/databento/common/symbology.py b/databento/common/symbology.py index f2c0a51..811aa1b 100644 --- a/databento/common/symbology.py +++ b/databento/common/symbology.py @@ -264,7 +264,7 @@ def insert_metadata(self, metadata: Metadata) -> None: stype_out=stype_out, ) - self._insert_inverval( + self._insert_interval( instrument_id, MappingInterval( start_date=start_date, @@ -308,7 +308,7 @@ def insert_symbol_mapping_msg( else: symbol = msg.stype_out_symbol - self._insert_inverval( + self._insert_interval( msg.hd.instrument_id, MappingInterval( start_date=pd.Timestamp(start_ts, unit="ns", tz="utc").date(), @@ -383,7 +383,7 @@ def insert_json( stype_out=stype_out, ) - self._insert_inverval( + self._insert_interval( instrument_id, MappingInterval( start_date=start_date, @@ -540,7 +540,7 @@ def map_symbols_json( return out_file_valid - def _insert_inverval(self, instrument_id: int, interval: MappingInterval) -> None: + def _insert_interval(self, instrument_id: int, interval: MappingInterval) -> None: """ Insert a SymbolInterval into the map. diff --git a/databento/live/client.py b/databento/live/client.py index 03a74d3..ee7436e 100644 --- a/databento/live/client.py +++ b/databento/live/client.py @@ -498,6 +498,7 @@ def terminate(self) -> None: if self._session is None: raise ValueError("cannot terminate a live client before it is connected") self._session.abort() + self._cleanup_client() def block_for_close( self, @@ -539,6 +540,8 @@ def block_for_close( raise except Exception: raise BentoError("connection lost") from None + finally: + self._cleanup_client() async def wait_for_close( self, @@ -581,9 +584,13 @@ async def wait_for_close( self.terminate() if isinstance(exc, KeyboardInterrupt): raise + except BentoError: + raise except Exception: logger.exception("exception encountered waiting for close") raise BentoError("connection lost") from None + finally: + self._cleanup_client() async def _shutdown(self) -> None: """ @@ -597,6 +604,12 @@ async def _shutdown(self) -> None: return await self._session.wait_for_close() + def _cleanup_client(self) -> None: + """ + Cleanup any stateful client data. + """ + self._symbology_map.clear() + to_remove = [] for stream in self._user_streams: stream_name = getattr(stream, "name", str(stream)) @@ -609,8 +622,6 @@ async def _shutdown(self) -> None: for key in to_remove: self._user_streams.pop(key) - self._symbology_map.clear() - def _map_symbol(self, record: DBNRecord) -> None: if isinstance(record, databento_dbn.SymbolMappingMsg): out_symbol = record.stype_out_symbol diff --git a/databento/version.py b/databento/version.py index 8b301a7..43e16f5 100644 --- a/databento/version.py +++ b/databento/version.py @@ -1 +1 @@ -__version__ = "0.23.0" +__version__ = "0.23.1" diff --git a/pyproject.toml b/pyproject.toml index d39b4f4..b3dff63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databento" -version = "0.23.0" +version = "0.23.1" description = "Official Python client library for Databento" authors = [ "Databento ", diff --git a/tests/test_bento_compression.py b/tests/test_bento_compression.py index cd6a472..f3bb5f2 100644 --- a/tests/test_bento_compression.py +++ b/tests/test_bento_compression.py @@ -20,7 +20,10 @@ def test_is_dbn(data: bytes, expected: bool) -> None: """ Test that buffers that start with DBN are identified as DBN files. """ + # Arrange, Act reader = BytesIO(data) + + # Assert assert is_dbn(reader) == expected @@ -52,5 +55,8 @@ def test_is_zstandard(data: bytes, expected: bool) -> None: """ Test that buffers that contain ZSTD data are correctly identified. """ + # Arrange, Act reader = BytesIO(data) + + # Assert assert is_zstandard(reader) == expected diff --git a/tests/test_bento_data_source.py b/tests/test_bento_data_source.py index cd81778..2c741c4 100644 --- a/tests/test_bento_data_source.py +++ b/tests/test_bento_data_source.py @@ -15,9 +15,11 @@ def test_memory_data_source( """ Test create of MemoryDataSource. """ + # Arrange, Act data = test_data(schema) data_source = MemoryDataSource(data) + # Assert assert len(data) == data_source.nbytes assert repr(data) == data_source.name @@ -30,8 +32,10 @@ def test_file_data_source( """ Test create of FileDataSource. """ + # Arrange, Act path = test_data_path(schema) data_source = FileDataSource(path) + # Assert assert path.stat().st_size == data_source.nbytes assert path.name == data_source.name diff --git a/tests/test_common_cram.py b/tests/test_common_cram.py index fa9cbb4..f46d149 100644 --- a/tests/test_common_cram.py +++ b/tests/test_common_cram.py @@ -31,8 +31,11 @@ def test_get_challenge_response( - challenge is the CRAM challenge, this is salt for the hash. """ + # Arrange, Act response = cram.get_challenge_response( challenge=challenge, key=key, ) + + # Assert assert response == expected diff --git a/tests/test_common_enums.py b/tests/test_common_enums.py index 9b101da..9d80180 100644 --- a/tests/test_common_enums.py +++ b/tests/test_common_enums.py @@ -52,8 +52,9 @@ def test_int_enum_string_coercion(enum_type: type[Enum]) -> None: See: databento.common.enums.coercible """ - for enum in enum_type: - assert enum == enum_type(str(enum.value)) + # Arrange, Act, Assert + for variant in enum_type: + assert variant == enum_type(str(variant.value)) with pytest.raises(ValueError): enum_type("NaN") # sanity @@ -69,6 +70,7 @@ def test_str_enum_case_coercion(enum_type: type[Enum]) -> None: See: databento.common.enums.coercible """ + # Arrange, Act, Assert for enum in enum_type: assert enum == enum_type(enum.value.lower()) assert enum == enum_type(enum.value.upper()) @@ -88,11 +90,13 @@ def test_enum_name_coercion(enum_type: type[Enum]) -> None: See: databento.common.enums.coercible """ + # Arrange, Act if enum_type in (Compression, Encoding, Schema, SType): enum_it = iter(enum_type.variants()) # type: ignore [attr-defined] else: enum_it = iter(enum_type) + # Assert for enum in enum_it: assert enum == enum_type(enum.name) assert enum == enum_type(enum.name.replace("_", "-")) @@ -113,9 +117,11 @@ def test_enum_none_not_coercible(enum_type: type[Enum]) -> None: See: databento.common.enum.coercible """ + # Arrange, Act if enum_type == Compression: enum_type(None) else: + # Assert with pytest.raises(ValueError): enum_type(None) @@ -131,8 +137,11 @@ def test_int_enum_stringy_mixin(enum_type: type[Enum]) -> None: See: databento.common.enum.StringyMixin """ + # Arrange, Act if not issubclass(enum_type, StringyMixin): pytest.skip(f"{type(enum_type)} is not a subclass of StringyMixin") + + # Assert for enum in enum_type: assert str(enum) == enum.name.lower() @@ -148,8 +157,11 @@ def test_str_enum_stringy_mixin(enum_type: type[Enum]) -> None: See: databento.common.enum.StringyMixin """ + # Arrange, Act if not issubclass(enum_type, StringyMixin): pytest.skip(f"{type(enum_type)} is not a subclass of StringyMixin") + + # Assert for enum in enum_type: assert str(enum) == enum.value @@ -162,8 +174,11 @@ def test_int_flags_stringy_mixin(enum_type: type[Flag]) -> None: """ Test that combinations of int flags are displayed properly. """ + # Arrange, Act for value in map(sum, combinations(enum_type, 2)): # type: ignore [arg-type] record_flags = enum_type(value) + + # Assert assert str(record_flags) == ", ".join( f.name.lower() for f in enum_type if f in record_flags ) diff --git a/tests/test_common_iterator.py b/tests/test_common_iterator.py index 80ffa27..0082561 100644 --- a/tests/test_common_iterator.py +++ b/tests/test_common_iterator.py @@ -29,5 +29,6 @@ def test_chunk( """ Test that an iterable is chunked property. """ + # Arrange, Act, Assert chunks = [chunk for chunk in iterator.chunk(things, size)] assert chunks == expected diff --git a/tests/test_common_parsing.py b/tests/test_common_parsing.py index 28dda40..24ba801 100644 --- a/tests/test_common_parsing.py +++ b/tests/test_common_parsing.py @@ -115,6 +115,7 @@ def test_optional_symbols_list_to_list_int( If integers are given for a different SType we expect a ValueError. """ + # Arrange, Act, Assert if isinstance(expected, list): assert optional_symbols_list_to_list(symbols, stype) == expected else: @@ -160,6 +161,7 @@ def test_optional_symbols_list_to_list_numpy( If integers are given for a different SType we expect a ValueError. """ + # Arrange, Act, Assert if isinstance(expected, list): assert optional_symbols_list_to_list(symbols, stype) == expected else: @@ -195,6 +197,7 @@ def test_optional_symbols_list_to_list_raw_symbol( """ Test that str are allowed for SType.RAW_SYMBOL. """ + # Arrange, Act, Assert if isinstance(expected, list): assert optional_symbols_list_to_list(symbols, stype) == expected else: @@ -267,4 +270,5 @@ def test_datetime_to_unix_nanoseconds( """ Test that various inputs for times convert to unix nanoseconds. """ + # Arrange, Act, Assert assert optional_datetime_to_unix_nanoseconds(value) == expected diff --git a/tests/test_common_symbology.py b/tests/test_common_symbology.py index c70f264..275ddc1 100644 --- a/tests/test_common_symbology.py +++ b/tests/test_common_symbology.py @@ -188,6 +188,7 @@ def test_instrument_map( """ Test the creation of an InstrumentMap. """ + # Arrange, Act, Assert assert instrument_map._data == {} @@ -199,6 +200,7 @@ def test_instrument_map_insert_metadata( """ Test the insertion of DBN Metadata. """ + # Arrange symbol = "test" instrument_id = 1234 @@ -219,7 +221,10 @@ def test_instrument_map_insert_metadata( mappings=mappings, ) + # Act instrument_map.insert_metadata(metadata) + + # Assert assert instrument_map.resolve(instrument_id, start_date.date()) == symbol @@ -232,6 +237,7 @@ def test_instrument_map_insert_metadata_multiple_mappings( Test the insertion of a DBN Metadata with multiple mapping for the same instrument_id. """ + # Arrange symbols = ["test_1", "test_2", "test_3"] instrument_id = 1234 @@ -254,8 +260,10 @@ def test_instrument_map_insert_metadata_multiple_mappings( mappings=mappings, ) + # Act instrument_map.insert_metadata(metadata) + # Assert for offset, symbol in enumerate(symbols): assert ( instrument_map.resolve( @@ -274,6 +282,7 @@ def test_instrument_map_insert_metadata_empty_mappings( """ Test the insertion of DBN Metadata that contains an empty mapping. """ + # Arrange mappings = [ SymbolMapping( raw_symbol="empty", @@ -291,7 +300,10 @@ def test_instrument_map_insert_metadata_empty_mappings( mappings=mappings, ) + # Act instrument_map.insert_metadata(metadata) + + # Assert assert instrument_map._data == {} @@ -303,6 +315,7 @@ def test_instrument_map_insert_symbol_mapping_message( """ Test the insertion of a SymbolMappingMessage. """ + # Arrange symbol = "test" instrument_id = 1234 @@ -314,8 +327,10 @@ def test_instrument_map_insert_symbol_mapping_message( end_ts=end_date, ) + # Act instrument_map.insert_symbol_mapping_msg(sym_msg) + # Assert assert instrument_map.resolve(instrument_id, start_date.date()) == symbol @@ -328,9 +343,11 @@ def test_instrument_map_insert_symbol_mapping_message_multiple_mappings( Test the insertion of multiple SymbolMappingMsg object for the same instrument_id. """ + # Arrange symbols = ["test_1", "test_2", "test_3"] instrument_id = 1234 + # Act for offset, symbol in enumerate(symbols): sym_mapping_msg = create_symbol_mapping_message( instrument_id=instrument_id, @@ -341,6 +358,7 @@ def test_instrument_map_insert_symbol_mapping_message_multiple_mappings( ) instrument_map.insert_symbol_mapping_msg(sym_mapping_msg) + # Assert for offset, symbol in enumerate(symbols): assert ( instrument_map.resolve( @@ -382,6 +400,7 @@ def test_instrument_map_insert_symbology_response( """ Test the insertion of a symbology responses. """ + # Arrange result = { symbol_in: [ {"d0": start_date.isoformat(), "d1": end_date.isoformat(), "s": symbol_out}, @@ -392,8 +411,11 @@ def test_instrument_map_insert_symbology_response( stype_in=stype_in, stype_out=stype_out, ) + + # Act instrument_map.insert_json(sym_resp) + # Assert # This is hard coded because it should be invariant under parameterization assert instrument_map.resolve(1234, start_date.date()) == "test_1" @@ -407,6 +429,7 @@ def test_instrument_map_insert_symbology_response_multiple_mappings( Test the insertion of multiple symbology responses for the same instrument_id. """ + # Arrange, Act symbols = ["test_1", "test_2", "test_3"] instrument_id = 1234 @@ -426,6 +449,7 @@ def test_instrument_map_insert_symbology_response_multiple_mappings( instrument_map.insert_json(sym_resp) + # Assert for offset, symbol in enumerate(symbols): assert ( instrument_map.resolve( @@ -444,6 +468,7 @@ def test_instrument_map_insert_symbology_response_empty_mapping( """ Test the insertion of an empty symbology mapping. """ + # Arrange result = { "test": [ {"d0": start_date.isoformat(), "d1": end_date.isoformat(), "s": ""}, @@ -455,7 +480,10 @@ def test_instrument_map_insert_symbology_response_empty_mapping( stype_out=SType.INSTRUMENT_ID, ) + # Act instrument_map.insert_json(sym_resp) + + # Assert assert instrument_map._data == {} @@ -493,6 +521,7 @@ def test_instrument_map_insert_json_str( """ Test the insertion of a JSON symbology response. """ + # Arrange result = { symbol_in: [ {"d0": start_date.isoformat(), "d1": end_date.isoformat(), "s": symbol_out}, @@ -504,8 +533,10 @@ def test_instrument_map_insert_json_str( stype_out=stype_out, ) + # Act instrument_map.insert_json(json.dumps(sym_resp)) + # Assert assert instrument_map.resolve(1234, start_date.date()) == expected_symbol @@ -544,6 +575,7 @@ def test_instrument_map_insert_json_file( """ Test the insertion of a JSON file. """ + # Arrange result = { symbol_in: [ {"d0": start_date.isoformat(), "d1": end_date.isoformat(), "s": symbol_out}, @@ -559,9 +591,11 @@ def test_instrument_map_insert_json_file( with open(symboloy_json, mode="w") as resp_file: json.dump(sym_resp, resp_file) + # Act with open(symboloy_json) as resp_file: instrument_map.insert_json(resp_file) + # Assert assert instrument_map.resolve(1234, start_date.date()) == expected_symbol @@ -573,6 +607,7 @@ def test_instrument_map_insert_json_str_empty_mapping( """ Test the insertion of an JSON symbology mapping. """ + # Arrange result = { "test": [ {"d0": start_date.isoformat(), "d1": end_date.isoformat(), "s": ""}, @@ -584,7 +619,10 @@ def test_instrument_map_insert_json_str_empty_mapping( stype_out=SType.INSTRUMENT_ID, ) + # Act instrument_map.insert_json(json.dumps(sym_resp)) + + # Assert assert instrument_map._data == {} @@ -609,6 +647,7 @@ def test_instrument_map_insert_symbology_response_invalid_stype( Test that a symbology response with no instrument_id mapping raises a ValueError. """ + # Arrange result = { symbol_in: [ {"d0": start_date.isoformat(), "d1": end_date.isoformat(), "s": symbol_out}, @@ -620,6 +659,7 @@ def test_instrument_map_insert_symbology_response_invalid_stype( stype_out=stype_out, ) + # Act, Assert with pytest.raises(ValueError): instrument_map.insert_json(sym_resp) @@ -630,6 +670,7 @@ def test_instrument_map_insert_symbology_response_invalid_response( """ Test that an invalid symbology response raises a ValueError. """ + # Arrange, Act, Assert with pytest.raises(ValueError): instrument_map.insert_json({"foo": "bar"}) @@ -641,6 +682,7 @@ def test_instrument_map_insert_symbology_response_invalid_result_entry( """ Test that an invalid symbology response entry raises a ValueError. """ + # Arrange result = { "test_1": [{"d0": start_date.isoformat(), "s": 1234}], } @@ -648,6 +690,7 @@ def test_instrument_map_insert_symbology_response_invalid_result_entry( result=result, ) + # Act, Assert with pytest.raises(ValueError): instrument_map.insert_json(sym_resp) @@ -660,6 +703,7 @@ def test_instrument_map_resolve_with_date( """ Test that resolve accepts `datetime.date` objects. """ + # Arrange, Act symbol = "test_1" instrument_id = 1234 @@ -671,6 +715,7 @@ def test_instrument_map_resolve_with_date( ), ] + # Assert assert ( instrument_map.resolve( instrument_id, @@ -690,6 +735,7 @@ def test_instrument_map_ignore_duplicate( """ Test that a duplicate entry is not inserted into an InstrumentMap. """ + # Arrange, Act symbol = "test_1" instrument_id = 1234 @@ -701,6 +747,7 @@ def test_instrument_map_ignore_duplicate( ), ] + # Act, Assert assert len(instrument_map._data[instrument_id]) == 1 msg = create_symbol_mapping_message( @@ -737,6 +784,7 @@ def test_instrument_map_symbols_csv( Test that a CSV file without mapped symbols is equivelant to a CSV file with mapped symbols after processing with map_symbols_csv. """ + # Arrange, Act store = DBNStore.from_file(test_data_path(schema)) csv_path = tmp_path / f"test_{schema}.csv" store.to_csv( @@ -758,6 +806,7 @@ def test_instrument_map_symbols_csv( out_file=outfile, ) + # Assert assert outfile == written_path assert outfile.read_text() == expected_path.read_text() @@ -783,6 +832,7 @@ def test_instrument_map_symbols_json( Test that a JSON file without mapped symbols is equivelant to a JSON file with mapped symbols after processing with map_symbols_json. """ + # Arrange, Act store = DBNStore.from_file(test_data_path(schema)) json_path = tmp_path / f"test_{schema}.json" store.to_json( @@ -804,5 +854,6 @@ def test_instrument_map_symbols_json( out_file=outfile, ) + # Assert assert outfile == written_path assert outfile.read_text() == expected_path.read_text() diff --git a/tests/test_common_validation.py b/tests/test_common_validation.py index 928e4ed..447489d 100644 --- a/tests/test_common_validation.py +++ b/tests/test_common_validation.py @@ -108,6 +108,7 @@ def test_validate_gateway( """ Tests several correct and malformed URLs. """ + # Arrange, Act, Assert if isinstance(expected, str): assert validate_gateway(url) == expected else: @@ -137,6 +138,7 @@ def test_validate_smart_symbol( """ Test several correct smart symbols and invalid syntax. """ + # Arrange, Act, Assert if isinstance(expected, str): assert validate_smart_symbol(symbol) == expected else: @@ -163,6 +165,7 @@ def test_validate_semantic_string( - whitespace - contain unprintable characters """ + # Arrange, Act, Assert if isinstance(expected, str): assert validate_semantic_string(value, "unittest") == expected else: diff --git a/tests/test_historical_bento.py b/tests/test_historical_bento.py index dd15a03..97e72f5 100644 --- a/tests/test_historical_bento.py +++ b/tests/test_historical_bento.py @@ -695,7 +695,7 @@ def test_dbnstore_iterable( Tests the DBNStore iterable implementation to ensure records can be accessed by iteration. """ - # Arrange + # Arrange, Act stub_data = test_data(Schema.MBO) dbnstore = DBNStore.from_bytes(data=stub_data) @@ -703,6 +703,7 @@ def test_dbnstore_iterable( first: MBOMsg = record_list[0] # type: ignore second: MBOMsg = record_list[1] # type: ignore + # Assert assert first.hd.length == 14 assert first.hd.rtype == 160 assert first.hd.rtype == 160 @@ -748,13 +749,14 @@ def test_dbnstore_iterable_parallel( For example, calling next() on one iterator does not affect another. """ - # Arrange + # Arrange, Act stub_data = test_data(Schema.MBO) dbnstore = DBNStore.from_bytes(data=stub_data) first = iter(dbnstore) second = iter(dbnstore) + # Assert assert next(first) == next(second) assert next(first) == next(second) @@ -774,12 +776,15 @@ def test_dbnstore_compression_equality( Note that stub data is compressed with zstandard by default. """ + # Arrange zstd_stub_data = test_data(schema) dbn_stub_data = zstandard.ZstdDecompressor().stream_reader(zstd_stub_data).read() + # Act zstd_dbnstore = DBNStore.from_bytes(zstd_stub_data) dbn_dbnstore = DBNStore.from_bytes(dbn_stub_data) + # Assert assert len(zstd_dbnstore.to_ndarray()) == len(dbn_dbnstore.to_ndarray()) assert zstd_dbnstore.metadata == dbn_dbnstore.metadata assert zstd_dbnstore.reader.read() == dbn_dbnstore.reader.read() @@ -869,6 +874,7 @@ def test_dbnstore_buffer_rewind( dbn_bytes.write(dbn_stub_data) dbnstore = DBNStore.from_bytes(data=dbn_bytes) + # Assert assert len(dbnstore.to_df()) == 4 @@ -899,8 +905,8 @@ def test_dbnstore_to_ndarray_with_count( # Act dbnstore = DBNStore.from_bytes(data=dbn_stub_data) - nd_iter = dbnstore.to_ndarray(count=count) expected = dbnstore.to_ndarray() + nd_iter = dbnstore.to_ndarray(count=count) # Assert aggregator: list[np.ndarray[Any, Any]] = [] @@ -929,8 +935,8 @@ def test_dbnstore_to_ndarray_with_schema( # Act dbnstore = DBNStore.from_bytes(data=dbn_stub_data) - actual = dbnstore.to_ndarray(schema=schema) expected = dbnstore.to_ndarray() + actual = dbnstore.to_ndarray(schema=schema) # Assert for i, row in enumerate(actual): @@ -1008,8 +1014,8 @@ def test_dbnstore_to_df_with_count( # Act dbnstore = DBNStore.from_bytes(data=dbn_stub_data) - df_iter = dbnstore.to_df(count=count) expected = dbnstore.to_df() + df_iter = dbnstore.to_df(count=count) # Assert aggregator: list[pd.DataFrame] = [] @@ -1042,8 +1048,8 @@ def test_dbnstore_to_df_with_schema( # Act dbnstore = DBNStore.from_bytes(data=dbn_stub_data) - actual = dbnstore.to_df(schema=schema) expected = dbnstore.to_df() + actual = dbnstore.to_df(schema=schema) # Assert assert actual.equals(expected) diff --git a/tests/test_historical_client.py b/tests/test_historical_client.py index cd5f474..965682e 100644 --- a/tests/test_historical_client.py +++ b/tests/test_historical_client.py @@ -93,7 +93,7 @@ def test_custom_gateway_force_https( """ Test that custom gateways are forced to the https scheme. """ - # Arrange Act + # Arrange, Act client = db.Historical(key="DUMMY_API_KEY", gateway=gateway) # Assert diff --git a/tests/test_historical_data.py b/tests/test_historical_data.py index 3376a0e..b684eb8 100644 --- a/tests/test_historical_data.py +++ b/tests/test_historical_data.py @@ -8,6 +8,7 @@ def test_mbo_fields() -> None: """ Test that columns match the MBO struct. """ + # Arrange struct = SCHEMA_STRUCT_MAP[databento.Schema.MBO] columns = SCHEMA_COLUMNS[databento.Schema.MBO] @@ -16,7 +17,10 @@ def test_mbo_fields() -> None: fields.remove("record_size") fields.remove("size_hint") + # Act difference = fields.symmetric_difference(set(columns)) + + # Assert assert not difference @@ -35,6 +39,7 @@ def test_mbp_fields( """ Test that columns match the MBP structs. """ + # Arrange struct = SCHEMA_STRUCT_MAP[schema] columns = SCHEMA_COLUMNS[schema] @@ -43,8 +48,10 @@ def test_mbp_fields( fields.remove("record_size") fields.remove("size_hint") + # Act difference = fields.symmetric_difference(set(columns)) + # Assert assert "levels" in difference # bid/ask size, price, ct for each level, plus the levels field @@ -66,6 +73,7 @@ def test_ohlcv_fields( """ Test that columns match the OHLCV structs. """ + # Arrange struct = SCHEMA_STRUCT_MAP[schema] columns = SCHEMA_COLUMNS[schema] @@ -74,7 +82,10 @@ def test_ohlcv_fields( fields.remove("record_size") fields.remove("size_hint") + # Act difference = fields.symmetric_difference(set(columns)) + + # Assert assert not difference @@ -82,6 +93,7 @@ def test_trades_struct() -> None: """ Test that columns match the Trades struct. """ + # Arrange struct = SCHEMA_STRUCT_MAP[databento.Schema.TRADES] columns = SCHEMA_COLUMNS[databento.Schema.TRADES] @@ -90,7 +102,10 @@ def test_trades_struct() -> None: fields.remove("record_size") fields.remove("size_hint") + # Act difference = fields.symmetric_difference(set(columns)) + + # Assert assert not difference @@ -98,6 +113,7 @@ def test_definition_struct() -> None: """ Test that columns match the Definition struct. """ + # Arrange struct = SCHEMA_STRUCT_MAP[databento.Schema.DEFINITION] columns = SCHEMA_COLUMNS[databento.Schema.DEFINITION] @@ -106,7 +122,10 @@ def test_definition_struct() -> None: fields.remove("record_size") fields.remove("size_hint") + # Act difference = fields.symmetric_difference(set(columns)) + + # Assert assert not difference @@ -114,6 +133,7 @@ def test_imbalance_struct() -> None: """ Test that columns match the Imbalance struct. """ + # Arrange struct = SCHEMA_STRUCT_MAP[databento.Schema.IMBALANCE] columns = SCHEMA_COLUMNS[databento.Schema.IMBALANCE] @@ -122,7 +142,10 @@ def test_imbalance_struct() -> None: fields.remove("record_size") fields.remove("size_hint") + # Act difference = fields.symmetric_difference(set(columns)) + + # Assert assert not difference @@ -130,6 +153,7 @@ def test_statistics_struct() -> None: """ Test that columns match the Statistics struct. """ + # Arrange struct = SCHEMA_STRUCT_MAP[databento.Schema.STATISTICS] columns = SCHEMA_COLUMNS[databento.Schema.STATISTICS] @@ -138,5 +162,8 @@ def test_statistics_struct() -> None: fields.remove("record_size") fields.remove("size_hint") + # Act difference = fields.symmetric_difference(set(columns)) + + # Assert assert not difference diff --git a/tests/test_historical_error.py b/tests/test_historical_error.py index b19140b..06bf563 100644 --- a/tests/test_historical_error.py +++ b/tests/test_historical_error.py @@ -30,8 +30,11 @@ def test_check_http_status( Test that responses with the given status code raise the expected exception. """ + # Arrange response = requests.Response() response.status_code = status_code + + # Act, Assert with pytest.raises(expected_exception) as exc: check_http_error(response) @@ -57,11 +60,14 @@ async def test_check_http_status_async( Test that responses with the given status code raise the expected exception. """ + # Arrange response = MagicMock( spec=aiohttp.ClientResponse, status=status_code, json=AsyncMock(return_value={}), ) + + # Act, Assert with pytest.raises(expected_exception) as exc: await check_http_error_async(response) diff --git a/tests/test_historical_warnings.py b/tests/test_historical_warnings.py index 852c0f6..0862d80 100644 --- a/tests/test_historical_warnings.py +++ b/tests/test_historical_warnings.py @@ -35,13 +35,16 @@ def test_backend_warning( Test that a backend warning in a response header is correctly parsed as a type of BentoWarning. """ + # Arrange response = Response() expected = f'["{category}: {message}"]' response.headers[header_field] = expected + # Act with pytest.warns() as warnings: check_backend_warnings(response) + # Assert assert len(warnings) == 1 assert warnings.list[0].category.__name__ == expected_category assert str(warnings.list[0].message) == message @@ -60,6 +63,7 @@ def test_multiple_backend_warning( """ Test that multiple backend warnings in a response header are supported. """ + # Arrange response = Response() backend_warnings = [ "Warning: this is a test", @@ -67,9 +71,11 @@ def test_multiple_backend_warning( ] response.headers[header_field] = json.dumps(backend_warnings) + # Act with pytest.warns() as warnings: check_backend_warnings(response) + # Assert assert len(warnings) == len(backend_warnings) assert warnings.list[0].category.__name__ == "BentoWarning" assert str(warnings.list[0].message) == "this is a test" diff --git a/tests/test_live_client.py b/tests/test_live_client.py index c270a04..b54086c 100644 --- a/tests/test_live_client.py +++ b/tests/test_live_client.py @@ -38,6 +38,7 @@ def test_live_connection_refused( """ Test that a refused connection raises a BentoError. """ + # Arrange live_client = client.Live( key=test_api_key, # Connect to something that does not exist @@ -45,6 +46,7 @@ def test_live_connection_refused( port=0, ) + # Act, Assert with pytest.raises(BentoError) as exc: live_client.subscribe( dataset=Dataset.GLBX_MDP3, @@ -66,6 +68,7 @@ def test_live_connection_timeout( set a timeout of 0. """ + # Arrange monkeypatch.setattr( session, "CONNECT_TIMEOUT_SECONDS", @@ -78,6 +81,7 @@ def test_live_connection_timeout( port=mock_live_server.port, ) + # Act, Assert with pytest.raises(BentoError) as exc: live_client.subscribe( dataset=Dataset.GLBX_MDP3, @@ -103,6 +107,7 @@ def test_live_invalid_gateway( """ Test that specifying an invalid gateway raises a ValueError. """ + # Arrange, Act, Assert with pytest.raises(ValueError): client.Live( key=test_api_key, @@ -126,6 +131,7 @@ def test_live_invalid_port( """ Test that specifying an invalid port raises a ValueError. """ + # Arrange, Act, Assert with pytest.raises(ValueError): client.Live( key=test_api_key, @@ -143,6 +149,7 @@ def test_live_connection_cram_failure( Test that a failed auth message due to an incorrect CRAM raises a BentoError. """ + # Arrange # Dork up the API key in the mock client to fail CRAM bucket_id = test_api_key[-BUCKET_ID_LENGTH:] invalid_key = "db-invalidkey00000000000000FFFFF" @@ -154,6 +161,7 @@ def test_live_connection_cram_failure( port=mock_live_server.port, ) + # Act, Assert with pytest.raises(BentoError) as exc: live_client.subscribe( dataset=Dataset.GLBX_MDP3, @@ -176,18 +184,21 @@ def test_live_creation( """ Test the live constructor and successful connection to the MockLiveServer. """ + # Arrange live_client = client.Live( key=test_api_key, gateway=mock_live_server.host, port=mock_live_server.port, ) + # Act # Subscribe to connect live_client.subscribe( dataset=dataset, schema=Schema.MBO, ) + # Assert assert live_client.gateway == mock_live_server.host assert live_client.port == mock_live_server.port assert live_client._key == test_api_key @@ -204,16 +215,19 @@ def test_live_connect_auth( Test the live sent a correct AuthenticationRequest message after connecting. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, ) + # Act message = mock_live_server.get_message_of_type( gateway.AuthenticationRequest, timeout=1, ) + # Assert assert message.auth.endswith(live_client.key[-BUCKET_ID_LENGTH:]) assert message.dataset == live_client.dataset assert message.encoding == Encoding.DBN @@ -227,6 +241,7 @@ def test_live_connect_auth_two_clients( Test the live sent a correct AuthenticationRequest message after connecting two distinct clients. """ + # Arrange first = client.Live( key=test_api_key, gateway=mock_live_server.host, @@ -239,6 +254,7 @@ def test_live_connect_auth_two_clients( port=mock_live_server.port, ) + # Act first.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, @@ -248,6 +264,8 @@ def test_live_connect_auth_two_clients( gateway.AuthenticationRequest, timeout=1, ) + + # Assert assert first_auth.auth.endswith(first.key[-BUCKET_ID_LENGTH:]) assert first_auth.dataset == first.dataset assert first_auth.encoding == Encoding.DBN @@ -274,6 +292,7 @@ def test_live_start( """ Test the live sends a SesssionStart message upon calling start(). """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, @@ -281,6 +300,7 @@ def test_live_start( assert live_client.is_connected() is True + # Act live_client.start() live_client.block_for_close() @@ -290,6 +310,7 @@ def test_live_start( timeout=1, ) + # Assert assert message.start_session @@ -299,13 +320,16 @@ def test_live_start_twice( """ Test that calling start() twice raises a ValueError. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, ) + # Act live_client.start() + # Assert with pytest.raises(ValueError): live_client.start() @@ -316,6 +340,7 @@ def test_live_start_before_subscribe( """ Test that calling start() before subscribe raises a ValueError. """ + # Arrange, Act, Assert with pytest.raises(ValueError): live_client.start() @@ -355,6 +380,7 @@ def test_live_subscribe( Test various combination of subscription messages are serialized and correctly deserialized by the MockLiveServer. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=schema, @@ -363,6 +389,7 @@ def test_live_subscribe( start=start, ) + # Act message = mock_live_server.get_message_of_type( gateway.SubscriptionRequest, timeout=1, @@ -371,6 +398,7 @@ def test_live_subscribe( if symbols is None: symbols = ALL_SYMBOLS + # Assert assert message.schema == schema assert message.stype_in == stype_in assert message.symbols == symbols @@ -386,9 +414,12 @@ async def test_live_subscribe_large_symbol_list( Test that sending a subscription with a large symbol list breaks that list up into multiple messages. """ + # Arrange large_symbol_list = list( random.choices(string.ascii_uppercase, k=256), # noqa: S311 ) + + # Act live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, @@ -417,6 +448,8 @@ async def test_live_subscribe_large_symbol_list( ).symbols.split(",") reconstructed = first_message + second_message + third_message + fourth_message + + # Assert assert reconstructed == large_symbol_list @@ -427,11 +460,13 @@ def test_live_stop( """ Test that calling start() and stop() appropriately update the client state. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, ) + # Act, Assert assert live_client.is_connected() is True live_client.start() @@ -439,6 +474,8 @@ def test_live_stop( live_client.stop() live_client.block_for_close() + assert live_client.is_connected() is False + @pytest.mark.usefixtures("mock_live_server") def test_live_shutdown_remove_closed_stream( @@ -448,6 +485,7 @@ def test_live_shutdown_remove_closed_stream( """ Test that closed streams are removed upon disconnection. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, @@ -455,6 +493,7 @@ def test_live_shutdown_remove_closed_stream( output = tmp_path / "output.dbn" + # Act, Assert with output.open("wb") as out: live_client.add_stream(out) @@ -474,11 +513,13 @@ def test_live_stop_twice( """ Test that calling stop() twice does not raise an exception. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, ) + # Act, Assert live_client.stop() live_client.stop() @@ -490,6 +531,7 @@ def test_live_block_for_close( """ Test that block_for_close unblocks when the connection is closed. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, @@ -498,6 +540,7 @@ def test_live_block_for_close( start=None, ) + # Act, Assert live_client.start() live_client.block_for_close() @@ -513,6 +556,7 @@ def test_live_block_for_close_timeout( Test that block_for_close terminates the session when the timeout is reached. """ + # Arrange monkeypatch.setattr(live_client, "terminate", MagicMock()) live_client.subscribe( dataset=Dataset.GLBX_MDP3, @@ -521,10 +565,39 @@ def test_live_block_for_close_timeout( symbols="ALL_SYMBOLS", start=None, ) + + # Act, Assert live_client.block_for_close(timeout=0) live_client.terminate.assert_called_once() # type: ignore +@pytest.mark.usefixtures("mock_live_server") +def test_live_block_for_close_timeout_stream( + live_client: client.Live, + monkeypatch: pytest.MonkeyPatch, + tmp_path: pathlib.Path, +) -> None: + """ + Test that block_for_close flushes user streams on timeout. + """ + # Arrange + live_client.subscribe( + dataset=Dataset.GLBX_MDP3, + schema=Schema.MBO, + stype_in=SType.INSTRUMENT_ID, + symbols="ALL_SYMBOLS", + start=None, + ) + path = tmp_path / "test.dbn" + stream = path.open("wb") + monkeypatch.setattr(stream, "flush", MagicMock()) + live_client.add_stream(stream) + + # Act, Assert + live_client.block_for_close(timeout=0) + stream.flush.assert_called() # type: ignore [attr-defined] + + @pytest.mark.usefixtures("mock_live_server") async def test_live_wait_for_close( live_client: client.Live, @@ -532,6 +605,7 @@ async def test_live_wait_for_close( """ Test that wait_for_close unblocks when the connection is closed. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, @@ -540,10 +614,11 @@ async def test_live_wait_for_close( start=None, ) + # Act live_client.start() - await live_client.wait_for_close() + #Assert assert not live_client.is_connected() @@ -556,8 +631,10 @@ async def test_live_wait_for_close_timeout( Test that wait_for_close terminates the session when the timeout is reached. """ + # Arrange monkeypatch.setattr(live_client, "terminate", MagicMock()) + # Act live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, @@ -565,23 +642,56 @@ async def test_live_wait_for_close_timeout( symbols="ALL_SYMBOLS", start=None, ) - await live_client.wait_for_close(timeout=0) + # Assert live_client.terminate.assert_called_once() # type: ignore +@pytest.mark.usefixtures("mock_live_server") +async def test_live_wait_for_close_timeout_stream( + live_client: client.Live, + monkeypatch: pytest.MonkeyPatch, + tmp_path: pathlib.Path, +) -> None: + """ + Test that wait_for_close flushes user streams on timeout. + """ + # Arrange + live_client.subscribe( + dataset=Dataset.GLBX_MDP3, + schema=Schema.MBO, + stype_in=SType.INSTRUMENT_ID, + symbols="ALL_SYMBOLS", + start=None, + ) + + path = tmp_path / "test.dbn" + stream = path.open("wb") + monkeypatch.setattr(stream, "flush", MagicMock()) + live_client.add_stream(stream) + + # Act + await live_client.wait_for_close(timeout=0) + + # Assert + stream.flush.assert_called() # type: ignore [attr-defined] + + def test_live_add_callback( live_client: client.Live, ) -> None: """ Test that calling add_callback adds that callback to the client. """ - + # Arrange def callback(_: object) -> None: pass + # Act live_client.add_callback(callback) + + # Assert assert callback in live_client._user_callbacks assert live_client._user_callbacks[callback] is None assert live_client._user_streams == {} @@ -593,9 +703,13 @@ def test_live_add_stream( """ Test that calling add_stream adds that stream to the client. """ + # Arrange stream = BytesIO() + # Act live_client.add_stream(stream) + + # Assert assert stream in live_client._user_streams assert live_client._user_streams[stream] is None @@ -607,14 +721,18 @@ def test_live_add_stream_invalid( """ Test that passing a non-writable stream to add_stream raises a ValueError. """ + # Arrange, Act with pytest.raises(ValueError): live_client.add_stream(object) # type: ignore readable_file = tmp_path / "nope.txt" readable_file.touch() + + # Assert with pytest.raises(ValueError): live_client.add_stream(readable_file.open(mode="rb")) + def test_live_add_stream_path_directory( tmp_path: pathlib.Path, live_client: client.Live, @@ -622,9 +740,11 @@ def test_live_add_stream_path_directory( """ Test that passing a path to a directory raises an OSError. """ + # Arrange, Act, Assert with pytest.raises(OSError): live_client.add_stream(tmp_path) + @pytest.mark.skipif(platform.system() == "Windows", reason="flaky on windows runner") async def test_live_async_iteration( live_client: client.Live, @@ -632,6 +752,7 @@ async def test_live_async_iteration( """ Test async-iteration of DBN records. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, @@ -640,9 +761,12 @@ async def test_live_async_iteration( ) records: list[DBNRecord] = [] + + # Act async for record in live_client: records.append(record) + # Assert assert len(records) == 4 assert isinstance(records[0], databento_dbn.MBOMsg) assert isinstance(records[1], databento_dbn.MBOMsg) @@ -660,6 +784,7 @@ async def test_live_async_iteration_backpressure( Test that a full queue disables reading on the transport but will resume it when the queue is depleted when iterating asynchronously. """ + # Arrange monkeypatch.setattr(client, "DEFAULT_QUEUE_SIZE", 4) live_client = client.Live( @@ -681,12 +806,15 @@ async def test_live_async_iteration_backpressure( pause_mock := MagicMock(), ) + # Act live_it = iter(live_client) await live_client.wait_for_close() pause_mock.assert_called() records: list[DBNRecord] = list(live_it) + + # Assert assert len(records) == 4 assert live_client._dbn_queue.empty() @@ -700,6 +828,7 @@ async def test_live_async_iteration_dropped( """ Test that an artificially small queue size will drop messages when full. """ + # Arrange monkeypatch.setattr(client, "DEFAULT_QUEUE_SIZE", 1) live_client = client.Live( @@ -721,12 +850,15 @@ async def test_live_async_iteration_dropped( pause_mock := MagicMock(), ) + # Act live_it = iter(live_client) await live_client.wait_for_close() pause_mock.assert_called() records = list(live_it) + + # Assert assert len(records) == 1 assert live_client._dbn_queue.empty() @@ -739,18 +871,21 @@ async def test_live_async_iteration_stop( Test that stopping in the middle of iteration does not prevent iterating the queue to completion. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, stype_in=SType.RAW_SYMBOL, symbols="TEST", ) - records = [] + + # Act async for record in live_client: records.append(record) live_client.stop() + # Assert assert len(records) > 1 assert live_client._dbn_queue.empty() @@ -762,17 +897,20 @@ def test_live_sync_iteration( """ Test synchronous iteration of DBN records. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, stype_in=SType.RAW_SYMBOL, symbols="TEST", ) - records = [] + + # Act for record in live_client: records.append(record) + # Assert assert len(records) == 4 assert isinstance(records[0], databento_dbn.MBOMsg) assert isinstance(records[1], databento_dbn.MBOMsg) @@ -786,25 +924,27 @@ async def test_live_callback( """ Test callback dispatch of DBN records. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, stype_in=SType.RAW_SYMBOL, symbols="TEST", ) - records = [] def callback(record: DBNRecord) -> None: nonlocal records records.append(record) + # Act live_client.add_callback(callback) live_client.start() await live_client.wait_for_close() + # Assert assert len(records) == 4 assert isinstance(records[0], databento_dbn.MBOMsg) assert isinstance(records[1], databento_dbn.MBOMsg) @@ -826,6 +966,7 @@ async def test_live_stream_to_dbn( Test that DBN data streamed by the MockLiveServer is properly re- constructed client side. """ + # Arrange output = tmp_path / "output.dbn" live_client.subscribe( @@ -836,6 +977,7 @@ async def test_live_stream_to_dbn( ) live_client.add_stream(output.open("wb", buffering=0)) + # Act live_client.start() await live_client.wait_for_close() @@ -847,6 +989,7 @@ async def test_live_stream_to_dbn( ) expected_data.seek(0) # rewind + # Assert assert output.read_bytes() == expected_data.read() @@ -864,6 +1007,7 @@ async def test_live_stream_to_dbn_from_path( Test that DBN data streamed by the MockLiveServer is properly re- constructed client side when specifying a file as a path. """ + # Arrange output = tmp_path / "output.dbn" live_client.subscribe( @@ -874,6 +1018,7 @@ async def test_live_stream_to_dbn_from_path( ) live_client.add_stream(output) + # Act live_client.start() await live_client.wait_for_close() @@ -885,6 +1030,7 @@ async def test_live_stream_to_dbn_from_path( ) expected_data.seek(0) # rewind + # Assert assert output.read_bytes() == expected_data.read() @@ -912,6 +1058,7 @@ async def test_live_stream_to_dbn_with_tiny_buffer( Test that DBN data streamed by the MockLiveServer is properly re- constructed client side when using the small values for RECV_BUFFER_SIZE. """ + # Arrange monkeypatch.setattr(protocol, "RECV_BUFFER_SIZE", buffer_size) output = tmp_path / "output.dbn" @@ -923,6 +1070,7 @@ async def test_live_stream_to_dbn_with_tiny_buffer( ) live_client.add_stream(output.open("wb", buffering=0)) + # Act live_client.start() await live_client.wait_for_close() @@ -934,6 +1082,7 @@ async def test_live_stream_to_dbn_with_tiny_buffer( ) expected_data.seek(0) # rewind + # Assert assert output.read_bytes() == expected_data.read() @@ -947,6 +1096,7 @@ async def test_live_disconnect_async( exception. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema="mbo", @@ -958,9 +1108,11 @@ async def test_live_disconnect_async( wait = live_client.wait_for_close() + # Act protocol = live_client._session._protocol protocol.disconnected.set_exception(Exception("test")) + # Assert with pytest.raises(BentoError) as exc: await wait @@ -977,6 +1129,7 @@ def test_live_disconnect( the exception. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema="mbo", @@ -986,9 +1139,11 @@ def test_live_disconnect( live_client.start() assert live_client._session is not None + # Act protocol = live_client._session._protocol protocol.disconnected.set_exception(Exception("test")) + # Assert with pytest.raises(BentoError) as exc: live_client.block_for_close() @@ -1001,6 +1156,7 @@ async def test_live_terminate( """ Test that terminate closes the connection. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, @@ -1008,9 +1164,11 @@ async def test_live_terminate( symbols="TEST", ) + # Act live_client.terminate() - await live_client.wait_for_close() + + # Assert assert not live_client.is_connected() @@ -1030,6 +1188,7 @@ async def test_live_iteration_with_reconnect( The iteration should yield every record. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=schema, @@ -1046,6 +1205,7 @@ async def test_live_iteration_with_reconnect( assert not live_client.is_connected() + # Act live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=schema, @@ -1066,6 +1226,7 @@ async def test_live_iteration_with_reconnect( ) dbn = DBNStore.from_bytes(expected_data) + # Assert records = list(my_iter) assert len(records) == 2 * len(list(dbn)) for record in records: @@ -1088,9 +1249,11 @@ async def test_live_callback_with_reconnect( That callback should emit every record. """ + # Arrange records: list[DBNRecord] = [] live_client.add_callback(records.append) + # Act, Assert for _ in range(5): live_client.subscribe( dataset=Dataset.GLBX_MDP3, @@ -1135,6 +1298,7 @@ async def test_live_stream_with_reconnect( That output stream should be readable. """ + # Arrange # TODO: Remove when status schema is available if schema == "status": pytest.skip("no stub data for status schema") @@ -1144,6 +1308,7 @@ async def test_live_stream_with_reconnect( output = tmp_path / "output.dbn" live_client.add_stream(output.open("wb", buffering=0)) + # Act for _ in range(5): live_client.subscribe( dataset=Dataset.GLBX_MDP3, @@ -1162,6 +1327,7 @@ async def test_live_stream_with_reconnect( data = DBNStore.from_file(output) + # Assert records = list(data) for record in records: assert isinstance(record, SCHEMA_STRUCT_MAP[schema]) @@ -1175,6 +1341,7 @@ def test_live_connection_reconnect_cram_failure( """ Test that a failed connection can reconnect. """ + # Arrange # Dork up the API key in the mock client to fail CRAM bucket_id = test_api_key[-BUCKET_ID_LENGTH:] invalid_key = "db-invalidkey00000000000000FFFFF" @@ -1186,6 +1353,7 @@ def test_live_connection_reconnect_cram_failure( port=mock_live_server.port, ) + # Act, Assert with pytest.raises(BentoError) as exc: live_client.subscribe( dataset=Dataset.GLBX_MDP3, @@ -1210,6 +1378,7 @@ async def test_live_callback_exception_handler( Test exceptions that occur during callbacks are dispatched to the assigned exception handler. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, @@ -1224,9 +1393,11 @@ def callback(_: DBNRecord) -> None: live_client.add_callback(callback, exceptions.append) + # Act live_client.start() - await live_client.wait_for_close() + + # Assert assert len(exceptions) == 4 @@ -1237,6 +1408,7 @@ async def test_live_stream_exception_handler( Test exceptions that occur during stream writes are dispatched to the assigned exception handler. """ + # Arrange live_client.subscribe( dataset=Dataset.GLBX_MDP3, schema=Schema.MBO, @@ -1250,7 +1422,9 @@ async def test_live_stream_exception_handler( live_client.add_stream(stream, exceptions.append) stream.close() + # Act live_client.start() + # Assert await live_client.wait_for_close() assert len(exceptions) == 5 # extra write from metadata diff --git a/tests/test_live_gateway_messages.py b/tests/test_live_gateway_messages.py index e07a4ab..30b5af3 100644 --- a/tests/test_live_gateway_messages.py +++ b/tests/test_live_gateway_messages.py @@ -65,6 +65,7 @@ def test_parse_authentication_request( """ Test that a AuthenticationRequest is parsed from a string as expected. """ + # Arrange, Act, Assert if isinstance(expected, tuple): msg = AuthenticationRequest.parse(line) assert ( @@ -108,6 +109,7 @@ def test_serialize_authentication_request( """ Test that a AuthenticationRequest is serialized as expected. """ + # Arrange, Act, Assert assert bytes(message) == expected @@ -126,6 +128,7 @@ def test_parse_authentication_response( """ Test that a AuthenticationResponse is parsed from a string as expected. """ + # Arrange, Act, Assert if isinstance(expected, tuple): msg = AuthenticationResponse.parse(line) assert (msg.success, msg.session_id) == expected @@ -160,6 +163,7 @@ def test_serialize_authentication_response( """ Test that a AuthenticationResponse is serialized as expected. """ + # Arrange, Act, Assert assert bytes(message) == expected @@ -178,6 +182,7 @@ def test_parse_challenge_request( """ Test that a ChallengeRequest is parsed from a string as expected. """ + # Arrange, Act, Assert if isinstance(expected, str): msg = ChallengeRequest.parse(line) assert msg.cram == expected @@ -199,6 +204,7 @@ def test_serialize_challenge_request( """ Test that a ChallengeRequest is serialized as expected. """ + # Arrange, Act, Assert assert bytes(message) == expected @@ -217,6 +223,7 @@ def test_parse_greeting( """ Test that a Greeting is parsed from a string as expected. """ + # Arrange, Act, Assert if isinstance(expected, str): msg = Greeting.parse(line) assert msg.lsg_version == expected @@ -238,6 +245,7 @@ def test_serialize_greeting( """ Test that a Greeting is serialized as expected. """ + # Arrange, Act, Assert assert bytes(message) == expected @@ -257,6 +265,7 @@ def test_parse_session_start( """ Test that a SessionStart is parsed from a string as expected. """ + # Arrange, Act, Assert if isinstance(expected, str): msg = SessionStart.parse(line) assert msg.start_session == expected @@ -278,6 +287,7 @@ def test_serialize_session_start( """ Test that a SessionStart is serialized as expected. """ + # Arrange, Act, Assert assert bytes(message) == expected @@ -314,6 +324,7 @@ def test_parse_subscription_request( """ Test that a SubscriptionRequest is parsed from a string as expected. """ + # Arrange, Act, Assert if isinstance(expected, tuple): msg = SubscriptionRequest.parse(line) assert ( @@ -359,6 +370,7 @@ def test_serialize_subscription_request( """ Test that a SubscriptionRequest is serialized as expected. """ + # Arrange, Act, Assert assert bytes(message) == expected @@ -378,5 +390,6 @@ def test_parse_bad_key(message_type: GatewayControl, line: str) -> None: """ Test that a ValueError is raised when parsing fails for general cases. """ + # Arrange, Act, Assert with pytest.raises(ValueError): message_type.parse(line) diff --git a/tests/test_live_protocol.py b/tests/test_live_protocol.py index 78bed80..a54238d 100644 --- a/tests/test_live_protocol.py +++ b/tests/test_live_protocol.py @@ -17,6 +17,7 @@ async def test_protocol_connection( Test the low-level DatabentoLiveProtocol can be used to establish a connection to the live subscription gateway. """ + # Arrange transport, protocol = await asyncio.get_event_loop().create_connection( protocol_factory=lambda: DatabentoLiveProtocol( api_key=test_api_key, @@ -26,10 +27,9 @@ async def test_protocol_connection( port=mock_live_server.port, ) + # Act, Assert await asyncio.wait_for(protocol.authenticated, timeout=1) - transport.close() - await asyncio.wait_for(protocol.disconnected, timeout=1) @@ -42,6 +42,7 @@ async def test_protocol_connection_streaming( Test the low-level DatabentoLiveProtocol can be used to stream DBN records from the live subscription gateway. """ + # Arrange monkeypatch.setattr( DatabentoLiveProtocol, "received_metadata", metadata_mock := MagicMock(), ) @@ -49,7 +50,7 @@ async def test_protocol_connection_streaming( DatabentoLiveProtocol, "received_record", record_mock := MagicMock(), ) - transport, protocol = await asyncio.get_event_loop().create_connection( + _, protocol = await asyncio.get_event_loop().create_connection( protocol_factory=lambda: DatabentoLiveProtocol( api_key=test_api_key, dataset="TEST", @@ -60,6 +61,7 @@ async def test_protocol_connection_streaming( await asyncio.wait_for(protocol.authenticated, timeout=1) + # Act protocol.subscribe( schema=Schema.MBO, symbols="TEST", @@ -68,8 +70,8 @@ async def test_protocol_connection_streaming( protocol.start() await asyncio.wait_for(protocol.started.wait(), timeout=1) - await asyncio.wait_for(protocol.disconnected, timeout=1) + # Assert assert metadata_mock.call_count == 1 assert record_mock.call_count == 4 diff --git a/tests/test_release.py b/tests/test_release.py index d2c0756..43d1433 100644 --- a/tests/test_release.py +++ b/tests/test_release.py @@ -27,6 +27,7 @@ def fixture_changelog() -> str: str """ + # Arrange, Act, Assert with open(PROJECT_ROOT / "CHANGELOG.md") as changelog: return changelog.read() @@ -41,6 +42,7 @@ def fixture_pyproject_version() -> str: str """ + # Arrange, Act, Assert with open(PROJECT_ROOT / "pyproject.toml", "rb") as pyproject: data = tomli.load(pyproject) return data["tool"]["poetry"]["version"] @@ -59,6 +61,7 @@ def test_release_changelog(changelog: str, pyproject_version: str) -> None: - The release dates are chronological. """ + # Arrange, Act releases = CHANGELOG_RELEASE_TITLE.findall(changelog) try: @@ -75,6 +78,7 @@ def test_release_changelog(changelog: str, pyproject_version: str) -> None: # This could happen if we have TBD as the release date. raise AssertionError("Failed to parse release date from CHANGELOG.md") + # Assert # Ensure latest version matches `__version__` assert databento.__version__ == versions[0]