Skip to content

Commit

Permalink
moved update_positions to a function
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Mar 10, 2024
1 parent 4254a93 commit a1f9ed6
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 20 deletions.
24 changes: 9 additions & 15 deletions roboquant/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
from roboquant.event import Event

from roboquant.order import Order

Expand Down Expand Up @@ -133,15 +132,18 @@ def mkt_value(self) -> float:
)

def equity(self) -> float:
"""Return the equity of the account.
equity = cash + sum of the market value of the open positions
"""Return the equity of the account. It will calcaluate the sum of the mkt value of
each open position and add the available cash.
The returned value is denoted in the base currency of the account.
"""
return self.cash + self.mkt_value()

def unrealized_pnl(self) -> float:
"""Return the sum of the unrealized profit and loss for the open position."""
"""Return the sum of the unrealized profit and loss for the open position.
The returned value is denoted in the base currency of the account.
"""
return sum(
[self.contract_value(symbol, pos.size, pos.mkt_price - pos.avg_price) for symbol, pos in self.positions.items()],
0.0,
Expand All @@ -155,19 +157,11 @@ def has_open_order(self, symbol: str) -> bool:
return True
return False

def get_position_size(self, symbol) -> Decimal:
"""Return the position size for the symbol"""
def get_position_size(self, symbol: str) -> Decimal:
"""Return the position size for a symbol"""
pos = self.positions.get(symbol)
return pos.size if pos else Decimal(0)

def update_positions(self, event: Event, price_type: str = "DEFAULT"):
"""update the open positions with the latest market prices"""
self.last_update = event.time

for symbol, position in self.positions.items():
if price := event.get_price(symbol, price_type):
position.mkt_price = price

def open_orders(self):
"""Return a list with the open orders"""
return [order for order in self.orders if order.open]
Expand Down
12 changes: 12 additions & 0 deletions roboquant/brokers/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,15 @@ def sync(self, event: Event | None = None) -> Account:
"""
...


def _update_positions(account: Account, event: Event | None, price_type: str = "DEFAULT"):
"""update the open positions in the account with the latest market prices"""
if not event:
return

account.last_update = event.time

for symbol, position in account.positions.items():
if price := event.get_price(symbol, price_type):
position.mkt_price = price
3 changes: 2 additions & 1 deletion roboquant/brokers/ibkr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from roboquant.account import Account, Position
from roboquant.event import Event
from roboquant.order import Order, OrderStatus
from .broker import Broker
from .broker import Broker, _update_positions

assert VERSION["major"] == 10 and VERSION["minor"] == 19, "Wrong version of the IBAPI found"

Expand Down Expand Up @@ -194,6 +194,7 @@ def sync(self, event: Event | None = None) -> Account:
acc.buying_power = api.get_buying_power()
acc.cash = api.get_cash()

_update_positions(acc, event)
logger.debug("end sync")
return acc

Expand Down
3 changes: 2 additions & 1 deletion roboquant/brokers/simbroker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging

from roboquant.account import Account, Position
from roboquant.brokers.broker import Broker
from roboquant.brokers.broker import Broker, _update_positions
from roboquant.event import Event
from roboquant.order import Order, OrderStatus

Expand Down Expand Up @@ -184,6 +184,7 @@ def sync(self, event: Event | None = None) -> Account:

self._process_modify_order()
self._process_create_orders(prices)
_update_positions(acc, event, self.price_type)

acc.buying_power = acc.cash
acc.orders = list(self._create_orders.values())
Expand Down
4 changes: 1 addition & 3 deletions roboquant/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def run(
journal: Journal | None = None,
timeframe: Timeframe | None = None,
capacity: int = 10,
heartbeat_timeout: float | None = None,
price_type: str = "DEFAULT"
heartbeat_timeout: float | None = None
) -> Account:
"""Start a new run. Only the first two parameters, the feed and strategy, are mandatory.
The other parameters are optional.
Expand All @@ -45,7 +44,6 @@ def run(
while event := channel.get(heartbeat_timeout):
signals = strategy.create_signals(event)
account = broker.sync(event)
account.update_positions(event, price_type)
orders = trader.create_orders(signals, event, account)
broker.place_orders(orders)
if journal:
Expand Down

0 comments on commit a1f9ed6

Please sign in to comment.