Skip to content

Commit

Permalink
CLOUDP-252326 add missing auth dependencies (#236)
Browse files Browse the repository at this point in the history
  • Loading branch information
husniMDB authored Jun 26, 2024
1 parent e32fddf commit 1269096
Show file tree
Hide file tree
Showing 7 changed files with 870 additions and 6 deletions.
174 changes: 174 additions & 0 deletions auth/device_flow.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Copyright 2022 MongoDB Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package auth

import (
"context"
"errors"
"net/http"
"net/url"
"strings"
"time"

"go.mongodb.org/ops-manager/opsmngr"
)

const authExpiredError = "DEVICE_AUTHORIZATION_EXPIRED"

// DeviceCode holds information about the authorization-in-progress.
type DeviceCode struct {
UserCode string `json:"user_code"` //nolint:tagliatelle // UserCode is the code presented to users
VerificationURI string `json:"verification_uri"` //nolint:tagliatelle // VerificationURI is the URI where users will need to confirm the code
DeviceCode string `json:"device_code"` //nolint:tagliatelle // DeviceCode is the internal code to confirm the status of the flow
ExpiresIn int `json:"expires_in"` //nolint:tagliatelle // ExpiresIn when the code will expire
Interval int `json:"interval"` // Interval how often to verify the status of the code

timeNow func() time.Time
timeSleep func(time.Duration)
}

type RegistrationConfig struct {
RegistrationURL string `json:"registrationUrl"`
}

const deviceBasePath = "api/private/unauth/account/device"

// RequestCode initiates the authorization flow by requesting a code.
func (c *Config) RequestCode(ctx context.Context) (*DeviceCode, *opsmngr.Response, error) {
req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/authorize",
url.Values{
"client_id": {c.ClientID},
"scope": {strings.Join(c.Scopes, " ")},
},
)
if err != nil {
return nil, nil, err
}
var r *DeviceCode
resp, err2 := c.Do(ctx, req, &r)
return r, resp, err2
}

// GetToken gets a device token.
func (c *Config) GetToken(ctx context.Context, deviceCode string) (*Token, *opsmngr.Response, error) {
req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/token",
url.Values{
"client_id": {c.ClientID},
"device_code": {deviceCode},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
},
)
if err != nil {
return nil, nil, err
}
var t *Token
resp, err2 := c.Do(ctx, req, &t)
if err2 != nil {
return nil, resp, err2
}
return t, resp, err2
}

// ErrTimeout is returned when polling the server for the granted token has timed out.
var ErrTimeout = errors.New("authentication timed out")

// PollToken polls the server until an access token is granted or denied.
func (c *Config) PollToken(ctx context.Context, code *DeviceCode) (*Token, *opsmngr.Response, error) {
timeNow := code.timeNow
if timeNow == nil {
timeNow = time.Now
}
timeSleep := code.timeSleep
if timeSleep == nil {
timeSleep = time.Sleep
}

checkInterval := time.Duration(code.Interval) * time.Second
expiresAt := timeNow().Add(time.Duration(code.ExpiresIn) * time.Second)

for {
timeSleep(checkInterval)
token, resp, err := c.GetToken(ctx, code.DeviceCode)
var target *opsmngr.ErrorResponse
if errors.As(err, &target) && target.ErrorCode == "DEVICE_AUTHORIZATION_PENDING" {
continue
}
if err != nil {
return nil, resp, err
}

if timeNow().After(expiresAt) {
return nil, nil, ErrTimeout
}
return token, resp, nil
}
}

// RefreshToken takes a refresh token and gets a new access token.
func (c *Config) RefreshToken(ctx context.Context, token string) (*Token, *opsmngr.Response, error) {
req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/token",
url.Values{
"client_id": {c.ClientID},
"refresh_token": {token},
"scope": {strings.Join(c.Scopes, " ")},
"grant_type": {"refresh_token"},
},
)
if err != nil {
return nil, nil, err
}
var t *Token
resp, err2 := c.Do(ctx, req, &t)
if err2 != nil {
return nil, resp, err2
}
return t, resp, err2
}

// RevokeToken takes an access or refresh token and revokes it.
func (c *Config) RevokeToken(ctx context.Context, token, tokenTypeHint string) (*opsmngr.Response, error) {
req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/revoke",
url.Values{
"client_id": {c.ClientID},
"token": {token},
"token_type_hint": {tokenTypeHint},
},
)
if err != nil {
return nil, err
}

return c.Do(ctx, req, nil)
}

// RegistrationConfig retrieves the config used for registration.
func (c *Config) RegistrationConfig(ctx context.Context) (*RegistrationConfig, *opsmngr.Response, error) {
req, err := c.NewRequest(ctx, http.MethodGet, deviceBasePath+"/registration", url.Values{})
if err != nil {
return nil, nil, err
}
var rc *RegistrationConfig
resp, err := c.Do(ctx, req, &rc)
if err != nil {
return nil, resp, err
}
return rc, resp, err
}

// IsTimeoutErr checks if the given error is for the case where the device flow has expired.
func IsTimeoutErr(err error) bool {
var target *opsmngr.ErrorResponse
return errors.Is(err, ErrTimeout) || (errors.As(err, &target) && target.ErrorCode == authExpiredError)
}
221 changes: 221 additions & 0 deletions auth/device_flow_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
// Copyright 2022 MongoDB Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package auth

