diff --git a/recipe/session/recipeImplementation.go b/recipe/session/recipeImplementation.go index 5b8aff0a..1658a607 100644 --- a/recipe/session/recipeImplementation.go +++ b/recipe/session/recipeImplementation.go @@ -176,7 +176,7 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ } session := response.Session - frontToken := BuildFrontToken(session.UserID, session.ExpiryTime, responseToken.Payload) + frontToken := BuildFrontToken(session.UserID, response.AccessToken.Expiry, responseToken.Payload) sessionContainerInput := makeSessionContainerInput(response.AccessToken.Token, session.Handle, session.UserID, responseToken.Payload, result, frontToken, response.AntiCsrfToken, nil, &response.RefreshToken, true) sessionContainer := newSessionContainer(config, &sessionContainerInput) diff --git a/test/frontendIntegration/main.go b/test/frontendIntegration/main.go index 3fc5f65b..400afa0b 100644 --- a/test/frontendIntegration/main.go +++ b/test/frontendIntegration/main.go @@ -17,6 +17,7 @@ package main import ( + "encoding/base64" "encoding/json" "fmt" "io/ioutil" @@ -58,10 +59,24 @@ func maxVersion(version1 string, version2 string) string { return version2 } +func isProtectedProp(prop string) bool { + protectedProps := []string{ + "sub", + "iat", + "exp", + "sessionHandle", + "parentRefreshTokenHash1", + "refreshTokenHash1", + "antiCsrfToken", + } + + return supertokens.DoesSliceContainString(prop, protectedProps) +} + var routes *http.Handler func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) { - if maxVersion(supertokens.VERSION, "0.3.1") == supertokens.VERSION && enableJWT { + if maxVersion(supertokens.VERSION, "0.12.0") == supertokens.VERSION && enableJWT { port := "8080" if len(os.Args) == 2 { port = os.Args[1] @@ -81,6 +96,7 @@ func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) { }, RecipeList: []supertokens.Recipe{ session.Init(&sessmodels.TypeInput{ + ExposeAccessTokenToFrontendInCookieBasedAuth: true, ErrorHandlers: &sessmodels.ErrorHandlers{ OnUnauthorised: func(message string, req *http.Request, res http.ResponseWriter) error { res.Header().Set("Content-Type", "text/html; charset=utf-8") @@ -93,7 +109,61 @@ func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) { Override: &sessmodels.OverrideStruct{ Functions: func(originalImplementation sessmodels.RecipeInterface) sessmodels.RecipeInterface { ogCNS := *originalImplementation.CreateNewSession - (*originalImplementation.CreateNewSession) = func(userID string, accessTokenPayload map[string]interface{}, sessionDataInDatabase map[string]interface{}, disableAntiCsrf *bool, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) { + *originalImplementation.CreateNewSession = func(userID string, accessTokenPayload map[string]interface{}, sessionDataInDatabase map[string]interface{}, disableAntiCsrf *bool, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) { + if accessTokenPayload == nil { + accessTokenPayload = map[string]interface{}{} + } + accessTokenPayload["customClaim"] = "customValue" + + return ogCNS(userID, accessTokenPayload, sessionDataInDatabase, disableAntiCsrf, userContext) + } + return originalImplementation + }, + APIs: func(originalImplementation sessmodels.APIInterface) sessmodels.APIInterface { + originalImplementation.RefreshPOST = nil + return originalImplementation + }, + }, + }), + }, + }) + + if err != nil { + panic(err.Error()) + } + } else if maxVersion(supertokens.VERSION, "0.3.1") == supertokens.VERSION && enableJWT { + port := "8080" + if len(os.Args) == 2 { + port = os.Args[1] + } + antiCsrf := "NONE" + if enableAntiCsrf { + antiCsrf = "VIA_TOKEN" + } + err := supertokens.Init(supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:9000", + }, + AppInfo: supertokens.AppInfo{ + AppName: "SuperTokens", + APIDomain: "0.0.0.0:" + port, + WebsiteDomain: "http://localhost.org:8080", + }, + RecipeList: []supertokens.Recipe{ + session.Init(&sessmodels.TypeInput{ + ErrorHandlers: &sessmodels.ErrorHandlers{ + OnUnauthorised: func(message string, req *http.Request, res http.ResponseWriter) error { + res.Header().Set("Content-Type", "text/html; charset=utf-8") + res.WriteHeader(401) + res.Write([]byte("")) + return nil + }, + }, + AntiCsrf: &antiCsrf, + Override: &sessmodels.OverrideStruct{ + Functions: func(originalImplementation sessmodels.RecipeInterface) sessmodels.RecipeInterface { + ogCNS := *originalImplementation.CreateNewSession + *originalImplementation.CreateNewSession = func(userID string, accessTokenPayload map[string]interface{}, sessionDataInDatabase map[string]interface{}, disableAntiCsrf *bool, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) { if accessTokenPayload == nil { accessTokenPayload = map[string]interface{}{} } @@ -166,6 +236,8 @@ func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) { setEnableJWT(rw, r) } else if r.URL.Path == "/login" && r.Method == "POST" { login(rw, r) + } else if r.URL.Path == "/login-2.18" && r.Method == "POST" { + login218(rw, r) } else if r.URL.Path == "/beforeeach" && r.Method == "POST" { beforeeach(rw, r) } else if r.URL.Path == "/testUserConfig" && r.Method == "POST" { @@ -273,7 +345,8 @@ func reinitialiseBackendConfig(w http.ResponseWriter, r *http.Request) { func featureFlag(response http.ResponseWriter, request *http.Request) { json.NewEncoder(response).Encode(map[string]interface{}{ - "sessionJwt": maxVersion(supertokens.VERSION, "0.3.1") == supertokens.VERSION && lastEnableJWTSetting, + "sessionJwt": maxVersion(supertokens.VERSION, "0.3.1") == supertokens.VERSION && lastEnableJWTSetting, + "v3AccessToken": maxVersion(supertokens.VERSION, "0.12.0") == supertokens.VERSION, }) } @@ -348,7 +421,21 @@ func updateJwt(response http.ResponseWriter, request *http.Request) { var body map[string]interface{} _ = json.NewDecoder(request.Body).Decode(&body) userSession := session.GetSessionFromRequestContext(request.Context()) - userSession.MergeIntoAccessTokenPayload(body) + + sessionAccessTokenPayload := userSession.GetAccessTokenPayload() + clearing := map[string]interface{}{} + + for k := range sessionAccessTokenPayload { + if !isProtectedProp(k) { + clearing[k] = nil + } + } + + for k, v := range body { + clearing[k] = v + } + + userSession.MergeIntoAccessTokenPayload(clearing) json.NewEncoder(response).Encode(userSession.GetAccessTokenPayload()) } @@ -356,7 +443,26 @@ func updateJwtWithHandle(response http.ResponseWriter, request *http.Request) { var body map[string]interface{} _ = json.NewDecoder(request.Body).Decode(&body) userSession := session.GetSessionFromRequestContext(request.Context()) - session.MergeIntoAccessTokenPayload(userSession.GetHandle(), body) + sessionInformation, err := session.GetSessionInformation(userSession.GetHandle()) + + if err != nil { + response.WriteHeader(500) + response.Write([]byte("")) + return + } + + customClaimsInPayload := sessionInformation.CustomClaimsInAccessTokenPayload + clearing := map[string]interface{}{} + + for k := range customClaimsInPayload { + clearing[k] = nil + } + + for k, v := range body { + clearing[k] = v + } + + session.MergeIntoAccessTokenPayload(userSession.GetHandle(), clearing) json.NewEncoder(response).Encode(userSession.GetAccessTokenPayload()) } @@ -407,6 +513,71 @@ func login(response http.ResponseWriter, request *http.Request) { response.Write([]byte(sess.GetUserID())) } +func login218(response http.ResponseWriter, request *http.Request) { + var body map[string]interface{} + _ = json.NewDecoder(request.Body).Decode(&body) + + userID := body["userId"].(string) + payload := body["payload"].(map[string]interface{}) + + querier, err := supertokens.GetNewQuerierInstanceOrThrowError("session") + + if err != nil { + response.WriteHeader(500) + response.Write([]byte("")) + return + } + + querier.SetApiVersionForTests("2.18") + resp, err := querier.SendPostRequest("/recipe/session", map[string]interface{}{ + "userId": userID, + "userDataInJWT": payload, + "userDataInDatabase": map[string]interface{}{}, + "enableAntiCsrf": false, + }) + + if err != nil { + response.WriteHeader(500) + response.Write([]byte("")) + return + } + + querier.SetApiVersionForTests("") + + responseByte, err := json.Marshal(resp) + if err != nil { + response.WriteHeader(500) + response.Write([]byte("")) + return + } + var sessionResp sessmodels.CreateOrRefreshAPIResponse + err = json.Unmarshal(responseByte, &sessionResp) + if err != nil { + response.WriteHeader(500) + response.Write([]byte("")) + return + } + + legacyAccessToken := sessionResp.AccessToken.Token + legacyRefreshToken := sessionResp.RefreshToken.Token + + frontTokenJson := json.NewEncoder(response).Encode(map[string]interface{}{ + "uid": userID, + "ate": session.GetCurrTimeInMS() + 3600000, + "up": payload, + }) + + parsed, _ := json.Marshal(frontTokenJson) + data := []byte(parsed) + + frontToken := base64.StdEncoding.EncodeToString(data) + + response.Header().Set("st-access-token", legacyAccessToken) + response.Header().Set("st-refresh-token", legacyRefreshToken) + response.Header().Set("front-token", frontToken) + response.Write([]byte("")) +} + func fail(w http.ResponseWriter, r *http.Request) { w.WriteHeader(404) w.Write([]byte(""))