From b5e125e6bd60b4ffb34be580d35bb3d6f5135b7f Mon Sep 17 00:00:00 2001 From: Marcus Efraimsson Date: Thu, 10 Oct 2024 14:55:16 +0200 Subject: [PATCH] Error source HTTP client middleware --- backend/data_adapter.go | 9 +- backend/error_source.go | 118 ++-------- backend/error_source_test.go | 142 ------------ backend/httpclient/error_source_middleware.go | 24 ++ backend/httpclient/http_client.go | 1 + backend/httpclient/http_client_test.go | 3 +- backend/httpclient/provider_test.go | 9 +- backend/request_status.go | 30 +-- backend/status/doc.go | 2 + backend/status/status_source.go | 213 ++++++++++++++++++ backend/status/status_source_test.go | 193 ++++++++++++++++ .../errorsource/error_source_middleware.go | 11 +- 12 files changed, 481 insertions(+), 274 deletions(-) delete mode 100644 backend/error_source_test.go create mode 100644 backend/httpclient/error_source_middleware.go create mode 100644 backend/status/doc.go create mode 100644 backend/status/status_source.go create mode 100644 backend/status/status_source_test.go diff --git a/backend/data_adapter.go b/backend/data_adapter.go index 827efd85d..31d69f429 100644 --- a/backend/data_adapter.go +++ b/backend/data_adapter.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" + "github.com/grafana/grafana-plugin-sdk-go/backend/status" "github.com/grafana/grafana-plugin-sdk-go/genproto/pluginv2" ) @@ -29,9 +30,9 @@ func (a *dataSDKAdapter) QueryData(ctx context.Context, req *pluginv2.QueryDataR var innerErr error resp, innerErr = a.queryDataHandler.QueryData(ctx, parsedReq) - status := RequestStatusFromQueryDataResponse(resp, innerErr) + requestStatus := RequestStatusFromQueryDataResponse(resp, innerErr) if innerErr != nil { - return status, innerErr + return requestStatus, innerErr } else if resp == nil { return RequestStatusError, errors.New("both response and error are nil, but one must be provided") } @@ -41,7 +42,7 @@ func (a *dataSDKAdapter) QueryData(ctx context.Context, req *pluginv2.QueryDataR // and if there's no plugin error var hasPluginError, hasDownstreamError bool for refID, r := range resp.Responses { - if r.Error == nil || isCancelledError(r.Error) { + if r.Error == nil || status.IsCancelledError(r.Error) { continue } @@ -81,7 +82,7 @@ func (a *dataSDKAdapter) QueryData(ctx context.Context, req *pluginv2.QueryDataR } } - return status, nil + return requestStatus, nil }) if err != nil { return nil, err diff --git a/backend/error_source.go b/backend/error_source.go index 8c157cf30..4ff4d2462 100644 --- a/backend/error_source.go +++ b/backend/error_source.go @@ -2,125 +2,62 @@ package backend import ( "context" - "errors" "fmt" - "net/http" + + "github.com/grafana/grafana-plugin-sdk-go/backend/status" ) // ErrorSource type defines the source of the error -type ErrorSource string +type ErrorSource = status.Source const ( // ErrorSourcePlugin error originates from plugin. - ErrorSourcePlugin ErrorSource = "plugin" + ErrorSourcePlugin = status.SourcePlugin // ErrorSourceDownstream error originates from downstream service. - ErrorSourceDownstream ErrorSource = "downstream" + ErrorSourceDownstream = status.SourceDownstream // DefaultErrorSource is the default [ErrorSource] that should be used when it is not explicitly set. - DefaultErrorSource ErrorSource = ErrorSourcePlugin + DefaultErrorSource = status.SourcePlugin ) -func (es ErrorSource) IsValid() bool { - return es == ErrorSourceDownstream || es == ErrorSourcePlugin +// ErrorSourceFromHTTPError returns an [ErrorSource] based on provided error. +func ErrorSourceFromHTTPError(err error) ErrorSource { + return status.SourceFromHTTPError(err) } -// ErrorSourceFromStatus returns an [ErrorSource] based on provided HTTP status code. +// ErrorSourceFromHTTPStatus returns an [ErrorSource] based on provided HTTP status code. func ErrorSourceFromHTTPStatus(statusCode int) ErrorSource { - switch statusCode { - case http.StatusMethodNotAllowed, - http.StatusNotAcceptable, - http.StatusPreconditionFailed, - http.StatusRequestEntityTooLarge, - http.StatusRequestHeaderFieldsTooLarge, - http.StatusRequestURITooLong, - http.StatusExpectationFailed, - http.StatusUpgradeRequired, - http.StatusRequestedRangeNotSatisfiable, - http.StatusNotImplemented: - return ErrorSourcePlugin - } - - return ErrorSourceDownstream + return status.SourceFromHTTPStatus(statusCode) } -type errorWithSourceImpl struct { - source ErrorSource - err error +// IsDownstreamError return true if provided error is an error with downstream source or +// a timeout error or a cancelled error. +func IsDownstreamError(err error) bool { + return status.IsDownstreamError(err) } -func IsDownstreamError(err error) bool { - e := errorWithSourceImpl{ - source: ErrorSourceDownstream, - } - if errors.Is(err, e) { - return true - } - - type errorWithSource interface { - ErrorSource() ErrorSource - } - - // nolint:errorlint - if errWithSource, ok := err.(errorWithSource); ok && errWithSource.ErrorSource() == ErrorSourceDownstream { - return true - } - - if isHTTPTimeoutError(err) || isCancelledError(err) { - return true - } - - return false +// IsDownstreamError return true if provided error is an error with downstream source or +// a HTTP timeout error or a cancelled error or a connection reset/refused error or dns not found error. +func IsDownstreamHTTPError(err error) bool { + return status.IsDownstreamHTTPError(err) } func DownstreamError(err error) error { - return errorWithSourceImpl{ - source: ErrorSourceDownstream, - err: err, - } + return status.DownstreamError(err) } func DownstreamErrorf(format string, a ...any) error { return DownstreamError(fmt.Errorf(format, a...)) } -func (e errorWithSourceImpl) ErrorSource() ErrorSource { - return e.source -} - -func (e errorWithSourceImpl) Error() string { - return fmt.Errorf("%s error: %w", e.source, e.err).Error() -} - -// Implements the interface used by [errors.Is]. -func (e errorWithSourceImpl) Is(err error) bool { - if errWithSource, ok := err.(errorWithSourceImpl); ok { - return errWithSource.ErrorSource() == e.source - } - - return false -} - -func (e errorWithSourceImpl) Unwrap() error { - return e.err -} - -type errorSourceCtxKey struct{} - -// errorSourceFromContext returns the error source stored in the context. -// If no error source is stored in the context, [DefaultErrorSource] is returned. func errorSourceFromContext(ctx context.Context) ErrorSource { - value, ok := ctx.Value(errorSourceCtxKey{}).(*ErrorSource) - if ok { - return *value - } - return DefaultErrorSource + return status.SourceFromContext(ctx) } -// initErrorSource initialize the status source for the context. +// initErrorSource initialize the error source for the context. func initErrorSource(ctx context.Context) context.Context { - s := DefaultErrorSource - return context.WithValue(ctx, errorSourceCtxKey{}, &s) + return status.InitSource(ctx) } // WithErrorSource mutates the provided context by setting the error source to @@ -128,12 +65,7 @@ func initErrorSource(ctx context.Context) context.Context { // will not be mutated and an error returned. This means that [initErrorSource] // has to be called before this function. func WithErrorSource(ctx context.Context, s ErrorSource) error { - v, ok := ctx.Value(errorSourceCtxKey{}).(*ErrorSource) - if !ok { - return errors.New("the provided context does not have a status source") - } - *v = s - return nil + return status.WithSource(ctx, s) } // WithDownstreamErrorSource mutates the provided context by setting the error source to @@ -141,5 +73,5 @@ func WithErrorSource(ctx context.Context, s ErrorSource) error { // will not be mutated and an error returned. This means that [initErrorSource] has to be // called before this function. func WithDownstreamErrorSource(ctx context.Context) error { - return WithErrorSource(ctx, ErrorSourceDownstream) + return status.WithDownstreamSource(ctx) } diff --git a/backend/error_source_test.go b/backend/error_source_test.go deleted file mode 100644 index 94ebe1354..000000000 --- a/backend/error_source_test.go +++ /dev/null @@ -1,142 +0,0 @@ -package backend - -import ( - "context" - "errors" - "fmt" - "net" - "os" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -func TestErrorSource(t *testing.T) { - var es ErrorSource - require.False(t, es.IsValid()) - require.True(t, ErrorSourceDownstream.IsValid()) - require.True(t, ErrorSourcePlugin.IsValid()) -} - -func TestIsDownstreamError(t *testing.T) { - tcs := []struct { - name string - err error - expected bool - }{ - { - name: "nil", - err: nil, - expected: false, - }, - { - name: "downstream error", - err: DownstreamError(nil), - expected: true, - }, - { - name: "timeout network error", - err: newFakeNetworkError(true, false), - expected: true, - }, - { - name: "wrapped timeout network error", - err: fmt.Errorf("oh no. err %w", newFakeNetworkError(true, false)), - expected: true, - }, - { - name: "temporary timeout network error", - err: newFakeNetworkError(true, true), - expected: true, - }, - { - name: "non-timeout network error", - err: newFakeNetworkError(false, false), - expected: false, - }, - { - name: "os.ErrDeadlineExceeded", - err: os.ErrDeadlineExceeded, - expected: true, - }, - { - name: "os.ErrDeadlineExceeded", - err: fmt.Errorf("error: %w", os.ErrDeadlineExceeded), - expected: true, - }, - { - name: "wrapped os.ErrDeadlineExceeded", - err: errors.Join(fmt.Errorf("oh no"), os.ErrDeadlineExceeded), - expected: true, - }, - { - name: "other error", - err: fmt.Errorf("other error"), - expected: false, - }, - { - name: "context.Canceled", - err: context.Canceled, - expected: true, - }, - { - name: "wrapped context.Canceled", - err: fmt.Errorf("error: %w", context.Canceled), - expected: true, - }, - { - name: "joined context.Canceled", - err: errors.Join(fmt.Errorf("oh no"), context.Canceled), - expected: true, - }, - { - name: "gRPC canceled error", - err: status.Error(codes.Canceled, "canceled"), - expected: true, - }, - { - name: "wrapped gRPC canceled error", - err: fmt.Errorf("error: %w", status.Error(codes.Canceled, "canceled")), - expected: true, - }, - { - name: "joined gRPC canceled error", - err: errors.Join(fmt.Errorf("oh no"), status.Error(codes.Canceled, "canceled")), - expected: true, - }, - } - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - assert.Equalf(t, tc.expected, IsDownstreamError(tc.err), "IsDownstreamError(%v)", tc.err) - }) - } -} - -var _ net.Error = &fakeNetworkError{} - -type fakeNetworkError struct { - timeout bool - temporary bool -} - -func newFakeNetworkError(timeout, temporary bool) *fakeNetworkError { - return &fakeNetworkError{ - timeout: timeout, - temporary: temporary, - } -} - -func (d *fakeNetworkError) Error() string { - return "dummy timeout error" -} - -func (d *fakeNetworkError) Timeout() bool { - return d.timeout -} - -func (d *fakeNetworkError) Temporary() bool { - return d.temporary -} diff --git a/backend/httpclient/error_source_middleware.go b/backend/httpclient/error_source_middleware.go new file mode 100644 index 000000000..17594443a --- /dev/null +++ b/backend/httpclient/error_source_middleware.go @@ -0,0 +1,24 @@ +package httpclient + +import ( + "net/http" + + "github.com/grafana/grafana-plugin-sdk-go/backend/status" +) + +// ErrorSourceMiddlewareName is the middleware name used by ErrorSourceMiddleware. +const ErrorSourceMiddlewareName = "ErrorSource" + +// ErrorSourceMiddleware inspect the response error and wraps it in a [status.DownstreamError] if [status.IsDownstreamHTTPError] returns true. +func ErrorSourceMiddleware() Middleware { + return NamedMiddlewareFunc(ErrorSourceMiddlewareName, func(_ Options, next http.RoundTripper) http.RoundTripper { + return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + res, err := next.RoundTrip(req) + if err != nil && status.IsDownstreamHTTPError(err) { + return res, status.DownstreamError(err) + } + + return res, err + }) + }) +} diff --git a/backend/httpclient/http_client.go b/backend/httpclient/http_client.go index 40a4d39d4..0d88f72e8 100644 --- a/backend/httpclient/http_client.go +++ b/backend/httpclient/http_client.go @@ -210,6 +210,7 @@ func DefaultMiddlewares() []Middleware { BasicAuthenticationMiddleware(), CustomHeadersMiddleware(), ContextualMiddleware(), + ErrorSourceMiddleware(), } } diff --git a/backend/httpclient/http_client_test.go b/backend/httpclient/http_client_test.go index 570fe0833..e8875d180 100644 --- a/backend/httpclient/http_client_test.go +++ b/backend/httpclient/http_client_test.go @@ -55,11 +55,12 @@ func TestNewClient(t *testing.T) { require.NoError(t, err) require.NotNil(t, client) - require.Len(t, usedMiddlewares, 4) + require.Len(t, usedMiddlewares, 5) require.Equal(t, TracingMiddlewareName, usedMiddlewares[0].(MiddlewareName).MiddlewareName()) require.Equal(t, BasicAuthenticationMiddlewareName, usedMiddlewares[1].(MiddlewareName).MiddlewareName()) require.Equal(t, CustomHeadersMiddlewareName, usedMiddlewares[2].(MiddlewareName).MiddlewareName()) require.Equal(t, ContextualMiddlewareName, usedMiddlewares[3].(MiddlewareName).MiddlewareName()) + require.Equal(t, ErrorSourceMiddlewareName, usedMiddlewares[4].(MiddlewareName).MiddlewareName()) }) t.Run("New() with opts middleware should return expected http.Client", func(t *testing.T) { diff --git a/backend/httpclient/provider_test.go b/backend/httpclient/provider_test.go index b5331321a..deb7525ce 100644 --- a/backend/httpclient/provider_test.go +++ b/backend/httpclient/provider_test.go @@ -24,11 +24,12 @@ func TestProvider(t *testing.T) { client, err := ctx.provider.New() require.NoError(t, err) require.NotNil(t, client) - require.Len(t, ctx.usedMiddlewares, 4) + require.Len(t, ctx.usedMiddlewares, 5) require.Equal(t, TracingMiddlewareName, ctx.usedMiddlewares[0].(MiddlewareName).MiddlewareName()) require.Equal(t, BasicAuthenticationMiddlewareName, ctx.usedMiddlewares[1].(MiddlewareName).MiddlewareName()) require.Equal(t, CustomHeadersMiddlewareName, ctx.usedMiddlewares[2].(MiddlewareName).MiddlewareName()) require.Equal(t, ContextualMiddlewareName, ctx.usedMiddlewares[3].(MiddlewareName).MiddlewareName()) + require.Equal(t, ErrorSourceMiddlewareName, ctx.usedMiddlewares[4].(MiddlewareName).MiddlewareName()) }) t.Run("Transport should use default middlewares", func(t *testing.T) { @@ -36,11 +37,12 @@ func TestProvider(t *testing.T) { transport, err := ctx.provider.GetTransport() require.NoError(t, err) require.NotNil(t, transport) - require.Len(t, ctx.usedMiddlewares, 4) + require.Len(t, ctx.usedMiddlewares, 5) require.Equal(t, TracingMiddlewareName, ctx.usedMiddlewares[0].(MiddlewareName).MiddlewareName()) require.Equal(t, BasicAuthenticationMiddlewareName, ctx.usedMiddlewares[1].(MiddlewareName).MiddlewareName()) require.Equal(t, CustomHeadersMiddlewareName, ctx.usedMiddlewares[2].(MiddlewareName).MiddlewareName()) require.Equal(t, ContextualMiddlewareName, ctx.usedMiddlewares[3].(MiddlewareName).MiddlewareName()) + require.Equal(t, ErrorSourceMiddlewareName, ctx.usedMiddlewares[4].(MiddlewareName).MiddlewareName()) }) t.Run("New() with options and no middleware should return expected http client and transport", func(t *testing.T) { @@ -81,7 +83,7 @@ func TestProvider(t *testing.T) { require.Equal(t, DefaultTimeoutOptions.Timeout, client.Timeout) t.Run("Should use configured middlewares and implement MiddlewareName", func(t *testing.T) { - require.Len(t, pCtx.usedMiddlewares, 7) + require.Len(t, pCtx.usedMiddlewares, 8) require.Equal(t, "mw1", pCtx.usedMiddlewares[0].(MiddlewareName).MiddlewareName()) require.Equal(t, "mw2", pCtx.usedMiddlewares[1].(MiddlewareName).MiddlewareName()) require.Equal(t, "mw3", pCtx.usedMiddlewares[2].(MiddlewareName).MiddlewareName()) @@ -89,6 +91,7 @@ func TestProvider(t *testing.T) { require.Equal(t, BasicAuthenticationMiddlewareName, pCtx.usedMiddlewares[4].(MiddlewareName).MiddlewareName()) require.Equal(t, CustomHeadersMiddlewareName, pCtx.usedMiddlewares[5].(MiddlewareName).MiddlewareName()) require.Equal(t, ContextualMiddlewareName, pCtx.usedMiddlewares[6].(MiddlewareName).MiddlewareName()) + require.Equal(t, ErrorSourceMiddlewareName, pCtx.usedMiddlewares[7].(MiddlewareName).MiddlewareName()) }) t.Run("When roundtrip should call expected middlewares", func(t *testing.T) { diff --git a/backend/request_status.go b/backend/request_status.go index 36280295e..e2e490633 100644 --- a/backend/request_status.go +++ b/backend/request_status.go @@ -2,14 +2,9 @@ package backend import ( "context" - "errors" - "net" - "os" "strings" - grpccodes "google.golang.org/grpc/codes" - grpcstatus "google.golang.org/grpc/status" - + "github.com/grafana/grafana-plugin-sdk-go/backend/status" "github.com/grafana/grafana-plugin-sdk-go/genproto/pluginv2" ) @@ -31,15 +26,15 @@ func (status RequestStatus) String() string { } func RequestStatusFromError(err error) RequestStatus { - status := RequestStatusOK + requestStatus := RequestStatusOK if err != nil { - status = RequestStatusError - if isCancelledError(err) { - status = RequestStatusCancelled + requestStatus = RequestStatusError + if status.IsCancelledError(err) { + requestStatus = RequestStatusCancelled } } - return status + return requestStatus } func RequestStatusFromErrorString(errString string) RequestStatus { @@ -103,16 +98,3 @@ func RequestStatusFromProtoQueryDataResponse(res *pluginv2.QueryDataResponse, er return status } - -func isCancelledError(err error) bool { - return errors.Is(err, context.Canceled) || grpcstatus.Code(err) == grpccodes.Canceled -} - -func isHTTPTimeoutError(err error) bool { - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { - return true - } - - return errors.Is(err, os.ErrDeadlineExceeded) // replacement for os.IsTimeout(err) -} diff --git a/backend/status/doc.go b/backend/status/doc.go new file mode 100644 index 000000000..a148e000a --- /dev/null +++ b/backend/status/doc.go @@ -0,0 +1,2 @@ +// Package status provides utilities for status and errors. +package status diff --git a/backend/status/status_source.go b/backend/status/status_source.go new file mode 100644 index 000000000..0dedf9342 --- /dev/null +++ b/backend/status/status_source.go @@ -0,0 +1,213 @@ +package status + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "os" + "syscall" + + grpccodes "google.golang.org/grpc/codes" + grpcstatus "google.golang.org/grpc/status" +) + +// Source type defines the status source. +type Source string + +const ( + // SourcePlugin status originates from plugin. + SourcePlugin Source = "plugin" + + // SourceDownstream status originates from downstream service. + SourceDownstream Source = "downstream" + + // DefaultSource is the default [Source] that should be used when it is not explicitly set. + DefaultSource Source = SourcePlugin +) + +// IsValid return true if es is [SourceDownstream] or [SourcePlugin]. +func (s Source) IsValid() bool { + return s == SourceDownstream || s == SourcePlugin +} + +// String returns the string representation of s. If s is not valid, [DefaultSource] is returned. +func (s Source) String() string { + if !s.IsValid() { + return string(DefaultSource) + } + + return string(s) +} + +// SourceFromHTTPError returns a [Source] based on provided error. +func SourceFromHTTPError(err error) Source { + if IsDownstreamHTTPError(err) { + return SourceDownstream + } + return SourcePlugin +} + +// ErrorSourceFromStatus returns a [Source] based on provided HTTP status code. +func SourceFromHTTPStatus(statusCode int) Source { + switch statusCode { + case http.StatusMethodNotAllowed, + http.StatusNotAcceptable, + http.StatusPreconditionFailed, + http.StatusRequestEntityTooLarge, + http.StatusRequestHeaderFieldsTooLarge, + http.StatusRequestURITooLong, + http.StatusExpectationFailed, + http.StatusUpgradeRequired, + http.StatusRequestedRangeNotSatisfiable, + http.StatusNotImplemented: + return SourcePlugin + } + + return SourceDownstream +} + +type errorWithSourceImpl struct { + source Source + err error +} + +// DownstreamError creates a new error with status [SourceDownstream]. +func DownstreamError(err error) error { + return errorWithSourceImpl{ + source: SourceDownstream, + err: err, + } +} + +// DownstreamError creates a new error with status [SourceDownstream] and formats +// according to a format specifier and returns the string as a value that satisfies error. +func DownstreamErrorf(format string, a ...any) error { + return DownstreamError(fmt.Errorf(format, a...)) +} + +func (e errorWithSourceImpl) ErrorSource() Source { + return e.source +} + +func (e errorWithSourceImpl) Error() string { + return fmt.Errorf("%s error: %w", e.source, e.err).Error() +} + +// Implements the interface used by [errors.Is]. +func (e errorWithSourceImpl) Is(err error) bool { + if errWithSource, ok := err.(errorWithSourceImpl); ok { + return errWithSource.ErrorSource() == e.source + } + + return false +} + +func (e errorWithSourceImpl) Unwrap() error { + return e.err +} + +// IsDownstreamError return true if provided error is an error with downstream source or +// a timeout error or a cancelled error. +func IsDownstreamError(err error) bool { + e := errorWithSourceImpl{ + source: SourceDownstream, + } + if errors.Is(err, e) { + return true + } + + type errorWithSource interface { + ErrorSource() Source + } + + // nolint:errorlint + if errWithSource, ok := err.(errorWithSource); ok && errWithSource.ErrorSource() == SourceDownstream { + return true + } + + return isHTTPTimeoutError(err) || IsCancelledError(err) +} + +// IsDownstreamHTTPError return true if provided error is an error with downstream source or +// a HTTP timeout error or a cancelled error or a connection reset/refused error or dns not found error. +func IsDownstreamHTTPError(err error) bool { + return IsDownstreamError(err) || + isConnectionResetOrRefusedError(err) || + isDNSNotFoundError(err) +} + +// InCancelledError returns true if err is context.Canceled or is gRPC status Canceled. +func IsCancelledError(err error) bool { + return errors.Is(err, context.Canceled) || grpcstatus.Code(err) == grpccodes.Canceled +} + +func isHTTPTimeoutError(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + + return errors.Is(err, os.ErrDeadlineExceeded) // replacement for os.IsTimeout(err) +} + +func isConnectionResetOrRefusedError(err error) bool { + var netErr *net.OpError + if errors.As(err, &netErr) { + var sysErr *os.SyscallError + if errors.As(netErr.Err, &sysErr) { + return errors.Is(sysErr.Err, syscall.ECONNRESET) || errors.Is(sysErr.Err, syscall.ECONNREFUSED) + } + } + + return false +} + +func isDNSNotFoundError(err error) bool { + var dnsError *net.DNSError + if errors.As(err, &dnsError) && dnsError.IsNotFound { + return true + } + + return false +} + +type sourceCtxKey struct{} + +// SourceFromContext returns the source stored in the context. +// If no source is stored in the context, [DefaultSource] is returned. +func SourceFromContext(ctx context.Context) Source { + value, ok := ctx.Value(sourceCtxKey{}).(*Source) + if ok { + return *value + } + return DefaultSource +} + +// InitSource initialize the source for the context. +func InitSource(ctx context.Context) context.Context { + s := DefaultSource + return context.WithValue(ctx, sourceCtxKey{}, &s) +} + +// WithSource mutates the provided context by setting the source to +// s. If the provided context does not have a source, the context +// will not be mutated and an error returned. This means that [InitSource] +// has to be called before this function. +func WithSource(ctx context.Context, s Source) error { + v, ok := ctx.Value(sourceCtxKey{}).(*Source) + if !ok { + return errors.New("the provided context does not have a status source") + } + *v = s + return nil +} + +// WithDownstreamSource mutates the provided context by setting the source to +// [SourceDownstream]. If the provided context does not have a source, the context +// will not be mutated and an error returned. This means that [InitSource] has to be +// called before this function. +func WithDownstreamSource(ctx context.Context) error { + return WithSource(ctx, SourceDownstream) +} diff --git a/backend/status/status_source_test.go b/backend/status/status_source_test.go new file mode 100644 index 000000000..91ef362a5 --- /dev/null +++ b/backend/status/status_source_test.go @@ -0,0 +1,193 @@ +package status + +import ( + "context" + "errors" + "fmt" + "net" + "os" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestSource(t *testing.T) { + var s Source + require.False(t, s.IsValid()) + require.Equal(t, "plugin", s.String()) + require.True(t, SourceDownstream.IsValid()) + require.Equal(t, "downstream", SourceDownstream.String()) + require.True(t, SourcePlugin.IsValid()) + require.Equal(t, "plugin", SourcePlugin.String()) +} + +func TestIsDownstreamError(t *testing.T) { + tcs := []struct { + name string + err error + expected bool + }{ + { + name: "nil", + err: nil, + expected: false, + }, + { + name: "downstream error", + err: DownstreamError(nil), + expected: true, + }, + { + name: "timeout network error", + err: newFakeNetworkError(true, false), + expected: true, + }, + { + name: "temporary timeout network error", + err: newFakeNetworkError(true, true), + expected: true, + }, + { + name: "non-timeout network error", + err: newFakeNetworkError(false, false), + expected: false, + }, + { + name: "os.ErrDeadlineExceeded", + err: os.ErrDeadlineExceeded, + expected: true, + }, + { + name: "other error", + err: fmt.Errorf("other error"), + expected: false, + }, + { + name: "context.Canceled", + err: context.Canceled, + expected: true, + }, + { + name: "gRPC canceled error", + err: status.Error(codes.Canceled, "canceled"), + expected: true, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + wrappedErr := fmt.Errorf("error: %w", tc.err) + joinedErr := errors.Join(errors.New("oh no"), tc.err) + assert.Equalf(t, tc.expected, IsDownstreamError(tc.err), "IsDownstreamHTTPError(%v)", tc.err) + assert.Equalf(t, tc.expected, IsDownstreamError(wrappedErr), "wrapped IsDownstreamHTTPError(%v)", wrappedErr) + assert.Equalf(t, tc.expected, IsDownstreamError(joinedErr), "joined IsDownstreamHTTPError(%v)", joinedErr) + }) + } +} + +func TestIsDownstreamHTTPError(t *testing.T) { + tcs := []struct { + name string + err error + expected bool + }{ + { + name: "nil", + err: nil, + expected: false, + }, + { + name: "downstream error", + err: DownstreamError(nil), + expected: true, + }, + { + name: "timeout network error", + err: newFakeNetworkError(true, false), + expected: true, + }, + { + name: "temporary timeout network error", + err: newFakeNetworkError(true, true), + expected: true, + }, + { + name: "non-timeout network error", + err: newFakeNetworkError(false, false), + expected: false, + }, + { + name: "os.ErrDeadlineExceeded", + err: os.ErrDeadlineExceeded, + expected: true, + }, + { + name: "other error", + err: fmt.Errorf("other error"), + expected: false, + }, + { + name: "context.Canceled", + err: context.Canceled, + expected: true, + }, + { + name: "gRPC canceled error", + err: status.Error(codes.Canceled, "canceled"), + expected: true, + }, + { + name: "connection reset error", + err: &net.OpError{Err: &os.SyscallError{Err: syscall.ECONNREFUSED}}, + expected: true, + }, + { + name: "connection refused error", + err: &net.OpError{Err: &os.SyscallError{Err: syscall.ECONNREFUSED}}, + expected: true, + }, + { + name: "DNS not found error", + err: &net.DNSError{IsNotFound: true}, + expected: true, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + wrappedErr := fmt.Errorf("error: %w", tc.err) + joinedErr := errors.Join(errors.New("oh no"), tc.err) + assert.Equalf(t, tc.expected, IsDownstreamHTTPError(tc.err), "IsDownstreamHTTPError(%v)", tc.err) + assert.Equalf(t, tc.expected, IsDownstreamHTTPError(wrappedErr), "wrapped IsDownstreamHTTPError(%v)", wrappedErr) + assert.Equalf(t, tc.expected, IsDownstreamHTTPError(joinedErr), "joined IsDownstreamHTTPError(%v)", joinedErr) + }) + } +} + +var _ net.Error = &fakeNetworkError{} + +type fakeNetworkError struct { + timeout bool + temporary bool +} + +func newFakeNetworkError(timeout, temporary bool) *fakeNetworkError { + return &fakeNetworkError{ + timeout: timeout, + temporary: temporary, + } +} + +func (d *fakeNetworkError) Error() string { + return "dummy timeout error" +} + +func (d *fakeNetworkError) Timeout() bool { + return d.timeout +} + +func (d *fakeNetworkError) Temporary() bool { + return d.temporary +} diff --git a/experimental/errorsource/error_source_middleware.go b/experimental/errorsource/error_source_middleware.go index 95ff79b10..4245ffeeb 100644 --- a/experimental/errorsource/error_source_middleware.go +++ b/experimental/errorsource/error_source_middleware.go @@ -2,12 +2,11 @@ package errorsource import ( "errors" - "net" "net/http" - "syscall" "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana-plugin-sdk-go/backend/status" ) // Middleware captures error source metric @@ -26,13 +25,11 @@ func RoundTripper(_ httpclient.Options, next http.RoundTripper) http.RoundTrippe } return res, Error{source: errorSource, err: err} } - if errors.Is(err, syscall.ECONNREFUSED) { - return res, Error{source: backend.ErrorSourceDownstream, err: err} - } - var dnsError *net.DNSError - if errors.As(err, &dnsError) && dnsError.IsNotFound { + + if status.IsDownstreamHTTPError(err) { return res, Error{source: backend.ErrorSourceDownstream, err: err} } + return res, err }) }