Skip to content

Commit

Permalink
more test transfers
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabhpoddar committed Apr 10, 2024
1 parent 6845699 commit 325bfae
Show file tree
Hide file tree
Showing 6 changed files with 803 additions and 1,935 deletions.
342 changes: 342 additions & 0 deletions recipe/thirdparty/override_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/supertokens/supertokens-golang/recipe/passwordless"
"github.com/supertokens/supertokens-golang/recipe/passwordless/plessmodels"
"github.com/supertokens/supertokens-golang/recipe/session"
"github.com/supertokens/supertokens-golang/recipe/session/sessmodels"
"github.com/supertokens/supertokens-golang/recipe/thirdparty/tpmodels"
Expand All @@ -34,6 +36,346 @@ import (
"gopkg.in/h2non/gock.v1"
)

func TestOverridingAPIs(t *testing.T) {
var userRef *tpmodels.User
var newUser bool
configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
},
AppInfo: supertokens.AppInfo{
APIDomain: "api.supertokens.io",
AppName: "SuperTokens",
WebsiteDomain: "supertokens.io",
},
RecipeList: []supertokens.Recipe{
session.Init(&sessmodels.TypeInput{
GetTokenTransferMethod: func(req *http.Request, forCreateNewSession bool, userContext supertokens.UserContext) sessmodels.TokenTransferMethod {
return sessmodels.CookieTransferMethod
},
}),
passwordless.Init(plessmodels.TypeInput{
FlowType: "USER_INPUT_CODE_AND_MAGIC_LINK",
ContactMethodEmail: plessmodels.ContactMethodEmailConfig{
Enabled: true,
},
}),
Init(&tpmodels.TypeInput{
SignInAndUpFeature: tpmodels.TypeInputSignInAndUp{
Providers: []tpmodels.ProviderInput{customProvider1},
},
Override: &tpmodels.OverrideStruct{
APIs: func(originalImplementation tpmodels.APIInterface) tpmodels.APIInterface {
originalThirdPartySignInUpPost := *originalImplementation.SignInUpPOST
*originalImplementation.SignInUpPOST = func(provider *tpmodels.TypeProvider, input tpmodels.TypeSignInUpInput, tenantId string, options tpmodels.APIOptions, userContext supertokens.UserContext) (tpmodels.SignInUpPOSTResponse, error) {
resp, err := originalThirdPartySignInUpPost(provider, input, tenantId, options, userContext)
userRef = &resp.OK.User
newUser = resp.OK.CreatedNewUser
return resp, err
}
return originalImplementation
},
},
}),
},
}

BeforeEach()
unittesting.StartUpST("localhost", "8080")
defer AfterEach()
err := supertokens.Init(configValue)
if err != nil {
t.Error(err.Error())
}
q, err := supertokens.GetNewQuerierInstanceOrThrowError("")
if err != nil {
t.Error(err.Error())
}
apiV, err := q.GetQuerierAPIVersion()
if err != nil {
t.Error(err.Error())
}

if unittesting.MaxVersion(apiV, "2.11") == "2.11" {
return
}

mux := http.NewServeMux()

mux.HandleFunc("/user", func(rw http.ResponseWriter, r *http.Request) {
userId := r.URL.Query().Get("userId")
fetchedUser, err := GetUserByID(userId)
if err != nil {
t.Error(err.Error())
}
jsonResp, err := json.Marshal(fetchedUser)
if err != nil {
t.Errorf("Error happened in JSON marshal. Err: %s", err)
}
rw.WriteHeader(200)
rw.Write(jsonResp)
})

defer gock.OffAll()
gock.New("https://test.com").
Post("/oauth/token").
Persist().
Reply(200).
JSON(map[string]string{})

testServer := httptest.NewServer(supertokens.Middleware(mux))
defer testServer.Close()

formFields := map[string]interface{}{
"thirdPartyId": "custom",
"redirectURIInfo": map[string]interface{}{
"redirectURIOnProviderDashboard": testServer.URL + "/callback",
"redirectURIQueryParams": map[string]interface{}{
"code": "abcdefghj",
},
},
}

postBody, err := json.Marshal(formFields)
if err != nil {
t.Error(err.Error())
}

gock.New(testServer.URL).EnableNetworking().Persist()
gock.New("http://localhost:8080/").EnableNetworking().Persist()

resp, err := http.Post(testServer.URL+"/auth/signinup", "application/json", bytes.NewBuffer(postBody))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)

