diff --git a/CHANGELOG.md b/CHANGELOG.md index 35f7a86..05789e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## 0.23.0 - 2023-10-26 + +#### Enhancements +- Added `map_symbols_csv` function to the `databento` module for using `symbology.json` files to map a symbol column onto a CSV file +- Added `map_symbols_json` function to the `databento` module for using `symbology.json` files to add a symbol key to a file of JSON records +- Added new publisher values in preparation for IFEU.IMPACT and NDEX.IMPACT datasets + +#### Bug fixes +- Fixed issue where a large unreadable symbol subscription message could be sent +- Fixed an issue where `DBNStore.to_df` with `pretty_ts=True` was very slow + ## 0.22.1 - 2023-10-24 #### Bug fixes diff --git a/databento/__init__.py b/databento/__init__.py index c06776e..92f0c99 100644 --- a/databento/__init__.py +++ b/databento/__init__.py @@ -19,6 +19,7 @@ from databento_dbn import TradeMsg from databento.common import bentologging +from databento.common import symbology from databento.common.dbnstore import DBNStore from databento.common.enums import Delivery from databento.common.enums import FeedMode @@ -35,6 +36,7 @@ from databento.common.publishers import Dataset from databento.common.publishers import Publisher from databento.common.publishers import Venue +from databento.common.symbology import InstrumentMap from databento.historical.api import API_VERSION from databento.historical.client import Historical from databento.live import DBNRecord @@ -60,6 +62,7 @@ "RecordFlags", "Historical", "HistoricalGateway", + "InstrumentMap", "Live", "Packaging", "RollRule", @@ -91,3 +94,5 @@ # Convenience imports enable_logging = bentologging.enable_logging from_dbn = DBNStore.from_file +map_symbols_csv = symbology.map_symbols_csv +map_symbols_json = symbology.map_symbols_json diff --git a/databento/common/bentologging.py b/databento/common/bentologging.py index 44f5bf4..47796c5 100644 --- a/databento/common/bentologging.py +++ b/databento/common/bentologging.py @@ -6,7 +6,7 @@ def enable_logging(level: int | str = logging.INFO) -> None: """ Enable logging for the Databento module. This function should be used for - simple applications and examples. It is advisible to configure your own + simple applications and examples. It is advisable to configure your own logging for serious applications. Parameters diff --git a/databento/common/constants.py b/databento/common/constants.py new file mode 100644 index 0000000..4176b4a --- /dev/null +++ b/databento/common/constants.py @@ -0,0 +1 @@ +ALL_SYMBOLS = "ALL_SYMBOLS" diff --git a/databento/common/dbnstore.py b/databento/common/dbnstore.py index 4f38287..c9c8b63 100644 --- a/databento/common/dbnstore.py +++ b/databento/common/dbnstore.py @@ -7,7 +7,6 @@ import warnings from collections.abc import Generator from collections.abc import Iterator -from functools import partial from io import BytesIO from os import PathLike from pathlib import Path @@ -1112,8 +1111,8 @@ def _transcode( compression=compression, pretty_px=pretty_px, pretty_ts=pretty_ts, - map_symbols=map_symbols, has_metadata=True, + map_symbols=map_symbols, symbol_map=symbol_map, # type: ignore [arg-type] schema=schema, ) @@ -1246,9 +1245,7 @@ def _format_px( def _format_pretty_ts(self, df: pd.DataFrame) -> None: for field in self._struct._timestamp_fields: - df[field] = df[field].apply( - partial(pd.to_datetime, utc=True, errors="coerce"), - ) + df[field] = pd.to_datetime(df[field], utc=True, errors="coerce") def _format_set_index(self, df: pd.DataFrame) -> None: index_column = ( diff --git a/databento/common/parsing.py b/databento/common/parsing.py index ea1a54c..cfa44aa 100644 --- a/databento/common/parsing.py +++ b/databento/common/parsing.py @@ -9,7 +9,7 @@ import pandas as pd from databento_dbn import SType -from databento.common.symbology import ALL_SYMBOLS +from databento.common.constants import ALL_SYMBOLS from databento.common.validation import validate_smart_symbol diff --git a/databento/common/publishers.py b/databento/common/publishers.py index 69b7134..08f7a01 100644 --- a/databento/common/publishers.py +++ b/databento/common/publishers.py @@ -91,6 +91,10 @@ class Venue(StringyMixin, str, Enum): Cboe BZX Options Exchange. MXOP MEMX LLC Options. + IFEU + ICE Futures Europe (Commodities). + NDEX + ICE Endex. """ @@ -131,6 +135,8 @@ class Venue(StringyMixin, str, Enum): XPHL = "XPHL" BATO = "BATO" MXOP = "MXOP" + IFEU = "IFEU" + NDEX = "NDEX" @classmethod def from_int(cls, value: int) -> Venue: @@ -211,6 +217,10 @@ def from_int(cls, value: int) -> Venue: return Venue.BATO if value == 37: return Venue.MXOP + if value == 38: + return Venue.IFEU + if value == 39: + return Venue.NDEX raise ValueError(f"Integer value {value} does not correspond with any Venue variant") def to_int(self) -> int: @@ -291,6 +301,10 @@ def to_int(self) -> int: return 36 if self == Venue.MXOP: return 37 + if self == Venue.IFEU: + return 38 + if self == Venue.NDEX: + return 39 raise ValueError("Invalid Venue") @property @@ -372,6 +386,10 @@ def description(self) -> str: return "Cboe BZX Options Exchange" if self == Venue.MXOP: return "MEMX LLC Options" + if self == Venue.IFEU: + return "ICE Futures Europe (Commodities)" + if self == Venue.NDEX: + return "ICE Endex" raise ValueError("Unexpected Venue value") @unique @@ -434,6 +452,10 @@ class Dataset(StringyMixin, str, Enum): Nasdaq QBBO. XNAS_NLS Nasdaq NLS. + IFEU_IMPACT + ICE Futures Europe (Commodities) iMpact. + NDEX_IMPACT + ICE Endex iMpact. """ @@ -464,6 +486,8 @@ class Dataset(StringyMixin, str, Enum): XNYS_TRADES = "XNYS.TRADES" XNAS_QBBO = "XNAS.QBBO" XNAS_NLS = "XNAS.NLS" + IFEU_IMPACT = "IFEU.IMPACT" + NDEX_IMPACT = "NDEX.IMPACT" @classmethod def from_int(cls, value: int) -> Dataset: @@ -524,6 +548,10 @@ def from_int(cls, value: int) -> Dataset: return Dataset.XNAS_QBBO if value == 27: return Dataset.XNAS_NLS + if value == 28: + return Dataset.IFEU_IMPACT + if value == 29: + return Dataset.NDEX_IMPACT raise ValueError(f"Integer value {value} does not correspond with any Dataset variant") def to_int(self) -> int: @@ -584,6 +612,10 @@ def to_int(self) -> int: return 26 if self == Dataset.XNAS_NLS: return 27 + if self == Dataset.IFEU_IMPACT: + return 28 + if self == Dataset.NDEX_IMPACT: + return 29 raise ValueError("Invalid Dataset") @property @@ -645,6 +677,10 @@ def description(self) -> str: return "Nasdaq QBBO" if self == Dataset.XNAS_NLS: return "Nasdaq NLS" + if self == Dataset.IFEU_IMPACT: + return "ICE Futures Europe (Commodities) iMpact" + if self == Dataset.NDEX_IMPACT: + return "ICE Endex iMpact" raise ValueError("Unexpected Dataset value") @unique @@ -765,6 +801,10 @@ class Publisher(StringyMixin, str, Enum): DBEQ Plus - FINRA/Nasdaq TRF Carteret. DBEQ_PLUS_FINC DBEQ Plus - FINRA/Nasdaq TRF Chicago. + IFEU_IMPACT_IFEU + ICE Futures Europe (Commodities). + NDEX_IMPACT_NDEX + ICE Endex. """ @@ -824,6 +864,8 @@ class Publisher(StringyMixin, str, Enum): DBEQ_PLUS_FINN = "DBEQ.PLUS.FINN" DBEQ_PLUS_FINY = "DBEQ.PLUS.FINY" DBEQ_PLUS_FINC = "DBEQ.PLUS.FINC" + IFEU_IMPACT_IFEU = "IFEU.IMPACT.IFEU" + NDEX_IMPACT_NDEX = "NDEX.IMPACT.NDEX" @classmethod def from_int(cls, value: int) -> Publisher: @@ -942,6 +984,10 @@ def from_int(cls, value: int) -> Publisher: return Publisher.DBEQ_PLUS_FINY if value == 56: return Publisher.DBEQ_PLUS_FINC + if value == 57: + return Publisher.IFEU_IMPACT_IFEU + if value == 58: + return Publisher.NDEX_IMPACT_NDEX raise ValueError(f"Integer value {value} does not correspond with any Publisher variant") def to_int(self) -> int: @@ -1060,6 +1106,10 @@ def to_int(self) -> int: return 55 if self == Publisher.DBEQ_PLUS_FINC: return 56 + if self == Publisher.IFEU_IMPACT_IFEU: + return 57 + if self == Publisher.NDEX_IMPACT_NDEX: + return 58 raise ValueError("Invalid Publisher") @property def venue(self) -> Venue: @@ -1178,6 +1228,10 @@ def venue(self) -> Venue: return Venue.FINY if self == Publisher.DBEQ_PLUS_FINC: return Venue.FINC + if self == Publisher.IFEU_IMPACT_IFEU: + return Venue.IFEU + if self == Publisher.NDEX_IMPACT_NDEX: + return Venue.NDEX raise ValueError("Unexpected Publisher value") @property def dataset(self) -> Dataset: @@ -1296,6 +1350,10 @@ def dataset(self) -> Dataset: return Dataset.DBEQ_PLUS if self == Publisher.DBEQ_PLUS_FINC: return Dataset.DBEQ_PLUS + if self == Publisher.IFEU_IMPACT_IFEU: + return Dataset.IFEU_IMPACT + if self == Publisher.NDEX_IMPACT_NDEX: + return Dataset.NDEX_IMPACT raise ValueError("Unexpected Publisher value") @property @@ -1415,4 +1473,8 @@ def description(self) -> str: return "DBEQ Plus - FINRA/Nasdaq TRF Carteret" if self == Publisher.DBEQ_PLUS_FINC: return "DBEQ Plus - FINRA/Nasdaq TRF Chicago" + if self == Publisher.IFEU_IMPACT_IFEU: + return "ICE Futures Europe (Commodities)" + if self == Publisher.NDEX_IMPACT_NDEX: + return "ICE Endex" raise ValueError("Unexpected Publisher value") diff --git a/databento/common/symbology.py b/databento/common/symbology.py index 65dadb0..f2c0a51 100644 --- a/databento/common/symbology.py +++ b/databento/common/symbology.py @@ -1,12 +1,15 @@ from __future__ import annotations import bisect +import csv import datetime as dt import functools import json from collections import defaultdict from collections.abc import Mapping from io import TextIOWrapper +from os import PathLike +from pathlib import Path from typing import Any, ClassVar, NamedTuple, TextIO import pandas as pd @@ -15,8 +18,7 @@ from databento_dbn import SType from databento_dbn import SymbolMappingMsg - -ALL_SYMBOLS = "ALL_SYMBOLS" +from databento.common.parsing import datetime_to_unix_nanoseconds class MappingInterval(NamedTuple): @@ -39,6 +41,126 @@ class MappingInterval(NamedTuple): symbol: str +def _validate_path_pair( + in_file: Path | PathLike[str] | str, + out_file: Path | PathLike[str] | str | None, +) -> tuple[Path, Path]: + in_file_valid = Path(in_file) + + if not in_file_valid.exists(): + raise ValueError(f"{in_file_valid} does not exist") + if not in_file_valid.is_file(): + raise ValueError(f"{in_file_valid} is not a file") + + if out_file is not None: + out_file_valid = Path(out_file) + else: + out_file_valid = in_file_valid.with_name( + f"{in_file_valid.stem}_mapped{in_file_valid.suffix}", + ) + + i = 0 + while out_file_valid.exists(): + out_file_valid = in_file_valid.with_name( + f"{in_file_valid.stem}_mapped_{i}{in_file_valid.suffix}", + ) + i += 1 + + if in_file_valid == out_file_valid: + raise ValueError("The input file cannot be the same path as the output file.") + + return in_file_valid, out_file_valid + + +def map_symbols_csv( + symbology_file: Path | PathLike[str] | str, + csv_file: Path | PathLike[str] | str, + out_file: Path | PathLike[str] | str | None = None, +) -> Path: + """ + Use a `symbology.json` file to map a symbols column onto an existing CSV + file. The result is written to `out_file`. + + Parameters + ---------- + symbology_file: Path | PathLike[str] | str + Path to a `symbology.json` file to use as a symbology source. + csv_file: Path | PathLike[str] | str + Path to a CSV file that contains encoded DBN data; must contain + a `ts_recv` or `ts_event` and `instrument_id` column. + out_file: Path | PathLike[str] | str (optional) + Path to a file to write results to. If unspecified, `_mapped` will be + appended to the `csv_file` name. + + Returns + ------- + Path + The path to the written file. + + Raises + ------ + ValueError + When the input or output paths are invalid. + When the input CSV file does not contain a valid timestamp or instrument_id column. + + See Also + -------- + map_symbols_json + + """ + instrument_map = InstrumentMap() + with open(symbology_file) as input_symbology: + instrument_map.insert_json(json.load(input_symbology)) + return instrument_map.map_symbols_csv( + csv_file=csv_file, + out_file=out_file, + ) + + +def map_symbols_json( + symbology_file: Path | PathLike[str] | str, + json_file: Path | PathLike[str] | str, + out_file: Path | PathLike[str] | str | None = None, +) -> Path: + """ + Use a `symbology.json` file to insert a symbols key into records of an + existing JSON file. The result is written to `out_file`. + + Parameters + ---------- + symbology_file: Path | PathLike[str] | str + Path to a `symbology.json` file to use as a symbology source. + json_file: Path | PathLike[str] | str + Path to a JSON file that contains encoded DBN data. + out_file: Path | PathLike[str] | str (optional) + Path to a file to write results to. If unspecified, `_mapped` will be + appended to the `json_file` name. + + Returns + ------- + Path + The path to the written file. + + Raises + ------ + ValueError + When the input or output paths are invalid. + When the input JSON file does not contain a valid record. + + See Also + -------- + map_symbols_csv + + """ + instrument_map = InstrumentMap() + with open(symbology_file) as input_symbology: + instrument_map.insert_json(json.load(input_symbology)) + return instrument_map.map_symbols_json( + json_file=json_file, + out_file=out_file, + ) + + class InstrumentMap: SYMBOLOGY_RESOLVE_KEYS: ClassVar[tuple[str, ...]] = ( "result", @@ -94,7 +216,7 @@ def resolve( If the InstrumentMap does not contain a mapping for the `instrument_id`. """ - mappings = self._data[instrument_id] + mappings = self._data[int(instrument_id)] for entry in mappings: if entry.start_date <= date < entry.end_date: return entry.symbol @@ -270,6 +392,154 @@ def insert_json( ), ) + def map_symbols_csv( + self, + csv_file: Path | PathLike[str] | str, + out_file: Path | PathLike[str] | str | None = None, + ) -> Path: + """ + Use the loaded symbology data to map a symbols column onto an existing + CSV file. The result is written to `out_file`. + + Parameters + ---------- + csv_file: Path | PathLike[str] | str + Path to a CSV file that contains encoded DBN data; must contain + a `ts_recv` or `ts_event` and `instrument_id` column. + out_file: Path | PathLike[str] | str (optional) + Path to a file to write results to. If unspecified, `_mapped` will be + appended to the `csv_file` name. + + Returns + ------- + Path + The path to the written file. + + Raises + ------ + ValueError + When the input or output paths are invalid. + When the input CSV file does not contain a valid timestamp or instrument_id column. + + See Also + -------- + InstrumentMap.map_symbols_json + + """ + csv_file_valid, out_file_valid = _validate_path_pair(csv_file, out_file) + + with csv_file_valid.open() as input_: + reader = csv.DictReader(input_) + + in_fields = reader.fieldnames + + if in_fields is None: + raise ValueError(f"no CSV header in {csv_file}") + + if "ts_recv" in in_fields: + ts_field = "ts_recv" + elif "ts_event" in in_fields: + ts_field = "ts_event" + else: + raise ValueError( + f"{csv_file} does not have a 'ts_recv' or 'ts_event' column", + ) + + if "instrument_id" not in in_fields: + raise ValueError(f"{csv_file} does not have an 'instrument_id' column") + + out_fields = (*in_fields, "symbol") + + with out_file_valid.open("w") as output: + writer = csv.DictWriter( + output, + fieldnames=out_fields, + lineterminator="\n", + ) + writer.writeheader() + + for row in reader: + ts = datetime_to_unix_nanoseconds(row[ts_field]) + date = pd.Timestamp(ts, unit="ns").date() + instrument_id = row["instrument_id"] + if instrument_id is None: + row["symbol"] = "" + else: + row["symbol"] = self.resolve(instrument_id, date) + + writer.writerow(row) + + return out_file_valid + + def map_symbols_json( + self, + json_file: Path | PathLike[str] | str, + out_file: Path | PathLike[str] | str | None = None, + ) -> Path: + """ + Use the loaded symbology data to insert a symbols key into records of + an existing JSON file. The result is written to `out_file`. + + Parameters + ---------- + json_file: Path | PathLike[str] | str + Path to a JSON file that contains encoded DBN data. + out_file: Path | PathLike[str] | str (optional) + Path to a file to write results to. If unspecified, `_mapped` will be + appended to the `json_file` name. + + Returns + ------- + Path + The path to the written file. + + Raises + ------ + ValueError + When the input or output paths are invalid. + When the input JSON file does not contain a valid record. + + See Also + -------- + InstrumentMap.map_symbols_csv + + """ + json_file_valid, out_file_valid = _validate_path_pair(json_file, out_file) + + with json_file_valid.open() as input_: + with out_file_valid.open("w") as output: + for i, record in enumerate(map(json.loads, input_)): + try: + header = record["hd"] + instrument_id = header["instrument_id"] + except KeyError: + raise ValueError( + f"{json_file}:{i} does not contain a valid JSON encoded record", + ) + + if "ts_recv" in record: + ts_field = record["ts_recv"] + elif "ts_event" in header: + ts_field = header["ts_event"] + else: + raise ValueError( + f"{json_file}:{i} does not have a 'ts_recv' or 'ts_event' key", + ) + + ts = datetime_to_unix_nanoseconds(ts_field) + + date = pd.Timestamp(ts, unit="ns").date() + record["symbol"] = self.resolve(instrument_id, date) + + json.dump( + record, + output, + separators=(",", ":"), + ) + output.write("\n") + + return out_file_valid + def _insert_inverval(self, instrument_id: int, interval: MappingInterval) -> None: """ Insert a SymbolInterval into the map. diff --git a/databento/live/client.py b/databento/live/client.py index 2b09c71..03a74d3 100644 --- a/databento/live/client.py +++ b/databento/live/client.py @@ -16,11 +16,11 @@ from databento_dbn import Schema from databento_dbn import SType +from databento.common.constants import ALL_SYMBOLS from databento.common.cram import BUCKET_ID_LENGTH from databento.common.error import BentoError from databento.common.parsing import optional_datetime_to_unix_nanoseconds from databento.common.publishers import Dataset -from databento.common.symbology import ALL_SYMBOLS from databento.common.validation import validate_enum from databento.common.validation import validate_semantic_string from databento.live import DBNRecord diff --git a/databento/live/protocol.py b/databento/live/protocol.py index e71a415..04e6d74 100644 --- a/databento/live/protocol.py +++ b/databento/live/protocol.py @@ -5,18 +5,19 @@ from collections.abc import Iterable from functools import singledispatchmethod from numbers import Number +from typing import Final import databento_dbn from databento_dbn import Schema from databento_dbn import SType from databento.common import cram +from databento.common.constants import ALL_SYMBOLS from databento.common.error import BentoError from databento.common.iterator import chunk from databento.common.parsing import optional_datetime_to_unix_nanoseconds from databento.common.parsing import optional_symbols_list_to_list from databento.common.publishers import Dataset -from databento.common.symbology import ALL_SYMBOLS from databento.common.validation import validate_enum from databento.common.validation import validate_semantic_string from databento.live import DBNRecord @@ -30,7 +31,8 @@ from databento.live.gateway import SubscriptionRequest -RECV_BUFFER_SIZE: int = 64 * 2**10 # 64kb +RECV_BUFFER_SIZE: Final = 64 * 2**10 # 64kb +SYMBOL_LIST_BATCH_SIZE: Final = 64 logger = logging.getLogger(__name__) @@ -278,7 +280,7 @@ def subscribe( stype_in_valid = validate_enum(stype_in, SType, "stype_in") symbols_list = optional_symbols_list_to_list(symbols, stype_in_valid) - for batch in chunk(symbols_list, 128): + for batch in chunk(symbols_list, SYMBOL_LIST_BATCH_SIZE): batch_str = ",".join(batch) message = SubscriptionRequest( schema=validate_enum(schema, Schema, "schema"), diff --git a/databento/live/session.py b/databento/live/session.py index 4bfc602..cc3e447 100644 --- a/databento/live/session.py +++ b/databento/live/session.py @@ -14,9 +14,9 @@ from databento_dbn import Schema from databento_dbn import SType +from databento.common.constants import ALL_SYMBOLS from databento.common.error import BentoError from databento.common.publishers import Dataset -from databento.common.symbology import ALL_SYMBOLS from databento.live import AUTH_TIMEOUT_SECONDS from databento.live import CONNECT_TIMEOUT_SECONDS from databento.live import DBNRecord diff --git a/databento/version.py b/databento/version.py index d74a474..8b301a7 100644 --- a/databento/version.py +++ b/databento/version.py @@ -1 +1 @@ -__version__ = "0.22.1" +__version__ = "0.23.0" diff --git a/pyproject.toml b/pyproject.toml index 4bc144e..d39b4f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databento" -version = "0.22.1" +version = "0.23.0" description = "Official Python client library for Databento" authors = [ "Databento ", diff --git a/tests/data/test_data.definition.dbn.zst b/tests/data/test_data.definition.dbn.zst index 76c7c3e..181e410 100644 Binary files a/tests/data/test_data.definition.dbn.zst and b/tests/data/test_data.definition.dbn.zst differ diff --git a/tests/data/test_data.statistics.dbn.zst b/tests/data/test_data.statistics.dbn.zst index 31cab7f..8a86fb0 100644 Binary files a/tests/data/test_data.statistics.dbn.zst and b/tests/data/test_data.statistics.dbn.zst differ diff --git a/tests/test_common_symbology.py b/tests/test_common_symbology.py index b1abf45..c70f264 100644 --- a/tests/test_common_symbology.py +++ b/tests/test_common_symbology.py @@ -3,10 +3,11 @@ import json import pathlib from collections.abc import Iterable -from typing import NamedTuple +from typing import Callable, NamedTuple import pandas as pd import pytest +from databento.common.dbnstore import DBNStore from databento.common.symbology import InstrumentMap from databento.common.symbology import MappingInterval from databento_dbn import UNDEF_TIMESTAMP @@ -713,3 +714,95 @@ def test_instrument_map_ignore_duplicate( instrument_map.insert_symbol_mapping_msg(msg) assert len(instrument_map._data[instrument_id]) == 1 + + +@pytest.mark.parametrize( + "schema", + [pytest.param(s, id=str(s)) for s in Schema.variants()], +) +@pytest.mark.parametrize( + "pretty_ts", + [ + True, + False, + ], +) +def test_instrument_map_symbols_csv( + tmp_path: pathlib.Path, + test_data_path: Callable[[Schema], pathlib.Path], + pretty_ts: bool, + schema: Schema, +) -> None: + """ + Test that a CSV file without mapped symbols is equivelant to a CSV file + with mapped symbols after processing with map_symbols_csv. + """ + store = DBNStore.from_file(test_data_path(schema)) + csv_path = tmp_path / f"test_{schema}.csv" + store.to_csv( + csv_path, + pretty_ts=pretty_ts, + map_symbols=False, + ) + + expected_path = tmp_path / "expected.csv" + store.to_csv( + expected_path, + pretty_ts=pretty_ts, + map_symbols=True, + ) + + outfile = tmp_path / f"test_{schema}_mapped.csv" + written_path = store._instrument_map.map_symbols_csv( + csv_file=csv_path, + out_file=outfile, + ) + + assert outfile == written_path + assert outfile.read_text() == expected_path.read_text() + + +@pytest.mark.parametrize( + "schema", + [pytest.param(s, id=str(s)) for s in Schema.variants()], +) +@pytest.mark.parametrize( + "pretty_ts", + [ + True, + False, + ], +) +def test_instrument_map_symbols_json( + tmp_path: pathlib.Path, + test_data_path: Callable[[Schema], pathlib.Path], + pretty_ts: bool, + schema: Schema, +) -> None: + """ + Test that a JSON file without mapped symbols is equivelant to a JSON file + with mapped symbols after processing with map_symbols_json. + """ + store = DBNStore.from_file(test_data_path(schema)) + json_path = tmp_path / f"test_{schema}.json" + store.to_json( + json_path, + pretty_ts=pretty_ts, + map_symbols=False, + ) + + expected_path = tmp_path / "expected.json" + store.to_json( + expected_path, + pretty_ts=pretty_ts, + map_symbols=True, + ) + + outfile = tmp_path / f"test_{schema}_mapped.json" + written_path = store._instrument_map.map_symbols_json( + json_file=json_path, + out_file=outfile, + ) + + assert outfile == written_path + assert outfile.read_text() == expected_path.read_text() diff --git a/tests/test_historical_bento.py b/tests/test_historical_bento.py index 3b085c0..dd15a03 100644 --- a/tests/test_historical_bento.py +++ b/tests/test_historical_bento.py @@ -1017,7 +1017,11 @@ def test_dbnstore_to_df_with_count( assert len(batch) <= count aggregator.append(batch) - assert expected.equals(pd.concat(aggregator)) + pd.testing.assert_frame_equal( + pd.concat(aggregator), + expected, + check_dtype=False, + ) @pytest.mark.parametrize( diff --git a/tests/test_live_client.py b/tests/test_live_client.py index 628b528..c270a04 100644 --- a/tests/test_live_client.py +++ b/tests/test_live_client.py @@ -14,12 +14,12 @@ import databento_dbn import pytest import zstandard +from databento.common.constants import ALL_SYMBOLS from databento.common.cram import BUCKET_ID_LENGTH from databento.common.data import SCHEMA_STRUCT_MAP from databento.common.dbnstore import DBNStore from databento.common.error import BentoError from databento.common.publishers import Dataset -from databento.common.symbology import ALL_SYMBOLS from databento.live import DBNRecord from databento.live import client from databento.live import gateway @@ -399,14 +399,24 @@ async def test_live_subscribe_large_symbol_list( first_message = mock_live_server.get_message_of_type( gateway.SubscriptionRequest, timeout=1, - ) + ).symbols.split(",") second_message = mock_live_server.get_message_of_type( gateway.SubscriptionRequest, timeout=1, - ) + ).symbols.split(",") + + third_message = mock_live_server.get_message_of_type( + gateway.SubscriptionRequest, + timeout=1, + ).symbols.split(",") + + fourth_message = mock_live_server.get_message_of_type( + gateway.SubscriptionRequest, + timeout=1, + ).symbols.split(",") - reconstructed = first_message.symbols.split(",") + second_message.symbols.split(",") + reconstructed = first_message + second_message + third_message + fourth_message assert reconstructed == large_symbol_list