diff --git a/auth/device_flow.go b/auth/device_flow.go new file mode 100644 index 0000000..7b42998 --- /dev/null +++ b/auth/device_flow.go @@ -0,0 +1,174 @@ +// Copyright 2022 MongoDB Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "errors" + "net/http" + "net/url" + "strings" + "time" + + "go.mongodb.org/ops-manager/opsmngr" +) + +const authExpiredError = "DEVICE_AUTHORIZATION_EXPIRED" + +// DeviceCode holds information about the authorization-in-progress. +type DeviceCode struct { + UserCode string `json:"user_code"` //nolint:tagliatelle // UserCode is the code presented to users + VerificationURI string `json:"verification_uri"` //nolint:tagliatelle // VerificationURI is the URI where users will need to confirm the code + DeviceCode string `json:"device_code"` //nolint:tagliatelle // DeviceCode is the internal code to confirm the status of the flow + ExpiresIn int `json:"expires_in"` //nolint:tagliatelle // ExpiresIn when the code will expire + Interval int `json:"interval"` // Interval how often to verify the status of the code + + timeNow func() time.Time + timeSleep func(time.Duration) +} + +type RegistrationConfig struct { + RegistrationURL string `json:"registrationUrl"` +} + +const deviceBasePath = "api/private/unauth/account/device" + +// RequestCode initiates the authorization flow by requesting a code. +func (c *Config) RequestCode(ctx context.Context) (*DeviceCode, *opsmngr.Response, error) { + req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/authorize", + url.Values{ + "client_id": {c.ClientID}, + "scope": {strings.Join(c.Scopes, " ")}, + }, + ) + if err != nil { + return nil, nil, err + } + var r *DeviceCode + resp, err2 := c.Do(ctx, req, &r) + return r, resp, err2 +} + +// GetToken gets a device token. +func (c *Config) GetToken(ctx context.Context, deviceCode string) (*Token, *opsmngr.Response, error) { + req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/token", + url.Values{ + "client_id": {c.ClientID}, + "device_code": {deviceCode}, + "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, + }, + ) + if err != nil { + return nil, nil, err + } + var t *Token + resp, err2 := c.Do(ctx, req, &t) + if err2 != nil { + return nil, resp, err2 + } + return t, resp, err2 +} + +// ErrTimeout is returned when polling the server for the granted token has timed out. +var ErrTimeout = errors.New("authentication timed out") + +// PollToken polls the server until an access token is granted or denied. +func (c *Config) PollToken(ctx context.Context, code *DeviceCode) (*Token, *opsmngr.Response, error) { + timeNow := code.timeNow + if timeNow == nil { + timeNow = time.Now + } + timeSleep := code.timeSleep + if timeSleep == nil { + timeSleep = time.Sleep + } + + checkInterval := time.Duration(code.Interval) * time.Second + expiresAt := timeNow().Add(time.Duration(code.ExpiresIn) * time.Second) + + for { + timeSleep(checkInterval) + token, resp, err := c.GetToken(ctx, code.DeviceCode) + var target *opsmngr.ErrorResponse + if errors.As(err, &target) && target.ErrorCode == "DEVICE_AUTHORIZATION_PENDING" { + continue + } + if err != nil { + return nil, resp, err + } + + if timeNow().After(expiresAt) { + return nil, nil, ErrTimeout + } + return token, resp, nil + } +} + +// RefreshToken takes a refresh token and gets a new access token. +func (c *Config) RefreshToken(ctx context.Context, token string) (*Token, *opsmngr.Response, error) { + req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/token", + url.Values{ + "client_id": {c.ClientID}, + "refresh_token": {token}, + "scope": {strings.Join(c.Scopes, " ")}, + "grant_type": {"refresh_token"}, + }, + ) + if err != nil { + return nil, nil, err + } + var t *Token + resp, err2 := c.Do(ctx, req, &t) + if err2 != nil { + return nil, resp, err2 + } + return t, resp, err2 +} + +// RevokeToken takes an access or refresh token and revokes it. +func (c *Config) RevokeToken(ctx context.Context, token, tokenTypeHint string) (*opsmngr.Response, error) { + req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/revoke", + url.Values{ + "client_id": {c.ClientID}, + "token": {token}, + "token_type_hint": {tokenTypeHint}, + }, + ) + if err != nil { + return nil, err + } + + return c.Do(ctx, req, nil) +} + +// RegistrationConfig retrieves the config used for registration. +func (c *Config) RegistrationConfig(ctx context.Context) (*RegistrationConfig, *opsmngr.Response, error) { + req, err := c.NewRequest(ctx, http.MethodGet, deviceBasePath+"/registration", url.Values{}) + if err != nil { + return nil, nil, err + } + var rc *RegistrationConfig + resp, err := c.Do(ctx, req, &rc) + if err != nil { + return nil, resp, err + } + return rc, resp, err +} + +// IsTimeoutErr checks if the given error is for the case where the device flow has expired. +func IsTimeoutErr(err error) bool { + var target *opsmngr.ErrorResponse + return errors.Is(err, ErrTimeout) || (errors.As(err, &target) && target.ErrorCode == authExpiredError) +} diff --git a/auth/device_flow_test.go b/auth/device_flow_test.go new file mode 100644 index 0000000..1165ae5 --- /dev/null +++ b/auth/device_flow_test.go @@ -0,0 +1,221 @@ +// Copyright 2022 MongoDB Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "fmt" + "net/http" + "testing" + + "github.com/go-test/deep" + "go.mongodb.org/ops-manager/opsmngr" +) + +func TestConfig_RequestCode(t *testing.T) { + config, mux, teardown := setup() + defer teardown() + + mux.HandleFunc("/api/private/unauth/account/device/authorize", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r) + fmt.Fprintf(w, `{ + "user_code": "QW3PYV7R", + "verification_uri": "%s/account/connect", + "device_code": "61eef18e310968047ff5e02a", + "expires_in": 600, + "interval": 10 + }`, baseURLPath) + }) + + results, _, err := config.RequestCode(ctx) + if err != nil { + t.Fatalf("RequestCode returned error: %v", err) + } + + expected := &DeviceCode{ + UserCode: "QW3PYV7R", + VerificationURI: baseURLPath + "/account/connect", + DeviceCode: "61eef18e310968047ff5e02a", + ExpiresIn: 600, + Interval: 10, + } + + if diff := deep.Equal(results, expected); diff != nil { + t.Error(diff) + } +} + +func TestConfig_GetToken(t *testing.T) { + config, mux, teardown := setup() + defer teardown() + + mux.HandleFunc("/api/private/unauth/account/device/token", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r) + fmt.Fprint(w, `{ + "access_token": "secret1", + "refresh_token": "secret2", + "scope": "openid", + "id_token": "idtoken", + "token_type": "Bearer", + "expires_in": 3600 + }`) + }) + code := &DeviceCode{ + DeviceCode: "61eef18e310968047ff5e02a", + ExpiresIn: 600, + Interval: 10, + } + results, _, err := config.GetToken(ctx, code.DeviceCode) + if err != nil { + t.Fatalf("GetToken returned error: %v", err) + } + + expected := &Token{ + AccessToken: "secret1", + RefreshToken: "secret2", + Scope: "openid", + IDToken: "idtoken", + TokenType: "Bearer", + ExpiresIn: 3600, + } + + if diff := deep.Equal(results, expected); diff != nil { + t.Error(diff) + } +} + +func TestConfig_RefreshToken(t *testing.T) { + config, mux, teardown := setup() + defer teardown() + + mux.HandleFunc("/api/private/unauth/account/device/token", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r) + fmt.Fprint(w, `{ + "access_token": "secret1", + "refresh_token": "secret2", + "scope": "openid", + "id_token": "idtoken", + "token_type": "Bearer", + "expires_in": 3600 + }`) + }) + + results, _, err := config.RefreshToken(ctx, "secret2") + if err != nil { + t.Fatalf("RefreshToken returned error: %v", err) + } + + expected := &Token{ + AccessToken: "secret1", + RefreshToken: "secret2", + Scope: "openid", + IDToken: "idtoken", + TokenType: "Bearer", + ExpiresIn: 3600, + } + + if diff := deep.Equal(results, expected); diff != nil { + t.Error(diff) + } +} + +func TestConfig_PollToken(t *testing.T) { + config, mux, teardown := setup() + defer teardown() + + mux.HandleFunc("/api/private/unauth/account/device/token", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r) + _, _ = fmt.Fprint(w, `{ + "access_token": "secret1", + "refresh_token": "secret2", + "scope": "openid", + "id_token": "idtoken", + "token_type": "Bearer", + "expires_in": 3600 + }`) + }) + code := &DeviceCode{ + DeviceCode: "61eef18e310968047ff5e02a", + ExpiresIn: 600, + Interval: 10, + } + results, _, err := config.PollToken(ctx, code) + if err != nil { + t.Fatalf("PollToken returned error: %v", err) + } + + expected := &Token{ + AccessToken: "secret1", + RefreshToken: "secret2", + Scope: "openid", + IDToken: "idtoken", + TokenType: "Bearer", + ExpiresIn: 3600, + } + + if diff := deep.Equal(results, expected); diff != nil { + t.Error(diff) + } +} + +func TestConfig_RevokeToken(t *testing.T) { + config, mux, teardown := setup() + defer teardown() + + mux.HandleFunc("/api/private/unauth/account/device/revoke", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r) + }) + + _, err := config.RevokeToken(ctx, "a", "refresh_token") + if err != nil { + t.Fatalf("RequestCode returned error: %v", err) + } +} + +func TestConfig_RegistrationConfig(t *testing.T) { + config, mux, teardown := setup() + defer teardown() + + mux.HandleFunc("/api/private/unauth/account/device/registration", func(w http.ResponseWriter, r *http.Request) { + if http.MethodGet != r.Method { + t.Errorf("Request method = %v, expected %v", r.Method, http.MethodGet) + } + + fmt.Fprint(w, `{ + "registrationUrl": "http://localhost:8080/account/register/cli" + }`) + }) + + results, _, err := config.RegistrationConfig(ctx) + if err != nil { + t.Fatalf("RegistrationConfig returned error: %v", err) + } + + expected := &RegistrationConfig{ + RegistrationURL: "http://localhost:8080/account/register/cli", + } + + if diff := deep.Equal(results, expected); diff != nil { + t.Error(diff) + } +} + +func TestIsTimeoutErr(t *testing.T) { + err := &opsmngr.ErrorResponse{ + ErrorCode: "DEVICE_AUTHORIZATION_EXPIRED", + } + if !IsTimeoutErr(err) { + t.Error("expected to be a timeout error") + } +} diff --git a/auth/oauth.go b/auth/oauth.go new file mode 100644 index 0000000..4f83454 --- /dev/null +++ b/auth/oauth.go @@ -0,0 +1,209 @@ +// Copyright 2022 MongoDB Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "runtime" + "strings" + + "go.mongodb.org/ops-manager/opsmngr" +) + +const defaultBaseURL = opsmngr.DefaultBaseURL + +var ( + userAgent = fmt.Sprintf("go-mongodbopsmanager/%s (%s;%s)", opsmngr.ClientVersion, runtime.GOOS, runtime.GOARCH) +) + +type Config struct { + client *http.Client + ClientID string + AuthURL *url.URL + UserAgent string + Scopes []string + + // copy raw server response to the Response struct + withRaw bool +} + +type ConfigOpt func(*Config) error + +func NewConfig(httpClient *http.Client) *Config { + if httpClient == nil { + httpClient = http.DefaultClient + } + + baseURL, _ := url.Parse(defaultBaseURL) + c := &Config{ + client: httpClient, + AuthURL: baseURL, + UserAgent: userAgent, + } + return c +} + +func NewConfigWithOptions(httpClient *http.Client, opts ...ConfigOpt) (*Config, error) { + c := NewConfig(httpClient) + for _, opt := range opts { + if err := opt(c); err != nil { + return nil, err + } + } + return c, nil +} + +// SetAuthURL is a config option for setting the base URL. +func SetAuthURL(bu string) ConfigOpt { + return func(c *Config) error { + u, err := url.Parse(bu) + if err != nil { + return err + } + + c.AuthURL = u + return nil + } +} + +// SetWithRaw is a client option for getting raw atlas server response within Response structure. +func SetWithRaw() ConfigOpt { + return func(c *Config) error { + c.withRaw = true + return nil + } +} + +// SetUserAgent is a config option for setting the user agent. +func SetUserAgent(ua string) ConfigOpt { + return func(c *Config) error { + c.UserAgent = fmt.Sprintf("%s %s", ua, userAgent) + return nil + } +} + +// SetClientID is a config option for setting the ClientID. +func SetClientID(clientID string) ConfigOpt { + return func(c *Config) error { + c.ClientID = clientID + return nil + } +} + +// SetScopes is a config option for setting the Scopes. +func SetScopes(scopes []string) ConfigOpt { + return func(c *Config) error { + c.Scopes = scopes + return nil + } +} + +// A TokenSource is anything that can return a token. +type TokenSource interface { + // Token returns a token or an error. + // Token must be safe for concurrent use by multiple goroutines. + // The returned Token must not be modified. + Token() (*Token, error) +} + +func (c *Config) Do(ctx context.Context, req *http.Request, v interface{}) (*opsmngr.Response, error) { + resp, err := opsmngr.DoRequestWithClient(ctx, c.client, req) + if err != nil { + // If we got an error, and the context has been canceled, + // the context's error is probably more useful. + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + return nil, err + } + + defer resp.Body.Close() + + r := &opsmngr.Response{Response: resp} + body := resp.Body + + if c.withRaw { + raw := new(bytes.Buffer) + _, err = io.Copy(raw, body) + if err != nil { + return r, err + } + + r.Raw = raw.Bytes() + body = io.NopCloser(raw) + } + + if err2 := r.CheckResponse(body); err2 != nil { + return r, err2 + } + + if v != nil { + if w, ok := v.(io.Writer); ok { + _, err = io.Copy(w, body) + if err != nil { + return nil, err + } + } else { + decErr := json.NewDecoder(body).Decode(v) + if errors.Is(decErr, io.EOF) { + decErr = nil // ignore EOF errors caused by empty response body + } + if decErr != nil { + err = decErr + } + } + } + return r, err +} + +func (c *Config) NewRequest(ctx context.Context, method, urlStr string, v url.Values) (*http.Request, error) { + if !strings.HasSuffix(c.AuthURL.Path, "/") { + return nil, fmt.Errorf("base URL must have a trailing slash, but %q does not", c.AuthURL) + } + rel, err := url.Parse(urlStr) + if err != nil { + return nil, err + } + + u := c.AuthURL.ResolveReference(rel) + + req, err := http.NewRequestWithContext(ctx, method, u.String(), strings.NewReader(v.Encode())) + if err != nil { + return nil, err + } + if c.UserAgent != "" { + req.Header.Set("User-Agent", c.UserAgent) + } + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Accept", "application/json") + if c.UserAgent != "" { + req.Header.Set("User-Agent", c.UserAgent) + } + return req, nil +} diff --git a/auth/oauth_test.go b/auth/oauth_test.go new file mode 100644 index 0000000..eeac3e9 --- /dev/null +++ b/auth/oauth_test.go @@ -0,0 +1,211 @@ +// Copyright 2022 MongoDB Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "os" + "reflect" + "testing" + + "go.mongodb.org/ops-manager/opsmngr" +) + +const ( + // baseURLPath is a non-empty Client.BaseURL path to use during tests, + // to ensure relative URLs are used for all endpoints. + baseURLPath = "/api-v1" +) + +var ( + ctx = context.TODO() +) + +// setup sets up a test HTTP server along with an auth.Config that is +// configured to talk to that test server. Tests should register handlers on +// mux which provide mock responses for the API method being tested. +func setup() (config *Config, mux *http.ServeMux, teardown func()) { + // mux is the HTTP request multiplexer used with the test server. + mux = http.NewServeMux() + + // We want to ensure that tests catch mistakes where the endpoint URL is + // specified as absolute rather than relative. It only makes a difference + // when there's a non-empty base URL path. So, use that. + apiHandler := http.NewServeMux() + apiHandler.Handle(baseURLPath+"/", http.StripPrefix(baseURLPath, mux)) + apiHandler.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + fmt.Fprintln(os.Stderr, "FAIL: Client.BaseURL path prefix is not preserved in the request URL:") + fmt.Fprintln(os.Stderr) + fmt.Fprintln(os.Stderr, "\t"+req.URL.String()) + fmt.Fprintln(os.Stderr) + fmt.Fprintln(os.Stderr, "\tDid you accidentally use an absolute endpoint URL rather than relative?") + http.Error(w, "Client.BaseURL path prefix is not preserved in the request URL.", http.StatusInternalServerError) + }) + + // server is a test HTTP server used to provide mock API responses. + server := httptest.NewServer(apiHandler) + + // client is the Atlas client being tested and is + // configured to use test server. + config = NewConfig(nil) + u, _ := url.Parse(server.URL + baseURLPath + "/") + config.AuthURL = u + + return config, mux, server.Close +} + +func testMethod(t *testing.T, r *http.Request) { + t.Helper() + if http.MethodPost != r.Method { + t.Errorf("Request method = %v, expected %v", r.Method, http.MethodPost) + } +} + +func testClientDefaultBaseURL(t *testing.T, c *Config) { + t.Helper() + if c.AuthURL == nil || c.AuthURL.String() != defaultBaseURL { + t.Errorf("NewConfig BaseURL = %v, expected %v", c.AuthURL, defaultBaseURL) + } +} + +func testClientDefaultUserAgent(t *testing.T, c *Config) { + t.Helper() + if c.UserAgent != userAgent { + t.Errorf("NewConfig UserAgent = %v, expected %v", c.UserAgent, userAgent) + } +} + +func testClientDefaults(t *testing.T, c *Config) { + t.Helper() + testClientDefaultBaseURL(t, c) + testClientDefaultUserAgent(t, c) +} + +func TestNewConfig(t *testing.T) { + c := NewConfig(nil) + testClientDefaults(t, c) +} + +func TestNewConfigWithOptions(t *testing.T) { + c, err := NewConfigWithOptions(nil) + + if err != nil { + t.Fatalf("NewConfigWithOptions(): %v", err) + } + testClientDefaults(t, c) +} + +func TestNewRequest_withCustomUserAgent(t *testing.T) { + ua := fmt.Sprintf("testing/%s", opsmngr.ClientVersion) + c, err := NewConfigWithOptions(nil, SetUserAgent(ua)) + + if err != nil { + t.Fatalf("NewConfigWithOptions() unexpected error: %v", err) + } + + req, _ := c.NewRequest(ctx, http.MethodGet, "/foo", nil) + + expected := fmt.Sprintf("%s %s", ua, userAgent) + if got := req.Header.Get("User-Agent"); got != expected { + t.Errorf("NewConfigWithOptions() UserAgent = %s; expected %s", got, expected) + } +} + +func TestNewPlainRequest(t *testing.T) { + c := NewConfig(nil) + + requestPath := "foo" + + inURL, outURL := requestPath, defaultBaseURL+requestPath + req, _ := c.NewRequest(ctx, http.MethodGet, inURL, nil) + + // test relative URL was expanded + if req.URL.String() != outURL { + t.Errorf("NewPlainRequest(%v) URL = %v, expected %v", inURL, req.URL, outURL) + } + + // test accept content type is correct + accept := req.Header.Get("Accept") + if accept != "application/json" { + t.Errorf("NewPlainRequest() Accept = %v, expected %v", accept, "application/json") + } + contentType := req.Header.Get("Content-Type") + if contentType != "application/x-www-form-urlencoded" { + t.Errorf("NewPlainRequest() Accept = %v, expected %v", contentType, "application/x-www-form-urlencoded") + } + // test default user-agent is attached to the request + uA := req.Header.Get("User-Agent") + if c.UserAgent != uA { + t.Errorf("NewPlainRequest() User-Agent = %v, expected %v", uA, c.UserAgent) + } +} + +func TestDo(t *testing.T) { + client, mux, teardown := setup() + defer teardown() + + type foo struct { + A string + } + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if m := http.MethodGet; m != r.Method { + t.Errorf("Request method = %v, expected %v", r.Method, m) + } + fmt.Fprint(w, `{"A":"a"}`) + }) + + req, _ := client.NewRequest(ctx, http.MethodGet, ".", nil) + body := new(foo) + _, err := client.Do(context.Background(), req, body) + if err != nil { + t.Fatalf("Do(): %v", err) + } + + expected := &foo{"a"} + if !reflect.DeepEqual(body, expected) { + t.Errorf("Response body = %v, expected %v", body, expected) + } +} + +func TestCustomUserAgent(t *testing.T) { + ua := fmt.Sprintf("testing/%s", opsmngr.ClientVersion) + c, err := NewConfigWithOptions(nil, SetUserAgent(ua)) + + if err != nil { + t.Fatalf("New() unexpected error: %v", err) + } + + expected := fmt.Sprintf("%s %s", ua, userAgent) + if got := c.UserAgent; got != expected { + t.Errorf("New() UserAgent = %s; expected %s", got, expected) + } +} + +func TestCustomBaseURL(t *testing.T) { + const baseURL = "http://localhost/foo" + c, err := NewConfigWithOptions(nil, SetAuthURL(baseURL)) + + if err != nil { + t.Fatalf("New() unexpected error: %v", err) + } + if got := c.AuthURL.String(); got != baseURL { + t.Errorf("New() BaseURL = %s; expected %s", got, baseURL) + } +} diff --git a/auth/token.go b/auth/token.go new file mode 100644 index 0000000..8f91bd1 --- /dev/null +++ b/auth/token.go @@ -0,0 +1,47 @@ +// Copyright 2022 MongoDB Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "net/http" + "time" +) + +type Token struct { + AccessToken string `json:"access_token"` //nolint:tagliatelle // used as in the API + RefreshToken string `json:"refresh_token"` //nolint:tagliatelle // used as in the API + Scope string `json:"scope"` + IDToken string `json:"id_token"` //nolint:tagliatelle // used as in the API + TokenType string `json:"token_type"` //nolint:tagliatelle // used as in the API + ExpiresIn int `json:"expires_in"` //nolint:tagliatelle // used as in the API + Expiry time.Time +} + +func (t *Token) SetAuthHeader(r *http.Request) { + r.Header.Set("Authorization", "Bearer "+t.AccessToken) +} + +const expiryDelta = 10 * time.Second + +func (t *Token) expired() bool { + if t.Expiry.IsZero() { + return false + } + return t.Expiry.Round(0).Add(-expiryDelta).Before(time.Now()) +} + +func (t *Token) Valid() bool { + return t != nil && t.AccessToken != "" && !t.expired() +} diff --git a/opsmngr/opsmngr.go b/opsmngr/opsmngr.go index 8274023..dc33100 100644 --- a/opsmngr/opsmngr.go +++ b/opsmngr/opsmngr.go @@ -32,11 +32,13 @@ import ( ) const ( - defaultBaseURL = "https://cloud.mongodb.com/" + DefaultBaseURL = "https://cloud.mongodb.com/" userAgent = "go-ops-manager" jsonMediaType = "application/json" gzipMediaType = "application/gzip" plainMediaType = "text/plain" + // ClientVersion of the current API client. Should be set to the next version planned to be released. + ClientVersion = "0.56.0" ) type HTTPClient interface { @@ -209,7 +211,7 @@ func NewClient(httpClient HTTPClient) *Client { httpClient = http.DefaultClient } - baseURL, _ := url.Parse(defaultBaseURL) + baseURL, _ := url.Parse(DefaultBaseURL) c := &Client{ client: httpClient, diff --git a/opsmngr/opsmngr_test.go b/opsmngr/opsmngr_test.go index c645084..c40e183 100644 --- a/opsmngr/opsmngr_test.go +++ b/opsmngr/opsmngr_test.go @@ -98,8 +98,8 @@ func testURLParseError(t *testing.T, err error) { func testClientDefaultBaseURL(t *testing.T, c *Client) { t.Helper() - if c.BaseURL == nil || c.BaseURL.String() != defaultBaseURL { - t.Errorf("NewClient AuthURL = %v, expected %v", c.BaseURL, defaultBaseURL) + if c.BaseURL == nil || c.BaseURL.String() != DefaultBaseURL { + t.Errorf("NewClient AuthURL = %v, expected %v", c.BaseURL, DefaultBaseURL) } } @@ -139,7 +139,7 @@ type testRequestBody struct { func TestNewRequest_withUserData(t *testing.T) { c := NewClient(nil) - inURL, outURL := requestPath, defaultBaseURL+requestPath + inURL, outURL := requestPath, DefaultBaseURL+requestPath inBody, outBody := &testRequestBody{TestName: "l", TestUserData: "u"}, `{"testName":"l","testCounter":0,`+ `"testUserData":"u"}`+"\n" @@ -293,7 +293,7 @@ func TestNewPlainRequest_badURL(t *testing.T) { func TestNewPlainRequest(t *testing.T) { c := NewClient(nil) - inURL, outURL := requestPath, defaultBaseURL+requestPath + inURL, outURL := requestPath, DefaultBaseURL+requestPath req, _ := c.NewPlainRequest(ctx, http.MethodGet, inURL) // test relative URL was expanded