From 91e26c448b36f2dad705c896f05f3ddf3cfb0c6d Mon Sep 17 00:00:00 2001 From: dave vader <48764154+plyr4@users.noreply.github.com> Date: Fri, 1 Sep 2023 10:28:27 -0500 Subject: [PATCH] enhance: add context to Users (#941) --- api/admin/user.go | 5 +- api/auth/get_token.go | 8 ++- api/auth/logout.go | 3 +- api/auth/post_token.go | 5 +- api/build/create.go | 2 +- api/build/restart.go | 2 +- api/build/update.go | 2 +- api/metrics.go | 2 +- api/pipeline/template.go | 3 +- api/repo/update.go | 2 +- api/user/create.go | 3 +- api/user/create_token.go | 3 +- api/user/delete.go | 5 +- api/user/delete_token.go | 3 +- api/user/get.go | 3 +- api/user/list.go | 3 +- api/user/update.go | 5 +- api/user/update_current.go | 3 +- api/webhook/post.go | 2 +- cmd/vela-server/schedule.go | 2 +- database/integration_test.go | 14 ++--- database/resource.go | 1 + database/user/count.go | 4 +- database/user/count_test.go | 7 ++- database/user/create.go | 3 +- database/user/create_test.go | 3 +- database/user/delete.go | 4 +- database/user/delete_test.go | 5 +- database/user/get.go | 4 +- database/user/get_name.go | 4 +- database/user/get_name_test.go | 5 +- database/user/get_test.go | 5 +- database/user/index.go | 4 +- database/user/index_test.go | 3 +- database/user/interface.go | 22 ++++--- database/user/list.go | 6 +- database/user/list_lite.go | 6 +- database/user/list_lite_test.go | 7 ++- database/user/list_test.go | 7 ++- database/user/opts.go | 11 ++++ database/user/opts_test.go | 50 +++++++++++++++ database/user/table.go | 4 +- database/user/table_test.go | 3 +- database/user/update.go | 3 +- database/user/update_test.go | 5 +- database/user/user.go | 7 ++- internal/token/refresh.go | 5 +- internal/token/refresh_test.go | 14 +++-- router/middleware/claims/claims_test.go | 5 +- router/middleware/perm/perm.go | 9 ++- router/middleware/perm/perm_test.go | 68 ++++++++++----------- router/middleware/pipeline/pipeline_test.go | 4 +- router/middleware/user/user.go | 3 +- router/middleware/user/user_test.go | 5 +- 54 files changed, 252 insertions(+), 124 deletions(-) diff --git a/api/admin/user.go b/api/admin/user.go index c0c2aacfd..9893c0250 100644 --- a/api/admin/user.go +++ b/api/admin/user.go @@ -53,6 +53,9 @@ import ( func UpdateUser(c *gin.Context) { logrus.Info("Admin: updating user in database") + // capture middleware values + ctx := c.Request.Context() + // capture body from API request input := new(library.User) @@ -66,7 +69,7 @@ func UpdateUser(c *gin.Context) { } // send API call to update the user - u, err := database.FromContext(c).UpdateUser(input) + u, err := database.FromContext(c).UpdateUser(ctx, input) if err != nil { retErr := fmt.Errorf("unable to update user %d: %w", input.GetID(), err) diff --git a/api/auth/get_token.go b/api/auth/get_token.go index 2b6d7da66..81f632f4d 100644 --- a/api/auth/get_token.go +++ b/api/auth/get_token.go @@ -62,6 +62,8 @@ func GetAuthToken(c *gin.Context) { var err error tm := c.MustGet("token-manager").(*token.Manager) + // capture middleware values + ctx := c.Request.Context() // capture the OAuth state if present oAuthState := c.Request.FormValue("state") @@ -97,7 +99,7 @@ func GetAuthToken(c *gin.Context) { } // send API call to capture the user logging in - u, err := database.FromContext(c).GetUserForName(newUser.GetName()) + u, err := database.FromContext(c).GetUserForName(ctx, newUser.GetName()) // create a new user account if len(u.GetName()) == 0 || err != nil { // create the user account @@ -121,7 +123,7 @@ func GetAuthToken(c *gin.Context) { u.SetRefreshToken(rt) // send API call to create the user in the database - _, err = database.FromContext(c).CreateUser(u) + _, err = database.FromContext(c).CreateUser(ctx, u) if err != nil { retErr := fmt.Errorf("unable to create user %s: %w", u.GetName(), err) @@ -154,7 +156,7 @@ func GetAuthToken(c *gin.Context) { u.SetRefreshToken(rt) // send API call to update the user in the database - _, err = database.FromContext(c).UpdateUser(u) + _, err = database.FromContext(c).UpdateUser(ctx, u) if err != nil { retErr := fmt.Errorf("unable to update user %s: %w", u.GetName(), err) diff --git a/api/auth/logout.go b/api/auth/logout.go index 1bc8f8d58..a496f15cd 100644 --- a/api/auth/logout.go +++ b/api/auth/logout.go @@ -46,6 +46,7 @@ func Logout(c *gin.Context) { m := c.MustGet("metadata").(*types.Metadata) // capture middleware values u := user.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -75,7 +76,7 @@ func Logout(c *gin.Context) { u.SetRefreshToken("") // send API call to update the user in the database - _, err = database.FromContext(c).UpdateUser(u) + _, err = database.FromContext(c).UpdateUser(ctx, u) if err != nil { retErr := fmt.Errorf("unable to update user %s: %w", u.GetName(), err) diff --git a/api/auth/post_token.go b/api/auth/post_token.go index 8b386bfcf..55fe47fdf 100644 --- a/api/auth/post_token.go +++ b/api/auth/post_token.go @@ -49,6 +49,9 @@ import ( // process a user logging in using PAT to Vela from // the API. func PostAuthToken(c *gin.Context) { + // capture middleware values + ctx := c.Request.Context() + // attempt to get user from source u, err := scm.FromContext(c).AuthenticateToken(c.Request) if err != nil { @@ -60,7 +63,7 @@ func PostAuthToken(c *gin.Context) { } // check if the user exists - u, err = database.FromContext(c).GetUserForName(u.GetName()) + u, err = database.FromContext(c).GetUserForName(ctx, u.GetName()) if err != nil { retErr := fmt.Errorf("user %s not found", u.GetName()) diff --git a/api/build/create.go b/api/build/create.go index ee0d0f64b..5398c0d91 100644 --- a/api/build/create.go +++ b/api/build/create.go @@ -123,7 +123,7 @@ func CreateBuild(c *gin.Context) { } // send API call to capture the repo owner - u, err = database.FromContext(c).GetUser(r.GetUserID()) + u, err = database.FromContext(c).GetUser(ctx, r.GetUserID()) if err != nil { retErr := fmt.Errorf("unable to get owner for %s: %w", r.GetFullName(), err) diff --git a/api/build/restart.go b/api/build/restart.go index 804ad88bd..0f00d7b6e 100644 --- a/api/build/restart.go +++ b/api/build/restart.go @@ -103,7 +103,7 @@ func RestartBuild(c *gin.Context) { logger.Infof("restarting build %s", entry) // send API call to capture the repo owner - u, err := database.FromContext(c).GetUser(r.GetUserID()) + u, err := database.FromContext(c).GetUser(ctx, r.GetUserID()) if err != nil { retErr := fmt.Errorf("unable to get owner for %s: %w", r.GetFullName(), err) diff --git a/api/build/update.go b/api/build/update.go index d042fdd82..4a6ab2004 100644 --- a/api/build/update.go +++ b/api/build/update.go @@ -170,7 +170,7 @@ func UpdateBuild(c *gin.Context) { b.GetStatus() == constants.StatusKilled || b.GetStatus() == constants.StatusError { // send API call to capture the repo owner - u, err := database.FromContext(c).GetUser(r.GetUserID()) + u, err := database.FromContext(c).GetUser(ctx, r.GetUserID()) if err != nil { logrus.Errorf("unable to get owner for build %s: %v", entry, err) } diff --git a/api/metrics.go b/api/metrics.go index cb3832b9b..3eb1f90e1 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -250,7 +250,7 @@ func recordGauges(c *gin.Context) { // user_count if q.UserCount { // send API call to capture the total number of users - u, err := database.FromContext(c).CountUsers() + u, err := database.FromContext(c).CountUsers(ctx) if err != nil { logrus.Errorf("unable to get count of all users: %v", err) } diff --git a/api/pipeline/template.go b/api/pipeline/template.go index 54233b691..7b00abb5d 100644 --- a/api/pipeline/template.go +++ b/api/pipeline/template.go @@ -82,6 +82,7 @@ func GetTemplates(c *gin.Context) { p := pipeline.Retrieve(c) r := repo.Retrieve(c) u := user.Retrieve(c) + ctx := c.Request.Context() entry := fmt.Sprintf("%s/%s", r.GetFullName(), p.GetCommit()) @@ -107,7 +108,7 @@ func GetTemplates(c *gin.Context) { } // send API call to capture the repo owner - user, err := database.FromContext(c).GetUser(r.GetUserID()) + user, err := database.FromContext(c).GetUser(ctx, r.GetUserID()) if err != nil { util.HandleError(c, http.StatusBadRequest, fmt.Errorf("unable to get owner for %s: %w", r.GetFullName(), err)) diff --git a/api/repo/update.go b/api/repo/update.go index 6ed328490..947ddba4d 100644 --- a/api/repo/update.go +++ b/api/repo/update.go @@ -268,7 +268,7 @@ func UpdateRepo(c *gin.Context) { // capture admin name for logging admn := u.GetName() - u, err = database.FromContext(c).GetUser(r.GetUserID()) + u, err = database.FromContext(c).GetUser(ctx, r.GetUserID()) if err != nil { retErr := fmt.Errorf("unable to get repo owner of %s for platform admin webhook update: %w", r.GetFullName(), err) diff --git a/api/user/create.go b/api/user/create.go index bb7823546..691f448ad 100644 --- a/api/user/create.go +++ b/api/user/create.go @@ -51,6 +51,7 @@ import ( func CreateUser(c *gin.Context) { // capture middleware values u := user.Retrieve(c) + ctx := c.Request.Context() // capture body from API request input := new(library.User) @@ -72,7 +73,7 @@ func CreateUser(c *gin.Context) { }).Infof("creating new user %s", input.GetName()) // send API call to create the user - user, err := database.FromContext(c).CreateUser(input) + user, err := database.FromContext(c).CreateUser(ctx, input) if err != nil { retErr := fmt.Errorf("unable to create user: %w", err) diff --git a/api/user/create_token.go b/api/user/create_token.go index 427fab34d..f114f369a 100644 --- a/api/user/create_token.go +++ b/api/user/create_token.go @@ -42,6 +42,7 @@ import ( func CreateToken(c *gin.Context) { // capture middleware values u := user.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -65,7 +66,7 @@ func CreateToken(c *gin.Context) { u.SetRefreshToken(rt) // send API call to update the user - _, err = database.FromContext(c).UpdateUser(u) + _, err = database.FromContext(c).UpdateUser(ctx, u) if err != nil { retErr := fmt.Errorf("unable to update user %s: %w", u.GetName(), err) diff --git a/api/user/delete.go b/api/user/delete.go index 9e0651e7e..29c63135a 100644 --- a/api/user/delete.go +++ b/api/user/delete.go @@ -50,6 +50,7 @@ func DeleteUser(c *gin.Context) { // capture middleware values u := user.Retrieve(c) user := util.PathParameter(c, "user") + ctx := c.Request.Context() // update engine logger with API metadata // @@ -59,7 +60,7 @@ func DeleteUser(c *gin.Context) { }).Infof("deleting user %s", user) // send API call to capture the user - u, err := database.FromContext(c).GetUserForName(user) + u, err := database.FromContext(c).GetUserForName(ctx, user) if err != nil { retErr := fmt.Errorf("unable to get user %s: %w", user, err) @@ -69,7 +70,7 @@ func DeleteUser(c *gin.Context) { } // send API call to remove the user - err = database.FromContext(c).DeleteUser(u) + err = database.FromContext(c).DeleteUser(ctx, u) if err != nil { retErr := fmt.Errorf("unable to delete user %s: %w", u.GetName(), err) diff --git a/api/user/delete_token.go b/api/user/delete_token.go index 05a7a3699..d24ae8bf0 100644 --- a/api/user/delete_token.go +++ b/api/user/delete_token.go @@ -42,6 +42,7 @@ import ( func DeleteToken(c *gin.Context) { // capture middleware values u := user.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -65,7 +66,7 @@ func DeleteToken(c *gin.Context) { u.SetRefreshToken(rt) // send API call to update the user - _, err = database.FromContext(c).UpdateUser(u) + _, err = database.FromContext(c).UpdateUser(ctx, u) if err != nil { retErr := fmt.Errorf("unable to update user %s: %w", u.GetName(), err) diff --git a/api/user/get.go b/api/user/get.go index 9f0b723d3..bf6e6f519 100644 --- a/api/user/get.go +++ b/api/user/get.go @@ -46,6 +46,7 @@ func GetUser(c *gin.Context) { // capture middleware values u := user.Retrieve(c) user := util.PathParameter(c, "user") + ctx := c.Request.Context() // update engine logger with API metadata // @@ -55,7 +56,7 @@ func GetUser(c *gin.Context) { }).Infof("reading user %s", user) // send API call to capture the user - u, err := database.FromContext(c).GetUserForName(user) + u, err := database.FromContext(c).GetUserForName(ctx, user) if err != nil { retErr := fmt.Errorf("unable to get user %s: %w", user, err) diff --git a/api/user/list.go b/api/user/list.go index 1f879bfef..a1d0c487d 100644 --- a/api/user/list.go +++ b/api/user/list.go @@ -66,6 +66,7 @@ import ( func ListUsers(c *gin.Context) { // capture middleware values u := user.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -98,7 +99,7 @@ func ListUsers(c *gin.Context) { perPage = util.MaxInt(1, util.MinInt(100, perPage)) // send API call to capture the list of users - users, t, err := database.FromContext(c).ListLiteUsers(page, perPage) + users, t, err := database.FromContext(c).ListLiteUsers(ctx, page, perPage) if err != nil { retErr := fmt.Errorf("unable to get users: %w", err) diff --git a/api/user/update.go b/api/user/update.go index 21a089e04..f65335a0e 100644 --- a/api/user/update.go +++ b/api/user/update.go @@ -61,6 +61,7 @@ func UpdateUser(c *gin.Context) { // capture middleware values u := user.Retrieve(c) user := util.PathParameter(c, "user") + ctx := c.Request.Context() // update engine logger with API metadata // @@ -82,7 +83,7 @@ func UpdateUser(c *gin.Context) { } // send API call to capture the user - u, err = database.FromContext(c).GetUserForName(user) + u, err = database.FromContext(c).GetUserForName(ctx, user) if err != nil { retErr := fmt.Errorf("unable to get user %s: %w", user, err) @@ -108,7 +109,7 @@ func UpdateUser(c *gin.Context) { } // send API call to update the user - u, err = database.FromContext(c).UpdateUser(u) + u, err = database.FromContext(c).UpdateUser(ctx, u) if err != nil { retErr := fmt.Errorf("unable to update user %s: %w", user, err) diff --git a/api/user/update_current.go b/api/user/update_current.go index 0eeb4e1be..dd5051277 100644 --- a/api/user/update_current.go +++ b/api/user/update_current.go @@ -55,6 +55,7 @@ import ( func UpdateCurrentUser(c *gin.Context) { // capture middleware values u := user.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -82,7 +83,7 @@ func UpdateCurrentUser(c *gin.Context) { } // send API call to update the user - u, err = database.FromContext(c).UpdateUser(u) + u, err = database.FromContext(c).UpdateUser(ctx, u) if err != nil { retErr := fmt.Errorf("unable to update user %s: %w", u.GetName(), err) diff --git a/api/webhook/post.go b/api/webhook/post.go index a413a0f0e..be9663d06 100644 --- a/api/webhook/post.go +++ b/api/webhook/post.go @@ -286,7 +286,7 @@ func PostWebhook(c *gin.Context) { // send API call to capture repo owner logrus.Debugf("capturing owner of repository %s", repo.GetFullName()) - u, err := database.FromContext(c).GetUser(repo.GetUserID()) + u, err := database.FromContext(c).GetUser(ctx, repo.GetUserID()) if err != nil { retErr := fmt.Errorf("%s: failed to get owner for %s: %w", baseErr, repo.GetFullName(), err) util.HandleError(c, http.StatusBadRequest, retErr) diff --git a/cmd/vela-server/schedule.go b/cmd/vela-server/schedule.go index b65723177..9d6c77c7f 100644 --- a/cmd/vela-server/schedule.go +++ b/cmd/vela-server/schedule.go @@ -156,7 +156,7 @@ func processSchedule(ctx context.Context, s *library.Schedule, compiler compiler } // send API call to capture the owner for the repo - u, err := database.GetUser(r.GetUserID()) + u, err := database.GetUser(ctx, r.GetUserID()) if err != nil { return fmt.Errorf("unable to get owner for repo %s: %w", r.GetFullName(), err) } diff --git a/database/integration_test.go b/database/integration_test.go index 23f6e9f5c..38f5387c0 100644 --- a/database/integration_test.go +++ b/database/integration_test.go @@ -1656,7 +1656,7 @@ func testUsers(t *testing.T, db Interface, resources *Resources) { // create the users for _, user := range resources.Users { - _, err := db.CreateUser(user) + _, err := db.CreateUser(context.TODO(), user) if err != nil { t.Errorf("unable to create user %d: %v", user.GetID(), err) } @@ -1664,7 +1664,7 @@ func testUsers(t *testing.T, db Interface, resources *Resources) { methods["CreateUser"] = true // count the users - count, err := db.CountUsers() + count, err := db.CountUsers(context.TODO()) if err != nil { t.Errorf("unable to count users: %v", err) } @@ -1674,7 +1674,7 @@ func testUsers(t *testing.T, db Interface, resources *Resources) { methods["CountUsers"] = true // list the users - list, err := db.ListUsers() + list, err := db.ListUsers(context.TODO()) if err != nil { t.Errorf("unable to list users: %v", err) } @@ -1684,7 +1684,7 @@ func testUsers(t *testing.T, db Interface, resources *Resources) { methods["ListUsers"] = true // lite list the users - list, count, err = db.ListLiteUsers(1, 10) + list, count, err = db.ListLiteUsers(context.TODO(), 1, 10) if err != nil { t.Errorf("unable to list lite users: %v", err) } @@ -1698,7 +1698,7 @@ func testUsers(t *testing.T, db Interface, resources *Resources) { // lookup the users by name for _, user := range resources.Users { - got, err := db.GetUserForName(user.GetName()) + got, err := db.GetUserForName(context.TODO(), user.GetName()) if err != nil { t.Errorf("unable to get user %d by name: %v", user.GetID(), err) } @@ -1711,7 +1711,7 @@ func testUsers(t *testing.T, db Interface, resources *Resources) { // update the users for _, user := range resources.Users { user.SetActive(false) - got, err := db.UpdateUser(user) + got, err := db.UpdateUser(context.TODO(), user) if err != nil { t.Errorf("unable to update user %d: %v", user.GetID(), err) } @@ -1725,7 +1725,7 @@ func testUsers(t *testing.T, db Interface, resources *Resources) { // delete the users for _, user := range resources.Users { - err = db.DeleteUser(user) + err = db.DeleteUser(context.TODO(), user) if err != nil { t.Errorf("unable to delete user %d: %v", user.GetID(), err) } diff --git a/database/resource.go b/database/resource.go index 6da486717..4ae8ce9d8 100644 --- a/database/resource.go +++ b/database/resource.go @@ -141,6 +141,7 @@ func (e *engine) NewResources(ctx context.Context) error { // create the database agnostic engine for users e.UserInterface, err = user.New( + user.WithContext(e.ctx), user.WithClient(e.client), user.WithEncryptionKey(e.config.EncryptionKey), user.WithLogger(e.logger), diff --git a/database/user/count.go b/database/user/count.go index 074a5ef66..6786f6e92 100644 --- a/database/user/count.go +++ b/database/user/count.go @@ -5,11 +5,13 @@ package user import ( + "context" + "github.com/go-vela/types/constants" ) // CountUsers gets the count of all users from the database. -func (e *engine) CountUsers() (int64, error) { +func (e *engine) CountUsers(ctx context.Context) (int64, error) { e.logger.Tracef("getting count of all users from the database") // variable to store query results diff --git a/database/user/count_test.go b/database/user/count_test.go index be4c77f0c..912e7c511 100644 --- a/database/user/count_test.go +++ b/database/user/count_test.go @@ -5,6 +5,7 @@ package user import ( + "context" "reflect" "testing" @@ -37,12 +38,12 @@ func TestUser_Engine_CountUsers(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateUser(_userOne) + _, err := _sqlite.CreateUser(context.TODO(), _userOne) if err != nil { t.Errorf("unable to create test user for sqlite: %v", err) } - _, err = _sqlite.CreateUser(_userTwo) + _, err = _sqlite.CreateUser(context.TODO(), _userTwo) if err != nil { t.Errorf("unable to create test user for sqlite: %v", err) } @@ -71,7 +72,7 @@ func TestUser_Engine_CountUsers(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.CountUsers() + got, err := test.database.CountUsers(context.TODO()) if test.failure { if err == nil { diff --git a/database/user/create.go b/database/user/create.go index 4a45f0d89..33808c2e4 100644 --- a/database/user/create.go +++ b/database/user/create.go @@ -6,6 +6,7 @@ package user import ( + "context" "fmt" "github.com/go-vela/types/constants" @@ -15,7 +16,7 @@ import ( ) // CreateUser creates a new user in the database. -func (e *engine) CreateUser(u *library.User) (*library.User, error) { +func (e *engine) CreateUser(ctx context.Context, u *library.User) (*library.User, error) { e.logger.WithFields(logrus.Fields{ "user": u.GetName(), }).Tracef("creating user %s in the database", u.GetName()) diff --git a/database/user/create_test.go b/database/user/create_test.go index 815b3ef88..f08b1372b 100644 --- a/database/user/create_test.go +++ b/database/user/create_test.go @@ -5,6 +5,7 @@ package user import ( + "context" "reflect" "testing" @@ -56,7 +57,7 @@ VALUES ($1,$2,$3,$4,$5,$6,$7,$8) RETURNING "id"`). // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.CreateUser(_user) + got, err := test.database.CreateUser(context.TODO(), _user) if test.failure { if err == nil { diff --git a/database/user/delete.go b/database/user/delete.go index 95b77ff64..7cafb61c8 100644 --- a/database/user/delete.go +++ b/database/user/delete.go @@ -5,6 +5,8 @@ package user import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -12,7 +14,7 @@ import ( ) // DeleteUser deletes an existing user from the database. -func (e *engine) DeleteUser(u *library.User) error { +func (e *engine) DeleteUser(ctx context.Context, u *library.User) error { e.logger.WithFields(logrus.Fields{ "user": u.GetName(), }).Tracef("deleting user %s from the database", u.GetName()) diff --git a/database/user/delete_test.go b/database/user/delete_test.go index 937df9bc3..76db55ecb 100644 --- a/database/user/delete_test.go +++ b/database/user/delete_test.go @@ -5,6 +5,7 @@ package user import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -29,7 +30,7 @@ func TestUser_Engine_DeleteUser(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateUser(_user) + _, err := _sqlite.CreateUser(context.TODO(), _user) if err != nil { t.Errorf("unable to create test user for sqlite: %v", err) } @@ -55,7 +56,7 @@ func TestUser_Engine_DeleteUser(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err = test.database.DeleteUser(_user) + err = test.database.DeleteUser(context.TODO(), _user) if test.failure { if err == nil { diff --git a/database/user/get.go b/database/user/get.go index d37275c80..33f9ea7be 100644 --- a/database/user/get.go +++ b/database/user/get.go @@ -5,13 +5,15 @@ package user import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" ) // GetUser gets a user by ID from the database. -func (e *engine) GetUser(id int64) (*library.User, error) { +func (e *engine) GetUser(ctx context.Context, id int64) (*library.User, error) { e.logger.Tracef("getting user %d from the database", id) // variable to store query results diff --git a/database/user/get_name.go b/database/user/get_name.go index 4e8da5550..c975ca883 100644 --- a/database/user/get_name.go +++ b/database/user/get_name.go @@ -5,6 +5,8 @@ package user import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -12,7 +14,7 @@ import ( ) // GetUserForName gets a user by name from the database. -func (e *engine) GetUserForName(name string) (*library.User, error) { +func (e *engine) GetUserForName(ctx context.Context, name string) (*library.User, error) { e.logger.WithFields(logrus.Fields{ "user": name, }).Tracef("getting user %s from the database", name) diff --git a/database/user/get_name_test.go b/database/user/get_name_test.go index cb93ffd0b..71d72e181 100644 --- a/database/user/get_name_test.go +++ b/database/user/get_name_test.go @@ -5,6 +5,7 @@ package user import ( + "context" "reflect" "testing" @@ -35,7 +36,7 @@ func TestUser_Engine_GetUserForName(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateUser(_user) + _, err := _sqlite.CreateUser(context.TODO(), _user) if err != nil { t.Errorf("unable to create test user for sqlite: %v", err) } @@ -64,7 +65,7 @@ func TestUser_Engine_GetUserForName(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.GetUserForName("foo") + got, err := test.database.GetUserForName(context.TODO(), "foo") if test.failure { if err == nil { diff --git a/database/user/get_test.go b/database/user/get_test.go index 2ecac65f2..9a122d4e2 100644 --- a/database/user/get_test.go +++ b/database/user/get_test.go @@ -5,6 +5,7 @@ package user import ( + "context" "reflect" "testing" @@ -35,7 +36,7 @@ func TestUser_Engine_GetUser(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateUser(_user) + _, err := _sqlite.CreateUser(context.TODO(), _user) if err != nil { t.Errorf("unable to create test user for sqlite: %v", err) } @@ -64,7 +65,7 @@ func TestUser_Engine_GetUser(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.GetUser(1) + got, err := test.database.GetUser(context.TODO(), 1) if test.failure { if err == nil { diff --git a/database/user/index.go b/database/user/index.go index 5445e963e..eb2ef5544 100644 --- a/database/user/index.go +++ b/database/user/index.go @@ -4,6 +4,8 @@ package user +import "context" + const ( // CreateUserRefreshIndex represents a query to create an // index on the users table for the refresh_token column. @@ -16,7 +18,7 @@ ON users (refresh_token); ) // CreateUserIndexes creates the indexes for the users table in the database. -func (e *engine) CreateUserIndexes() error { +func (e *engine) CreateUserIndexes(ctx context.Context) error { e.logger.Tracef("creating indexes for users table in the database") // create the refresh_token column index for the users table diff --git a/database/user/index_test.go b/database/user/index_test.go index 55728b77a..c3c96b3de 100644 --- a/database/user/index_test.go +++ b/database/user/index_test.go @@ -5,6 +5,7 @@ package user import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -41,7 +42,7 @@ func TestUser_Engine_CreateUserIndexes(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err := test.database.CreateUserIndexes() + err := test.database.CreateUserIndexes(context.TODO()) if test.failure { if err == nil { diff --git a/database/user/interface.go b/database/user/interface.go index dc67c933c..ffbb62bfb 100644 --- a/database/user/interface.go +++ b/database/user/interface.go @@ -5,6 +5,8 @@ package user import ( + "context" + "github.com/go-vela/types/library" ) @@ -18,28 +20,28 @@ type UserInterface interface { // https://en.wikipedia.org/wiki/Data_definition_language // CreateUserIndexes defines a function that creates the indexes for the users table. - CreateUserIndexes() error + CreateUserIndexes(context.Context) error // CreateUserTable defines a function that creates the users table. - CreateUserTable(string) error + CreateUserTable(context.Context, string) error // User Data Manipulation Language Functions // // https://en.wikipedia.org/wiki/Data_manipulation_language // CountUsers defines a function that gets the count of all users. - CountUsers() (int64, error) + CountUsers(context.Context) (int64, error) // CreateUser defines a function that creates a new user. - CreateUser(*library.User) (*library.User, error) + CreateUser(context.Context, *library.User) (*library.User, error) // DeleteUser defines a function that deletes an existing user. - DeleteUser(*library.User) error + DeleteUser(context.Context, *library.User) error // GetUser defines a function that gets a user by ID. - GetUser(int64) (*library.User, error) + GetUser(context.Context, int64) (*library.User, error) // GetUserForName defines a function that gets a user by name. - GetUserForName(string) (*library.User, error) + GetUserForName(context.Context, string) (*library.User, error) // ListUsers defines a function that gets a list of all users. - ListUsers() ([]*library.User, error) + ListUsers(context.Context) ([]*library.User, error) // ListLiteUsers defines a function that gets a lite list of users. - ListLiteUsers(int, int) ([]*library.User, int64, error) + ListLiteUsers(context.Context, int, int) ([]*library.User, int64, error) // UpdateUser defines a function that updates an existing user. - UpdateUser(*library.User) (*library.User, error) + UpdateUser(context.Context, *library.User) (*library.User, error) } diff --git a/database/user/list.go b/database/user/list.go index 4bc730f27..70e8c26d8 100644 --- a/database/user/list.go +++ b/database/user/list.go @@ -5,13 +5,15 @@ package user import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" ) // ListUsers gets a list of all users from the database. -func (e *engine) ListUsers() ([]*library.User, error) { +func (e *engine) ListUsers(ctx context.Context) ([]*library.User, error) { e.logger.Trace("listing all users from the database") // variables to store query results and return value @@ -20,7 +22,7 @@ func (e *engine) ListUsers() ([]*library.User, error) { users := []*library.User{} // count the results - count, err := e.CountUsers() + count, err := e.CountUsers(ctx) if err != nil { return nil, err } diff --git a/database/user/list_lite.go b/database/user/list_lite.go index ee90bca3b..f5b2bc1fa 100644 --- a/database/user/list_lite.go +++ b/database/user/list_lite.go @@ -5,6 +5,8 @@ package user import ( + "context" + "github.com/go-vela/types/constants" "github.com/go-vela/types/database" "github.com/go-vela/types/library" @@ -13,7 +15,7 @@ import ( // ListLiteUsers gets a lite (only: id, name) list of users from the database. // //nolint:lll // ignore long line length due to variable names -func (e *engine) ListLiteUsers(page, perPage int) ([]*library.User, int64, error) { +func (e *engine) ListLiteUsers(ctx context.Context, page, perPage int) ([]*library.User, int64, error) { e.logger.Trace("listing lite users from the database") // variables to store query results and return values @@ -22,7 +24,7 @@ func (e *engine) ListLiteUsers(page, perPage int) ([]*library.User, int64, error users := []*library.User{} // count the results - count, err := e.CountUsers() + count, err := e.CountUsers(ctx) if err != nil { return users, 0, err } diff --git a/database/user/list_lite_test.go b/database/user/list_lite_test.go index 6dd88aa75..df279c39d 100644 --- a/database/user/list_lite_test.go +++ b/database/user/list_lite_test.go @@ -5,6 +5,7 @@ package user import ( + "context" "reflect" "testing" @@ -49,12 +50,12 @@ func TestUser_Engine_ListLiteUsers(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateUser(_userOne) + _, err := _sqlite.CreateUser(context.TODO(), _userOne) if err != nil { t.Errorf("unable to create test user for sqlite: %v", err) } - _, err = _sqlite.CreateUser(_userTwo) + _, err = _sqlite.CreateUser(context.TODO(), _userTwo) if err != nil { t.Errorf("unable to create test user for sqlite: %v", err) } @@ -94,7 +95,7 @@ func TestUser_Engine_ListLiteUsers(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, _, err := test.database.ListLiteUsers(1, 10) + got, _, err := test.database.ListLiteUsers(context.TODO(), 1, 10) if test.failure { if err == nil { diff --git a/database/user/list_test.go b/database/user/list_test.go index 9e088e72f..dcf1970b4 100644 --- a/database/user/list_test.go +++ b/database/user/list_test.go @@ -5,6 +5,7 @@ package user import ( + "context" "reflect" "testing" @@ -49,12 +50,12 @@ func TestUser_Engine_ListUsers(t *testing.T) { _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateUser(_userOne) + _, err := _sqlite.CreateUser(context.TODO(), _userOne) if err != nil { t.Errorf("unable to create test user for sqlite: %v", err) } - _, err = _sqlite.CreateUser(_userTwo) + _, err = _sqlite.CreateUser(context.TODO(), _userTwo) if err != nil { t.Errorf("unable to create test user for sqlite: %v", err) } @@ -83,7 +84,7 @@ func TestUser_Engine_ListUsers(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.ListUsers() + got, err := test.database.ListUsers(context.TODO()) if test.failure { if err == nil { diff --git a/database/user/opts.go b/database/user/opts.go index 58780c317..135cd789f 100644 --- a/database/user/opts.go +++ b/database/user/opts.go @@ -5,6 +5,8 @@ package user import ( + "context" + "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -52,3 +54,12 @@ func WithSkipCreation(skipCreation bool) EngineOpt { return nil } } + +// WithContext sets the context in the database engine for Users. +func WithContext(ctx context.Context) EngineOpt { + return func(e *engine) error { + e.ctx = ctx + + return nil + } +} diff --git a/database/user/opts_test.go b/database/user/opts_test.go index 77fb9ae23..867b4ad9c 100644 --- a/database/user/opts_test.go +++ b/database/user/opts_test.go @@ -5,6 +5,7 @@ package user import ( + "context" "reflect" "testing" @@ -208,3 +209,52 @@ func TestUser_EngineOpt_WithSkipCreation(t *testing.T) { }) } } + +func TestUser_EngineOpt_WithContext(t *testing.T) { + // setup types + e := &engine{config: new(config)} + + // setup tests + tests := []struct { + failure bool + name string + ctx context.Context + want context.Context + }{ + { + failure: false, + name: "context set to TODO", + ctx: context.TODO(), + want: context.TODO(), + }, + { + failure: false, + name: "context set to nil", + ctx: nil, + want: nil, + }, + } + + // run tests + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := WithContext(test.ctx)(e) + + if test.failure { + if err == nil { + t.Errorf("WithContext for %s should have returned err", test.name) + } + + return + } + + if err != nil { + t.Errorf("WithContext returned err: %v", err) + } + + if !reflect.DeepEqual(e.ctx, test.want) { + t.Errorf("WithContext is %v, want %v", e.ctx, test.want) + } + }) + } +} diff --git a/database/user/table.go b/database/user/table.go index 456853770..890e91a1b 100644 --- a/database/user/table.go +++ b/database/user/table.go @@ -5,6 +5,8 @@ package user import ( + "context" + "github.com/go-vela/types/constants" ) @@ -45,7 +47,7 @@ users ( ) // CreateUserTable creates the users table in the database. -func (e *engine) CreateUserTable(driver string) error { +func (e *engine) CreateUserTable(ctx context.Context, driver string) error { e.logger.Tracef("creating users table in the database") // handle the driver provided to create the table diff --git a/database/user/table_test.go b/database/user/table_test.go index 95a2d4c00..4f7b82ce2 100644 --- a/database/user/table_test.go +++ b/database/user/table_test.go @@ -5,6 +5,7 @@ package user import ( + "context" "testing" "github.com/DATA-DOG/go-sqlmock" @@ -41,7 +42,7 @@ func TestUser_Engine_CreateUserTable(t *testing.T) { // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err := test.database.CreateUserTable(test.name) + err := test.database.CreateUserTable(context.TODO(), test.name) if test.failure { if err == nil { diff --git a/database/user/update.go b/database/user/update.go index 2a215a75f..04d932e9d 100644 --- a/database/user/update.go +++ b/database/user/update.go @@ -6,6 +6,7 @@ package user import ( + "context" "fmt" "github.com/go-vela/types/constants" @@ -15,7 +16,7 @@ import ( ) // UpdateUser updates an existing user in the database. -func (e *engine) UpdateUser(u *library.User) (*library.User, error) { +func (e *engine) UpdateUser(ctx context.Context, u *library.User) (*library.User, error) { e.logger.WithFields(logrus.Fields{ "user": u.GetName(), }).Tracef("updating user %s in the database", u.GetName()) diff --git a/database/user/update_test.go b/database/user/update_test.go index 44f2cb0fe..58a909b5c 100644 --- a/database/user/update_test.go +++ b/database/user/update_test.go @@ -5,6 +5,7 @@ package user import ( + "context" "reflect" "testing" @@ -32,7 +33,7 @@ WHERE "id" = $8`). _sqlite := testSqlite(t) defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }() - _, err := _sqlite.CreateUser(_user) + _, err := _sqlite.CreateUser(context.TODO(), _user) if err != nil { t.Errorf("unable to create test user for sqlite: %v", err) } @@ -58,7 +59,7 @@ WHERE "id" = $8`). // run tests for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := test.database.UpdateUser(_user) + got, err := test.database.UpdateUser(context.TODO(), _user) if test.failure { if err == nil { diff --git a/database/user/user.go b/database/user/user.go index 99e1f9701..67f9413e8 100644 --- a/database/user/user.go +++ b/database/user/user.go @@ -5,6 +5,7 @@ package user import ( + "context" "fmt" "github.com/go-vela/types/constants" @@ -27,6 +28,8 @@ type ( // engine configuration settings used in user functions config *config + ctx context.Context + // gorm.io/gorm database client used in user functions // // https://pkg.go.dev/gorm.io/gorm#DB @@ -67,13 +70,13 @@ func New(opts ...EngineOpt) (*engine, error) { } // create the users table - err := e.CreateUserTable(e.client.Config.Dialector.Name()) + err := e.CreateUserTable(e.ctx, e.client.Config.Dialector.Name()) if err != nil { return nil, fmt.Errorf("unable to create %s table: %w", constants.TableUser, err) } // create the indexes for the users table - err = e.CreateUserIndexes() + err = e.CreateUserIndexes(e.ctx) if err != nil { return nil, fmt.Errorf("unable to create indexes for %s table: %w", constants.TableUser, err) } diff --git a/internal/token/refresh.go b/internal/token/refresh.go index 8cb69f374..9ba8c3187 100644 --- a/internal/token/refresh.go +++ b/internal/token/refresh.go @@ -14,6 +14,9 @@ import ( // Refresh returns a new access token, if the provided refreshToken is valid. func (tm *Manager) Refresh(c *gin.Context, refreshToken string) (string, error) { + // capture middleware values + ctx := c.Request.Context() + // retrieve claims from token claims, err := tm.ParseToken(refreshToken) if err != nil { @@ -21,7 +24,7 @@ func (tm *Manager) Refresh(c *gin.Context, refreshToken string) (string, error) } // look up user in database given claims subject - u, err := database.FromContext(c).GetUserForName(claims.Subject) + u, err := database.FromContext(c).GetUserForName(ctx, claims.Subject) if err != nil { return "", fmt.Errorf("unable to retrieve user %s from database from claims subject: %w", claims.Subject, err) } diff --git a/internal/token/refresh_test.go b/internal/token/refresh_test.go index ee9a30c4d..ee7a96077 100644 --- a/internal/token/refresh_test.go +++ b/internal/token/refresh_test.go @@ -5,6 +5,8 @@ package token import ( + "context" + "net/http" "net/http/httptest" "testing" "time" @@ -51,11 +53,11 @@ func TestTokenManager_Refresh(t *testing.T) { } defer func() { - db.DeleteUser(u) + db.DeleteUser(context.TODO(), u) db.Close() }() - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(context.TODO(), u) // set up context gin.SetMode(gin.TestMode) @@ -63,6 +65,8 @@ func TestTokenManager_Refresh(t *testing.T) { resp := httptest.NewRecorder() context, _ := gin.CreateTestContext(resp) context.Set("database", db) + req, _ := http.NewRequestWithContext(context, "", "", nil) + context.Request = req // run tests got, err := tm.Refresh(context, rt) @@ -110,11 +114,11 @@ func TestTokenManager_Refresh_Expired(t *testing.T) { } defer func() { - db.DeleteUser(u) + db.DeleteUser(context.TODO(), u) db.Close() }() - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(context.TODO(), u) // set up context gin.SetMode(gin.TestMode) @@ -122,6 +126,8 @@ func TestTokenManager_Refresh_Expired(t *testing.T) { resp := httptest.NewRecorder() context, _ := gin.CreateTestContext(resp) context.Set("database", db) + req, _ := http.NewRequestWithContext(context, "", "", nil) + context.Request = req // run tests _, err = tm.Refresh(context, rt) diff --git a/router/middleware/claims/claims_test.go b/router/middleware/claims/claims_test.go index fd1e0ccfd..c98f191fa 100644 --- a/router/middleware/claims/claims_test.go +++ b/router/middleware/claims/claims_test.go @@ -5,6 +5,7 @@ package claims import ( + _context "context" "fmt" "net/http" "net/http/httptest" @@ -274,11 +275,11 @@ func TestClaims_Establish_BadToken(t *testing.T) { } defer func() { - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) mto := &token.MintTokenOpts{ User: u, diff --git a/router/middleware/perm/perm.go b/router/middleware/perm/perm.go index 8028275ae..3e02e33be 100644 --- a/router/middleware/perm/perm.go +++ b/router/middleware/perm/perm.go @@ -348,6 +348,7 @@ func MustAdmin() gin.HandlerFunc { o := org.Retrieve(c) r := repo.Retrieve(c) u := user.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -371,7 +372,7 @@ func MustAdmin() gin.HandlerFunc { // try again using the repo owner token // // https://docs.github.com/en/rest/reference/repos#get-repository-permissions-for-a-user - ro, err := database.FromContext(c).GetUser(r.GetUserID()) + ro, err := database.FromContext(c).GetUser(ctx, r.GetUserID()) if err != nil { retErr := fmt.Errorf("unable to get owner for %s: %w", r.GetFullName(), err) @@ -406,6 +407,7 @@ func MustWrite() gin.HandlerFunc { o := org.Retrieve(c) r := repo.Retrieve(c) u := user.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -429,7 +431,7 @@ func MustWrite() gin.HandlerFunc { // try again using the repo owner token // // https://docs.github.com/en/rest/reference/repos#get-repository-permissions-for-a-user - ro, err := database.FromContext(c).GetUser(r.GetUserID()) + ro, err := database.FromContext(c).GetUser(ctx, r.GetUserID()) if err != nil { retErr := fmt.Errorf("unable to get owner for %s: %w", r.GetFullName(), err) @@ -466,6 +468,7 @@ func MustRead() gin.HandlerFunc { o := org.Retrieve(c) r := repo.Retrieve(c) u := user.Retrieve(c) + ctx := c.Request.Context() // update engine logger with API metadata // @@ -511,7 +514,7 @@ func MustRead() gin.HandlerFunc { // try again using the repo owner token // // https://docs.github.com/en/rest/reference/repos#get-repository-permissions-for-a-user - ro, err := database.FromContext(c).GetUser(r.GetUserID()) + ro, err := database.FromContext(c).GetUser(ctx, r.GetUserID()) if err != nil { retErr := fmt.Errorf("unable to get owner for %s: %w", r.GetFullName(), err) diff --git a/router/middleware/perm/perm_test.go b/router/middleware/perm/perm_test.go index 2b384250e..aa5ed7543 100644 --- a/router/middleware/perm/perm_test.go +++ b/router/middleware/perm/perm_test.go @@ -60,11 +60,11 @@ func TestPerm_MustPlatformAdmin(t *testing.T) { } defer func() { - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) // setup context gin.SetMode(gin.TestMode) @@ -148,11 +148,11 @@ func TestPerm_MustPlatformAdmin_NotAdmin(t *testing.T) { } defer func() { - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/admin/users", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) @@ -275,11 +275,11 @@ func TestPerm_MustWorkerRegisterToken_PlatAdmin(t *testing.T) { } defer func() { - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/test/foo/bar", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) @@ -538,13 +538,13 @@ func TestPerm_MustBuildAccess_PlatAdmin(t *testing.T) { defer func() { db.DeleteBuild(ctx, b) db.DeleteRepo(_context.TODO(), r) - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() _, _ = db.CreateRepo(_context.TODO(), r) _, _ = db.CreateBuild(ctx, b) - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/test/foo/bar/builds/1", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) @@ -958,12 +958,12 @@ func TestPerm_MustAdmin(t *testing.T) { defer func() { db.DeleteRepo(_context.TODO(), r) - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() _, _ = db.CreateRepo(_context.TODO(), r) - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/test/foo/bar", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) @@ -1056,12 +1056,12 @@ func TestPerm_MustAdmin_PlatAdmin(t *testing.T) { defer func() { db.DeleteRepo(_context.TODO(), r) - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() _, _ = db.CreateRepo(_context.TODO(), r) - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/test/foo/bar", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) @@ -1154,12 +1154,12 @@ func TestPerm_MustAdmin_NotAdmin(t *testing.T) { defer func() { db.DeleteRepo(_context.TODO(), r) - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() _, _ = db.CreateRepo(_context.TODO(), r) - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/test/foo/bar", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) @@ -1252,12 +1252,12 @@ func TestPerm_MustWrite(t *testing.T) { defer func() { db.DeleteRepo(_context.TODO(), r) - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() _, _ = db.CreateRepo(_context.TODO(), r) - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/test/foo/bar", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) @@ -1350,12 +1350,12 @@ func TestPerm_MustWrite_PlatAdmin(t *testing.T) { defer func() { db.DeleteRepo(_context.TODO(), r) - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() _, _ = db.CreateRepo(_context.TODO(), r) - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/test/foo/bar", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) @@ -1448,12 +1448,12 @@ func TestPerm_MustWrite_RepoAdmin(t *testing.T) { defer func() { db.DeleteRepo(_context.TODO(), r) - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() _, _ = db.CreateRepo(_context.TODO(), r) - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/test/foo/bar", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) @@ -1546,12 +1546,12 @@ func TestPerm_MustWrite_NotWrite(t *testing.T) { defer func() { db.DeleteRepo(_context.TODO(), r) - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() _, _ = db.CreateRepo(_context.TODO(), r) - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/test/foo/bar", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) @@ -1644,12 +1644,12 @@ func TestPerm_MustRead(t *testing.T) { defer func() { db.DeleteRepo(_context.TODO(), r) - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() _, _ = db.CreateRepo(_context.TODO(), r) - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/test/foo/bar", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) @@ -1742,12 +1742,12 @@ func TestPerm_MustRead_PlatAdmin(t *testing.T) { defer func() { db.DeleteRepo(_context.TODO(), r) - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() _, _ = db.CreateRepo(_context.TODO(), r) - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/test/foo/bar", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) @@ -1926,12 +1926,12 @@ func TestPerm_MustRead_RepoAdmin(t *testing.T) { defer func() { db.DeleteRepo(_context.TODO(), r) - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() _, _ = db.CreateRepo(_context.TODO(), r) - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/test/foo/bar", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) @@ -2024,12 +2024,12 @@ func TestPerm_MustRead_RepoWrite(t *testing.T) { defer func() { db.DeleteRepo(_context.TODO(), r) - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() _, _ = db.CreateRepo(_context.TODO(), r) - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/test/foo/bar", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) @@ -2122,12 +2122,12 @@ func TestPerm_MustRead_RepoPublic(t *testing.T) { defer func() { db.DeleteRepo(_context.TODO(), r) - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() _, _ = db.CreateRepo(_context.TODO(), r) - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/test/foo/bar", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) @@ -2220,12 +2220,12 @@ func TestPerm_MustRead_NotRead(t *testing.T) { defer func() { db.DeleteRepo(_context.TODO(), r) - db.DeleteUser(u) + db.DeleteUser(_context.TODO(), u) db.Close() }() _, _ = db.CreateRepo(_context.TODO(), r) - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(_context.TODO(), u) context.Request, _ = http.NewRequest(http.MethodGet, "/test/foo/bar", nil) context.Request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", tok)) diff --git a/router/middleware/pipeline/pipeline_test.go b/router/middleware/pipeline/pipeline_test.go index ead046065..aaff5fb1a 100644 --- a/router/middleware/pipeline/pipeline_test.go +++ b/router/middleware/pipeline/pipeline_test.go @@ -290,12 +290,12 @@ func TestPipeline_Establish_NoPipeline(t *testing.T) { defer func() { db.DeleteRepo(context.TODO(), r) - db.DeleteUser(u) + db.DeleteUser(context.TODO(), u) db.Close() }() _, _ = db.CreateRepo(context.TODO(), r) - _, _ = db.CreateUser(u) + _, _ = db.CreateUser(context.TODO(), u) // setup context gin.SetMode(gin.TestMode) diff --git a/router/middleware/user/user.go b/router/middleware/user/user.go index ad94a5e0d..9717e7fa6 100644 --- a/router/middleware/user/user.go +++ b/router/middleware/user/user.go @@ -28,6 +28,7 @@ func Retrieve(c *gin.Context) *library.User { func Establish() gin.HandlerFunc { return func(c *gin.Context) { cl := claims.Retrieve(c) + ctx := c.Request.Context() // if token is not a user token or claims were not retrieved, establish empty user to better handle nil checks if cl == nil || !strings.EqualFold(cl.TokenType, constants.UserAccessTokenType) { @@ -42,7 +43,7 @@ func Establish() gin.HandlerFunc { logrus.Debugf("parsing user access token") // lookup user in claims subject in the database - u, err := database.FromContext(c).GetUserForName(cl.Subject) + u, err := database.FromContext(c).GetUserForName(ctx, cl.Subject) if err != nil { util.HandleError(c, http.StatusUnauthorized, err) return diff --git a/router/middleware/user/user_test.go b/router/middleware/user/user_test.go index c90cb53fc..74bdd412e 100644 --- a/router/middleware/user/user_test.go +++ b/router/middleware/user/user_test.go @@ -5,6 +5,7 @@ package user import ( + _context "context" "fmt" "net/http" "net/http/httptest" @@ -92,11 +93,11 @@ func TestUser_Establish(t *testing.T) { } defer func() { - db.DeleteUser(want) + db.DeleteUser(_context.TODO(), want) db.Close() }() - _, _ = db.CreateUser(want) + _, _ = db.CreateUser(_context.TODO(), want) // setup context gin.SetMode(gin.TestMode)