Skip to content
This repository has been archived by the owner on Feb 12, 2025. It is now read-only.

Commit

Permalink
Add token refresh operations that take a context (#274)
Browse files Browse the repository at this point in the history
Token refresh operations make HTTP requests, thus they should support
contexts so they can be cancelled etc.
Simplified retry logic in refreshInternal().
  • Loading branch information
jhendrixMSFT authored Apr 27, 2018
1 parent 1ff2880 commit 5cdef8c
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 29 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# CHANGELOG

## v10.7.0

### New Features

- Added *WithContext() methods to ADAL token refresh operations.

## v10.6.2

- Fixed a bug on device authentication.
Expand Down
56 changes: 38 additions & 18 deletions autorest/adal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package adal
// limitations under the License.

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
Expand Down Expand Up @@ -77,6 +78,13 @@ type Refresher interface {
EnsureFresh() error
}

// RefresherWithContext is an interface for token refresh functionality
type RefresherWithContext interface {
RefreshWithContext(ctx context.Context) error
RefreshExchangeWithContext(ctx context.Context, resource string) error
EnsureFreshWithContext(ctx context.Context) error
}

// TokenRefreshCallback is the type representing callbacks that will be called after
// a successful token refresh
type TokenRefreshCallback func(Token) error
Expand Down Expand Up @@ -528,12 +536,18 @@ func newTokenRefreshError(message string, resp *http.Response) TokenRefreshError
// EnsureFresh will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
func (spt *ServicePrincipalToken) EnsureFresh() error {
return spt.EnsureFreshWithContext(context.Background())
}

// EnsureFreshWithContext will refresh the token if it will expire within the refresh window (as set by
// RefreshWithin) and autoRefresh flag is on. This method is safe for concurrent use.
func (spt *ServicePrincipalToken) EnsureFreshWithContext(ctx context.Context) error {
if spt.autoRefresh && spt.token.WillExpireIn(spt.refreshWithin) {
// take the write lock then check to see if the token was already refreshed
spt.refreshLock.Lock()
defer spt.refreshLock.Unlock()
if spt.token.WillExpireIn(spt.refreshWithin) {
return spt.refreshInternal(spt.resource)
return spt.refreshInternal(ctx, spt.resource)
}
}
return nil
Expand All @@ -555,17 +569,29 @@ func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error {
// Refresh obtains a fresh token for the Service Principal.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) Refresh() error {
return spt.RefreshWithContext(context.Background())
}

// RefreshWithContext obtains a fresh token for the Service Principal.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error {
spt.refreshLock.Lock()
defer spt.refreshLock.Unlock()
return spt.refreshInternal(spt.resource)
return spt.refreshInternal(ctx, spt.resource)
}

// RefreshExchange refreshes the token, but for a different resource.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshExchange(resource string) error {
return spt.RefreshExchangeWithContext(context.Background(), resource)
}

// RefreshExchangeWithContext refreshes the token, but for a different resource.
// This method is not safe for concurrent use and should be syncrhonized.
func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error {
spt.refreshLock.Lock()
defer spt.refreshLock.Unlock()
return spt.refreshInternal(resource)
return spt.refreshInternal(ctx, resource)
}

func (spt *ServicePrincipalToken) getGrantType() string {
Expand All @@ -587,12 +613,12 @@ func isIMDS(u url.URL) bool {
return u.Host == imds.Host && u.Path == imds.Path
}

func (spt *ServicePrincipalToken) refreshInternal(resource string) error {
func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
req, err := http.NewRequest(http.MethodPost, spt.oauthConfig.TokenEndpoint.String(), nil)
if err != nil {
return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
}

req = req.WithContext(ctx)
if !isIMDS(spt.oauthConfig.TokenEndpoint) {
v := url.Values{}
v.Set("client_id", spt.clientID)
Expand Down Expand Up @@ -683,24 +709,18 @@ func retry(sender Sender, req *http.Request) (resp *http.Response, err error) {

for attempt < maxAttempts {
resp, err = sender.Do(req)
if err != nil {
if err != nil || resp.StatusCode == http.StatusOK || !containsInt(retries, resp.StatusCode) {
return
}

if resp.StatusCode == http.StatusOK {
return
}
if containsInt(retries, resp.StatusCode) {
delayed := false
if resp.StatusCode == http.StatusTooManyRequests {
delayed = delay(resp, req.Cancel)
}
if !delayed {
time.Sleep(time.Second)
if !delay(resp, req.Context().Done()) {
select {
case <-time.After(time.Second):
attempt++
case <-req.Context().Done():
err = req.Context().Err()
return
}
} else {
return
}
}
return
Expand Down
35 changes: 35 additions & 0 deletions autorest/adal/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package adal
// limitations under the License.

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
Expand All @@ -27,6 +28,7 @@ import (
"reflect"
"strconv"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -185,6 +187,39 @@ func TestServicePrincipalTokenFromMSIRefreshUsesGET(t *testing.T) {
}
}

func TestServicePrincipalTokenFromMSIRefreshCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
endpoint, _ := GetMSIVMEndpoint()

spt, err := NewServicePrincipalTokenFromMSI(endpoint, "https://resource")
if err != nil {
t.Fatalf("Failed to get MSI SPT: %v", err)
}

c := mocks.NewSender()
c.AppendAndRepeatResponse(mocks.NewResponseWithStatus("Internal server error", http.StatusInternalServerError), 5)

var wg sync.WaitGroup
wg.Add(1)
start := time.Now()
end := time.Now()

go func() {
spt.SetSender(c)
err = spt.RefreshWithContext(ctx)
end = time.Now()
wg.Done()
}()

cancel()
wg.Wait()
time.Sleep(5 * time.Millisecond)

if end.Sub(start) >= time.Second {
t.Fatalf("TestServicePrincipalTokenFromMSIRefreshCancel failed to cancel")
}
}

func TestServicePrincipalTokenRefreshSetsMimeType(t *testing.T) {
spt := newServicePrincipalToken()

Expand Down
22 changes: 12 additions & 10 deletions autorest/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,19 @@ func (ba *BearerAuthorizer) WithAuthorization() PrepareDecorator {
return PreparerFunc(func(r *http.Request) (*http.Request, error) {
r, err := p.Prepare(r)
if err == nil {
refresher, ok := ba.tokenProvider.(adal.Refresher)
if ok {
err := refresher.EnsureFresh()
if err != nil {
var resp *http.Response
if tokError, ok := err.(adal.TokenRefreshError); ok {
resp = tokError.Response()
}
return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", resp,
"Failed to refresh the Token for request to %s", r.URL)
// the ordering is important here, prefer RefresherWithContext if available
if refresher, ok := ba.tokenProvider.(adal.RefresherWithContext); ok {
err = refresher.EnsureFreshWithContext(r.Context())
} else if refresher, ok := ba.tokenProvider.(adal.Refresher); ok {
err = refresher.EnsureFresh()
}
if err != nil {
var resp *http.Response
if tokError, ok := err.(adal.TokenRefreshError); ok {
resp = tokError.Response()
}
return r, NewErrorWithError(err, "azure.BearerAuthorizer", "WithAuthorization", resp,
"Failed to refresh the Token for request to %s", r.URL)
}
return Prepare(r, WithHeader(headerAuthorization, fmt.Sprintf("Bearer %s", ba.tokenProvider.OAuthToken())))
}
Expand Down
2 changes: 1 addition & 1 deletion autorest/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ package autorest

// Version returns the semantic version (see http://semver.org).
func Version() string {
return "v10.5.0"
return "v10.7.0"
}

0 comments on commit 5cdef8c

Please sign in to comment.