signUpResponse := *unittesting.HttpResponseToConsumableInformation(resp.Body)
fetchedUser := signUpResponse["user"].(map[string]interface{})

assert.NotNil(t, userRef)
assert.True(t, newUser)
assert.Equal(t, fetchedUser["email"], userRef.Email)
assert.Equal(t, fetchedUser["id"], userRef.ID)
assert.Equal(t, fetchedUser["thirdParty"].(map[string]interface{})["id"], userRef.ThirdParty.ID)
assert.Equal(t, fetchedUser["thirdParty"].(map[string]interface{})["userId"], userRef.ThirdParty.UserID)

userRef = nil
assert.Nil(t, userRef)

formFields = map[string]interface{}{
"thirdPartyId": "custom",
"redirectURIInfo": map[string]interface{}{
"redirectURIOnProviderDashboard": testServer.URL + "/callback",
"redirectURIQueryParams": map[string]interface{}{
"code": "abcdefghj",
},
},
}

postBody, err = json.Marshal(formFields)
if err != nil {
t.Error(err.Error())
}

resp, err = http.Post(testServer.URL+"/auth/signinup", "application/json", bytes.NewBuffer(postBody))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)

signInResponse := *unittesting.HttpResponseToConsumableInformation(resp.Body)
fetchedUserFromSignIn := signInResponse["user"].(map[string]interface{})

assert.NotNil(t, userRef)
assert.False(t, newUser)
assert.Equal(t, fetchedUserFromSignIn["email"], userRef.Email)
assert.Equal(t, fetchedUserFromSignIn["id"], userRef.ID)
assert.Equal(t, fetchedUserFromSignIn["thirdParty"].(map[string]interface{})["id"], userRef.ThirdParty.ID)
assert.Equal(t, fetchedUserFromSignIn["thirdParty"].(map[string]interface{})["userId"], userRef.ThirdParty.UserID)
}

func TestOverridingFunctions(t *testing.T) {
var userRef *tpmodels.User
var newUser bool
configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
},
AppInfo: supertokens.AppInfo{
APIDomain: "api.supertokens.io",
AppName: "SuperTokens",
WebsiteDomain: "supertokens.io",
},
RecipeList: []supertokens.Recipe{
session.Init(&sessmodels.TypeInput{
GetTokenTransferMethod: func(req *http.Request, forCreateNewSession bool, userContext supertokens.UserContext) sessmodels.TokenTransferMethod {
return sessmodels.CookieTransferMethod
},
}),
passwordless.Init(plessmodels.TypeInput{
FlowType: "USER_INPUT_CODE_AND_MAGIC_LINK",
ContactMethodEmail: plessmodels.ContactMethodEmailConfig{
Enabled: true,
},
}),
Init(&tpmodels.TypeInput{
SignInAndUpFeature: tpmodels.TypeInputSignInAndUp{
Providers: []tpmodels.ProviderInput{customProvider1},
},
Override: &tpmodels.OverrideStruct{
Functions: func(originalImplementation tpmodels.RecipeInterface) tpmodels.RecipeInterface {
originalThirdPartySignInUp := *originalImplementation.SignInUp
*originalImplementation.SignInUp = func(thirdPartyID, thirdPartyUserID, email string, oAuthTokens tpmodels.TypeOAuthTokens, rawUserInfoFromProvider tpmodels.TypeRawUserInfoFromProvider, tenantId string, userContext supertokens.UserContext) (tpmodels.SignInUpResponse, error) {
resp, err := originalThirdPartySignInUp(thirdPartyID, thirdPartyUserID, email, oAuthTokens, rawUserInfoFromProvider, tenantId, userContext)
userRef = &resp.OK.User
newUser = resp.OK.CreatedNewUser
return resp, err
}
originalGetUserById := *originalImplementation.GetUserByID
*originalImplementation.GetUserByID = func(userID string, userContext supertokens.UserContext) (*tpmodels.User, error) {
resp, err := originalGetUserById(userID, userContext)
userRef = resp
return resp, err
}
return originalImplementation
},
},
}),
},
}

BeforeEach()
unittesting.StartUpST("localhost", "8080")
defer AfterEach()
err := supertokens.Init(configValue)
if err != nil {
t.Error(err.Error())
}
q, err := supertokens.GetNewQuerierInstanceOrThrowError("")
if err != nil {
t.Error(err.Error())
}
apiV, err := q.GetQuerierAPIVersion()
if err != nil {
t.Error(err.Error())
}

