diff --git a/CHANGELOG.md b/CHANGELOG.md index 2346b91..3654b39 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## [1.8.0] - 2024-NOV-12 + +### Added +- Custom response type object for WebSocket channels + ## [1.7.0] - 2024-OCT-16 ### Added diff --git a/README.md b/README.md index c3a390a..5d12a8c 100644 --- a/README.md +++ b/README.md @@ -269,6 +269,35 @@ The functions described above handle the asynchronous nature of WebSocket connec We similarly provide async channel specific methods for subscribing and unsubscribing such as `ticker_async`, `ticker_unsubscribe_async`, etc. +### WebSocket Response Types +For your convenience, we have provided a custom, built-in WebSocket response type object to help interact with our WebSocket feeds more easily. + +Assume we simply want the price feed for BTC-USD and ETH-USD. +Like we did in previous steps, we subscribe to the `ticker` channel and include 'BTC-USD' and 'ETH-USD' in the `product_ids` list. +As the data comes through, it is passed into the `on_message` function. From there, we use it to build the `WebsocketResponse` object. + +Using said object, we can now extract only the desired parts. In our case, we retrieve and print only the `product_id` and `price` fields, resulting in a cleaner feed. +```python +def on_message(msg): + ws_object = WebsocketResponse(json.loads(msg)) + if ws_object.channel == "ticker" : + for event in ws_object.events: + for ticker in event.tickers: + print(ticker.product_id + ": " + ticker.price) + +client.open() +client.subscribe(product_ids=["BTC-USD", "ETH-USD"], channels=["ticker"]) +time.sleep(10) +client.unsubscribe(product_ids=["BTC-USD", "ETH-USD"], channels=["ticker"]) +client.close() +``` +#### Avoiding errors +In the example, note how we first checked `if ws_object.channel == "ticker"`. +Since each channel's event field has a unique structure and set of fields, it's important to ensure that the fields we access are actually present in the object. +For example, if we were to subscribe to the `user` channel and try to access a field that does not exist in it, such as the `tickers` field, we would be met with an error. + +Therefore, we urge users to reference our [documentation](https://docs.cdp.coinbase.com/advanced-trade/docs/ws-channels), which outlines the JSON object that each channel will return. + ___ ## Debugging the Clients You can enable debug logging for the REST and WebSocket clients by setting the `verbose` variable to `True` when initializing the clients. This will log useful information throughout the lifecycle of the REST request or WebSocket connection, and is highly recommended for debugging purposes. diff --git a/coinbase/__version__.py b/coinbase/__version__.py index 14d9d2f..29654ee 100644 --- a/coinbase/__version__.py +++ b/coinbase/__version__.py @@ -1 +1 @@ -__version__ = "1.7.0" +__version__ = "1.8.0" diff --git a/coinbase/rest/types/base_response.py b/coinbase/rest/types/base_response.py index 7242782..d62fac1 100644 --- a/coinbase/rest/types/base_response.py +++ b/coinbase/rest/types/base_response.py @@ -9,11 +9,9 @@ class BaseResponse: def __init__(self, **kwargs): - for field in list(kwargs.keys()): - attr_name = field.replace("-", "_") - + for field, formattedField in common_fields.items(): if field in kwargs: - setattr(self, attr_name, kwargs.pop(field)) + setattr(self, formattedField, kwargs.pop(field)) for key in list(kwargs.keys()): setattr(self, key, kwargs.pop(key)) diff --git a/coinbase/rest/types/common_types.py b/coinbase/rest/types/common_types.py index f8da529..a3df2b8 100644 --- a/coinbase/rest/types/common_types.py +++ b/coinbase/rest/types/common_types.py @@ -382,17 +382,18 @@ def __init__(self, **kwargs): **kwargs.pop("limit_limit_fok") ) if "stop_limit_stop_limit_gtc" in kwargs: - self.stop_limit_stop_limit_gtc: Optional[ - StopLimitStopLimitGtc - ] = StopLimitStopLimitGtc(**kwargs.pop("stop_limit_stop_limit_gtc")) + self.stop_limit_stop_limit_gtc: Optional[StopLimitStopLimitGtc] = ( + StopLimitStopLimitGtc(**kwargs.pop("stop_limit_stop_limit_gtc")) + ) if "stop_limit_stop_limit_gtd" in kwargs: - self.stop_limit_stop_limit_gtd: Optional[ - StopLimitStopLimitGtd - ] = StopLimitStopLimitGtd(**kwargs.pop("stop_limit_stop_limit_gtd")) + self.stop_limit_stop_limit_gtd: Optional[StopLimitStopLimitGtd] = ( + StopLimitStopLimitGtd(**kwargs.pop("stop_limit_stop_limit_gtd")) + ) if "trigger_bracket_gtc" in kwargs: self.trigger_bracket_gtc: Optional[TriggerBracketGtc] = TriggerBracketGtc( **kwargs.pop("trigger_bracket_gtc") ) + if "trigger_bracket_gtd" in kwargs: self.trigger_bracket_gtd: Optional[TriggerBracketGtd] = TriggerBracketGtd( **kwargs.pop("trigger_bracket_gtd") diff --git a/coinbase/rest/types/futures_types.py b/coinbase/rest/types/futures_types.py index 688636a..ba2cea0 100644 --- a/coinbase/rest/types/futures_types.py +++ b/coinbase/rest/types/futures_types.py @@ -83,9 +83,9 @@ def __init__(self, response: dict): "is_intraday_margin_killswitch_enabled" ) if "is_intraday_margin_enrollment_killswitch_enabled" in response: - self.is_intraday_margin_enrollment_killswitch_enabled: Optional[ - bool - ] = response.pop("is_intraday_margin_enrollment_killswitch_enabled") + self.is_intraday_margin_enrollment_killswitch_enabled: Optional[bool] = ( + response.pop("is_intraday_margin_enrollment_killswitch_enabled") + ) super().__init__(**response) diff --git a/coinbase/websocket/__init__.py b/coinbase/websocket/__init__.py index 5e27849..7787afd 100644 --- a/coinbase/websocket/__init__.py +++ b/coinbase/websocket/__init__.py @@ -3,6 +3,7 @@ from coinbase.constants import API_ENV_KEY, API_SECRET_ENV_KEY, WS_USER_BASE_URL +from .types.websocket_response import WebsocketResponse from .websocket_base import WSBase, WSClientConnectionClosedException, WSClientException diff --git a/coinbase/websocket/types/base_response.py b/coinbase/websocket/types/base_response.py new file mode 100644 index 0000000..2feb28b --- /dev/null +++ b/coinbase/websocket/types/base_response.py @@ -0,0 +1,13 @@ +from typing import Any + + +class BaseResponse: + def __init__(self, **data): + for key in list(data.keys()): + setattr(self, key, data.pop(key)) + + def __getitem__(self, key: str) -> Any: + return self.__dict__.get(key) + + def __repr__(self): + return str(self.__dict__) diff --git a/coinbase/websocket/types/misc_types.py b/coinbase/websocket/types/misc_types.py new file mode 100644 index 0000000..ac61ef9 --- /dev/null +++ b/coinbase/websocket/types/misc_types.py @@ -0,0 +1,226 @@ +from typing import List, Optional + +from coinbase.websocket.types.base_response import BaseResponse + + +class WSHeartBeats(BaseResponse): + def __init__(self, **kwargs): + self.current_time: Optional[str] = kwargs.pop("current_time", None) + self.heartbeat_counter: Optional[str] = kwargs.pop("heartbeat_counter", None) + super().__init__(**kwargs) + + +class WSCandle(BaseResponse): + def __init__(self, **kwargs): + self.start: str = kwargs.pop("start", None) + self.high: str = kwargs.pop("high", None) + self.low: str = kwargs.pop("low", None) + self.open: str = kwargs.pop("open", None) + self.close: str = kwargs.pop("close", None) + self.volume: str = kwargs.pop("volume", None) + self.product_id: str = kwargs.pop("product_id", None) + super().__init__(**kwargs) + + +class WSHistoricalMarketTrade(BaseResponse): + def __init__(self, **kwargs): + self.product_id: str = kwargs.pop("product_id", None) + self.trade_id: str = kwargs.pop("trade_id", None) + self.price: str = kwargs.pop("price", None) + self.size: str = kwargs.pop("size", None) + self.time: str = kwargs.pop("time", None) + self.side: str = kwargs.pop("side", None) + super().__init__(**kwargs) + + +class WSProduct(BaseResponse): + def __init__(self, **kwargs): + self.product_type: str = kwargs.pop("product_type", None) + self.id: str = kwargs.pop("id", None) + self.base_currency: str = kwargs.pop("base_currency", None) + self.quote_currency: str = kwargs.pop("quote_currency", None) + self.base_increment: str = kwargs.pop("base_increment", None) + self.quote_increment: str = kwargs.pop("quote_increment", None) + self.display_name: str = kwargs.pop("display_name", None) + self.status: str = kwargs.pop("status", None) + self.status_message: str = kwargs.pop("status_message", None) + self.min_market_funds: str = kwargs.pop("min_market_funds", None) + super().__init__(**kwargs) + + +class WSTicker(BaseResponse): + def __init__(self, **kwargs): + self.type: str = kwargs.pop("type", None) + self.product_id: str = kwargs.pop("product_id", None) + self.price: str = kwargs.pop("price", None) + self.volume_24_h: str = kwargs.pop("volume_24_h", None) + self.low_24_h: str = kwargs.pop("low_24_h", None) + self.high_24_h: str = kwargs.pop("high_24_h", None) + self.low_52_w: str = kwargs.pop("low_52_w", None) + self.high_52_w: str = kwargs.pop("high_52_w", None) + self.price_percent_chg_24_h: str = kwargs.pop("price_percent_chg_24_h", None) + self.best_bid: str = kwargs.pop("best_bid", None) + self.best_ask: str = kwargs.pop("best_ask", None) + self.best_bid_quantity: str = kwargs.pop("best_bid_quantity", None) + self.best_ask_quantity: str = kwargs.pop("best_ask_quantity", None) + super().__init__(**kwargs) + + +class L2Update(BaseResponse): + def __init__(self, **kwargs): + self.side: str = kwargs.pop("side", None) + self.event_time: str = kwargs.pop("event_time", None) + self.price_level: str = kwargs.pop("price_level", None) + self.new_quantity: str = kwargs.pop("new_quantity", None) + super().__init__(**kwargs) + + +class UserOrders(BaseResponse): + def __init__(self, **kwargs): + self.avg_price: Optional[str] = kwargs.pop("avg_price", None) + self.cancel_reason: Optional[str] = kwargs.pop("cancel_reason", None) + self.client_order_id: Optional[str] = kwargs.pop("client_order_id", None) + self.completion_percentage: Optional[str] = kwargs.pop( + "completion_percentage", None + ) + self.contract_expiry_type: Optional[str] = kwargs.pop( + "contract_expiry_type", None + ) + self.cumulative_quantity: Optional[str] = kwargs.pop( + "cumulative_quantity", None + ) + self.filled_value: Optional[str] = kwargs.pop("filled_value", None) + self.leaves_quantity: Optional[str] = kwargs.pop("leaves_quantity", None) + self.limit_price: Optional[str] = kwargs.pop("limit_price", None) + self.number_of_fills: Optional[str] = kwargs.pop("number_of_fills", None) + self.order_id: Optional[str] = kwargs.pop("order_id", None) + self.order_side: Optional[str] = kwargs.pop("order_side", None) + self.order_type: Optional[str] = kwargs.pop("order_type", None) + self.outstanding_hold_amount: Optional[str] = kwargs.pop( + "outstanding_hold_amount", None + ) + self.post_only: Optional[str] = kwargs.pop("post_only", None) + self.product_id: Optional[str] = kwargs.pop("product_id", None) + self.product_type: Optional[str] = kwargs.pop("product_type", None) + self.reject_reason: Optional[str] = kwargs.pop("reject_reason", None) + self.retail_portfolio_id: Optional[str] = kwargs.pop( + "retail_portfolio_id", None + ) + self.risk_managed_by: Optional[str] = kwargs.pop("risk_managed_by", None) + self.status: Optional[str] = kwargs.pop("status", None) + self.stop_price: Optional[str] = kwargs.pop("stop_price", None) + self.time_in_force: Optional[str] = kwargs.pop("time_in_force", None) + self.total_fees: Optional[str] = kwargs.pop("total_fees", None) + self.total_value_after_fees: Optional[str] = kwargs.pop( + "total_value_after_fees", None + ) + self.trigger_status: Optional[str] = kwargs.pop("trigger_status", None) + self.creation_time: Optional[str] = kwargs.pop("creation_time", None) + self.end_time: Optional[str] = kwargs.pop("end_time", None) + self.start_time: Optional[str] = kwargs.pop("start_time", None) + super().__init__(**kwargs) + + +class UserPositions(BaseResponse): + def __init__(self, **kwargs): + self.perpetual_futures_positions: Optional[List[UserFuturesPositions]] = ( + [ + UserFuturesPositions(**position) + for position in kwargs.pop("perpetual_futures_positions", []) + ] + if kwargs.get("perpetual_futures_positions") is not None + else [] + ) + self.expiring_futures_positions: Optional[List[UserExpFuturesPositions]] = ( + [ + UserExpFuturesPositions(**position) + for position in kwargs.pop("expiring_futures_positions", []) + ] + if kwargs.get("expiring_futures_positions") is not None + else [] + ) + super().__init__(**kwargs) + + +class UserFuturesPositions(BaseResponse): + def __init__(self, **kwargs): + self.product_id: Optional[str] = kwargs.pop("product_id", None) + self.portfolio_uuid: Optional[str] = kwargs.pop("portfolio_uuid", None) + self.vwap: Optional[str] = kwargs.pop("vwap", None) + self.entry_vwap: Optional[str] = kwargs.pop("entry_vwap", None) + self.position_side: Optional[str] = kwargs.pop("position_side", None) + self.margin_type: Optional[str] = kwargs.pop("margin_type", None) + self.net_size: Optional[str] = kwargs.pop("net_size", None) + self.buy_order_size: Optional[str] = kwargs.pop("buy_order_size", None) + self.sell_order_size: Optional[str] = kwargs.pop("sell_order_size", None) + self.leverage: Optional[str] = kwargs.pop("leverage", None) + self.mark_price: Optional[str] = kwargs.pop("mark_price", None) + self.liquidation_price: Optional[str] = kwargs.pop("liquidation_price", None) + self.im_notional: Optional[str] = kwargs.pop("im_notional", None) + self.mm_notional: Optional[str] = kwargs.pop("mm_notional", None) + self.position_notional: Optional[str] = kwargs.pop("position_notional", None) + self.unrealized_pnl: Optional[str] = kwargs.pop("unrealized_pnl", None) + self.aggregated_pnl: Optional[str] = kwargs.pop("aggregated_pnl", None) + super().__init__(**kwargs) + + +class UserExpFuturesPositions(BaseResponse): + def __init__(self, **kwargs): + self.product_id: Optional[str] = kwargs.pop("product_id", None) + self.side: Optional[str] = kwargs.pop("side", None) + self.number_of_contracts: Optional[str] = kwargs.pop( + "number_of_contracts", None + ) + self.realized_pnl: Optional[str] = kwargs.pop("realized_pnl", None) + self.unrealized_pnl: Optional[str] = kwargs.pop("unrealized_pnl", None) + self.entry_price: Optional[str] = kwargs.pop("entry_price", None) + super().__init__(**kwargs) + + +class WSFCMBalanceSummary(BaseResponse): + def __init__(self, **kwargs): + self.futures_buying_power: str = kwargs.pop("futures_buying_power", None) + self.total_usd_balance: str = kwargs.pop("total_usd_balance", None) + self.cbi_usd_balance: str = kwargs.pop("cbi_usd_balance", None) + self.cfm_usd_balance: str = kwargs.pop("cfm_usd_balance", None) + self.total_open_orders_hold_amount: str = kwargs.pop( + "total_open_orders_hold_amount", None + ) + self.unrealized_pnl: str = kwargs.pop("unrealized_pnl", None) + self.daily_realized_pnl: str = kwargs.pop("daily_realized_pnl", None) + self.initial_margin: str = kwargs.pop("initial_margin", None) + self.available_margin: str = kwargs.pop("available_margin", None) + self.liquidation_threshold: str = kwargs.pop("liquidation_threshold", None) + self.liquidation_buffer_amount: str = kwargs.pop( + "liquidation_buffer_amount", None + ) + self.liquidation_buffer_percentage: str = kwargs.pop( + "liquidation_buffer_percentage", None + ) + self.intraday_margin_window_measure: Optional[FCMMarginWindowMeasure] = ( + FCMMarginWindowMeasure(**kwargs.pop("intraday_margin_window_measure")) + if kwargs.get("intraday_margin_window_measure") + else None + ) + self.overnight_margin_window_measure: Optional[FCMMarginWindowMeasure] = ( + FCMMarginWindowMeasure(**kwargs.pop("overnight_margin_window_measure")) + if kwargs.get("overnight_margin_window_measure") + else None + ) + super().__init__(**kwargs) + + +class FCMMarginWindowMeasure(BaseResponse): + def __init__(self, **kwargs): + self.margin_window_type: Optional[str] = kwargs.pop("margin_window_type", None) + self.margin_level: Optional[str] = kwargs.pop("margin_level", None) + self.initial_margin: Optional[str] = kwargs.pop("initial_margin", None) + self.maintenance_margin: Optional[str] = kwargs.pop("maintenance_margin", None) + self.liquidation_buffer_percentage: Optional[str] = kwargs.pop( + "liquidation_buffer_percentage", None + ) + self.total_hold: Optional[str] = kwargs.pop("total_hold", None) + self.futures_buying_power: Optional[str] = kwargs.pop( + "futures_buying_power", None + ) + super().__init__(**kwargs) diff --git a/coinbase/websocket/types/websocket_response.py b/coinbase/websocket/types/websocket_response.py new file mode 100644 index 0000000..7e2634c --- /dev/null +++ b/coinbase/websocket/types/websocket_response.py @@ -0,0 +1,91 @@ +from typing import List, Optional + +from coinbase.websocket.types.base_response import BaseResponse +from coinbase.websocket.types.misc_types import ( + L2Update, + UserOrders, + UserPositions, + WSCandle, + WSFCMBalanceSummary, + WSHistoricalMarketTrade, + WSProduct, + WSTicker, +) + + +class WebsocketResponse(BaseResponse): + def __init__(self, data): + self.channel = data.pop("channel") + self.client_id = data.pop("client_id") + self.timestamp = data.pop("timestamp") + self.sequence_num = data.pop("sequence_num") + self.events = [ + Event(event_data, self.channel) for event_data in data.pop("events") + ] + + super().__init__(**data) + + +class Event(BaseResponse): + def __init__(self, data, channel): + if channel == "heartbeats": + self.current_time = data.pop("current_time", None) + self.heartbeat_counter = data.pop("heartbeat_counter", None) + elif channel == "candles": + self.type = data.pop("type", None) + self.candles: List[WSCandle] = ( + [WSCandle(**ws_candle) for ws_candle in data.pop("candles", [])] + if "candles" in data + else None + ) + elif channel == "market_trades": + self.type = data.pop("type", None) + self.trades: List[WSHistoricalMarketTrade] = ( + [ + WSHistoricalMarketTrade(**ws_trades) + for ws_trades in data.pop("trades", []) + ] + if "trades" in data + else None + ) + elif channel == "status": + self.type = data.pop("type", None) + self.products: List[WSProduct] = ( + [WSProduct(**ws_product) for ws_product in data.pop("products", [])] + if "products" in data + else None + ) + elif channel == "ticker" or channel == "ticker_batch": + self.type = data.pop("type", None) + self.tickers: List[WSTicker] = ( + [WSTicker(**ws_ticker) for ws_ticker in data.pop("tickers", [])] + if "tickers" in data + else None + ) + elif channel == "l2_data": + self.type = data.pop("type", None) + self.product_id = data.pop("product_id", None) + self.updates: List[L2Update] = ( + [L2Update(**l2_update) for l2_update in data.pop("updates", [])] + if "updates" in data + else None + ) + elif channel == "user": + self.type = data.pop("type", None) + self.orders: Optional[List[UserOrders]] = ( + [UserOrders(**user_order) for user_order in data.pop("orders", [])] + if data.get("orders") is not None + else None + ) + self.positions: Optional[UserPositions] = ( + UserPositions(**data.pop("positions")) if "positions" in data else None + ) + elif channel == "futures_balance_summary": + self.type = data.pop("type", None) + self.fcm_balance_summary: WSFCMBalanceSummary = ( + WSFCMBalanceSummary(**data.pop("fcm_balance_summary")) + if data.get("fcm_balance_summary") + else None + ) + + super().__init__(**data) diff --git a/lint_requirements.txt b/lint_requirements.txt index a45e8c9..5d21eeb 100644 --- a/lint_requirements.txt +++ b/lint_requirements.txt @@ -1,2 +1,2 @@ -black==23.3.0 +black==24.3.0 isort==5.12.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f1d8a28..0f6585d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ requests>=2.31.0 cryptography>=42.0.4 PyJWT>=2.8.0 -websockets>=12.0 +websockets>=12.0,<14.0 backoff>=2.2.1