diff --git a/roboquant/__init__.py b/roboquant/__init__.py index 6f3c091..939bed3 100644 --- a/roboquant/__init__.py +++ b/roboquant/__init__.py @@ -1,11 +1,11 @@ -__version__ = "0.2.4" +__version__ = "0.2.5" from roboquant import brokers from roboquant import feeds from roboquant import journals from roboquant import strategies from roboquant import traders -from .account import Account, OptionAccount, Position +from .account import Account, Position, Converter, CurrencyConverter, OptionConverter from .config import Config from .event import Event, PriceItem, Candle, Trade, Quote from .order import Order, OrderStatus diff --git a/roboquant/account.py b/roboquant/account.py index a114e41..4b52585 100644 --- a/roboquant/account.py +++ b/roboquant/account.py @@ -62,17 +62,17 @@ def __call__(self, symbol: str, time: datetime) -> float: class CurrencyConverter(Converter): - """Support symbols that are denoted in a different currency from the base currency of the account""" + """Supports trading in symbols that are denoted in a different currency from the base currency of the account""" - def __init__(self, base_currency="USD", default_symbol_currency="USD"): + def __init__(self, base_currency="USD", default_symbol_currency: str | None = "USD"): super().__init__() self.rates = {} self.base_currency = base_currency self.default_symbol_currency = default_symbol_currency - self.registered_symbols = {} + self.registered_symbols: dict[str, str] = {} def register_symbol(self, symbol: str, currency: str): - """Register a symbol being denoted in a currency""" + """Register a symbol and its denoted currency""" self.registered_symbols[symbol] = currency def register_rate(self, currency: str, rate: float): @@ -81,6 +81,8 @@ def register_rate(self, currency: str, rate: float): def __call__(self, symbol: str, _: datetime) -> float: currency = self.registered_symbols.get(symbol, self.default_symbol_currency) + if not currency: + raise ValueError(f"no currency or default_symbol_currency registered for symbol={symbol}") if currency == self.base_currency: return 1.0 return self.rates[currency] @@ -114,9 +116,9 @@ def register_converter(converter: Converter): Account.__converter = converter def contract_value(self, symbol: str, size: Decimal, price: float) -> float: - # pylint: disable=unused-argument + # pylint: disable=not-callable """Return the total value of the provided contract size denoted in the base currency of the account.""" - rate = 1.0 if not Account.__converter else Account.__converter.__call__(symbol, self.last_update) + rate = 1.0 if not Account.__converter else Account.__converter(symbol, self.last_update) return float(size) * price * rate def mkt_value(self, prices: dict[str, float]) -> float: diff --git a/roboquant/feeds/csvfeed.py b/roboquant/feeds/csvfeed.py index 635ff24..a0dbc91 100644 --- a/roboquant/feeds/csvfeed.py +++ b/roboquant/feeds/csvfeed.py @@ -15,28 +15,28 @@ class CSVFeed(HistoricFeed): """Use CSV files with historic data as a feed.""" def __init__( - self, - path: str | pathlib.Path, - columns=None, - adj_close=False, - time_offset: str | None = None, - datetime_fmt: str | None = None, - endswith=".csv", - frequency="", + self, + path: str | pathlib.Path, + columns=None, + adj_close=False, + time_offset: str | None = None, + datetime_fmt: str | None = None, + endswith=".csv", + frequency="", ): super().__init__() - self.columns = columns - self.time_offset = time_offset + columns = columns or ["Date", "Open", "High", "Low", "Close", "Volume", "AdjClose"] + self.ohlcv_columns = columns[1:7] if adj_close else columns[1:6] + self.date_column = columns[0] self.datetime_fmt = datetime_fmt self.adj_close = adj_close - self.endswith = endswith self.freq = frequency + self.endswith = endswith + self.time_offset = time.fromisoformat(time_offset) if time_offset is not None else None files = self._get_files(path) logger.info("located %s files in path %s", len(files), path) - - for file in files: - self._parse_csvfile(file) # type: ignore + self._parse_csvfiles(files) # type: ignore def _get_files(self, path): if pathlib.Path(path).is_file(): @@ -52,31 +52,28 @@ def _get_symbol(self, filename: str): """Return the symbol based on the filename""" return pathlib.Path(filename).stem.upper() - def _parse_csvfile(self, filename: str): + def _parse_csvfiles(self, filenames: list[str]): adj_close = self.adj_close datetime_fmt = self.datetime_fmt - columns = self.columns or ["Date", "Open", "High", "Low", "Close", "Volume", "AdjClose"] - price_columns = columns[1:7] if adj_close else columns[1:6] - date_column = columns[0] - symbol = self._get_symbol(filename) + ohlcv_columns = self.ohlcv_columns + date_column = self.date_column freq = self.freq - t = time.fromisoformat(self.time_offset) if self.time_offset is not None else None - - with open(filename, encoding="utf8") as csvfile: - reader = csv.DictReader(csvfile) - - for row in reader: - dt = ( - datetime.fromisoformat(row[date_column]) # type: ignore - if datetime_fmt is None - else datetime.strptime(row[date_column], datetime_fmt) # type: ignore - ) - if t: - dt = datetime.combine(dt, t) - - prices = array("f", [float(row[column_name]) for column_name in price_columns]) - pb = Candle(symbol, prices, freq) if not adj_close else Candle.from_adj_close(symbol, prices, freq) - self._add_item(dt.astimezone(timezone.utc), pb) + time_offset = self.time_offset + + for filename in filenames: + symbol = self._get_symbol(filename) + with open(filename, encoding="utf8") as csvfile: + reader = csv.DictReader(csvfile) + + for row in reader: + date_str = row[date_column] + dt = datetime.strptime(date_str, datetime_fmt) if datetime_fmt else datetime.fromisoformat(date_str) + if time_offset: + dt = datetime.combine(dt, time_offset) + + ohlcv = array("f", [float(row[column]) for column in ohlcv_columns]) + pb = Candle(symbol, ohlcv, freq) if not adj_close else Candle.from_adj_close(symbol, ohlcv, freq) + self._add_item(dt.astimezone(timezone.utc), pb) @classmethod def stooq_us_daily(cls, path): diff --git a/roboquant/feeds/historicfeed.py b/roboquant/feeds/historicfeed.py index 430de5c..33e6bfe 100644 --- a/roboquant/feeds/historicfeed.py +++ b/roboquant/feeds/historicfeed.py @@ -33,20 +33,20 @@ def _add_item(self, time: datetime, item: PriceItem): @property def symbols(self): - """Return all the symbols available in this feed""" + """Return the list of symbols available in this feed""" self.__update() return self.__symbols def timeline(self) -> List[datetime]: - """Return the timeline of this feed""" + """Return the timeline of this feed as a list of datatime""" self.__update() return list(self.__data.keys()) def timeframe(self): """Return the timeframe of this feed""" tl = self.timeline() - if len(tl) == 0: - return Timeframe.empty() + if not tl: + raise ValueError("Feed doesn't contain any events.") return Timeframe(tl[0], tl[-1], inclusive=True) @@ -62,3 +62,9 @@ def play(self, channel: EventChannel): for k, v in self.__data.items(): evt = Event(k, v) channel.put(evt) + + def __repr__(self) -> str: + events = len(self.timeline()) + timeframe = self.timeframe() if events else None + feed = self.__class__.__name__ + return f"{feed}(events={events} symbols={len(self.symbols)} timeframe={timeframe})" diff --git a/roboquant/feeds/yahoofeed.py b/roboquant/feeds/yahoofeed.py index 36344b9..a091a2e 100644 --- a/roboquant/feeds/yahoofeed.py +++ b/roboquant/feeds/yahoofeed.py @@ -1,6 +1,7 @@ import logging from array import array -from datetime import datetime, timezone +from datetime import timezone +import warnings import yfinance @@ -15,17 +16,14 @@ class YahooFeed(HistoricFeed): def __init__(self, *symbols: str, start_date="2010-01-01", end_date: str | None = None, interval="1d"): super().__init__() - - end_date = end_date or datetime.now().strftime("%Y-%m-%d") - + warnings.simplefilter(action="ignore", category=FutureWarning) columns = ["Open", "High", "Low", "Close", "Volume", "Adj Close"] for symbol in symbols: logger.debug("requesting symbol=%s", symbol) df = yfinance.Ticker(symbol).history( - start=start_date, end=end_date, auto_adjust=False, actions=False, interval=interval + start=start_date, end=end_date, auto_adjust=False, actions=False, interval=interval, timeout=30 )[columns] - df.dropna(inplace=True) if len(df) == 0: logger.warning("no data retrieved for symbol=%s", symbol) @@ -33,6 +31,7 @@ def __init__(self, *symbols: str, start_date="2010-01-01", end_date: str | None # yFinance one doesn't correct the volume, so we use this one instead self.__auto_adjust(df) + for t in df.itertuples(index=True): dt = t[0].to_pydatetime().astimezone(timezone.utc) prices = t[1:6] diff --git a/roboquant/timeframe.py b/roboquant/timeframe.py index 117a802..91f5ddf 100644 --- a/roboquant/timeframe.py +++ b/roboquant/timeframe.py @@ -1,5 +1,6 @@ import random from datetime import datetime, timedelta, timezone +from typing import Any class Timeframe: @@ -94,12 +95,13 @@ def annualize(self, rate: float) -> float: years = timedelta(days=365) / self.duration return (1.0 + rate) ** years - 1.0 - def split(self, n: int | timedelta) -> list["Timeframe"]: + def split(self, n: int | timedelta | Any) -> list["Timeframe"]: """Split the timeframe in sequential parts and return the resulting list of timeframes. - The parameter `n` can be a number or a timedelta instance. + The parameter `n` can be a number or a timedelta instance or other types like relativedelta that support + datetime calculations. """ - period = n if isinstance(n, timedelta) else self.duration / n + period = self.duration / n if isinstance(n, int) else n end = self.start result = [] while end < self.end: @@ -113,7 +115,7 @@ def split(self, n: int | timedelta) -> list["Timeframe"]: last.end = self.end return result - def sample(self, duration: timedelta, n: int = 1) -> list["Timeframe"]: + def sample(self, duration: timedelta | Any, n: int = 1) -> list["Timeframe"]: """Sample one or more periods of `duration` from this timeframe.""" result = [] diff --git a/tests/performance/test_profiling.py b/tests/performance/test_profiling.py index 38cbb5c..4c69228 100644 --- a/tests/performance/test_profiling.py +++ b/tests/performance/test_profiling.py @@ -11,8 +11,8 @@ class TestProfile(unittest.TestCase): def test_profile(self): path = os.path.expanduser("~/data/nasdaq_stocks/1") feed = rq.feeds.CSVFeed.stooq_us_daily(path) - print("timeframe =", feed.timeframe(), " symbols =", len(feed.symbols)) - strategy = rq.strategies.EMACrossover(13, 26) + print(feed) + strategy = rq.strategies.EMACrossover() journal = rq.journals.BasicJournal() # Profile the run to detect bottlenecks diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 845e4ad..1a56c17 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -1,7 +1,7 @@ import unittest from decimal import Decimal -from roboquant import Account, Position, OptionAccount +from roboquant import Account, Position, OptionConverter class TestAccount(unittest.TestCase): @@ -25,8 +25,11 @@ def test_account_positions(self): self.assertAlmostEqual(acc.unrealized_pnl(prices), 0.0) def test_account_option(self): - acc = OptionAccount() - acc.register("DUMMY", 5.0) + oc = OptionConverter() + oc.register("DUMMY", 5.0) + acc = Account() + acc.register_converter(oc) + self.assertEqual(1000.0, acc.contract_value("DUMMY", Decimal(1), 200.0)) self.assertEqual(200.0, acc.contract_value("TSLA", Decimal(1), 200.0))