Skip to content

Commit

Permalink
updated account support
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Mar 8, 2024
1 parent bacbd9f commit b6d687c
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 63 deletions.
4 changes: 2 additions & 2 deletions roboquant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
__version__ = "0.2.4"
__version__ = "0.2.5"

from roboquant import brokers
from roboquant import feeds
from roboquant import journals
from roboquant import strategies
from roboquant import traders
from .account import Account, OptionAccount, Position
from .account import Account, Position, Converter, CurrencyConverter, OptionConverter
from .config import Config
from .event import Event, PriceItem, Candle, Trade, Quote
from .order import Order, OrderStatus
Expand Down
14 changes: 8 additions & 6 deletions roboquant/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,17 @@ def __call__(self, symbol: str, time: datetime) -> float:


class CurrencyConverter(Converter):
"""Support symbols that are denoted in a different currency from the base currency of the account"""
"""Supports trading in symbols that are denoted in a different currency from the base currency of the account"""

def __init__(self, base_currency="USD", default_symbol_currency="USD"):
def __init__(self, base_currency="USD", default_symbol_currency: str | None = "USD"):
super().__init__()
self.rates = {}
self.base_currency = base_currency
self.default_symbol_currency = default_symbol_currency
self.registered_symbols = {}
self.registered_symbols: dict[str, str] = {}

def register_symbol(self, symbol: str, currency: str):
"""Register a symbol being denoted in a currency"""
"""Register a symbol and its denoted currency"""
self.registered_symbols[symbol] = currency

def register_rate(self, currency: str, rate: float):
Expand All @@ -81,6 +81,8 @@ def register_rate(self, currency: str, rate: float):

def __call__(self, symbol: str, _: datetime) -> float:
currency = self.registered_symbols.get(symbol, self.default_symbol_currency)
if not currency:
raise ValueError(f"no currency or default_symbol_currency registered for symbol={symbol}")
if currency == self.base_currency:
return 1.0
return self.rates[currency]
Expand Down Expand Up @@ -114,9 +116,9 @@ def register_converter(converter: Converter):
Account.__converter = converter

def contract_value(self, symbol: str, size: Decimal, price: float) -> float:
# pylint: disable=unused-argument
# pylint: disable=not-callable
"""Return the total value of the provided contract size denoted in the base currency of the account."""
rate = 1.0 if not Account.__converter else Account.__converter.__call__(symbol, self.last_update)
rate = 1.0 if not Account.__converter else Account.__converter(symbol, self.last_update)
return float(size) * price * rate

def mkt_value(self, prices: dict[str, float]) -> float:
Expand Down
69 changes: 33 additions & 36 deletions roboquant/feeds/csvfeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,28 @@ class CSVFeed(HistoricFeed):
"""Use CSV files with historic data as a feed."""

def __init__(
self,
path: str | pathlib.Path,
columns=None,
adj_close=False,
time_offset: str | None = None,
datetime_fmt: str | None = None,
endswith=".csv",
frequency="",
self,
path: str | pathlib.Path,
columns=None,
adj_close=False,
time_offset: str | None = None,
datetime_fmt: str | None = None,
endswith=".csv",
frequency="",
):
super().__init__()
self.columns = columns
self.time_offset = time_offset
columns = columns or ["Date", "Open", "High", "Low", "Close", "Volume", "AdjClose"]
self.ohlcv_columns = columns[1:7] if adj_close else columns[1:6]
self.date_column = columns[0]
self.datetime_fmt = datetime_fmt
self.adj_close = adj_close
self.endswith = endswith
self.freq = frequency
self.endswith = endswith
self.time_offset = time.fromisoformat(time_offset) if time_offset is not None else None

files = self._get_files(path)
logger.info("located %s files in path %s", len(files), path)

for file in files:
self._parse_csvfile(file) # type: ignore
self._parse_csvfiles(files) # type: ignore

def _get_files(self, path):
if pathlib.Path(path).is_file():
Expand All @@ -52,31 +52,28 @@ def _get_symbol(self, filename: str):
"""Return the symbol based on the filename"""
return pathlib.Path(filename).stem.upper()

def _parse_csvfile(self, filename: str):
def _parse_csvfiles(self, filenames: list[str]):
adj_close = self.adj_close
datetime_fmt = self.datetime_fmt
columns = self.columns or ["Date", "Open", "High", "Low", "Close", "Volume", "AdjClose"]
price_columns = columns[1:7] if adj_close else columns[1:6]
date_column = columns[0]
symbol = self._get_symbol(filename)
ohlcv_columns = self.ohlcv_columns
date_column = self.date_column
freq = self.freq
t = time.fromisoformat(self.time_offset) if self.time_offset is not None else None

with open(filename, encoding="utf8") as csvfile:
reader = csv.DictReader(csvfile)

for row in reader:
dt = (
datetime.fromisoformat(row[date_column]) # type: ignore
if datetime_fmt is None
else datetime.strptime(row[date_column], datetime_fmt) # type: ignore
)
if t:
dt = datetime.combine(dt, t)

prices = array("f", [float(row[column_name]) for column_name in price_columns])
pb = Candle(symbol, prices, freq) if not adj_close else Candle.from_adj_close(symbol, prices, freq)
self._add_item(dt.astimezone(timezone.utc), pb)
time_offset = self.time_offset

for filename in filenames:
symbol = self._get_symbol(filename)
with open(filename, encoding="utf8") as csvfile:
reader = csv.DictReader(csvfile)

for row in reader:
date_str = row[date_column]
dt = datetime.strptime(date_str, datetime_fmt) if datetime_fmt else datetime.fromisoformat(date_str)
if time_offset:
dt = datetime.combine(dt, time_offset)