if unittesting.MaxVersion(apiV, "2.11") == "2.11" {
return
}

mux := http.NewServeMux()

mux.HandleFunc("/user", func(rw http.ResponseWriter, r *http.Request) {
userId := r.URL.Query().Get("userId")
fetchedUser, err := GetUserByID(userId)
if err != nil {
t.Error(err.Error())
}
jsonResp, err := json.Marshal(fetchedUser)
if err != nil {
t.Errorf("Error happened in JSON marshal. Err: %s", err)
}
rw.WriteHeader(200)
rw.Write(jsonResp)
})

defer gock.OffAll()
gock.New("https://test.com").
Post("/oauth/token").
Persist().
Reply(200).
JSON(map[string]string{})

testServer := httptest.NewServer(supertokens.Middleware(mux))
defer testServer.Close()

formFields := map[string]interface{}{
"thirdPartyId": "custom",
"redirectURIInfo": map[string]interface{}{
"redirectURIOnProviderDashboard": testServer.URL + "/callback",
"redirectURIQueryParams": map[string]interface{}{
"code": "abcdefghj",
},
},
}

postBody, err := json.Marshal(formFields)
if err != nil {
t.Error(err.Error())
}

gock.New(testServer.URL).EnableNetworking().Persist()
gock.New("http://localhost:8080/").EnableNetworking().Persist()

resp, err := http.Post(testServer.URL+"/auth/signinup", "application/json", bytes.NewBuffer(postBody))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)

signUpResponse := *unittesting.HttpResponseToConsumableInformation(resp.Body)
fetchedUser := signUpResponse["user"].(map[string]interface{})

assert.NotNil(t, userRef)
assert.True(t, newUser)
assert.Equal(t, fetchedUser["email"], userRef.Email)
assert.Equal(t, fetchedUser["id"], userRef.ID)
assert.Equal(t, fetchedUser["thirdParty"].(map[string]interface{})["id"], userRef.ThirdParty.ID)
assert.Equal(t, fetchedUser["thirdParty"].(map[string]interface{})["userId"], userRef.ThirdParty.UserID)

userRef = nil
assert.Nil(t, userRef)

formFields = map[string]interface{}{
"thirdPartyId": "custom",
"redirectURIInfo": map[string]interface{}{
"redirectURIOnProviderDashboard": testServer.URL + "/callback",
"redirectURIQueryParams": map[string]interface{}{
"code": "abcdefghj",
},
},
}

postBody, err = json.Marshal(formFields)
if err != nil {
t.Error(err.Error())
}

resp, err = http.Post(testServer.URL+"/auth/signinup", "application/json", bytes.NewBuffer(postBody))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)

signInResponse := *unittesting.HttpResponseToConsumableInformation(resp.Body)
fetchedUserFromSignIn := signInResponse["user"].(map[string]interface{})

assert.NotNil(t, userRef)
assert.False(t, newUser)
assert.Equal(t, fetchedUserFromSignIn["email"], userRef.Email)
assert.Equal(t, fetchedUserFromSignIn["id"], userRef.ID)
assert.Equal(t, fetchedUserFromSignIn["thirdParty"].(map[string]interface{})["id"], userRef.ThirdParty.ID)
assert.Equal(t, fetchedUserFromSignIn["thirdParty"].(map[string]interface{})["userId"], userRef.ThirdParty.UserID)

userRef = nil
assert.Nil(t, userRef)

req, err := http.NewRequest(http.MethodPost, testServer.URL+"/user", nil)
assert.NoError(t, err)

query := req.URL.Query()
query.Add("userId", fetchedUserFromSignIn["id"].(string))
req.URL.RawQuery = query.Encode()

res, err := http.DefaultClient.Do(req)

assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)

userByIdResponse := *unittesting.HttpResponseToConsumableInformation(res.Body)

assert.NotNil(t, userRef)
assert.Equal(t, userByIdResponse["email"], userRef.Email)
assert.Nil(t, userByIdResponse["phoneNumber"])
assert.Equal(t, userByIdResponse["id"], userRef.ID)
assert.Equal(t, userByIdResponse["thirdParty"].(map[string]interface{})["id"], userRef.ThirdParty.ID)
assert.Equal(t, userByIdResponse["thirdParty"].(map[string]interface{})["userId"], userRef.ThirdParty.UserID)
}

func TestOverrideFunctions(t *testing.T) {
var createdNewUser bool
var user tpmodels.User
Expand Down
Loading

0 comments on commit 325bfae

Please sign in to comment.