diff --git a/roboquant/__init__.py b/roboquant/__init__.py index d93c5ec..0a5c69b 100644 --- a/roboquant/__init__.py +++ b/roboquant/__init__.py @@ -5,6 +5,7 @@ from roboquant import journals from roboquant import strategies from roboquant import traders +from roboquant import ml from .account import Account, Position, Converter, CurrencyConverter, OptionConverter from .config import Config from .event import Event, PriceItem, Candle, Trade, Quote diff --git a/roboquant/brokers/broker.py b/roboquant/brokers/broker.py index 47ea653..63e440c 100644 --- a/roboquant/brokers/broker.py +++ b/roboquant/brokers/broker.py @@ -1,13 +1,14 @@ -from typing import Protocol +from abc import ABC, abstractmethod from roboquant.account import Account from roboquant.event import Event from roboquant.order import Order -class Broker(Protocol): +class Broker(ABC): """A broker accepts orders and communicates its state through the account object""" + @abstractmethod def place_orders(self, orders: list[Order]): """ Place zero or more orders at this broker. @@ -23,6 +24,7 @@ def place_orders(self, orders: list[Order]): """ ... + @abstractmethod def sync(self, event: Event | None = None) -> Account: """Sync the state, and return an updated account to reflect the latest state. @@ -35,6 +37,9 @@ def sync(self, event: Event | None = None) -> Account: """ ... + def reset(self): + """Reset the state""" + def _update_positions(account: Account, event: Event | None, price_type: str = "DEFAULT"): """update the open positions in the account with the latest market prices""" diff --git a/roboquant/brokers/simbroker.py b/roboquant/brokers/simbroker.py index 1d3ab4b..99dd612 100644 --- a/roboquant/brokers/simbroker.py +++ b/roboquant/brokers/simbroker.py @@ -39,10 +39,20 @@ def __init__( self._create_orders: dict[str, Order] = {} self._account.cash = initial_deposit self._account.buying_power = initial_deposit + self._order_id = 0 + self.slippage = slippage self.price_type = price_type self.clean_up_orders = clean_up_orders - self.__order_id = 0 + self.initial_deposit = initial_deposit + + def reset(self): + self._account = Account() + self._modify_orders: list[Order] = [] + self._create_orders: dict[str, Order] = {} + self._account.cash = self.initial_deposit + self._account.buying_power = self.initial_deposit + self._order_id = 0 def _update_account(self, trx: _Trx): """Update a position and cash based on a new transaction""" @@ -89,8 +99,8 @@ def _execute(self, order: Order, item) -> _Trx | None: return None def __next_order_id(self): - result = str(self.__order_id) - self.__order_id += 1 + result = str(self._order_id) + self._order_id += 1 return result def _has_expired(self, order: Order) -> bool: diff --git a/roboquant/event.py b/roboquant/event.py index 1e08e69..02b7fb8 100644 --- a/roboquant/event.py +++ b/roboquant/event.py @@ -145,5 +145,12 @@ def get_price(self, symbol: str, price_type: str = "DEFAULT") -> float | None: return item.price(price_type) return None + def get_volume(self, symbol: str, volume_type: str = "DEFAULT") -> float | None: + """Return the volume for the symbol, or None if not found.""" + + if item := self.price_items.get(symbol): + return item.volume(volume_type) + return None + def __repr__(self) -> str: return f"Event(time={self.time} items={len(self.items)})" diff --git a/roboquant/ml/__init__.py b/roboquant/ml/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/roboquant/strategies/features.py b/roboquant/ml/features.py similarity index 60% rename from roboquant/strategies/features.py rename to roboquant/ml/features.py index 3f008eb..08f5200 100644 --- a/roboquant/strategies/features.py +++ b/roboquant/ml/features.py @@ -5,7 +5,8 @@ import numpy as np from numpy.typing import NDArray -from roboquant import Signal +from roboquant.signal import Signal +from roboquant.account import Account from roboquant.event import Event, Candle from roboquant.feeds.feed import Feed from roboquant.strategies.strategy import Strategy @@ -14,12 +15,16 @@ class Feature(ABC): @abstractmethod - def calc(self, evt: Event) -> NDArray: + def calc(self, evt: Event, account: Account) -> NDArray: """ Return the result as a 1-dimensional NDArray. The result should always be the same size. """ + @abstractmethod + def size(self) -> int: + "return the size of this feature" + def returns(self, period=1): if period == 1: return ReturnsFeature(self) @@ -28,6 +33,12 @@ def returns(self, period=1): def __getitem__(self, *args): return SlicedFeature(self, args) + def reset(self): + """Reset the state of the feature""" + + def _get_nan(self): + return np.full((self.size(),), float("nan")) + class SlicedFeature(Feature): @@ -35,11 +46,15 @@ def __init__(self, feature: Feature, args) -> None: super().__init__() self.args = args self.feature = feature + self._size = len(np.zeros((self.feature.size(),))[args]) - def calc(self, evt): - values = self.feature.calc(evt) + def calc(self, evt, account): + values = self.feature.calc(evt, account) return values[self.args] + def size(self): + return self._size + class TrueRangeFeature(Feature): """Calculates the true range value for a symbol""" @@ -49,7 +64,7 @@ def __init__(self, symbol: str) -> None: self.prev_close = None self.symbol = symbol - def calc(self, evt): + def calc(self, evt, account): item = evt.price_items.get(self.symbol) if item is None or not isinstance(item, Candle): return np.array([float("nan")]) @@ -66,6 +81,12 @@ def calc(self, evt): return np.array([result]) + def size(self) -> int: + return 1 + + def reset(self): + self.prev_close = None + class FixedValueFeature(Feature): @@ -73,39 +94,90 @@ def __init__(self, value: NDArray) -> None: super().__init__() self.value = value - def calc(self, evt): + def size(self) -> int: + return len(self.value) + + def calc(self, evt, account): return self.value class PriceFeature(Feature): - """Extract a single price for a symbol""" + """Extract a single price for one or more symbols""" - def __init__(self, symbol: str, price_type: str = "DEFAULT") -> None: - self.symbol = symbol + def __init__(self, *symbols: str, price_type: str = "DEFAULT") -> None: + super().__init__() + self.symbols = symbols self.price_type = price_type - self.name = f"{symbol}-{price_type}-PRICE" - def calc(self, evt): - item = evt.price_items.get(self.symbol) - price = item.price(self.price_type) if item else float("nan") - return np.array([price]) + def calc(self, evt, account): + prices = [evt.get_price(symbol, self.price_type) for symbol in self.symbols] + return np.array(prices, dtype=np.float32) + + def size(self) -> int: + return len(self.symbols) + + +class PositionSizeFeature(Feature): + """Extract the position value for a symbol as fraction of the total equity""" + + def __init__(self, *symbols: str) -> None: + super().__init__() + self.symbols = symbols + + def calc(self, evt, account): + size = self.size() + result = np.zeros((size,), dtype=np.float32) + for i in range(size): + symbol = self.symbols[i] + position = account.positions.get(symbol) + if position: + value = account.contract_value(symbol, position.size, position.mkt_price) + pos_size = value / account.equity() - 1.0 + result[i] = pos_size + return result + + def size(self) -> int: + return len(self.symbols) + + +class PositionPNLFeature(Feature): + """Extract the pnl for an open position for a symbol. Returns 0.0 if no open position""" + + def __init__(self, *symbols: str) -> None: + super().__init__() + self.symbols = symbols + + def calc(self, evt, account): + size = self.size() + result = np.zeros((size,), dtype=np.float32) + for i in range(size): + position = account.positions.get(self.symbols[i]) + if position: + pnl = position.mkt_price / position.avg_price - 1.0 + result[i] = pnl + return result + + def size(self) -> int: + return len(self.symbols) class CandleFeature(Feature): """Extract the ohlcv values for a symbol""" def __init__(self, symbol: str) -> None: - + super().__init__() self.symbol = symbol - self.name = f"{symbol}-CANDLE" - def calc(self, evt): + def calc(self, evt, account): item = evt.price_items.get(self.symbol) if isinstance(item, Candle): return np.array(item.ohlcv) return np.full((5,), float("nan")) + def size(self) -> int: + return 5 + class FillFeature(Feature): """If the feature returns nan, use the last complete values instead""" @@ -115,8 +187,8 @@ def __init__(self, feature: Feature) -> None: self.history = None self.feature: Feature = feature - def calc(self, evt): - values = self.feature.calc(evt) + def calc(self, evt, account): + values = self.feature.calc(evt, account) if np.any(np.isnan(values)): if self.history is not None: @@ -126,19 +198,28 @@ def calc(self, evt): self.history = values return values + def reset(self): + self.history = None + self.feature.reset() + + def size(self) -> int: + return self.feature.size() + class VolumeFeature(Feature): - """Extract the volume for a symbol""" + """Extract the volume for one or more symbols""" - def __init__(self, symbol: str, volume_type: str = "DEFAULT") -> None: + def __init__(self, *symbols: str, volume_type: str = "DEFAULT") -> None: super().__init__() - self.symbol = symbol + self.symbols = symbols self.volume_type = volume_type - def calc(self, evt: Event): - price_data = evt.price_items.get(self.symbol) - volume = price_data.volume(self.volume_type) if price_data else float("nan") - return np.array([volume]) + def calc(self, evt: Event, account: Account): + volumes = [evt.get_price(symbol, self.volume_type) for symbol in self.symbols] + return np.array(volumes, dtype=np.float32) + + def size(self) -> int: + return len(self.symbols) class ReturnsFeature(Feature): @@ -148,8 +229,8 @@ def __init__(self, feature: Feature) -> None: self.history = None self.feature: Feature = feature - def calc(self, evt): - values = self.feature.calc(evt) + def calc(self, evt, account): + values = self.feature.calc(evt, account) if self.history is None: self.history = values @@ -159,6 +240,13 @@ def calc(self, evt): self.history = values return r + def size(self) -> int: + return self.feature.size() + + def reset(self): + self.history = None + self.feature.reset() + class LongReturnsFeature(Feature): @@ -167,8 +255,8 @@ def __init__(self, feature: Feature, period: int) -> None: self.history = deque(maxlen=period) self.feature: Feature = feature - def calc(self, evt): - values = self.feature.calc(evt) + def calc(self, evt, account): + values = self.feature.calc(evt, account) h = self.history if len(h) < h.maxlen: # type: ignore @@ -179,6 +267,13 @@ def calc(self, evt): h.append(values) return r + def size(self) -> int: + return self.feature.size() + + def reset(self): + self.history.clear() + self.feature.reset() + class MaxReturnFeature(Feature): """Calculate the maximum return over a certain period. @@ -190,8 +285,8 @@ def __init__(self, feature: Feature, period: int) -> None: self.history = deque(maxlen=period) self.feature: Feature = feature - def calc(self, evt): - values = self.feature.calc(evt) + def calc(self, evt, account): + values = self.feature.calc(evt, account) h = self.history if len(h) < h.maxlen: # type: ignore @@ -202,6 +297,13 @@ def calc(self, evt): h.append(values) return r + def size(self) -> int: + return self.feature.size() + + def reset(self): + self.history.clear() + self.feature.reset() + class MinReturnFeature(Feature): """Calculate the minimum return over a certain period. @@ -213,8 +315,8 @@ def __init__(self, feature: Feature, period: int) -> None: self.history = deque(maxlen=period) self.feature: Feature = feature - def calc(self, evt): - values = self.feature.calc(evt) + def calc(self, evt, account): + values = self.feature.calc(evt, account) h = self.history if len(h) < h.maxlen: # type: ignore @@ -225,6 +327,13 @@ def calc(self, evt): h.append(values) return r + def size(self) -> int: + return self.feature.size() + + def reset(self): + self.history.clear() + self.feature.reset() + class SMAFeature(Feature): @@ -235,8 +344,8 @@ def __init__(self, feature: Feature, period: int) -> None: self.history = None self._cnt = 0 - def calc(self, evt): - values = self.feature.calc(evt) + def calc(self, evt, account): + values = self.feature.calc(evt, account) if self.history is None: self.history = np.zeros((self.period, values.size)) @@ -249,6 +358,13 @@ def calc(self, evt): return np.mean(self.history, axis=0) + def size(self) -> int: + return self.feature.size() + + def reset(self): + self.history = None + self.feature.reset() + class DayOfWeekFeature(Feature): """Calculate a one-hot-encoded day of the week, Monday being 0""" @@ -256,13 +372,16 @@ class DayOfWeekFeature(Feature): def __init__(self, tz=timezone.utc) -> None: self.tz = tz - def calc(self, evt): + def calc(self, evt, account): dt = datetime.astimezone(evt.time, self.tz) weekday = dt.weekday() result = np.zeros(7) result[weekday] = 1.0 return result + def size(self) -> int: + return 7 + class FeatureStrategy(Strategy, ABC): """Abstract base class for strategies wanting to use features @@ -294,7 +413,7 @@ def create_signals(self, event: Event) -> dict[str, Signal]: def predict(self, x: NDArray) -> dict[str, Signal]: ... def __get_row(self, evt, features) -> NDArray: - data = [feature.calc(evt) for feature in features] + data = [feature.calc(evt, None) for feature in features] return np.hstack(data, dtype=self._dtype) def _get_xy(self, feed: Feed, timeframe=None, warmup=0) -> tuple[NDArray, NDArray]: @@ -304,9 +423,9 @@ def _get_xy(self, feed: Feed, timeframe=None, warmup=0) -> tuple[NDArray, NDArra while evt := channel.get(): if warmup: for f in self._features_x: - f.calc(evt) + f.calc(evt, None) for f in self._features_y: - f.calc(evt) + f.calc(evt, None) warmup -= 1 else: x.append(self.__get_row(evt, self._features_x)) diff --git a/roboquant/ml/gymenv.py b/roboquant/ml/gymenv.py new file mode 100644 index 0000000..279e73b --- /dev/null +++ b/roboquant/ml/gymenv.py @@ -0,0 +1,124 @@ +import logging +import gymnasium as gym +from gymnasium import spaces +import numpy as np +from roboquant.account import Account + +from roboquant.brokers.simbroker import SimBroker +from roboquant.event import Event +from roboquant.feeds.eventchannel import EventChannel +from roboquant.feeds.feed import Feed +from roboquant.signal import Signal +from roboquant.ml.features import Feature +from roboquant.ml.torch import Normalize +from roboquant.traders.flextrader import FlexTrader + + +logger = logging.getLogger(__name__) + + +class TradingEnv(gym.Env): + # pylint: disable=too-many-instance-attributes,unused-argument + + metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4} + + def __init__( + self, features: list[Feature], feed: Feed, rating_symbols: list[str], warmup: int = 0, broker=None, trader=None + ): + self.broker = broker or SimBroker() + self.trader = trader or FlexTrader() + self.channel = EventChannel() + self.feed = feed + self.event: Event | None = None + self.account = self.broker.sync() + self.symbols = rating_symbols + self.features = features + self.warmup = warmup + self.last_equity = self.account.equity() + self.obs_normalizer = None + self.reward_normalizer = None + + action_size = len(rating_symbols) + obs_size = sum(feature.size() for feature in features) + + self.observation_space = spaces.Box(-1.0, 1.0, shape=(obs_size,), dtype=np.float32) + self.action_space = spaces.Box(-1.0, 1.0, shape=(action_size,), dtype=np.float32) + + self.render_mode = None + + def get_broker(self): + return SimBroker() + + def get_trader(self): + return FlexTrader() + + def calc_normalization(self, steps: int): + obs, _ = self.reset() + obs_buffer = np.zeros((steps, obs.shape[0]), dtype=np.float32) + reward_buffer = np.zeros((steps,), dtype=np.float32) + step = 0 + while step < steps: + action = self.action_space.sample() + obs, reward, terminated, _, _ = self.step(action) + if not terminated: + obs_buffer[step] = obs + reward_buffer[step] = reward + else: + self.reset() + step += 1 + + obs_norm = obs_buffer.mean(axis=0), obs_buffer.std(axis=0) + reward_norm = reward_buffer.mean(axis=0).item(), reward_buffer.std(axis=0).item() + self.obs_normalizer = Normalize(obs_norm) + self.reward_normalizer = Normalize(reward_norm) + + def _get_obs(self, evt: Event, account: Account): + data = [feature.calc(evt, account) for feature in self.features] + obs = np.hstack(data, dtype=np.float32) + return self.obs_normalizer(obs) if self.obs_normalizer else obs + + def _get_reward(self, evt: Event, account: Account) -> float: + equity = account.equity() + reward = equity / self.last_equity - 1.0 + self.last_equity = equity + return self.reward_normalizer(reward) if self.reward_normalizer else reward + + def step(self, action): + assert self.event is not None + assert self.account is not None + signals = {symbol: Signal(rating) for symbol, rating in zip(self.symbols, action)} + + orders = self.trader.create_orders(signals, self.event, self.account) + self.broker.place_orders(orders) + self.event = self.channel.get() + + if self.event: + self.account = self.broker.sync(self.event) + observation = self._get_obs(self.event, self.account) + reward = self._get_reward(self.event, self.account) + return observation, reward, False, False, {} + + return None, 0.0, True, False, {} + + def reset(self, *, seed=None, options=None): + super().reset(seed=seed, options=options) + self.broker.reset() + self.trader.reset() + for feature in self.features: + feature.reset() + + self.channel = self.feed.play_background() + + i = 0 + while i <= self.warmup: + self.event = self.channel.get() + assert self.event is not None, "feed empty during warmup" + self.account = self.broker.sync(self.event) + self.trader.create_orders({}, self.event, self.account) + observation = self._get_obs(self.event, self.account) + i += 1 + self.last_equity = self.account.equity() + return observation, {} + + def render(self): + pass diff --git a/roboquant/strategies/torch.py b/roboquant/ml/torch.py similarity index 97% rename from roboquant/strategies/torch.py rename to roboquant/ml/torch.py index 69f6c1a..b53eb9e 100644 --- a/roboquant/strategies/torch.py +++ b/roboquant/ml/torch.py @@ -4,8 +4,8 @@ import torch from torch.utils.data import DataLoader, Dataset -from roboquant import Signal, BUY, SELL -from roboquant.strategies.features import FeatureStrategy +from roboquant.signal import Signal, BUY, SELL +from roboquant.ml.features import FeatureStrategy logger = logging.getLogger(__name__) @@ -19,6 +19,9 @@ def __init__(self, norm): def __call__(self, sample): return (sample - self.mean) / self.stdev + def __str__(self) -> str: + return f"mean={self.mean} stdev={self.stdev}" + class SequenceDataset(Dataset): """Sequence Dataset""" diff --git a/roboquant/strategies/buffer.py b/roboquant/strategies/buffer.py index b2eb283..9b238f5 100644 --- a/roboquant/strategies/buffer.py +++ b/roboquant/strategies/buffer.py @@ -10,49 +10,41 @@ class NumpyBuffer: It uses a single Numpy array to store its data. """ - __slots__ = "_data", "_idx" + __slots__ = "_data", "_idx", "rows" - def __init__(self, columns: int, capacity: int, dtype: Any = "float32") -> None: + def __init__(self, rows: int, columns: int, dtype: Any = "float32", order="C") -> None: """Create a new Numpy buffer""" - self._data: NDArray = np.full((capacity, columns), np.nan, dtype=dtype) + size = int(rows * 1.25 + 3) + self._data: NDArray = np.full((size, columns), np.nan, dtype=dtype, order=order) # type: ignore self._idx = 0 + self.rows = rows - @classmethod - def _from_data(cls, data): - result = cls(0, 0) - result._data = data - result._idx = len(data) - return result + def append(self, data: array | NDArray | list | tuple): + if self._idx >= len(self._data): + self._data[0: self.rows] = self._data[-self.rows:] + self._idx = self.rows - def append(self, data: array | NDArray): - idx = self._idx % self.capacity - self._data[idx] = data + self._data[self._idx] = data self._idx += 1 - @property - def capacity(self): - return len(self._data) + def __array__(self): + start = max(0, self._idx - self.rows) + return self._data[start: self._idx] def _get(self, column): - if self._idx < self.capacity: - return self._data[: self._idx, column] - - idx = self._idx % self.capacity - return np.concatenate([self._data[idx:, column], self._data[:idx, column]]) + start = max(0, self._idx - self.rows) + return self._data[start: self._idx, column] def __len__(self): - return min(self._idx, self.capacity) - - def get_all(self): - """Return all the values in the buffer""" - if self._idx < self.capacity: - return self._data[: self._idx] + return min(self._idx, self.rows) - idx = self._idx % self.capacity - return np.concatenate([self._data[idx:], self._data[:idx]]) + def to_numpy(self): + """Return all the values in the buffer as a 2D numpy array""" + start = max(0, self._idx - self.rows) + return self._data[start: self._idx] def is_full(self) -> bool: - return self._idx >= self.capacity + return self._idx >= self.rows def reset(self): """reset the buffer""" @@ -61,12 +53,11 @@ def reset(self): class OHLCVBuffer(NumpyBuffer): - """A OHLCV buffer (first-in-first-out) of a fixed capacity. - """ + """A OHLCV buffer (first-in-first-out) of a fixed capacity.""" def __init__(self, capacity: int, dtype="float64") -> None: """Create a new OHLCV buffer""" - super().__init__(5, capacity, dtype) + super().__init__(capacity, 5, dtype) def open(self) -> NDArray: """Return the open prices""" diff --git a/roboquant/strategies/candlestrategy.py b/roboquant/strategies/candlestrategy.py index 239239f..d22dc78 100644 --- a/roboquant/strategies/candlestrategy.py +++ b/roboquant/strategies/candlestrategy.py @@ -1,5 +1,4 @@ from abc import abstractmethod, ABC -from typing import Dict from roboquant.event import Candle from roboquant.signal import Signal @@ -14,10 +13,10 @@ class CandleStrategy(Strategy, ABC): def __init__(self, size: int) -> None: super().__init__() - self._data: Dict[str, OHLCVBuffer] = {} + self._data: dict[str, OHLCVBuffer] = {} self.size = size - def create_signals(self, event) -> Dict[str, Signal]: + def create_signals(self, event) -> dict[str, Signal]: signals = {} for item in event.items: if isinstance(item, Candle): diff --git a/roboquant/traders/trader.py b/roboquant/traders/trader.py index d241a4e..585d496 100644 --- a/roboquant/traders/trader.py +++ b/roboquant/traders/trader.py @@ -1,4 +1,4 @@ -from typing import Protocol +from abc import ABC, abstractmethod from roboquant.account import Account from roboquant.event import Event @@ -6,13 +6,14 @@ from roboquant.signal import Signal -class Trader(Protocol): +class Trader(ABC): """A trader creates the orders, typically based on the signals it receives from a strategy. But it is also possible to implement all logic in a Trader and don't rely on signals at all. In contrast to a `Strategy`, a `Trader` can also access the `Account` object. """ + @abstractmethod def create_orders(self, signals: dict[str, Signal], event: Event, account: Account) -> list[Order]: """Create zero or more orders. @@ -25,3 +26,6 @@ def create_orders(self, signals: dict[str, Signal], event: Event, account: Accou A list containing zero or more orders. """ ... + + def reset(self): + """Reset the state""" diff --git a/tests/samples/sb3.py b/tests/samples/sb3.py new file mode 100644 index 0000000..30237ee --- /dev/null +++ b/tests/samples/sb3.py @@ -0,0 +1,43 @@ +from gymnasium.wrappers.frame_stack import FrameStack +from stable_baselines3 import A2C + +from roboquant.feeds.yahoo import YahooFeed +from roboquant.ml.features import PriceFeature, VolumeFeature, SMAFeature, PositionPNLFeature +from roboquant.ml.gymenv import TradingEnv + + +def run(): + # pylint: disable=unused-variable + yahoo = YahooFeed("IBM", "JPM", start_date="2000-01-01", end_date="2020-12-31") + + features = [ + PriceFeature("IBM", "JPM").returns(), + VolumeFeature("IBM", "JPM").returns(), + SMAFeature(PriceFeature("JPM"), 10).returns(), + PositionPNLFeature("IBM", "JPM"), + ] + + trading = TradingEnv(features, yahoo, yahoo.symbols, warmup=20) + trading.calc_normalization(1000) + + env = FrameStack(trading, 10) + model = A2C("MlpPolicy", env, verbose=1) + + # Train the model + model.learn(total_timesteps=1_000_000) + + # Run the trained model on out of sample data + venv = model.get_env() + assert venv is not None + trading.feed = YahooFeed("IBM", "JPM", start_date="2021-01-01") + obs = venv.reset() + done = False + while not done: + action, _state = model.predict(obs, deterministic=True) # type: ignore + obs, reward, done, info = venv.step(action) + + print(trading.last_equity) + + +if __name__ == "__main__": + run() diff --git a/tests/samples/torch_lstm.py b/tests/samples/torch_lstm.py index 4410814..8cfc872 100644 --- a/tests/samples/torch_lstm.py +++ b/tests/samples/torch_lstm.py @@ -6,8 +6,8 @@ import roboquant as rq from roboquant.journals.basicjournal import BasicJournal -from roboquant.strategies.features import CandleFeature, MaxReturnFeature, PriceFeature, SMAFeature -from roboquant.strategies.torch import RNNStrategy +from roboquant.ml.features import CandleFeature, MaxReturnFeature, PriceFeature, SMAFeature +from roboquant.ml.torch import RNNStrategy class _MyModel(nn.Module): diff --git a/tests/unit/test_buffer.py b/tests/unit/test_buffer.py new file mode 100644 index 0000000..b445b6d --- /dev/null +++ b/tests/unit/test_buffer.py @@ -0,0 +1,30 @@ +from array import array +import unittest +import numpy as np + +from roboquant.strategies.buffer import NumpyBuffer + + +class TestBuffer(unittest.TestCase): + + def test_buffer(self): + b = NumpyBuffer(10, 5) + x = np.arange(100).reshape(20, 5) + for row in x: + b.append(row) + + c = np.asarray(b) + e = c == np.arange(50, 100).reshape(10, 5) + self.assertTrue(e.all()) + + def test_buffer_append(self): + b = NumpyBuffer(10, 2) + b.append(array("f", [1, 2])) + b.append([3, 4]) + b.append((5, 6)) + a = np.asarray(b) + self.assertEqual(3, len(a)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_features.py b/tests/unit/test_features.py index 61f3324..ffd4cc2 100644 --- a/tests/unit/test_features.py +++ b/tests/unit/test_features.py @@ -1,9 +1,10 @@ import unittest import numpy as np +from roboquant.account import Account from roboquant.event import Event -from roboquant.strategies.features import ( +from roboquant.ml.features import ( PriceFeature, SMAFeature, ReturnsFeature, @@ -31,20 +32,22 @@ def test_all_features(self): VolumeFeature(symbol2), DayOfWeekFeature() ] + account = Account() channel = feed.play_background() while evt := channel.get(): for feature in fs: - result = feature.calc(evt) + result = feature.calc(evt, account) self.assertTrue(len(result) > 0) def test_core_feature(self): + account = Account() f = FixedValueFeature(np.ones(10,))[2:5] - values = f.calc(Event.empty()) + values = f.calc(Event.empty(), account) self.assertEqual(3, len(values)) f = FixedValueFeature(np.ones(10,)).returns() - values = f.calc(Event.empty()) - values = f.calc(Event.empty()) + values = f.calc(Event.empty(), account) + values = f.calc(Event.empty(), account) self.assertEqual(0, values[0]) diff --git a/tests/unit/test_torch.py b/tests/unit/test_torch.py index 6d924d1..3054a50 100644 --- a/tests/unit/test_torch.py +++ b/tests/unit/test_torch.py @@ -4,8 +4,8 @@ import torch.nn.functional as F import roboquant as rq -from roboquant.strategies.features import CandleFeature, PriceFeature, SMAFeature -from roboquant.strategies.torch import RNNStrategy +from roboquant.ml.features import CandleFeature, PriceFeature, SMAFeature +from roboquant.ml.torch import RNNStrategy from tests.common import get_feed @@ -35,8 +35,8 @@ def test_lstm_model(self): model = _MyModel() strategy = RNNStrategy(model, symbol, sequences=20, buy_pct=0.01) strategy.add_x(CandleFeature(symbol).returns()) - strategy.add_x(SMAFeature(PriceFeature(symbol, "HIGH"), 10).returns()) - strategy.add_y(PriceFeature(symbol, "CLOSE").returns(prediction)) + strategy.add_x(SMAFeature(PriceFeature(symbol, price_type="HIGH"), 10).returns()) + strategy.add_y(PriceFeature(symbol, price_type="CLOSE").returns(prediction)) # Train the model with 10 years of data tf = rq.Timeframe.fromisoformat("2010-01-01", "2020-01-01")