Skip to content

Commit

Permalink
fix: Split cookie to meet the cookie length requirement (#1305)
Browse files Browse the repository at this point in the history
Signed-off-by: jyu6 <[email protected]>
Co-authored-by: jyu6 <[email protected]>
  • Loading branch information
jy4096 and jyu6 authored Nov 1, 2023
1 parent 30e1d4d commit 449dfd3
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 49 deletions.
3 changes: 0 additions & 3 deletions cmd/commands/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ func NewServerCommand() *cobra.Command {
baseHref string
disableAuth bool
dexServerAddr string
dexProxyAddr string
serverAddr string
)

Expand All @@ -56,7 +55,6 @@ func NewServerCommand() *cobra.Command {
BaseHref: baseHref,
DisableAuth: disableAuth,
DexServerAddr: dexServerAddr,
DexProxyAddr: dexProxyAddr,
ServerAddr: serverAddr,
}
server := svrcmd.NewServer(opts)
Expand All @@ -70,7 +68,6 @@ func NewServerCommand() *cobra.Command {
command.Flags().StringVar(&baseHref, "base-href", "/", "Base href for Numaflow server, defaults to '/'.")
command.Flags().BoolVar(&disableAuth, "disable-auth", false, "Whether to disable authentication and authorization, defaults to false.")
command.Flags().StringVar(&dexServerAddr, "dex-server-addr", "http://numaflow-dex-server:5556/dex", "The actual address of the Dex server for the reverse proxy to target.")
command.Flags().StringVar(&dexProxyAddr, "dex-proxy-addr", "https://localhost:8443/dex", "The proxy address of the Dex server.")
command.Flags().StringVar(&serverAddr, "server-addr", "https://localhost:8443", "The address of the Numaflow server.")
return command
}
52 changes: 31 additions & 21 deletions server/apis/v1/dexauthn.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type DexObject struct {
}

// NewDexObject returns a new DexObject.
func NewDexObject(baseURL string, baseHref string, proxyURL string) (*DexObject, error) {
func NewDexObject(baseURL string, baseHref string, dexURL string) (*DexObject, error) {
issuerURL, err := url.JoinPath(baseURL, "/dex")
if err != nil {
return nil, err
Expand All @@ -61,7 +61,7 @@ func NewDexObject(baseURL string, baseHref string, proxyURL string) (*DexObject,
client.Transport = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client.Transport = NewDexRewriteURLRoundTripper(proxyURL, client.Transport)
client.Transport = NewDexRewriteURLRoundTripper(dexURL, client.Transport)
return &DexObject{
clientID: common.AppClientID,
issuerURL: issuerURL,
Expand Down Expand Up @@ -100,17 +100,27 @@ func (d *DexObject) oauth2Config(scopes []string) (*oauth2.Config, error) {

func (d *DexObject) Authenticate(c *gin.Context) (*authn.UserInfo, error) {
var userInfo authn.UserInfo
userIdentityTokenStr, err := c.Cookie(common.UserIdentityCookieName)
cookies := c.Request.Cookies()
userIdentityTokenStr, err := common.JoinCookies(common.UserIdentityCookieName, cookies)
if err != nil {
return nil, fmt.Errorf("failed to get user identity token from cookie: %v", err)
return nil, fmt.Errorf("failed to retrieve user identity token: %v", err)
}
if userIdentityTokenStr == "" {
return nil, fmt.Errorf("failed to retrieve user identity token: empty token")
}
if err = json.Unmarshal([]byte(userIdentityTokenStr), &userInfo); err != nil {
return nil, fmt.Errorf("failed to parse user identity token: %v", err)
return nil, fmt.Errorf("user is not authenticated, err: %s", err.Error())
}
_, err = d.verify(c, userInfo.IDToken)
idToken, err := d.verify(c.Request.Context(), userInfo.IDToken)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to verify ID token: %w", err)
}

var claims authn.IDTokenClaims
if err = idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("error decoding ID token claims: %w", err)
}
userInfo = authn.NewUserInfo(&claims, userInfo.IDToken, userInfo.RefreshToken)
return &userInfo, nil
}

Expand Down Expand Up @@ -214,17 +224,8 @@ func (d *DexObject) handleCallback(c *gin.Context) {
c.JSON(http.StatusOK, NewNumaflowAPIResponse(&errMsg, nil))
return
}

idToken, err := d.verify(r.Context(), rawIDToken)
if err != nil {
errMsg := fmt.Sprintf("Failed to verify ID token: %v", err)
c.JSON(http.StatusOK, NewNumaflowAPIResponse(&errMsg, nil))
return
}

var claims authn.IDTokenClaims
if err := idToken.Claims(&claims); err != nil {
errMsg := fmt.Sprintf("error decoding ID token claims: %v", err)
if rawIDToken == "" {
errMsg := "Failed to get id_token: empty raw ID Token"
c.JSON(http.StatusOK, NewNumaflowAPIResponse(&errMsg, nil))
return
}
Expand All @@ -236,14 +237,24 @@ func (d *DexObject) handleCallback(c *gin.Context) {
return
}

res := authn.NewUserInfo(claims, rawIDToken, refreshToken)
// no need to include claims in the cookie
res := authn.NewUserInfo(nil, rawIDToken, refreshToken)
tokenStr, err := json.Marshal(res)
if err != nil {
errMsg := fmt.Sprintf("Failed to convert to token string: %v", err)
c.JSON(http.StatusOK, NewNumaflowAPIResponse(&errMsg, nil))
return
}
c.SetCookie(common.UserIdentityCookieName, string(tokenStr), common.UserIdentityCookieMaxAge, "/", "", true, true)

cookies, err := common.MakeCookieMetadata(common.UserIdentityCookieName, string(tokenStr))
if err != nil {
errMsg := fmt.Sprintf("Failed to create cookies: %v", err)
c.JSON(http.StatusOK, NewNumaflowAPIResponse(&errMsg, nil))
return
}
for _, cookie := range cookies {
c.SetCookie(cookie.Key, cookie.Value, common.StateCookieMaxAge, "/", "", true, true)
}
c.JSON(http.StatusOK, NewNumaflowAPIResponse(nil, res))
}

Expand Down Expand Up @@ -277,7 +288,6 @@ func NewDexReverseProxy(target string) func(c *gin.Context) {
proxyUrl, _ := url.Parse(target)
c.Request.URL.Path = c.Param("name")
proxy := httputil.NewSingleHostReverseProxy(proxyUrl)
fmt.Println("proxy", proxyUrl, c.Request.URL.Path)
proxy.Transport = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
Expand Down
40 changes: 33 additions & 7 deletions server/apis/v1/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ type handler struct {
kubeClient kubernetes.Interface
metricsClient *metricsversiond.Clientset
numaflowClient dfv1clients.NumaflowV1alpha1Interface
dexObj *DexObject
}

// NewHandler is used to provide a new instance of the handler type
func NewHandler() (*handler, error) {
func NewHandler(dexObj *DexObject) (*handler, error) {
var (
k8sRestConfig *rest.Config
err error
Expand All @@ -82,24 +83,49 @@ func NewHandler() (*handler, error) {
kubeClient: kubeClient,
metricsClient: metricsClient,
numaflowClient: numaflowClient,
dexObj: dexObj,
}, nil
}

// AuthInfo loads and returns auth info from cookie
func (h *handler) AuthInfo(c *gin.Context) {
userIdentityTokenStr, err := c.Cookie(common.UserIdentityCookieName)
if h.dexObj == nil {
errMsg := "User is not authenticated: missing Dex"
c.JSON(http.StatusUnauthorized, NewNumaflowAPIResponse(&errMsg, nil))
}
cookies := c.Request.Cookies()
userIdentityTokenStr, err := common.JoinCookies(common.UserIdentityCookieName, cookies)
if err != nil {
errMsg := fmt.Sprintf("User is not authenticated, err: %s", err.Error())
c.JSON(http.StatusUnauthorized, NewNumaflowAPIResponse(&errMsg, nil))
return
}
if userIdentityTokenStr == "" {
errMsg := "User is not authenticated, err: empty Token"
c.JSON(http.StatusUnauthorized, NewNumaflowAPIResponse(&errMsg, nil))
return
}
var userInfo authn.UserInfo
if err = json.Unmarshal([]byte(userIdentityTokenStr), &userInfo); err != nil {
errMsg := fmt.Sprintf("User is not authenticated, err: %s", err.Error())
c.JSON(http.StatusUnauthorized, NewNumaflowAPIResponse(&errMsg, nil))
return
}

idToken, err := h.dexObj.verify(c.Request.Context(), userInfo.IDToken)
if err != nil {
errMsg := fmt.Sprintf("user is not authenticated, err: %s", err.Error())
errMsg := fmt.Sprintf("Failed to verify ID token: %s", err)
c.JSON(http.StatusUnauthorized, NewNumaflowAPIResponse(&errMsg, nil))
return
}
userInfo := &authn.UserInfo{}
if err = json.Unmarshal([]byte(userIdentityTokenStr), userInfo); err != nil {
errMsg := fmt.Sprintf("user is not authenticated, err: %s", err.Error())
var claims authn.IDTokenClaims
if err = idToken.Claims(&claims); err != nil {
errMsg := fmt.Sprintf("Error decoding ID token claims: %s", err)
c.JSON(http.StatusUnauthorized, NewNumaflowAPIResponse(&errMsg, nil))
return
}
res := authn.NewUserInfo(userInfo.IDTokenClaims, userInfo.IDToken, userInfo.RefreshToken)

res := authn.NewUserInfo(&claims, userInfo.IDToken, userInfo.RefreshToken)
c.JSON(http.StatusOK, NewNumaflowAPIResponse(nil, res))
}

Expand Down
20 changes: 19 additions & 1 deletion server/apis/v1/noauthhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ limitations under the License.
package v1

import (
"fmt"
"net/http"
"strings"

"github.com/gin-gonic/gin"

Expand Down Expand Up @@ -47,6 +49,22 @@ func (h *noAuthHandler) Callback(c *gin.Context) {

// Logout is used to remove auth cookie ending a user's session.
func (h *noAuthHandler) Logout(c *gin.Context) {
c.SetCookie(common.UserIdentityCookieName, "", -1, "/", "", true, true)
cookies := c.Request.Cookies()
tokenString, err := common.JoinCookies(common.UserIdentityCookieName, cookies)
if err != nil {
errMsg := fmt.Sprintf("Failed to retrieve user identity token: %v", err)
c.JSON(http.StatusOK, NewNumaflowAPIResponse(&errMsg, nil))
}
if tokenString == "" {
errMsg := "Failed to retrieve user identity token: empty token"
c.JSON(http.StatusOK, NewNumaflowAPIResponse(&errMsg, nil))
}

for _, cookie := range cookies {
if !strings.HasPrefix(cookie.Name, common.UserIdentityCookieName) {
continue
}
c.SetCookie(cookie.Name, "", -1, "/", "", true, true)
}
c.JSON(http.StatusOK, NewNumaflowAPIResponse(nil, nil))
}
8 changes: 4 additions & 4 deletions server/authn/user_id_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ type IDTokenClaims struct {
// UserInfo includes information about the user identity
// It holds the IDTokenClaims, IDToken and RefreshToken for the user
type UserInfo struct {
IDTokenClaims IDTokenClaims `json:"id_token_claims"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
IDTokenClaims *IDTokenClaims `json:"id_token_claims,omitempty"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
}

func NewUserInfo(itc IDTokenClaims, idToken string, refreshToken string) UserInfo {
func NewUserInfo(itc *IDTokenClaims, idToken string, refreshToken string) UserInfo {
return UserInfo{
IDTokenClaims: itc,
IDToken: idToken,
Expand Down
4 changes: 0 additions & 4 deletions server/cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ type ServerOptions struct {
BaseHref string
DisableAuth bool
DexServerAddr string
DexProxyAddr string
ServerAddr string
}

Expand Down Expand Up @@ -83,7 +82,6 @@ func (s *server) Start() {
routes.AuthInfo{
DisableAuth: s.options.DisableAuth,
DexServerAddr: s.options.DexServerAddr,
DexProxyAddr: s.options.DexProxyAddr,
ServerAddr: s.options.ServerAddr,
},
s.options.BaseHref,
Expand All @@ -100,7 +98,6 @@ func (s *server) Start() {
"version", numaflow.GetVersion(),
"disable-auth", s.options.DisableAuth,
"dex-server-addr", s.options.DexServerAddr,
"dex-proxy-addr", s.options.DexProxyAddr,
"server-addr", s.options.ServerAddr)
if err := server.ListenAndServe(); err != nil {
panic(err)
Expand All @@ -116,7 +113,6 @@ func (s *server) Start() {
"version", numaflow.GetVersion(),
"disable-auth", s.options.DisableAuth,
"dex-server-addr", s.options.DexServerAddr,
"dex-proxy-addr", s.options.DexProxyAddr,
"server-addr", s.options.ServerAddr)
if err := server.ListenAndServeTLS("", ""); err != nil {
panic(err)
Expand Down
2 changes: 1 addition & 1 deletion server/common/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ const (
AppClientID = "numaflow-server-app"
StateCookieName = "numaflow-oauthstate"
StateCookieMaxAge = 60 * 5
UserIdentityCookieName = "user-identity-token"
UserIdentityCookieName = "numaflow.token"
UserIdentityCookieMaxAge = 60 * 60
)
100 changes: 100 additions & 0 deletions server/common/cookie.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package common

import (
"fmt"
"math"
"net/http"
"net/url"
"strconv"
"strings"
)

const (
maxCookieLength = 4096
maxValueLength = maxCookieLength - 500
// max number of chunks a cookie can be broken into
maxCookieNumber = 10
)

type IdentityCookie struct {
Key string
Value string
}

// MakeCookieMetadata generates a string representing a Web cookie. Yum!
func MakeCookieMetadata(key, value string) ([]IdentityCookie, error) {
numberOfCookies := int(math.Ceil(float64(len(value)) / float64(maxValueLength)))
if numberOfCookies > maxCookieNumber {
return nil, fmt.Errorf("the authentication token is %d characters long and requires %d cookies but the max number of cookies is %d", len(value), numberOfCookies, maxCookieNumber)
}
return splitCookie(key, value), nil
}

// splitCookie splits a cookie because browser requires cookie to be < 4kb.
// In order to support cookies longer than 4kb, we split cookie into multiple chunks.
func splitCookie(key, value string) []IdentityCookie {
var cookies []IdentityCookie
valueLength := len(value)
numberOfChunks := int(math.Ceil(float64(valueLength) / float64(maxValueLength)))

var end int
for i, j := 0, 0; i < valueLength; i, j = i+maxValueLength, j+1 {
end = i + maxValueLength
if end > valueLength {
end = valueLength
}
var cookie IdentityCookie
if j == 0 {
cookie = IdentityCookie{
Key: key,
Value: fmt.Sprintf("%d:%s", numberOfChunks, value[i:end]),
}
} else {
cookie = IdentityCookie{
Key: fmt.Sprintf("%s-%d", key, j),
Value: value[i:end],
}
}
cookies = append(cookies, cookie)
}
return cookies
}

// JoinCookies combines chunks of cookie based on Key as prefix. It returns cookie
// Value as string. cookieString is of format key1=value1; key2=value2; key3=value3
// first chunk will be of format numaflow.token=<numberOfChunks>:token; attributes
func JoinCookies(key string, cookieList []*http.Cookie) (string, error) {
cookies := make(map[string]string)
for _, cookie := range cookieList {
if !strings.HasPrefix(cookie.Name, key) {
continue
}
val, _ := url.QueryUnescape(cookie.Value)
cookies[cookie.Name] = val
}

var sb strings.Builder
var numOfChunks int
var err error
var token string
var ok bool

if token, ok = cookies[key]; !ok {
return "", fmt.Errorf("failed to retrieve cookie %s", key)
}
parts := strings.Split(token, ":")

if len(parts) >= 2 {
if numOfChunks, err = strconv.Atoi(parts[0]); err != nil {
return "", err
}
sb.WriteString(strings.Join(parts[1:], ":"))
} else {
return "", fmt.Errorf("invalid cookie for key %s", key)
}

for i := 1; i < numOfChunks; i++ {
sb.WriteString(cookies[fmt.Sprintf("%s-%d", key, i)])
}
return sb.String(), nil
}
Loading

0 comments on commit 449dfd3

Please sign in to comment.