Skip to content

Commit

Permalink
Improved doc and naming variables
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Aug 26, 2024
1 parent 6ffcfa7 commit 6706b47
Show file tree
Hide file tree
Showing 13 changed files with 48 additions and 45 deletions.
4 changes: 2 additions & 2 deletions roboquant/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def position_value(self, asset: Asset) -> float:

def short_positions(self) -> dict[Asset, Position]:
"""Return al the short positions in the account"""
return {symbol: position for (symbol, position) in self.positions.items() if position.is_short}
return {asset: position for (asset, position) in self.positions.items() if position.is_short}

def long_positions(self) -> dict[Asset, Position]:
"""Return al the long positions in the account"""
return {symbol: position for (symbol, position) in self.positions.items() if position.is_long}
return {asset: position for (asset, position) in self.positions.items() if position.is_long}

def contract_value(self, asset: Asset, size: Decimal, price: float) -> float:
"""Contract value denoted in the base currency of the account"""
Expand Down
8 changes: 8 additions & 0 deletions roboquant/alpaca/feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def _get_asset(symbol: str, asset_class: AssetClass) -> Asset:
return Option(symbol)


def _assert_keys(api_key, secret_key):
assert api_key, "no api key provided or found"
assert secret_key, "no secret key provided or found"


class AlpacaLiveFeed(LiveFeed):
"""Subscribe to live market data for stocks, cryptocurrencies or options"""

