Skip to content

Commit

Permalink
fixes #4
Browse files Browse the repository at this point in the history
  • Loading branch information
phelps-sg committed Sep 22, 2023
1 parent e59fb77 commit 5533742
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
4 changes: 4 additions & 0 deletions tests/test_tardis_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ray.util.client import ray
from zipline import Blotter, TradingAlgorithm, get_calendar
from zipline.assets import AssetDBWriter
from zipline.country import CountryCode
from zipline.data import bundles
from zipline.data.adjustments import SQLiteAdjustmentWriter
from zipline.data.bcolz_daily_bars import BcolzDailyBarReader, BcolzDailyBarWriter
Expand Down Expand Up @@ -284,13 +285,15 @@ def test_generate_empty_metadata():
"symbol",
"calendar_name",
"exchange",
"country_code",
}
assert result.start_date.dtype == dtype("<M8[ns]")
assert result.end_date.dtype == dtype("<M8[ns]")
assert result.auto_close_date.dtype == dtype("<M8[ns]")
assert result.symbol.dtype == dtype("O")
assert result.calendar_name.dtype == dtype("O")
assert result.exchange.dtype == dtype("O")
assert result.country_code.dtype == dtype("O")


def test_generate_metadata():
Expand All @@ -307,6 +310,7 @@ def test_generate_metadata():
pair.symbol,
CALENDAR_24_7,
CALENDAR_24_7,
CountryCode.UNITED_STATES,
)


Expand Down
5 changes: 4 additions & 1 deletion zipline_tardis_bundle/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ray.util.client import RayAPIStub, ray
from tardis_dev import datasets
from zipline.assets import AssetDBWriter
from zipline.country import CountryCode
from zipline.data.adjustments import SQLiteAdjustmentWriter
from zipline.data.bcolz_daily_bars import BcolzDailyBarWriter
from zipline.data.bcolz_minute_bars import BcolzMinuteBarWriter
Expand All @@ -60,7 +61,7 @@

MINUTES_PER_DAY = 60 * 24

_Metadata = Tuple[pd.Timestamp, pd.Timestamp, pd.Timestamp, str, str, str]
_Metadata = Tuple[pd.Timestamp, pd.Timestamp, pd.Timestamp, str, str, str, str]
_IngestPipeline = Iterator[Tuple[int, pd.DataFrame, _Metadata]]

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -342,6 +343,7 @@ def _generate_empty_metadata(pairs: Sized) -> pd.DataFrame:
("symbol", "object"),
("calendar_name", "object"),
("exchange", "object"),
("country_code", "object"),
]
return pd.DataFrame(np.empty(len(pairs), dtype=data_types))

Expand All @@ -363,6 +365,7 @@ def _generate_metadata(pricing_data: pd.DataFrame, asset: Asset) -> _Metadata:
asset.symbol,
CALENDAR_24_7,
CALENDAR_24_7,
CountryCode.UNITED_STATES,
)


Expand Down

0 comments on commit 5533742

Please sign in to comment.