Skip to content

Commit

Permalink
Merge pull request #286 from supertokens/feat/usercontext-request-helper
Browse files Browse the repository at this point in the history
feat: Add a helper function to read the original request from the user context inside overrides
  • Loading branch information
rishabhpoddar authored May 23, 2023
2 parents 714d2e0 + f6b0a7b commit b7905ad
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 1 deletion.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

## [0.12.4] - 2023-05-23

### Changes

- Added a new `GetRequestFromUserContext` function that can be used to read the original network request from the user context in overridden APIs and recipe functions

## [0.12.3] - 2023-05-22

### Added
Expand Down
98 changes: 98 additions & 0 deletions recipe/emailpassword/userContext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,101 @@ func TestDefaultUserContext(t *testing.T) {
assert.True(t, signInAPIContextWorks)
assert.True(t, createNewSessionContextWorks)
}

func TestGetRequestFromUserContext(t *testing.T) {
signInContextWorks := false
signInAPIContextWorks := false
createNewSessionContextWorks := false

configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
},
AppInfo: supertokens.AppInfo{
APIDomain: "api.supertokens.io",
AppName: "SuperTokens",
WebsiteDomain: "supertokens.io",
},
RecipeList: []supertokens.Recipe{
Init(&epmodels.TypeInput{
Override: &epmodels.OverrideStruct{
Functions: func(originalImplementation epmodels.RecipeInterface) epmodels.RecipeInterface {
originalSignIn := *originalImplementation.SignIn
newSignIn := func(email string, password string, userContext supertokens.UserContext) (epmodels.SignInResponse, error) {
requestFromUserContext := supertokens.GetRequestFromUserContext(userContext)
if requestFromUserContext != nil {
assert.True(t, requestFromUserContext.Method == "POST")
assert.True(t, requestFromUserContext.RequestURI == "/auth/signin")
signInContextWorks = true
}
return originalSignIn(email, password, userContext)
}
*originalImplementation.SignIn = newSignIn
return originalImplementation
},

APIs: func(originalImplementation epmodels.APIInterface) epmodels.APIInterface {
originalSignInPOST := *originalImplementation.SignInPOST
newSignInPOST := func(formFields []epmodels.TypeFormField, options epmodels.APIOptions, userContext supertokens.UserContext) (epmodels.SignInPOSTResponse, error) {
requestFromUserContext := supertokens.GetRequestFromUserContext(userContext)
if requestFromUserContext != nil {
assert.True(t, requestFromUserContext.Method == "POST")
assert.True(t, requestFromUserContext.RequestURI == "/auth/signin")
signInAPIContextWorks = true
}
return originalSignInPOST(formFields, options, userContext)
}
*originalImplementation.SignInPOST = newSignInPOST
return originalImplementation
},
},
}),
session.Init(&sessmodels.TypeInput{
GetTokenTransferMethod: func(req *http.Request, forCreateNewSession bool, userContext supertokens.UserContext) sessmodels.TokenTransferMethod {
return sessmodels.CookieTransferMethod
},

Override: &sessmodels.OverrideStruct{
Functions: func(originalImplementation sessmodels.RecipeInterface) sessmodels.RecipeInterface {
originalCreateNewSession := *originalImplementation.CreateNewSession
newCreateNewSession := func(userID string, accessTokenPayload map[string]interface{}, sessionDataInDatabase map[string]interface{}, disableAntiCsrf *bool, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) {
requestFromUserContext := supertokens.GetRequestFromUserContext(userContext)
if requestFromUserContext != nil {
assert.True(t, requestFromUserContext.Method == "POST")
assert.True(t, requestFromUserContext.RequestURI == "/auth/signin" || requestFromUserContext.RequestURI == "/auth/signup")
createNewSessionContextWorks = true
}
return originalCreateNewSession(userID, accessTokenPayload, sessionDataInDatabase, disableAntiCsrf, userContext)
}
*originalImplementation.CreateNewSession = newCreateNewSession
return originalImplementation
},
},
}),
},
}

BeforeEach()
unittesting.StartUpST("localhost", "8080")
defer AfterEach()
err := supertokens.Init(configValue)
if err != nil {
t.Error(err.Error())
}
mux := http.NewServeMux()
testServer := httptest.NewServer(supertokens.Middleware(mux))
defer testServer.Close()

unittesting.SignupRequest("[email protected]", "validpass123", testServer.URL)

res1, err := unittesting.SignInRequest("[email protected]", "validpass123", testServer.URL)

if err != nil {
t.Error(err.Error())
}

assert.Equal(t, 200, res1.StatusCode)
assert.True(t, signInContextWorks)
assert.True(t, signInAPIContextWorks)
assert.True(t, createNewSessionContextWorks)
}
2 changes: 1 addition & 1 deletion supertokens/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const (
)

// VERSION current version of the lib
const VERSION = "0.12.3"
const VERSION = "0.12.4"

var (
cdiSupported = []string{"2.21"}
Expand Down
4 changes: 4 additions & 0 deletions supertokens/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,7 @@ func GetUsersNewestFirst(paginationToken *string, limit *int, includeRecipeIds *
func DeleteUser(userId string) error {
return deleteUser(userId)
}

func GetRequestFromUserContext(userContext UserContext) *http.Request {
return getRequestFromUserContext(userContext)
}
21 changes: 21 additions & 0 deletions supertokens/supertokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"errors"
"flag"
"net/http"
"reflect"
"strconv"
"strings"
)
Expand Down Expand Up @@ -378,3 +379,23 @@ func ResetForTest() {
func IsRunningInTestMode() bool {
return flag.Lookup("test.v") != nil || IsTestFlag
}

func getRequestFromUserContext(userContext UserContext) *http.Request {
if userContext == nil {
return nil
}

_userContext := *userContext
defaultObj, ok := _userContext["_default"]

if !ok {
return nil
}

emptyMap := map[string]interface{}{}
if reflect.TypeOf(defaultObj).Kind() != reflect.TypeOf(emptyMap).Kind() {
return nil
}

return defaultObj.(map[string]interface{})["request"].(*http.Request)
}

0 comments on commit b7905ad

Please sign in to comment.