Expand All @@ -55,6 +60,7 @@ def __init__(self, market: Literal["iex", "sip", "crypto", "option"] = "iex", ap
config = Config()
api_key = api_key or config.get("alpaca.public.key")
secret_key = secret_key or config.get("alpaca.secret.key")
_assert_keys(api_key, secret_key)
self.market = market

match market:
Expand Down Expand Up @@ -153,6 +159,7 @@ def __init__(self, api_key=None, secret_key=None, data_api_url=None, feed: DataF
config = Config()
api_key = api_key or config.get("alpaca.public.key")
secret_key = secret_key or config.get("alpaca.secret.key")
_assert_keys(api_key, secret_key)
self.client = StockHistoricalDataClient(api_key, secret_key, url_override=data_api_url)
self.feed = feed

Expand Down Expand Up @@ -189,6 +196,7 @@ def __init__(self, api_key=None, secret_key=None, data_api_url=None):
config = Config()
api_key = api_key or config.get("alpaca.public.key")
secret_key = secret_key or config.get("alpaca.secret.key")
_assert_keys(api_key, secret_key)
self.client = CryptoHistoricalDataClient(api_key, secret_key, url_override=data_api_url)

def retrieve_bars(self, *symbols, start=None, end=None, resolution: TimeFrame | None = None):
Expand Down
8 changes: 4 additions & 4 deletions roboquant/brokers/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def _update_account(account: Account, event: Event | None, price_type: str = "DE

account.last_update = event.time

for symbol, position in account.positions.items():
if price := event.get_price(symbol, price_type):
for asset, position in account.positions.items():
if price := event.get_price(asset, price_type):
position.mkt_price = price


Expand All @@ -58,8 +58,8 @@ def _update_positions(account: Account, event: Event | None, price_type: str = "

account.last_update = event.time

for symbol, position in account.positions.items():
if price := event.get_price(symbol, price_type):
for asset, position in account.positions.items():
if price := event.get_price(asset, price_type):
position.mkt_price = price


Expand Down
22 changes: 12 additions & 10 deletions roboquant/feeds/feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __background():
thread.start()
return channel

def get_ohlcv(self, asset: Asset, timeframe: Timeframe | None = None) -> dict[str, list]:
def get_ohlcv(self, asset: Asset, timeframe: Timeframe | None = None) -> dict[str, list[float | datetime]]:
"""Get the OHLCV values for an asset in this feed.
The returned value is a dict with the keys being "Date", "Open", "High", "Low", "Close", "Volume"
and the values a list.
Expand All @@ -91,7 +91,7 @@ def get_ohlcv(self, asset: Asset, timeframe: Timeframe | None = None) -> dict[st
result["Volume"].append(item.ohlcv[4])
return result

def print_items(self, timeframe: Timeframe | None = None, timeout: float | None = None):
def print_items(self, timeframe: Timeframe | None = None, timeout: float | None = None) -> None:
"""Print the items in a feed to the console.
This is mostly useful for debugging purposes to see what items a feed generates.
"""
Expand All @@ -102,7 +102,7 @@ def print_items(self, timeframe: Timeframe | None = None, timeout: float | None
for item in event.items:
print("======> ", item)

def count_events(self, timeframe: Timeframe | None = None, timeout: float | None = None, include_empty=False):
def count_events(self, timeframe: Timeframe | None = None, timeout: float | None = None, include_empty=False) -> int:
"""Count the number of events in a feed"""

channel = self.play_background(timeframe)
Expand All @@ -112,7 +112,7 @@ def count_events(self, timeframe: Timeframe | None = None, timeout: float | None
events += 1
return events

def count_items(self, timeframe: Timeframe | None = None, timeout: float | None = None):
def count_items(self, timeframe: Timeframe | None = None, timeout: float | None = None) -> int:
"""Count the number of events in a feed"""

channel = self.play_background(timeframe)
Expand All @@ -121,7 +121,11 @@ def count_items(self, timeframe: Timeframe | None = None, timeout: float | None
items += len(evt.items)
return items

def to_dict(self, *assets: Asset, timeframe: Timeframe | None = None, price_type: str = "DEFAULT"):
def to_dict(
self, *assets: Asset, timeframe: Timeframe | None = None, price_type: str = "DEFAULT"
) -> dict[str, list[float | None]]:
"""Return the prices of one or more assets as a dict with the key being the synbol name."""

assert assets, "provide at least 1 asset"
result = {asset.symbol: [] for asset in assets}
channel = self.play_background(timeframe)
Expand All @@ -131,16 +135,14 @@ def to_dict(self, *assets: Asset, timeframe: Timeframe | None = None, price_type
result[asset.symbol].append(price)
return result

def plot(
self, asset: Asset, price_type: str = "DEFAULT", timeframe: Timeframe | None = None, plt: Any = pyplot, **kwargs
):
def plot(self, asset: Asset, price_type: str = "DEFAULT", timeframe: Timeframe | None = None, plt: Any = pyplot, **kwargs):
"""
Plot the prices of a symbol.
Plot the prices of a single asset.
Parameters
----------
asset : Asset
The symbol for which to plot prices.
The asset for which to plot prices.
price_type : str, optional
The type of price to plot, e.g. open, close, high, low. (default is "DEFAULT")
timeframe : Timeframe or None, optional
Expand Down
6 changes: 3 additions & 3 deletions roboquant/journals/alphabeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def __init__(self, window_size: int, price_type: str = "DEFAULT", risk_free_retu
def __get_market_value(self, prices: dict[Asset, float]):
cnt = 0
result = 0.0
for symbol in prices.keys():
if symbol in self.__last_prices:
for asset in prices.keys():
if asset in self.__last_prices:
cnt += 1
result += prices[symbol] / self.__last_prices[symbol]
result += prices[asset] / self.__last_prices[asset]
return 1.0 if cnt == 0 else result / cnt

def __update(self, equity, prices):
Expand Down
10 changes: 5 additions & 5 deletions roboquant/strategies/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def __init__(self, size: int):
self.size = size

def add_event(self, event: Event) -> set[Asset]:
"""Add a new event and return all the symbols that have been added and are ready to be processed"""
symbols: set[Asset] = set()
"""Add a new event and return all the assets that have been added and are ready to be processed"""
assets: set[Asset] = set()
for item in event.items:
if isinstance(item, Bar):
asset = item.asset
Expand All @@ -97,8 +97,8 @@ def add_event(self, event: Event) -> set[Asset]:
ohlcv = self[asset]
ohlcv.append(item.ohlcv)
if ohlcv.is_full():
symbols.add(asset)
return symbols
assets.add(asset)
return assets

def ready(self):
return {symbol for symbol, ohlcv in self.items() if ohlcv.is_full()}
return {asset for asset, ohlcv in self.items() if ohlcv.is_full()}
6 changes: 3 additions & 3 deletions roboquant/strategies/multistrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

class MultiStrategy(Strategy):
"""Combine one or more signal strategies. The MultiStrategy provides additional control on how to handle conflicting
signals for the same symbols via the signal_filter:
signals for the same asset via the signal_filter:
- first: in case of multiple signals for the same symbol, the first one wins
- last: in case of multiple signals for the same symbol, the last one wins.
- first: in case of multiple signals for the same asset, the first one wins
- last: in case of multiple signals for the same asset, the last one wins.
- mean: return the mean of the signal ratings. All signals will be ENTRY and EXIT.
- none: return all signals. This is also the default.
"""
Expand Down
2 changes: 1 addition & 1 deletion roboquant/strategies/tastrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ def create_signals(self, event) -> list[Signal]:
@abstractmethod
def process_asset(self, asset: Asset, ohlcv: OHLCVBuffer) -> Signal | None:
"""
Create zero or more orders for the provided symbol
Create zero or more orders for the provided asset
"""
...
8 changes: 4 additions & 4 deletions roboquant/traders/flextrader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def log(self, rule: str, **kwargs):
if logger.isEnabledFor(logging.INFO):
extra = " ".join(f"{k}={v}" for k, v in kwargs.items())
logger.info(
"Discarded signal because %s [symbol=%s rating=%s type=%s position=%s %s]",
"Discarded signal because %s [asset=%s rating=%s type=%s position=%s %s]",
rule,
self.signal.asset,
self.signal.rating,
Expand All @@ -67,11 +67,11 @@ def log(self, rule: str, **kwargs):

class FlexTrader(Trader):
"""Implementation of a Trader that has configurable rules to modify which signals are converted into orders.
This implementation will not generate orders if there is not a price in the event for the underlying symbol.
This implementation will not generate orders if there is not a price in the event for the underlying asset.
The configurable parameters include:
- one_order_only: don't create new orders for a symbol if there is already an open orders for that same symbol
- one_order_only: don't create new orders for a asset if there is already an open orders for that same asset
- size_fractions: enable fractional order sizes (if size_fractions is larger than 0), default is 0
- safety_margin_perc: the safety margin as percentage of equity that should remain available (to avoid margin calls),
default is 0.05 (5%)
Expand Down Expand Up @@ -221,7 +221,7 @@ def create_orders(self, signals: list[Signal], event: Event, account: Account) -

def _get_orders(self, asset: Asset, size: Decimal, item: PriceItem, signal: Signal, dt: datetime) -> list[Order]:
# pylint: disable=unused-argument
"""Return zero or more orders for the provided symbol and size."""
"""Return zero or more orders for the provided asset and size."""
gtd = None if not self.valid_for else dt + self.valid_for
return [Order(asset, size, item.price(), gtd)]

Expand Down
9 changes: 2 additions & 7 deletions tests/samples/sb3_strategy_quotes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from roboquant.asset import Stock
from roboquant.ml.features import EquityFeature, QuoteFeature
from roboquant.ml.rl import TradingEnv, SB3PolicyStrategy
from roboquant.feeds.parquet import ParquetFeed
from roboquant.timeframe import Timeframe

# %%
Expand All @@ -17,12 +16,8 @@
assert start < border < end

# %%
feed = ParquetFeed("/tmp/jpm.parquet")
if not feed.exists():
inputFeed = AlpacaHistoricStockFeed()
inputFeed.retrieve_quotes(asset.symbol, start=start, end=end)
feed.record(inputFeed)

feed = AlpacaHistoricStockFeed()
feed.retrieve_quotes(asset.symbol, start=start, end=end)
print("feed timeframe=", feed.timeframe())

obs_feature = QuoteFeature(asset).returns().normalize(20)
Expand Down
3 changes: 1 addition & 2 deletions tests/samples/talib_feature.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# %%
# pylint: disable=no-member
import talib.stream as ta
import roboquant as rq
from roboquant.asset import Asset
from roboquant.ml.features import TaFeature
from roboquant.strategies.buffer import OHLCVBuffer

# pylint: disable=no-member


# %%
class RSIFeature(TaFeature):
Expand Down
3 changes: 1 addition & 2 deletions tests/samples/talib_strategy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# %%
# pylint: disable=no-member
import talib.stream as ta
import roboquant as rq
from roboquant.signal import Signal
from roboquant.strategies import OHLCVBuffer, TaStrategy

# pylint: disable=no-member


# %%
class MyStrategy(TaStrategy):
Expand Down
4 changes: 2 additions & 2 deletions tests/samples/walkforward_yahoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import roboquant as rq

# %%
feed = rq.feeds.YahooFeed("JPM", "IBM", "F", start_date="2000-01-01")
feed = rq.feeds.YahooFeed("JPM", "IBM", "F", start_date="2000-01-01", end_date="2020-01-01")

# %%
# split the feed timeframe in 4 equal parts
# split the feed timeframe into 4 parts
timeframes = feed.timeframe().split(4)

# %%
Expand Down

0 comments on commit 6706b47

Please sign in to comment.