Skip to content
This repository was archived by the owner on May 31, 2023. It is now read-only.

Commit dcd41db

Browse files
author
Richard Patel
committed
stream: refactor StreamPriceAccounts
1 parent 5b2b616 commit dcd41db

File tree

2 files changed

+94
-16
lines changed

2 files changed

+94
-16
lines changed

stream.go

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package pyth
1717
import (
1818
"context"
1919
"errors"
20+
"sync"
2021
"time"
2122

2223
"github.com/cenkalti/backoff/v4"
@@ -26,16 +27,63 @@ import (
2627
"go.uber.org/zap"
2728
)
2829

30+
// StreamPriceAccounts creates a new stream of price account updates.
31+
func (c *Client) StreamPriceAccounts() *PriceAccountStream {
32+
ctx, cancel := context.WithCancel(context.Background())
33+
stream := &PriceAccountStream{
34+
cancel: cancel,
35+
updates: make(chan PriceAccountUpdate),
36+
client: c,
37+
}
38+
stream.errLock.Lock()
39+
go stream.runWrapper(ctx)
40+
return stream
41+
}
42+
43+
// PriceAccountUpdate is a real-time update carrying a price account change.
2944
type PriceAccountUpdate struct {
3045
Slot uint64
3146
*PriceAccount
3247
}
3348

34-
// StreamPriceAccounts sends an update to Prometheus any time a Pyth oracle account changes.
35-
func (c *Client) StreamPriceAccounts(ctx context.Context, updates chan<- PriceAccountUpdate) error {
49+
// PriceAccountStream is an ongoing stream of on-chain price account updates.
50+
type PriceAccountStream struct {
51+
cancel context.CancelFunc
52+
updates chan PriceAccountUpdate
53+
client *Client
54+
err error
55+
errLock sync.Mutex
56+
}
57+
58+
// Updates returns a channel with new price account updates.
59+
func (p *PriceAccountStream) Updates() <-chan PriceAccountUpdate {
60+
return p.updates
61+
}
62+
63+
// Err returns the reason why the price account stream is closed.
64+
// Will block until the stream has actually closed.
65+
// Returns nil if closure was expected.
66+
func (p *PriceAccountStream) Err() error {
67+
p.errLock.Lock()
68+
defer p.errLock.Unlock()
69+
return p.err
70+
}
71+
72+
// Close must be called when no more updates are needed.
73+
func (p *PriceAccountStream) Close() {
74+
p.cancel()
75+
}
76+
77+
func (p *PriceAccountStream) runWrapper(ctx context.Context) {
78+
defer p.errLock.Unlock()
79+
p.err = p.run(ctx)
80+
}
81+
82+
func (p *PriceAccountStream) run(ctx context.Context) error {
83+
defer close(p.updates)
3684
const retryInterval = 3 * time.Second
3785
return backoff.Retry(func() error {
38-
err := c.streamPriceAccounts(ctx, updates)
86+
err := p.runConn(ctx)
3987
switch {
4088
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
4189
return backoff.Permanent(err)
@@ -45,8 +93,8 @@ func (c *Client) StreamPriceAccounts(ctx context.Context, updates chan<- PriceAc
4593
}, backoff.WithContext(backoff.NewConstantBackOff(retryInterval), ctx))
4694
}
4795

48-
func (c *Client) streamPriceAccounts(ctx context.Context, updates chan<- PriceAccountUpdate) error {
49-
client, err := ws.Connect(ctx, c.WebSocketURL)
96+
func (p *PriceAccountStream) runConn(ctx context.Context) error {
97+
client, err := ws.Connect(ctx, p.client.WebSocketURL)
5098
if err != nil {
5199
return err
52100
}
@@ -62,7 +110,7 @@ func (c *Client) streamPriceAccounts(ctx context.Context, updates chan<- PriceAc
62110
defer metricsWsActiveConns.Dec()
63111

64112
sub, err := client.ProgramSubscribeWithOpts(
65-
c.Env.Program,
113+
p.client.Env.Program,
66114
rpc.CommitmentConfirmed,
67115
solana.EncodingBase64Zstd,
68116
[]rpc.RPCFilter{
@@ -83,17 +131,13 @@ func (c *Client) streamPriceAccounts(ctx context.Context, updates chan<- PriceAc
83131

84132
// Stream updates.
85133
for {
86-
if err := c.readNextUpdate(ctx, sub, updates); err != nil {
134+
if err := p.readNextUpdate(ctx, sub); err != nil {
87135
return err
88136
}
89137
}
90138
}
91139

92-
func (c *Client) readNextUpdate(
93-
ctx context.Context,
94-
sub *ws.ProgramSubscription,
95-
updates chan<- PriceAccountUpdate,
96-
) error {
140+
func (p *PriceAccountStream) readNextUpdate(ctx context.Context, sub *ws.ProgramSubscription) error {
97141
// If no update comes in within 20 seconds, bail.
98142
const readTimeout = 20 * time.Second
99143
ctx, cancel := context.WithTimeout(ctx, readTimeout)
@@ -102,7 +146,7 @@ func (c *Client) readNextUpdate(
102146
<-ctx.Done()
103147
// Terminate subscription if above timer has expired.
104148
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
105-
c.Log.Warn("Read deadline exceeded, terminating WebSocket connection",
149+
p.client.Log.Warn("Read deadline exceeded, terminating WebSocket connection",
106150
zap.Duration("timeout", readTimeout))
107151
sub.Unsubscribe()
108152
}
@@ -116,7 +160,7 @@ func (c *Client) readNextUpdate(
116160
metricsWsEventsTotal.Inc()
117161

118162
// Decode update.
119-
if update.Value.Account.Owner != c.Env.Program {
163+
if update.Value.Account.Owner != p.client.Env.Program {
120164
return nil
121165
}
122166
accountData := update.Value.Account.Data.GetBinary()
@@ -125,7 +169,7 @@ func (c *Client) readNextUpdate(
125169
}
126170
priceAcc := new(PriceAccount)
127171
if err := priceAcc.UnmarshalBinary(accountData); err != nil {
128-
c.Log.Warn("Failed to unmarshal priceAcc account", zap.Error(err))
172+
p.client.Log.Warn("Failed to unmarshal priceAcc account", zap.Error(err))
129173
return nil
130174
}
131175

@@ -137,7 +181,7 @@ func (c *Client) readNextUpdate(
137181
select {
138182
case <-ctx.Done():
139183
return ctx.Err()
140-
case updates <- msg:
184+
case p.updates <- msg:
141185
return nil
142186
}
143187
}

stream_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2022 Blockdaemon Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package pyth
16+
17+
import (
18+
"fmt"
19+
"time"
20+
)
21+
22+
func ExampleClient_StreamPriceAccounts() {
23+
client := NewClient(Devnet, testRPC, testWS)
24+
stream := client.StreamPriceAccounts()
25+
// Close stream after a while.
26+
go func() {
27+
<-time.After(3 * time.Second)
28+
stream.Close()
29+
}()
30+
// Print updates.
31+
for update := range stream.Updates() {
32+
fmt.Println(update.Agg.Price)
33+
}
34+
}

0 commit comments

Comments
 (0)