From 5cdef8c5dbd65369ef90f1b8d7ea2b7348ffc862 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Fri, 27 Apr 2018 15:46:30 -0700 Subject: [PATCH] Add token refresh operations that take a context (#274) Token refresh operations make HTTP requests, thus they should support contexts so they can be cancelled etc. Simplified retry logic in refreshInternal(). --- CHANGELOG.md | 6 ++++ autorest/adal/token.go | 56 +++++++++++++++++++++++++------------ autorest/adal/token_test.go | 35 +++++++++++++++++++++++ autorest/authorization.go | 22 ++++++++------- autorest/version.go | 2 +- 5 files changed, 92 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4320d72c0..19937de50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # CHANGELOG +## v10.7.0 + +### New Features + +- Added *WithContext() methods to ADAL token refresh operations. + ## v10.6.2 - Fixed a bug on device authentication. diff --git a/autorest/adal/token.go b/autorest/adal/token.go index 24641b621..b7d5c6071 100644 --- a/autorest/adal/token.go +++ b/autorest/adal/token.go @@ -15,6 +15,7 @@ package adal // limitations under the License. import ( + "context" "crypto/rand" "crypto/rsa" "crypto/sha1" @@ -77,6 +78,13 @@ type Refresher interface { EnsureFresh() error } +// RefresherWithContext is an interface for token refresh functionality +type RefresherWithContext interface { + RefreshWithContext(ctx context.Context) error + RefreshExchangeWithContext(ctx context.Context, resource string) error + EnsureFreshWithContext(ctx context.Context) error +} + // TokenRefreshCallback is the type representing callbacks that will be called after // a successful token refresh type TokenRefreshCallback func(Token) error @@ -528,12 +536,18 @@ func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError // EnsureFresh will refresh the token if it will expire within the refresh window (as set by // RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. func (spt *ServicePrincipalToken) EnsureFresh() error { + return spt.EnsureFreshWithContext(context.Background()) +} + +// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by +// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use. +func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error { if spt.autoRefresh && spt.token.WillExpireIn(spt.refreshWithin) { // take the write lock then check to see if the token was already refreshed spt.refreshLock.Lock() defer spt.refreshLock.Unlock() if spt.token.WillExpireIn(spt.refreshWithin) { - return spt.refreshInternal(spt.resource) + return spt.refreshInternal(ctx, spt.resource) } } return nil @@ -555,17 +569,29 @@ func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error { // Refresh obtains a fresh token for the Service Principal. // This method is not safe for concurrent use and should be syncrhonized. func (spt *ServicePrincipalToken) Refresh() error { + return spt.RefreshWithContext(context.Background()) +} + +// RefreshWithContext obtains a fresh token for the Service Principal. +// This method is not safe for concurrent use and should be syncrhonized. +func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error { spt.refreshLock.Lock() defer spt.refreshLock.Unlock() - return spt.refreshInternal(spt.resource) + return spt.refreshInternal(ctx, spt.resource) } // RefreshExchange refreshes the token, but for a different resource. // This method is not safe for concurrent use and should be syncrhonized. func (spt *ServicePrincipalToken) RefreshExchange(resource string) error { + return spt.RefreshExchangeWithContext(context.Background(), resource) +} + +// RefreshExchangeWithContext refreshes the token, but for a different resource. +// This method is not safe for concurrent use and should be syncrhonized. +func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error { spt.refreshLock.Lock() defer spt.refreshLock.Unlock() - return spt.refreshInternal(resource) + return spt.refreshInternal(ctx, resource) } func (spt *ServicePrincipalToken) getGrantType() string { @@ -587,12 +613,12 @@ func isIMDS(u url.URL) bool { return u.Host == imds.Host && u.Path == imds.Path } -func (spt *ServicePrincipalToken) refreshInternal(resource string) error { +func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error { req, err := http.NewRequest(http.MethodPost, spt.oauthConfig.TokenEndpoint.String(), nil) if err != nil { return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err) } - + req = req.WithContext(ctx) if !isIMDS(spt.oauthConfig.TokenEndpoint) { v := url.Values{} v.Set("client_id", spt.clientID) @@ -683,24 +709,18 @@ func retry(sender Sender, req *http.Request) (resp *http.Response, err error) { for attempt < maxAttempts { resp, err = sender.Do(req) - if err != nil { + if err != nil || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) { return } - if resp.StatusCode == http.StatusOK { - return - } - if containsInt(retries, resp.StatusCode) { - delayed := false - if resp.StatusCode == http.StatusTooManyRequests { - delayed = delay(resp, req.Cancel) - } - if !delayed { - time.Sleep(time.Second) + if !delay(resp, req.Context().Done()) { + select { + case <-time.After(time.Second): attempt++ + case <-req.Context().Done(): + err = req.Context().Err() + return } - } else { - return } } return diff --git a/autorest/adal/token_test.go b/autorest/adal/token_test.go index 194c76795..6719e9616 100644 --- a/autorest/adal/token_test.go +++ b/autorest/adal/token_test.go @@ -15,6 +15,7 @@ package adal // limitations under the License. import ( + "context" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -27,6 +28,7 @@ import ( "reflect" "strconv" "strings" + "sync" "testing" "time" @@ -185,6 +187,39 @@ func TestServicePrincipalTokenFromMSIRefreshUsesGET(t *testing.T) { } } +func TestServicePrincipalTokenFromMSIRefreshCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + endpoint, _ := GetMSIVMEndpoint() + + spt, err := NewServicePrincipalTokenFromMSI(endpoint, "https://resource") + if err != nil { + t.Fatalf("Failed to get MSI SPT: %v", err) + } + + c := mocks.NewSender() + c.AppendAndRepeatResponse(mocks.NewResponseWithStatus("Internal server error", http.StatusInternalServerError), 5) + + var wg sync.WaitGroup + wg.Add(1) + start := time.Now() + end := time.Now() + + go func() { + spt.SetSender(c) + err = spt.RefreshWithContext(ctx) + end = time.Now() + wg.Done() + }() + + cancel() + wg.Wait() + time.Sleep(5 * time.Millisecond) + + if end.Sub(start) >= time.Second { + t.Fatalf("TestServicePrincipalTokenFromMSIRefreshCancel failed to cancel") + } +} + func TestServicePrincipalTokenRefreshSetsMimeType(t *testing.T) { spt := newServicePrincipalToken() diff --git a/autorest/authorization.go b/autorest/authorization.go index c51eac0a7..77eff45bd 100644 --- a/autorest/authorization.go +++ b/autorest/authorization.go @@ -113,17 +113,19 @@ func (ba *BearerAuthorizer) WithAuthorization() PrepareDecorator { return PreparerFunc(func(r *http.Request) (*http.Request, error) { r, err := p.Prepare(r) if err == nil { - refresher, ok := ba.tokenProvider.(adal.Refresher) - if ok { - err := refresher.EnsureFresh() - if err != nil { - var resp *http.Response - if tokError, ok := err.(adal.TokenRefreshError); ok { - resp = tokError.Response() - } - return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", resp, - "Failed to refresh the Token for request to %s", r.URL) + // the ordering is important here, prefer RefresherWithContext if available + if refresher, ok := ba.tokenProvider.(adal.RefresherWithContext); ok { + err = refresher.EnsureFreshWithContext(r.Context()) + } else if refresher, ok := ba.tokenProvider.(adal.Refresher); ok { + err = refresher.EnsureFresh() + } + if err != nil { + var resp *http.Response + if tokError, ok := err.(adal.TokenRefreshError); ok { + resp = tokError.Response() } + return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", resp, + "Failed to refresh the Token for request to %s", r.URL) } return Prepare(r, WithHeader(headerAuthorization, fmt.Sprintf("Bearer %s", ba.tokenProvider.OAuthToken()))) } diff --git a/autorest/version.go b/autorest/version.go index 4ad7754ad..efa7d8e12 100644 --- a/autorest/version.go +++ b/autorest/version.go @@ -16,5 +16,5 @@ package autorest // Version returns the semantic version (see http://semver.org). func Version() string { - return "v10.5.0" + return "v10.7.0" }