From 819f0b3e8771d1ce4f2c5a3b5c87b90ac2427847 Mon Sep 17 00:00:00 2001 From: Nathanael Liechti Date: Tue, 21 Nov 2023 16:16:50 +0100 Subject: [PATCH] feat(oidc): optionally query OIDC UserInfo to gather group claims (#12062) Signed-off-by: Nathanael Liechti --- docs/operator-manual/user-management/index.md | 14 + server/server.go | 32 +- server/server_test.go | 125 +++++++- util/cache/inmemory.go | 4 + util/oidc/oidc.go | 182 ++++++++++- util/oidc/oidc_test.go | 294 +++++++++++++++++- util/settings/settings.go | 34 ++ util/test/testutil.go | 10 + 8 files changed, 665 insertions(+), 30 deletions(-) diff --git a/docs/operator-manual/user-management/index.md b/docs/operator-manual/user-management/index.md index 8a5ba7802676d..496dd17a83e9f 100644 --- a/docs/operator-manual/user-management/index.md +++ b/docs/operator-manual/user-management/index.md @@ -387,6 +387,20 @@ For a simple case this can be: oidc.config: | requestedIDTokenClaims: {"groups": {"essential": true}} ``` + +### Retrieving group claims when not in the token + +Some OIDC providers don't return the group information for a user in the ID token, even if explicitly requested using the `requestedIDTokenClaims` setting (Okta for example). They instead provide the groups on the user info endpoint. With the following config, Argo CD queries the user info endpoint during login for groups information of a user: + +```yaml +oidc.config: | + enableUserInfoGroups: true + userInfoPath: /userinfo + userInfoCacheExpiration: "5m" +``` + +**Note: If you omit the `userInfoCacheExpiration` setting or if it's greater than the expiration of the ID token, the argocd-server will cache group information as long as the ID token is valid!** + ### Configuring a custom logout URL for your OIDC provider Optionally, if your OIDC provider exposes a logout API and you wish to configure a custom logout URL for the purposes of invalidating diff --git a/server/server.go b/server/server.go index e52416927143b..d9f1638024c51 100644 --- a/server/server.go +++ b/server/server.go @@ -1121,7 +1121,7 @@ func (a *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) { // Run dex OpenID Connect Identity Provider behind a reverse proxy (served at /api/dex) var err error mux.HandleFunc(common.DexAPIEndpoint+"/", dexutil.NewDexHTTPReverseProxy(a.DexServerAddr, a.BaseHRef, a.DexTLSConfig)) - a.ssoClientApp, err = oidc.NewClientApp(a.settings, a.DexServerAddr, a.DexTLSConfig, a.BaseHRef) + a.ssoClientApp, err = oidc.NewClientApp(a.settings, a.DexServerAddr, a.DexTLSConfig, a.BaseHRef, cacheutil.NewRedisCache(a.RedisClient, a.settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone)) errorsutil.CheckError(err) mux.HandleFunc(common.LoginEndpoint, a.ssoClientApp.HandleLogin) mux.HandleFunc(common.CallbackEndpoint, a.ssoClientApp.HandleCallback) @@ -1315,7 +1315,35 @@ func (a *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, error if err != nil { return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err) } - return claims, newToken, nil + + // Some SSO implementations (Okta) require a call to + // the OIDC user info path to get attributes like groups + // we assume that everywhere in argocd jwt.MapClaims is used as type for interface jwt.Claims + // otherwise this would cause a panic + var groupClaims jwt.MapClaims + if groupClaims, ok = claims.(jwt.MapClaims); !ok { + if tmpClaims, ok := claims.(*jwt.MapClaims); ok { + groupClaims = *tmpClaims + } + } + iss := jwtutil.StringField(groupClaims, "iss") + if iss != util_session.SessionManagerClaimsIssuer && a.settings.UserInfoGroupsEnabled() && a.settings.UserInfoPath() != "" { + userInfo, unauthorized, err := a.ssoClientApp.GetUserInfo(groupClaims, a.settings.IssuerURL(), a.settings.UserInfoPath()) + if unauthorized { + log.Errorf("error while quering userinfo endpoint: %v", err) + return claims, "", status.Errorf(codes.Unauthenticated, "invalid session") + } + if err != nil { + log.Errorf("error fetching user info endpoint: %v", err) + return claims, "", status.Errorf(codes.Internal, "invalid userinfo response") + } + if groupClaims["sub"] != userInfo["sub"] { + return claims, "", status.Error(codes.Unknown, "subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo") + } + groupClaims["groups"] = userInfo["groups"] + } + + return groupClaims, newToken, nil } // getToken extracts the token from gRPC metadata or cookie headers diff --git a/server/server_test.go b/server/server_test.go index 303f938871f38..acfb32e57e5d4 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -32,8 +32,10 @@ import ( "github.com/argoproj/argo-cd/v2/server/rbacpolicy" "github.com/argoproj/argo-cd/v2/test" "github.com/argoproj/argo-cd/v2/util/assets" + "github.com/argoproj/argo-cd/v2/util/cache" cacheutil "github.com/argoproj/argo-cd/v2/util/cache" appstatecache "github.com/argoproj/argo-cd/v2/util/cache/appstate" + "github.com/argoproj/argo-cd/v2/util/oidc" "github.com/argoproj/argo-cd/v2/util/rbac" settings_util "github.com/argoproj/argo-cd/v2/util/settings" testutil "github.com/argoproj/argo-cd/v2/util/test" @@ -533,7 +535,7 @@ func dexMockHandler(t *testing.T, url string) func(http.ResponseWriter, *http.Re } } -func getTestServer(t *testing.T, anonymousEnabled bool, withFakeSSO bool, useDexForSSO bool) (argocd *ArgoCDServer, oidcURL string) { +func getTestServer(t *testing.T, anonymousEnabled bool, withFakeSSO bool, useDexForSSO bool, additionalOIDCConfig settings_util.OIDCConfig) (argocd *ArgoCDServer, oidcURL string) { cm := test.NewFakeConfigMap() if anonymousEnabled { cm.Data["users.anonymous.enabled"] = "true" @@ -562,13 +564,12 @@ connectors: clientID: test-client clientSecret: $dex.oidc.clientSecret` } else { - oidcConfig := settings_util.OIDCConfig{ - Name: "Okta", - Issuer: oidcServer.URL, - ClientID: "argo-cd", - ClientSecret: "$oidc.okta.clientSecret", - } - oidcConfigString, err := yaml.Marshal(oidcConfig) + // override required oidc config fields but keep other configs as passed in + additionalOIDCConfig.Name = "Okta" + additionalOIDCConfig.Issuer = oidcServer.URL + additionalOIDCConfig.ClientID = "argo-cd" + additionalOIDCConfig.ClientSecret = "$oidc.okta.clientSecret" + oidcConfigString, err := yaml.Marshal(additionalOIDCConfig) require.NoError(t, err) cm.Data["oidc.config"] = string(oidcConfigString) // Avoid bothering with certs for local tests. @@ -589,9 +590,109 @@ connectors: argoCDOpts.DexServerAddr = ts.URL } argocd = NewServer(context.Background(), argoCDOpts) + var err error + argocd.ssoClientApp, err = oidc.NewClientApp(argocd.settings, argocd.DexServerAddr, argocd.DexTLSConfig, argocd.BaseHRef, cache.NewInMemoryCache(24*time.Hour)) + require.NoError(t, err) return argocd, oidcServer.URL } +func TestGetClaims(t *testing.T) { + + defaultExpiry := jwt.NewNumericDate(time.Now().Add(time.Hour * 24)) + defaultExpiryUnix := float64(defaultExpiry.Unix()) + + type testData struct { + test string + claims jwt.MapClaims + expectedErrorContains string + expectedClaims jwt.MapClaims + expectNewToken bool + additionalOIDCConfig settings_util.OIDCConfig + } + var tests = []testData{ + { + test: "GetClaims", + claims: jwt.MapClaims{ + "aud": "argo-cd", + "exp": defaultExpiry, + "sub": "randomUser", + }, + expectedErrorContains: "", + expectedClaims: jwt.MapClaims{ + "aud": "argo-cd", + "exp": defaultExpiryUnix, + "sub": "randomUser", + }, + expectNewToken: false, + additionalOIDCConfig: settings_util.OIDCConfig{}, + }, + { + // note: a passing test with user info groups can never be achieved since the user never logged in properly + // therefore the oidcClient's cache contains no accessToken for the user info endpoint + // and since the oidcClient cache is unexported (for good reasons) we can't mock this behaviour + test: "GetClaimsWithUserInfoGroupsEnabled", + claims: jwt.MapClaims{ + "aud": common.ArgoCDClientAppID, + "exp": defaultExpiry, + "sub": "randomUser", + }, + expectedErrorContains: "invalid session", + expectedClaims: jwt.MapClaims{ + "aud": common.ArgoCDClientAppID, + "exp": defaultExpiryUnix, + "sub": "randomUser", + }, + expectNewToken: false, + additionalOIDCConfig: settings_util.OIDCConfig{ + EnableUserInfoGroups: true, + UserInfoPath: "/userinfo", + UserInfoCacheExpiration: "5m", + }, + }, + } + + for _, testData := range tests { + testDataCopy := testData + + t.Run(testDataCopy.test, func(t *testing.T) { + t.Parallel() + + // Must be declared here to avoid race. + ctx := context.Background() //nolint:ineffassign,staticcheck + + argocd, oidcURL := getTestServer(t, false, true, false, testDataCopy.additionalOIDCConfig) + + // create new JWT and store it on the context to simulate an incoming request + testDataCopy.claims["iss"] = oidcURL + testDataCopy.expectedClaims["iss"] = oidcURL + token := jwt.NewWithClaims(jwt.SigningMethodRS512, testDataCopy.claims) + key, err := jwt.ParseRSAPrivateKeyFromPEM(testutil.PrivateKey) + require.NoError(t, err) + tokenString, err := token.SignedString(key) + require.NoError(t, err) + ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(apiclient.MetaDataTokenKey, tokenString)) + + gotClaims, newToken, err := argocd.getClaims(ctx) + + // Note: testutil.oidcMockHandler currently doesn't implement reissuing expired tokens + // so newToken will always be empty + if testDataCopy.expectNewToken { + assert.NotEmpty(t, newToken) + } + if testDataCopy.expectedClaims == nil { + assert.Nil(t, gotClaims) + } else { + assert.Equal(t, testDataCopy.expectedClaims, gotClaims) + } + if testDataCopy.expectedErrorContains != "" { + assert.ErrorContains(t, err, testDataCopy.expectedErrorContains, "getClaims should have thrown an error and return an error") + } else { + assert.NoError(t, err) + } + }) + } +} + func TestAuthenticate_3rd_party_JWTs(t *testing.T) { // Marshaling single strings to strings is typical, so we test for this relatively common behavior. jwt.MarshalSingleStringAsArray = false @@ -723,7 +824,7 @@ func TestAuthenticate_3rd_party_JWTs(t *testing.T) { // Must be declared here to avoid race. ctx := context.Background() //nolint:ineffassign,staticcheck - argocd, oidcURL := getTestServer(t, testDataCopy.anonymousEnabled, true, testDataCopy.useDex) + argocd, oidcURL := getTestServer(t, testDataCopy.anonymousEnabled, true, testDataCopy.useDex, settings_util.OIDCConfig{}) if testDataCopy.useDex { testDataCopy.claims.Issuer = fmt.Sprintf("%s/api/dex", oidcURL) @@ -779,7 +880,7 @@ func TestAuthenticate_no_request_metadata(t *testing.T) { t.Run(testDataCopy.test, func(t *testing.T) { t.Parallel() - argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true) + argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true, settings_util.OIDCConfig{}) ctx := context.Background() ctx, err := argocd.Authenticate(ctx) @@ -825,7 +926,7 @@ func TestAuthenticate_no_SSO(t *testing.T) { // Must be declared here to avoid race. ctx := context.Background() //nolint:ineffassign,staticcheck - argocd, dexURL := getTestServer(t, testDataCopy.anonymousEnabled, false, true) + argocd, dexURL := getTestServer(t, testDataCopy.anonymousEnabled, false, true, settings_util.OIDCConfig{}) token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{Issuer: fmt.Sprintf("%s/api/dex", dexURL)}) tokenString, err := token.SignedString([]byte("key")) require.NoError(t, err) @@ -933,7 +1034,7 @@ func TestAuthenticate_bad_request_metadata(t *testing.T) { // Must be declared here to avoid race. ctx := context.Background() //nolint:ineffassign,staticcheck - argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true) + argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true, settings_util.OIDCConfig{}) ctx = metadata.NewIncomingContext(context.Background(), testDataCopy.metadata) ctx, err := argocd.Authenticate(ctx) diff --git a/util/cache/inmemory.go b/util/cache/inmemory.go index 53e690925d940..f75688c275546 100644 --- a/util/cache/inmemory.go +++ b/util/cache/inmemory.go @@ -16,6 +16,10 @@ func NewInMemoryCache(expiration time.Duration) *InMemoryCache { } } +func init() { + gob.Register([]interface{}{}) +} + // compile-time validation of adherance of the CacheClient contract var _ CacheClient = &InMemoryCache{} diff --git a/util/oidc/oidc.go b/util/oidc/oidc.go index 3df3166490172..2c376cc7e5b5b 100644 --- a/util/oidc/oidc.go +++ b/util/oidc/oidc.go @@ -6,6 +6,7 @@ import ( "fmt" "html" "html/template" + "io" "net" "net/http" "net/url" @@ -21,9 +22,12 @@ import ( "github.com/argoproj/argo-cd/v2/common" "github.com/argoproj/argo-cd/v2/server/settings/oidc" + "github.com/argoproj/argo-cd/v2/util/cache" "github.com/argoproj/argo-cd/v2/util/crypto" "github.com/argoproj/argo-cd/v2/util/dex" + httputil "github.com/argoproj/argo-cd/v2/util/http" + jwtutil "github.com/argoproj/argo-cd/v2/util/jwt" "github.com/argoproj/argo-cd/v2/util/rand" "github.com/argoproj/argo-cd/v2/util/settings" ) @@ -31,9 +35,11 @@ import ( var InvalidRedirectURLError = fmt.Errorf("invalid return URL") const ( - GrantTypeAuthorizationCode = "authorization_code" - GrantTypeImplicit = "implicit" - ResponseTypeCode = "code" + GrantTypeAuthorizationCode = "authorization_code" + GrantTypeImplicit = "implicit" + ResponseTypeCode = "code" + UserInfoResponseCachePrefix = "userinfo_response" + AccessTokenCachePrefix = "access_token" ) // OIDCConfiguration holds a subset of interested fields from the OIDC configuration spec @@ -57,6 +63,8 @@ type ClientApp struct { redirectURI string // URL of the issuer (e.g. https://argocd.example.com/api/dex) issuerURL string + // the path where the issuer providers user information (e.g /user-info for okta) + userInfoPath string // The URL endpoint at which the ArgoCD server is accessed. baseHRef string // client is the HTTP client which is used to query the IDp @@ -70,6 +78,8 @@ type ClientApp struct { encryptionKey []byte // provider is the OIDC provider provider Provider + // clientCache represent a cache of sso artifact + clientCache cache.CacheClient } func GetScopesOrDefault(scopes []string) []string { @@ -81,7 +91,7 @@ func GetScopesOrDefault(scopes []string) []string { // NewClientApp will register the Argo CD client app (either via Dex or external OIDC) and return an // object which has HTTP handlers for handling the HTTP responses for login and callback -func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTlsConfig *dex.DexTLSConfig, baseHRef string) (*ClientApp, error) { +func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTlsConfig *dex.DexTLSConfig, baseHRef string, cacheClient cache.CacheClient) (*ClientApp, error) { redirectURL, err := settings.RedirectURL() if err != nil { return nil, err @@ -95,8 +105,10 @@ func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTl clientSecret: settings.OAuth2ClientSecret(), redirectURI: redirectURL, issuerURL: settings.IssuerURL(), + userInfoPath: settings.UserInfoPath(), baseHRef: baseHRef, encryptionKey: encryptionKey, + clientCache: cacheClient, } log.Infof("Creating client app (%s)", a.clientID) u, err := url.Parse(settings.URL) @@ -376,6 +388,26 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } + // save the accessToken in memory for later use + encToken, err := crypto.Encrypt([]byte(token.AccessToken), a.encryptionKey) + if err != nil { + claimsJSON, _ := json.Marshal(claims) + http.Error(w, "failed encrypting token", http.StatusInternalServerError) + log.Errorf("cannot encrypt accessToken: %v (claims=%s)", err, claimsJSON) + return + } + sub := jwtutil.StringField(claims, "sub") + err = a.clientCache.Set(&cache.Item{ + Key: formatAccessTokenCacheKey(AccessTokenCachePrefix, sub), + Object: encToken, + Expiration: getTokenExpiration(claims), + }) + if err != nil { + claimsJSON, _ := json.Marshal(claims) + http.Error(w, fmt.Sprintf("claims=%s, err=%v", claimsJSON, err), http.StatusInternalServerError) + return + } + if idTokenRAW != "" { cookies, err := httputil.MakeCookieMetadata(common.AuthCookieName, idTokenRAW, flags...) if err != nil { @@ -509,3 +541,145 @@ func createClaimsAuthenticationRequestParameter(requestedClaims map[string]*oidc } return oauth2.SetAuthURLParam("claims", string(claimsRequestRAW)), nil } + +// GetUserInfo queries the IDP userinfo endpoint for claims +func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) { + sub := jwtutil.StringField(actualClaims, "sub") + var claims jwt.MapClaims + var encClaims []byte + + // in case we got it in the cache, we just return the item + clientCacheKey := formatUserInfoResponseCacheKey(UserInfoResponseCachePrefix, sub) + if err := a.clientCache.Get(clientCacheKey, &encClaims); err == nil { + claimsRaw, err := crypto.Decrypt(encClaims, a.encryptionKey) + if err != nil { + log.Errorf("decrypting the cached claims failed (sub=%s): %s", sub, err) + } else { + err = json.Unmarshal(claimsRaw, &claims) + if err != nil { + log.Errorf("cannot unmarshal cached claims structure: %s", err) + } else { + // return the cached claims since they are not yet expired, were successfully decrypted and unmarshaled + return claims, false, err + } + } + } + + // check if the accessToken for the user is still present + var encAccessToken []byte + err := a.clientCache.Get(formatAccessTokenCacheKey(AccessTokenCachePrefix, sub), &encAccessToken) + // without an accessToken we can't query the user info endpoint + // thus the user needs to reauthenticate for argocd to get a new accessToken + if err == cache.ErrCacheMiss { + return claims, true, fmt.Errorf("no accessToken for %s: %w", sub, err) + } else if err != nil { + return claims, true, fmt.Errorf("couldn't read accessToken from cache for %s: %w", sub, err) + } + + accessToken, err := crypto.Decrypt(encAccessToken, a.encryptionKey) + if err != nil { + return claims, true, fmt.Errorf("couldn't decrypt accessToken for %s: %w", sub, err) + } + + url := issuerURL + userInfoPath + request, err := http.NewRequest("GET", url, nil) + + if err != nil { + err = fmt.Errorf("failed creating new http request: %w", err) + return claims, false, err + } + + bearer := fmt.Sprintf("Bearer %s", accessToken) + request.Header.Set("Authorization", bearer) + + response, err := a.client.Do(request) + if err != nil { + return claims, false, fmt.Errorf("failed to query userinfo endpoint of IDP: %w", err) + } + defer response.Body.Close() + if response.StatusCode == http.StatusUnauthorized { + return claims, true, err + } + + // according to https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponseValidation + // the response should be validated + header := response.Header.Get("content-type") + rawBody, err := io.ReadAll(response.Body) + if err != nil { + return claims, false, fmt.Errorf("got error reading response body: %w", err) + } + switch header { + case "application/jwt": + // if body is JWT, first validate it before extracting claims + idToken, err := a.provider.Verify(string(rawBody), a.settings) + if err != nil { + return claims, false, fmt.Errorf("user info response in jwt format not valid: %w", err) + } + err = idToken.Claims(claims) + if err != nil { + return claims, false, fmt.Errorf("cannot get claims from userinfo jwt: %w", err) + } + default: + // if body is json, unsigned and unencrypted claims can be deserialized + err = json.Unmarshal(rawBody, &claims) + if err != nil { + return claims, false, fmt.Errorf("failed to decode response body to struct: %w", err) + } + } + + // in case response was successfully validated and there was no error, put item in cache + // but first let's determine the expiry of the cache + var cacheExpiry time.Duration + settingExpiry := a.settings.UserInfoCacheExpiration() + tokenExpiry := getTokenExpiration(claims) + + // only use configured expiry if the token lives longer and the expiry is configured + // if the token has no expiry, use the expiry of the actual token + // otherwise use the expiry of the token + if settingExpiry < tokenExpiry && settingExpiry != 0 { + cacheExpiry = settingExpiry + } else if tokenExpiry < 0 { + cacheExpiry = getTokenExpiration(actualClaims) + } else { + cacheExpiry = tokenExpiry + } + + rawClaims, err := json.Marshal(claims) + if err != nil { + return claims, false, fmt.Errorf("couldn't marshal claim to json: %w", err) + } + encClaims, err = crypto.Encrypt(rawClaims, a.encryptionKey) + if err != nil { + return claims, false, fmt.Errorf("couldn't encrypt user info response: %w", err) + } + + err = a.clientCache.Set(&cache.Item{ + Key: clientCacheKey, + Object: encClaims, + Expiration: cacheExpiry, + }) + if err != nil { + return claims, false, fmt.Errorf("couldn't put item to cache: %w", err) + } + + return claims, false, nil +} + +// getTokenExpiration returns a time.Duration until the token expires +func getTokenExpiration(claims jwt.MapClaims) time.Duration { + // get duration until token expires + exp := jwtutil.Float64Field(claims, "exp") + tm := time.Unix(int64(exp), 0) + tokenExpiry := time.Until(tm) + return tokenExpiry +} + +// formatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache +func formatUserInfoResponseCacheKey(prefix, sub string) string { + return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, sub) +} + +// formatAccessTokenCacheKey returns the key which is used to store the accessToken of a user in cache +func formatAccessTokenCacheKey(prefix, sub string) string { + return fmt.Sprintf("%s_%s", prefix, sub) +} diff --git a/util/oidc/oidc_test.go b/util/oidc/oidc_test.go index fe5fa77eed3b5..cd1d3fa1bf789 100644 --- a/util/oidc/oidc_test.go +++ b/util/oidc/oidc_test.go @@ -11,8 +11,10 @@ import ( "os" "strings" "testing" + "time" gooidc "github.com/coreos/go-oidc/v3/oidc" + "github.com/golang-jwt/jwt/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -20,6 +22,7 @@ import ( "github.com/argoproj/argo-cd/v2/common" "github.com/argoproj/argo-cd/v2/server/settings/oidc" "github.com/argoproj/argo-cd/v2/util" + "github.com/argoproj/argo-cd/v2/util/cache" "github.com/argoproj/argo-cd/v2/util/crypto" "github.com/argoproj/argo-cd/v2/util/dex" "github.com/argoproj/argo-cd/v2/util/settings" @@ -126,7 +129,7 @@ clientID: xxx clientSecret: yyy requestedScopes: ["oidc"]`, oidcTestServer.URL), } - app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) req := httptest.NewRequest(http.MethodGet, "https://argocd.example.com/auth/login", nil) @@ -141,7 +144,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), cdSettings.OIDCTLSInsecureSkipVerify = true - app, err = NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err = NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) w = httptest.NewRecorder() @@ -166,7 +169,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), require.NoError(t, err) cdSettings.Certificate = &cert - app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) req := httptest.NewRequest(http.MethodGet, "https://argocd.example.com/auth/login", nil) @@ -179,7 +182,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), t.Fatal("did not receive expected certificate verification failure error") } - app, err = NewClientApp(cdSettings, dexTestServer.URL, &dex.DexTLSConfig{StrictValidation: false}, "https://argocd.example.com") + app, err = NewClientApp(cdSettings, dexTestServer.URL, &dex.DexTLSConfig{StrictValidation: false}, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) w = httptest.NewRecorder() @@ -211,7 +214,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), // The base href (the last argument for NewClientApp) is what HandleLogin will fall back to when no explicit // redirect URL is given. - app, err := NewClientApp(cdSettings, "", nil, "/") + app, err := NewClientApp(cdSettings, "", nil, "/", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) w := httptest.NewRecorder() @@ -254,7 +257,7 @@ clientID: xxx clientSecret: yyy requestedScopes: ["oidc"]`, oidcTestServer.URL), } - app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) req := httptest.NewRequest(http.MethodGet, "https://argocd.example.com/auth/callback", nil) @@ -269,7 +272,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), cdSettings.OIDCTLSInsecureSkipVerify = true - app, err = NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err = NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) w = httptest.NewRecorder() @@ -294,7 +297,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), require.NoError(t, err) cdSettings.Certificate = &cert - app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com") + app, err := NewClientApp(cdSettings, dexTestServer.URL, nil, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) req := httptest.NewRequest(http.MethodGet, "https://argocd.example.com/auth/callback", nil) @@ -307,7 +310,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL), t.Fatal("did not receive expected certificate verification failure error") } - app, err = NewClientApp(cdSettings, dexTestServer.URL, &dex.DexTLSConfig{StrictValidation: false}, "https://argocd.example.com") + app, err = NewClientApp(cdSettings, dexTestServer.URL, &dex.DexTLSConfig{StrictValidation: false}, "https://argocd.example.com", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) w = httptest.NewRecorder() @@ -406,7 +409,7 @@ func TestGenerateAppState(t *testing.T) { signature, err := util.MakeSignature(32) require.NoError(t, err) expectedReturnURL := "http://argocd.example.com/" - app, err := NewClientApp(&settings.ArgoCDSettings{ServerSignature: signature, URL: expectedReturnURL}, "", nil, "") + app, err := NewClientApp(&settings.ArgoCDSettings{ServerSignature: signature, URL: expectedReturnURL}, "", nil, "", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) generateResponse := httptest.NewRecorder() state, err := app.generateAppState(expectedReturnURL, generateResponse) @@ -443,7 +446,7 @@ func TestGenerateAppState_XSS(t *testing.T) { URL: "https://argocd.example.com", ServerSignature: signature, }, - "", nil, "", + "", nil, "", cache.NewInMemoryCache(24*time.Hour), ) require.NoError(t, err) @@ -495,7 +498,7 @@ func TestGenerateAppState_NoReturnURL(t *testing.T) { encrypted, err := crypto.Encrypt([]byte("123"), key) require.NoError(t, err) - app, err := NewClientApp(cdSettings, "", nil, "/argo-cd") + app, err := NewClientApp(cdSettings, "", nil, "/argo-cd", cache.NewInMemoryCache(24*time.Hour)) require.NoError(t, err) req.AddCookie(&http.Cookie{Name: common.StateCookieName, Value: hex.EncodeToString(encrypted)}) @@ -503,3 +506,270 @@ func TestGenerateAppState_NoReturnURL(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "/argo-cd", returnURL) } + +func TestGetUserInfo(t *testing.T) { + + var tests = []struct { + name string + userInfoPath string + expectedOutput interface{} + expectError bool + expectUnauthenticated bool + expectedCacheItems []struct { // items to check in cache after function call + key string + value string + expectEncrypted bool + expectError bool + } + idpHandler func(w http.ResponseWriter, r *http.Request) + idpClaims jwt.MapClaims // as per specification sub and exp are REQUIRED fields + cache cache.CacheClient + cacheItems []struct { // items to put in cache before execution + key string + value string + encrypt bool + } + }{ + { + name: "call UserInfo with wrong userInfoPath", + userInfoPath: "/user", + expectedOutput: jwt.MapClaims(nil), + expectError: true, + expectUnauthenticated: false, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + key: formatUserInfoResponseCacheKey(UserInfoResponseCachePrefix, "randomUser"), + expectError: true, + }, + }, + idpClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + idpHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + cacheItems: []struct { + key string + value string + encrypt bool + }{ + { + key: formatAccessTokenCacheKey(AccessTokenCachePrefix, "randomUser"), + value: "FakeAccessToken", + encrypt: true, + }, + }, + }, + { + name: "call UserInfo with bad accessToken", + userInfoPath: "/user-info", + expectedOutput: jwt.MapClaims(nil), + expectError: false, + expectUnauthenticated: true, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + key: formatUserInfoResponseCacheKey(UserInfoResponseCachePrefix, "randomUser"), + expectError: true, + }, + }, + idpClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + idpHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + cacheItems: []struct { + key string + value string + encrypt bool + }{ + { + key: formatAccessTokenCacheKey(AccessTokenCachePrefix, "randomUser"), + value: "FakeAccessToken", + encrypt: true, + }, + }, + }, + { + name: "call UserInfo with garbage returned", + userInfoPath: "/user-info", + expectedOutput: jwt.MapClaims(nil), + expectError: true, + expectUnauthenticated: false, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + key: formatUserInfoResponseCacheKey(UserInfoResponseCachePrefix, "randomUser"), + expectError: true, + }, + }, + idpClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + idpHandler: func(w http.ResponseWriter, r *http.Request) { + userInfoBytes := ` + notevenJsongarbage + ` + _, err := w.Write([]byte(userInfoBytes)) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusTeapot) + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + cacheItems: []struct { + key string + value string + encrypt bool + }{ + { + key: formatAccessTokenCacheKey(AccessTokenCachePrefix, "randomUser"), + value: "FakeAccessToken", + encrypt: true, + }, + }, + }, + { + name: "call UserInfo without accessToken in cache", + userInfoPath: "/user-info", + expectedOutput: jwt.MapClaims(nil), + expectError: true, + expectUnauthenticated: true, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + key: formatUserInfoResponseCacheKey(UserInfoResponseCachePrefix, "randomUser"), + expectError: true, + }, + }, + idpClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + idpHandler: func(w http.ResponseWriter, r *http.Request) { + userInfoBytes := ` + { + "groups":["githubOrg:engineers"] + }` + w.Header().Set("content-type", "application/json") + _, err := w.Write([]byte(userInfoBytes)) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + }, + { + name: "call UserInfo with valid accessToken in cache", + userInfoPath: "/user-info", + expectedOutput: jwt.MapClaims{"groups": []interface{}{"githubOrg:engineers"}}, + expectError: false, + expectUnauthenticated: false, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + key: formatUserInfoResponseCacheKey(UserInfoResponseCachePrefix, "randomUser"), + value: "{\"groups\":[\"githubOrg:engineers\"]}", + expectEncrypted: true, + expectError: false, + }, + }, + idpClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + idpHandler: func(w http.ResponseWriter, r *http.Request) { + userInfoBytes := ` + { + "groups":["githubOrg:engineers"] + }` + w.Header().Set("content-type", "application/json") + _, err := w.Write([]byte(userInfoBytes)) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + cacheItems: []struct { + key string + value string + encrypt bool + }{ + { + key: formatAccessTokenCacheKey(AccessTokenCachePrefix, "randomUser"), + value: "FakeAccessToken", + encrypt: true, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(tt.idpHandler)) + defer ts.Close() + + signature, err := util.MakeSignature(32) + require.NoError(t, err) + cdSettings := &settings.ArgoCDSettings{ServerSignature: signature} + encryptionKey, err := cdSettings.GetServerEncryptionKey() + assert.NoError(t, err) + a, _ := NewClientApp(cdSettings, "", nil, "/argo-cd", tt.cache) + + for _, item := range tt.cacheItems { + var newValue []byte + newValue = []byte(item.value) + if item.encrypt { + newValue, err = crypto.Encrypt([]byte(item.value), encryptionKey) + assert.NoError(t, err) + } + err := a.clientCache.Set(&cache.Item{ + Key: item.key, + Object: newValue, + }) + require.NoError(t, err) + } + + got, unauthenticated, err := a.GetUserInfo(tt.idpClaims, ts.URL, tt.userInfoPath) + assert.Equal(t, tt.expectedOutput, got) + assert.Equal(t, tt.expectUnauthenticated, unauthenticated) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + for _, item := range tt.expectedCacheItems { + var tmpValue []byte + err := a.clientCache.Get(item.key, &tmpValue) + if item.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + if item.expectEncrypted { + tmpValue, err = crypto.Decrypt(tmpValue, encryptionKey) + require.NoError(t, err) + } + assert.Equal(t, item.value, string(tmpValue)) + } + } + }) + } + +} diff --git a/util/settings/settings.go b/util/settings/settings.go index bc091e8b818ec..baff450aa817e 100644 --- a/util/settings/settings.go +++ b/util/settings/settings.go @@ -161,6 +161,9 @@ func (o *oidcConfig) toExported() *OIDCConfig { ClientID: o.ClientID, ClientSecret: o.ClientSecret, CLIClientID: o.CLIClientID, + UserInfoPath: o.UserInfoPath, + EnableUserInfoGroups: o.EnableUserInfoGroups, + UserInfoCacheExpiration: o.UserInfoCacheExpiration, RequestedScopes: o.RequestedScopes, RequestedIDTokenClaims: o.RequestedIDTokenClaims, LogoutURL: o.LogoutURL, @@ -175,6 +178,9 @@ type OIDCConfig struct { ClientID string `json:"clientID,omitempty"` ClientSecret string `json:"clientSecret,omitempty"` CLIClientID string `json:"cliClientID,omitempty"` + EnableUserInfoGroups bool `json:"enableUserInfoGroups,omitempty"` + UserInfoPath string `json:"userInfoPath,omitempty"` + UserInfoCacheExpiration string `json:"userInfoCacheExpiration,omitempty"` RequestedScopes []string `json:"requestedScopes,omitempty"` RequestedIDTokenClaims map[string]*oidc.Claim `json:"requestedIDTokenClaims,omitempty"` LogoutURL string `json:"logoutURL,omitempty"` @@ -1850,6 +1856,34 @@ func (a *ArgoCDSettings) IssuerURL() string { return "" } +// UserInfoGroupsEnabled returns whether group claims should be fetch from UserInfo endpoint +func (a *ArgoCDSettings) UserInfoGroupsEnabled() bool { + if oidcConfig := a.OIDCConfig(); oidcConfig != nil { + return oidcConfig.EnableUserInfoGroups + } + return false +} + +// UserInfoPath returns the sub-path on which the IDP exposes the UserInfo endpoint +func (a *ArgoCDSettings) UserInfoPath() string { + if oidcConfig := a.OIDCConfig(); oidcConfig != nil { + return oidcConfig.UserInfoPath + } + return "" +} + +// UserInfoCacheExpiration returns the expiry time of the UserInfo cache +func (a *ArgoCDSettings) UserInfoCacheExpiration() time.Duration { + if oidcConfig := a.OIDCConfig(); oidcConfig != nil && oidcConfig.UserInfoCacheExpiration != "" { + userInfoCacheExpiration, err := time.ParseDuration(oidcConfig.UserInfoCacheExpiration) + if err != nil { + log.Warnf("Failed to parse 'oidc.config.userInfoCacheExpiration' key: %v", err) + } + return userInfoCacheExpiration + } + return 0 +} + func (a *ArgoCDSettings) OAuth2ClientID() string { if oidcConfig := a.OIDCConfig(); oidcConfig != nil { return oidcConfig.ClientID diff --git a/util/test/testutil.go b/util/test/testutil.go index 6fdbd4151d82c..1cb23bc08bb3e 100644 --- a/util/test/testutil.go +++ b/util/test/testutil.go @@ -168,6 +168,16 @@ func oidcMockHandler(t *testing.T, url string) func(http.ResponseWriter, *http.R "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], "claims_supported": ["sub", "aud", "exp"] }`, url)) + require.NoError(t, err) + case "/userinfo": + w.Header().Set("content-type", "application/json") + _, err := io.WriteString(w, fmt.Sprintf(` +{ + "groups":["githubOrg:engineers"], + "iss": "%[1]s", + "sub": "randomUser" +}`, url)) + require.NoError(t, err) case "/keys": pubKey, err := jwt.ParseRSAPublicKeyFromPEM(Cert)