Skip to content

Commit

Permalink
removed some pandas dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Mar 4, 2024
1 parent d0909dc commit 82638cd
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 15 deletions.
2 changes: 1 addition & 1 deletion roboquant/feeds/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import roboquant.feeds.feedutil
from roboquant.feeds import feedutil
from .candlefeed import CandleFeed
from .csvfeed import CSVFeed
from .eventchannel import EventChannel
Expand Down
22 changes: 8 additions & 14 deletions roboquant/feeds/feedutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,23 @@ def get_symbol_prices(
return x, y


def get_symbol_ohlcv(feed: Feed, symbol: str, timeframe: Timeframe | None = None) -> list[tuple]:
def get_symbol_ohlcv(feed: Feed, symbol: str, timeframe: Timeframe | None = None) -> dict[str, list]:
"""Get the candles for a single symbol from a feed"""

result = []
result = {column: [] for column in ["Date", "Open", "High", "Low", "Close", "Volume"]}
channel = feed.play_background(timeframe)
while event := channel.get():
item = event.price_items.get(symbol)
if item and isinstance(item, Candle):
result.append((event.time, *item.ohlcv))
result["Date"].append(event.time)
result["Open"].append(item.ohlcv[0])
result["High"].append(item.ohlcv[1])
result["Low"].append(item.ohlcv[2])
result["Close"].append(item.ohlcv[3])
result["Volume"].append(item.ohlcv[4])
return result


def get_symbol_dataframe(feed: Feed, symbol: str, timeframe: Timeframe | None = None):
# pylint: disable=import-outside-toplevel
"""Get prices for a single symbol from a feed as a pandas dataframe"""

# noinspection PyPackageRequirements
import pandas as pd

ohlcv = get_symbol_ohlcv(feed, symbol, timeframe)
return pd.DataFrame(ohlcv, columns=["Date", "Open", "High", "Low", "Close", "Volume"]).set_index("Date")


def get_sp500_symbols():
full_path = pathlib.Path(__file__).parent.resolve().joinpath("sp500.json")
with open(full_path, encoding="utf8") as f:
Expand Down
16 changes: 16 additions & 0 deletions tests/samples/dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
if __name__ == "__main__":
# %%
import pandas as pd
import warnings
import roboquant as rq

warnings.simplefilter(action="ignore", category=FutureWarning)

# %%
feed = rq.feeds.YahooFeed("JPM", "IBM", "F", start_date="2010-01-01")
ohlcv = rq.feeds.feedutil.get_symbol_ohlcv(feed, "IBM")

# %%
pd.DataFrame(ohlcv).set_index("Date")

# %%

0 comments on commit 82638cd

Please sign in to comment.