Skip to content

Commit

Permalink
use lestrrat jwx lib for jwks
Browse files Browse the repository at this point in the history
  • Loading branch information
ecrupper committed May 21, 2024
1 parent ab20991 commit 75770cd
Show file tree
Hide file tree
Showing 22 changed files with 453 additions and 204 deletions.
3 changes: 1 addition & 2 deletions api/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/gin-gonic/gin"

"github.com/go-vela/server/api/types"
"github.com/go-vela/server/database"
"github.com/go-vela/server/util"
)
Expand Down Expand Up @@ -40,5 +39,5 @@ func GetJWKS(c *gin.Context) {
return
}

c.JSON(http.StatusOK, types.JWKS{Keys: keys})
c.JSON(http.StatusOK, keys)
}
31 changes: 13 additions & 18 deletions api/types/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

package types

import "github.com/golang-jwt/jwt/v5"

// OpenIDConfig is a struct that represents the OpenID Connect configuration.
//
// swagger:model OpenIDConfig
Expand All @@ -12,22 +14,15 @@ type OpenIDConfig struct {
Algorithms []string `json:"id_token_signing_alg_values_supported"`
}

// JWKS is a slice of JWKs.
//
// swagger:model JWKS
type JWKS struct {
Keys []JWK `json:"keys"`
}

