Skip to content

Commit

Permalink
Added support for simple converters
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Mar 7, 2024
1 parent d60a3cc commit bacbd9f
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 71 deletions.
4 changes: 2 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
"--ignore=tests/integration",
"--ignore=tests/performance"
],
"python.testing.pytestEnabled": true,
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": false,
"python.testing.unittestEnabled": true,
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
}
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ testpaths = [
"tests/unit"
]

[tool.pyright]
reportOptionalOperand = "off"

[tool.pylint.'MESSAGES CONTROL']
max-line-length = 127
disable = "too-few-public-methods,missing-module-docstring,missing-class-docstring,missing-function-docstring,unnecessary-ellipsis"
Expand Down
128 changes: 81 additions & 47 deletions roboquant/account.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
Expand All @@ -16,8 +17,78 @@ class Position:
"""Average price paid denoted in the currency of the symbol"""


class Converter(ABC):
"""Abstraction that enables trading symbols that are denoted in different currencies and/or contact sizes"""

@abstractmethod
def __call__(self, symbol: str, time: datetime) -> float:
"""Return the conversion rate for the symbol at the given time"""
...


class OptionConverter(Converter):
"""
This converter handles common option contracts of size 100 and 10 and serves as an example.
If no contract size is registered for a symbol, it calculates one based on the symbol name.
If the symbol is not recognized as an OCC compliant option symbol, it is assumed to have a
contract size of 1.0
"""

def __init__(self):
super().__init__()
self._contract_sizes: dict[str, float] = {}

def register(self, symbol: str, contract_size: float = 100.0):
"""Register a contract-size for a symbol"""
self._contract_sizes[symbol] = contract_size

def __call__(self, symbol: str, time: datetime) -> float:
contract_size = self._contract_sizes.get(symbol)

# If no contract has been registered, we try to defer the contract size from the symbol
if contract_size is None:
if len(symbol) == 21:
# OCC compliant option symbol
symbol = symbol[0:6].rstrip()
contract_size = 10.0 if symbol[-1] == "7" else 100.0
else:
# not an option symbol
contract_size = 1.0

self._contract_sizes[symbol] = contract_size

return contract_size


class CurrencyConverter(Converter):
"""Support symbols that are denoted in a different currency from the base currency of the account"""

def __init__(self, base_currency="USD", default_symbol_currency="USD"):
super().__init__()
self.rates = {}
self.base_currency = base_currency
self.default_symbol_currency = default_symbol_currency
self.registered_symbols = {}

def register_symbol(self, symbol: str, currency: str):
"""Register a symbol being denoted in a currency"""
self.registered_symbols[symbol] = currency

def register_rate(self, currency: str, rate: float):
"""Register a conversion rate from a currency to the base_currency"""
self.rates[currency] = rate

def __call__(self, symbol: str, _: datetime) -> float:
currency = self.registered_symbols.get(symbol, self.default_symbol_currency)
if currency == self.base_currency:
return 1.0
return self.rates[currency]


class Account:
"""The account maintains the following state during a run:
"""Represents a trading account with all monetary amounts denoted in a single currency.
The account maintains the following state during a run:
- Available buying power for orders in the base currency of the account
- All the open positions
Expand All @@ -28,11 +99,7 @@ class Account:
Only the broker updates the account and does this only during its `sync` method.
"""

buying_power: float
positions: dict[str, Position]
orders: list[Order]
last_update: datetime
equity: float
__converter: Converter | None = None

def __init__(self):
self.buying_power: float = 0.0
Expand All @@ -41,16 +108,16 @@ def __init__(self):
self.last_update: datetime = datetime.fromisoformat("1900-01-01T00:00:00+00:00")
self.equity: float = 0.0

@staticmethod
def register_converter(converter: Converter):
"""Register a converter"""
Account.__converter = converter

def contract_value(self, symbol: str, size: Decimal, price: float) -> float:
# pylint: disable=unused-argument
"""Return the total value of the provided contract size denoted in the base currency of the account.
The default implementation returns `size * price`.
A subclass can implement different logic to cater for:
- symbols denoted in different currencies
- symbols having different contract sizes like option contracts.
"""
return float(size) * price
"""Return the total value of the provided contract size denoted in the base currency of the account."""
rate = 1.0 if not Account.__converter else Account.__converter.__call__(symbol, self.last_update)
return float(size) * price * rate

def mkt_value(self, prices: dict[str, float]) -> float:
"""Return the market value of all the open positions in the account using the provided prices.
Expand Down Expand Up @@ -115,36 +182,3 @@ def __repr__(self) -> str:
f"""last update : {self.last_update}"""
)
return result


class OptionAccount(Account):
"""
This account handles common option contracts of size 100 and 10 and serves as an example.
If no contract size is registered for a symbol, it creates one based on the option symbol name.
If the symbol is not recognized as an OCC compliant option symbol, it is assumed to have a
contract size of 1.0
"""

def __init__(self):
super().__init__()
self._contract_sizes: dict[str, float] = {}

def register(self, symbol: str, contract_size: float = 100.0):
"""Register a certain contract-size for a symbol"""
self._contract_sizes[symbol] = contract_size

