diff --git a/recipe/emailpassword/network_interceptor_test.go b/recipe/emailpassword/network_interceptor_test.go new file mode 100644 index 00000000..605d448f --- /dev/null +++ b/recipe/emailpassword/network_interceptor_test.go @@ -0,0 +1,57 @@ +package emailpassword + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/supertokens/supertokens-golang/supertokens" + "github.com/supertokens/supertokens-golang/test/unittesting" +) + +var isNetworkIntercepted = false + +func TestNetworkInterceptorDuringSignIn(t *testing.T) { + configValue := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + NetworkInterceptor: func(request *http.Request, context supertokens.UserContext) *http.Request { + isNetworkIntercepted = true + return request + }, + }, + AppInfo: supertokens.AppInfo{ + AppName: "SuperTokens", + APIDomain: "api.supertokens.io", + WebsiteDomain: "supertokens.io", + }, + RecipeList: []supertokens.Recipe{ + Init(nil), + }, + } + BeforeEach() + + unittesting.StartUpST("localhost", "8080") + + defer AfterEach() + + err := supertokens.Init(configValue) + + if err != nil { + t.Error(err.Error()) + } + + mux := http.NewServeMux() + testServer := httptest.NewServer(supertokens.Middleware(mux)) + defer testServer.Close() + + res, err := unittesting.SignInRequest("random@gmail.com", "validpass123", testServer.URL) + + if err != nil { + t.Error(err.Error()) + } + + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, true, isNetworkIntercepted) +} diff --git a/supertokens/querier.go b/supertokens/querier.go index 9d664645..ec6d8aa7 100644 --- a/supertokens/querier.go +++ b/supertokens/querier.go @@ -131,7 +131,6 @@ func (q *Querier) SendPostRequest(path string, data map[string]interface{}) (map if err != nil { return nil, err } - req = querierInterceptor(req, nil) apiVerion, querierAPIVersionError := q.GetQuerierAPIVersion() if querierAPIVersionError != nil { @@ -147,6 +146,10 @@ func (q *Querier) SendPostRequest(path string, data map[string]interface{}) (map req.Header.Set("rid", q.RIDToCore) } + if querierInterceptor != nil { + req = querierInterceptor(req, nil) + } + client := &http.Client{} return client.Do(req) }, len(QuerierHosts), nil) @@ -189,6 +192,10 @@ func (q *Querier) SendDeleteRequest(path string, data map[string]interface{}, pa req.Header.Set("rid", q.RIDToCore) } + if querierInterceptor != nil { + req = querierInterceptor(req, nil) + } + client := &http.Client{} return client.Do(req) }, len(QuerierHosts), nil) @@ -225,6 +232,10 @@ func (q *Querier) SendGetRequest(path string, params map[string]string) (map[str req.Header.Set("rid", q.RIDToCore) } + if querierInterceptor != nil { + req = querierInterceptor(req, nil) + } + client := &http.Client{} return client.Do(req) }, len(QuerierHosts), nil) @@ -262,6 +273,10 @@ func (q *Querier) SendGetRequestWithResponseHeaders(path string, params map[stri req.Header.Set("rid", q.RIDToCore) } + if querierInterceptor != nil { + req = querierInterceptor(req, nil) + } + client := &http.Client{} return client.Do(req) }, len(QuerierHosts), nil) @@ -296,6 +311,10 @@ func (q *Querier) SendPutRequest(path string, data map[string]interface{}) (map[ req.Header.Set("rid", q.RIDToCore) } + if querierInterceptor != nil { + req = querierInterceptor(req, nil) + } + client := &http.Client{} return client.Do(req) }, len(QuerierHosts), nil)