Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
gabor committed Aug 22, 2024
1 parent a1c7fd2 commit dbedd92
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 1 deletion.
34 changes: 33 additions & 1 deletion backend/httpclient/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net"
"net/http"

"github.com/grafana/grafana-plugin-sdk-go/backend/log"
"github.com/grafana/grafana-plugin-sdk-go/backend/proxy"
)

Expand Down Expand Up @@ -41,6 +42,30 @@ func New(opts ...Options) (*http.Client, error) {
return c, nil
}

type reportSizeRoundtripper struct {
next http.RoundTripper
}

func newReportSizeRoundtripper(next http.RoundTripper) http.RoundTripper {
return &reportSizeRoundtripper{next: next}
}

func (rt *reportSizeRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
res, err := rt.next.RoundTrip(req)

if err != nil {
return res, err
}

res.Body = ReportSizeReader(res.Body, func(size int) {
log.DefaultLogger.Debug("downstream response info", "bytes", size, "url", req.URL.String())
})

return res, err
}

var _ http.RoundTripper = &reportSizeRoundtripper{}

// GetTransport creates a new http.RoundTripper given provided options.
// If opts is nil the http.DefaultTransport will be returned.
// If no middlewares are provided the DefaultMiddlewares() will be used. If you
Expand Down Expand Up @@ -93,7 +118,14 @@ func GetTransport(opts ...Options) (http.RoundTripper, error) {
return nil, err
}

return roundTripperFromMiddlewares(clientOpts, clientOpts.Middlewares, transport)
_, hasDatasourceTypeLabel := clientOpts.Labels["datasource_type"]

var rt http.RoundTripper = transport
if hasDatasourceTypeLabel {
rt = newReportSizeRoundtripper(rt)
}

return roundTripperFromMiddlewares(clientOpts, clientOpts.Middlewares, rt)
}

// GetTLSConfig creates a new tls.Config given provided options.
Expand Down
43 changes: 43 additions & 0 deletions backend/httpclient/report_size_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package httpclient

import (
"io"
)

type ReportSizeCallback func(int)

func ReportSizeReader(r io.ReadCloser, callback ReportSizeCallback) io.ReadCloser {
return &reportSizeReader{
r: r,
bytesRead: 0,
sizeReported: false,
callback: callback,
}
}

type reportSizeReader struct {
r io.ReadCloser
bytesRead int
sizeReported bool
callback ReportSizeCallback
}

func (rsr *reportSizeReader) Read(p []byte) (int, error) {
count, err := rsr.r.Read(p)
rsr.bytesRead += count

// we want to handle the case when:
// 1. Read() returns an error
// 2. user calls Read() again
// we only want to report the size on [1], not on [2].
// i do not know if this is allowed or not, but better be safe
if (!rsr.sizeReported) && (err != nil) {
rsr.sizeReported = true
rsr.callback(rsr.bytesRead)
}
return count, err
}

func (rsr *reportSizeReader) Close() error {
return rsr.r.Close()
}
76 changes: 76 additions & 0 deletions backend/httpclient/report_size_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package httpclient

import (
"bytes"
"errors"
"io"
"testing"

"github.com/stretchr/testify/require"
)

type testReportSizeReaderReader struct {
r io.ReadCloser
overrideErr error
}

// this reader allows customizing the error "at the end", to help with tests
func newTestReportSizeReaderReader(b []byte, overrideErr error) io.ReadCloser {
r := io.NopCloser(bytes.NewReader(b))
return &testReportSizeReaderReader{
r: r,
overrideErr: overrideErr,
}
}

func (r *testReportSizeReaderReader) Read(p []byte) (int, error) {
count, err := r.r.Read(p)
if err != nil {
err = r.overrideErr
}

return count, err
}

func (r *testReportSizeReaderReader) Close() error {
return r.r.Close()
}

func TestReportSizeReader(t *testing.T) {

tcs := []struct {
name string
b []byte
err error
}{
{name: "test1", b: []byte("hello world"), err: io.EOF},
{name: "test2", b: []byte("hello world2"), err: errors.New("test error")},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
// we want to make sure things work when the data is received in multiple batches,
// so we intentionally read the data in small blocks
p := make([]byte, 2)
var err error
byteCount := 0
c := 0

reader := newTestReportSizeReaderReader(tc.b, tc.err)
reportHappened := false
wrapped := ReportSizeReader(reader, func(size int) {
reportHappened = true
require.Equal(t, len(tc.b), size) // test that we report the right number
require.Equal(t, len(tc.b), byteCount) // test that we return the correct sizes in Read() calls
})

for err == nil {
c, err = wrapped.Read(p)
byteCount += c
}
wrapped.Close()

require.Equal(t, tc.err, err)
require.Equal(t, true, reportHappened)
})
}
}

0 comments on commit dbedd92

Please sign in to comment.