Skip to content

Commit

Permalink
fixed pylint findings
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaron committed Feb 29, 2024
1 parent 22ced95 commit eba654e
Show file tree
Hide file tree
Showing 19 changed files with 75 additions and 59 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ testpaths = [
[tool.pylint.'MESSAGES CONTROL']
max-line-length = 127
disable = "too-few-public-methods,missing-module-docstring,missing-class-docstring,missing-function-docstring,unnecessary-ellipsis"
max-args = 15
max-locals = 20
max-attributes = 10

[build-system]
requires = ["setuptools>=61.0"]
Expand Down
2 changes: 1 addition & 1 deletion roboquant/brokers/ibkrbroker.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def sync(self, event: Event | None = None) -> Account:
api.request_account()

acc.positions = {k: v for k, v in api.positions.items() if not v.size.is_zero()}
acc.orders = [order for order in api.orders.values()]
acc.orders = list(api.orders.values())
acc.buying_power = api.get_buying_power()
acc.equity = api.get_equity()

Expand Down
49 changes: 28 additions & 21 deletions roboquant/brokers/simbroker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
@dataclass(slots=True, frozen=True)
class _Trx:
"""transaction for an executed trade"""

symbol: str
size: Decimal
price: float # denoted in the currency of the symbol
Expand All @@ -20,7 +19,7 @@ class _Trx:
@dataclass
class _OrderState:
order: Order
accepted: datetime | None = None
expires_at: datetime | None = None


class SimBroker(Broker):
Expand Down Expand Up @@ -83,32 +82,39 @@ def _simulate_market(self, order: Order, item) -> _Trx | None:
"""Simulate a market for the three order types"""

price = self._get_execution_price(order, item)
if self._is_executable(order, price):
return _Trx(order.symbol, order.size, price)
fill = self._get_fill(order, price)
if fill:
return _Trx(order.symbol, fill, price)
return None

def __next_order_id(self):
result = str(self.__order_id)
self.__order_id += 1
return result

def _has_expired(self, state: _OrderState) -> bool:
if state.accepted is None:
"""Returns true if the order has expired, false otherwise"""
if state.expires_at is None:
return False
else:
return self._account.last_update - state.accepted > timedelta(days=180)

def _is_executable(self, order, price) -> bool:
"""Is this order executable given the provided execution price.
A market order is always executable, a limit order only when the limit is below the BUY price or
above the SELL price"""
return self._account.last_update >= state.expires_at

def _get_fill(self, order, price) -> Decimal:
"""Return the fill for the order given the provided price.
The default implementation is:
- A market order is always fully filled,
- A limit order only when the limit is below the BUY price or
above the SELL price."""
if order.limit is None:
return True
return order.remaining
if order.is_buy and price <= order.limit:
return True
return order.remaining
if order.is_sell and price >= order.limit:
return True
return order.remaining

return False
return Decimal(0)

def __update_mkt_prices(self, price_items):
"""track the latest market prices for all open positions"""
Expand All @@ -123,21 +129,21 @@ def place_orders(self, orders):
processed during time `t+1`. This protects against future bias.
"""
for order in orders:
assert not order.closed, "cannot place closed orders"
assert not order.closed, "cannot place a closed order"
if order.id is None:
order.id = self.__next_order_id()
assert order.id not in self._orders
self._orders[order.id] = _OrderState(order)
else:
assert order.id in self._orders, "existing order id not found"
assert order.id in self._orders, "existing order id is not found"
self._modify_orders.append(order)

def _process_modify_order(self):
for order in self._modify_orders:
state = self._orders[order.id] # type: ignore
if state.order.closed:
continue
elif order.is_cancellation:
if order.is_cancellation:
state.order.status = OrderStatus.CANCELLED
else:
state.order.size = order.size or state.order.size
Expand All @@ -153,12 +159,13 @@ def _process_create_orders(self, prices):
order.status = OrderStatus.EXPIRED
else:
if (item := prices.get(order.symbol)) is not None:
state.accepted = state.accepted or self._account.last_update
state.expires_at = state.expires_at or self._account.last_update + timedelta(days=90)
trx = self._simulate_market(order, item)
if trx is not None:
self._update_account(trx)
order.status = OrderStatus.FILLED
order.fill = order.size
order.fill += trx.size
if order.fill == order.size:
order.status = OrderStatus.FILLED

def sync(self, event: Event | None = None) -> Account:
"""This will perform the trading simulation for open orders and update the account accordingly."""
Expand Down
1 change: 1 addition & 0 deletions roboquant/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def get_price(self, symbol: str, price_type: str = "DEFAULT") -> float | None:

if item := self.price_items.get(symbol):
return item.price(price_type)
return None

def __repr__(self) -> str:
return f"Event(time={self.time} item={len(self.items)})"
3 changes: 2 additions & 1 deletion roboquant/feeds/eventchannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def put(self, event: Event):

if self.timeframe is None or event.time in self.timeframe:
self._queue.put(event)
elif not (event.time < self.timeframe.start):
elif not event.time < self.timeframe.start:
# we get in this branch when timeframe is not None and
# the event is past the provided timeframe.
self.close()
Expand All @@ -57,6 +57,7 @@ def get(self, timeout=None) -> Event | None:
return Event.empty(now)

self._closed = True
return None

def close(self):
"""close this channel and put a None message on the channel to indicate to consumers it is closed"""
Expand Down
3 changes: 2 additions & 1 deletion roboquant/feeds/feedutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def __background():
try:
feed.play(channel)
except ChannelClosed:
"""this exception we can expect"""
# this exception we can expect
pass
finally:
channel.close()

Expand Down
4 changes: 2 additions & 2 deletions roboquant/feeds/historicfeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def timeframe(self):
tl = self.timeline()
if len(tl) == 0:
return Timeframe.empty()
else:
return Timeframe(tl[0], tl[-1], inclusive=True)

return Timeframe(tl[0], tl[-1], inclusive=True)

def __update(self):
if self.__modified:
Expand Down
11 changes: 6 additions & 5 deletions roboquant/feeds/tiingofeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, key: str | None = None):
super().__init__()
self.key = key or Config().get("tiingo.key")
assert self.key, "no Tiingo key found"
self.timeout = 10

@staticmethod
def __get_csv_iter(response: requests.Response):
Expand All @@ -60,7 +61,7 @@ def retrieve_eod_stocks(
for symbol in symbols:
url = f"https://api.tiingo.com/tiingo/daily/{symbol}/prices?{query}"
logger.debug("eod stock url is %s", url)
response = requests.get(url)
response = requests.get(url, timeout=self.timeout)
if not response.ok:
logger.warning("error symbol=%s reason=%s", symbol, response.reason)
continue
Expand All @@ -87,7 +88,7 @@ def retrieve_intraday_iex(self, *symbols: str, start_date="2023-01-01", end_date
for symbol in symbols:
url = f"https://api.tiingo.com/iex/{symbol}/prices?{query}"
logger.debug("intraday iex is %s", url)
response = requests.get(url)
response = requests.get(url, timeout=self.timeout)
if not response.ok:
logger.warning("error symbol=%s reason=%s", symbol, response.reason)
continue
Expand All @@ -106,7 +107,7 @@ def retrieve_intraday_crypto(self, *symbols: str, start_date="2023-01-01", end_d

url = f"https://api.tiingo.com/tiingo/crypto/prices?{query}"
logger.debug("intraday crypto url is %s", url)
response = requests.get(url)
response = requests.get(url, timeout=self.timeout)
if not response.ok:
logger.warning("error reason=%s", response.reason)
return
Expand All @@ -126,7 +127,7 @@ def retrieve_intraday_fx(self, *symbols: str, start_date="2023-01-01", end_date:
query = f"startDate={start_date}&endDate={end_date}&format=csv&resampleFreq={frequency}&token={self.key}"
url = f"https://api.tiingo.com/tiingo/fx/{symbols_str}/prices?{query}"

response = requests.get(url)
response = requests.get(url, timeout=self.timeout)
logger.debug("intraday fx url is %s", url)
if not response.ok:
logger.warning("error reason=%s", response.reason)
Expand Down Expand Up @@ -154,7 +155,7 @@ def __init__(self, key: str | None = None, market: Literal["crypto", "iex", "fx"
self.channel = None

url = f"wss://api.tiingo.com/{market}"
logger.info(f"Opening websocket {url}")
logger.info("Opening websocket url=%s", url)
self.ws = websocket.WebSocketApp( # type: ignore
url, on_message=self._handle_message, on_error=self._handle_error, on_close=self._handle_close # type: ignore
)
Expand Down
6 changes: 3 additions & 3 deletions roboquant/journals/alphabeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def calc(self, event, account, signals, orders):

if self.__cnt <= self._data.shape[-1]:
return {}
else:
alpha, beta = self.alpha_beta()
return {"perf/alpha": alpha, "perf/beta": beta}

alpha, beta = self.alpha_beta()
return {"perf/alpha": alpha, "perf/beta": beta}

def alpha_beta(self) -> Tuple[float, float]:
ar_total, mr_total = np.cumprod(self._data, axis=1)[:, -1]
Expand Down
4 changes: 4 additions & 0 deletions roboquant/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,7 @@ def is_buy(self):
def is_sell(self):
"""Return True if this is a SELL order, False otherwise"""
return self.size < 0

@property
def remaining(self):
return self.size - self.fill
15 changes: 7 additions & 8 deletions roboquant/strategies/featureset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,12 @@ def calc(self, evt):
values = self.feature.calc(evt)

if not self.history:
nan = float("nan")
self.history = values
return np.full(values.shape, [nan])
else:
r = values / self.history - 1.0
self.history = values
return r
return np.full(values.shape, float("nan"))

r = values / self.history - 1.0
self.history = values
return r


class SMAFeature(Feature):
Expand All @@ -95,8 +94,8 @@ def calc(self, evt):

if self._cnt < self.period:
return np.full((values.size,), np.nan)
else:
return np.mean(self.history, axis=0)

return np.mean(self.history, axis=0)


class DayOfWeekFeature(Feature):
Expand Down
6 changes: 3 additions & 3 deletions roboquant/strategies/rnnstrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ def predict_rating(self) -> float | None:
p = output.item()
if p > self.pct:
return 1.0
elif p < -self.pct:
if p < -self.pct:
return -1.0
else:
return None

return None

def _train_epoch(self, data_loader):
model, opt, crit = self.model, self.optimizer, self.criterion
Expand Down
3 changes: 1 addition & 2 deletions roboquant/timeframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def next(days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0,
def __contains__(self, time):
if self.inclusive:
return self.start <= time <= self.end
else:
return self.start <= time < self.end
return self.start <= time < self.end

def __repr__(self):
last_char = "]" if self.inclusive else ">"
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/test_tiingolivefeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ def setUp(self):
config = Config()
self.key = config.get("tiingo.key")

def test_tiingo_cryptolivefeed(self):
def test_tiingo_crypt_live_feed(self):
feed = TiingoLiveFeed(self.key)
feed.subscribe("btcusdt", "ethusdt")
run_priceitem_feed(feed, ["BTCUSDT", "ETHUSDT"], self, Timeframe.next(minutes=1))
feed.close()

def test_tiingo_fxlivefeed(self):
def test_tiingo_fx_live_feed(self):
feed = TiingoLiveFeed(self.key, "fx")
feed.subscribe("eurusd")
run_priceitem_feed(feed, ["EURUSD"], self, Timeframe.next(minutes=1))
feed.close()

def test_tiingo_iexlivefeed(self):
def test_tiingo_iex_live_feed(self):
feed = TiingoLiveFeed(self.key, "iex")
feed.subscribe("IBM", "TSLA")
run_priceitem_feed(feed, ["IBM", "TSLA"], self, Timeframe.next(minutes=1))
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_crossover.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

class TestCrossover(unittest.TestCase):

def test_smacrossover(self):
def test_sma_crossover(self):
strategy = SMACrossover(13, 26)
run_strategy(strategy, self)

def test_emacrossover(self):
def test_ema_crossover(self):
strategy = SMACrossover(13, 26)
run_strategy(strategy, self)

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_csvfeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@

class TestCSVFeed(unittest.TestCase):

def test_csvfeed_generic(self):
def test_csv_feed_generic(self):
root = pathlib.Path(__file__).parent.resolve().joinpath("data", "csv")
feed = CSVFeed(root, time_offset="21:00:00+00:00")
run_priceitem_feed(feed, ["AAPL", "AMZN", "TSLA"], self)

def test_csvfeed_yahoo(self):
def test_csv_feed_yahoo(self):
root = pathlib.Path(__file__).parent.resolve().joinpath("data", "yahoo")
feed = CSVFeed.yahoo(root)
run_priceitem_feed(feed, ["META"], self)

def test_csvfeed_stooq(self):
def test_csv_feed_stooq(self):
root = pathlib.Path(__file__).parent.resolve().joinpath("data", "stooq")
feed = CSVFeed.stooq_us_daily(root)
run_priceitem_feed(feed, ["IBM"], self)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_featureset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class TestFeatureSet(unittest.TestCase):

def test_featureset(self):
def test_feature_set(self):
feed = get_feed()
symbols = feed.symbols
symbol1 = symbols[0]
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_flextrader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def _get_orders(self, symbol, size, action, rating):

class TestFlexTrader(unittest.TestCase):

def test_default_flextrader(self):
def test_default_flex_trader(self):
feed = get_feed()
journal = BasicJournal()
rq.run(feed, EMACrossover(), journal=journal)
self.assertGreater(journal.orders, 0)

def test_custom_flextrader(self):
def test_custom_flex_trader(self):
feed = get_feed()
journal = BasicJournal()
rq.run(feed, EMACrossover(), trader=_MyTrader(), journal=journal)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_multistrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class TestMultiStrategy(unittest.TestCase):

def test_multistrategies(self):
def test_multi_strategies(self):
strategy = MultiStrategy(
EMACrossover(13, 26),
EMACrossover(5, 12),
Expand Down

0 comments on commit eba654e

Please sign in to comment.