Skip to content

Commit 4c16994

Browse files
committed
Add WebSocket support
1 parent 8a755f8 commit 4c16994

File tree

7 files changed

+341
-3
lines changed

7 files changed

+341
-3
lines changed

requirements.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,6 @@ html5lib>=1.1
1111
peewee>=3.16.2
1212
requests_cache>=1.0
1313
requests_ratelimiter>=0.3.1
14-
scipy>=1.6.3
14+
scipy>=1.6.3
15+
protobuf>=5.29.2
16+
websockets>=14.1

setup.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,17 @@
6363
'requests>=2.31', 'multitasking>=0.0.7',
6464
'lxml>=4.9.1', 'platformdirs>=2.0.0', 'pytz>=2022.5',
6565
'frozendict>=2.3.4', 'peewee>=3.16.2',
66-
'beautifulsoup4>=4.11.1', 'html5lib>=1.1'],
66+
'beautifulsoup4>=4.11.1', 'html5lib>=1.1',
67+
'protobuf>=5.29.2', 'websockets>=14.1'],
6768
extras_require={
6869
'nospam': ['requests_cache>=1.0', 'requests_ratelimiter>=0.3.1'],
6970
'repair': ['scipy>=1.6.3'],
7071
},
72+
# Include protobuf files for websocket support
73+
package_data={
74+
'yfinance': ['pricing.proto', 'pricing_pb2.py'],
75+
},
76+
include_package_data=True,
7177
# Note: Pandas.read_html() needs html5lib & beautifulsoup4
7278
entry_points={
7379
'console_scripts': [

tests/test_live.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import unittest
2+
from unittest.mock import Mock
3+
4+
from yfinance.live import BaseWebSocket
5+
6+
7+
class TestWebSocket(unittest.TestCase):
8+
def test_decode_message_valid(self):
9+
message = ("CgdCVEMtVVNEFYoMuUcYwLCVgIplIgNVU0QqA0NDQzApOAFFPWrEP0iAgOrxvANVx/25R12csrRHZYD8skR9/"
10+
"7i0R7ABgIDq8bwD2AEE4AGAgOrxvAPoAYCA6vG8A/IBA0JUQ4ECAAAAwPrjckGJAgAA2P5ZT3tC")
11+
12+
ws = BaseWebSocket(Mock())
13+
decoded = ws._decode_message(message)
14+
15+
expected = {'id': 'BTC-USD', 'price': 94745.08, 'time': '1736509140000', 'currency': 'USD', 'exchange': 'CCC',
16+
'quote_type': 41, 'market_hours': 1, 'change_percent': 1.5344921, 'day_volume': '59712028672',
17+
'day_high': 95227.555, 'day_low': 92517.22, 'change': 1431.8906, 'open_price': 92529.99,
18+
'last_size': '59712028672', 'price_hint': '2', 'vol_24hr': '59712028672',
19+
'vol_all_currencies': '59712028672', 'from_currency': 'BTC', 'circulating_supply': 19808172.0,
20+
'market_cap': 1876726640000.0}
21+
22+
self.assertEqual(expected, decoded)
23+
24+
def test_decode_message_invalid(self):
25+
websocket = BaseWebSocket(Mock())
26+
base64_message = "invalid_base64_string"
27+
decoded = websocket._decode_message(base64_message)
28+
assert "error" in decoded
29+
assert "raw_base64" in decoded
30+
self.assertEqual(base64_message, decoded["raw_base64"])

yfinance/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .ticker import Ticker
2525
from .tickers import Tickers
2626
from .multi import download
27+
from .live import WebSocket, AsyncWebSocket
2728
from .utils import enable_debug_mode
2829
from .cache import set_tz_cache_location
2930
from .domain.sector import Sector
@@ -39,6 +40,6 @@
3940
import warnings
4041
warnings.filterwarnings('default', category=DeprecationWarning, module='^yfinance')
4142

42-
__all__ = ['download', 'Market', 'Search', 'Ticker', 'Tickers', 'enable_debug_mode', 'set_tz_cache_location', 'Sector', 'Industry']
43+
__all__ = ['download', 'Market', 'Search', 'Ticker', 'Tickers', 'enable_debug_mode', 'set_tz_cache_location', 'Sector', 'Industry', 'WebSocket', 'AsyncWebSocket']
4344
# screener stuff:
4445
__all__ += ['EquityQuery', 'FundQuery', 'screen', 'PREDEFINED_SCREENER_QUERIES']

yfinance/base.py

+11
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from . import utils, cache
3535
from .data import YfData
3636
from .exceptions import YFEarningsDateMissing, YFRateLimitError
37+
from .live import WebSocket
3738
from .scrapers.analysis import Analysis
3839
from .scrapers.fundamentals import Fundamentals
3940
from .scrapers.holders import Holders
@@ -76,6 +77,9 @@ def __init__(self, ticker, session=None, proxy=None):
7677

7778
self._fast_info = None
7879

80+
self._message_handler = None
81+
self.ws = None
82+
7983
@utils.log_indent_decorator
8084
def history(self, *args, **kwargs) -> pd.DataFrame:
8185
return self._lazy_load_price_history().history(*args, **kwargs)
@@ -690,3 +694,10 @@ def get_funds_data(self, proxy=None) -> Optional[FundsData]:
690694
self._funds_data = FundsData(self._data, self.ticker)
691695

692696
return self._funds_data
697+
698+
def live(self, message_handler=None, verbose=True):
699+
self._message_handler = message_handler
700+
701+
self.ws = WebSocket(verbose=verbose)
702+
self.ws.subscribe(self.ticker)
703+
self.ws.listen(self._message_handler)

yfinance/live.py

+277
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
import asyncio
2+
import base64
3+
import json
4+
from typing import List, Optional, Callable
5+
6+
from websockets.sync.client import connect as sync_connect
7+
from websockets.asyncio.client import connect as async_connect
8+
9+
from yfinance import utils
10+
from yfinance.pricing_pb2 import PricingData
11+
from google.protobuf.json_format import MessageToDict
12+
13+
14+
class BaseWebSocket:
15+
def __init__(self, url: str = "wss://streamer.finance.yahoo.com/?version=2", verbose=True):
16+
self.url = url
17+
self.verbose = verbose
18+
self.logger = utils.get_yf_logger()
19+
self._ws = None
20+
self._subscriptions = set()
21+
self._subscription_interval = 15 # seconds
22+
23+
def _decode_message(self, base64_message: str) -> dict:
24+
try:
25+
decoded_bytes = base64.b64decode(base64_message)
26+
pricing_data = PricingData()
27+
pricing_data.ParseFromString(decoded_bytes)
28+
return MessageToDict(pricing_data, preserving_proto_field_name=True)
29+
except Exception as e:
30+
self.logger.error("Failed to decode message: %s", e, exc_info=True)
31+
print("Failed to decode message: %s", e)
32+
return {
33+
'error': str(e),
34+
'raw_base64': base64_message
35+
}
36+
37+
38+
class AsyncWebSocket(BaseWebSocket):
39+
"""
40+
Asynchronous WebSocket client for streaming real time pricing data.
41+
"""
42+
43+
def __init__(self, url: str = "wss://streamer.finance.yahoo.com/?version=2", verbose=True):
44+
"""
45+
Initialize the AsyncWebSocket client.
46+
47+
Args:
48+
url (str): The WebSocket server URL. Defaults to Yahoo Finance's WebSocket URL.
49+
verbose (bool): Flag to enable or disable print statements. Defaults to True.
50+
"""
51+
super().__init__(url, verbose)
52+
self._message_handler = None # Callable to handle messages
53+
self._heartbeat_task = None # Task to send heartbeat subscribe
54+
55+
async def _connect(self):
56+
if self._ws is None:
57+
self._ws = await async_connect(self.url)
58+
self.logger.info("Connected to WebSocket.")
59+
if self.verbose:
60+
print("Connected to WebSocket.")
61+
62+
async def _periodic_subscribe(self):
63+
while True:
64+
try:
65+
await asyncio.sleep(self._subscription_interval)
66+
67+
if self._subscriptions:
68+
message = {"subscribe": list(self._subscriptions)}
69+
await self._ws.send(json.dumps(message))
70+
71+
if self.verbose:
72+
print(f"Heartbeat subscription sent for symbols: {self._subscriptions}")
73+
except Exception as e:
74+
self.logger.error("Error in heartbeat subscription: %s", e, exc_info=True)
75+
if self.verbose:
76+
print(f"Error in heartbeat subscription: {e}")
77+
break
78+
79+
async def subscribe(self, symbols: str | List[str]):
80+
"""
81+
Subscribe to a stock symbol or a list of stock symbols.
82+
83+
Args:
84+
symbols (str | List[str]): Stock symbol(s) to subscribe to.
85+
"""
86+
await self._connect()
87+
88+
if isinstance(symbols, str):
89+
symbols = [symbols]
90+
91+
self._subscriptions.update(symbols)
92+
93+
message = {"subscribe": list(self._subscriptions)}
94+
await self._ws.send(json.dumps(message))
95+
96+
# Start heartbeat subscription task
97+
if self._heartbeat_task is None:
98+
self._heartbeat_task = asyncio.create_task(self._periodic_subscribe())
99+
100+
self.logger.info(f"Subscribed to symbols: {symbols}")
101+
if self.verbose:
102+
print(f"Subscribed to symbols: {symbols}")
103+
104+
async def unsubscribe(self, symbols: str | List[str]):
105+
"""
106+
Unsubscribe from a stock symbol or a list of stock symbols.
107+
108+
Args:
109+
symbols (str | List[str]): Stock symbol(s) to unsubscribe from.
110+
"""
111+
await self._connect()
112+
113+
if isinstance(symbols, str):
114+
symbols = [symbols]
115+
116+
self._subscriptions.difference_update(symbols)
117+
118+
message = {"unsubscribe": list(self._subscriptions)}
119+
await self._ws.send(json.dumps(message))
120+
121+
self.logger.info(f"Unsubscribed from symbols: {symbols}")
122+
if self.verbose:
123+
print(f"Unsubscribed from symbols: {symbols}")
124+
125+
async def listen(self, message_handler: Optional[Callable[[dict], None]] = None):
126+
"""
127+
Start listening to messages from the WebSocket server.
128+
129+
Args:
130+
message_handler (Optional[Callable[[dict], None]]): Optional function to handle received messages.
131+
"""
132+
await self._connect()
133+
self._message_handler = message_handler
134+
135+
self.logger.info("Listening for messages...")
136+
if self.verbose:
137+
print("Listening for messages...")
138+
139+
# Start heartbeat subscription task
140+
if self._heartbeat_task is None:
141+
self._heartbeat_task = asyncio.create_task(self._periodic_subscribe())
142+
143+
try:
144+
async for message in self._ws:
145+
message_json = json.loads(message)
146+
encoded_data = message_json.get("message", "")
147+
decoded_message = self._decode_message(encoded_data)
148+
if self._message_handler:
149+
self._message_handler(decoded_message)
150+
else:
151+
print(decoded_message)
152+
except (KeyboardInterrupt, asyncio.CancelledError) as e:
153+
self.logger.info("WebSocket listening interrupted. Closing connection...")
154+
if self.verbose:
155+
print("WebSocket listening interrupted. Closing connection...")
156+
await self.close()
157+
except Exception as e:
158+
self.logger.error("Error while listening to messages: %s", e, exc_info=True)
159+
if self.verbose:
160+
print("Error while listening to messages: %s", e)
161+
162+
async def close(self):
163+
"""Close the WebSocket connection."""
164+
if self._heartbeat_task:
165+
self._heartbeat_task.cancel()
166+
167+
if self._ws is not None:# and not self._ws.closed:
168+
await self._ws.close()
169+
self.logger.info("WebSocket connection closed.")
170+
if self.verbose:
171+
print("WebSocket connection closed.")
172+
173+
174+
class WebSocket(BaseWebSocket):
175+
"""
176+
Synchronous WebSocket client for streaming real time pricing data.
177+
"""
178+
179+
def __init__(self, url: str = "wss://streamer.finance.yahoo.com/?version=2", verbose=True):
180+
"""
181+
Initialize the WebSocket client.
182+
183+
Args:
184+
url (str): The WebSocket server URL. Defaults to Yahoo Finance's WebSocket URL.
185+
verbose (bool): Flag to enable or disable print statements. Defaults to True.
186+
"""
187+
super().__init__(url, verbose)
188+
189+
def _connect(self):
190+
if self._ws is None:
191+
self._ws = sync_connect(self.url)
192+
self.logger.info("Connected to WebSocket.")
193+
if self.verbose:
194+
print("Connected to WebSocket.")
195+
196+
def subscribe(self, symbols: str | List[str]):
197+
"""
198+
Subscribe to a stock symbol or a list of stock symbols.
199+
200+
Args:
201+
symbols (str | List[str]): Stock symbol(s) to subscribe to.
202+
"""
203+
self._connect()
204+
205+
if isinstance(symbols, str):
206+
symbols = [symbols]
207+
208+
self._subscriptions.update(symbols)
209+
210+
message = {"subscribe": list(self._subscriptions)}
211+
self._ws.send(json.dumps(message))
212+
213+
self.logger.info(f"Subscribed to symbols: {symbols}")
214+
if self.verbose:
215+
print(f"Subscribed to symbols: {symbols}")
216+
217+
def unsubscribe(self, symbols: str | List[str]):
218+
"""
219+
Unsubscribe from a stock symbol or a list of stock symbols.
220+
221+
Args:
222+
symbols (str | List[str]): Stock symbol(s) to unsubscribe from.
223+
"""
224+
self._connect()
225+
226+
if isinstance(symbols, str):
227+
symbols = [symbols]
228+
229+
self._subscriptions.difference_update(symbols)
230+
231+
message = {"unsubscribe": list(self._subscriptions)}
232+
self._ws.send(json.dumps(message))
233+
234+
self.logger.info(f"Unsubscribed from symbols: {symbols}")
235+
if self.verbose:
236+
print(f"Unsubscribed from symbols: {symbols}")
237+
238+
def listen(self, message_handler: Optional[Callable[[dict], None]] = None):
239+
"""
240+
Start listening to messages from the WebSocket server.
241+
242+
Args:
243+
message_handler (Optional[Callable[[dict], None]]): Optional function to handle received messages.
244+
"""
245+
self._connect()
246+
247+
self.logger.info("Listening for messages...")
248+
if self.verbose:
249+
print("Listening for messages...")
250+
251+
try:
252+
while True:
253+
message = self._ws.recv()
254+
message_json = json.loads(message)
255+
encoded_data = message_json.get("message", "")
256+
decoded_message = self._decode_message(encoded_data)
257+
258+
if message_handler:
259+
message_handler(decoded_message)
260+
else:
261+
print(decoded_message)
262+
except KeyboardInterrupt:
263+
if self.verbose:
264+
print("Received keyboard interrupt.")
265+
self.close()
266+
except Exception as e:
267+
self.logger.error("Error while listening to messages: %s", e, exc_info=True)
268+
if self.verbose:
269+
print("Error while listening to messages: %s", e)
270+
271+
def close(self):
272+
"""Close the WebSocket connection."""
273+
if self._ws is not None:
274+
self._ws.close()
275+
self.logger.info("WebSocket connection closed.")
276+
if self.verbose:
277+
print("WebSocket connection closed.")

0 commit comments

Comments
 (0)