def contract_value(self, symbol: str, size: Decimal, price: float) -> float:
contract_size = self._contract_sizes.get(symbol)

# If no contract has been registered, we try to defer the contract size from the symbol
if contract_size is None:
if len(symbol) == 21:
# OCC compliant option symbol
symbol = symbol[0:6].rstrip()
contract_size = 10.0 if symbol[-1] == "7" else 100.0
else:
# not an option symbol
contract_size = 1.0

return contract_size * float(size) * price
53 changes: 36 additions & 17 deletions roboquant/brokers/ibkrbroker.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,18 @@ def get_equity(self):
return self.account[equity_tag] or 0.0

def orderStatus(
self,
orderId,
status,
filled,
remaining,
avgFillPrice,
permId,
parentId,
lastFillPrice,
clientId,
whyHeld,
mktCapPrice,
self,
orderId,
status,
filled,
remaining,
avgFillPrice,
permId,
parentId,
lastFillPrice,
clientId,
whyHeld,
mktCapPrice,
):
logger.debug("order status orderId=%s status=%s fill=%s", orderId, status, filled)
orderId = str(orderId)
Expand All @@ -122,10 +122,20 @@ class IBKRBroker(Broker):
Map symbols to IBKR contracts.
If a symbol is not found, the symbol is assumed to represent a US stock
host
the ip number of the host where TWS or IB Gateway is running.
port
By default, TWS uses socket port 7496 for live sessions and 7497 for paper sessions.
IB Gateway by contrast uses 4001 for live sessions and 4002 for paper sessions.
However these are just defaults, and can be modified as desired.
client_id
The client id to use to connect to TWS or IB Gateway.
"""

def __init__(self, host="127.0.0.1", port=4002, account=None, client_id=123) -> None:
self.__account = account or Account()
def __init__(self, host="127.0.0.1", port=4002, client_id=123) -> None:
self.__account = Account()
self.contract_mapping: dict[str, Contract] = {}
api = _IBApi()
api.connect(host, port, client_id)
Expand All @@ -137,16 +147,25 @@ def __init__(self, host="127.0.0.1", port=4002, account=None, client_id=123) ->
self.__api_thread.start()
time.sleep(3.0)

@classmethod
def use_tws(cls, client_id=123):
"""Return a broker connected to the TWS papertrade instance with its default port (7497) settings"""
return cls("127.0.0.1", 7497, client_id)

@classmethod
def use_ibgateway(cls, client_id=123):
"""Return a broker connected to a IB Gateway papertrade instance with its default port (4002) settings"""
return cls("127.0.0.1", 4002, client_id)

def disconnect(self):
self.__api.reader.conn.disconnect()
self.__api.reader.conn.disconnect() # type: ignore

def _should_sync(self, now: datetime):
"""Avoid too many API calls"""
return self.__has_new_orders_since_sync or now - self.__account.last_update > timedelta(seconds=1)

def sync(self, event: Event | None = None) -> Account:
"""Sync with the IBKR account
"""
"""Sync with the IBKR account"""

logger.debug("start sync")
now = datetime.now(timezone.utc)
Expand Down
10 changes: 7 additions & 3 deletions roboquant/brokers/simbroker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
@dataclass(slots=True, frozen=True)
class _Trx:
"""transaction for an executed trade"""

symbol: str
size: Decimal
price: float # denoted in the currency of the symbol
Expand All @@ -29,11 +30,14 @@ class SimBroker(Broker):
"""

def __init__(
self, initial_deposit=1_000_000.0, account=None, price_type="DEFAULT", slippage=0.001, clean_up_orders=True
self,
initial_deposit=1_000_000.0,
price_type="DEFAULT",
slippage=0.001,
clean_up_orders=True,
):
super().__init__()
self.initial_deposit = initial_deposit
self._account = account or Account()
self._account = Account()
self._modify_orders: list[Order] = []
self._account.buying_power = initial_deposit
self.slippage = slippage
Expand Down
2 changes: 1 addition & 1 deletion roboquant/traders/trader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Trader(Protocol):
def create_orders(self, signals: dict[str, Signal], event: Event, account: Account) -> list[Order]:
"""Create zero or more orders.
Args:
Arguments
signals: Zero or more signals created by the strategy.
event: The event with its items.
account: The latest account object.
Expand Down
6 changes: 5 additions & 1 deletion tests/samples/papertrade_tiingo_ibkr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import timedelta
import logging
import roboquant as rq
from roboquant.account import Account, CurrencyConverter
from roboquant.brokers.ibkrbroker import IBKRBroker

from roboquant.feeds.feedutil import get_sp500_symbols
Expand All @@ -10,7 +11,10 @@
logging.getLogger("roboquant").setLevel(level=logging.INFO)

# Connect to local running TWS or IB Gateway
ibkr = IBKRBroker()
converter = CurrencyConverter("EUR", "USD")
converter.register_rate("USD", 0.91)
Account.register_converter(converter)
ibkr = IBKRBroker.use_tws()

# Connect to Tiingo and subscribe to S&P-500 stocks
src_feed = rq.feeds.TiingoLiveFeed(market="iex")
Expand Down

0 comments on commit bacbd9f

Please sign in to comment.