diff --git a/api/ws/client.go b/api/ws/client.go index 34178bd..81d1678 100644 --- a/api/ws/client.go +++ b/api/ws/client.go @@ -10,6 +10,7 @@ import ( "github.com/amir-the-h/okex" "github.com/amir-the-h/okex/events" "github.com/gorilla/websocket" + "io" "net/http" "sync" "time" @@ -31,6 +32,7 @@ type ClientWs struct { sendChan map[bool]chan []byte url map[bool]okex.BaseURL conn map[bool]*websocket.Conn + dialer *websocket.Dialer apiKey string secretKey []byte passphrase string @@ -55,19 +57,18 @@ const ( func NewClient(ctx context.Context, apiKey, secretKey, passphrase string, url map[bool]okex.BaseURL) *ClientWs { ctx, cancel := context.WithCancel(ctx) c := &ClientWs{ - apiKey: apiKey, - secretKey: []byte(secretKey), - passphrase: passphrase, - ctx: ctx, - Cancel: cancel, - url: url, - sendChan: map[bool]chan []byte{true: make(chan []byte, 3), false: make(chan []byte, 3)}, - DoneChan: make(chan interface{}), - StructuredEventChan: make(chan interface{}), - RawEventChan: make(chan *events.Basic), - conn: make(map[bool]*websocket.Conn), - lastTransmit: make(map[bool]*time.Time), - mu: map[bool]*sync.RWMutex{true: {}, false: {}}, + apiKey: apiKey, + secretKey: []byte(secretKey), + passphrase: passphrase, + ctx: ctx, + Cancel: cancel, + url: url, + sendChan: map[bool]chan []byte{true: make(chan []byte, 3), false: make(chan []byte, 3)}, + DoneChan: make(chan interface{}), + conn: make(map[bool]*websocket.Conn), + dialer: websocket.DefaultDialer, + lastTransmit: make(map[bool]*time.Time), + mu: map[bool]*sync.RWMutex{true: {}, false: {}}, } c.Private = NewPrivate(c) c.Public = NewPublic(c) @@ -131,20 +132,24 @@ func (c *ClientWs) Login() error { // Users can choose to subscribe to one or more channels, and the total length of multiple channels cannot exceed 4096 bytes. // // https://www.okex.com/docs-v5/en/#websocket-api-subscribe -func (c *ClientWs) Subscribe(p bool, ch []okex.ChannelName, args map[string]string) error { - count := 1 - if len(ch) != 0 { - count = len(ch) - } - tmpArgs := make([]map[string]string, count) - tmpArgs[0] = args - for i, name := range ch { - tmpArgs[i] = map[string]string{} - tmpArgs[i]["channel"] = string(name) - for k, v := range args { - tmpArgs[i][k] = v +func (c *ClientWs) Subscribe(p bool, ch []okex.ChannelName, args ...map[string]string) error { + chCount := max(len(ch), 1) + tmpArgs := make([]map[string]string, chCount*len(args)) + + n := 0 + for i := 0; i < chCount; i++ { + for _, arg := range args { + tmpArgs[n] = make(map[string]string) + for k, v := range arg { + tmpArgs[n][k] = v + } + if len(ch) > 0 { + tmpArgs[n]["channel"] = string(ch[i]) + } + n++ } } + return c.Send(p, okex.SubscribeOperation, tmpArgs) } @@ -205,6 +210,16 @@ func (c *ClientWs) SetChannels(errCh chan *events.Error, subCh chan *events.Subs c.SuccessChan = sCh } +// SetDialer sets a custom dialer for the WebSocket connection. +func (c *ClientWs) SetDialer(dialer *websocket.Dialer) { + c.dialer = dialer +} + +func (c *ClientWs) SetEventChannels(structuredEventCh chan interface{}, rawEventCh chan *events.Basic) { + c.StructuredEventChan = structuredEventCh + c.RawEventChan = rawEventCh +} + // WaitForAuthorization waits for the auth response and try to log in if it was needed func (c *ClientWs) WaitForAuthorization() error { if c.Authorized { @@ -225,16 +240,23 @@ func (c *ClientWs) WaitForAuthorization() error { func (c *ClientWs) dial(p bool) error { c.mu[p].Lock() - conn, res, err := websocket.DefaultDialer.Dial(string(c.url[p]), nil) + conn, res, err := c.dialer.Dial(string(c.url[p]), nil) if err != nil { var statusCode int if res != nil { statusCode = res.StatusCode } - c.mu[p].Unlock() return fmt.Errorf("error %d: %w", statusCode, err) } - defer res.Body.Close() + c.conn[p] = conn + c.mu[p].Unlock() + + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + fmt.Printf("error closing body: %v\n", err) + } + }(res.Body) go func() { err := c.receiver(p) if err != nil { @@ -247,10 +269,10 @@ func (c *ClientWs) dial(p bool) error { fmt.Printf("sender error: %v\n", err) } }() - c.conn[p] = conn - c.mu[p].Unlock() + return nil } + func (c *ClientWs) sender(p bool) error { ticker := time.NewTicker(time.Millisecond * 300) defer ticker.Stop() @@ -279,7 +301,11 @@ func (c *ClientWs) sender(p bool) error { return err } case <-ticker.C: - if c.conn[p] != nil && (c.lastTransmit[p] == nil || (c.lastTransmit[p] != nil && time.Since(*c.lastTransmit[p]) > PingPeriod)) { + c.mu[p].RLock() + conn := c.conn[p] + lastTransmit := c.lastTransmit[p] + c.mu[p].RUnlock() + if conn != nil && (lastTransmit == nil || (lastTransmit != nil && time.Since(*lastTransmit) > PingPeriod)) { go func() { c.sendChan[p] <- []byte("ping") }() @@ -289,6 +315,7 @@ func (c *ClientWs) sender(p bool) error { } } } + func (c *ClientWs) receiver(p bool) error { for { select { @@ -326,6 +353,7 @@ func (c *ClientWs) receiver(p bool) error { } } } + func (c *ClientWs) sign(method, path string) (string, string) { t := time.Now().UTC().Unix() ts := fmt.Sprint(t) @@ -335,6 +363,7 @@ func (c *ClientWs) sign(method, path string) (string, string) { h.Write(p) return ts, base64.StdEncoding.EncodeToString(h.Sum(nil)) } + func (c *ClientWs) handleCancel(msg string) error { go func() { c.DoneChan <- msg @@ -342,35 +371,34 @@ func (c *ClientWs) handleCancel(msg string) error { return fmt.Errorf("operation cancelled: %s", msg) } -// TODO: break each case into a separate function func (c *ClientWs) process(data []byte, e *events.Basic) bool { switch e.Event { case "error": e := events.Error{} _ = json.Unmarshal(data, &e) - go func() { + if c.ErrChan != nil { c.ErrChan <- &e - }() + } return true case "subscribe": e := events.Subscribe{} _ = json.Unmarshal(data, &e) - go func() { - if c.SubscribeChan != nil { - c.SubscribeChan <- &e - } + if c.SubscribeChan != nil { + c.SubscribeChan <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true case "unsubscribe": e := events.Unsubscribe{} _ = json.Unmarshal(data, &e) - go func() { - if c.UnsubscribeCh != nil { - c.UnsubscribeCh <- &e - } + if c.UnsubscribeCh != nil { + c.UnsubscribeCh <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true case "login": if time.Since(*c.AuthRequested).Seconds() > 30 { @@ -381,12 +409,12 @@ func (c *ClientWs) process(data []byte, e *events.Basic) bool { c.Authorized = true e := events.Login{} _ = json.Unmarshal(data, &e) - go func() { - if c.LoginChan != nil { - c.LoginChan <- &e - } + if c.LoginChan != nil { + c.LoginChan <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true } if c.Private.Process(data, e) { @@ -403,14 +431,14 @@ func (c *ClientWs) process(data []byte, e *events.Basic) bool { } e := events.Success{} _ = json.Unmarshal(data, &e) - go func() { - if c.SuccessChan != nil { - c.SuccessChan <- &e - } + if c.SuccessChan != nil { + c.SuccessChan <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true } - go func() { c.RawEventChan <- e }() + c.RawEventChan <- e return false } diff --git a/api/ws/public.go b/api/ws/public.go index af13a23..45c201e 100644 --- a/api/ws/public.go +++ b/api/ws/public.go @@ -83,7 +83,7 @@ func (c *Public) UTickers(req requests.Tickers, rCh ...bool) error { } // OpenInterest -// Retrieve the open interest. Data will by pushed every 3 seconds. +// Retrieve the open interest. Data will be pushed every 3 seconds. // // https://www.okex.com/docs-v5/en/#websocket-api-public-channels-open-interest-channel func (c *Public) OpenInterest(req requests.OpenInterest, ch ...chan *public.OpenInterest) error { @@ -106,7 +106,7 @@ func (c *Public) UOpenInterest(req requests.OpenInterest, rCh ...bool) error { } // Candlesticks -// Retrieve the open interest. Data will by pushed every 3 seconds. +// Retrieve the open interest. Data will be pushed every 3 seconds. // // https://www.okex.com/docs-v5/en/#websocket-api-public-channels-candlesticks-channel func (c *Public) Candlesticks(req requests.Candlesticks, ch ...chan *public.Candlesticks) error { @@ -248,17 +248,21 @@ func (c *Public) UPriceLimit(req requests.PriceLimit, rCh ...bool) error { } // OrderBook -// Retrieve order book data. +// Retrieve order book data for multiple instruments. // // Use books for 400 depth levels, book5 for 5 depth levels, books50-l2-tbt tick-by-tick 50 depth levels, and books-l2-tbt for tick-by-tick 400 depth levels. // // https://www.okex.com/docs-v5/en/#websocket-api-public-channels-order-book-channel -func (c *Public) OrderBook(req requests.OrderBook, ch ...chan *public.OrderBook) error { - m := okex.S2M(req) +func (c *Public) OrderBook(reqs []requests.OrderBook, ch ...chan *public.OrderBook) error { if len(ch) > 0 { c.obCh = ch[0] } - return c.Subscribe(false, []okex.ChannelName{}, m) + var subscriptions []map[string]string + for _, req := range reqs { + m := okex.S2M(req) + subscriptions = append(subscriptions, m) + } + return c.Subscribe(false, []okex.ChannelName{}, subscriptions...) } // UOrderBook @@ -375,196 +379,176 @@ func (c *Public) Process(data []byte, e *events.Basic) bool { switch ch { case "instruments": e := public.Instruments{} - err := json.Unmarshal(data, &e) - if err != nil { + if err := json.Unmarshal(data, &e); err != nil { return false } - go func() { - if c.iCh != nil { - c.iCh <- &e - } + if c.iCh != nil { + c.iCh <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true case "tickers": e := public.Tickers{} - err := json.Unmarshal(data, &e) - if err != nil { + if err := json.Unmarshal(data, &e); err != nil { return false } - go func() { - if c.tCh != nil { - c.tCh <- &e - } + if c.tCh != nil { + c.tCh <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true case "open-interest": e := public.OpenInterest{} - err := json.Unmarshal(data, &e) - if err != nil { + if err := json.Unmarshal(data, &e); err != nil { return false } - go func() { - if c.oiCh != nil { - c.oiCh <- &e - } + if c.oiCh != nil { + c.oiCh <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true case "trades": e := public.Trades{} - err := json.Unmarshal(data, &e) - if err != nil { + if err := json.Unmarshal(data, &e); err != nil { return false } - go func() { - if c.trCh != nil { - c.trCh <- &e - } + if c.trCh != nil { + c.trCh <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true case "estimated-price": e := public.EstimatedDeliveryExercisePrice{} - err := json.Unmarshal(data, &e) - if err != nil { + if err := json.Unmarshal(data, &e); err != nil { return false } - go func() { - if c.edepCh != nil { - c.edepCh <- &e - } + if c.edepCh != nil { + c.edepCh <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true case "mark-price": e := public.MarkPrice{} - err := json.Unmarshal(data, &e) - if err != nil { + if err := json.Unmarshal(data, &e); err != nil { return false } - go func() { - if c.mpCh != nil { - c.mpCh <- &e - } + if c.mpCh != nil { + c.mpCh <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true case "price-limit": e := public.PriceLimit{} - err := json.Unmarshal(data, &e) - if err != nil { + if err := json.Unmarshal(data, &e); err != nil { return false } - go func() { - if c.plCh != nil { - c.plCh <- &e - } + if c.plCh != nil { + c.plCh <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true case "opt-summary": e := public.OPTIONSummary{} - err := json.Unmarshal(data, &e) - if err != nil { + if err := json.Unmarshal(data, &e); err != nil { return false } - go func() { - if c.osCh != nil { - c.osCh <- &e - } + if c.osCh != nil { + c.osCh <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true case "funding-rate": - e := public.OPTIONSummary{} - err := json.Unmarshal(data, &e) - if err != nil { + e := public.FundingRate{} + if err := json.Unmarshal(data, &e); err != nil { return false } - go func() { - if c.osCh != nil { - c.osCh <- &e - } + if c.frCh != nil { + c.frCh <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true case "index-tickers": e := public.IndexTickers{} - err := json.Unmarshal(data, &e) - if err != nil { + if err := json.Unmarshal(data, &e); err != nil { return false } - go func() { - if c.itCh != nil { - c.itCh <- &e - } + if c.itCh != nil { + c.itCh <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true default: - // special cases - // market price candlestick channel chName := fmt.Sprint(ch) - // market price channels if strings.Contains(chName, "mark-price-candle") { e := public.MarkPriceCandlesticks{} - err := json.Unmarshal(data, &e) - if err != nil { + if err := json.Unmarshal(data, &e); err != nil { return false } - go func() { - if c.mpcCh != nil { - c.mpcCh <- &e - } + if c.mpcCh != nil { + c.mpcCh <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true } - // index chandlestick channels if strings.Contains(chName, "index-candle") { e := public.IndexCandlesticks{} - err := json.Unmarshal(data, &e) - if err != nil { + if err := json.Unmarshal(data, &e); err != nil { return false } - go func() { - if c.icCh != nil { - c.icCh <- &e - } + if c.icCh != nil { + c.icCh <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true } - // candlestick channels if strings.Contains(chName, "candle") { e := public.Candlesticks{} - err := json.Unmarshal(data, &e) - if err != nil { + if err := json.Unmarshal(data, &e); err != nil { return false } - go func() { - if c.cCh != nil { - c.cCh <- &e - } + if c.cCh != nil { + c.cCh <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true } - // order book channels if strings.Contains(chName, "books") { e := public.OrderBook{} - err := json.Unmarshal(data, &e) - if err != nil { + if err := json.Unmarshal(data, &e); err != nil { return false } - go func() { - if c.obCh != nil { - c.obCh <- &e - } + if c.obCh != nil { + c.obCh <- &e + } + if c.StructuredEventChan != nil { c.StructuredEventChan <- e - }() + } return true } } diff --git a/examples/books.go b/examples/books.go new file mode 100644 index 0000000..9e393b3 --- /dev/null +++ b/examples/books.go @@ -0,0 +1,69 @@ +package main + +import ( + "context" + "crypto/tls" + "github.com/amir-the-h/okex" + "github.com/amir-the-h/okex/api" + "github.com/amir-the-h/okex/events/public" + requests "github.com/amir-the-h/okex/requests/ws/public" + "github.com/gorilla/websocket" + "log" + "net/http" + _ "net/http/pprof" + "time" +) + +func main() { + + // Start the pprof server + go func() { + log.Println("Starting pprof server on localhost:6060") + if err := http.ListenAndServe("localhost:6060", nil); err != nil { + log.Fatalf("could not start pprof server: %v", err) + } + }() + + apiKey := "" + secretKey := "" + passphrase := "" + ctx := context.Background() + client, err := api.NewClient(ctx, apiKey, secretKey, passphrase, okex.NormalServer) + if err != nil { + log.Fatalln(err) + } + + orderBookRequests := []requests.OrderBook{ + {InstID: "BTC-USDT", Channel: "books"}, + {InstID: "ETH-USDT", Channel: "books"}, + {InstID: "LTC-USDT", Channel: "books"}, + {InstID: "XRP-USDT", Channel: "books"}, + {InstID: "EOS-USDT", Channel: "books"}, + {InstID: "BCH-USDT", Channel: "books"}, + {InstID: "ETC-USDT", Channel: "books"}, + {InstID: "BSV-USDT", Channel: "books"}, + {InstID: "TRX-USDT", Channel: "books"}, + {InstID: "LINK-USDT", Channel: "books"}, + {InstID: "ADA-USDT", Channel: "books"}, + {InstID: "DOT-USDT", Channel: "books"}, + {InstID: "UNI-USDT", Channel: "books"}, + } + + client.Ws.Public.SetDialer(&websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }) + obCh := make(chan *public.OrderBook) + err = client.Ws.Public.OrderBook(orderBookRequests, obCh) + if err != nil { + log.Fatalln(err) + } + + // Listen for updates + for update := range obCh { + log.Printf("Received order book update: %+v\n", update) + insId, _ := update.Arg.Get("instId") + log.Printf("Instrument ID: %s\n", insId) + } +} diff --git a/go.mod b/go.mod index d4a64a0..68006a2 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,5 @@ module github.com/amir-the-h/okex -go 1.17 +go 1.21 require github.com/gorilla/websocket v1.4.2