diff --git a/cmd/api/src/api/v2/auth/auth.go b/cmd/api/src/api/v2/auth/auth.go index 056dd2b644..553671cc08 100644 --- a/cmd/api/src/api/v2/auth/auth.go +++ b/cmd/api/src/api/v2/auth/auth.go @@ -332,7 +332,7 @@ func (s ManagementResource) ListRoles(response http.ResponseWriter, request *htt if sqlFilter, err := queryFilters.BuildSQLFilter(); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "error building SQL for filter", request), response) return - } else if roles, err = s.db.GetAllRoles(strings.Join(order, ", "), sqlFilter); err != nil { + } else if roles, err = s.db.GetAllRoles(request.Context(), strings.Join(order, ", "), sqlFilter); err != nil { api.HandleDatabaseError(request, response, err) } else { api.WriteBasicResponse(request.Context(), v2.ListRolesResponse{Roles: roles}, http.StatusOK, response) @@ -348,7 +348,7 @@ func (s ManagementResource) GetRole(response http.ResponseWriter, request *http. if roleID, err := strconv.ParseInt(rawRoleID, 10, 32); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseDetailsIDMalformed, request), response) - } else if role, err := s.db.GetRole(int32(roleID)); err != nil { + } else if role, err := s.db.GetRole(request.Context(), int32(roleID)); err != nil { api.HandleDatabaseError(request, response, err) } else { api.WriteBasicResponse(request.Context(), role, http.StatusOK, response) @@ -428,7 +428,7 @@ func (s ManagementResource) CreateUser(response http.ResponseWriter, request *ht api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response) } else if len(createUserRequest.Roles) > 1 { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, ErrorResponseDetailsNumRoles, request), response) - } else if roles, err := s.db.GetRoles(createUserRequest.Roles); err != nil { + } else if roles, err := s.db.GetRoles(request.Context(), createUserRequest.Roles); err != nil { api.HandleDatabaseError(request, response, err) } else { userTemplate.Roles = roles @@ -520,7 +520,7 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response) } else if len(updateUserRequest.Roles) > 1 { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "a user can only have one role", request), response) - } else if roles, err := s.db.GetRoles(updateUserRequest.Roles); err != nil { + } else if roles, err := s.db.GetRoles(request.Context(), updateUserRequest.Roles); err != nil { api.HandleDatabaseError(request, response, err) } else { user.Roles = roles diff --git a/cmd/api/src/api/v2/auth/auth_test.go b/cmd/api/src/api/v2/auth/auth_test.go index cb3a72eb9e..b779634899 100644 --- a/cmd/api/src/api/v2/auth/auth_test.go +++ b/cmd/api/src/api/v2/auth/auth_test.go @@ -136,7 +136,7 @@ func TestManagementResource_EnableUserSAML(t *testing.T) { defer mockCtrl.Finish() - mockDB.EXPECT().GetRoles(gomock.Eq(goodRoles)).Return(model.Roles{}, nil).AnyTimes() + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Eq(goodRoles)).Return(model.Roles{}, nil).AnyTimes() mockDB.EXPECT().GetUser(gomock.Any(), badUserID).Return(model.User{AuthSecret: &model.AuthSecret{}}, nil) mockDB.EXPECT().GetUser(gomock.Any(), goodUserID).Return(model.User{}, nil) mockDB.EXPECT().GetSAMLProvider(samlProviderID).Return(model.SAMLProvider{}, nil).Times(2) @@ -563,7 +563,7 @@ func TestManagementResource_ListRoles_DBError(t *testing.T) { endpoint := "/api/v2/auth/roles" mockDB := dbmocks.NewMockDatabase(mockCtrl) - mockDB.EXPECT().GetAllRoles("description desc, name", model.SQLFilter{}).Return(model.Roles{}, fmt.Errorf("foo")) + mockDB.EXPECT().GetAllRoles(gomock.Any(), "description desc, name", model.SQLFilter{}).Return(model.Roles{}, fmt.Errorf("foo")) config, err := config.NewDefaultConfiguration() require.Nilf(t, err, "Failed to create default configuration: %v", err) @@ -613,7 +613,7 @@ func TestManagementResource_ListRoles(t *testing.T) { } resources, mockDB := apitest.NewAuthManagementResource(mockCtrl) - mockDB.EXPECT().GetAllRoles("description desc, name", model.SQLFilter{}).Return(model.Roles{role1, role2}, nil) + mockDB.EXPECT().GetAllRoles(gomock.Any(), "description desc, name", model.SQLFilter{}).Return(model.Roles{role1, role2}, nil) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -656,7 +656,7 @@ func TestManagementResource_ListRoles_Filtered(t *testing.T) { } resources, mockDB := apitest.NewAuthManagementResource(mockCtrl) - mockDB.EXPECT().GetAllRoles("", model.SQLFilter{SQLString: "name = ?", Params: []any{"a"}}).Return(model.Roles{role1}, nil) + mockDB.EXPECT().GetAllRoles(gomock.Any(), "", model.SQLFilter{SQLString: "name = ?", Params: []any{"a"}}).Return(model.Roles{role1}, nil) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -1054,8 +1054,8 @@ func TestCreateUser_Failure(t *testing.T) { Duration: appcfg.DefaultPasswordExpirationWindow, }), }, nil).AnyTimes() - mockDB.EXPECT().GetRoles(badRole).Return(model.Roles{}, fmt.Errorf("db error")) - mockDB.EXPECT().GetRoles(gomock.Not(badRole)).Return(model.Roles{}, nil).AnyTimes() + mockDB.EXPECT().GetRoles(gomock.Any(), badRole).Return(model.Roles{}, fmt.Errorf("db error")) + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Not(badRole)).Return(model.Roles{}, nil).AnyTimes() mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockDB.EXPECT().CreateUser(gomock.Any(), badUser).Return(model.User{}, fmt.Errorf("db error")) @@ -1170,7 +1170,7 @@ func TestCreateUser_Success(t *testing.T) { Duration: appcfg.DefaultPasswordExpirationWindow, }), }, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil) mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) @@ -1223,7 +1223,7 @@ func TestCreateUser_ResetPassword(t *testing.T) { Duration: appcfg.DefaultPasswordExpirationWindow, }), }, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil) mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil) input := struct { @@ -1296,7 +1296,7 @@ func TestManagementResource_UpdateUser_IDMalformed(t *testing.T) { Duration: appcfg.DefaultPasswordExpirationWindow, }), }, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil) mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) @@ -1359,7 +1359,7 @@ func TestManagementResource_UpdateUser_GetUserError(t *testing.T) { Duration: appcfg.DefaultPasswordExpirationWindow, }), }, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil) mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() mockDB.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(model.User{}, fmt.Errorf("foo")) @@ -1423,10 +1423,10 @@ func TestManagementResource_UpdateUser_GetRolesError(t *testing.T) { Duration: appcfg.DefaultPasswordExpirationWindow, }), }, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil) mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() mockDB.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(goodUser, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, fmt.Errorf("foo")) + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, fmt.Errorf("foo")) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) input := v2.CreateUserRequest{ @@ -1481,10 +1481,10 @@ func TestManagementResource_UpdateUser_SelfDisable(t *testing.T) { Duration: appcfg.DefaultPasswordExpirationWindow, }), }, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil) mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() mockDB.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(goodUser, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{model.Role{ + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{model.Role{ Name: "admin", Description: "admin", Permissions: model.Permissions{model.Permission{ @@ -1562,10 +1562,10 @@ func TestManagementResource_UpdateUser_LookupActiveSessionsError(t *testing.T) { Duration: appcfg.DefaultPasswordExpirationWindow, }), }, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil) mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() mockDB.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(goodUser, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{model.Role{ + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{model.Role{ Name: "admin", Description: "admin", Permissions: model.Permissions{model.Permission{ @@ -1643,10 +1643,10 @@ func TestManagementResource_UpdateUser_DBError(t *testing.T) { Duration: appcfg.DefaultPasswordExpirationWindow, }), }, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil) mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() mockDB.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(goodUser, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{model.Role{ + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{model.Role{ Name: "admin", Description: "admin", Permissions: model.Permissions{model.Permission{ @@ -1868,10 +1868,10 @@ func TestManagementResource_UpdateUser_Success(t *testing.T) { Duration: appcfg.DefaultPasswordExpirationWindow, }), }, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{}, nil) mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() mockDB.EXPECT().GetUser(gomock.Any(), gomock.Any()).Return(goodUser, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{model.Role{ + mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{model.Role{ Name: "admin", Description: "admin", Permissions: model.Permissions{model.Permission{ diff --git a/cmd/api/src/api/v2/integration/auth.go b/cmd/api/src/api/v2/integration/auth.go index b89f5e346b..53ddf343ca 100644 --- a/cmd/api/src/api/v2/integration/auth.go +++ b/cmd/api/src/api/v2/integration/auth.go @@ -1,24 +1,24 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package integration import ( - "github.com/specterops/bloodhound/src/model" "github.com/gofrs/uuid" + "github.com/specterops/bloodhound/src/model" "github.com/stretchr/testify/require" ) @@ -64,16 +64,6 @@ func (s *Context) GetRolesByName(roleNames ...string) model.Roles { return foundRoles } -func (s *Context) SetUserRole(userID uuid.UUID, roleName string) { - err := s.AdminClient().UserAddRole(userID, s.GetRolesByName(roleName)[0].ID) - require.Nilf(s.TestCtrl, err, "Failed to set role for user %s: %v", userID.String(), err) -} - -func (s *Context) RemoveUserRole(userID uuid.UUID, roleName string) { - err := s.AdminClient().UserRemoveRole(userID, s.GetRolesByName(roleName)[0].ID) - require.Nilf(s.TestCtrl, err, "Failed to remove role for user %s: %v", userID.String(), err) -} - func (s *Context) ListUsers() model.Users { listUsersResponse, err := s.AdminClient().ListUsers() require.Nilf(s.TestCtrl, err, "Failed to list users: %v", err) diff --git a/cmd/api/src/bootstrap/server.go b/cmd/api/src/bootstrap/server.go index f8de3fa299..4db3f716fa 100644 --- a/cmd/api/src/bootstrap/server.go +++ b/cmd/api/src/bootstrap/server.go @@ -66,7 +66,7 @@ func MigrateGraph(ctx context.Context, db graph.Database, schema graph.Schema) e } // MigrateDB runs database migrations on PG -func MigrateDB(cfg config.Configuration, db database.Database) error { +func MigrateDB(ctx context.Context, cfg config.Configuration, db database.Database) error { if err := db.Migrate(); err != nil { return err } @@ -79,7 +79,7 @@ func MigrateDB(cfg config.Configuration, db database.Database) error { secretDigester := cfg.Crypto.Argon2.NewDigester() - if roles, err := db.GetAllRoles("", model.SQLFilter{}); err != nil { + if roles, err := db.GetAllRoles(ctx, "", model.SQLFilter{}); err != nil { return fmt.Errorf("error while attempting to fetch user roles: %w", err) } else if secretDigest, err := secretDigester.Digest(cfg.DefaultAdmin.Password); err != nil { return fmt.Errorf("error while attempting to digest secret for user: %w", err) diff --git a/cmd/api/src/database/auth.go b/cmd/api/src/database/auth.go index f7d2b28481..da26896459 100644 --- a/cmd/api/src/database/auth.go +++ b/cmd/api/src/database/auth.go @@ -88,65 +88,30 @@ func (s contextInitializer) InitContextFromToken(ctx context.Context, authToken return auth.Context{}, ErrNotFound } -func (s *BloodhoundDB) CreateRole(role model.Role) (model.Role, error) { - var ( - updatedRole = role - result = s.db.Create(&updatedRole) - ) - - return updatedRole, CheckError(result) -} - -// UpdateRole updates permissions for the row matching the provided Role struct -// UPDATE roles SET permissions=.... WHERE role_id = ... -func (s *BloodhoundDB) UpdateRole(role model.Role) error { - // Update permissions first - if err := s.db.Model(&role).Association("Permissions").Replace(&role.Permissions); err != nil { - return err - } - - result := s.db.Save(&role) - return CheckError(result) -} - // GetAllRoles retrieves all available roles in the db // SELECT * FROM roles -func (s *BloodhoundDB) GetAllRoles(order string, filter model.SQLFilter) (model.Roles, error) { +func (s *BloodhoundDB) GetAllRoles(ctx context.Context, order string, filter model.SQLFilter) (model.Roles, error) { var ( roles model.Roles - result *gorm.DB + cursor = s.preload(model.RoleAssociations()).WithContext(ctx) ) - if order == "" && filter.SQLString == "" { - result = s.preload(model.RoleAssociations()).Find(&roles) - } else if order == "" && filter.SQLString != "" { - result = s.preload(model.RoleAssociations()).Where(filter.SQLString, filter.Params).Find(&roles) - } else if order != "" && filter.SQLString == "" { - result = s.preload(model.RoleAssociations()).Order(order).Find(&roles) - } else { - result = s.preload(model.RoleAssociations()).Where(filter.SQLString, filter.Params).Order(order).Find(&roles) + if order != "" && filter.SQLString == "" { + cursor = cursor.Order(order) + } + if filter.SQLString != "" { + cursor = cursor.Where(filter.SQLString, filter.Params) } - return roles, CheckError(result) + return roles, CheckError(cursor.Find(&roles)) } // GetRoles retrieves all rows in the Roles table corresponding to the provided list of IDs // SELECT * FROM roles where ID in (...) -func (s *BloodhoundDB) GetRoles(ids []int32) (model.Roles, error) { - var ( - roles model.Roles - result = s.preload(model.RoleAssociations()).Where("id in ?", ids).Find(&roles) - ) - - return roles, CheckError(result) -} - -// GetRolesByName retrieves all rows in the Roles table corresponding to the provided list of role names -// SELECT * FROM roles WHERE role_name IN (..) -func (s *BloodhoundDB) GetRolesByName(names []string) (model.Roles, error) { +func (s *BloodhoundDB) GetRoles(ctx context.Context, ids []int32) (model.Roles, error) { var ( roles model.Roles - result = s.preload(model.RoleAssociations()).Where("name in ?", names).Find(&roles) + result = s.preload(model.RoleAssociations()).WithContext(ctx).Where("id in ?", ids).Find(&roles) ) return roles, CheckError(result) @@ -154,21 +119,10 @@ func (s *BloodhoundDB) GetRolesByName(names []string) (model.Roles, error) { // GetRole retrieves the role associated with the provided ID // SELECT * FROM roles WHERE role_id = .... -func (s *BloodhoundDB) GetRole(id int32) (model.Role, error) { - var ( - role model.Role - result = s.preload(model.RoleAssociations()).First(&role, id) - ) - - return role, CheckError(result) -} - -// LookupRoleByName retrieves a row from the Roles table corresponding to the role name provided -// SELECT * FROM roles WHERE role_name = .... -func (s *BloodhoundDB) LookupRoleByName(name string) (model.Role, error) { +func (s *BloodhoundDB) GetRole(ctx context.Context, id int32) (model.Role, error) { var ( role model.Role - result = s.preload(model.RoleAssociations()).Where("name = ?", name).First(&role) + result = s.preload(model.RoleAssociations()).WithContext(ctx).First(&role, id) ) return role, CheckError(result) diff --git a/cmd/api/src/database/auth_test.go b/cmd/api/src/database/auth_test.go index 36148067d1..80f0d74979 100644 --- a/cmd/api/src/database/auth_test.go +++ b/cmd/api/src/database/auth_test.go @@ -45,7 +45,7 @@ func initAndGetRoles(t *testing.T) (database.Database, model.Roles) { t.Fatalf("Failed preparing DB: %v", err) } - if roles, err := dbInst.GetAllRoles("", model.SQLFilter{}); err != nil { + if roles, err := dbInst.GetAllRoles(context.Background(), "", model.SQLFilter{}); err != nil { t.Fatalf("Error fetching roles: %v", err) } else { return dbInst, roles @@ -140,41 +140,6 @@ func TestDatabase_InitializeRoles(t *testing.T) { } } -func TestDatabase_UpdateRole(t *testing.T) { - dbInst, roles := initAndGetRoles(t) - - if role, found := roles.FindByName(auth.RoleReadOnly); !found { - t.Fatal("Unable to find role") - } else if allPermissions, err := dbInst.GetAllPermissions(context.Background(), "", model.SQLFilter{}); err != nil { - t.Fatalf("Failed fetching all permissions: %v", err) - } else { - role.Permissions = allPermissions - - if err := dbInst.UpdateRole(role); err != nil { - t.Fatalf("Failed updating role %s: %v", role.Name, err) - } - - if updatedRole, err := dbInst.GetRole(role.ID); err != nil { - t.Fatalf("Failed fetching updated role %s: %v", role.Name, err) - } else { - for _, permission := range role.Permissions { - found := false - - for _, updatedPermission := range updatedRole.Permissions { - if permission.Equals(updatedPermission) { - found = true - break - } - } - - if !found { - t.Fatalf("Updated role %s missing expected permission %s", role.Name, permission) - } - } - } - } -} - func TestDatabase_CreateGetDeleteUser(t *testing.T) { var ( ctx = context.Background() diff --git a/cmd/api/src/database/db.go b/cmd/api/src/database/db.go index fc13a10088..4c6ede00df 100644 --- a/cmd/api/src/database/db.go +++ b/cmd/api/src/database/db.go @@ -84,13 +84,9 @@ type Database interface { ListAuditLogs(ctx context.Context, before, after time.Time, offset, limit int, order string, filter model.SQLFilter) (model.AuditLogs, int, error) // Roles - CreateRole(role model.Role) (model.Role, error) - UpdateRole(role model.Role) error - GetAllRoles(order string, filter model.SQLFilter) (model.Roles, error) - GetRoles(ids []int32) (model.Roles, error) - GetRolesByName(names []string) (model.Roles, error) - GetRole(id int32) (model.Role, error) - LookupRoleByName(name string) (model.Role, error) + GetAllRoles(ctx context.Context, order string, filter model.SQLFilter) (model.Roles, error) + GetRoles(ctx context.Context, ids []int32) (model.Roles, error) + GetRole(ctx context.Context, id int32) (model.Role, error) // Permissions GetAllPermissions(ctx context.Context, order string, filter model.SQLFilter) (model.Permissions, error) diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index 52bb7b17b7..ca15e1a259 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -258,21 +258,6 @@ func (mr *MockDatabaseMockRecorder) CreateInstallation() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateInstallation", reflect.TypeOf((*MockDatabase)(nil).CreateInstallation)) } -// CreateRole mocks base method. -func (m *MockDatabase) CreateRole(arg0 model.Role) (model.Role, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateRole", arg0) - ret0, _ := ret[0].(model.Role) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateRole indicates an expected call of CreateRole. -func (mr *MockDatabaseMockRecorder) CreateRole(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRole", reflect.TypeOf((*MockDatabase)(nil).CreateRole), arg0) -} - // CreateSAMLIdentityProvider mocks base method. func (m *MockDatabase) CreateSAMLIdentityProvider(arg0 context.Context, arg1 model.SAMLProvider) (model.SAMLProvider, error) { m.ctrl.T.Helper() @@ -638,18 +623,18 @@ func (mr *MockDatabaseMockRecorder) GetAllPermissions(arg0, arg1, arg2 interface } // GetAllRoles mocks base method. -func (m *MockDatabase) GetAllRoles(arg0 string, arg1 model.SQLFilter) (model.Roles, error) { +func (m *MockDatabase) GetAllRoles(arg0 context.Context, arg1 string, arg2 model.SQLFilter) (model.Roles, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAllRoles", arg0, arg1) + ret := m.ctrl.Call(m, "GetAllRoles", arg0, arg1, arg2) ret0, _ := ret[0].(model.Roles) ret1, _ := ret[1].(error) return ret0, ret1 } // GetAllRoles indicates an expected call of GetAllRoles. -func (mr *MockDatabaseMockRecorder) GetAllRoles(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) GetAllRoles(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllRoles", reflect.TypeOf((*MockDatabase)(nil).GetAllRoles), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllRoles", reflect.TypeOf((*MockDatabase)(nil).GetAllRoles), arg0, arg1, arg2) } // GetAllSAMLProviders mocks base method. @@ -925,48 +910,33 @@ func (mr *MockDatabaseMockRecorder) GetPermission(arg0, arg1 interface{}) *gomoc } // GetRole mocks base method. -func (m *MockDatabase) GetRole(arg0 int32) (model.Role, error) { +func (m *MockDatabase) GetRole(arg0 context.Context, arg1 int32) (model.Role, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRole", arg0) + ret := m.ctrl.Call(m, "GetRole", arg0, arg1) ret0, _ := ret[0].(model.Role) ret1, _ := ret[1].(error) return ret0, ret1 } // GetRole indicates an expected call of GetRole. -func (mr *MockDatabaseMockRecorder) GetRole(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) GetRole(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRole", reflect.TypeOf((*MockDatabase)(nil).GetRole), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRole", reflect.TypeOf((*MockDatabase)(nil).GetRole), arg0, arg1) } // GetRoles mocks base method. -func (m *MockDatabase) GetRoles(arg0 []int32) (model.Roles, error) { +func (m *MockDatabase) GetRoles(arg0 context.Context, arg1 []int32) (model.Roles, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRoles", arg0) + ret := m.ctrl.Call(m, "GetRoles", arg0, arg1) ret0, _ := ret[0].(model.Roles) ret1, _ := ret[1].(error) return ret0, ret1 } // GetRoles indicates an expected call of GetRoles. -func (mr *MockDatabaseMockRecorder) GetRoles(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoles", reflect.TypeOf((*MockDatabase)(nil).GetRoles), arg0) -} - -// GetRolesByName mocks base method. -func (m *MockDatabase) GetRolesByName(arg0 []string) (model.Roles, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRolesByName", arg0) - ret0, _ := ret[0].(model.Roles) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetRolesByName indicates an expected call of GetRolesByName. -func (mr *MockDatabaseMockRecorder) GetRolesByName(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) GetRoles(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRolesByName", reflect.TypeOf((*MockDatabase)(nil).GetRolesByName), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoles", reflect.TypeOf((*MockDatabase)(nil).GetRoles), arg0, arg1) } // GetSAMLProvider mocks base method. @@ -1167,21 +1137,6 @@ func (mr *MockDatabaseMockRecorder) LookupActiveSessionsByUser(arg0 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupActiveSessionsByUser", reflect.TypeOf((*MockDatabase)(nil).LookupActiveSessionsByUser), arg0) } -// LookupRoleByName mocks base method. -func (m *MockDatabase) LookupRoleByName(arg0 string) (model.Role, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LookupRoleByName", arg0) - ret0, _ := ret[0].(model.Role) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// LookupRoleByName indicates an expected call of LookupRoleByName. -func (mr *MockDatabaseMockRecorder) LookupRoleByName(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupRoleByName", reflect.TypeOf((*MockDatabase)(nil).LookupRoleByName), arg0) -} - // LookupSAMLProviderByName mocks base method. func (m *MockDatabase) LookupSAMLProviderByName(arg0 string) (model.SAMLProvider, error) { m.ctrl.T.Helper() @@ -1379,20 +1334,6 @@ func (mr *MockDatabaseMockRecorder) UpdateFileUploadJob(arg0 interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateFileUploadJob", reflect.TypeOf((*MockDatabase)(nil).UpdateFileUploadJob), arg0) } -// UpdateRole mocks base method. -func (m *MockDatabase) UpdateRole(arg0 model.Role) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateRole", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// UpdateRole indicates an expected call of UpdateRole. -func (mr *MockDatabaseMockRecorder) UpdateRole(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateRole", reflect.TypeOf((*MockDatabase)(nil).UpdateRole), arg0) -} - // UpdateSAMLIdentityProvider mocks base method. func (m *MockDatabase) UpdateSAMLIdentityProvider(arg0 context.Context, arg1 model.SAMLProvider) error { m.ctrl.T.Helper() diff --git a/cmd/api/src/services/entrypoint.go b/cmd/api/src/services/entrypoint.go index 175fc92a04..71afffe962 100644 --- a/cmd/api/src/services/entrypoint.go +++ b/cmd/api/src/services/entrypoint.go @@ -69,7 +69,7 @@ func ConnectDatabases(ctx context.Context, cfg config.Configuration) (bootstrap. func Entrypoint(ctx context.Context, cfg config.Configuration, connections bootstrap.DatabaseConnections[*database.BloodhoundDB, *graph.DatabaseSwitch]) ([]daemons.Daemon, error) { if !cfg.DisableMigrations { - if err := bootstrap.MigrateDB(cfg, connections.RDMS); err != nil { + if err := bootstrap.MigrateDB(ctx, cfg, connections.RDMS); err != nil { return nil, fmt.Errorf("rdms migration error: %w", err) } else if err := bootstrap.MigrateGraph(ctx, connections.Graph, schema.DefaultGraphSchema()); err != nil { return nil, fmt.Errorf("graph migration error: %w", err) diff --git a/cmd/api/src/test/lab/fixtures/postgres.go b/cmd/api/src/test/lab/fixtures/postgres.go index d7768bdce9..2720a6ebeb 100644 --- a/cmd/api/src/test/lab/fixtures/postgres.go +++ b/cmd/api/src/test/lab/fixtures/postgres.go @@ -17,6 +17,7 @@ package fixtures import ( + "context" "fmt" "log" @@ -35,7 +36,7 @@ var PostgresFixture = lab.NewFixture(func(harness *lab.Harness) (*database.Blood return nil, err } else if err := integration.Prepare(database.NewBloodhoundDB(pgdb, auth.NewIdentityResolver())); err != nil { return nil, fmt.Errorf("failed ensuring database: %v", err) - } else if err := bootstrap.MigrateDB(config, database.NewBloodhoundDB(pgdb, auth.NewIdentityResolver())); err != nil { + } else if err := bootstrap.MigrateDB(context.Background(), config, database.NewBloodhoundDB(pgdb, auth.NewIdentityResolver())); err != nil { return nil, fmt.Errorf("failed migrating database: %v", err) } else { return database.NewBloodhoundDB(pgdb, auth.NewIdentityResolver()), nil