ohlcv = array("f", [float(row[column]) for column in ohlcv_columns])
pb = Candle(symbol, ohlcv, freq) if not adj_close else Candle.from_adj_close(symbol, ohlcv, freq)
self._add_item(dt.astimezone(timezone.utc), pb)

@classmethod
def stooq_us_daily(cls, path):
Expand Down
14 changes: 10 additions & 4 deletions roboquant/feeds/historicfeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ def _add_item(self, time: datetime, item: PriceItem):

@property
def symbols(self):
"""Return all the symbols available in this feed"""
"""Return the list of symbols available in this feed"""
self.__update()
return self.__symbols

def timeline(self) -> List[datetime]:
"""Return the timeline of this feed"""
"""Return the timeline of this feed as a list of datatime"""
self.__update()
return list(self.__data.keys())

def timeframe(self):
"""Return the timeframe of this feed"""
tl = self.timeline()
if len(tl) == 0:
return Timeframe.empty()
if not tl:
raise ValueError("Feed doesn't contain any events.")

return Timeframe(tl[0], tl[-1], inclusive=True)

Expand All @@ -62,3 +62,9 @@ def play(self, channel: EventChannel):
for k, v in self.__data.items():
evt = Event(k, v)
channel.put(evt)

def __repr__(self) -> str:
events = len(self.timeline())
timeframe = self.timeframe() if events else None
feed = self.__class__.__name__
return f"{feed}(events={events} symbols={len(self.symbols)} timeframe={timeframe})"
11 changes: 5 additions & 6 deletions roboquant/feeds/yahoofeed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from array import array
from datetime import datetime, timezone
from datetime import timezone
import warnings

import yfinance

Expand All @@ -15,24 +16,22 @@ class YahooFeed(HistoricFeed):

def __init__(self, *symbols: str, start_date="2010-01-01", end_date: str | None = None, interval="1d"):
super().__init__()

end_date = end_date or datetime.now().strftime("%Y-%m-%d")

warnings.simplefilter(action="ignore", category=FutureWarning)
columns = ["Open", "High", "Low", "Close", "Volume", "Adj Close"]

for symbol in symbols:
logger.debug("requesting symbol=%s", symbol)
df = yfinance.Ticker(symbol).history(
start=start_date, end=end_date, auto_adjust=False, actions=False, interval=interval
start=start_date, end=end_date, auto_adjust=False, actions=False, interval=interval, timeout=30
)[columns]
df.dropna(inplace=True)

if len(df) == 0:
logger.warning("no data retrieved for symbol=%s", symbol)
continue

# yFinance one doesn't correct the volume, so we use this one instead
self.__auto_adjust(df)

for t in df.itertuples(index=True):
dt = t[0].to_pydatetime().astimezone(timezone.utc)
prices = t[1:6]
Expand Down
10 changes: 6 additions & 4 deletions roboquant/timeframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import random
from datetime import datetime, timedelta, timezone
from typing import Any


class Timeframe:
Expand Down Expand Up @@ -94,12 +95,13 @@ def annualize(self, rate: float) -> float:
years = timedelta(days=365) / self.duration
return (1.0 + rate) ** years - 1.0

def split(self, n: int | timedelta) -> list["Timeframe"]:
def split(self, n: int | timedelta | Any) -> list["Timeframe"]:
"""Split the timeframe in sequential parts and return the resulting list of timeframes.
The parameter `n` can be a number or a timedelta instance.
The parameter `n` can be a number or a timedelta instance or other types like relativedelta that support
datetime calculations.
"""

period = n if isinstance(n, timedelta) else self.duration / n
period = self.duration / n if isinstance(n, int) else n
end = self.start
result = []
while end < self.end:
Expand All @@ -113,7 +115,7 @@ def split(self, n: int | timedelta) -> list["Timeframe"]:
last.end = self.end
return result

def sample(self, duration: timedelta, n: int = 1) -> list["Timeframe"]:
def sample(self, duration: timedelta | Any, n: int = 1) -> list["Timeframe"]:
"""Sample one or more periods of `duration` from this timeframe."""

result = []
Expand Down
4 changes: 2 additions & 2 deletions tests/performance/test_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ class TestProfile(unittest.TestCase):
def test_profile(self):
path = os.path.expanduser("~/data/nasdaq_stocks/1")
feed = rq.feeds.CSVFeed.stooq_us_daily(path)
print("timeframe =", feed.timeframe(), " symbols =", len(feed.symbols))
strategy = rq.strategies.EMACrossover(13, 26)
print(feed)
strategy = rq.strategies.EMACrossover()
journal = rq.journals.BasicJournal()

# Profile the run to detect bottlenecks
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/test_account.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from decimal import Decimal

from roboquant import Account, Position, OptionAccount
from roboquant import Account, Position, OptionConverter


class TestAccount(unittest.TestCase):
Expand All @@ -25,8 +25,11 @@ def test_account_positions(self):
self.assertAlmostEqual(acc.unrealized_pnl(prices), 0.0)

def test_account_option(self):
acc = OptionAccount()
acc.register("DUMMY", 5.0)
oc = OptionConverter()
oc.register("DUMMY", 5.0)
acc = Account()
acc.register_converter(oc)

self.assertEqual(1000.0, acc.contract_value("DUMMY", Decimal(1), 200.0))
self.assertEqual(200.0, acc.contract_value("TSLA", Decimal(1), 200.0))

Expand Down

0 comments on commit b6d687c

Please sign in to comment.