Skip to content

Commit

Permalink
Merge pull request #376 from supertokens/access-token-validation
Browse files Browse the repository at this point in the history
feat: Add ValidateAccessToken function to providers
  • Loading branch information
rishabhpoddar authored Oct 4, 2023
2 parents adf83ee + 5dd3901 commit d65e7f6
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 3 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

## [0.16.1] - 2023-10-03

### Changes

- Added `ValidateAccessToken` to the configuration for social login providers, this function allows you to verify the access token returned by the social provider. If you are using Github as a provider, there is a default implementation provided for this function.

## [0.16.0] - 2023-09-27

### Fixes
Expand Down
5 changes: 3 additions & 2 deletions recipe/session/accessTokenVersions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1004,11 +1004,12 @@ func TestShouldThrowWhenRefreshInLegacySessionsWithProtectedProp(t *testing.T) {
assert.True(t, cookiesAfterRefresh["frontToken"] == "remove")
}

/**
/*
*
We want to make sure that for access token claims that can be null, the SDK does not fail access token validation if the
core does not send them as part of the payload.
For this we verify that validation passes when the keys are nil, empty or a different type
# For this we verify that validation passes when the keys are nil, empty or a different type
For now this test checks for:
- antiCsrfToken
Expand Down
82 changes: 82 additions & 0 deletions recipe/thirdparty/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@
package thirdparty

import (
"errors"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -697,3 +701,81 @@ func TestPassingScopesInConfigForGithub(t *testing.T) {
"scope": {"test-scope-1 test-scope-2"},
}, authParams)
}

func TestThatSignInUpFailsIfValidateAccessTokenReturnsError(t *testing.T) {
configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
},
AppInfo: supertokens.AppInfo{
APIDomain: "api.supertokens.io",
AppName: "SuperTokens",
WebsiteDomain: "supertokens.io",
},
RecipeList: []supertokens.Recipe{
Init(
&tpmodels.TypeInput{
SignInAndUpFeature: tpmodels.TypeInputSignInAndUp{
Providers: []tpmodels.ProviderInput{
{
Override: func(originalImplementation *tpmodels.TypeProvider) *tpmodels.TypeProvider {
originalImplementation.ExchangeAuthCodeForOAuthTokens = func(redirectURIInfo tpmodels.TypeRedirectURIInfo, userContext supertokens.UserContext) (tpmodels.TypeOAuthTokens, error) {
return map[string]interface{}{
"access_token": "wrongaccesstoken",
"id_token": "wrongidtoken",
}, nil
}

return originalImplementation
},
Config: tpmodels.ProviderConfig{
ThirdPartyId: "custom",
Clients: []tpmodels.ProviderClientConfig{
{
ClientID: "test",
ClientSecret: "test-secret",
Scope: []string{"test-scope-1", "test-scope-2"},
},
},
ValidateAccessToken: func(accessToken string, clientConfig tpmodels.ProviderConfigForClientType, userContext supertokens.UserContext) error {
if accessToken == "wrongaccesstoken" {
return errors.New("Invalid access token")
}

return nil
},
},
},
},
},
},
),
},
}

BeforeEach()
unittesting.StartUpST("localhost", "8080")
defer AfterEach()
err := supertokens.Init(configValue)

if err != nil {
t.Error(err.Error())
}

mux := http.NewServeMux()
testServer := httptest.NewServer(supertokens.Middleware(mux))
defer testServer.Close()

req, err := http.NewRequest(http.MethodPost, testServer.URL+"/auth/signinup", strings.NewReader(`{"thirdPartyId": "custom", "redirectURIInfo": {"redirectURIOnProviderDashboard": "http://127.0.0.1/callback", "redirectURIQueryParams": {"code": "abcdefghj"}}}`))
if err != nil {
t.Error(err.Error())
}

res, err := http.DefaultClient.Do(req)

data2, err := io.ReadAll(res.Body)
assert.NoError(t, err)
respString := string(data2)
respString = strings.Replace(respString, "\n", "", -1)
assert.Equal(t, respString, "Invalid access token")
}
1 change: 1 addition & 0 deletions recipe/thirdparty/providers/config_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ func getProviderConfigForClient(config tpmodels.ProviderConfig, clientConfig tpm
OIDCDiscoveryEndpoint: config.OIDCDiscoveryEndpoint,
UserInfoMap: config.UserInfoMap,
ValidateIdTokenPayload: config.ValidateIdTokenPayload,
ValidateAccessToken: config.ValidateAccessToken,
RequireEmail: config.RequireEmail,
GenerateFakeEmail: config.GenerateFakeEmail,
}
Expand Down
37 changes: 37 additions & 0 deletions recipe/thirdparty/providers/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package providers

import (
"encoding/base64"
"errors"
"fmt"

Expand All @@ -36,6 +37,42 @@ func Github(input tpmodels.ProviderInput) *tpmodels.TypeProvider {
input.Config.TokenEndpoint = "https://github.com/login/oauth/access_token"
}

if input.Config.ValidateAccessToken == nil {
input.Config.ValidateAccessToken = func(accessToken string, clientConfig tpmodels.ProviderConfigForClientType, userContext supertokens.UserContext) error {
basicAuthToken := base64.StdEncoding.EncodeToString([]byte(clientConfig.ClientID + ":" + clientConfig.ClientSecret))
wrongClientIdError := errors.New("Access token does not belong to your application")

resp, err := doPostRequest("https://api.github.com/applications/"+clientConfig.ClientID+"/token", map[string]interface{}{
"access_token": accessToken,
}, map[string]interface{}{
"Authorization": "Basic " + basicAuthToken,
"Content-Type": "application/json",
})

if err != nil {
return errors.New("Invalid access token")
}

app, appOk := resp["app"]

if !appOk {
return wrongClientIdError
}

clientId, clientIdOk := app.(map[string]interface{})["client_id"]

if !clientIdOk {
return wrongClientIdError
}

if clientId != clientConfig.ClientID {
return wrongClientIdError
}

return nil
}
}

oOverride := input.Override

input.Override = func(originalImplementation *tpmodels.TypeProvider) *tpmodels.TypeProvider {
Expand Down
7 changes: 7 additions & 0 deletions recipe/thirdparty/providers/oauth2_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ func oauth2_GetUserInfo(config tpmodels.ProviderConfigForClientType, oAuthTokens
}
}

if config.ValidateAccessToken != nil && accessTokenOk {
err := config.ValidateAccessToken(accessToken, config, userContext)
if err != nil {
return tpmodels.TypeUserInfo{}, err
}
}

if accessTokenOk && config.UserInfoEndpoint != "" {
headers := map[string]string{
"Authorization": "Bearer " + accessToken,
Expand Down
2 changes: 2 additions & 0 deletions recipe/thirdparty/tpmodels/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ type ProviderConfig struct {
RequireEmail *bool `json:"requireEmail,omitempty"`

ValidateIdTokenPayload func(idTokenPayload map[string]interface{}, clientConfig ProviderConfigForClientType, userContext supertokens.UserContext) error `json:"-"`
ValidateAccessToken func(accessToken string, clientConfig ProviderConfigForClientType, userContext supertokens.UserContext) error `json:"-"`
GenerateFakeEmail func(thirdPartyUserId string, tenantId string, userContext supertokens.UserContext) string `json:"-"`
}

Expand Down Expand Up @@ -158,6 +159,7 @@ type ProviderConfigForClientType struct {
OIDCDiscoveryEndpoint string
UserInfoMap TypeUserInfoMap
ValidateIdTokenPayload func(idTokenPayload map[string]interface{}, clientConfig ProviderConfigForClientType, userContext supertokens.UserContext) error
ValidateAccessToken func(accessToken string, clientConfig ProviderConfigForClientType, userContext supertokens.UserContext) error

RequireEmail *bool
GenerateFakeEmail func(thirdPartyUserId string, tenantId string, userContext supertokens.UserContext) string
Expand Down
2 changes: 1 addition & 1 deletion supertokens/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const (
)

// VERSION current version of the lib
const VERSION = "0.16.0"
const VERSION = "0.16.1"

var (
cdiSupported = []string{"3.0"}
Expand Down

0 comments on commit d65e7f6

Please sign in to comment.