1
1
from __future__ import annotations
2
2
3
3
import abc
4
- import datetime as dt
5
4
import itertools
6
5
import logging
7
6
from collections .abc import Generator
30
29
from databento .common .data import SCHEMA_DTYPES_MAP
31
30
from databento .common .data import SCHEMA_STRUCT_MAP
32
31
from databento .common .error import BentoError
33
- from databento .common .symbology import InstrumentIdMappingInterval
32
+ from databento .common .symbology import InstrumentMap
34
33
from databento .common .validation import validate_file_write_path
35
34
from databento .common .validation import validate_maybe_enum
36
35
from databento .live import DBNRecord
@@ -98,7 +97,6 @@ def format_dataframe(
98
97
schema : Schema ,
99
98
pretty_px : bool ,
100
99
pretty_ts : bool ,
101
- instrument_id_index : dict [dt .date , dict [int , str ]],
102
100
) -> pd .DataFrame :
103
101
struct = SCHEMA_STRUCT_MAP [schema ]
104
102
@@ -122,13 +120,6 @@ def format_dataframe(
122
120
index_column = "ts_event" if schema .value .startswith ("ohlcv" ) else "ts_recv"
123
121
df .set_index (index_column , inplace = True )
124
122
125
- if instrument_id_index :
126
- df_index = df .index if pretty_ts else pd .to_datetime (df .index , utc = True )
127
- dates = [ts .date () for ts in df_index ]
128
- df ["symbol" ] = [
129
- instrument_id_index [dates [i ]][p ] for i , p in enumerate (df ["instrument_id" ])
130
- ]
131
-
132
123
return df
133
124
134
125
@@ -252,7 +243,12 @@ class MemoryDataSource(DataSource):
252
243
"""
253
244
254
245
def __init__ (self , source : BytesIO | bytes | IO [bytes ]):
255
- initial_data = source if isinstance (source , bytes ) else source .read ()
246
+ if isinstance (source , bytes ):
247
+ initial_data = source
248
+ else :
249
+ source .seek (0 )
250
+ initial_data = source .read ()
251
+
256
252
if len (initial_data ) == 0 :
257
253
raise ValueError (
258
254
f"Cannot create data source from empty { type (source ).__name__ } " ,
@@ -397,11 +393,7 @@ def __init__(self, data_source: DataSource) -> None:
397
393
metadata_bytes .getvalue (),
398
394
)
399
395
400
- # This is populated when _map_symbols is called
401
- self ._instrument_id_index : dict [
402
- dt .date ,
403
- dict [int , str ],
404
- ] = {}
396
+ self ._instrument_map = InstrumentMap ()
405
397
406
398
def __iter__ (self ) -> Generator [DBNRecord , None , None ]:
407
399
reader = self .reader
@@ -417,6 +409,8 @@ def __iter__(self) -> Generator[DBNRecord, None, None]:
417
409
for record in records :
418
410
if isinstance (record , databento_dbn .Metadata ):
419
411
continue
412
+ if isinstance (record , databento_dbn .SymbolMappingMsg ):
413
+ self ._instrument_map .insert_symbol_mapping_msg (record )
420
414
yield record
421
415
else :
422
416
if len (decoder .buffer ()) > 0 :
@@ -429,38 +423,6 @@ def __repr__(self) -> str:
429
423
name = self .__class__ .__name__
430
424
return f"<{ name } (schema={ self .schema } )>"
431
425
432
- def _build_instrument_id_index (self ) -> dict [dt .date , dict [int , str ]]:
433
- intervals : list [InstrumentIdMappingInterval ] = []
434
- for raw_symbol , i in self .mappings .items ():
435
- for row in i :
436
- symbol = row ["symbol" ]
437
- if symbol == "" :
438
- continue
439
- intervals .append (
440
- InstrumentIdMappingInterval (
441
- start_date = row ["start_date" ],
442
- end_date = row ["end_date" ],
443
- raw_symbol = raw_symbol ,
444
- instrument_id = int (row ["symbol" ]),
445
- ),
446
- )
447
-
448
- instrument_id_index : dict [dt .date , dict [int , str ]] = {}
449
- for interval in intervals :
450
- for ts in pd .date_range (
451
- start = interval .start_date ,
452
- end = interval .end_date ,
453
- # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.date_range.html
454
- ** {"inclusive" if pd .__version__ >= "1.4.0" else "closed" : "left" },
455
- ):
456
- d : dt .date = ts .date ()
457
- date_map : dict [int , str ] = instrument_id_index .get (d , {})
458
- if not date_map :
459
- instrument_id_index [d ] = date_map
460
- date_map [interval .instrument_id ] = interval .raw_symbol
461
-
462
- return instrument_id_index
463
-
464
426
@property
465
427
def compression (self ) -> Compression :
466
428
"""
@@ -808,13 +770,20 @@ def request_symbology(self, client: Historical) -> dict[str, Any]:
808
770
date range.
809
771
810
772
"""
773
+ if self .end is None :
774
+ end_date = None
775
+ elif self .start .date () == self .end .date ():
776
+ end_date = (self .start + pd .Timedelta (days = 1 )).date ()
777
+ else :
778
+ end_date = self .end
779
+
811
780
return client .symbology .resolve (
812
781
dataset = self .dataset ,
813
782
symbols = self .symbols ,
814
783
stype_in = self .stype_in ,
815
784
stype_out = self .stype_out ,
816
785
start_date = self .start .date (),
817
- end_date = self . end . date () if self . end else None ,
786
+ end_date = end_date ,
818
787
)
819
788
820
789
def to_csv (
@@ -877,7 +846,7 @@ def to_df(
877
846
self ,
878
847
pretty_px : bool = ...,
879
848
pretty_ts : bool = ...,
880
- map_symbols : bool | None = ...,
849
+ map_symbols : bool = ...,
881
850
schema : Schema | str | None = ...,
882
851
count : None = ...,
883
852
) -> pd .DataFrame :
@@ -888,7 +857,7 @@ def to_df(
888
857
self ,
889
858
pretty_px : bool = ...,
890
859
pretty_ts : bool = ...,
891
- map_symbols : bool | None = ...,
860
+ map_symbols : bool = ...,
892
861
schema : Schema | str | None = ...,
893
862
count : int = ...,
894
863
) -> DataFrameIterator :
@@ -898,7 +867,7 @@ def to_df(
898
867
self ,
899
868
pretty_px : bool = True ,
900
869
pretty_ts : bool = True ,
901
- map_symbols : bool | None = None ,
870
+ map_symbols : bool = True ,
902
871
schema : Schema | str | None = None ,
903
872
count : int | None = None ,
904
873
) -> pd .DataFrame | DataFrameIterator :
@@ -945,29 +914,22 @@ def to_df(
945
914
raise ValueError ("a schema must be specified for mixed DBN data" )
946
915
schema = self .schema
947
916
948
- if map_symbols is None :
949
- map_symbols = self .stype_out == SType .INSTRUMENT_ID
950
-
951
- if map_symbols :
952
- if self .stype_out != SType .INSTRUMENT_ID :
953
- raise ValueError (
954
- "`map_symbols` is not supported when `stype_out` is not 'instrument_id'" ,
955
- )
956
- if not self ._instrument_id_index :
957
- self ._instrument_id_index = self ._build_instrument_id_index ()
958
-
959
917
if count is None :
960
918
records = iter ([self .to_ndarray (schema )])
961
919
else :
962
920
records = self .to_ndarray (schema , count )
963
921
922
+ if map_symbols :
923
+ self ._instrument_map .insert_metadata (self .metadata )
924
+
964
925
df_iter = DataFrameIterator (
965
926
records = records ,
966
927
schema = schema ,
967
928
count = count ,
929
+ instrument_map = self ._instrument_map ,
968
930
pretty_px = pretty_px ,
969
931
pretty_ts = pretty_ts ,
970
- instrument_id_index = self . _instrument_id_index if map_symbols else {} ,
932
+ map_symbols = map_symbols ,
971
933
)
972
934
973
935
if count is None :
@@ -1111,7 +1073,7 @@ def to_ndarray(
1111
1073
1112
1074
dtype = SCHEMA_DTYPES_MAP [schema ]
1113
1075
ndarray_iter = NDArrayIterator (
1114
- filter (lambda r : isinstance (r , SCHEMA_STRUCT_MAP [schema ]), self ), # type: ignore [arg-type]
1076
+ filter (lambda r : isinstance (r , SCHEMA_STRUCT_MAP [schema ]), self ),
1115
1077
dtype ,
1116
1078
count ,
1117
1079
)
@@ -1163,30 +1125,38 @@ def __init__(
1163
1125
records : Iterator [np .ndarray [Any , Any ]],
1164
1126
count : int | None ,
1165
1127
schema : Schema ,
1128
+ instrument_map : InstrumentMap ,
1166
1129
pretty_px : bool = True ,
1167
1130
pretty_ts : bool = True ,
1168
- instrument_id_index : dict [ dt . date , dict [ int , str ]] | None = None ,
1131
+ map_symbols : bool = True ,
1169
1132
):
1170
1133
self ._records = records
1171
1134
self ._schema = schema
1172
1135
self ._count = count
1173
1136
self ._pretty_px = pretty_px
1174
1137
self ._pretty_ts = pretty_ts
1175
- self ._instrument_id_index = (
1176
- instrument_id_index if instrument_id_index is not None else {}
1177
- )
1138
+ self ._map_symbols = map_symbols
1139
+ self ._instrument_map = instrument_map
1178
1140
1179
1141
def __iter__ (self ) -> DataFrameIterator :
1180
1142
return self
1181
1143
1182
1144
def __next__ (self ) -> pd .DataFrame :
1183
- return format_dataframe (
1145
+ df = format_dataframe (
1184
1146
pd .DataFrame (
1185
1147
next (self ._records ),
1186
1148
columns = SCHEMA_COLUMNS [self ._schema ],
1187
1149
),
1188
1150
schema = self ._schema ,
1189
1151
pretty_px = self ._pretty_px ,
1190
1152
pretty_ts = self ._pretty_ts ,
1191
- instrument_id_index = self ._instrument_id_index ,
1192
1153
)
1154
+
1155
+ if self ._map_symbols :
1156
+ df_index = df .index if self ._pretty_ts else pd .to_datetime (df .index , utc = True )
1157
+ dates = [ts .date () for ts in df_index ]
1158
+ df ["symbol" ] = [
1159
+ self ._instrument_map .resolve (inst , dates [i ]) for i , inst in enumerate (df ["instrument_id" ])
1160
+ ]
1161
+
1162
+ return df
0 commit comments