From 630cea128e523ddd996814ffcb805c9e8ea227d8 Mon Sep 17 00:00:00 2001 From: Sergei Trofimov Date: Tue, 29 Aug 2023 11:44:47 +0100 Subject: [PATCH] Add authentication support Add support for authentication with the remote service. Basic HTTP and OAuth2 schemes are supported (as well as "passthrough" which disables authentication). This is implemented by associating an IAuthorizer with the Client, which sets the Authrization HTTP header in Client's requests. Signed-off-by: Sergei Trofimov --- auth/basic.go | 71 ++++++++++++++ auth/basic_test.go | 54 +++++++++++ auth/iauthenticator.go | 8 ++ auth/method.go | 42 +++++++++ auth/null.go | 13 +++ auth/oauth2.go | 122 +++++++++++++++++++++++++ auth/oauth2_test.go | 52 +++++++++++ common/client.go | 41 +++++++-- go.mod | 6 ++ go.sum | 24 +++++ management/management.go | 8 +- management/management_test.go | 6 +- provisioning/provisioning.go | 23 ++++- provisioning/provisioning_test.go | 2 +- verification/challengeresponse.go | 20 ++-- verification/challengeresponse_test.go | 2 +- 16 files changed, 466 insertions(+), 28 deletions(-) create mode 100644 auth/basic.go create mode 100644 auth/basic_test.go create mode 100644 auth/iauthenticator.go create mode 100644 auth/method.go create mode 100644 auth/null.go create mode 100644 auth/oauth2.go create mode 100644 auth/oauth2_test.go diff --git a/auth/basic.go b/auth/basic.go new file mode 100644 index 0000000..08cd6ab --- /dev/null +++ b/auth/basic.go @@ -0,0 +1,71 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package auth + +import ( + "encoding/base64" + "errors" + "fmt" + "strings" + + "github.com/mitchellh/mapstructure" +) + +type BasicAuthenticator struct { + Username string + Password string +} + +func (o *BasicAuthenticator) Configure(cfg map[string]interface{}) error { + decoded := struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + Rest map[string]interface{} `mapstructure:",remain"` + }{} + + if err := mapstructure.Decode(cfg, &decoded); err != nil { + return err + } + + o.Username = decoded.Username + o.Password = decoded.Password + + if err := o.validate(); err != nil { + return err + } + + if len(decoded.Rest) > 0 { + var unexpected []string + for k := range decoded.Rest { + unexpected = append(unexpected, k) + } + return fmt.Errorf("unexpected fields in config: %s", + strings.Join(unexpected, ", ")) + } + + return nil +} + +func (o *BasicAuthenticator) EncodeHeader() (string, error) { + if err := o.validate(); err != nil { + return "", err + } + + credsRaw := fmt.Sprintf("%s:%s", o.Username, o.Password) + credsEncoded := base64.StdEncoding.EncodeToString([]byte(credsRaw)) + header := fmt.Sprintf("Basic %s", credsEncoded) + + return header, nil +} + +func (o *BasicAuthenticator) validate() error { + if o.Username == "" { + return errors.New("missing username") + } + + if o.Password == "" { + return errors.New("missing password") + } + + return nil +} diff --git a/auth/basic_test.go b/auth/basic_test.go new file mode 100644 index 0000000..ba99162 --- /dev/null +++ b/auth/basic_test.go @@ -0,0 +1,54 @@ +package auth + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBasic_Configure(t *testing.T) { + var ba BasicAuthenticator + + err := ba.Configure(map[string]interface{}{ + "username": "user1", + "password": "Passw0rd!", + }) + require.NoError(t, err) + assert.Equal(t, "user1", ba.Username) + assert.Equal(t, "Passw0rd!", ba.Password) + + err = ba.Configure(map[string]interface{}{ + "username": "user1", + }) + assert.EqualError(t, err, "missing password") + + err = ba.Configure(map[string]interface{}{ + "password": "Passw0rd!", + }) + assert.EqualError(t, err, "missing username") + + err = ba.Configure(map[string]interface{}{ + "username": "user1", + "password": "Passw0rd!", + "full name": "User One", + }) + assert.EqualError(t, err, "unexpected fields in config: full name") +} + +func TestBasic_EncodeHeader(t *testing.T) { + var ba BasicAuthenticator + + _, err := ba.EncodeHeader() + assert.EqualError(t, err, "missing username") + + err = ba.Configure(map[string]interface{}{ + "username": "user1", + "password": "Passw0rd!", + }) + require.NoError(t, err) + + header, err := ba.EncodeHeader() + require.NoError(t, err) + assert.Equal(t, "Basic dXNlcjE6UGFzc3cwcmQh", header) +} diff --git a/auth/iauthenticator.go b/auth/iauthenticator.go new file mode 100644 index 0000000..d11f8ba --- /dev/null +++ b/auth/iauthenticator.go @@ -0,0 +1,8 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package auth + +type IAuthenticator interface { + Configure(cfg map[string]interface{}) error + EncodeHeader() (string, error) +} diff --git a/auth/method.go b/auth/method.go new file mode 100644 index 0000000..62f02d2 --- /dev/null +++ b/auth/method.go @@ -0,0 +1,42 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 + +package auth + +import "fmt" + +// Method is the enumeration of authentication methods supported by Veraison +// service. It implements the pflag.Value interface. +type Method string + +const ( + MethodPassthrough Method = "passthrough" + MethodBasic Method = "basic" + MethodOauth2 Method = "oauth2" +) + +// String representation of the Method +func (o *Method) String() string { + return string(*o) +} + +// Set the value of the Method +func (o *Method) Set(v string) error { + switch v { + case "none", "passthrough": + *o = MethodPassthrough + case "basic": + *o = MethodBasic + case "oauth2": + *o = MethodOauth2 + default: + return fmt.Errorf("unexpected Method %q", v) + } + + return nil +} + +// Type returns the string representing the type name (used by pflag). +func (o *Method) Type() string { + return "Method" +} diff --git a/auth/null.go b/auth/null.go new file mode 100644 index 0000000..6d74fa3 --- /dev/null +++ b/auth/null.go @@ -0,0 +1,13 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package auth + +type NullAuthenticator struct{} + +func (o *NullAuthenticator) Configure(cfg map[string]interface{}) error { + return nil +} + +func (o *NullAuthenticator) EncodeHeader() (string, error) { + return "", nil +} diff --git a/auth/oauth2.go b/auth/oauth2.go new file mode 100644 index 0000000..cf22409 --- /dev/null +++ b/auth/oauth2.go @@ -0,0 +1,122 @@ +// Copyright 2023 Contributors to the Veraison project. +// SPDX-License-Identifier: Apache-2.0 +package auth + +import ( + "context" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/mitchellh/mapstructure" + "golang.org/x/oauth2" +) + +type Oauth2Authenticator struct { + TokenURL string + ClientID string + ClientSecret string + Username string + Password string + + Token *oauth2.Token +} + +func (o *Oauth2Authenticator) Configure(cfg map[string]interface{}) error { + decoded := struct { + TokenURL string `mapstructure:"token_url" valid:"url"` + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + Rest map[string]interface{} `mapstructure:",remain"` + }{} + + if err := mapstructure.Decode(cfg, &decoded); err != nil { + return err + } + + o.ClientID = decoded.ClientID + o.ClientSecret = decoded.ClientSecret + o.TokenURL = decoded.TokenURL + o.Username = decoded.Username + o.Password = decoded.Password + + if err := o.validate(); err != nil { + return err + } + + if len(decoded.Rest) > 0 { + var unexpected []string + for k := range decoded.Rest { + unexpected = append(unexpected, k) + } + return fmt.Errorf("unexpected fields in config: %s", + strings.Join(unexpected, ", ")) + } + + return nil +} + +func (o *Oauth2Authenticator) EncodeHeader() (string, error) { + var err error + + if o.Token == nil || o.Token.Expiry.Before(time.Now()) { + o.Token, err = o.obtainToken() + if err != nil { + return "", err + } + } + + header := fmt.Sprintf("Bearer %s", o.Token.AccessToken) + + return header, nil +} + +func (o *Oauth2Authenticator) obtainToken() (*oauth2.Token, error) { + if err := o.validate(); err != nil { + return nil, err + } + + ctx := context.Background() + conf := &oauth2.Config{ + ClientID: o.ClientID, + ClientSecret: o.ClientSecret, + Scopes: []string{"openid"}, + Endpoint: oauth2.Endpoint{ + TokenURL: o.TokenURL, + }, + } + + return conf.PasswordCredentialsToken(ctx, o.Username, o.Password) +} + +func (o *Oauth2Authenticator) validate() error { + if o.ClientID == "" { + return errors.New("missing client_id") + } + + if o.ClientSecret == "" { + return errors.New("missing client_secret") + } + + if o.TokenURL == "" { + return errors.New("missing token_url") + } + + if _, err := url.Parse(o.TokenURL); err != nil { + return fmt.Errorf("invalid token_url: %w", err) + } + + if o.Username == "" { + return errors.New("missing username") + } + + if o.Password == "" { + return errors.New("missing password") + } + + return nil +} diff --git a/auth/oauth2_test.go b/auth/oauth2_test.go new file mode 100644 index 0000000..ddef2c6 --- /dev/null +++ b/auth/oauth2_test.go @@ -0,0 +1,52 @@ +package auth + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOauth2_Configure(t *testing.T) { + var oa2a Oauth2Authenticator + + err := oa2a.Configure(map[string]interface{}{ + "client_id": "myclient", + "client_secret": "deadbeef", + "username": "user1", + "password": "Passw0rd!", + "token_url": "http://example.com", + }) + require.NoError(t, err) + assert.Equal(t, "user1", oa2a.Username) + assert.Equal(t, "Passw0rd!", oa2a.Password) + assert.Equal(t, "myclient", oa2a.ClientID) + assert.Equal(t, "deadbeef", oa2a.ClientSecret) + assert.Equal(t, "http://example.com", oa2a.TokenURL) + + err = oa2a.Configure(map[string]interface{}{ + "client_id": "myclient", + "client_secret": "deadbeef", + "username": "user1", + "token_url": "http://example.com", + }) + assert.EqualError(t, err, "missing password") + + err = oa2a.Configure(map[string]interface{}{ + "client_id": "myclient", + "client_secret": "deadbeef", + "token_url": "http://example.com", + "password": "Passw0rd!", + }) + assert.EqualError(t, err, "missing username") + + err = oa2a.Configure(map[string]interface{}{ + "client_id": "myclient", + "client_secret": "deadbeef", + "username": "user1", + "password": "Passw0rd!", + "token_url": "http://example.com", + "full name": "User One", + }) + assert.EqualError(t, err, "unexpected fields in config: full name") +} diff --git a/common/client.go b/common/client.go index 80460f7..08cebce 100644 --- a/common/client.go +++ b/common/client.go @@ -6,26 +6,34 @@ package common import ( "bytes" "fmt" + "io" "net/http" "time" + + "github.com/veraison/apiclient/auth" ) -// Client holds configuration data associated with the HTTP(s) session +// Client holds configuration data associated with the HTTP(s) session, and a +// reference to an IAuthenticator that is used to provide Authorization headers +// for requests. type Client struct { HTTPClient http.Client + Auth auth.IAuthenticator } -// NewClient instantiates a new Client with a fixed 5s timeout -func NewClient() *Client { +// NewClient instantiates a new Client with a fixed 5s timeout. The client will +// use the provided IAuthenticator for requests, if it is not nil +func NewClient(a auth.IAuthenticator) *Client { return &Client{ HTTPClient: http.Client{ Timeout: 5 * time.Second, }, + Auth: a, } } func (c Client) DeleteResource(uri string) error { - req, err := http.NewRequest("DELETE", uri, http.NoBody) + req, err := c.newRequest("DELETE", uri, http.NoBody) if err != nil { return fmt.Errorf("DELETE %q, request creation failed: %w", uri, err) } @@ -45,7 +53,7 @@ func (c Client) DeleteResource(uri string) error { } func (c Client) PostResource(body []byte, ct, accept, uri string) (*http.Response, error) { - req, err := http.NewRequest("POST", uri, bytes.NewBuffer(body)) + req, err := c.newRequest("POST", uri, bytes.NewBuffer(body)) if err != nil { return nil, fmt.Errorf("POST %q, request creation failed: %w", uri, err) } @@ -57,7 +65,7 @@ func (c Client) PostResource(body []byte, ct, accept, uri string) (*http.Respons } func (c Client) PostEmptyResource(accept, uri string) (*http.Response, error) { - req, err := http.NewRequest("POST", uri, http.NoBody) + req, err := c.newRequest("POST", uri, http.NoBody) if err != nil { return nil, fmt.Errorf("POST %q, request creation failed: %w", uri, err) } @@ -68,7 +76,7 @@ func (c Client) PostEmptyResource(accept, uri string) (*http.Response, error) { } func (c Client) GetResource(accept, uri string) (*http.Response, error) { - req, err := http.NewRequest("GET", uri, http.NoBody) + req, err := c.newRequest("GET", uri, http.NoBody) if err != nil { return nil, fmt.Errorf("POST %q, request creation failed: %w", uri, err) } @@ -78,6 +86,25 @@ func (c Client) GetResource(accept, uri string) (*http.Response, error) { return c.send(req) } +func (c Client) newRequest(method, uri string, body io.Reader) (*http.Request, error) { + req, err := http.NewRequest(method, uri, body) + if err != nil { + return nil, err + } + + if c.Auth != nil { + header, err := c.Auth.EncodeHeader() + if err != nil { + return nil, fmt.Errorf("could not get Authorization header: %w", err) + } + if header != "" { + req.Header.Set("Authorization", header) + } + } + + return req, nil +} + func (c Client) send(req *http.Request) (*http.Response, error) { hc := &c.HTTPClient diff --git a/go.mod b/go.mod index aaed54a..71be539 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.19 require ( github.com/google/uuid v1.3.0 + github.com/mitchellh/mapstructure v1.5.0 github.com/moogar0880/problems v0.1.1 github.com/stretchr/testify v1.8.2 github.com/veraison/cmw v0.1.0 @@ -12,7 +13,12 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/fxamacker/cbor/v2 v2.4.0 // indirect + github.com/golang/protobuf v1.5.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/x448/float16 v0.8.4 // indirect + golang.org/x/net v0.14.0 // indirect + golang.org/x/oauth2 v0.11.0 // indirect + google.golang.org/appengine v1.6.7 // indirect + google.golang.org/protobuf v1.31.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index bee98f6..fe82697 100644 --- a/go.sum +++ b/go.sum @@ -3,8 +3,15 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fxamacker/cbor/v2 v2.4.0 h1:ri0ArlOR+5XunOP8CRUowT0pSJOwhW098ZCUyskZD88= github.com/fxamacker/cbor/v2 v2.4.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/moogar0880/problems v0.1.1 h1:bktLhq8NDG/czU2ZziYNigBFksx13RaYe5AVdNmHDT4= github.com/moogar0880/problems v0.1.1/go.mod h1:5Dxrk2sD7BfBAgnOzQ1yaTiuCYdGPUh49L8Vhfky62c= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -20,6 +27,23 @@ github.com/veraison/cmw v0.1.0 h1:vD6tBlGPROCW/HlDcG1jh+XUJi5ihrjXatKZBjrv8mU= github.com/veraison/cmw v0.1.0/go.mod h1:WoBrlgByc6C1FeHhdze1/bQx1kv5d1sWKO5ezEf4Hs4= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/oauth2 v0.11.0 h1:vPL4xzxBM4niKCW6g9whtaWVXTJf1U5e4aZxxFx/gbU= +golang.org/x/oauth2 v0.11.0/go.mod h1:LdF7O/8bLR/qWK9DrpXmbHLTouvRHK0SgJl0GmDBchk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= +google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/management/management.go b/management/management.go index 7c3af37..d93b833 100644 --- a/management/management.go +++ b/management/management.go @@ -10,6 +10,7 @@ import ( "net/url" "github.com/google/uuid" + "github.com/veraison/apiclient/auth" "github.com/veraison/apiclient/common" ) @@ -33,9 +34,10 @@ type Service struct { } // NewService creates a new Service instance using the provided endpoint -// URI and the default HTTP client. -func NewService(uri string) (*Service, error) { - m := Service{Client: common.NewClient()} +// URI and the default HTTP client. If the supplied IAuthenticator is not nil, +// that will be used to set the Authorization header in the service requests. +func NewService(uri string, a auth.IAuthenticator) (*Service, error) { + m := Service{Client: common.NewClient(a)} if err := m.SetEndpointURI(uri); err != nil { return nil, err diff --git a/management/management_test.go b/management/management_test.go index 32bb9df..0f85262 100644 --- a/management/management_test.go +++ b/management/management_test.go @@ -28,13 +28,13 @@ var ( ) func TestService_NewService(t *testing.T) { - _, err := NewService(string([]byte{0x7f})) + _, err := NewService(string([]byte{0x7f}), nil) assert.EqualError(t, err, "malformed URI: parse \"\\x7f\": net/url: invalid control character in URL") - _, err = NewService("test") + _, err = NewService("test", nil) assert.EqualError(t, err, "URI is not absolute: \"test\"") - service, err := NewService("http://veraison.example:9999/test/v1") + service, err := NewService("http://veraison.example:9999/test/v1", nil) assert.NoError(t, err) assert.Equal(t, "veraison.example:9999", service.EndPointURI.Host) } diff --git a/provisioning/provisioning.go b/provisioning/provisioning.go index 2251777..f8675d5 100644 --- a/provisioning/provisioning.go +++ b/provisioning/provisioning.go @@ -11,6 +11,7 @@ import ( "net/url" "time" + "github.com/veraison/apiclient/auth" "github.com/veraison/apiclient/common" ) @@ -28,9 +29,10 @@ type SubmitSession struct { // SubmitConfig holds the context of an endorsement submission API session type SubmitConfig struct { - Client *common.Client // HTTP(s) client connection configuration - SubmitURI string // URI of the /submit endpoint - DeleteSession bool // explicitly DELETE the session object after we are done + Client *common.Client // HTTP(s) client connection configuration + SubmitURI string // URI of the /submit endpoint + DeleteSession bool // explicitly DELETE the session object after we are done + Auth auth.IAuthenticator // when set, Auth supplies the Authorization header for requests } // SetClient sets the HTTP(s) client connection configuration @@ -38,6 +40,11 @@ func (cfg *SubmitConfig) SetClient(client *common.Client) error { if client == nil { return errors.New("no client supplied") } + + if cfg.Auth != nil { + client.Auth = cfg.Auth + } + cfg.Client = client return nil } @@ -60,6 +67,14 @@ func (cfg *SubmitConfig) SetDeleteSession(session bool) { cfg.DeleteSession = session } +// SetAuth sets the IAuthenticator that will be used +func (cfg *SubmitConfig) SetAuth(a auth.IAuthenticator) { + cfg.Auth = a + if cfg.Client != nil { + cfg.Client.Auth = cfg.Auth + } +} + // Run implements the endorsement submission API. If the session does not // complete synchronously, this call will block until either the session state // moves out of the processing state, or the MaxAttempts*PollPeriod threshold is @@ -70,7 +85,7 @@ func (cfg SubmitConfig) Run(endorsement []byte, mediaType string) error { } if cfg.Client == nil { - cfg.Client = common.NewClient() + cfg.Client = common.NewClient(cfg.Auth) } // POST endorsement to the /submit endpoint diff --git a/provisioning/provisioning_test.go b/provisioning/provisioning_test.go index 99399d0..cadaae4 100644 --- a/provisioning/provisioning_test.go +++ b/provisioning/provisioning_test.go @@ -37,7 +37,7 @@ func TestSubmitConfig_check_no_submit_uri(t *testing.T) { func TestSubmitConfig_SetClient_ok(t *testing.T) { tv := SubmitConfig{} - client := common.NewClient() + client := common.NewClient(nil) err := tv.SetClient(client) assert.NoError(t, err) } diff --git a/verification/challengeresponse.go b/verification/challengeresponse.go index 5fa6885..3617267 100644 --- a/verification/challengeresponse.go +++ b/verification/challengeresponse.go @@ -13,6 +13,7 @@ import ( "net/url" "time" + "github.com/veraison/apiclient/auth" "github.com/veraison/apiclient/common" "github.com/veraison/cmw" ) @@ -38,13 +39,14 @@ var cmwInfoMap = map[CmwWrap]cmwInfo{ // ChallengeResponseConfig holds the configuration for one or more // challenge-response exchanges type ChallengeResponseConfig struct { - Nonce []byte // an explicit nonce supplied by the user - NonceSz uint // the size of a nonce to be provided by server - EvidenceBuilder EvidenceBuilder // Evidence generation logics supplied by the user - NewSessionURI string // URI of the "/newSession" endpoint - Client *common.Client // HTTP(s) client connection configuration - DeleteSession bool // explicitly DELETE the session object after we are done - Wrap CmwWrap // when set, wrap the supplied evidence as a Conceptual Message Wrapper(CMW) + Nonce []byte // an explicit nonce supplied by the user + NonceSz uint // the size of a nonce to be provided by server + EvidenceBuilder EvidenceBuilder // Evidence generation logics supplied by the user + NewSessionURI string // URI of the "/newSession" endpoint + Client *common.Client // HTTP(s) client connection configuration + DeleteSession bool // explicitly DELETE the session object after we are done + Wrap CmwWrap // when set, wrap the supplied evidence as a Conceptual Message Wrapper(CMW) + Auth auth.IAuthenticator // when set, Auth supplies the Authorization header for requests } // Blob wraps a base64 encoded value together with its media type @@ -146,7 +148,7 @@ func (cfg ChallengeResponseConfig) Run() ([]byte, error) { // Attach the default client if the user hasn't supplied one if cfg.Client == nil { - cfg.Client = common.NewClient() + cfg.Client = common.NewClient(cfg.Auth) } newSessionCtx, sessionURI, err := cfg.newSession() @@ -196,7 +198,7 @@ func (cfg ChallengeResponseConfig) NewSession() (*ChallengeResponseSession, stri // Attach the default client if the user hasn't supplied one if cfg.Client == nil { - cfg.Client = common.NewClient() + cfg.Client = common.NewClient(cfg.Auth) } return cfg.newSession() diff --git a/verification/challengeresponse_test.go b/verification/challengeresponse_test.go index 32ae440..e3cd7cb 100644 --- a/verification/challengeresponse_test.go +++ b/verification/challengeresponse_test.go @@ -67,7 +67,7 @@ func TestChallengeResponseConfig_SetNonceSz_zero_noncesz(t *testing.T) { } func TestChallengeResponseConfig_SetClient_ok(t *testing.T) { cfg := ChallengeResponseConfig{} - client := common.NewClient() + client := common.NewClient(nil) err := cfg.SetClient(client) assert.NoError(t, err) }