diff --git a/client/client.go b/client/client.go index 175a57351..1ac995176 100644 --- a/client/client.go +++ b/client/client.go @@ -63,8 +63,7 @@ func (c *DatabricksClient) GetOAuthToken(ctx context.Context, authDetails string // Do sends an HTTP request against path. func (c *DatabricksClient) Do(ctx context.Context, method, path string, - headers map[string]string, request, response any, - visitors ...func(*http.Request) error) error { + headers map[string]string, request, response any, visitors ...func(*http.Request) error) error { opts := []httpclient.DoOption{} for _, v := range visitors { opts = append(opts, httpclient.WithRequestVisitor(v)) diff --git a/config/api_client.go b/config/api_client.go index 0913e35b3..6a824e6d7 100644 --- a/config/api_client.go +++ b/config/api_client.go @@ -9,6 +9,7 @@ import ( "time" "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/common" "github.com/databricks/databricks-sdk-go/credentials" "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/useragent" @@ -73,17 +74,18 @@ func (c *Config) NewApiClient() (*httpclient.ApiClient, error) { return nil }, }, - TransientErrors: []string{ - "REQUEST_LIMIT_EXCEEDED", // This is temporary workaround for SCIM API returning 500. Remove when it's fixed - }, ErrorMapper: apierr.GetAPIError, - ErrorRetriable: func(ctx context.Context, err error) bool { - var apiErr *apierr.APIError - if errors.As(err, &apiErr) { - return apiErr.IsRetriable(ctx) - } - return false - }, + ErrorRetriable: httpclient.CombineRetriers( + func(ctx context.Context, _ *http.Request, _ *common.ResponseWrapper, err error) bool { + var apiErr *apierr.APIError + if errors.As(err, &apiErr) { + return apiErr.IsRetriable(ctx) + } + return false + }, + httpclient.RetryUrlErrors, + httpclient.RetryTransientErrors([]string{"REQUEST_LIMIT_EXCEEDED"}), + ), }), nil } diff --git a/config/config.go b/config/config.go index fcf69d2cb..f2c948844 100644 --- a/config/config.go +++ b/config/config.go @@ -311,13 +311,16 @@ func (c *Config) EnsureResolved() error { HTTPTimeout: time.Duration(c.HTTPTimeoutSeconds) * time.Second, Transport: c.HTTPTransport, ErrorMapper: c.refreshTokenErrorMapper, - TransientErrors: []string{ - "throttled", - "too many requests", - "429", - "request limit exceeded", - "rate limit", - }, + ErrorRetriable: httpclient.CombineRetriers( + httpclient.DefaultErrorRetriable, + httpclient.RetryTransientErrors([]string{ + "throttled", + "too many requests", + "429", + "request limit exceeded", + "rate limit", + }), + ), }) if c.azureTenantIdFetchClient == nil { c.azureTenantIdFetchClient = &http.Client{ diff --git a/httpclient/api_client.go b/httpclient/api_client.go index 2130fd3bb..41374ca0e 100644 --- a/httpclient/api_client.go +++ b/httpclient/api_client.go @@ -9,7 +9,6 @@ import ( "net/http" "net/url" "runtime" - "strings" "time" "github.com/databricks/databricks-sdk-go/common" @@ -35,9 +34,8 @@ type ClientConfig struct { DebugTruncateBytes int RateLimitPerSecond int - ErrorMapper func(ctx context.Context, resp common.ResponseWrapper) error - ErrorRetriable func(ctx context.Context, err error) bool - TransientErrors []string + ErrorMapper func(ctx context.Context, resp common.ResponseWrapper) error + ErrorRetriable ErrorRetryer Transport http.RoundTripper } @@ -130,7 +128,6 @@ func (c *ApiClient) Do(ctx context.Context, method, path string, opts ...DoOptio // merge client-wide and request-specific visitors visitors = append(visitors, o.in) } - } // Use default AuthVisitor if none is provided if authVisitor == nil { @@ -170,45 +167,6 @@ func (c *ApiClient) Do(ctx context.Context, method, path string, opts ...DoOptio return nil } -func (c *ApiClient) isRetriable(ctx context.Context, err error) bool { - if c.config.ErrorRetriable(ctx, err) { - return true - } - if isRetriableUrlError(err) { - // all IO errors are retriable - logger.Debugf(ctx, "Attempting retry because of IO error: %s", err) - return true - } - message := err.Error() - // Handle transient errors for retries - for _, substring := range c.config.TransientErrors { - if strings.Contains(message, substring) { - logger.Debugf(ctx, "Attempting retry because of %#v", substring) - return true - } - } - // some API's recommend retries on HTTP 500, but we'll add that later - return false -} - -// Common error-handling logic for all responses that may need to be retried. -// -// If the error is retriable, return a retries.Err to retry the request. However, as the request body will have been consumed -// by the first attempt, the body must be reset before retrying. If the body cannot be reset, return a retries.Err to halt. -// -// Always returns nil for the first parameter as there is no meaningful response body to return in the error case. -// -// If it is certain that an error should not be retried, use failRequest() instead. -func (c *ApiClient) handleError(ctx context.Context, err error, body common.RequestBody) (*common.ResponseWrapper, *retries.Err) { - if !c.isRetriable(ctx, err) { - return nil, retries.Halt(err) - } - if resetErr := body.Reset(); resetErr != nil { - return nil, retries.Halt(resetErr) - } - return nil, retries.Continue(err) -} - // Fails the request with a retries.Err to halt future retries. func (c *ApiClient) failRequest(msg string, err error) (*common.ResponseWrapper, *retries.Err) { err = fmt.Errorf("%s: %w", msg, err) @@ -299,7 +257,16 @@ func (c *ApiClient) attempt( // proactively release the connections in HTTP connection pool c.httpClient.CloseIdleConnections() - return c.handleError(ctx, err, requestBody) + + // Non-retriable errors can be returned immediately. + if !c.config.ErrorRetriable(ctx, request, &responseWrapper, err) { + return nil, retries.Halt(err) + } + // Retriable errors may require the request body to be reset. + if resetErr := requestBody.Reset(); resetErr != nil { + return nil, retries.Halt(resetErr) + } + return nil, retries.Continue(err) } } diff --git a/httpclient/errors.go b/httpclient/errors.go index 540c6b885..99efced29 100644 --- a/httpclient/errors.go +++ b/httpclient/errors.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/databricks/databricks-sdk-go/common" + "github.com/databricks/databricks-sdk-go/logger" ) type HttpError struct { @@ -45,17 +46,39 @@ func DefaultErrorMapper(ctx context.Context, resp common.ResponseWrapper) error } } -func DefaultErrorRetriable(ctx context.Context, err error) bool { - var httpError *HttpError - if errors.As(err, &httpError) { - if httpError.StatusCode == http.StatusTooManyRequests { - return true - } - if httpError.StatusCode == http.StatusGatewayTimeout { - return true +type ErrorRetryer func(context.Context, *http.Request, *common.ResponseWrapper, error) bool + +func DefaultErrorRetriable(ctx context.Context, req *http.Request, resp *common.ResponseWrapper, err error) bool { + return CombineRetriers( + RetryOnTooManyRequests, + RetryOnGatewayTimeout, + RetryUrlErrors, + )(ctx, req, resp, err) +} + +func RetryOnTooManyRequests(ctx context.Context, _ *http.Request, resp *common.ResponseWrapper, err error) bool { + if resp.Response == nil { + return false + } + return resp.Response.StatusCode == http.StatusTooManyRequests +} + +func RetryOnGatewayTimeout(ctx context.Context, _ *http.Request, resp *common.ResponseWrapper, err error) bool { + if resp.Response == nil { + return false + } + return resp.Response.StatusCode == http.StatusGatewayTimeout +} + +func CombineRetriers(retriers ...ErrorRetryer) ErrorRetryer { + return func(ctx context.Context, req *http.Request, resp *common.ResponseWrapper, err error) bool { + for _, retrier := range retriers { + if retrier(ctx, req, resp, err) { + return true + } } + return false } - return false } var urlErrorTransientErrorMessages = []string{ @@ -66,15 +89,30 @@ var urlErrorTransientErrorMessages = []string{ "i/o timeout", } -func isRetriableUrlError(err error) bool { +func RetryUrlErrors(ctx context.Context, _ *http.Request, _ *common.ResponseWrapper, err error) bool { var urlError *url.Error if !errors.As(err, &urlError) { return false } for _, msg := range urlErrorTransientErrorMessages { if strings.Contains(err.Error(), msg) { + logger.Debugf(ctx, "Attempting retry because of IO error: %s", err) return true } } return false } + +func RetryTransientErrors(errors []string) ErrorRetryer { + return func(ctx context.Context, _ *http.Request, _ *common.ResponseWrapper, err error) bool { + message := err.Error() + // Handle transient errors for retries + for _, substring := range errors { + if strings.Contains(message, substring) { + logger.Debugf(ctx, "Attempting retry because of %#v", substring) + return true + } + } + return false + } +}