Skip to content

Commit

Permalink
fix: plumb ctx into permission db methods (#475)
Browse files Browse the repository at this point in the history
  • Loading branch information
mistahj67 authored Mar 7, 2024
1 parent a90f74b commit cce7e6f
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 56 deletions.
4 changes: 2 additions & 2 deletions cmd/api/src/api/v2/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ func (s ManagementResource) ListPermissions(response http.ResponseWriter, reques
if sqlFilter, err := queryFilters.BuildSQLFilter(); err != nil {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "error building SQL for filter", request), response)
return
} else if permissions, err = s.db.GetAllPermissions(strings.Join(order, ", "), sqlFilter); err != nil {
} else if permissions, err = s.db.GetAllPermissions(request.Context(), strings.Join(order, ", "), sqlFilter); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else {
Expand All @@ -270,7 +270,7 @@ func (s ManagementResource) GetPermission(response http.ResponseWriter, request

if permissionID, err := strconv.Atoi(rawPermissionID); err != nil {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseDetailsIDMalformed, request), response)
} else if permission, err := s.db.GetPermission(permissionID); err != nil {
} else if permission, err := s.db.GetPermission(request.Context(), permissionID); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
api.WriteBasicResponse(request.Context(), permission, http.StatusOK, response)
Expand Down
4 changes: 2 additions & 2 deletions cmd/api/src/api/v2/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ func TestManagementResource_ListPermissions_DBError(t *testing.T) {

endpoint := "/api/v2/permissions"
mockDB := dbmocks.NewMockDatabase(mockCtrl)
mockDB.EXPECT().GetAllPermissions("authority desc, name", model.SQLFilter{SQLString: "name = ?", Params: []any{"foo"}}).Return(model.Permissions{}, fmt.Errorf("foo"))
mockDB.EXPECT().GetAllPermissions(gomock.Any(), "authority desc, name", model.SQLFilter{SQLString: "name = ?", Params: []any{"foo"}}).Return(model.Permissions{}, fmt.Errorf("foo"))

config, err := config.NewDefaultConfiguration()
require.Nilf(t, err, "Failed to create default configuration: %v", err)
Expand Down Expand Up @@ -395,7 +395,7 @@ func TestManagementResource_ListPermissions(t *testing.T) {
}

resources, mockDB := apitest.NewAuthManagementResource(mockCtrl)
mockDB.EXPECT().GetAllPermissions("authority desc, name", model.SQLFilter{SQLString: "name = ?", Params: []any{"a"}}).Return(model.Permissions{perm1, perm2}, nil)
mockDB.EXPECT().GetAllPermissions(gomock.Any(), "authority desc, name", model.SQLFilter{SQLString: "name = ?", Params: []any{"a"}}).Return(model.Permissions{perm1, perm2}, nil)

ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{})
if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil {
Expand Down
35 changes: 11 additions & 24 deletions cmd/api/src/database/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,47 +176,34 @@ func (s *BloodhoundDB) LookupRoleByName(name string) (model.Role, error) {

// GetAllPermissions retrieves all rows from the Permissions table
// SELECT * FROM permissions
func (s *BloodhoundDB) GetAllPermissions(order string, filter model.SQLFilter) (model.Permissions, error) {
func (s *BloodhoundDB) GetAllPermissions(ctx context.Context, order string, filter model.SQLFilter) (model.Permissions, error) {
var (
permissions model.Permissions
result *gorm.DB
cursor = s.db.WithContext(ctx)
)

if order == "" && filter.SQLString == "" {
result = s.db.Find(&permissions)
} else if order != "" && filter.SQLString == "" {
result = s.db.Order(order).Find(&permissions)
} else if order == "" && filter.SQLString != "" {
result = s.db.Where(filter.SQLString, filter.Params).Find(&permissions)
} else {
result = s.db.Where(filter.SQLString, filter.Params).Order(order).Find(&permissions)
if order != "" {
cursor = cursor.Order(order)
}

if filter.SQLString != "" {
cursor = cursor.Where(filter.SQLString, filter.Params)
}

return permissions, CheckError(result)
return permissions, CheckError(cursor.Find(&permissions))
}

// GetPermission retrieves a row in the Permissions table corresponding to the ID provided
// SELECT * FROM permissions WHERE permission_id = ...
func (s *BloodhoundDB) GetPermission(id int) (model.Permission, error) {
func (s *BloodhoundDB) GetPermission(ctx context.Context, id int) (model.Permission, error) {
var (
permission model.Permission
result = s.db.First(&permission, id)
result = s.db.WithContext(ctx).First(&permission, id)
)

return permission, CheckError(result)
}

// CreatePermission creates a new permission row with the struct provided
// INSERT INTO permissions (id, authority, name) VALUES (ID, authority, name)
func (s *BloodhoundDB) CreatePermission(permission model.Permission) (model.Permission, error) {
var (
updatedPermission = permission
result = s.db.Create(&updatedPermission)
)

return updatedPermission, CheckError(result)
}

// InitializeSAMLAuth creates new SAMLProvider, User and Installation entries based on the input provided
func (s *BloodhoundDB) InitializeSAMLAuth(adminUser model.User, samlProvider model.SAMLProvider) (model.SAMLProvider, model.Installation, error) {
var (
Expand Down
4 changes: 2 additions & 2 deletions cmd/api/src/database/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func TestDatabase_InitializePermissions(t *testing.T) {
t.Fatalf("Failed preparing DB: %v", err)
}

if permissions, err := dbInst.GetAllPermissions("", model.SQLFilter{}); err != nil {
if permissions, err := dbInst.GetAllPermissions(context.Background(), "", model.SQLFilter{}); err != nil {
t.Fatalf("Error fetching permissions: %v", err)
} else {
templates := auth.Permissions().All()
Expand Down Expand Up @@ -145,7 +145,7 @@ func TestDatabase_UpdateRole(t *testing.T) {

if role, found := roles.FindByName(auth.RoleReadOnly); !found {
t.Fatal("Unable to find role")
} else if allPermissions, err := dbInst.GetAllPermissions("", model.SQLFilter{}); err != nil {
} else if allPermissions, err := dbInst.GetAllPermissions(context.Background(), "", model.SQLFilter{}); err != nil {
t.Fatalf("Failed fetching all permissions: %v", err)
} else {
role.Permissions = allPermissions
Expand Down
5 changes: 2 additions & 3 deletions cmd/api/src/database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ type Database interface {
LookupRoleByName(name string) (model.Role, error)

// Permissions
GetAllPermissions(order string, filter model.SQLFilter) (model.Permissions, error)
GetPermission(id int) (model.Permission, error)
CreatePermission(permission model.Permission) (model.Permission, error)
GetAllPermissions(ctx context.Context, order string, filter model.SQLFilter) (model.Permissions, error)
GetPermission(ctx context.Context, id int) (model.Permission, error)

InitializeSAMLAuth(adminUser model.User, samlProvider model.SAMLProvider) (model.SAMLProvider, model.Installation, error)
InitializeSecretAuth(adminUser model.User, authSecret model.AuthSecret) (model.Installation, error)
Expand Down
31 changes: 8 additions & 23 deletions cmd/api/src/database/mocks/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit cce7e6f

Please sign in to comment.