From 4254a93feb9123671791d24651de50c7ed438260 Mon Sep 17 00:00:00 2001 From: Peter Dekkers Date: Sat, 9 Mar 2024 23:52:13 +0100 Subject: [PATCH] Using cash in account --- pyproject.toml | 1 + requirements.txt | 2 + roboquant/__init__.py | 2 +- roboquant/account.py | 59 +++++++------- roboquant/brokers/{ibkrbroker.py => ibkr.py} | 30 +++++--- roboquant/brokers/simbroker.py | 77 +++++++++---------- roboquant/feeds/__init__.py | 6 +- roboquant/feeds/alpacafeed.py | 52 +++++++++++++ roboquant/feeds/csvfeed.py | 2 +- .../feeds/{historicfeed.py => historic.py} | 0 roboquant/feeds/randomwalk.py | 2 +- roboquant/feeds/{tiingofeed.py => tiingo.py} | 2 +- roboquant/feeds/{yahoofeed.py => yahoo.py} | 2 +- roboquant/journals/alphabeta.py | 2 +- roboquant/journals/basicjournal.py | 17 ++-- roboquant/journals/pnlmetric.py | 10 +-- roboquant/order.py | 10 ++- roboquant/run.py | 2 + roboquant/strategies/emacrossover.py | 2 +- roboquant/traders/flextrader.py | 2 +- tests/integration/test_ibkrbroker.py | 4 +- tests/samples/close_positions_ibkr.py | 2 +- .../{dataframe.py => pandas_dataframe.py} | 0 tests/samples/papertrade_tiingo_ibkr.py | 2 +- tests/samples/walkforward.py | 4 +- tests/unit/test_account.py | 10 +-- tests/unit/test_csvfeed.py | 2 +- tests/unit/{test_roboquant.py => test_run.py} | 0 28 files changed, 186 insertions(+), 120 deletions(-) rename roboquant/brokers/{ibkrbroker.py => ibkr.py} (90%) create mode 100644 roboquant/feeds/alpacafeed.py rename roboquant/feeds/{historicfeed.py => historic.py} (100%) rename roboquant/feeds/{tiingofeed.py => tiingo.py} (99%) rename roboquant/feeds/{yahoofeed.py => yahoo.py} (97%) rename tests/samples/{dataframe.py => pandas_dataframe.py} (100%) rename tests/unit/{test_roboquant.py => test_run.py} (100%) diff --git a/pyproject.toml b/pyproject.toml index c44e3eb..4804f4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ version = {attr = "roboquant.__version__"} torch = ["torch>=2.1.0", "tensorboard>=2.15.2"] yahoo = ["yfinance~=0.2.36"] ibkr = ["nautilus-ibapi~=10.19.2"] +alpaca = ["alpaca-py"] all = [ "roboquant[torch,yahoo,ibkr]" ] diff --git a/requirements.txt b/requirements.txt index 903986a..a06de5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,9 +8,11 @@ yfinance~=0.2.36 torch>=2.1.0 tensorboard>=2.15.1 nautilus-ibapi~=10.19.2 +alpaca-py # Build tools build~=1.0.3 twine>=5.0.0 flake8>=7.0.0 +pylint>=3.1.0 diff --git a/roboquant/__init__.py b/roboquant/__init__.py index 939bed3..d93c5ec 100644 --- a/roboquant/__init__.py +++ b/roboquant/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.2.5" +__version__ = "0.2.6" from roboquant import brokers from roboquant import feeds diff --git a/roboquant/account.py b/roboquant/account.py index 4b52585..12236dc 100644 --- a/roboquant/account.py +++ b/roboquant/account.py @@ -2,11 +2,12 @@ from dataclasses import dataclass from datetime import datetime from decimal import Decimal +from roboquant.event import Event from roboquant.order import Order -@dataclass(slots=True, frozen=True) +@dataclass(slots=True) class Position: """Position of a symbol""" @@ -16,6 +17,9 @@ class Position: avg_price: float """Average price paid denoted in the currency of the symbol""" + mkt_price: float + """latest market price denoted in the currency of the symbol""" + class Converter(ABC): """Abstraction that enables trading symbols that are denoted in different currencies and/or contact sizes""" @@ -108,7 +112,7 @@ def __init__(self): self.positions: dict[str, Position] = {} self.orders: list[Order] = [] self.last_update: datetime = datetime.fromisoformat("1900-01-01T00:00:00+00:00") - self.equity: float = 0.0 + self.cash: float = 0.0 @staticmethod def register_converter(converter: Converter): @@ -121,42 +125,33 @@ def contract_value(self, symbol: str, size: Decimal, price: float) -> float: rate = 1.0 if not Account.__converter else Account.__converter(symbol, self.last_update) return float(size) * price * rate - def mkt_value(self, prices: dict[str, float]) -> float: - """Return the market value of all the open positions in the account using the provided prices. - If there is no known price provided for a position, the average price paid will be used instead. - - Args: - prices: The prices to use to calculate the market value. - """ + def mkt_value(self) -> float: + """Return the sum of the market values of the open positions in the account.""" return sum( - [ - self.contract_value(symbol, pos.size, prices[symbol] if symbol in prices else pos.avg_price) - for symbol, pos in self.positions.items() - ], + [self.contract_value(symbol, pos.size, pos.mkt_price) for symbol, pos in self.positions.items()], 0.0, ) - def unrealized_pnl(self, prices: dict[str, float]) -> float: - """Return the unrealized profit and loss for the open position given the provided market prices - If there is no known price provided for a position, it will be ignored. + def equity(self) -> float: + """Return the equity of the account. + + equity = cash + sum of the market value of the open positions - Args: - prices: The prices to use to calculate the unrealized PNL. """ + return self.cash + self.mkt_value() + + def unrealized_pnl(self) -> float: + """Return the sum of the unrealized profit and loss for the open position.""" return sum( - [ - self.contract_value(symbol, pos.size, prices[symbol] - pos.avg_price) - for symbol, pos in self.positions.items() - if symbol in prices - ], + [self.contract_value(symbol, pos.size, pos.mkt_price - pos.avg_price) for symbol, pos in self.positions.items()], 0.0, ) def has_open_order(self, symbol: str) -> bool: - """Return True if there is an open order for the symbol, False otherwise""" + """Return True if there is at least one open order for the symbol, False otherwise""" for order in self.orders: - if order.symbol == symbol and not order.closed: + if order.symbol == symbol and order.open: return True return False @@ -165,9 +160,17 @@ def get_position_size(self, symbol) -> Decimal: pos = self.positions.get(symbol) return pos.size if pos else Decimal(0) + def update_positions(self, event: Event, price_type: str = "DEFAULT"): + """update the open positions with the latest market prices""" + self.last_update = event.time + + for symbol, position in self.positions.items(): + if price := event.get_price(symbol, price_type): + position.mkt_price = price + def open_orders(self): """Return a list with the open orders""" - return [order for order in self.orders if not order.closed] + return [order for order in self.orders if order.open] def __repr__(self) -> str: p = [f"{v.size}@{k}" for k, v in self.positions.items()] @@ -178,8 +181,10 @@ def __repr__(self) -> str: result = ( f"""buying power : {self.buying_power:_.2f}\n""" - f"""equity : {self.equity:_.2f}\n""" + f"""cash : {self.cash:_.2f}\n""" + f"""equity : {self.equity():_.2f}\n""" f"""positions : {p_str}\n""" + f"""mkt value : {self.mkt_value():_.2f}\n""" f"""open orders : {o_str}\n""" f"""last update : {self.last_update}""" ) diff --git a/roboquant/brokers/ibkrbroker.py b/roboquant/brokers/ibkr.py similarity index 90% rename from roboquant/brokers/ibkrbroker.py rename to roboquant/brokers/ibkr.py index 151277c..443adde 100644 --- a/roboquant/brokers/ibkrbroker.py +++ b/roboquant/brokers/ibkr.py @@ -28,7 +28,7 @@ def __init__(self): EClient.__init__(self, self) self.orders: dict[str, Order] = {} self.positions: dict[str, Position] = {} - self.account = {"EquityWithLoanValue": 0.0, "AvailableFunds": 0.0} + self.__account = {AccountSummaryTags.TotalCashValue: 0.0, AccountSummaryTags.BuyingPower: 0.0} self.__account_end = threading.Condition() self.__order_id = 0 @@ -44,11 +44,13 @@ def get_next_order_id(self): def position(self, account: str, contract: Contract, position: Decimal, avgCost: float): logger.debug("position=%s symbol=%s avgCost=%s", position, contract.localSymbol, avgCost) symbol = contract.localSymbol or contract.symbol - self.positions[symbol] = Position(position, avgCost) + old_position = self.positions.get(symbol) + mkt_price = old_position.mkt_price if old_position else avgCost + self.positions[symbol] = Position(position, avgCost, mkt_price) def accountSummary(self, reqId: int, account: str, tag: str, value: str, currency: str): logger.debug("account %s=%s", tag, value) - self.account[tag] = float(value) + self.__account[tag] = float(value) def accountSummaryEnd(self, reqId: int): with self.__account_end: @@ -71,18 +73,18 @@ def openOrder(self, orderId: int, contract, order: IBKROrder, orderState): def request_account(self): """blocking call till account summary has been received""" buyingpower_tag = AccountSummaryTags.BuyingPower - equity_tag = AccountSummaryTags.NetLiquidation + cash_tag = AccountSummaryTags.TotalCashValue with self.__account_end: - super().reqAccountSummary(1, "All", f"{buyingpower_tag},{equity_tag}") + super().reqAccountSummary(1, "All", f"{buyingpower_tag},{cash_tag}") self.__account_end.wait() def get_buying_power(self): buyingpower_tag = AccountSummaryTags.BuyingPower - return self.account[buyingpower_tag] or 0.0 + return self.__account[buyingpower_tag] or 0.0 - def get_equity(self): - equity_tag = AccountSummaryTags.NetLiquidation - return self.account[equity_tag] or 0.0 + def get_cash(self): + cash_tag = AccountSummaryTags.TotalCashValue + return self.__account[cash_tag] or 0.0 def orderStatus( self, @@ -141,6 +143,7 @@ def __init__(self, host="127.0.0.1", port=4002, client_id=123) -> None: api.connect(host, port, client_id) self.__api = api self.__has_new_orders_since_sync = False + self.price_type = "DEFAULT" # Start the handling in a thread self.__api_thread = threading.Thread(target=api.run, daemon=True) @@ -189,7 +192,7 @@ def sync(self, event: Event | None = None) -> Account: acc.positions = {k: v for k, v in api.positions.items() if not v.size.is_zero()} acc.orders = list(api.orders.values()) acc.buying_power = api.get_buying_power() - acc.equity = api.get_equity() + acc.cash = api.get_cash() logger.debug("end sync") return acc @@ -234,7 +237,12 @@ def __get_order(order: Order): o = IBKROrder() o.action = "BUY" if order.is_buy else "SELL" o.totalQuantity = abs(order.size) - o.tif = "GTC" + if order.gtd: + o.tif = "GTD" + o.goodTillDate = order.gtd.strftime("%Y%m%d %H:%M:%S %Z") + else: + o.tif = "GTC" + if order.limit: o.orderType = "LMT" o.lmtPrice = order.limit diff --git a/roboquant/brokers/simbroker.py b/roboquant/brokers/simbroker.py index 7e32afb..a765589 100644 --- a/roboquant/brokers/simbroker.py +++ b/roboquant/brokers/simbroker.py @@ -1,12 +1,15 @@ from dataclasses import dataclass -from datetime import datetime, timedelta +from datetime import timedelta from decimal import Decimal +import logging from roboquant.account import Account, Position from roboquant.brokers.broker import Broker from roboquant.event import Event from roboquant.order import Order, OrderStatus +logger = logging.getLogger(__name__) + @dataclass(slots=True, frozen=True) class _Trx: @@ -17,12 +20,6 @@ class _Trx: price: float # denoted in the currency of the symbol -@dataclass -class _OrderState: - order: Order - expires_at: datetime | None = None - - class SimBroker(Broker): """Implementation of a Broker that simulates order handling and trade execution. @@ -39,11 +36,11 @@ def __init__( super().__init__() self._account = Account() self._modify_orders: list[Order] = [] + self._create_orders: dict[str, Order] = {} + self._account.cash = initial_deposit self._account.buying_power = initial_deposit self.slippage = slippage self.price_type = price_type - self._prices: dict[str, float] = {} - self._orders: dict[str, _OrderState] = {} self.clean_up_orders = clean_up_orders self.__order_id = 0 @@ -51,13 +48,13 @@ def _update_account(self, trx: _Trx): """Update a position and cash based on a new transaction""" acc = self._account symbol = trx.symbol - acc.buying_power -= acc.contract_value(symbol, trx.size, trx.price) + acc.cash -= acc.contract_value(symbol, trx.size, trx.price) size = acc.get_position_size(symbol) if size.is_zero(): # opening of position - acc.positions[symbol] = Position(trx.size, trx.price) + acc.positions[symbol] = Position(trx.size, trx.price, trx.price) else: new_size: Decimal = size + trx.size if new_size.is_zero(): @@ -65,12 +62,12 @@ def _update_account(self, trx: _Trx): del acc.positions[symbol] elif new_size.is_signed() != size.is_signed(): # reverse of position - acc.positions[symbol] = Position(new_size, trx.price) + acc.positions[symbol] = Position(new_size, trx.price, trx.price) else: # increase of position size old_price = acc.positions[symbol].avg_price avg_price = (old_price * float(size) + trx.price * float(trx.size)) / (float(size + trx.size)) - acc.positions[symbol] = Position(new_size, avg_price) + acc.positions[symbol] = Position(new_size, avg_price, trx.price) def _get_execution_price(self, order, item) -> float: """Return the execution price to use for an order based on the price item. @@ -96,12 +93,11 @@ def __next_order_id(self): self.__order_id += 1 return result - def _has_expired(self, state: _OrderState) -> bool: + def _has_expired(self, order: Order) -> bool: """Returns true if the order has expired, false otherwise""" - if state.expires_at is None: + if not order.gtd: return False - - return self._account.last_update >= state.expires_at + return self._account.last_update >= order.gtd def _get_fill(self, order, price) -> Decimal: """Return the fill for the order given the provided price. @@ -120,12 +116,6 @@ def _get_fill(self, order, price) -> Decimal: return Decimal(0) - def __update_mkt_prices(self, price_items): - """track the latest market prices for all open positions""" - for symbol in self._account.positions: - if item := price_items.get(symbol): - self._prices[symbol] = item.price(self.price_type) - def place_orders(self, orders): """Place new orders at this broker. The order gets assigned a unique id if it hasn't one already. @@ -136,43 +126,51 @@ def place_orders(self, orders): assert not order.closed, "cannot place a closed order" if order.id is None: order.id = self.__next_order_id() - assert order.id not in self._orders - self._orders[order.id] = _OrderState(order) + assert order.id not in self._create_orders + self._create_orders[order.id] = order else: - assert order.id in self._orders, "existing order id is not found" + assert order.id in self._create_orders, "existing order id is not found" self._modify_orders.append(order) def _process_modify_order(self): for order in self._modify_orders: - state = self._orders[order.id] # type: ignore - if state.order.closed: + orig_order = self._create_orders.get(order.id) # type: ignore + if not orig_order: + logger.info("couldn't find order with id %s", order.id) + continue + if orig_order.closed: + logger.info("cannot modify order because order is already closed %s", orig_order) continue if order.is_cancellation: - state.order.status = OrderStatus.CANCELLED + orig_order.status = OrderStatus.CANCELLED else: - state.order.size = order.size or state.order.size - state.order.limit = order.limit or state.order.limit + orig_order.size = order.size or orig_order.size + orig_order.limit = order.limit or orig_order.limit + logger.info("modified order %s", orig_order) + self._modify_orders = [] def _process_create_orders(self, prices): - for state in self._orders.values(): - order = state.order + for order in self._create_orders.values(): if order.closed: continue - if self._has_expired(state): + if self._has_expired(order): + logger.info("order expired order=%s time=%s", order, self._account.last_update) order.status = OrderStatus.EXPIRED else: if (item := prices.get(order.symbol)) is not None: - state.expires_at = state.expires_at or self._account.last_update + timedelta(days=90) + if not order.gtd: + order.gtd = self._account.last_update + timedelta(days=90) trx = self._execute(order, item) if trx is not None: + logger.info("executed order=%s trx=%s", order, trx) self._update_account(trx) order.fill += trx.size if order.fill == order.size: order.status = OrderStatus.FILLED def sync(self, event: Event | None = None) -> Account: - """This will perform the trading simulation for open orders and update the account accordingly.""" + """This will perform the trading simulation for open orders areturn an updated the account""" acc = self._account if event: @@ -182,12 +180,11 @@ def sync(self, event: Event | None = None) -> Account: if self.clean_up_orders: # remove all the closed orders from the previous step - self._orders = {order_id: state for order_id, state in self._orders.items() if not state.order.closed} + self._create_orders = {order_id: order for order_id, order in self._create_orders.items() if not order.closed} self._process_modify_order() self._process_create_orders(prices) - self.__update_mkt_prices(prices) - acc.equity = acc.mkt_value(self._prices) + acc.buying_power - acc.orders = [state.order for state in self._orders.values()] + acc.buying_power = acc.cash + acc.orders = list(self._create_orders.values()) return acc diff --git a/roboquant/feeds/__init__.py b/roboquant/feeds/__init__.py index bf25721..5768ff3 100644 --- a/roboquant/feeds/__init__.py +++ b/roboquant/feeds/__init__.py @@ -3,12 +3,12 @@ from .csvfeed import CSVFeed from .eventchannel import EventChannel from .feed import Feed -from .historicfeed import HistoricFeed +from .historic import HistoricFeed from .randomwalk import RandomWalk from .sqllitefeed import SQLFeed -from .tiingofeed import TiingoLiveFeed, TiingoHistoricFeed +from .tiingo import TiingoLiveFeed, TiingoHistoricFeed try: - from .yahoofeed import YahooFeed + from .yahoo import YahooFeed except ImportError: pass diff --git a/roboquant/feeds/alpacafeed.py b/roboquant/feeds/alpacafeed.py new file mode 100644 index 0000000..1f51a34 --- /dev/null +++ b/roboquant/feeds/alpacafeed.py @@ -0,0 +1,52 @@ +import threading +import time + +from alpaca.data.live.crypto import CryptoDataStream + +from roboquant.config import Config +from roboquant.event import Event, Trade +from roboquant.feeds.eventchannel import EventChannel + +from roboquant.feeds.feed import Feed + + +class AlpacaLiveFeed(Feed): + + def __init__(self) -> None: + super().__init__() + config = Config() + api_key = config.get("alpaca.public.key") + secret_key = config.get("alpaca.secret.key") + self.stream = CryptoDataStream(api_key, secret_key) + thread = threading.Thread(None, self.stream.run, daemon=True) + thread.start() + # print("running", flush=True) + self._channel = None + + def play(self, channel: EventChannel): + self._channel = channel + while not channel.is_closed: + time.sleep(1) + self._channel = None + + async def handle_trades(self, data): + print(data) + if self._channel: + item = Trade(data.symbol, data.price, data.size) + event = Event(data["timestamp"], [item]) + self._channel.put(event) + + def subscribe(self, *symbols: str): + self.stream.subscribe_trades(self.handle_trades, *symbols) + + +def run(): + feed = AlpacaLiveFeed() + feed.subscribe("BTC/USD", "ETH/USD") + channel = feed.play_background() + while event := channel.get(60.0): + print(event) + + +if __name__ == "__main__": + run() diff --git a/roboquant/feeds/csvfeed.py b/roboquant/feeds/csvfeed.py index a0dbc91..f51c874 100644 --- a/roboquant/feeds/csvfeed.py +++ b/roboquant/feeds/csvfeed.py @@ -6,7 +6,7 @@ from datetime import datetime, time, timezone from roboquant.event import Candle -from roboquant.feeds.historicfeed import HistoricFeed +from roboquant.feeds.historic import HistoricFeed logger = logging.getLogger(__name__) diff --git a/roboquant/feeds/historicfeed.py b/roboquant/feeds/historic.py similarity index 100% rename from roboquant/feeds/historicfeed.py rename to roboquant/feeds/historic.py diff --git a/roboquant/feeds/randomwalk.py b/roboquant/feeds/randomwalk.py index f0500af..adbfb14 100644 --- a/roboquant/feeds/randomwalk.py +++ b/roboquant/feeds/randomwalk.py @@ -4,7 +4,7 @@ import numpy as np from roboquant.event import Trade -from .historicfeed import HistoricFeed +from .historic import HistoricFeed class RandomWalk(HistoricFeed): diff --git a/roboquant/feeds/tiingofeed.py b/roboquant/feeds/tiingo.py similarity index 99% rename from roboquant/feeds/tiingofeed.py rename to roboquant/feeds/tiingo.py index 457b55d..c939613 100644 --- a/roboquant/feeds/tiingofeed.py +++ b/roboquant/feeds/tiingo.py @@ -16,7 +16,7 @@ from roboquant.event import Trade, Quote, Event from roboquant.feeds.eventchannel import EventChannel from roboquant.feeds.feed import Feed -from roboquant.feeds.historicfeed import HistoricFeed +from roboquant.feeds.historic import HistoricFeed logger = logging.getLogger(__name__) diff --git a/roboquant/feeds/yahoofeed.py b/roboquant/feeds/yahoo.py similarity index 97% rename from roboquant/feeds/yahoofeed.py rename to roboquant/feeds/yahoo.py index a091a2e..b793c09 100644 --- a/roboquant/feeds/yahoofeed.py +++ b/roboquant/feeds/yahoo.py @@ -6,7 +6,7 @@ import yfinance from roboquant.event import Candle -from roboquant.feeds.historicfeed import HistoricFeed +from roboquant.feeds.historic import HistoricFeed logger = logging.getLogger(__name__) diff --git a/roboquant/journals/alphabeta.py b/roboquant/journals/alphabeta.py index 4ffecc1..796833b 100644 --- a/roboquant/journals/alphabeta.py +++ b/roboquant/journals/alphabeta.py @@ -40,7 +40,7 @@ def __update(self, equity, prices): def calc(self, event, account, signals, orders): prices = event.get_prices(self.price_type) - equity = account.equity + equity = account.equity() if self.__last_equity is None: self.__update(equity, prices) return {} diff --git a/roboquant/journals/basicjournal.py b/roboquant/journals/basicjournal.py index 7358431..c060dc9 100644 --- a/roboquant/journals/basicjournal.py +++ b/roboquant/journals/basicjournal.py @@ -10,32 +10,27 @@ class BasicJournal(Journal): """Tracks a number of basic metrics: - total number of events, items, signals and orders until that time - - total pnl percentage + + It will also log these values at each step in the run at `info` level. This journal adds little overhead to a run, both CPU and memory wise. """ + items: int orders: int signals: int events: int - pnl: float def __init__(self): + self.events = 0 + self.signals = 0 self.items = 0 self.orders = 0 - self.signals = 0 - self.events = 0 - self.pnl = 0.0 - self.__first_equity = None def track(self, event, account, signals, orders): - if self.__first_equity is None: - self.__first_equity = account.equity - self.items += len(event.items) - self.orders += len(orders) self.events += 1 self.signals += len(signals) - self.pnl = account.equity / self.__first_equity - 1.0 + self.orders += len(orders) logger.info("time=%s info=%s", event.time, self) diff --git a/roboquant/journals/pnlmetric.py b/roboquant/journals/pnlmetric.py index 78fa5e2..b32998f 100644 --- a/roboquant/journals/pnlmetric.py +++ b/roboquant/journals/pnlmetric.py @@ -18,12 +18,11 @@ def __init__(self): self.prev_equity = None self.max_equity = -10e10 self.min_equity = 10e10 - self._prices = {} def calc(self, event, account, signals, orders) -> dict[str, float]: - equity = account.equity + equity = account.equity() - total, realized, unrealized = self.__get_pnl_values(equity, event, account) + total, realized, unrealized = self.__get_pnl_values(equity, account) return { "pnl/equity": equity, @@ -35,12 +34,11 @@ def calc(self, event, account, signals, orders) -> dict[str, float]: "pnl/unrealized": unrealized, } - def __get_pnl_values(self, equity, event, account): + def __get_pnl_values(self, equity, account): if self.first_equity is None: self.first_equity = equity - self._prices.update(event.get_prices()) - unrealized = account.unrealized_pnl(self._prices) + unrealized = account.unrealized_pnl() total = equity - self.first_equity realized = total - unrealized return total, realized, unrealized diff --git a/roboquant/order.py b/roboquant/order.py index 6533efc..e559c98 100644 --- a/roboquant/order.py +++ b/roboquant/order.py @@ -1,5 +1,6 @@ from copy import copy from dataclasses import dataclass +from datetime import datetime from decimal import Decimal from enum import Flag, auto from typing import Any @@ -51,18 +52,23 @@ class Order: symbol: str size: Decimal limit: float | None + gtd: datetime | None info: dict[str, Any] id: str | None status: OrderStatus fill: Decimal - def __init__(self, symbol: str, size: Decimal | str | int | float, limit: float | None = None, **kwargs): + def __init__( + self, symbol: str, size: Decimal | str | int | float, limit: float | None = None, gtd: datetime | None = None, **kwargs + ): self.symbol = symbol self.size = Decimal(size) assert not self.size.is_zero(), "Cannot create a new order with size is zero" self.limit = limit + self.gtd = gtd + self.id: str | None = None self.status: OrderStatus = OrderStatus.INITIAL self.fill = Decimal(0) @@ -111,7 +117,7 @@ def update(self, size: Decimal | str | int | float | None = None, limit: float | return result def __copy__(self): - result = Order(self.symbol, self.size, self.limit, **self.info) + result = Order(self.symbol, self.size, self.limit, self.gtd, **self.info) result.id = self.id result.status = self.status result.fill = self.fill diff --git a/roboquant/run.py b/roboquant/run.py index 242b642..73e3824 100644 --- a/roboquant/run.py +++ b/roboquant/run.py @@ -18,6 +18,7 @@ def run( timeframe: Timeframe | None = None, capacity: int = 10, heartbeat_timeout: float | None = None, + price_type: str = "DEFAULT" ) -> Account: """Start a new run. Only the first two parameters, the feed and strategy, are mandatory. The other parameters are optional. @@ -44,6 +45,7 @@ def run( while event := channel.get(heartbeat_timeout): signals = strategy.create_signals(event) account = broker.sync(event) + account.update_positions(event, price_type) orders = trader.create_orders(signals, event, account) broker.place_orders(orders) if journal: diff --git a/roboquant/strategies/emacrossover.py b/roboquant/strategies/emacrossover.py index e985a7e..61cf917 100644 --- a/roboquant/strategies/emacrossover.py +++ b/roboquant/strategies/emacrossover.py @@ -39,7 +39,7 @@ def create_signals(self, event: Event) -> dict[str, Signal]: class _Calculator: - __slots__ = "momentum", "price", "step" + __slots__ = "momentum", "price" def __init__(self, momentum, price): self.momentum = momentum diff --git a/roboquant/traders/flextrader.py b/roboquant/traders/flextrader.py index 773aab1..eaccfd6 100644 --- a/roboquant/traders/flextrader.py +++ b/roboquant/traders/flextrader.py @@ -84,7 +84,7 @@ def create_orders(self, signals: dict[str, Signal], event: Event, account: Accou return [] orders: list[Order] = [] - equity = account.equity + equity = account.equity() max_order_value = equity * self.max_order_perc min_order_value = equity * self.min_order_perc available = account.buying_power - self.min_buying_power_perc * equity diff --git a/tests/integration/test_ibkrbroker.py b/tests/integration/test_ibkrbroker.py index f0da1b0..dc17224 100644 --- a/tests/integration/test_ibkrbroker.py +++ b/tests/integration/test_ibkrbroker.py @@ -4,7 +4,7 @@ from decimal import Decimal from roboquant import OrderStatus, Order -from roboquant.brokers.ibkrbroker import IBKRBroker +from roboquant.brokers.ibkr import IBKRBroker class TestIBKRBroker(unittest.TestCase): @@ -17,7 +17,7 @@ def test_ibkr_order(self): broker = IBKRBroker() account = broker.sync() - self.assertGreater(account.equity, 0) + self.assertGreater(account.equity(), 0) self.assertEqual(len(account.orders), 0) # Place an order diff --git a/tests/samples/close_positions_ibkr.py b/tests/samples/close_positions_ibkr.py index a844c16..e928b22 100644 --- a/tests/samples/close_positions_ibkr.py +++ b/tests/samples/close_positions_ibkr.py @@ -1,7 +1,7 @@ from time import sleep import logging from roboquant import Order -from roboquant.brokers.ibkrbroker import IBKRBroker +from roboquant.brokers.ibkr import IBKRBroker if __name__ == "__main__": diff --git a/tests/samples/dataframe.py b/tests/samples/pandas_dataframe.py similarity index 100% rename from tests/samples/dataframe.py rename to tests/samples/pandas_dataframe.py diff --git a/tests/samples/papertrade_tiingo_ibkr.py b/tests/samples/papertrade_tiingo_ibkr.py index 50d0448..ea272c8 100644 --- a/tests/samples/papertrade_tiingo_ibkr.py +++ b/tests/samples/papertrade_tiingo_ibkr.py @@ -2,7 +2,7 @@ import logging import roboquant as rq from roboquant.account import Account, CurrencyConverter -from roboquant.brokers.ibkrbroker import IBKRBroker +from roboquant.brokers.ibkr import IBKRBroker from roboquant.feeds.feedutil import get_sp500_symbols diff --git a/tests/samples/walkforward.py b/tests/samples/walkforward.py index 8107f7e..8513122 100644 --- a/tests/samples/walkforward.py +++ b/tests/samples/walkforward.py @@ -7,8 +7,8 @@ # split the feed timeframe in 4 equal parts timeframes = feed.timeframe().split(4) - # run a back-test on each timeframe + # run a walkforward back-test on each timeframe for timeframe in timeframes: strategy = rq.strategies.EMACrossover(13, 26) account = rq.run(feed, strategy, timeframe=timeframe) - print(f"{timeframe} equity={account.equity:7_.2f}") + print(f"{timeframe} equity={account.equity():7_.2f}") diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 1a56c17..76de4cc 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -9,8 +9,8 @@ class TestAccount(unittest.TestCase): def test_account_init(self): acc = Account() self.assertEqual(acc.buying_power, 0.0) - self.assertEqual(acc.unrealized_pnl({}), 0.0) - self.assertEqual(acc.mkt_value({}), 0.0) + self.assertEqual(acc.unrealized_pnl(), 0.0) + self.assertEqual(acc.mkt_value(), 0.0) def test_account_positions(self): acc = Account() @@ -18,11 +18,11 @@ def test_account_positions(self): for i in range(10): symbol = f"AA${i}" price = 10.0 + i - acc.positions[symbol] = Position(Decimal(10), price) + acc.positions[symbol] = Position(Decimal(10), price, price) prices[symbol] = price - self.assertAlmostEqual(acc.mkt_value(prices), 1450.0) - self.assertAlmostEqual(acc.unrealized_pnl(prices), 0.0) + self.assertAlmostEqual(acc.mkt_value(), 1450.0) + self.assertAlmostEqual(acc.unrealized_pnl(), 0.0) def test_account_option(self): oc = OptionConverter() diff --git a/tests/unit/test_csvfeed.py b/tests/unit/test_csvfeed.py index 79eb63b..cfa23cf 100644 --- a/tests/unit/test_csvfeed.py +++ b/tests/unit/test_csvfeed.py @@ -1,7 +1,7 @@ import pathlib import unittest -from roboquant.feeds.csvfeed import CSVFeed +from roboquant.feeds import CSVFeed from tests.common import run_price_item_feed diff --git a/tests/unit/test_roboquant.py b/tests/unit/test_run.py similarity index 100% rename from tests/unit/test_roboquant.py rename to tests/unit/test_run.py