Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(db): return user on created and updated #935

Merged
merged 3 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/admin/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func UpdateUser(c *gin.Context) {
}

// send API call to update the user
err = database.FromContext(c).UpdateUser(input)
u, err := database.FromContext(c).UpdateUser(input)
if err != nil {
retErr := fmt.Errorf("unable to update user %d: %w", input.GetID(), err)

Expand All @@ -75,5 +75,5 @@ func UpdateUser(c *gin.Context) {
return
}

c.JSON(http.StatusOK, input)
c.JSON(http.StatusOK, u)
}
4 changes: 2 additions & 2 deletions api/auth/get_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func GetAuthToken(c *gin.Context) {
u.SetRefreshToken(rt)

// send API call to create the user in the database
err = database.FromContext(c).CreateUser(u)
_, err = database.FromContext(c).CreateUser(u)
if err != nil {
retErr := fmt.Errorf("unable to create user %s: %w", u.GetName(), err)

Expand Down Expand Up @@ -154,7 +154,7 @@ func GetAuthToken(c *gin.Context) {
u.SetRefreshToken(rt)

// send API call to update the user in the database
err = database.FromContext(c).UpdateUser(u)
_, err = database.FromContext(c).UpdateUser(u)
if err != nil {
retErr := fmt.Errorf("unable to update user %s: %w", u.GetName(), err)

Expand Down
2 changes: 1 addition & 1 deletion api/auth/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func Logout(c *gin.Context) {
u.SetRefreshToken("")

// send API call to update the user in the database
err = database.FromContext(c).UpdateUser(u)
_, err = database.FromContext(c).UpdateUser(u)
if err != nil {
retErr := fmt.Errorf("unable to update user %s: %w", u.GetName(), err)

Expand Down
5 changes: 1 addition & 4 deletions api/user/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func CreateUser(c *gin.Context) {
}).Infof("creating new user %s", input.GetName())

// send API call to create the user
err = database.FromContext(c).CreateUser(input)
user, err := database.FromContext(c).CreateUser(input)
if err != nil {
retErr := fmt.Errorf("unable to create user: %w", err)

Expand All @@ -81,8 +81,5 @@ func CreateUser(c *gin.Context) {
return
}

// send API call to capture the created user
user, _ := database.FromContext(c).GetUserForName(input.GetName())

c.JSON(http.StatusCreated, user)
}
2 changes: 1 addition & 1 deletion api/user/create_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func CreateToken(c *gin.Context) {
u.SetRefreshToken(rt)

// send API call to update the user
err = database.FromContext(c).UpdateUser(u)
_, err = database.FromContext(c).UpdateUser(u)
if err != nil {
retErr := fmt.Errorf("unable to update user %s: %w", u.GetName(), err)

Expand Down
2 changes: 1 addition & 1 deletion api/user/delete_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func DeleteToken(c *gin.Context) {
u.SetRefreshToken(rt)

// send API call to update the user
err = database.FromContext(c).UpdateUser(u)
_, err = database.FromContext(c).UpdateUser(u)
if err != nil {
retErr := fmt.Errorf("unable to update user %s: %w", u.GetName(), err)

Expand Down
5 changes: 1 addition & 4 deletions api/user/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func UpdateUser(c *gin.Context) {
}

// send API call to update the user
err = database.FromContext(c).UpdateUser(u)
u, err = database.FromContext(c).UpdateUser(u)
if err != nil {
retErr := fmt.Errorf("unable to update user %s: %w", user, err)

Expand All @@ -117,8 +117,5 @@ func UpdateUser(c *gin.Context) {
return
}

// send API call to capture the updated user
u, _ = database.FromContext(c).GetUserForName(user)

c.JSON(http.StatusOK, u)
}
12 changes: 1 addition & 11 deletions api/user/update_current.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func UpdateCurrentUser(c *gin.Context) {
}

// send API call to update the user
err = database.FromContext(c).UpdateUser(u)
u, err = database.FromContext(c).UpdateUser(u)
if err != nil {
retErr := fmt.Errorf("unable to update user %s: %w", u.GetName(), err)

Expand All @@ -91,15 +91,5 @@ func UpdateCurrentUser(c *gin.Context) {
return
}

// send API call to capture the updated user
u, err = database.FromContext(c).GetUserForName(u.GetName())
if err != nil {
retErr := fmt.Errorf("unable to get updated user %s: %w", u.GetName(), err)

util.HandleError(c, http.StatusNotFound, retErr)

return
}

c.JSON(http.StatusOK, u)
}
9 changes: 2 additions & 7 deletions database/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1656,7 +1656,7 @@ func testUsers(t *testing.T, db Interface, resources *Resources) {

// create the users
for _, user := range resources.Users {
err := db.CreateUser(user)
_, err := db.CreateUser(user)
if err != nil {
t.Errorf("unable to create user %d: %v", user.GetID(), err)
}
Expand Down Expand Up @@ -1711,16 +1711,11 @@ func testUsers(t *testing.T, db Interface, resources *Resources) {
// update the users
for _, user := range resources.Users {
user.SetActive(false)
err = db.UpdateUser(user)
got, err := db.UpdateUser(user)
if err != nil {
t.Errorf("unable to update user %d: %v", user.GetID(), err)
}

// lookup the user by ID
got, err := db.GetUser(user.GetID())
if err != nil {
t.Errorf("unable to get user %d by ID: %v", user.GetID(), err)
}
if !reflect.DeepEqual(got, user) {
t.Errorf("GetUser() is %v, want %v", got, user)
}
Expand Down
4 changes: 2 additions & 2 deletions database/user/count_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ func TestUser_Engine_CountUsers(t *testing.T) {
_sqlite := testSqlite(t)
defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }()

err := _sqlite.CreateUser(_userOne)
_, err := _sqlite.CreateUser(_userOne)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}

err = _sqlite.CreateUser(_userTwo)
_, err = _sqlite.CreateUser(_userTwo)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}
Expand Down
19 changes: 12 additions & 7 deletions database/user/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
)

// CreateUser creates a new user in the database.
func (e *engine) CreateUser(u *library.User) error {
func (e *engine) CreateUser(u *library.User) (*library.User, error) {
e.logger.WithFields(logrus.Fields{
"user": u.GetName(),
}).Tracef("creating user %s in the database", u.GetName())
Expand All @@ -30,20 +30,25 @@ func (e *engine) CreateUser(u *library.User) error {
// https://pkg.go.dev/github.com/go-vela/types/database#User.Validate
err := user.Validate()
if err != nil {
return err
return nil, err
}

// encrypt the fields for the user
//
// https://pkg.go.dev/github.com/go-vela/types/database#User.Encrypt
err = user.Encrypt(e.config.EncryptionKey)
if err != nil {
return fmt.Errorf("unable to encrypt user %s: %w", u.GetName(), err)
return nil, fmt.Errorf("unable to encrypt user %s: %w", u.GetName(), err)
}

// send query to the database
return e.client.
Table(constants.TableUser).
Create(user).
Error
result := e.client.Table(constants.TableUser).Create(user)

// decrypt fields to return user
err = user.Decrypt(e.config.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("unable to decrypt user %s: %w", u.GetName(), err)
}

return user.ToLibrary(), result.Error
}
7 changes: 6 additions & 1 deletion database/user/create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package user

import (
"reflect"
"testing"

"github.com/DATA-DOG/go-sqlmock"
Expand Down Expand Up @@ -55,7 +56,7 @@ VALUES ($1,$2,$3,$4,$5,$6,$7,$8) RETURNING "id"`).
// run tests
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
err := test.database.CreateUser(_user)
got, err := test.database.CreateUser(_user)

if test.failure {
if err == nil {
Expand All @@ -68,6 +69,10 @@ VALUES ($1,$2,$3,$4,$5,$6,$7,$8) RETURNING "id"`).
if err != nil {
t.Errorf("CreateUser for %s returned err: %v", test.name, err)
}

if !reflect.DeepEqual(got, _user) {
t.Errorf("CreateUser for %s returned %s, want %s", test.name, got, _user)
}
})
}
}
2 changes: 1 addition & 1 deletion database/user/delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestUser_Engine_DeleteUser(t *testing.T) {
_sqlite := testSqlite(t)
defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }()

err := _sqlite.CreateUser(_user)
_, err := _sqlite.CreateUser(_user)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion database/user/get_name_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestUser_Engine_GetUserForName(t *testing.T) {
_sqlite := testSqlite(t)
defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }()

err := _sqlite.CreateUser(_user)
_, err := _sqlite.CreateUser(_user)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion database/user/get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestUser_Engine_GetUser(t *testing.T) {
_sqlite := testSqlite(t)
defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }()

err := _sqlite.CreateUser(_user)
_, err := _sqlite.CreateUser(_user)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions database/user/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type UserInterface interface {
// CountUsers defines a function that gets the count of all users.
CountUsers() (int64, error)
// CreateUser defines a function that creates a new user.
CreateUser(*library.User) error
CreateUser(*library.User) (*library.User, error)
// DeleteUser defines a function that deletes an existing user.
DeleteUser(*library.User) error
// GetUser defines a function that gets a user by ID.
Expand All @@ -41,5 +41,5 @@ type UserInterface interface {
// ListLiteUsers defines a function that gets a lite list of users.
ListLiteUsers(int, int) ([]*library.User, int64, error)
// UpdateUser defines a function that updates an existing user.
UpdateUser(*library.User) error
UpdateUser(*library.User) (*library.User, error)
}
4 changes: 2 additions & 2 deletions database/user/list_lite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ func TestUser_Engine_ListLiteUsers(t *testing.T) {
_sqlite := testSqlite(t)
defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }()

err := _sqlite.CreateUser(_userOne)
_, err := _sqlite.CreateUser(_userOne)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}

err = _sqlite.CreateUser(_userTwo)
_, err = _sqlite.CreateUser(_userTwo)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions database/user/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ func TestUser_Engine_ListUsers(t *testing.T) {
_sqlite := testSqlite(t)
defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }()

err := _sqlite.CreateUser(_userOne)
_, err := _sqlite.CreateUser(_userOne)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}

err = _sqlite.CreateUser(_userTwo)
_, err = _sqlite.CreateUser(_userTwo)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}
Expand Down
19 changes: 12 additions & 7 deletions database/user/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
)

// UpdateUser updates an existing user in the database.
func (e *engine) UpdateUser(u *library.User) error {
func (e *engine) UpdateUser(u *library.User) (*library.User, error) {
e.logger.WithFields(logrus.Fields{
"user": u.GetName(),
}).Tracef("updating user %s in the database", u.GetName())
Expand All @@ -30,20 +30,25 @@ func (e *engine) UpdateUser(u *library.User) error {
// https://pkg.go.dev/github.com/go-vela/types/database#User.Validate
err := user.Validate()
if err != nil {
return err
return nil, err
}

// encrypt the fields for the user
//
// https://pkg.go.dev/github.com/go-vela/types/database#User.Encrypt
err = user.Encrypt(e.config.EncryptionKey)
if err != nil {
return fmt.Errorf("unable to encrypt user %s: %w", u.GetName(), err)
return nil, fmt.Errorf("unable to encrypt user %s: %w", u.GetName(), err)
}

// send query to the database
return e.client.
Table(constants.TableUser).
Save(user).
Error
result := e.client.Table(constants.TableUser).Save(user)

// decrypt fields to return user
err = user.Decrypt(e.config.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("unable to decrypt user %s: %w", u.GetName(), err)
}

return user.ToLibrary(), result.Error
}
9 changes: 7 additions & 2 deletions database/user/update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package user

import (
"reflect"
"testing"

"github.com/DATA-DOG/go-sqlmock"
Expand All @@ -31,7 +32,7 @@ WHERE "id" = $8`).
_sqlite := testSqlite(t)
defer func() { _sql, _ := _sqlite.client.DB(); _sql.Close() }()

err := _sqlite.CreateUser(_user)
_, err := _sqlite.CreateUser(_user)
if err != nil {
t.Errorf("unable to create test user for sqlite: %v", err)
}
Expand All @@ -57,7 +58,7 @@ WHERE "id" = $8`).
// run tests
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
err = test.database.UpdateUser(_user)
got, err := test.database.UpdateUser(_user)

if test.failure {
if err == nil {
Expand All @@ -70,6 +71,10 @@ WHERE "id" = $8`).
if err != nil {
t.Errorf("UpdateUser for %s returned err: %v", test.name, err)
}

if !reflect.DeepEqual(got, _user) {
t.Errorf("UpdateUser for %s returned %s, want %s", test.name, got, _user)
}
})
}
}
Loading