Skip to content

Commit

Permalink
fix: cpu contention when reading JWKs and suppress generating duplica…
Browse files Browse the repository at this point in the history
…te JWKs

Previously each concurrent caller would need to lock a shared mutex when reading or writing a given JWK set.
The read path now doesn't require locking a mutex at all and instead returns valid query results directly.
The write path is now protected by a concurrency control mechanism (using x/sync/singleflight) to ensure only one JWK set is generated and persisted.
Note: Duplicate JWK sets may still be improperly generated if running more than one Hydra instance in a high traffic environment.
  • Loading branch information
terev committed Nov 10, 2024
1 parent 825c24d commit 7aa20e9
Show file tree
Hide file tree
Showing 17 changed files with 111 additions and 120 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ format: .bin/goimports .bin/ory node_modules
mocks: .bin/mockgen
mockgen -package oauth2_test -destination oauth2/oauth2_provider_mock_test.go github.com/ory/fosite OAuth2Provider
mockgen -package jwk_test -destination jwk/registry_mock_test.go -source=jwk/registry.go
mockgen -package jwk_test -destination jwk/manager_mock_test.go -source=jwk/manager.go
go generate ./...

# Generates the SDKs
Expand Down
4 changes: 1 addition & 3 deletions cmd/server/helper_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import (
"encoding/pem"
"sync"

"github.com/gofrs/uuid"

"github.com/go-jose/go-jose/v3"

"github.com/ory/hydra/v2/driver"
Expand Down Expand Up @@ -58,7 +56,7 @@ func GetOrCreateTLSCertificate(ctx context.Context, d driver.Registry, iface con
}

// no certificates configured: self-sign a new cert
priv, err := jwk.GetOrGenerateKeys(ctx, d, d.SoftwareKeyManager(), TlsKeyName, uuid.Must(uuid.NewV4()).String(), "RS256")
priv, err := jwk.GetOrGenerateKeySetPrivateKey(ctx, d.SoftwareKeyManager(), TlsKeyName, "", "RS256")
if err != nil {
d.Logger().WithError(err).Fatal("Unable to fetch or generate HTTPS TLS key pair")
return nil // in case Fatal is hooked
Expand Down
4 changes: 3 additions & 1 deletion hsm/manager_hsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ import (
"net/http"
"sync"

"github.com/ory/hydra/v2/driver/config"
"github.com/ory/x/otelx"

"github.com/ory/hydra/v2/driver/config"

"github.com/pkg/errors"

"github.com/pborman/uuid"

"github.com/ory/fosite"

"github.com/ory/hydra/v2/jwk"

"github.com/miekg/pkcs11"
Expand Down
3 changes: 2 additions & 1 deletion hsm/manager_nohsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import (
"context"
"sync"

"github.com/ory/hydra/v2/driver/config"
"github.com/ory/x/logrusx"

"github.com/ory/hydra/v2/driver/config"

"github.com/pkg/errors"

"github.com/ory/hydra/v2/jwk"
Expand Down
4 changes: 2 additions & 2 deletions internal/mock/config_cookie.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 3 additions & 12 deletions jwk/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ import (
"github.com/ory/herodot"
"github.com/ory/x/httprouterx"

"github.com/gofrs/uuid"
"github.com/pkg/errors"

"github.com/ory/x/urlx"

"github.com/ory/x/errorsx"
Expand Down Expand Up @@ -101,17 +98,11 @@ func (h *Handler) discoverJsonWebKeys(w http.ResponseWriter, r *http.Request) {
for _, set := range wellKnownKeys {
set := set
eg.Go(func() error {
k, err := h.r.KeyManager().GetKeySet(ctx, set)
if errors.Is(err, x.ErrNotFound) {
h.r.Logger().Warnf("JSON Web Key Set %q does not exist yet, generating new key pair...", set)
k, err = h.r.KeyManager().GenerateAndPersistKeySet(ctx, set, uuid.Must(uuid.NewV4()).String(), string(jose.RS256), "sig")
if err != nil {
return err
}
} else if err != nil {
keySet, err := GetOrGenerateKeySet(ctx, h.r.KeyManager(), set, "", string(jose.RS256))
if err != nil {
return err
}
keys <- ExcludePrivateKeys(k)
keys <- ExcludePrivateKeys(keySet)
return nil
})
}
Expand Down
68 changes: 25 additions & 43 deletions jwk/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,69 +12,51 @@ import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"sync"

hydra "github.com/ory/hydra-client-go/v2"

"github.com/ory/x/josex"

"github.com/ory/x/errorsx"

hydra "github.com/ory/hydra-client-go/v2"
"github.com/ory/hydra/v2/x"

"github.com/ory/x/errorsx"

jose "github.com/go-jose/go-jose/v3"
"github.com/pkg/errors"
)

var mapLock sync.RWMutex
var locks = map[string]*sync.RWMutex{}

func getLock(set string) *sync.RWMutex {
mapLock.Lock()
defer mapLock.Unlock()
if _, ok := locks[set]; !ok {
locks[set] = new(sync.RWMutex)
}
return locks[set]
}

func EnsureAsymmetricKeypairExists(ctx context.Context, r InternalRegistry, alg, set string) error {
_, err := GetOrGenerateKeys(ctx, r, r.KeyManager(), set, set, alg)
_, err := GetOrGenerateKeySetPrivateKey(ctx, r.KeyManager(), set, set, alg)
return err
}

func GetOrGenerateKeys(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (private *jose.JSONWebKey, err error) {
getLock(set).Lock()
defer getLock(set).Unlock()

keys, err := m.GetKeySet(ctx, set)
if errors.Is(err, x.ErrNotFound) || keys != nil && len(keys.Keys) == 0 {
r.Logger().Warnf("JSON Web Key Set \"%s\" does not exist yet, generating new key pair...", set)
keys, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig")
if err != nil {
return nil, err
}
} else if err != nil {
func GetOrGenerateKeySetPrivateKey(ctx context.Context, m Manager, set, kid, alg string) (*jose.JSONWebKey, error) {
keySet, err := GetOrGenerateKeySet(ctx, m, set, kid, alg)
if err != nil {
return nil, err
}

privKey, privKeyErr := FindPrivateKey(keys)
if privKeyErr == nil {
privKey, err := FindPrivateKey(keySet)
if err == nil {
return privKey, nil
} else {
r.Logger().WithField("jwks", set).Warnf("JSON Web Key not found in JSON Web Key Set %s, generating new key pair...", set)
}

keys, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig")
if err != nil {
return nil, err
}
keySet, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig")
if err != nil {
return nil, err
}

privKey, err := FindPrivateKey(keys)
if err != nil {
return nil, err
}
return privKey, nil
return FindPrivateKey(keySet)
}

func GetOrGenerateKeySet(ctx context.Context, m Manager, set, kid, alg string) (*jose.JSONWebKeySet, error) {
keys, err := m.GetKeySet(ctx, set)
if err != nil && !errors.Is(err, x.ErrNotFound) {
return nil, err
} else if keys != nil && len(keys.Keys) > 0 {
return keys, nil
}

return m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig")
}

func First(keys []jose.JSONWebKey) *jose.JSONWebKey {
Expand Down
30 changes: 14 additions & 16 deletions jwk/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,19 @@ import (
"strings"
"testing"

gomock "github.com/golang/mock/gomock"
"github.com/pborman/uuid"
"github.com/pkg/errors"

hydra "github.com/ory/hydra-client-go/v2"

"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/cryptosigner"
"github.com/golang/mock/gomock"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ory/hydra/v2/internal"
"github.com/ory/hydra/v2/jwk"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/contextx"
)

type fakeSigner struct {
Expand Down Expand Up @@ -210,7 +209,6 @@ func TestExcludeOpaquePrivateKeys(t *testing.T) {

func TestGetOrGenerateKeys(t *testing.T) {
t.Parallel()
reg := internal.NewMockedRegistry(t, &contextx.Default{})

setId := uuid.NewUUID().String()
keyId := uuid.NewUUID().String()
Expand All @@ -226,46 +224,46 @@ func TestGetOrGenerateKeys(t *testing.T) {
return NewMockManager(ctrl)
}

t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GetKeySetError", func(t *testing.T) {
t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GetKeySetError", func(t *testing.T) {
keyManager := km(t)
keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(nil, errors.New("GetKeySetError"))
privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256")
privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), keyManager, setId, keyId, "RS256")
assert.Nil(t, privKey)
assert.EqualError(t, err, "GetKeySetError")
})

t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySetError", func(t *testing.T) {
t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GenerateAndPersistKeySetError", func(t *testing.T) {
keyManager := km(t)
keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(nil, errors.Wrap(x.ErrNotFound, ""))
keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(nil, errors.New("GetKeySetError"))
privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256")
privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), keyManager, setId, keyId, "RS256")
assert.Nil(t, privKey)
assert.EqualError(t, err, "GetKeySetError")
})

t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySetError", func(t *testing.T) {
t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GenerateAndPersistKeySetError", func(t *testing.T) {
keyManager := km(t)
keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil)
keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(nil, errors.New("GetKeySetError"))
privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256")
privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), keyManager, setId, keyId, "RS256")
assert.Nil(t, privKey)
assert.EqualError(t, err, "GetKeySetError")
})

t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GetKeySet_ContainsMissingPrivateKey", func(t *testing.T) {
t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GetKeySet_ContainsMissingPrivateKey", func(t *testing.T) {
keyManager := km(t)
keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil)
keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySet, nil)
privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256")
privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), keyManager, setId, keyId, "RS256")
assert.NoError(t, err)
assert.Equal(t, privKey, &keySet.Keys[0])
})

