Skip to content

Commit

Permalink
Fix provider deduction from existing sessions so that an argument is …
Browse files Browse the repository at this point in the history
…not needed on the authentication redirect page.

This has been broken since 78433fe.

Also moving the deduction code to be a fallback: if a provider is explicitly
specified it should not be overridden by the existing session. I believe
this is why the Logout() problems with multiple providers existed in the
first place. If deduction wins over explicit parameters, then once you
wrongly click one provider, the cookie will force you to use that
authentication mechanism even if you've navigated to a URL that
explicitly states another provider's name.

Fixed the test harness to check session key names (this was a bug in the
test harness) and added a test to verify that this change works.
  • Loading branch information
akramer committed Aug 11, 2019
1 parent 3b80120 commit e211ab6
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 23 deletions.
24 changes: 11 additions & 13 deletions gothic/gothic.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,19 +245,6 @@ var GetProviderName = getProviderName

func getProviderName(req *http.Request) (string, error) {

// get all the used providers
providers := goth.GetProviders()

// loop over the used providers, if we already have a valid session for any provider (ie. user is already logged-in with a provider), then return that provider name
for _, provider := range providers {
p := provider.Name()
session, _ := Store.Get(req, p+SessionName)
value := session.Values[p]
if _, ok := value.(string); ok {
return p, nil
}
}

// try to get it from the url param "provider"
if p := req.URL.Query().Get("provider"); p != "" {
return p, nil
Expand All @@ -278,6 +265,17 @@ func getProviderName(req *http.Request) (string, error) {
return p, nil
}

// As a fallback, loop over the used providers, if we already have a valid session for any provider (ie. user has already begun authentication with a provider), then return that provider name
providers := goth.GetProviders()
session, _ := Store.Get(req, SessionName)
for _, provider := range providers {
p := provider.Name()
value := session.Values[p]
if _, ok := value.(string); ok {
return p, nil
}
}

// if not found then return an empty string with the corresponding error
return "", errors.New("you must select a provider")
}
Expand Down
46 changes: 36 additions & 10 deletions gothic/gothic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@ import (
"github.com/stretchr/testify/assert"
)

type mapKey struct {
r *http.Request
n string
}

type ProviderStore struct {
Store map[*http.Request]*sessions.Session
Store map[mapKey]*sessions.Session
}

func NewProviderStore() *ProviderStore {
return &ProviderStore{map[*http.Request]*sessions.Session{}}
return &ProviderStore{map[mapKey]*sessions.Session{}}
}

func (p ProviderStore) Get(r *http.Request, name string) (*sessions.Session, error) {
s := p.Store[r]
s := p.Store[mapKey{r, name}]
if s == nil {
s, err := p.New(r, name)
return s, err
Expand All @@ -42,12 +47,12 @@ func (p ProviderStore) New(r *http.Request, name string) (*sessions.Session, err
Path: "/",
MaxAge: 86400 * 30,
}
p.Store[r] = s
p.Store[mapKey{r, name}] = s
return s, nil
}

func (p ProviderStore) Save(r *http.Request, w http.ResponseWriter, s *sessions.Session) error {
p.Store[r] = s
p.Store[mapKey{r, s.Name()}] = s
return nil
}

Expand All @@ -68,7 +73,7 @@ func Test_BeginAuthHandler(t *testing.T) {

BeginAuthHandler(res, req)

sess, err := Store.Get(req, "faux"+SessionName)
sess, err := Store.Get(req, SessionName)
if err != nil {
t.Fatalf("error getting faux Gothic session: %v", err)
}
Expand Down Expand Up @@ -128,7 +133,28 @@ func Test_CompleteUserAuth(t *testing.T) {
a.NoError(err)

sess := faux.Session{Name: "Homer Simpson", Email: "[email protected]"}
session, _ := Store.Get(req, "faux"+SessionName)
session, _ := Store.Get(req, SessionName)
session.Values["faux"] = gzipString(sess.Marshal())
err = session.Save(req, res)
a.NoError(err)

user, err := CompleteUserAuth(res, req)
a.NoError(err)

a.Equal(user.Name, "Homer Simpson")
a.Equal(user.Email, "[email protected]")
}

func Test_CompleteUserAuthWithSessionDeducedProvider(t *testing.T) {
a := assert.New(t)

res := httptest.NewRecorder()
// Inteintionally omit a provider argument, force looking in session.
req, err := http.NewRequest("GET", "/auth/callback", nil)
a.NoError(err)

sess := faux.Session{Name: "Homer Simpson", Email: "[email protected]"}
session, _ := Store.Get(req, SessionName)
session.Values["faux"] = gzipString(sess.Marshal())
err = session.Save(req, res)
a.NoError(err)
Expand All @@ -148,7 +174,7 @@ func Test_Logout(t *testing.T) {
a.NoError(err)

sess := faux.Session{Name: "Homer Simpson", Email: "[email protected]"}
session, _ := Store.Get(req, "faux"+SessionName)
session, _ := Store.Get(req, SessionName)
session.Values["faux"] = gzipString(sess.Marshal())
err = session.Save(req, res)
a.NoError(err)
Expand All @@ -160,7 +186,7 @@ func Test_Logout(t *testing.T) {
a.Equal(user.Email, "[email protected]")
err = Logout(res, req)
a.NoError(err)
session, _ = Store.Get(req, "faux"+SessionName)
session, _ = Store.Get(req, SessionName)
a.Equal(session.Values, make(map[interface{}]interface{}))
a.Equal(session.Options.MaxAge, -1)
}
Expand Down Expand Up @@ -188,7 +214,7 @@ func Test_StateValidation(t *testing.T) {
a.NoError(err)

BeginAuthHandler(res, req)
session, _ := Store.Get(req, "faux"+SessionName)
session, _ := Store.Get(req, SessionName)

// Assert that matching states will return a nil error
req, err = http.NewRequest("GET", "/auth/callback?provider=faux&state=state_REAL", nil)
Expand Down

0 comments on commit e211ab6

Please sign in to comment.