Skip to content

Commit

Permalink
feat: Add ability to provide dynamic query parameters (#44)
Browse files Browse the repository at this point in the history
keelerm84 authored Dec 19, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent c7ba525 commit efc5d86
Showing 3 changed files with 105 additions and 5 deletions.
19 changes: 14 additions & 5 deletions stream.go
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ import (
"io"
"io/ioutil"
"net/http"
"net/url"
"sync"
"time"
)
@@ -14,11 +15,12 @@ import (
// It will try and reconnect if the connection is lost, respecting both
// received retry delays and event id's.
type Stream struct {
c *http.Client
req *http.Request
lastEventID string
readTimeout time.Duration
retryDelay *retryDelayStrategy
c *http.Client
req *http.Request
queryParamsFunc *func(existing url.Values) url.Values
lastEventID string
readTimeout time.Duration
retryDelay *retryDelayStrategy
// Events emits the events received by the stream
Events chan Event
// Errors emits any errors encountered while reading events from the stream.
@@ -187,6 +189,10 @@ func newStream(request *http.Request, configuredOptions streamOptions) *Stream {
closer: make(chan struct{}),
}

if configuredOptions.queryParamsFunc != nil {
stream.queryParamsFunc = configuredOptions.queryParamsFunc
}

if configuredOptions.errorHandler == nil {
// The Errors channel is only used if there is no error handler.
stream.Errors = make(chan error)
@@ -231,6 +237,9 @@ func (stream *Stream) connect() (io.ReadCloser, error) {
stream.req.Header.Set("Last-Event-ID", stream.lastEventID)
}
req := *stream.req
if stream.queryParamsFunc != nil {
req.URL.RawQuery = (*stream.queryParamsFunc)(req.URL.Query()).Encode()
}

// All but the initial connection will need to regenerate the body
if stream.connections > 0 && req.GetBody != nil {
18 changes: 18 additions & 0 deletions stream_options.go
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@ package eventsource

import (
"net/http"
"net/url"
"time"
)

@@ -16,6 +17,7 @@ type streamOptions struct {
retryResetInterval time.Duration
initialRetryTimeout time.Duration
errorHandler StreamErrorHandler
queryParamsFunc *func(existing url.Values) url.Values
}

// StreamOption is a common interface for optional configuration parameters that can be
@@ -24,6 +26,22 @@ type StreamOption interface {
apply(s *streamOptions) error
}

type dynamicQueryParamsOption struct {
queryParamsFunc func(existing url.Values) url.Values
}

func (o dynamicQueryParamsOption) apply(s *streamOptions) error {
s.queryParamsFunc = &o.queryParamsFunc
return nil
}

// StreamOptionDynamicQueryParams returns an option that sets a function to
// generate query parameters each time the stream needs to make a fresh
// connection.
func StreamOptionDynamicQueryParams(f func(existing url.Values) url.Values) StreamOption {
return dynamicQueryParamsOption{queryParamsFunc: f}
}

type readTimeoutOption struct {
timeout time.Duration
}
73 changes: 73 additions & 0 deletions stream_requests_test.go
Original file line number Diff line number Diff line change
@@ -4,6 +4,8 @@ import (
"bytes"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
"time"

@@ -45,6 +47,77 @@ func TestStreamSendsLastEventID(t *testing.T) {
assert.Equal(t, lastID, r0.Request.Header.Get("Last-Event-ID"))
}

func TestCanReplaceStreamQueryParameters(t *testing.T) {
streamHandler, streamControl := httphelpers.SSEHandler(nil)
defer streamControl.Close()
handler, requestsCh := httphelpers.RecordingHandler(streamHandler)

httpServer := httptest.NewServer(handler)
defer httpServer.Close()

option := StreamOptionDynamicQueryParams(func(existing url.Values) url.Values {
return url.Values{
"filter": []string{"my-custom-filter"},
"basis": []string{"last-known-basis"},
}
})

stream := mustSubscribe(t, httpServer.URL, option)
defer stream.Close()

r0 := <-requestsCh
assert.Equal(t, "my-custom-filter", r0.Request.URL.Query().Get("filter"))
assert.Equal(t, "last-known-basis", r0.Request.URL.Query().Get("basis"))
}

func TestCanUpdateStreamQueryParameters(t *testing.T) {
streamHandler, streamControl := httphelpers.SSEHandler(nil)
defer streamControl.Close()
handler, requestsCh := httphelpers.RecordingHandler(streamHandler)

httpServer := httptest.NewServer(handler)
defer httpServer.Close()

option := StreamOptionDynamicQueryParams(func(existing url.Values) url.Values {
if existing.Has("count") {
count, _ := strconv.Atoi(existing.Get("count"))

if count == 1 {
existing.Set("count", strconv.Itoa(count+1))
return existing
}

return url.Values{}
}

return url.Values{
"initial": []string{"payload is set"},
"count": []string{"1"},
}
})

stream := mustSubscribe(t, httpServer.URL, option, StreamOptionInitialRetry(time.Millisecond))
defer stream.Close()

r0 := <-requestsCh
assert.Equal(t, "payload is set", r0.Request.URL.Query().Get("initial"))
assert.Equal(t, "1", r0.Request.URL.Query().Get("count"))

streamControl.EndAll()
<-stream.Errors // Accept the error to unblock the retry handler

r1 := <-requestsCh
assert.Equal(t, "payload is set", r1.Request.URL.Query().Get("initial"))
assert.Equal(t, "2", r1.Request.URL.Query().Get("count"))

streamControl.EndAll()
<-stream.Errors // Accept the error to unblock the retry handler

r2 := <-requestsCh
assert.False(t, r2.Request.URL.Query().Has("initial"))
assert.False(t, r2.Request.URL.Query().Has("count"))
}

func TestStreamReconnectWithRequestBodySendsBodyTwice(t *testing.T) {
body := []byte("my-body")

0 comments on commit efc5d86

Please sign in to comment.