@@ -17,6 +17,7 @@ package pyth
1717import (
1818 "context"
1919 "errors"
20+ "sync"
2021 "time"
2122
2223 "github.com/cenkalti/backoff/v4"
@@ -26,16 +27,63 @@ import (
2627 "go.uber.org/zap"
2728)
2829
30+ // StreamPriceAccounts creates a new stream of price account updates.
31+ func (c * Client ) StreamPriceAccounts () * PriceAccountStream {
32+ ctx , cancel := context .WithCancel (context .Background ())
33+ stream := & PriceAccountStream {
34+ cancel : cancel ,
35+ updates : make (chan PriceAccountUpdate ),
36+ client : c ,
37+ }
38+ stream .errLock .Lock ()
39+ go stream .runWrapper (ctx )
40+ return stream
41+ }
42+
43+ // PriceAccountUpdate is a real-time update carrying a price account change.
2944type PriceAccountUpdate struct {
3045 Slot uint64
3146 * PriceAccount
3247}
3348
34- // StreamPriceAccounts sends an update to Prometheus any time a Pyth oracle account changes.
35- func (c * Client ) StreamPriceAccounts (ctx context.Context , updates chan <- PriceAccountUpdate ) error {
49+ // PriceAccountStream is an ongoing stream of on-chain price account updates.
50+ type PriceAccountStream struct {
51+ cancel context.CancelFunc
52+ updates chan PriceAccountUpdate
53+ client * Client
54+ err error
55+ errLock sync.Mutex
56+ }
57+
58+ // Updates returns a channel with new price account updates.
59+ func (p * PriceAccountStream ) Updates () <- chan PriceAccountUpdate {
60+ return p .updates
61+ }
62+
63+ // Err returns the reason why the price account stream is closed.
64+ // Will block until the stream has actually closed.
65+ // Returns nil if closure was expected.
66+ func (p * PriceAccountStream ) Err () error {
67+ p .errLock .Lock ()
68+ defer p .errLock .Unlock ()
69+ return p .err
70+ }
71+
72+ // Close must be called when no more updates are needed.
73+ func (p * PriceAccountStream ) Close () {
74+ p .cancel ()
75+ }
76+
77+ func (p * PriceAccountStream ) runWrapper (ctx context.Context ) {
78+ defer p .errLock .Unlock ()
79+ p .err = p .run (ctx )
80+ }
81+
82+ func (p * PriceAccountStream ) run (ctx context.Context ) error {
83+ defer close (p .updates )
3684 const retryInterval = 3 * time .Second
3785 return backoff .Retry (func () error {
38- err := c . streamPriceAccounts (ctx , updates )
86+ err := p . runConn (ctx )
3987 switch {
4088 case errors .Is (err , context .Canceled ), errors .Is (err , context .DeadlineExceeded ):
4189 return backoff .Permanent (err )
@@ -45,8 +93,8 @@ func (c *Client) StreamPriceAccounts(ctx context.Context, updates chan<- PriceAc
4593 }, backoff .WithContext (backoff .NewConstantBackOff (retryInterval ), ctx ))
4694}
4795
48- func (c * Client ) streamPriceAccounts (ctx context.Context , updates chan <- PriceAccountUpdate ) error {
49- client , err := ws .Connect (ctx , c .WebSocketURL )
96+ func (p * PriceAccountStream ) runConn (ctx context.Context ) error {
97+ client , err := ws .Connect (ctx , p . client .WebSocketURL )
5098 if err != nil {
5199 return err
52100 }
@@ -62,7 +110,7 @@ func (c *Client) streamPriceAccounts(ctx context.Context, updates chan<- PriceAc
62110 defer metricsWsActiveConns .Dec ()
63111
64112 sub , err := client .ProgramSubscribeWithOpts (
65- c .Env .Program ,
113+ p . client .Env .Program ,
66114 rpc .CommitmentConfirmed ,
67115 solana .EncodingBase64Zstd ,
68116 []rpc.RPCFilter {
@@ -83,17 +131,13 @@ func (c *Client) streamPriceAccounts(ctx context.Context, updates chan<- PriceAc
83131
84132 // Stream updates.
85133 for {
86- if err := c .readNextUpdate (ctx , sub , updates ); err != nil {
134+ if err := p .readNextUpdate (ctx , sub ); err != nil {
87135 return err
88136 }
89137 }
90138}
91139
92- func (c * Client ) readNextUpdate (
93- ctx context.Context ,
94- sub * ws.ProgramSubscription ,
95- updates chan <- PriceAccountUpdate ,
96- ) error {
140+ func (p * PriceAccountStream ) readNextUpdate (ctx context.Context , sub * ws.ProgramSubscription ) error {
97141 // If no update comes in within 20 seconds, bail.
98142 const readTimeout = 20 * time .Second
99143 ctx , cancel := context .WithTimeout (ctx , readTimeout )
@@ -102,7 +146,7 @@ func (c *Client) readNextUpdate(
102146 <- ctx .Done ()
103147 // Terminate subscription if above timer has expired.
104148 if errors .Is (ctx .Err (), context .DeadlineExceeded ) {
105- c .Log .Warn ("Read deadline exceeded, terminating WebSocket connection" ,
149+ p . client .Log .Warn ("Read deadline exceeded, terminating WebSocket connection" ,
106150 zap .Duration ("timeout" , readTimeout ))
107151 sub .Unsubscribe ()
108152 }
@@ -116,7 +160,7 @@ func (c *Client) readNextUpdate(
116160 metricsWsEventsTotal .Inc ()
117161
118162 // Decode update.
119- if update .Value .Account .Owner != c .Env .Program {
163+ if update .Value .Account .Owner != p . client .Env .Program {
120164 return nil
121165 }
122166 accountData := update .Value .Account .Data .GetBinary ()
@@ -125,7 +169,7 @@ func (c *Client) readNextUpdate(
125169 }
126170 priceAcc := new (PriceAccount )
127171 if err := priceAcc .UnmarshalBinary (accountData ); err != nil {
128- c .Log .Warn ("Failed to unmarshal priceAcc account" , zap .Error (err ))
172+ p . client .Log .Warn ("Failed to unmarshal priceAcc account" , zap .Error (err ))
129173 return nil
130174 }
131175
@@ -137,7 +181,7 @@ func (c *Client) readNextUpdate(
137181 select {
138182 case <- ctx .Done ():
139183 return ctx .Err ()
140- case updates <- msg :
184+ case p . updates <- msg :
141185 return nil
142186 }
143187}
0 commit comments