// JWK represents a JSON Web Key parsed with fields as the correct Go types.
type JWK struct {
Algorithm string `json:"alg"`
Use string `json:"use"`
X5t string `json:"x5t"`
Kid string `json:"kid"`
Kty string `json:"kty"`
X5c []string `json:"x5c"`

N string `json:"n"` // modulus
E string `json:"e"` // public exponent
// OpenIDClaims struct is an extension of the JWT standard claims. It
// includes information relevant to OIDC services.
type OpenIDClaims struct {
BuildNumber int `json:"build_number,omitempty"`
Actor string `json:"actor,omitempty"`
Repo string `json:"repo,omitempty"`
TokenType string `json:"token_type,omitempty"`
Image string `json:"image,omitempty"`
Request string `json:"request,omitempty"`
Commands bool `json:"commands,omitempty"`
jwt.RegisteredClaims
}
62 changes: 32 additions & 30 deletions database/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/adhocore/gronx"
"github.com/google/go-cmp/cmp"
"github.com/lestrrat-go/jwx/jwk"

api "github.com/go-vela/server/api/types"
"github.com/go-vela/server/api/types/settings"
Expand All @@ -20,7 +21,7 @@ import (
"github.com/go-vela/server/database/deployment"
"github.com/go-vela/server/database/executable"
"github.com/go-vela/server/database/hook"
"github.com/go-vela/server/database/jwk"
dbJWK "github.com/go-vela/server/database/jwk"
"github.com/go-vela/server/database/log"
"github.com/go-vela/server/database/pipeline"
"github.com/go-vela/server/database/repo"
Expand All @@ -44,7 +45,7 @@ type Resources struct {
Deployments []*library.Deployment
Executables []*library.BuildExecutable
Hooks []*library.Hook
JWKs []api.JWK
JWKs jwk.Set
Logs []*library.Log
Pipelines []*library.Pipeline
Repos []*api.Repo
Expand Down Expand Up @@ -863,7 +864,7 @@ func testJWKs(t *testing.T, db Interface, resources *Resources) {
// create a variable to track the number of methods called for jwks
methods := make(map[string]bool)
// capture the element type of the jwk interface
element := reflect.TypeOf(new(jwk.JWKInterface)).Elem()
element := reflect.TypeOf(new(dbJWK.JWKInterface)).Elem()
// iterate through all methods found in the jwk interface
for i := 0; i < element.NumMethod(); i++ {
// skip tracking the methods to create indexes and tables for jwks
Expand All @@ -876,10 +877,14 @@ func testJWKs(t *testing.T, db Interface, resources *Resources) {
methods[element.Method(i).Name] = false
}

for _, jwk := range resources.JWKs {
err := db.CreateJWK(context.TODO(), jwk)
for i := 0; i < resources.JWKs.Len(); i++ {
jk, _ := resources.JWKs.Get(i)

jkPub, _ := jk.(jwk.RSAPublicKey)

err := db.CreateJWK(context.TODO(), jkPub)
if err != nil {
t.Errorf("unable to create jwk %s: %v", jwk.Kid, err)
t.Errorf("unable to create jwk %s: %v", jkPub.KeyID(), err)
}
}
methods["CreateJWK"] = true
Expand All @@ -895,14 +900,18 @@ func testJWKs(t *testing.T, db Interface, resources *Resources) {

methods["ListJWKs"] = true

for _, jwk := range resources.JWKs {
got, err := db.GetActiveJWK(context.TODO(), jwk.Kid)
for i := 0; i < resources.JWKs.Len(); i++ {
jk, _ := resources.JWKs.Get(i)

jkPub, _ := jk.(jwk.RSAPublicKey)

got, err := db.GetActiveJWK(context.TODO(), jkPub.KeyID())
if err != nil {
t.Errorf("unable to get jwk %s: %v", jwk.Kid, err)
t.Errorf("unable to get jwk %s: %v", jkPub.KeyID(), err)
}

if !cmp.Equal(jwk, got) {
t.Errorf("GetJWK() is %v, want %v", got, jwk)
if !cmp.Equal(jkPub, got) {
t.Errorf("GetJWK() is %v, want %v", got, jkPub)
}
}

Expand All @@ -913,8 +922,12 @@ func testJWKs(t *testing.T, db Interface, resources *Resources) {
t.Errorf("unable to rotate keys: %v", err)
}

for _, jwk := range resources.JWKs {
_, err := db.GetActiveJWK(context.TODO(), jwk.Kid)
for i := 0; i < resources.JWKs.Len(); i++ {
jk, _ := resources.JWKs.Get(i)

jkPub, _ := jk.(jwk.RSAPublicKey)

_, err := db.GetActiveJWK(context.TODO(), jkPub.KeyID())
if err == nil {
t.Errorf("GetActiveJWK() should return err after rotation")
}
Expand Down Expand Up @@ -2557,23 +2570,12 @@ func newResources() *Resources {
hookThree.SetLink("https://github.com/github/octocat/settings/hooks/1")
hookThree.SetWebhookID(78910)

jwkOne := api.JWK{
Algorithm: "RS256",
Kid: "c8da1302-07d6-11ea-882f-4893bca275b8",
Kty: "rsa",
Use: "sig",
N: "123456",
E: "123",
}
jwkOne := testutils.JWK()
jwkTwo := testutils.JWK()

jwkTwo := api.JWK{
Algorithm: "RS256",
Kid: "c8da1302-07d6-11ea-882f-4893bca275b9",
Kty: "rsa",
Use: "sig",
N: "789101",
E: "456",
}
jwkSet := jwk.NewSet()
jwkSet.Add(jwkOne)
jwkSet.Add(jwkTwo)

logServiceOne := new(library.Log)
logServiceOne.SetID(1)
Expand Down Expand Up @@ -2840,7 +2842,7 @@ func newResources() *Resources {
Deployments: []*library.Deployment{deploymentOne, deploymentTwo},
Executables: []*library.BuildExecutable{executableOne, executableTwo},
Hooks: []*library.Hook{hookOne, hookTwo, hookThree},
JWKs: []api.JWK{jwkOne, jwkTwo},
JWKs: jwkSet,
Logs: []*library.Log{logServiceOne, logServiceTwo, logStepOne, logStepTwo},
Pipelines: []*library.Pipeline{pipelineOne, pipelineTwo},
Repos: []*api.Repo{repoOne, repoTwo},
Expand Down
13 changes: 4 additions & 9 deletions database/jwk/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,22 @@ import (
"context"
"database/sql"

"github.com/lestrrat-go/jwx/jwk"
"github.com/sirupsen/logrus"

api "github.com/go-vela/server/api/types"
"github.com/go-vela/server/constants"
"github.com/go-vela/server/database/types"
)

// CreateJWK creates a new JWK in the database.
func (e *engine) CreateJWK(_ context.Context, j api.JWK) error {
func (e *engine) CreateJWK(_ context.Context, j jwk.RSAPublicKey) error {
e.logger.WithFields(logrus.Fields{
"jwk": j.Kid,
}).Tracef("creating key %s in the database", j.Kid)
"jwk": j.KeyID(),
}).Tracef("creating key %s in the database", j.KeyID())

key := types.JWKFromAPI(j)
key.Active = sql.NullBool{Bool: true, Valid: true}

err := key.Validate()
if err != nil {
return err
}

// send query to the database
return e.client.Table(constants.TableJWK).Create(key).Error
}
10 changes: 7 additions & 3 deletions database/jwk/create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package jwk

import (
"context"
"encoding/json"
"testing"

"github.com/DATA-DOG/go-sqlmock"
Expand All @@ -13,8 +14,11 @@ import (

func TestJWK_Engine_CreateJWK(t *testing.T) {
// setup types
_jwk := testutils.APIJWK()
_jwk.Kid = "c8da1302-07d6-11ea-882f-4893bca275b8"
_jwk := testutils.JWK()
_jwkBytes, err := json.Marshal(_jwk)
if err != nil {
t.Errorf("unable to marshal JWK: %v", err)
}

_postgres, _mock := testPostgres(t)
defer func() { _sql, _ := _postgres.client.DB(); _sql.Close() }()
Expand All @@ -23,7 +27,7 @@ func TestJWK_Engine_CreateJWK(t *testing.T) {
_mock.ExpectExec(`INSERT INTO "jwks"
("id","active","key")
VALUES ($1,$2,$3)`).
WithArgs("c8da1302-07d6-11ea-882f-4893bca275b8", true, `{"alg":"","use":"","x5t":"","kid":"c8da1302-07d6-11ea-882f-4893bca275b8","kty":"","x5c":null,"n":"","e":""}`).
WithArgs(_jwk.KeyID(), true, _jwkBytes).
WillReturnResult(sqlmock.NewResult(1, 1))

_sqlite := testSqlite(t)
Expand Down
4 changes: 2 additions & 2 deletions database/jwk/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ package jwk
import (
"context"

Check failure on line 7 in database/jwk/get.go

View workflow job for this annotation

GitHub Actions / golangci

[golangci] database/jwk/get.go#L7

File is not `gci`-ed with --skip-generated -s standard -s default -s blank -s dot -s prefix(github.com/go-vela) --custom-order (gci)
Raw output
database/jwk/get.go:7: File is not `gci`-ed with --skip-generated -s standard -s default -s blank -s dot -s prefix(github.com/go-vela) --custom-order (gci)
api "github.com/go-vela/server/api/types"
"github.com/go-vela/server/constants"
"github.com/go-vela/server/database/types"
"github.com/lestrrat-go/jwx/jwk"
)

// GetActiveJWK gets a JWK by UUID (kid) from the database if active.
func (e *engine) GetActiveJWK(_ context.Context, id string) (api.JWK, error) {
func (e *engine) GetActiveJWK(_ context.Context, id string) (jwk.RSAPublicKey, error) {
e.logger.Tracef("getting key %s from the database", id)

// variable to store query results
Expand Down
27 changes: 13 additions & 14 deletions database/jwk/get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,39 @@ package jwk

import (
"context"
"encoding/json"
"testing"

"github.com/DATA-DOG/go-sqlmock"
"github.com/google/go-cmp/cmp"
"github.com/lestrrat-go/jwx/jwk"

api "github.com/go-vela/server/api/types"
"github.com/go-vela/server/database/testutils"
)

func TestJWK_Engine_GetJWK(t *testing.T) {
// setup types
_jwk := testutils.APIJWK()
_jwk.Kid = "c8da1302-07d6-11ea-882f-4893bca275b8"
_jwk.Algorithm = "RS256"
_jwk.Kty = "rsa"
_jwk.Use = "sig"
_jwk.N = "123456"
_jwk.E = "123"
_jwk := testutils.JWK()
_jwkBytes, err := json.Marshal(_jwk)
if err != nil {
t.Errorf("unable to marshal JWK: %v", err)
}

_postgres, _mock := testPostgres(t)
defer func() { _sql, _ := _postgres.client.DB(); _sql.Close() }()

// create expected result in mock
_rows := sqlmock.NewRows(
[]string{"id", "active", "key"},
).AddRow("c8da1302-07d6-11ea-882f-4893bca275b8", true, []byte(`{"alg":"RS256","use":"sig","x5t":"","kid":"c8da1302-07d6-11ea-882f-4893bca275b8","kty":"rsa","x5c":null,"n":"123456","e":"123"}`))
).AddRow(_jwk.KeyID(), true, _jwkBytes)

// ensure the mock expects the query
_mock.ExpectQuery(`SELECT * FROM "jwks" WHERE id = $1 AND active = $2 LIMIT $3`).WithArgs("c8da1302-07d6-11ea-882f-4893bca275b8", true, 1).WillReturnRows(_rows)
_mock.ExpectQuery(`SELECT * FROM "jwks" WHERE id = $1 AND active = $2 LIMIT $3`).WithArgs(_jwk.KeyID(), true, 1).WillReturnRows(_rows)

_sqlite := testSqlite(t)
defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }()

err := _sqlite.CreateJWK(context.TODO(), _jwk)
err = _sqlite.CreateJWK(context.TODO(), _jwk)
if err != nil {
t.Errorf("unable to create test repo for sqlite: %v", err)
}
Expand All @@ -47,7 +46,7 @@ func TestJWK_Engine_GetJWK(t *testing.T) {
failure bool
name string
database *engine
want api.JWK
want jwk.RSAPublicKey
}{
{
failure: false,
Expand All @@ -66,7 +65,7 @@ func TestJWK_Engine_GetJWK(t *testing.T) {
// run tests
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got, err := test.database.GetActiveJWK(context.TODO(), "c8da1302-07d6-11ea-882f-4893bca275b8")
got, err := test.database.GetActiveJWK(context.TODO(), _jwk.KeyID())

if test.failure {
if err == nil {
Expand All @@ -80,7 +79,7 @@ func TestJWK_Engine_GetJWK(t *testing.T) {
t.Errorf("GetActiveJWK for %s returned err: %v", test.name, err)
}

if diff := cmp.Diff(got, test.want); diff != "" {
if diff := cmp.Diff(test.want, got, jwkOpts); diff != "" {
t.Errorf("GetActiveJWK mismatch (-want +got):\n%s", diff)
}
})
Expand Down
8 changes: 4 additions & 4 deletions database/jwk/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package jwk
import (
"context"

api "github.com/go-vela/server/api/types"
"github.com/lestrrat-go/jwx/jwk"
)

// JWKInterface represents the Vela interface for JWK
Expand All @@ -23,11 +23,11 @@ type JWKInterface interface {
// https://en.wikipedia.org/wiki/Data_manipulation_language

// CreateJWK defines a function that creates a JWK.
CreateJWK(context.Context, api.JWK) error
CreateJWK(context.Context, jwk.RSAPublicKey) error
// RotateKeys defines a function that rotates JWKs.
RotateKeys(context.Context) error
// ListJWKs defines a function that lists all JWKs configured.
ListJWKs(context.Context) ([]api.JWK, error)
ListJWKs(context.Context) (jwk.Set, error)
// GetJWK defines a function that gets a JWK by the provided key ID.
GetActiveJWK(context.Context, string) (api.JWK, error)
GetActiveJWK(context.Context, string) (jwk.RSAPublicKey, error)
}
Loading

0 comments on commit 75770cd

Please sign in to comment.