import (
"fmt"
"net/http"
"testing"

"github.com/go-test/deep"
"go.mongodb.org/ops-manager/opsmngr"
)

func TestConfig_RequestCode(t *testing.T) {
config, mux, teardown := setup()
defer teardown()

mux.HandleFunc("/api/private/unauth/account/device/authorize", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r)
fmt.Fprintf(w, `{
"user_code": "QW3PYV7R",
"verification_uri": "%s/account/connect",
"device_code": "61eef18e310968047ff5e02a",
"expires_in": 600,
"interval": 10
}`, baseURLPath)
})

results, _, err := config.RequestCode(ctx)
if err != nil {
t.Fatalf("RequestCode returned error: %v", err)
}

expected := &DeviceCode{
UserCode: "QW3PYV7R",
VerificationURI: baseURLPath + "/account/connect",
DeviceCode: "61eef18e310968047ff5e02a",
ExpiresIn: 600,
Interval: 10,
}

if diff := deep.Equal(results, expected); diff != nil {
t.Error(diff)
}
}

func TestConfig_GetToken(t *testing.T) {
config, mux, teardown := setup()
defer teardown()

mux.HandleFunc("/api/private/unauth/account/device/token", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r)
fmt.Fprint(w, `{
"access_token": "secret1",
"refresh_token": "secret2",
"scope": "openid",
"id_token": "idtoken",
"token_type": "Bearer",
"expires_in": 3600
}`)
})
code := &DeviceCode{
DeviceCode: "61eef18e310968047ff5e02a",
ExpiresIn: 600,
Interval: 10,
}
results, _, err := config.GetToken(ctx, code.DeviceCode)
if err != nil {
t.Fatalf("GetToken returned error: %v", err)
}

expected := &Token{
AccessToken: "secret1",
RefreshToken: "secret2",
Scope: "openid",
IDToken: "idtoken",
TokenType: "Bearer",
ExpiresIn: 3600,
}

if diff := deep.Equal(results, expected); diff != nil {
t.Error(diff)
}
}

func TestConfig_RefreshToken(t *testing.T) {
config, mux, teardown := setup()
defer teardown()

mux.HandleFunc("/api/private/unauth/account/device/token", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r)
fmt.Fprint(w, `{
"access_token": "secret1",
"refresh_token": "secret2",
"scope": "openid",
"id_token": "idtoken",
"token_type": "Bearer",
"expires_in": 3600
}`)
})

results, _, err := config.RefreshToken(ctx, "secret2")
if err != nil {
t.Fatalf("RefreshToken returned error: %v", err)
}

expected := &Token{
AccessToken: "secret1",
RefreshToken: "secret2",
Scope: "openid",
IDToken: "idtoken",
TokenType: "Bearer",
ExpiresIn: 3600,
}

if diff := deep.Equal(results, expected); diff != nil {
t.Error(diff)
}
}

func TestConfig_PollToken(t *testing.T) {
config, mux, teardown := setup()
defer teardown()

mux.HandleFunc("/api/private/unauth/account/device/token", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r)
_, _ = fmt.Fprint(w, `{
"access_token": "secret1",
"refresh_token": "secret2",
"scope": "openid",
"id_token": "idtoken",
"token_type": "Bearer",
"expires_in": 3600
}`)
})
code := &DeviceCode{
DeviceCode: "61eef18e310968047ff5e02a",
ExpiresIn: 600,
Interval: 10,
}
results, _, err := config.PollToken(ctx, code)
if err != nil {
t.Fatalf("PollToken returned error: %v", err)
}

expected := &Token{
AccessToken: "secret1",
RefreshToken: "secret2",
Scope: "openid",
IDToken: "idtoken",
TokenType: "Bearer",
ExpiresIn: 3600,
}

if diff := deep.Equal(results, expected); diff != nil {
t.Error(diff)
}
}

func TestConfig_RevokeToken(t *testing.T) {
config, mux, teardown := setup()
defer teardown()

mux.HandleFunc("/api/private/unauth/account/device/revoke", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r)
})

_, err := config.RevokeToken(ctx, "a", "refresh_token")
if err != nil {
t.Fatalf("RequestCode returned error: %v", err)
}
}

func TestConfig_RegistrationConfig(t *testing.T) {
config, mux, teardown := setup()
defer teardown()

mux.HandleFunc("/api/private/unauth/account/device/registration", func(w http.ResponseWriter, r *http.Request) {
if http.MethodGet != r.Method {
t.Errorf("Request method = %v, expected %v", r.Method, http.MethodGet)
}

fmt.Fprint(w, `{
"registrationUrl": "http://localhost:8080/account/register/cli"
}`)
})

results, _, err := config.RegistrationConfig(ctx)
if err != nil {
t.Fatalf("RegistrationConfig returned error: %v", err)
}

expected := &RegistrationConfig{
RegistrationURL: "http://localhost:8080/account/register/cli",
}

if diff := deep.Equal(results, expected); diff != nil {
t.Error(diff)
}
}

func TestIsTimeoutErr(t *testing.T) {
err := &opsmngr.ErrorResponse{
ErrorCode: "DEVICE_AUTHORIZATION_EXPIRED",
}
if !IsTimeoutErr(err) {
t.Error("expected to be a timeout error")
}
}
Loading

0 comments on commit 1269096

Please sign in to comment.