Skip to content

Commit

Permalink
fix: plumb ctx into roles db methods (#476)
Browse files Browse the repository at this point in the history
  • Loading branch information
mistahj67 authored Mar 7, 2024
1 parent cce7e6f commit 143bfa0
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 215 deletions.
8 changes: 4 additions & 4 deletions cmd/api/src/api/v2/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func (s ManagementResource) ListRoles(response http.ResponseWriter, request *htt
if sqlFilter, err := queryFilters.BuildSQLFilter(); err != nil {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "error building SQL for filter", request), response)
return
} else if roles, err = s.db.GetAllRoles(strings.Join(order, ", "), sqlFilter); err != nil {
} else if roles, err = s.db.GetAllRoles(request.Context(), strings.Join(order, ", "), sqlFilter); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
api.WriteBasicResponse(request.Context(), v2.ListRolesResponse{Roles: roles}, http.StatusOK, response)
Expand All @@ -348,7 +348,7 @@ func (s ManagementResource) GetRole(response http.ResponseWriter, request *http.

if roleID, err := strconv.ParseInt(rawRoleID, 10, 32); err != nil {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseDetailsIDMalformed, request), response)
} else if role, err := s.db.GetRole(int32(roleID)); err != nil {
} else if role, err := s.db.GetRole(request.Context(), int32(roleID)); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
api.WriteBasicResponse(request.Context(), role, http.StatusOK, response)
Expand Down Expand Up @@ -428,7 +428,7 @@ func (s ManagementResource) CreateUser(response http.ResponseWriter, request *ht
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response)
} else if len(createUserRequest.Roles) > 1 {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, ErrorResponseDetailsNumRoles, request), response)
} else if roles, err := s.db.GetRoles(createUserRequest.Roles); err != nil {
} else if roles, err := s.db.GetRoles(request.Context(), createUserRequest.Roles); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
userTemplate.Roles = roles
Expand Down Expand Up @@ -520,7 +520,7 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response)
} else if len(updateUserRequest.Roles) > 1 {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "a user can only have one role", request), response)
} else if roles, err := s.db.GetRoles(updateUserRequest.Roles); err != nil {
} else if roles, err := s.db.GetRoles(request.Context(), updateUserRequest.Roles); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
user.Roles = roles
Expand Down
40 changes: 20 additions & 20 deletions cmd/api/src/api/v2/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {

defer mockCtrl.Finish()

mockDB.EXPECT().GetRoles(gomock.Eq(goodRoles)).Return(model.Roles{}, nil).AnyTimes()
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Eq(goodRoles)).Return(model.Roles{}, nil).AnyTimes()
mockDB.EXPECT().GetUser(gomock.Any(), badUserID).Return(model.User{AuthSecret: &model.AuthSecret{}}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), goodUserID).Return(model.User{}, nil)
mockDB.EXPECT().GetSAMLProvider(samlProviderID).Return(model.SAMLProvider{}, nil).Times(2)
Expand Down Expand Up @@ -563,7 +563,7 @@ func TestManagementResource_ListRoles_DBError(t *testing.T) {

endpoint := "/api/v2/auth/roles"
mockDB := dbmocks.NewMockDatabase(mockCtrl)
mockDB.EXPECT().GetAllRoles("description desc, name", model.SQLFilter{}).Return(model.Roles{}, fmt.Errorf("foo"))
mockDB.EXPECT().GetAllRoles(gomock.Any(), "description desc, name", model.SQLFilter{}).Return(model.Roles{}, fmt.Errorf("foo"))

config, err := config.NewDefaultConfiguration()
require.Nilf(t, err, "Failed to create default configuration: %v", err)
Expand Down Expand Up @@ -613,7 +613,7 @@ func TestManagementResource_ListRoles(t *testing.T) {
}

resources, mockDB := apitest.NewAuthManagementResource(mockCtrl)
mockDB.EXPECT().GetAllRoles("description desc, name", model.SQLFilter{}).Return(model.Roles{role1, role2}, nil)
mockDB.EXPECT().GetAllRoles(gomock.Any(), "description desc, name", model.SQLFilter{}).Return(model.Roles{role1, role2}, nil)

ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{})
if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil {
Expand Down Expand Up @@ -656,7 +656,7 @@ func TestManagementResource_ListRoles_Filtered(t *testing.T) {
}

resources, mockDB := apitest.NewAuthManagementResource(mockCtrl)
mockDB.EXPECT().GetAllRoles("", model.SQLFilter{SQLString: "name = ?", Params: []any{"a"}}).Return(model.Roles{role1}, nil)
mockDB.EXPECT().GetAllRoles(gomock.Any(), "", model.SQLFilter{SQLString: "name = ?", Params: []any{"a"}}).Return(model.Roles{role1}, nil)

ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{})
if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil {
Expand Down Expand Up @@ -1054,8 +1054,8 @@ func TestCreateUser_Failure(t *testing.T) {
Duration: appcfg.DefaultPasswordExpirationWindow,
}),
}, nil).AnyTimes()
mockDB.EXPECT().GetRoles(badRole).Return(model.Roles{}, fmt.Errorf("db error"))
mockDB.EXPECT().GetRoles(gomock.Not(badRole)).Return(model.Roles{}, nil).AnyTimes()
mockDB.EXPECT().GetRoles(gomock.Any(), badRole).Return(model.Roles{}, fmt.Errorf("db error"))
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Not(badRole)).Return(model.Roles{}, nil).AnyTimes()
mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
mockDB.EXPECT().CreateUser(gomock.Any(), badUser).Return(model.User{}, fmt.Errorf("db error"))

