diff --git a/google/google.go b/google/google.go index d765ca5..4a43271 100644 --- a/google/google.go +++ b/google/google.go @@ -11,7 +11,9 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" "os" + "strings" "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" @@ -44,6 +46,8 @@ func init() { gob.Register(goauth.Userinfo{}) } +var loginURL string + func randToken() string { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { @@ -74,6 +78,19 @@ func Setup(redirectURL, credFile string, scopes []string, secret []byte) { } } +// SetupFromString accepts string values for ouath2 Configs +func SetupFromString(redirectURL, clientID string, clientSecret string, scopes []string, secret []byte) { + store = cookie.NewStore(secret) + + conf = &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURL, + Scopes: scopes, + Endpoint: google.Endpoint, + } +} + func Session(name string) gin.HandlerFunc { return sessions.Sessions(name, store) } @@ -100,29 +117,39 @@ func GetLoginURL(state string) string { return conf.AuthCodeURL(state) } +func WithLoginURL(s string) error { + s = strings.TrimSpace(s) + url, err := url.ParseRequestURI(s) + if err != nil { + return err + } + loginURL = url.String() + return nil +} + // Auth is the google authorization middleware. You can use them to protect a routergroup. // Example: // -// private.Use(google.Auth()) -// private.GET("/", UserInfoHandler) -// private.GET("/api", func(ctx *gin.Context) { -// ctx.JSON(200, gin.H{"message": "Hello from private for groups"}) -// }) +// private.Use(google.Auth()) +// private.GET("/", UserInfoHandler) +// private.GET("/api", func(ctx *gin.Context) { +// ctx.JSON(200, gin.H{"message": "Hello from private for groups"}) +// }) // -// // Requires google oauth pkg to be imported as `goauth "google.golang.org/api/oauth2/v2"` -// func UserInfoHandler(ctx *gin.Context) { -// var ( -// res goauth.Userinfo -// ok bool -// ) +// // Requires google oauth pkg to be imported as `goauth "google.golang.org/api/oauth2/v2"` +// func UserInfoHandler(ctx *gin.Context) { +// var ( +// res goauth.Userinfo +// ok bool +// ) // -// val := ctx.MustGet("user") -// if res, ok = val.(goauth.Userinfo); !ok { -// res = goauth.Userinfo{Name: "no user"} -// } +// val := ctx.MustGet("user") +// if res, ok = val.(goauth.Userinfo); !ok { +// res = goauth.Userinfo{Name: "no user"} +// } // -// ctx.JSON(http.StatusOK, gin.H{"Hello": "from private", "user": res.Email}) -// } +// ctx.JSON(http.StatusOK, gin.H{"Hello": "from private", "user": res.Email}) +// } func Auth() gin.HandlerFunc { return func(ctx *gin.Context) { // Handle the exchange code to initiate a transport. @@ -137,7 +164,11 @@ func Auth() gin.HandlerFunc { retrievedState := session.Get(stateKey) if retrievedState != ctx.Query(stateKey) { - ctx.AbortWithError(http.StatusUnauthorized, fmt.Errorf("invalid session state: %s", retrievedState)) + if loginURL != "" { + ctx.Redirect(302, loginURL) + } else { + ctx.AbortWithError(http.StatusUnauthorized, fmt.Errorf("invalid session state: %s", retrievedState)) + } return } diff --git a/google/google_test.go b/google/google_test.go new file mode 100644 index 0000000..cb887f2 --- /dev/null +++ b/google/google_test.go @@ -0,0 +1,56 @@ +package google + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSetupFromString(t *testing.T) { + t.Run("should assign store and config accordingly", func(t *testing.T) { + store = nil + conf = nil + SetupFromString("http://fake.fake", "clientid", "clientsecret", []string{}, []byte("secret")) + assert.NotNil(t, conf) + assert.NotNil(t, store) + assert.Equal(t, conf.ClientID, "clientid") + assert.Equal(t, conf.ClientSecret, "clientsecret") + }) +} + +func TestWithLoginURL(t *testing.T) { + + var testCases = []struct { + description string + urlParm string + expectURLLogin string + isErrNil bool + }{ + { + description: "should assign a valid url without error", + urlParm: "http://fake.fake", + expectURLLogin: "http://fake.fake", + isErrNil: true, + }, + { + description: "should assign a sanitizable url without error", + urlParm: " http://fake.fake ", + expectURLLogin: "http://fake.fake", + isErrNil: true, + }, + { + description: "should not assign an invalid url, and should return an error", + urlParm: "not a parseable url", + expectURLLogin: "", + isErrNil: false, + }, + } + for _, testCase := range testCases { + t.Run(testCase.description, func(t *testing.T) { + loginURL = "" + err := WithLoginURL(testCase.urlParm) + assert.Equal(t, testCase.expectURLLogin, loginURL) + assert.Equal(t, testCase.isErrNil, err == nil) + }) + } +}