Skip to content

Commit

Permalink
REF: Remove Number in type hints for Python client
Browse files Browse the repository at this point in the history
  • Loading branch information
nmacholl committed Jan 11, 2024
1 parent d1dd6d5 commit ab5d54d
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 48 deletions.
17 changes: 9 additions & 8 deletions databento/common/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from datetime import date
from functools import partial
from functools import singledispatch
from numbers import Number
from numbers import Integral
from typing import Any

import pandas as pd
from databento_dbn import SType
Expand Down Expand Up @@ -59,7 +60,7 @@ def optional_values_list_to_string(

@singledispatch
def optional_symbols_list_to_list(
symbols: Iterable[str] | Iterable[Number] | str | Number | None,
symbols: Iterable[str | int | Integral] | str | int | Integral | None,
stype_in: SType,
) -> list[str]:
"""
Expand All @@ -68,7 +69,7 @@ def optional_symbols_list_to_list(
Parameters
----------
symbols : iterable of str, iterable of Number, str, or Number optional
symbols : Iterable of str or int or Number, or str or int or Number, optional
The symbols to concatenate.
stype_in : SType
The input symbology type for the request.
Expand All @@ -84,7 +85,7 @@ def optional_symbols_list_to_list(
"""
raise TypeError(
f"`{symbols}` is not a valid type for symbol input; "
"allowed types are Iterable[str], Iterable[int], str, int, and None.",
"allowed types are Iterable[str | int], str, int, and None.",
)


Expand All @@ -102,10 +103,10 @@ def _(_: None, __: SType) -> list[str]:
return [ALL_SYMBOLS]


@optional_symbols_list_to_list.register(cls=Number)
def _(symbols: Number, stype_in: SType) -> list[str]:
@optional_symbols_list_to_list.register(cls=Integral)
def _(symbols: Integral, stype_in: SType) -> list[str]:
"""
Dispatch method for optional_symbols_list_to_list. Handles numerical types,
Dispatch method for optional_symbols_list_to_list. Handles integral types,
alerting when an integer is given for STypes that expect strings.
See Also
Expand Down Expand Up @@ -147,7 +148,7 @@ def _(symbols: str, stype_in: SType) -> list[str]:


@optional_symbols_list_to_list.register(cls=Iterable)
def _(symbols: Iterable[str] | Iterable[int], stype_in: SType) -> list[str]:
def _(symbols: Iterable[Any], stype_in: SType) -> list[str]:
"""
Dispatch method for optional_symbols_list_to_list. Handles Iterables by
dispatching the individual members.
Expand Down
5 changes: 3 additions & 2 deletions databento/historical/api/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import os
from collections.abc import Iterable
from datetime import date
from os import PathLike
from pathlib import Path
Expand Down Expand Up @@ -48,7 +49,7 @@ def __init__(self, key: str, gateway: str) -> None:
def submit_job(
self,
dataset: Dataset | str,
symbols: list[str] | str,
symbols: Iterable[str | int] | str | int,
schema: Schema | str,
start: pd.Timestamp | date | str | int,
end: pd.Timestamp | date | str | int | None = None,
Expand All @@ -75,7 +76,7 @@ def submit_job(
----------
dataset : Dataset or str
The dataset code (string identifier) for the request.
symbols : list[str | int] or str
symbols : Iterable[str | int] or str or int
The instrument symbols to filter for. Takes up to 2,000 symbols per request.
If more than 1 symbol is specified, the data is merged and sorted by time.
If 'ALL_SYMBOLS' or `None` then will be for **all** symbols.
Expand Down
13 changes: 7 additions & 6 deletions databento/historical/api/metadata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections.abc import Iterable
from datetime import date
from typing import Any

Expand Down Expand Up @@ -261,7 +262,7 @@ def get_record_count(
dataset: Dataset | str,
start: pd.Timestamp | date | str | int,
end: pd.Timestamp | date | str | int | None = None,
symbols: list[str] | str | None = None,
symbols: Iterable[str | int] | str | int | None = None,
schema: Schema | str = "trades",
stype_in: SType | str = "raw_symbol",
limit: int | None = None,
Expand All @@ -285,7 +286,7 @@ def get_record_count(
If an integer is passed, then this represents nanoseconds since the UNIX epoch.
Values are forward filled based on the resolution provided.
Defaults to the same value as `start`.
symbols : list[str | int] or str, optional
symbols : Iterable[str | int] or str or int, optional
The instrument symbols to filter for. Takes up to 2,000 symbols per request.
If 'ALL_SYMBOLS' or `None` then will be for **all** symbols.
schema : Schema or str {'mbo', 'mbp-1', 'mbp-10', 'trades', 'tbbo', 'ohlcv-1s', 'ohlcv-1m', 'ohlcv-1h', 'ohlcv-1d', 'definition', 'statistics', 'status'}, default 'trades' # noqa
Expand Down Expand Up @@ -329,7 +330,7 @@ def get_billable_size(
dataset: Dataset | str,
start: pd.Timestamp | date | str | int,
end: pd.Timestamp | date | str | int | None = None,
symbols: list[str] | str | None = None,
symbols: Iterable[str | int] | str | int | None = None,
schema: Schema | str = "trades",
stype_in: SType | str = "raw_symbol",
limit: int | None = None,
Expand All @@ -354,7 +355,7 @@ def get_billable_size(
If an integer is passed, then this represents nanoseconds since the UNIX epoch.
Values are forward filled based on the resolution provided.
Defaults to the same value as `start`.
symbols : list[str | int] or str, optional
symbols : Iterable[str | int] or str or int, optional
The instrument symbols to filter for. Takes up to 2,000 symbols per request.
If 'ALL_SYMBOLS' or `None` then will be for **all** symbols.
schema : Schema or str {'mbo', 'mbp-1', 'mbp-10', 'trades', 'tbbo', 'ohlcv-1s', 'ohlcv-1m', 'ohlcv-1h', 'ohlcv-1d', 'definition', 'statistics', 'status'}, default 'trades' # noqa
Expand Down Expand Up @@ -399,7 +400,7 @@ def get_cost(
start: pd.Timestamp | date | str | int,
end: pd.Timestamp | date | str | int | None = None,
mode: FeedMode | str = "historical-streaming",
symbols: list[str] | str | None = None,
symbols: Iterable[str | int] | str | int | None = None,
schema: Schema | str = "trades",
stype_in: SType | str = "raw_symbol",
limit: int | None = None,
Expand All @@ -426,7 +427,7 @@ def get_cost(
Defaults to the same value as `start`.
mode : FeedMode or str {'live', 'historical-streaming', 'historical'}, default 'historical-streaming'
The data feed mode for the request.
symbols : list[str | int] or str, optional
symbols : Iterable[str | int] or str or int, optional
The instrument symbols to filter for. Takes up to 2,000 symbols per request.
If 'ALL_SYMBOLS' or `None` then will be for **all** symbols.
schema : Schema or str {'mbo', 'mbp-1', 'mbp-10', 'trades', 'tbbo', 'ohlcv-1s', 'ohlcv-1m', 'ohlcv-1h', 'ohlcv-1d', 'definition', 'statistics', 'status'}, default 'trades' # noqa
Expand Down
5 changes: 3 additions & 2 deletions databento/historical/api/symbology.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections.abc import Iterable
from datetime import date
from typing import Any

Expand Down Expand Up @@ -28,7 +29,7 @@ def __init__(self, key: str, gateway: str) -> None:
def resolve(
self,
dataset: Dataset | str,
symbols: list[str] | str,
symbols: Iterable[str | int] | str | int,
stype_in: SType | str,
stype_out: SType | str,
start_date: date | str,
Expand All @@ -43,7 +44,7 @@ def resolve(
----------
dataset : Dataset or str
The dataset code (string identifier) for the request.
symbols : list[str | int] or str, optional
symbols : Iterable[str | int] or str or int, optional
The symbols to resolve. Takes up to 2,000 symbols per request.
stype_in : SType or str, default 'raw_symbol'
The input symbology type to resolve from.
Expand Down
9 changes: 5 additions & 4 deletions databento/historical/api/timeseries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections.abc import Iterable
from datetime import date
from os import PathLike

Expand Down Expand Up @@ -35,7 +36,7 @@ def get_range(
dataset: Dataset | str,
start: pd.Timestamp | date | str | int,
end: pd.Timestamp | date | str | int | None = None,
symbols: list[str] | str | None = None,
symbols: Iterable[str | int] | str | int | None = None,
schema: Schema | str = "trades",
stype_in: SType | str = "raw_symbol",
stype_out: SType | str = "instrument_id",
Expand Down Expand Up @@ -67,7 +68,7 @@ def get_range(
If an integer is passed, then this represents nanoseconds since the UNIX epoch.
Values are forward filled based on the resolution provided.
Defaults to the same value as `start`.
symbols : list[str | instr | intt] or str, optional
symbols : Iterable[str | int], or str, or int, optional
The instrument symbols to filter for. Takes up to 2,000 symbols per request.
If more than 1 symbol is specified, the data is merged and sorted by time.
If 'ALL_SYMBOLS' or `None` then will be for **all** symbols.
Expand Down Expand Up @@ -131,7 +132,7 @@ async def get_range_async(
dataset: Dataset | str,
start: pd.Timestamp | date | str | int,
end: pd.Timestamp | date | str | int | None = None,
symbols: list[str] | str | None = None,
symbols: Iterable[str | int] | str | int | None = None,
schema: Schema | str = "trades",
stype_in: SType | str = "raw_symbol",
stype_out: SType | str = "instrument_id",
Expand Down Expand Up @@ -164,7 +165,7 @@ async def get_range_async(
If an integer is passed, then this represents nanoseconds since the UNIX epoch.
Values are forward filled based on the resolution provided.
Defaults to the same value as `start`.
symbols : list[str | int] or str, optional
symbols : Iterable[str | int] or str or int, optional
The instrument symbols to filter for. Takes up to 2,000 symbols per request.
If more than 1 symbol is specified, the data is merged and sorted by time.
If 'ALL_SYMBOLS' or `None` then will be for **all** symbols.
Expand Down
5 changes: 2 additions & 3 deletions databento/live/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import threading
from collections.abc import Iterable
from concurrent import futures
from numbers import Number
from os import PathLike
from typing import IO, Final

Expand Down Expand Up @@ -409,7 +408,7 @@ def subscribe(
self,
dataset: Dataset | str,
schema: Schema | str,
symbols: Iterable[str] | Iterable[Number] | str | Number = ALL_SYMBOLS,
symbols: Iterable[str | int] | str | int = ALL_SYMBOLS,
stype_in: SType | str = SType.RAW_SYMBOL,
start: str | int | None = None,
) -> None:
Expand All @@ -428,7 +427,7 @@ def subscribe(
The dataset for the subscription.
schema : Schema or str
The schema to subscribe to.
symbols : Iterable[str | Number] or str or Number, default 'ALL_SYMBOLS'
symbols : Iterable[str | int] or str or int, default 'ALL_SYMBOLS'
The symbols to subscribe to.
stype_in : SType or str, default 'raw_symbol'
The input symbology type to resolve from.
Expand Down
7 changes: 3 additions & 4 deletions databento/live/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
from collections.abc import Iterable
from functools import singledispatchmethod
from numbers import Number
from typing import Final

import databento_dbn
Expand Down Expand Up @@ -52,7 +51,7 @@ class DatabentoLiveProtocol(asyncio.BufferedProtocol):
----------
api_key : str
The user API key for authentication.
dataset : Dataset, or str
dataset : Dataset or str
The dataset for authentication.
ts_out : bool, default False
Flag for requesting `ts_out` to be appending to all records in the session.
Expand Down Expand Up @@ -253,7 +252,7 @@ def received_record(self, record: DBNRecord) -> None:
def subscribe(
self,
schema: Schema | str,
symbols: Iterable[str] | Iterable[Number] | str | Number = ALL_SYMBOLS,
symbols: Iterable[str | int] | str | int = ALL_SYMBOLS,
stype_in: SType | str = SType.RAW_SYMBOL,
start: str | int | None = None,
) -> None:
Expand All @@ -264,7 +263,7 @@ def subscribe(
----------
schema : Schema or str
The schema to subscribe to.
symbols : Iterable[str | Number] or str or Number, default 'ALL_SYMBOLS'
symbols : Iterable[str | int] or str or int, default 'ALL_SYMBOLS'
The symbols to subscribe to.
stype_in : SType or str, default 'raw_symbol'
The input symbology type to resolve from.
Expand Down
9 changes: 4 additions & 5 deletions databento/live/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import queue
import threading
from collections.abc import Iterable
from numbers import Number
from typing import IO, Callable, Final

import databento_dbn
Expand Down Expand Up @@ -335,7 +334,7 @@ def subscribe(
self,
dataset: Dataset | str,
schema: Schema | str,
symbols: Iterable[str] | Iterable[Number] | str | Number = ALL_SYMBOLS,
symbols: Iterable[str | int] | str | int = ALL_SYMBOLS,
stype_in: SType | str = SType.RAW_SYMBOL,
start: str | int | None = None,
) -> None:
Expand All @@ -345,11 +344,11 @@ def subscribe(
Parameters
----------
dataset : Dataset, str
dataset : Dataset or str
The dataset for the subscription.
schema : Schema or str
The schema to subscribe to.
symbols : Iterable[str | Number] or str or Number, default 'ALL_SYMBOLS'
symbols : Iterable[str | int] or str or int, default 'ALL_SYMBOLS'
The symbols to subscribe to.
stype_in : SType or str, default 'raw_symbol'
The input symbology type to resolve from.
Expand Down Expand Up @@ -503,7 +502,7 @@ async def _connect_task(
async def _subscribe_task(
self,
schema: Schema | str,
symbols: Iterable[str] | Iterable[Number] | str | Number = ALL_SYMBOLS,
symbols: Iterable[str | int] | str | int = ALL_SYMBOLS,
stype_in: SType | str = SType.RAW_SYMBOL,
start: str | int | None = None,
) -> None:
Expand Down
17 changes: 3 additions & 14 deletions tests/test_common_parsing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import datetime as dt
from numbers import Number
from typing import Any

import numpy as np
Expand Down Expand Up @@ -105,7 +104,7 @@ def test_optional_symbols_list_to_list_given_valid_inputs_returns_expected(
],
)
def test_optional_symbols_list_to_list_int(
symbols: list[Number] | Number | None,
symbols: list[int] | int | None,
stype: SType,
expected: list[object] | type[Exception],
) -> None:
Expand Down Expand Up @@ -138,20 +137,10 @@ def test_optional_symbols_list_to_list_int(
SType.INSTRUMENT_ID,
["12345", "67890"],
),
pytest.param(
[np.int_(12345), np.longlong(67890)],
SType.INSTRUMENT_ID,
["12345", "67890"],
),
pytest.param(
[np.int_(12345), np.longlong(67890)],
SType.INSTRUMENT_ID,
["12345", "67890"],
),
],
)
def test_optional_symbols_list_to_list_numpy(
symbols: list[Number] | Number | None,
symbols: list[int] | int | None,
stype: SType,
expected: list[object] | type[Exception],
) -> None:
Expand Down Expand Up @@ -190,7 +179,7 @@ def test_optional_symbols_list_to_list_numpy(
],
)
def test_optional_symbols_list_to_list_raw_symbol(
symbols: list[Number] | Number | None,
symbols: list[int] | int | None,
stype: SType,
expected: list[object] | type[Exception],
) -> None:
Expand Down

0 comments on commit ab5d54d

Please sign in to comment.