Expand Down Expand Up @@ -1170,7 +1170,7 @@ func TestCreateUser_Success(t *testing.T) {
Duration: appcfg.DefaultPasswordExpirationWindow,
}),
}, nil)
mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes()

ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{})
Expand Down Expand Up @@ -1223,7 +1223,7 @@ func TestCreateUser_ResetPassword(t *testing.T) {
Duration: appcfg.DefaultPasswordExpirationWindow,
}),
}, nil)
mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil)

input := struct {
Expand Down Expand Up @@ -1296,7 +1296,7 @@ func TestManagementResource_UpdateUser_IDMalformed(t *testing.T) {
Duration: appcfg.DefaultPasswordExpirationWindow,
}),
}, nil)
mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes()

ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{})
Expand Down Expand Up @@ -1359,7 +1359,7 @@ func TestManagementResource_UpdateUser_GetUserError(t *testing.T) {
Duration: appcfg.DefaultPasswordExpirationWindow,
}),
}, nil)
mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes()
mockDB.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(model.User{}, fmt.Errorf("foo"))

Expand Down Expand Up @@ -1423,10 +1423,10 @@ func TestManagementResource_UpdateUser_GetRolesError(t *testing.T) {
Duration: appcfg.DefaultPasswordExpirationWindow,
}),
}, nil)
mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes()
mockDB.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(goodUser, nil)
mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, fmt.Errorf("foo"))
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, fmt.Errorf("foo"))

ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{})
input := v2.CreateUserRequest{
Expand Down Expand Up @@ -1481,10 +1481,10 @@ func TestManagementResource_UpdateUser_SelfDisable(t *testing.T) {
Duration: appcfg.DefaultPasswordExpirationWindow,
}),
}, nil)
mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes()
mockDB.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(goodUser, nil)
mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{model.Role{
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{model.Role{
Name: "admin",
Description: "admin",
Permissions: model.Permissions{model.Permission{
Expand Down Expand Up @@ -1562,10 +1562,10 @@ func TestManagementResource_UpdateUser_LookupActiveSessionsError(t *testing.T) {
Duration: appcfg.DefaultPasswordExpirationWindow,
}),
}, nil)
mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes()
mockDB.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(goodUser, nil)
mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{model.Role{
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{model.Role{
Name: "admin",
Description: "admin",
Permissions: model.Permissions{model.Permission{
Expand Down Expand Up @@ -1643,10 +1643,10 @@ func TestManagementResource_UpdateUser_DBError(t *testing.T) {
Duration: appcfg.DefaultPasswordExpirationWindow,
}),
}, nil)
mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes()
mockDB.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(goodUser, nil)
mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{model.Role{
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{model.Role{
Name: "admin",
Description: "admin",
Permissions: model.Permissions{model.Permission{
Expand Down Expand Up @@ -1868,10 +1868,10 @@ func TestManagementResource_UpdateUser_Success(t *testing.T) {
Duration: appcfg.DefaultPasswordExpirationWindow,
}),
}, nil)
mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil)
mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes()
mockDB.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(goodUser, nil)
mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{model.Role{
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{model.Role{
Name: "admin",
Description: "admin",
Permissions: model.Permissions{model.Permission{
Expand Down
20 changes: 5 additions & 15 deletions cmd/api/src/api/v2/integration/auth.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
// Copyright 2023 Specter Ops, Inc.
//
//
// Licensed under the Apache License, Version 2.0
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//
// SPDX-License-Identifier: Apache-2.0

package integration

import (
"github.com/specterops/bloodhound/src/model"
"github.com/gofrs/uuid"
"github.com/specterops/bloodhound/src/model"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -64,16 +64,6 @@ func (s *Context) GetRolesByName(roleNames ...string) model.Roles {
return foundRoles
}

func (s *Context) SetUserRole(userID uuid.UUID, roleName string) {
err := s.AdminClient().UserAddRole(userID, s.GetRolesByName(roleName)[0].ID)
require.Nilf(s.TestCtrl, err, "Failed to set role for user %s: %v", userID.String(), err)
}

func (s *Context) RemoveUserRole(userID uuid.UUID, roleName string) {
err := s.AdminClient().UserRemoveRole(userID, s.GetRolesByName(roleName)[0].ID)
require.Nilf(s.TestCtrl, err, "Failed to remove role for user %s: %v", userID.String(), err)
}

func (s *Context) ListUsers() model.Users {
listUsersResponse, err := s.AdminClient().ListUsers()
require.Nilf(s.TestCtrl, err, "Failed to list users: %v", err)
Expand Down
4 changes: 2 additions & 2 deletions cmd/api/src/bootstrap/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func MigrateGraph(ctx context.Context, db graph.Database, schema graph.Schema) e
}

// MigrateDB runs database migrations on PG
func MigrateDB(cfg config.Configuration, db database.Database) error {
func MigrateDB(ctx context.Context, cfg config.Configuration, db database.Database) error {
if err := db.Migrate(); err != nil {
return err
}
Expand All @@ -79,7 +79,7 @@ func MigrateDB(cfg config.Configuration, db database.Database) error {

secretDigester := cfg.Crypto.Argon2.NewDigester()

if roles, err := db.GetAllRoles("", model.SQLFilter{}); err != nil {
if roles, err := db.GetAllRoles(ctx, "", model.SQLFilter{}); err != nil {
return fmt.Errorf("error while attempting to fetch user roles: %w", err)
} else if secretDigest, err := secretDigester.Digest(cfg.DefaultAdmin.Password); err != nil {
return fmt.Errorf("error while attempting to digest secret for user: %w", err)
Expand Down
70 changes: 12 additions & 58 deletions cmd/api/src/database/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,87 +88,41 @@ func (s contextInitializer) InitContextFromToken(ctx context.Context, authToken
return auth.Context{}, ErrNotFound
}

func (s *BloodhoundDB) CreateRole(role model.Role) (model.Role, error) {
var (
updatedRole = role
result = s.db.Create(&updatedRole)
)

return updatedRole, CheckError(result)
}

// UpdateRole updates permissions for the row matching the provided Role struct
// UPDATE roles SET permissions=.... WHERE role_id = ...
func (s *BloodhoundDB) UpdateRole(role model.Role) error {
// Update permissions first
if err := s.db.Model(&role).Association("Permissions").Replace(&role.Permissions); err != nil {
return err
}

result := s.db.Save(&role)
return CheckError(result)
}

// GetAllRoles retrieves all available roles in the db
// SELECT * FROM roles
func (s *BloodhoundDB) GetAllRoles(order string, filter model.SQLFilter) (model.Roles, error) {
func (s *BloodhoundDB) GetAllRoles(ctx context.Context, order string, filter model.SQLFilter) (model.Roles, error) {
var (
roles model.Roles
result *gorm.DB
cursor = s.preload(model.RoleAssociations()).WithContext(ctx)
)

if order == "" && filter.SQLString == "" {
result = s.preload(model.RoleAssociations()).Find(&roles)
} else if order == "" && filter.SQLString != "" {
result = s.preload(model.RoleAssociations()).Where(filter.SQLString, filter.Params).Find(&roles)
} else if order != "" && filter.SQLString == "" {
result = s.preload(model.RoleAssociations()).Order(order).Find(&roles)
} else {
result = s.preload(model.RoleAssociations()).Where(filter.SQLString, filter.Params).Order(order).Find(&roles)
if order != "" && filter.SQLString == "" {
cursor = cursor.Order(order)
}
if filter.SQLString != "" {
cursor = cursor.Where(filter.SQLString, filter.Params)
}

return roles, CheckError(result)
return roles, CheckError(cursor.Find(&roles))
}

// GetRoles retrieves all rows in the Roles table corresponding to the provided list of IDs
// SELECT * FROM roles where ID in (...)
func (s *BloodhoundDB) GetRoles(ids []int32) (model.Roles, error) {
var (
roles model.Roles
result = s.preload(model.RoleAssociations()).Where("id in ?", ids).Find(&roles)
)

return roles, CheckError(result)
}

// GetRolesByName retrieves all rows in the Roles table corresponding to the provided list of role names
// SELECT * FROM roles WHERE role_name IN (..)
func (s *BloodhoundDB) GetRolesByName(names []string) (model.Roles, error) {
func (s *BloodhoundDB) GetRoles(ctx context.Context, ids []int32) (model.Roles, error) {
var (
roles model.Roles
result = s.preload(model.RoleAssociations()).Where("name in ?", names).Find(&roles)
result = s.preload(model.RoleAssociations()).WithContext(ctx).Where("id in ?", ids).Find(&roles)
)

return roles, CheckError(result)
}

// GetRole retrieves the role associated with the provided ID
// SELECT * FROM roles WHERE role_id = ....
func (s *BloodhoundDB) GetRole(id int32) (model.Role, error) {
var (
role model.Role
result = s.preload(model.RoleAssociations()).First(&role, id)
)

return role, CheckError(result)
}

// LookupRoleByName retrieves a row from the Roles table corresponding to the role name provided
// SELECT * FROM roles WHERE role_name = ....
func (s *BloodhoundDB) LookupRoleByName(name string) (model.Role, error) {
func (s *BloodhoundDB) GetRole(ctx context.Context, id int32) (model.Role, error) {
var (
role model.Role
result = s.preload(model.RoleAssociations()).Where("name = ?", name).First(&role)
result = s.preload(model.RoleAssociations()).WithContext(ctx).First(&role, id)
)

return role, CheckError(result)
Expand Down
37 changes: 1 addition & 36 deletions cmd/api/src/database/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func initAndGetRoles(t *testing.T) (database.Database, model.Roles) {
t.Fatalf("Failed preparing DB: %v", err)
}

if roles, err := dbInst.GetAllRoles("", model.SQLFilter{}); err != nil {
if roles, err := dbInst.GetAllRoles(context.Background(), "", model.SQLFilter{}); err != nil {
t.Fatalf("Error fetching roles: %v", err)
} else {
return dbInst, roles
Expand Down Expand Up @@ -140,41 +140,6 @@ func TestDatabase_InitializeRoles(t *testing.T) {
}
}

func TestDatabase_UpdateRole(t *testing.T) {
dbInst, roles := initAndGetRoles(t)

if role, found := roles.FindByName(auth.RoleReadOnly); !found {
t.Fatal("Unable to find role")
} else if allPermissions, err := dbInst.GetAllPermissions(context.Background(), "", model.SQLFilter{}); err != nil {
t.Fatalf("Failed fetching all permissions: %v", err)
} else {
role.Permissions = allPermissions

if err := dbInst.UpdateRole(role); err != nil {
t.Fatalf("Failed updating role %s: %v", role.Name, err)
}

if updatedRole, err := dbInst.GetRole(role.ID); err != nil {
t.Fatalf("Failed fetching updated role %s: %v", role.Name, err)
} else {
for _, permission := range role.Permissions {
found := false

for _, updatedPermission := range updatedRole.Permissions {
if permission.Equals(updatedPermission) {
found = true
break
}
}

if !found {
t.Fatalf("Updated role %s missing expected permission %s", role.Name, permission)
}
}
}
}
}

func TestDatabase_CreateGetDeleteUser(t *testing.T) {
var (
ctx = context.Background()
Expand Down
Loading

0 comments on commit 143bfa0

Please sign in to comment.