diff --git a/.vscode/settings.json b/.vscode/settings.json index c235eee..8b924ab 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -29,8 +29,8 @@ "--ignore=tests/integration", "--ignore=tests/performance" ], - "python.testing.pytestEnabled": true, - "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": false, + "python.testing.unittestEnabled": true, "[python]": { "editor.defaultFormatter": "ms-python.black-formatter" } diff --git a/pyproject.toml b/pyproject.toml index 7c25e1e..c365162 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,9 @@ testpaths = [ "tests/unit" ] +[tool.pyright] +reportOptionalOperand = "off" + [tool.pylint.'MESSAGES CONTROL'] max-line-length = 127 disable = "too-few-public-methods,missing-module-docstring,missing-class-docstring,missing-function-docstring,unnecessary-ellipsis" diff --git a/roboquant/account.py b/roboquant/account.py index 98df5d5..a114e41 100644 --- a/roboquant/account.py +++ b/roboquant/account.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime from decimal import Decimal @@ -16,8 +17,78 @@ class Position: """Average price paid denoted in the currency of the symbol""" +class Converter(ABC): + """Abstraction that enables trading symbols that are denoted in different currencies and/or contact sizes""" + + @abstractmethod + def __call__(self, symbol: str, time: datetime) -> float: + """Return the conversion rate for the symbol at the given time""" + ... + + +class OptionConverter(Converter): + """ + This converter handles common option contracts of size 100 and 10 and serves as an example. + + If no contract size is registered for a symbol, it calculates one based on the symbol name. + If the symbol is not recognized as an OCC compliant option symbol, it is assumed to have a + contract size of 1.0 + """ + + def __init__(self): + super().__init__() + self._contract_sizes: dict[str, float] = {} + + def register(self, symbol: str, contract_size: float = 100.0): + """Register a contract-size for a symbol""" + self._contract_sizes[symbol] = contract_size + + def __call__(self, symbol: str, time: datetime) -> float: + contract_size = self._contract_sizes.get(symbol) + + # If no contract has been registered, we try to defer the contract size from the symbol + if contract_size is None: + if len(symbol) == 21: + # OCC compliant option symbol + symbol = symbol[0:6].rstrip() + contract_size = 10.0 if symbol[-1] == "7" else 100.0 + else: + # not an option symbol + contract_size = 1.0 + + self._contract_sizes[symbol] = contract_size + + return contract_size + + +class CurrencyConverter(Converter): + """Support symbols that are denoted in a different currency from the base currency of the account""" + + def __init__(self, base_currency="USD", default_symbol_currency="USD"): + super().__init__() + self.rates = {} + self.base_currency = base_currency + self.default_symbol_currency = default_symbol_currency + self.registered_symbols = {} + + def register_symbol(self, symbol: str, currency: str): + """Register a symbol being denoted in a currency""" + self.registered_symbols[symbol] = currency + + def register_rate(self, currency: str, rate: float): + """Register a conversion rate from a currency to the base_currency""" + self.rates[currency] = rate + + def __call__(self, symbol: str, _: datetime) -> float: + currency = self.registered_symbols.get(symbol, self.default_symbol_currency) + if currency == self.base_currency: + return 1.0 + return self.rates[currency] + + class Account: - """The account maintains the following state during a run: + """Represents a trading account with all monetary amounts denoted in a single currency. + The account maintains the following state during a run: - Available buying power for orders in the base currency of the account - All the open positions @@ -28,11 +99,7 @@ class Account: Only the broker updates the account and does this only during its `sync` method. """ - buying_power: float - positions: dict[str, Position] - orders: list[Order] - last_update: datetime - equity: float + __converter: Converter | None = None def __init__(self): self.buying_power: float = 0.0 @@ -41,16 +108,16 @@ def __init__(self): self.last_update: datetime = datetime.fromisoformat("1900-01-01T00:00:00+00:00") self.equity: float = 0.0 + @staticmethod + def register_converter(converter: Converter): + """Register a converter""" + Account.__converter = converter + def contract_value(self, symbol: str, size: Decimal, price: float) -> float: # pylint: disable=unused-argument - """Return the total value of the provided contract size denoted in the base currency of the account. - The default implementation returns `size * price`. - - A subclass can implement different logic to cater for: - - symbols denoted in different currencies - - symbols having different contract sizes like option contracts. - """ - return float(size) * price + """Return the total value of the provided contract size denoted in the base currency of the account.""" + rate = 1.0 if not Account.__converter else Account.__converter.__call__(symbol, self.last_update) + return float(size) * price * rate def mkt_value(self, prices: dict[str, float]) -> float: """Return the market value of all the open positions in the account using the provided prices. @@ -115,36 +182,3 @@ def __repr__(self) -> str: f"""last update : {self.last_update}""" ) return result - - -class OptionAccount(Account): - """ - This account handles common option contracts of size 100 and 10 and serves as an example. - If no contract size is registered for a symbol, it creates one based on the option symbol name. - - If the symbol is not recognized as an OCC compliant option symbol, it is assumed to have a - contract size of 1.0 - """ - - def __init__(self): - super().__init__() - self._contract_sizes: dict[str, float] = {} - - def register(self, symbol: str, contract_size: float = 100.0): - """Register a certain contract-size for a symbol""" - self._contract_sizes[symbol] = contract_size - - def contract_value(self, symbol: str, size: Decimal, price: float) -> float: - contract_size = self._contract_sizes.get(symbol) - - # If no contract has been registered, we try to defer the contract size from the symbol - if contract_size is None: - if len(symbol) == 21: - # OCC compliant option symbol - symbol = symbol[0:6].rstrip() - contract_size = 10.0 if symbol[-1] == "7" else 100.0 - else: - # not an option symbol - contract_size = 1.0 - - return contract_size * float(size) * price diff --git a/roboquant/brokers/ibkrbroker.py b/roboquant/brokers/ibkrbroker.py index 3763b39..151277c 100644 --- a/roboquant/brokers/ibkrbroker.py +++ b/roboquant/brokers/ibkrbroker.py @@ -85,18 +85,18 @@ def get_equity(self): return self.account[equity_tag] or 0.0 def orderStatus( - self, - orderId, - status, - filled, - remaining, - avgFillPrice, - permId, - parentId, - lastFillPrice, - clientId, - whyHeld, - mktCapPrice, + self, + orderId, + status, + filled, + remaining, + avgFillPrice, + permId, + parentId, + lastFillPrice, + clientId, + whyHeld, + mktCapPrice, ): logger.debug("order status orderId=%s status=%s fill=%s", orderId, status, filled) orderId = str(orderId) @@ -122,10 +122,20 @@ class IBKRBroker(Broker): Map symbols to IBKR contracts. If a symbol is not found, the symbol is assumed to represent a US stock + host + the ip number of the host where TWS or IB Gateway is running. + + port + By default, TWS uses socket port 7496 for live sessions and 7497 for paper sessions. + IB Gateway by contrast uses 4001 for live sessions and 4002 for paper sessions. + However these are just defaults, and can be modified as desired. + + client_id + The client id to use to connect to TWS or IB Gateway. """ - def __init__(self, host="127.0.0.1", port=4002, account=None, client_id=123) -> None: - self.__account = account or Account() + def __init__(self, host="127.0.0.1", port=4002, client_id=123) -> None: + self.__account = Account() self.contract_mapping: dict[str, Contract] = {} api = _IBApi() api.connect(host, port, client_id) @@ -137,16 +147,25 @@ def __init__(self, host="127.0.0.1", port=4002, account=None, client_id=123) -> self.__api_thread.start() time.sleep(3.0) + @classmethod + def use_tws(cls, client_id=123): + """Return a broker connected to the TWS papertrade instance with its default port (7497) settings""" + return cls("127.0.0.1", 7497, client_id) + + @classmethod + def use_ibgateway(cls, client_id=123): + """Return a broker connected to a IB Gateway papertrade instance with its default port (4002) settings""" + return cls("127.0.0.1", 4002, client_id) + def disconnect(self): - self.__api.reader.conn.disconnect() + self.__api.reader.conn.disconnect() # type: ignore def _should_sync(self, now: datetime): """Avoid too many API calls""" return self.__has_new_orders_since_sync or now - self.__account.last_update > timedelta(seconds=1) def sync(self, event: Event | None = None) -> Account: - """Sync with the IBKR account - """ + """Sync with the IBKR account""" logger.debug("start sync") now = datetime.now(timezone.utc) diff --git a/roboquant/brokers/simbroker.py b/roboquant/brokers/simbroker.py index 3da34f0..1cb7e2c 100644 --- a/roboquant/brokers/simbroker.py +++ b/roboquant/brokers/simbroker.py @@ -11,6 +11,7 @@ @dataclass(slots=True, frozen=True) class _Trx: """transaction for an executed trade""" + symbol: str size: Decimal price: float # denoted in the currency of the symbol @@ -29,11 +30,14 @@ class SimBroker(Broker): """ def __init__( - self, initial_deposit=1_000_000.0, account=None, price_type="DEFAULT", slippage=0.001, clean_up_orders=True + self, + initial_deposit=1_000_000.0, + price_type="DEFAULT", + slippage=0.001, + clean_up_orders=True, ): super().__init__() - self.initial_deposit = initial_deposit - self._account = account or Account() + self._account = Account() self._modify_orders: list[Order] = [] self._account.buying_power = initial_deposit self.slippage = slippage diff --git a/roboquant/traders/trader.py b/roboquant/traders/trader.py index c5249f5..d241a4e 100644 --- a/roboquant/traders/trader.py +++ b/roboquant/traders/trader.py @@ -16,7 +16,7 @@ class Trader(Protocol): def create_orders(self, signals: dict[str, Signal], event: Event, account: Account) -> list[Order]: """Create zero or more orders. - Args: + Arguments signals: Zero or more signals created by the strategy. event: The event with its items. account: The latest account object. diff --git a/tests/samples/papertrade_tiingo_ibkr.py b/tests/samples/papertrade_tiingo_ibkr.py index 7975094..50d0448 100644 --- a/tests/samples/papertrade_tiingo_ibkr.py +++ b/tests/samples/papertrade_tiingo_ibkr.py @@ -1,6 +1,7 @@ from datetime import timedelta import logging import roboquant as rq +from roboquant.account import Account, CurrencyConverter from roboquant.brokers.ibkrbroker import IBKRBroker from roboquant.feeds.feedutil import get_sp500_symbols @@ -10,7 +11,10 @@ logging.getLogger("roboquant").setLevel(level=logging.INFO) # Connect to local running TWS or IB Gateway - ibkr = IBKRBroker() + converter = CurrencyConverter("EUR", "USD") + converter.register_rate("USD", 0.91) + Account.register_converter(converter) + ibkr = IBKRBroker.use_tws() # Connect to Tiingo and subscribe to S&P-500 stocks src_feed = rq.feeds.TiingoLiveFeed(market="iex")