t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySet_ContainsMissingPrivateKey", func(t *testing.T) {
t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GenerateAndPersistKeySet_ContainsMissingPrivateKey", func(t *testing.T) {
keyManager := km(t)
keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil)
keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySetWithoutPrivateKey, nil).Times(1)
privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256")
privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), keyManager, setId, keyId, "RS256")
assert.Nil(t, privKey)
assert.EqualError(t, err, "key not found")
})
Expand Down
4 changes: 2 additions & 2 deletions jwk/jwt_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import (
"github.com/ory/x/josex"

"github.com/go-jose/go-jose/v3"
"github.com/gofrs/uuid"

"github.com/ory/fosite"

"github.com/ory/hydra/v2/driver/config"

"github.com/pkg/errors"
Expand Down Expand Up @@ -40,7 +40,7 @@ func NewDefaultJWTSigner(c *config.DefaultProvider, r InternalRegistry, setID st
}

func (j *DefaultJWTSigner) getKeys(ctx context.Context) (private *jose.JSONWebKey, err error) {
private, err = GetOrGenerateKeys(ctx, j.r, j.r.KeyManager(), j.setID, uuid.Must(uuid.NewV4()).String(), string(jose.RS256))
private, err = GetOrGenerateKeySetPrivateKey(ctx, j.r.KeyManager(), j.setID, "", string(jose.RS256))
if err == nil {
return private, nil
}
Expand Down
3 changes: 2 additions & 1 deletion jwk/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ import (

"github.com/pkg/errors"

"github.com/ory/x/errorsx"

"github.com/ory/hydra/v2/aead"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/errorsx"

jose "github.com/go-jose/go-jose/v3"
"github.com/gofrs/uuid"
Expand Down
4 changes: 2 additions & 2 deletions jwk/manager_mock_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion jwk/manager_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import (
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"

"github.com/ory/hydra/v2/x"
"github.com/ory/x/otelx"

"github.com/ory/hydra/v2/x"
)

const tracingComponent = "github.com/ory/hydra/v2/jwk"
Expand Down
4 changes: 2 additions & 2 deletions jwk/registry_mock_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion oauth2/oauth2_provider_mock_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions persistence/sql/migratest/migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ import (
"testing"
"time"

"github.com/ory/hydra/v2/internal"
"github.com/ory/x/contextx"

"github.com/ory/hydra/v2/internal"

"github.com/bradleyjkemp/cupaloy/v2"
"github.com/fatih/structs"
"github.com/gofrs/uuid"
Expand All @@ -28,10 +29,11 @@ import (
"github.com/ory/x/networkx"
"github.com/ory/x/sqlxx"

"github.com/ory/x/popx"

"github.com/ory/hydra/v2/flow"
testhelpersuuid "github.com/ory/hydra/v2/internal/testhelpers/uuid"
"github.com/ory/hydra/v2/persistence/sql"
"github.com/ory/x/popx"

"github.com/ory/x/sqlcon/dockertest"

Expand Down Expand Up @@ -76,7 +78,7 @@ func TestMigrations(t *testing.T) {
connections["postgres"] = dockertest.ConnectToTestPostgreSQLPop(t)
},
func() {
connections["mysql"] = dockertest.ConnectToTestMySQLPop(t)
g connections["mysql"] = dockertest.ConnectToTestMySQLPop(t)

Check failure on line 81 in persistence/sql/migratest/migration_test.go

View workflow job for this annotation

GitHub Actions / format

expected ';', found connections
},
func() {
connections["cockroach"] = dockertest.ConnectToTestCockroachDBPop(t)
Expand Down
Loading

0 comments on commit 7aa20e9

Please sign in to comment.