From 56b9195969db503e924c24eaa4235f086bf15765 Mon Sep 17 00:00:00 2001 From: spacewander Date: Wed, 6 Mar 2024 14:44:53 +0800 Subject: [PATCH] oidc: handle refresh token Fix #338 Signed-off-by: spacewander --- plugins/oidc/config.go | 14 +++ plugins/oidc/config.pb.go | 51 ++++++-- plugins/oidc/config.pb.validate.go | 32 +++++ plugins/oidc/config.proto | 8 ++ plugins/oidc/config_test.go | 4 +- plugins/oidc/filter.go | 167 +++++++++++++++++++------- plugins/oidc/filter_test.go | 186 +++++++++++++++++++++++++---- 7 files changed, 383 insertions(+), 79 deletions(-) diff --git a/plugins/oidc/config.go b/plugins/oidc/config.go index 4ec0bd30c..6f5fe020f 100644 --- a/plugins/oidc/config.go +++ b/plugins/oidc/config.go @@ -16,6 +16,7 @@ package oidc import ( "context" + "encoding/base64" "net/http" "time" @@ -65,6 +66,8 @@ type config struct { oauth2Config *oauth2.Config verifier *oidc.IDTokenVerifier cookieEncoding *securecookie.SecureCookie + refreshLeeway time.Duration + cookieEntryID string } func (conf *config) ctxWithClient(ctx context.Context) context.Context { @@ -84,6 +87,13 @@ func (conf *config) Init(cb api.ConfigCallbackHandler) error { } conf.opTimeout = du + du = 10 * time.Second + leeway := conf.GetAccessTokenRefreshLeeway() + if leeway != nil { + du = leeway.AsDuration() + } + conf.refreshLeeway = du + ctx := conf.ctxWithClient(context.Background()) var provider *oidc.Provider var err error @@ -104,6 +114,9 @@ func (conf *config) Init(cb api.ConfigCallbackHandler) error { return err } + if !conf.DisableAccessTokenRefresh { + conf.Scopes = append(conf.Scopes, oidc.ScopeOfflineAccess) + } conf.oauth2Config = &oauth2.Config{ ClientID: conf.ClientId, ClientSecret: conf.ClientSecret, @@ -116,5 +129,6 @@ func (conf *config) Init(cb api.ConfigCallbackHandler) error { } conf.verifier = provider.Verifier(&oidc.Config{ClientID: conf.ClientId}) conf.cookieEncoding = securecookie.New([]byte(conf.ClientSecret), nil) + conf.cookieEntryID = base64.RawURLEncoding.EncodeToString([]byte(conf.ClientId)) return nil } diff --git a/plugins/oidc/config.pb.go b/plugins/oidc/config.pb.go index 73c8a1c0b..f3318874e 100644 --- a/plugins/oidc/config.pb.go +++ b/plugins/oidc/config.pb.go @@ -56,7 +56,12 @@ type Config struct { // Default to "x-id-token" IdTokenHeader string `protobuf:"bytes,7,opt,name=id_token_header,json=idTokenHeader,proto3" json:"id_token_header,omitempty"` // The timeout to wait for the OIDC provider to respond. Default to 3s. - Timeout *durationpb.Duration `protobuf:"bytes,8,opt,name=timeout,proto3" json:"timeout,omitempty"` + Timeout *durationpb.Duration `protobuf:"bytes,8,opt,name=timeout,proto3" json:"timeout,omitempty"` + DisableAccessTokenRefresh bool `protobuf:"varint,9,opt,name=disable_access_token_refresh,json=disableAccessTokenRefresh,proto3" json:"disable_access_token_refresh,omitempty"` + // The duration to determines how earlier a token should be considered + // expired than its actual expiration time. It is used to avoid late + // expirations due to client-server time mismatches. Default to 10s. + AccessTokenRefreshLeeway *durationpb.Duration `protobuf:"bytes,10,opt,name=access_token_refresh_leeway,json=accessTokenRefreshLeeway,proto3" json:"access_token_refresh_leeway,omitempty"` } func (x *Config) Reset() { @@ -147,6 +152,20 @@ func (x *Config) GetTimeout() *durationpb.Duration { return nil } +func (x *Config) GetDisableAccessTokenRefresh() bool { + if x != nil { + return x.DisableAccessTokenRefresh + } + return false +} + +func (x *Config) GetAccessTokenRefreshLeeway() *durationpb.Duration { + if x != nil { + return x.AccessTokenRefreshLeeway + } + return nil +} + var File_plugins_oidc_config_proto protoreflect.FileDescriptor var file_plugins_oidc_config_proto_rawDesc = []byte{ @@ -156,7 +175,7 @@ var file_plugins_oidc_config_proto_rawDesc = []byte{ 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x17, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x22, 0xd6, 0x02, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, + 0x74, 0x6f, 0x22, 0xfb, 0x03, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x07, 0xfa, 0x42, 0x04, 0x72, 0x02, 0x10, 0x01, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x2c, 0x0a, 0x0d, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, @@ -177,9 +196,20 @@ var file_plugins_oidc_config_proto_rawDesc = []byte{ 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x42, 0x08, 0xfa, 0x42, 0x05, 0xaa, 0x01, 0x02, - 0x2a, 0x00, 0x52, 0x07, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x42, 0x1b, 0x5a, 0x19, 0x6d, - 0x6f, 0x73, 0x6e, 0x2e, 0x69, 0x6f, 0x2f, 0x68, 0x74, 0x6e, 0x6e, 0x2f, 0x70, 0x6c, 0x75, 0x67, - 0x69, 0x6e, 0x73, 0x2f, 0x6f, 0x69, 0x64, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x2a, 0x00, 0x52, 0x07, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x3f, 0x0a, 0x1c, 0x64, + 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x5f, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x18, 0x09, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x19, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x12, 0x62, 0x0a, 0x1b, + 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x72, 0x65, 0x66, + 0x72, 0x65, 0x73, 0x68, 0x5f, 0x6c, 0x65, 0x65, 0x77, 0x61, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x42, 0x08, 0xfa, 0x42, + 0x05, 0xaa, 0x01, 0x02, 0x32, 0x00, 0x52, 0x18, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, + 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x4c, 0x65, 0x65, 0x77, 0x61, 0x79, + 0x42, 0x1b, 0x5a, 0x19, 0x6d, 0x6f, 0x73, 0x6e, 0x2e, 0x69, 0x6f, 0x2f, 0x68, 0x74, 0x6e, 0x6e, + 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2f, 0x6f, 0x69, 0x64, 0x63, 0x62, 0x06, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -201,11 +231,12 @@ var file_plugins_oidc_config_proto_goTypes = []interface{}{ } var file_plugins_oidc_config_proto_depIdxs = []int32{ 1, // 0: plugins.oidc.Config.timeout:type_name -> google.protobuf.Duration - 1, // [1:1] is the sub-list for method output_type - 1, // [1:1] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name + 1, // 1: plugins.oidc.Config.access_token_refresh_leeway:type_name -> google.protobuf.Duration + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name } func init() { file_plugins_oidc_config_proto_init() } diff --git a/plugins/oidc/config.pb.validate.go b/plugins/oidc/config.pb.validate.go index 21578fbb0..979408324 100644 --- a/plugins/oidc/config.pb.validate.go +++ b/plugins/oidc/config.pb.validate.go @@ -154,6 +154,38 @@ func (m *Config) validate(all bool) error { } } + // no validation rules for DisableAccessTokenRefresh + + if d := m.GetAccessTokenRefreshLeeway(); d != nil { + dur, err := d.AsDuration(), d.CheckValid() + if err != nil { + err = ConfigValidationError{ + field: "AccessTokenRefreshLeeway", + reason: "value is not a valid duration", + cause: err, + } + if !all { + return err + } + errors = append(errors, err) + } else { + + gte := time.Duration(0*time.Second + 0*time.Nanosecond) + + if dur < gte { + err := ConfigValidationError{ + field: "AccessTokenRefreshLeeway", + reason: "value must be greater than or equal to 0s", + } + if !all { + return err + } + errors = append(errors, err) + } + + } + } + if len(errors) > 0 { return ConfigMultiError(errors) } diff --git a/plugins/oidc/config.proto b/plugins/oidc/config.proto index bd01afece..ee6159456 100644 --- a/plugins/oidc/config.proto +++ b/plugins/oidc/config.proto @@ -42,4 +42,12 @@ message Config { google.protobuf.Duration timeout = 8 [(validate.rules).duration = { gt: {}, }]; + + bool disable_access_token_refresh = 9; + // The duration to determines how earlier a token should be considered + // expired than its actual expiration time. It is used to avoid late + // expirations due to client-server time mismatches. Default to 10s. + google.protobuf.Duration access_token_refresh_leeway = 10 [(validate.rules).duration = { + gte: {}, + }]; } diff --git a/plugins/oidc/config_test.go b/plugins/oidc/config_test.go index 49f47c9c1..d262acbbe 100644 --- a/plugins/oidc/config_test.go +++ b/plugins/oidc/config_test.go @@ -25,7 +25,7 @@ import ( func TestBadIssuer(t *testing.T) { c := config{ Config: Config{ - Issuer: "http://github.com", + Issuer: "http://1.1.1.1", Timeout: &durationpb.Duration{Seconds: 1}, // quick fail }, } @@ -36,7 +36,7 @@ func TestBadIssuer(t *testing.T) { func TestDefaultValue(t *testing.T) { c := config{ Config: Config{ - Issuer: "http://github.com", + Issuer: "http://1.1.1.1", Timeout: &durationpb.Duration{Seconds: 1}, // quick fail }, } diff --git a/plugins/oidc/filter.go b/plugins/oidc/filter.go index aa2b40c6b..9697d8f50 100644 --- a/plugins/oidc/filter.go +++ b/plugins/oidc/filter.go @@ -41,13 +41,14 @@ func factory(c interface{}, callbacks api.FilterCallbackHandler) api.Filter { type filter struct { api.PassThroughFilter - callbacks api.FilterCallbackHandler - config *config + callbacks api.FilterCallbackHandler + config *config + tokenCookie *http.Cookie } type Tokens struct { - IDToken string `json:"id_token"` - AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + Oauth2Token *oauth2.Token `json:"oauth_token"` } func generateState(verifier string, secret string, url string) string { @@ -74,6 +75,10 @@ func signState(state string, secret string) string { return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) } +func (f *filter) CookieName(key string) string { + return fmt.Sprintf("htnn_oidc_%s_%s", key, f.config.cookieEntryID) +} + func (f *filter) handleInitRequest(headers api.RequestHeaderMap) api.ResultAction { config := f.config o2conf := config.oauth2Config @@ -90,13 +95,14 @@ func (f *filter) handleInitRequest(headers api.RequestHeaderMap) api.ResultActio oauth2.S256ChallengeOption(verifier), oauth2.SetAuthURLParam("nonce", nonce)) - n, err := config.cookieEncoding.Encode("htnn_oidc_nonce", nonce) + cookieName := f.CookieName("nonce") + n, err := config.cookieEncoding.Encode(cookieName, nonce) if err != nil { api.LogErrorf("failed to encode cookie: %v", err) return &api.LocalResponse{Code: 503, Msg: "failed to encode cookie"} } cookieNonce := &http.Cookie{ - Name: "htnn_oidc_nonce", + Name: cookieName, Value: n, MaxAge: int(time.Hour.Seconds()), HttpOnly: true, @@ -112,11 +118,31 @@ func (f *filter) handleInitRequest(headers api.RequestHeaderMap) api.ResultActio } } -func normalizeExpiry(expiry time.Time, def time.Duration) int { - if expiry.IsZero() { - return int(def.Seconds()) +func (f *filter) calculateTokenTTL(accessTokenExpiry time.Time, idTokenExpiry time.Time, refreshEnabled bool) int { + if refreshEnabled { + // As the access token refresh is enabled, we only need to consider the expiry of id token + return int(time.Until(idTokenExpiry).Seconds()) } - return int(time.Until(expiry).Seconds()) + + // Use the min expiry between id token and access token as the expiry + if accessTokenExpiry.IsZero() { + // According to https://openid.net/specs/openid-connect-core-1_0.html#IDToken, + // the expiry of id token is required. + // Meanwhile, the expiry of access token is optional. + return int(time.Until(idTokenExpiry).Seconds()) + } + return int(min( + time.Until(accessTokenExpiry).Seconds(), + time.Until(idTokenExpiry).Seconds())) +} + +func getIDToken(token *oauth2.Token) (string, bool) { + rawIDToken, ok := token.Extra("id_token").(string) + return rawIDToken, ok +} + +func (f *filter) refreshEnabled(token *oauth2.Token) bool { + return !f.config.DisableAccessTokenRefresh && token.RefreshToken != "" } func (f *filter) handleCallback(headers api.RequestHeaderMap, query url.Values) api.ResultAction { @@ -145,8 +171,7 @@ func (f *filter) handleCallback(headers api.RequestHeaderMap, query url.Values) return &api.LocalResponse{Code: 503, Msg: "failed to exchange code to the token"} } - // TODO: handle refresh_token - rawIDToken, ok := oauth2Token.Extra("id_token").(string) + rawIDToken, ok := getIDToken(oauth2Token) if !ok { api.LogErrorf("failed to lookup id token: %v", err) return &api.LocalResponse{Code: 503, Msg: "failed to lookup id token"} @@ -155,18 +180,19 @@ func (f *filter) handleCallback(headers api.RequestHeaderMap, query url.Values) idToken, err := config.verifier.Verify(ctx, rawIDToken) if err != nil { api.LogInfof("bad token: %s", err) - return &api.LocalResponse{Code: 403, Msg: "bad token"} + return &api.LocalResponse{Code: 503, Msg: "bad token"} } if !config.SkipNonceVerify { - nonce := headers.Cookie("htnn_oidc_nonce") + cookieName := f.CookieName("nonce") + nonce := headers.Cookie(cookieName) if nonce == nil { api.LogInfof("bad nonce, expected %s", idToken.Nonce) return &api.LocalResponse{Code: 403, Msg: "bad nonce"} } var p string - err := config.cookieEncoding.Decode("htnn_oidc_nonce", nonce.Value, &p) + err := config.cookieEncoding.Decode(cookieName, nonce.Value, &p) if err != nil || p != idToken.Nonce { if err != nil { api.LogInfof("bad nonce: %s, expected %s", err, idToken.Nonce) @@ -177,32 +203,11 @@ func (f *filter) handleCallback(headers api.RequestHeaderMap, query url.Values) } } - value := Tokens{ - IDToken: rawIDToken, - AccessToken: oauth2Token.AccessToken, - } - token, err := config.cookieEncoding.Encode("htnn_oidc_token", &value) + cookie, err := f.saveTokenAsCookie(ctx, oauth2Token, rawIDToken) if err != nil { - api.LogErrorf("failed to encode cookie: %v", err) - return &api.LocalResponse{Code: 503, Msg: "failed to encode cookie"} + return &api.LocalResponse{Code: 503, Msg: "failed to save token"} } - // Use the min expiry between id token and access token as the expiry - // According to https://openid.net/specs/openid-connect-core-1_0.html#IDToken, - // the expiry of id token is required. - // Meanwhile, the expiry of access token is optional. - // To be roburst & security, we assume an empty expiry means a 360-days expiry. - fallbackTTL := 360 * 24 * time.Hour - ttl := min( - normalizeExpiry(idToken.Expiry, fallbackTTL), - normalizeExpiry(oauth2Token.Expiry, fallbackTTL), - ) - cookie := &http.Cookie{ - Name: "htnn_oidc_token", - Value: token, - MaxAge: ttl, - HttpOnly: true, - } return &api.LocalResponse{ Code: http.StatusFound, Header: http.Header{ @@ -214,20 +219,58 @@ func (f *filter) handleCallback(headers api.RequestHeaderMap, query url.Values) func (f *filter) attachInfo(headers api.RequestHeaderMap, encodedToken string) api.ResultAction { config := f.config + ctx := context.Background() - value := Tokens{} - err := config.cookieEncoding.Decode("htnn_oidc_token", encodedToken, &value) + tokens := &Tokens{} + cookieName := f.CookieName("token") + err := config.cookieEncoding.Decode(cookieName, encodedToken, tokens) if err != nil { - api.LogInfof("bad oidc cookie: %s, err: %s", encodedToken, err.Error()) + api.LogInfof("bad oidc cookie: %s, err: %v", encodedToken, err) return &api.LocalResponse{Code: 403, Msg: "bad oidc cookie"} } - headers.Set("authorization", fmt.Sprintf("Bearer %s", value.AccessToken)) - headers.Set(config.IdTokenHeader, value.IDToken) + + oauth2Token := tokens.Oauth2Token + rawIDToken := tokens.IDToken + if f.refreshEnabled(oauth2Token) { + tokenSrc := config.oauth2Config.TokenSource(context.Background(), oauth2Token) + tokenSrc = oauth2.ReuseTokenSourceWithExpiry(oauth2Token, tokenSrc, config.refreshLeeway) + possibleRefreshedToken, err := tokenSrc.Token() + if err != nil { + api.LogWarnf("failed to refresh access token %s, err: %v, refresh token: %s", + oauth2Token.AccessToken, err, oauth2Token.RefreshToken) + return &api.LocalResponse{Code: 401} + } + + if possibleRefreshedToken.AccessToken != oauth2Token.AccessToken { + // token refreshed + oauth2Token = possibleRefreshedToken + newIDToken, ok := getIDToken(oauth2Token) + if ok { + rawIDToken = newIDToken + } + + f.tokenCookie, err = f.saveTokenAsCookie(ctx, possibleRefreshedToken, rawIDToken) + if err != nil { + return &api.LocalResponse{Code: 503, Msg: "failed to save token"} + } + } + + } else { + ok := oauth2Token.Valid() + if !ok { + api.LogInfo("access token is not valid") + return &api.LocalResponse{Code: 401} + } + } + + headers.Set("authorization", fmt.Sprintf("%s %s", oauth2Token.Type(), oauth2Token.AccessToken)) + headers.Set(config.IdTokenHeader, rawIDToken) return api.Continue } func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api.ResultAction { - token := headers.Cookie("htnn_oidc_token") + cookieName := f.CookieName("token") + token := headers.Cookie(cookieName) if token != nil { return f.attachInfo(headers, token.Value) } @@ -240,3 +283,39 @@ func (f *filter) DecodeHeaders(headers api.RequestHeaderMap, endStream bool) api return f.handleCallback(headers, query) } + +func (f *filter) saveTokenAsCookie(ctx context.Context, oauth2Token *oauth2.Token, rawIDToken string) (*http.Cookie, error) { + idToken, err := f.config.verifier.Verify(ctx, rawIDToken) + if err != nil { + api.LogErrorf("bad token: %v", err) + return nil, err + } + + cookieName := f.CookieName("token") + token, err := f.config.cookieEncoding.Encode(cookieName, Tokens{ + Oauth2Token: oauth2Token, + IDToken: rawIDToken, + }) + if err != nil { + api.LogErrorf("failed to encode cookie: %v", err) + return nil, err + } + + ttl := f.calculateTokenTTL(oauth2Token.Expiry, idToken.Expiry, f.refreshEnabled(oauth2Token)) + cookie := &http.Cookie{ + Name: cookieName, + Value: token, + MaxAge: ttl, + HttpOnly: true, + } + + api.LogInfof("token saved as cookie %+v, client id: %s", cookie, f.config.ClientId) + return cookie, nil +} + +func (f *filter) EncodeHeaders(headers api.ResponseHeaderMap, endStream bool) api.ResultAction { + if f.tokenCookie != nil { + headers.Add("set-cookie", f.tokenCookie.String()) + } + return api.Continue +} diff --git a/plugins/oidc/filter_test.go b/plugins/oidc/filter_test.go index 0f4fa27e7..032e183de 100644 --- a/plugins/oidc/filter_test.go +++ b/plugins/oidc/filter_test.go @@ -42,6 +42,7 @@ func getCfg() *config { oauth2Config: &oauth2.Config{}, verifier: &oidc.IDTokenVerifier{}, cookieEncoding: securecookie.New([]byte("dSYo5hBwjX_DC57_tfZHlfrDel"), nil), + cookieEntryID: "id", } } @@ -63,6 +64,8 @@ func TestInitRequest(t *testing.T) { func TestCallback(t *testing.T) { conf := getCfg() + conf.DisableAccessTokenRefresh = true + verifier := oauth2.GenerateVerifier() state := generateState(verifier, conf.ClientSecret, "https://127.0.0.1:2379/x?y=1") rawIDToken := "rawIDToken" @@ -72,7 +75,7 @@ func TestCallback(t *testing.T) { }).WithExtra(map[string]interface{}{ "id_token": rawIDToken, }) - nonce, _ := conf.cookieEncoding.Encode("htnn_oidc_nonce", "xxx") + nonce, _ := conf.cookieEncoding.Encode("htnn_oidc_nonce_id", "xxx") tests := []struct { name string @@ -85,7 +88,7 @@ func TestCallback(t *testing.T) { { name: "sanity", state: state, - cookie: "htnn_oidc_nonce=" + nonce, + cookie: "htnn_oidc_nonce_id=" + nonce, mock: func() *gomonkey.Patches { patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "Exchange", token, nil) patches.ApplyMethodReturn(conf.verifier, "Verify", &oidc.IDToken{ @@ -108,16 +111,40 @@ func TestCallback(t *testing.T) { assert.Equal(t, "Bearer accessToken", bearer) }, }, + { + name: "ttl with access token expiry", + state: state, + cookie: "htnn_oidc_nonce_id=" + nonce, + mock: func() *gomonkey.Patches { + token := (&oauth2.Token{ + Expiry: time.Now().Add(2 * time.Minute), + AccessToken: accessToken, + }).WithExtra(map[string]interface{}{ + "id_token": rawIDToken, + }) + patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "Exchange", token, nil) + patches.ApplyMethodReturn(conf.verifier, "Verify", &oidc.IDToken{ + Nonce: "xxx", Expiry: time.Now().Add(2 * time.Hour), + }, nil) + return patches + }, + checkRedirectClientBack: func(f *filter, headers http.Header) { + s := headers.Get("Location") + assert.Equal(t, "https://127.0.0.1:2379/x?y=1", s) + cookie := headers.Get("Set-Cookie") + assert.Contains(t, cookie, "Max-Age=119;") + }, + }, { name: "bad state", state: state + "x", - cookie: "htnn_oidc_nonce=" + nonce, + cookie: "htnn_oidc_nonce_id=" + nonce, res: &api.LocalResponse{Code: 403, Msg: "bad state"}, }, { name: "failed to exchange", state: state, - cookie: "htnn_oidc_nonce=" + nonce, + cookie: "htnn_oidc_nonce_id=" + nonce, mock: func() *gomonkey.Patches { patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "Exchange", nil, errors.New("timed out")) return patches @@ -127,7 +154,7 @@ func TestCallback(t *testing.T) { { name: "failed to lookup token", state: state, - cookie: "htnn_oidc_nonce=" + nonce, + cookie: "htnn_oidc_nonce_id=" + nonce, mock: func() *gomonkey.Patches { patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "Exchange", &oauth2.Token{}, nil) return patches @@ -137,18 +164,18 @@ func TestCallback(t *testing.T) { { name: "bad token", state: state, - cookie: "htnn_oidc_nonce=" + nonce, + cookie: "htnn_oidc_nonce_id=" + nonce, mock: func() *gomonkey.Patches { patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "Exchange", token, nil) patches.ApplyMethodReturn(conf.verifier, "Verify", nil, errors.New("ouch")) return patches }, - res: &api.LocalResponse{Code: 403, Msg: "bad token"}, + res: &api.LocalResponse{Code: 503, Msg: "bad token"}, }, { name: "bad nonce", state: state, - cookie: "htnn_oidc_nonce=xxy", + cookie: "htnn_oidc_nonce_id=xxy", mock: func() *gomonkey.Patches { patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "Exchange", token, nil) patches.ApplyMethodReturn(conf.verifier, "Verify", &oidc.IDToken{Nonce: "xxx"}, nil) @@ -175,7 +202,7 @@ func TestCallback(t *testing.T) { } cb := envoy.NewFilterCallbackHandler() - f := factory(getCfg(), cb).(*filter) + f := factory(conf, cb).(*filter) h := http.Header{} h.Set(":path", "/echo?code=123&state="+tt.state) h.Set("cookie", tt.cookie) @@ -187,6 +214,7 @@ func TestCallback(t *testing.T) { if tt.checkRedirectClientBack != nil { resp := res.(*api.LocalResponse) + assert.Equal(t, http.StatusFound, resp.Code, resp.Msg) tt.checkRedirectClientBack(f, resp.Header) } }) @@ -197,7 +225,7 @@ func TestBadOIDCTokenCookie(t *testing.T) { cb := envoy.NewFilterCallbackHandler() f := factory(getCfg(), cb).(*filter) h := http.Header{} - h.Set("Cookie", "htnn_oidc_token=xxx") + h.Set("Cookie", "htnn_oidc_token_id=xxx") hdr := envoy.NewRequestHeaderMap(h) res := f.DecodeHeaders(hdr, true) resp := res.(*api.LocalResponse) @@ -205,19 +233,30 @@ func TestBadOIDCTokenCookie(t *testing.T) { assert.Equal(t, "bad oidc cookie", resp.Msg) } +type mockTokenSource struct { +} + +func (m *mockTokenSource) Token() (*oauth2.Token, error) { + return nil, nil +} + func TestAttachInfo(t *testing.T) { conf := getCfg() verifier := oauth2.GenerateVerifier() state := generateState(verifier, conf.ClientSecret, "https://127.0.0.1:2379/x?y=1") rawIDToken := "rawIDToken" + rawIDToken2 := "rawIDToken2" accessToken := "accessToken" + accessToken2 := "accessToken2" + refreshToken := "refreshToken" token := (&oauth2.Token{ - AccessToken: accessToken, - Expiry: time.Now().Add(1 * time.Hour), + AccessToken: accessToken, + RefreshToken: refreshToken, + Expiry: time.Now().Add(1 * time.Hour), }).WithExtra(map[string]interface{}{ "id_token": rawIDToken, }) - nonce, _ := conf.cookieEncoding.Encode("htnn_oidc_nonce", "xxx") + nonce, _ := conf.cookieEncoding.Encode("htnn_oidc_nonce_id", "xxx") patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "Exchange", token, nil) patches.ApplyMethodReturn(conf.verifier, "Verify", &oidc.IDToken{ @@ -226,22 +265,123 @@ func TestAttachInfo(t *testing.T) { defer patches.Reset() cb := envoy.NewFilterCallbackHandler() - f := factory(getCfg(), cb).(*filter) + f := factory(conf, cb).(*filter) h := http.Header{} h.Set(":path", "/echo?code=123&state="+state) - h.Set("cookie", "htnn_oidc_nonce="+nonce) + h.Set("cookie", "htnn_oidc_nonce_id="+nonce) hdr := envoy.NewRequestHeaderMap(h) res := f.DecodeHeaders(hdr, true) resp := res.(*api.LocalResponse) cookie := resp.Header.Get("Set-Cookie") - assert.Contains(t, cookie, "Max-Age=3599;") + assert.Contains(t, cookie, "Max-Age=7199;") // the ttl is from id token expiry v := strings.SplitN(strings.Split(cookie, ";")[0], "=", 2)[1] - h = http.Header{} - hdr = envoy.NewRequestHeaderMap(h) - assert.Equal(t, api.Continue, f.attachInfo(hdr, v)) - bearer, _ := hdr.Get("authorization") - idTokenSet, _ := hdr.Get("my-id-token") - assert.Equal(t, "Bearer accessToken", bearer) - assert.Equal(t, rawIDToken, idTokenSet) + + expiredToken, _ := conf.cookieEncoding.Encode("htnn_oidc_token_id", Tokens{ + Oauth2Token: &oauth2.Token{ + AccessToken: "expiredToken", + Expiry: time.Now().Add(-1 * time.Hour), + RefreshToken: refreshToken, + }, + }) + expiredAndUnrefreshableToken, _ := conf.cookieEncoding.Encode("htnn_oidc_token_id", Tokens{ + Oauth2Token: &oauth2.Token{ + AccessToken: "expiredToken", + Expiry: time.Now().Add(-1 * time.Hour), + }, + }) + + refreshedAccessToken := (&oauth2.Token{ + AccessToken: accessToken2, + Expiry: time.Now().Add(1 * time.Hour), + }).WithExtra(map[string]interface{}{ + "id_token": rawIDToken2, + }) + + tests := []struct { + name string + encodedToken string + mock func() *gomonkey.Patches + config *config + res api.ResultAction + authorization string + idTokenSet string + checkCookie func(cookie string) + }{ + { + name: "sanity", + encodedToken: v, + mock: func() *gomonkey.Patches { + tkSrc := &mockTokenSource{} + patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "TokenSource", tkSrc) + return patches + }, + res: api.Continue, + authorization: "Bearer accessToken", + idTokenSet: rawIDToken, + }, + { + name: "refresh token", + encodedToken: expiredToken, + mock: func() *gomonkey.Patches { + tkSrc := &mockTokenSource{} + patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "TokenSource", tkSrc) + patches.ApplyMethodReturn(tkSrc, "Token", refreshedAccessToken, nil) + return patches + }, + res: api.Continue, + authorization: "Bearer accessToken2", + idTokenSet: rawIDToken2, + checkCookie: func(cookie string) { + // A new cookie should be set + assert.Contains(t, cookie, "Max-Age=3599;") + }, + }, + { + name: "unrefreshable token", + encodedToken: expiredAndUnrefreshableToken, + res: &api.LocalResponse{Code: 401}, + }, + { + name: "failed to refresh token", + encodedToken: expiredToken, + mock: func() *gomonkey.Patches { + tkSrc := &mockTokenSource{} + patches := gomonkey.ApplyMethodReturn(conf.oauth2Config, "TokenSource", tkSrc) + patches.ApplyMethodReturn(tkSrc, "Token", nil, errors.New("failed to refresh")) + return patches + }, + res: &api.LocalResponse{Code: 401}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.mock != nil { + patches := tt.mock() + defer patches.Reset() + } + + conf := getCfg() + if tt.config != nil { + conf = tt.config + } + f := factory(conf, cb).(*filter) + h = http.Header{} + hdr = envoy.NewRequestHeaderMap(h) + + assert.Equal(t, tt.res, f.attachInfo(hdr, tt.encodedToken)) + bearer, _ := hdr.Get("authorization") + idTokenSet, _ := hdr.Get("my-id-token") + assert.Equal(t, tt.authorization, bearer) + assert.Equal(t, tt.idTokenSet, idTokenSet) + + if tt.checkCookie != nil { + h = http.Header{} + hdr := envoy.NewResponseHeaderMap(h) + f.EncodeHeaders(hdr, true) + cookie, _ := hdr.Get("Set-Cookie") + tt.checkCookie(cookie) + } + }) + } }