Skip to content

Commit

Permalink
Merge pull request #280 from supertokens/jwt-rework/fix-tests
Browse files Browse the repository at this point in the history
Update integration server and fix for tests
  • Loading branch information
rishabhpoddar authored May 6, 2023
2 parents 2b05770 + 75b02b9 commit 538437c
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 6 deletions.
2 changes: 1 addition & 1 deletion recipe/session/recipeImplementation.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ
}

session := response.Session
frontToken := BuildFrontToken(session.UserID, session.ExpiryTime, responseToken.Payload)
frontToken := BuildFrontToken(session.UserID, response.AccessToken.Expiry, responseToken.Payload)

sessionContainerInput := makeSessionContainerInput(response.AccessToken.Token, session.Handle, session.UserID, responseToken.Payload, result, frontToken, response.AntiCsrfToken, nil, &response.RefreshToken, true)
sessionContainer := newSessionContainer(config, &sessionContainerInput)
Expand Down
181 changes: 176 additions & 5 deletions test/frontendIntegration/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package main

import (
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -58,10 +59,24 @@ func maxVersion(version1 string, version2 string) string {
return version2
}

func isProtectedProp(prop string) bool {
protectedProps := []string{
"sub",
"iat",
"exp",
"sessionHandle",
"parentRefreshTokenHash1",
"refreshTokenHash1",
"antiCsrfToken",
}

return supertokens.DoesSliceContainString(prop, protectedProps)
}

var routes *http.Handler

func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) {
if maxVersion(supertokens.VERSION, "0.3.1") == supertokens.VERSION && enableJWT {
if maxVersion(supertokens.VERSION, "0.12.0") == supertokens.VERSION && enableJWT {
port := "8080"
if len(os.Args) == 2 {
port = os.Args[1]
Expand All @@ -81,6 +96,7 @@ func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) {
},
RecipeList: []supertokens.Recipe{
session.Init(&sessmodels.TypeInput{
ExposeAccessTokenToFrontendInCookieBasedAuth: true,
ErrorHandlers: &sessmodels.ErrorHandlers{
OnUnauthorised: func(message string, req *http.Request, res http.ResponseWriter) error {
res.Header().Set("Content-Type", "text/html; charset=utf-8")
Expand All @@ -93,7 +109,61 @@ func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) {
Override: &sessmodels.OverrideStruct{
Functions: func(originalImplementation sessmodels.RecipeInterface) sessmodels.RecipeInterface {
ogCNS := *originalImplementation.CreateNewSession
(*originalImplementation.CreateNewSession) = func(userID string, accessTokenPayload map[string]interface{}, sessionDataInDatabase map[string]interface{}, disableAntiCsrf *bool, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) {
*originalImplementation.CreateNewSession = func(userID string, accessTokenPayload map[string]interface{}, sessionDataInDatabase map[string]interface{}, disableAntiCsrf *bool, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) {
if accessTokenPayload == nil {
accessTokenPayload = map[string]interface{}{}
}
accessTokenPayload["customClaim"] = "customValue"

return ogCNS(userID, accessTokenPayload, sessionDataInDatabase, disableAntiCsrf, userContext)
}
return originalImplementation
},
APIs: func(originalImplementation sessmodels.APIInterface) sessmodels.APIInterface {
originalImplementation.RefreshPOST = nil
return originalImplementation
},
},
}),
},
})

if err != nil {
panic(err.Error())
}
} else if maxVersion(supertokens.VERSION, "0.3.1") == supertokens.VERSION && enableJWT {
port := "8080"
if len(os.Args) == 2 {
port = os.Args[1]
}
antiCsrf := "NONE"
if enableAntiCsrf {
antiCsrf = "VIA_TOKEN"
}
err := supertokens.Init(supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:9000",
},
AppInfo: supertokens.AppInfo{
AppName: "SuperTokens",
APIDomain: "0.0.0.0:" + port,
WebsiteDomain: "http://localhost.org:8080",
},
RecipeList: []supertokens.Recipe{
session.Init(&sessmodels.TypeInput{
ErrorHandlers: &sessmodels.ErrorHandlers{
OnUnauthorised: func(message string, req *http.Request, res http.ResponseWriter) error {
res.Header().Set("Content-Type", "text/html; charset=utf-8")
res.WriteHeader(401)
res.Write([]byte(""))
return nil
},
},
AntiCsrf: &antiCsrf,
Override: &sessmodels.OverrideStruct{
Functions: func(originalImplementation sessmodels.RecipeInterface) sessmodels.RecipeInterface {
ogCNS := *originalImplementation.CreateNewSession
*originalImplementation.CreateNewSession = func(userID string, accessTokenPayload map[string]interface{}, sessionDataInDatabase map[string]interface{}, disableAntiCsrf *bool, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) {
if accessTokenPayload == nil {
accessTokenPayload = map[string]interface{}{}
}
Expand Down Expand Up @@ -166,6 +236,8 @@ func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) {
setEnableJWT(rw, r)
} else if r.URL.Path == "/login" && r.Method == "POST" {
login(rw, r)
} else if r.URL.Path == "/login-2.18" && r.Method == "POST" {
login218(rw, r)
} else if r.URL.Path == "/beforeeach" && r.Method == "POST" {
beforeeach(rw, r)
} else if r.URL.Path == "/testUserConfig" && r.Method == "POST" {
Expand Down Expand Up @@ -273,7 +345,8 @@ func reinitialiseBackendConfig(w http.ResponseWriter, r *http.Request) {

func featureFlag(response http.ResponseWriter, request *http.Request) {
json.NewEncoder(response).Encode(map[string]interface{}{
"sessionJwt": maxVersion(supertokens.VERSION, "0.3.1") == supertokens.VERSION && lastEnableJWTSetting,
"sessionJwt": maxVersion(supertokens.VERSION, "0.3.1") == supertokens.VERSION && lastEnableJWTSetting,
"v3AccessToken": maxVersion(supertokens.VERSION, "0.12.0") == supertokens.VERSION,
})
}

Expand Down Expand Up @@ -348,15 +421,48 @@ func updateJwt(response http.ResponseWriter, request *http.Request) {
var body map[string]interface{}
_ = json.NewDecoder(request.Body).Decode(&body)
userSession := session.GetSessionFromRequestContext(request.Context())
userSession.MergeIntoAccessTokenPayload(body)

sessionAccessTokenPayload := userSession.GetAccessTokenPayload()
clearing := map[string]interface{}{}

for k := range sessionAccessTokenPayload {
if !isProtectedProp(k) {
clearing[k] = nil
}
}

for k, v := range body {
clearing[k] = v
}

userSession.MergeIntoAccessTokenPayload(clearing)
json.NewEncoder(response).Encode(userSession.GetAccessTokenPayload())
}

func updateJwtWithHandle(response http.ResponseWriter, request *http.Request) {
var body map[string]interface{}
_ = json.NewDecoder(request.Body).Decode(&body)
userSession := session.GetSessionFromRequestContext(request.Context())
session.MergeIntoAccessTokenPayload(userSession.GetHandle(), body)
sessionInformation, err := session.GetSessionInformation(userSession.GetHandle())

if err != nil {
response.WriteHeader(500)
response.Write([]byte(""))
return
}

customClaimsInPayload := sessionInformation.CustomClaimsInAccessTokenPayload
clearing := map[string]interface{}{}

for k := range customClaimsInPayload {
clearing[k] = nil
}

for k, v := range body {
clearing[k] = v
}

session.MergeIntoAccessTokenPayload(userSession.GetHandle(), clearing)
json.NewEncoder(response).Encode(userSession.GetAccessTokenPayload())
}

Expand Down Expand Up @@ -407,6 +513,71 @@ func login(response http.ResponseWriter, request *http.Request) {
response.Write([]byte(sess.GetUserID()))
}

func login218(response http.ResponseWriter, request *http.Request) {
var body map[string]interface{}
_ = json.NewDecoder(request.Body).Decode(&body)

userID := body["userId"].(string)
payload := body["payload"].(map[string]interface{})

querier, err := supertokens.GetNewQuerierInstanceOrThrowError("session")

if err != nil {
response.WriteHeader(500)
response.Write([]byte(""))
return
}

querier.SetApiVersionForTests("2.18")
resp, err := querier.SendPostRequest("/recipe/session", map[string]interface{}{
"userId": userID,
"userDataInJWT": payload,
"userDataInDatabase": map[string]interface{}{},
"enableAntiCsrf": false,
})

if err != nil {
response.WriteHeader(500)
response.Write([]byte(""))
return
}

querier.SetApiVersionForTests("")

responseByte, err := json.Marshal(resp)
if err != nil {
response.WriteHeader(500)
response.Write([]byte(""))
return
}
var sessionResp sessmodels.CreateOrRefreshAPIResponse
err = json.Unmarshal(responseByte, &sessionResp)
if err != nil {
response.WriteHeader(500)
response.Write([]byte(""))
return
}

legacyAccessToken := sessionResp.AccessToken.Token
legacyRefreshToken := sessionResp.RefreshToken.Token

frontTokenJson := json.NewEncoder(response).Encode(map[string]interface{}{
"uid": userID,
"ate": session.GetCurrTimeInMS() + 3600000,
"up": payload,
})

parsed, _ := json.Marshal(frontTokenJson)
data := []byte(parsed)

frontToken := base64.StdEncoding.EncodeToString(data)

response.Header().Set("st-access-token", legacyAccessToken)
response.Header().Set("st-refresh-token", legacyRefreshToken)
response.Header().Set("front-token", frontToken)
response.Write([]byte(""))
}

func fail(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte(""))
Expand Down

0 comments on commit 538437c

Please sign in to comment.