diff --git a/.github/config/integration.config.json b/.github/config/integration.config.json index 90a5973b3c..bdfd7c58ac 100644 --- a/.github/config/integration.config.json +++ b/.github/config/integration.config.json @@ -7,6 +7,8 @@ "collectors_base_path": "/tmp/collectors", "log_level": "ERROR", "log_path": "bhapi.log", + "enable_startup_wait_period": false, + "datapipe_interval": 1, "features": { "enable_auth": true }, diff --git a/cmd/api/src/api/middleware/auth.go b/cmd/api/src/api/middleware/auth.go index 6296dab345..375d93e8b1 100644 --- a/cmd/api/src/api/middleware/auth.go +++ b/cmd/api/src/api/middleware/auth.go @@ -22,10 +22,9 @@ import ( "strings" "time" - "github.com/specterops/bloodhound/src/ctx" - "github.com/gofrs/uuid" "github.com/gorilla/mux" + "github.com/specterops/bloodhound/src/ctx" "github.com/specterops/bloodhound/src/api" "github.com/specterops/bloodhound/src/auth" @@ -107,6 +106,7 @@ func PermissionsCheckAll(authorizer auth.Authorizer, permissions ...model.Permis if bhCtx := ctx.FromRequest(request); !bhCtx.AuthCtx.Authenticated() { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusUnauthorized, "not authenticated", request), response) } else if !authorizer.AllowsAllPermissions(bhCtx.AuthCtx, permissions) { + authorizer.AuditLogUnauthorizedAccess(request) api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusForbidden, "not authorized", request), response) } else { next.ServeHTTP(response, request) @@ -123,6 +123,7 @@ func PermissionsCheckAtLeastOne(authorizer auth.Authorizer, permissions ...model if bhCtx := ctx.FromRequest(request); !bhCtx.AuthCtx.Authenticated() { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusUnauthorized, "not authenticated", request), response) } else if !authorizer.AllowsAtLeastOnePermission(bhCtx.AuthCtx, permissions) { + authorizer.AuditLogUnauthorizedAccess(request) api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusForbidden, "not authorized", request), response) } else { next.ServeHTTP(response, request) @@ -188,6 +189,7 @@ func AuthorizeAuthManagementAccess(permissions auth.PermissionSet, authorizer au } if !authorized { + authorizer.AuditLogUnauthorizedAccess(request) api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusUnauthorized, fmt.Sprintf("not authorized for %s", userID), request), response) } else { next.ServeHTTP(response, request) diff --git a/cmd/api/src/api/middleware/auth_test.go b/cmd/api/src/api/middleware/auth_test.go index 4bb73cbca9..94ab0e8423 100644 --- a/cmd/api/src/api/middleware/auth_test.go +++ b/cmd/api/src/api/middleware/auth_test.go @@ -25,18 +25,20 @@ import ( "github.com/specterops/bloodhound/src/api" "github.com/specterops/bloodhound/src/auth" "github.com/specterops/bloodhound/src/ctx" + dbmocks "github.com/specterops/bloodhound/src/database/mocks" "github.com/specterops/bloodhound/src/model" "github.com/specterops/bloodhound/src/test/must" "github.com/specterops/bloodhound/src/utils/test" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" ) -func permissionsCheckAllHandler(internalHandler http.HandlerFunc, permissions ...model.Permission) http.Handler { - return PermissionsCheckAll(auth.NewAuthorizer(), permissions...)(internalHandler) +func permissionsCheckAllHandler(db *dbmocks.MockDatabase, internalHandler http.HandlerFunc, permissions ...model.Permission) http.Handler { + return PermissionsCheckAll(auth.NewAuthorizer(db), permissions...)(internalHandler) } -func permissionsCheckAtLeastOneHandler(internalHandler http.HandlerFunc, permissions ...model.Permission) http.Handler { - return PermissionsCheckAtLeastOne(auth.NewAuthorizer(), permissions...)(internalHandler) +func permissionsCheckAtLeastOneHandler(db *dbmocks.MockDatabase, internalHandler http.HandlerFunc, permissions ...model.Permission) http.Handler { + return PermissionsCheckAtLeastOne(auth.NewAuthorizer(db), permissions...)(internalHandler) } func Test_parseAuthorizationHeader(t *testing.T) { @@ -61,13 +63,17 @@ func TestPermissionsCheckAll(t *testing.T) { handlerReturn200 = func(response http.ResponseWriter, request *http.Request) { response.WriteHeader(http.StatusOK) } + mockCtrl = gomock.NewController(t) + mockDB = dbmocks.NewMockDatabase(mockCtrl) ) + defer mockCtrl.Finish() + mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() test.Request(t). WithURL("http//example.com"). WithHeader(headers.RequestID.String(), "requestID"). WithContext(&ctx.Context{}). - OnHandler(permissionsCheckAllHandler(handlerReturn200, auth.Permissions().AuthManageSelf)). + OnHandler(permissionsCheckAllHandler(mockDB, handlerReturn200, auth.Permissions().AuthManageSelf)). Require(). ResponseStatusCode(http.StatusUnauthorized) @@ -87,10 +93,11 @@ func TestPermissionsCheckAll(t *testing.T) { Session: model.UserSession{}, }, }). - OnHandler(permissionsCheckAllHandler(handlerReturn200, auth.Permissions().AuthManageSelf)). + OnHandler(permissionsCheckAllHandler(mockDB, handlerReturn200, auth.Permissions().AuthManageSelf)). Require(). ResponseStatusCode(http.StatusForbidden) + mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any()).Return(nil).Times(0) // No audit logs should be created on successful login test.Request(t). WithURL("http//example.com"). WithHeader(headers.RequestID.String(), "requestID"). @@ -109,7 +116,7 @@ func TestPermissionsCheckAll(t *testing.T) { Session: model.UserSession{}, }, }). - OnHandler(permissionsCheckAllHandler(handlerReturn200, auth.Permissions().AuthManageSelf)). + OnHandler(permissionsCheckAllHandler(mockDB, handlerReturn200, auth.Permissions().AuthManageSelf)). Require(). ResponseStatusCode(http.StatusOK) } @@ -119,8 +126,12 @@ func TestPermissionsCheckAtLeastOne(t *testing.T) { handlerReturn200 = func(response http.ResponseWriter, request *http.Request) { response.WriteHeader(http.StatusOK) } + mockCtrl = gomock.NewController(t) + mockDB = dbmocks.NewMockDatabase(mockCtrl) ) + defer mockCtrl.Finish() + mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any()).Return(nil).Times(0) test.Request(t). WithURL("http//example.com"). WithContext(&ctx.Context{ @@ -138,7 +149,7 @@ func TestPermissionsCheckAtLeastOne(t *testing.T) { Session: model.UserSession{}, }, }). - OnHandler(permissionsCheckAtLeastOneHandler(handlerReturn200, auth.Permissions().AuthManageSelf)). + OnHandler(permissionsCheckAtLeastOneHandler(mockDB, handlerReturn200, auth.Permissions().AuthManageSelf)). Require(). ResponseStatusCode(http.StatusOK) @@ -159,7 +170,7 @@ func TestPermissionsCheckAtLeastOne(t *testing.T) { Session: model.UserSession{}, }, }). - OnHandler(permissionsCheckAtLeastOneHandler(handlerReturn200, auth.Permissions().AuthManageSelf)). + OnHandler(permissionsCheckAtLeastOneHandler(mockDB, handlerReturn200, auth.Permissions().AuthManageSelf)). Require(). ResponseStatusCode(http.StatusOK) @@ -180,12 +191,13 @@ func TestPermissionsCheckAtLeastOne(t *testing.T) { Session: model.UserSession{}, }, }). - OnHandler(permissionsCheckAtLeastOneHandler(handlerReturn200, auth.Permissions().GraphDBRead)). + OnHandler(permissionsCheckAtLeastOneHandler(mockDB, handlerReturn200, auth.Permissions().GraphDBRead)). Require(). ResponseStatusCode(http.StatusOK) test.Request(t). WithURL("http//example.com"). + WithHeader(headers.RequestID.String(), "requestID"). WithContext(&ctx.Context{ AuthCtx: auth.Context{ PermissionOverrides: auth.PermissionOverrides{}, @@ -194,20 +206,20 @@ func TestPermissionsCheckAtLeastOne(t *testing.T) { { Name: "Big Boy", Description: "The big boy.", - Permissions: model.Permissions{auth.Permissions().AuthManageSelf, auth.Permissions().GraphDBRead}, + Permissions: auth.Permissions().All(), }, }, }, Session: model.UserSession{}, }, }). - OnHandler(permissionsCheckAtLeastOneHandler(handlerReturn200, auth.Permissions().GraphDBWrite)). + OnHandler(permissionsCheckAtLeastOneHandler(mockDB, handlerReturn200, auth.Permissions().AuthManageSelf)). Require(). - ResponseStatusCode(http.StatusForbidden) + ResponseStatusCode(http.StatusOK) + mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any()).Return(nil).Times(1) test.Request(t). WithURL("http//example.com"). - WithHeader(headers.RequestID.String(), "requestID"). WithContext(&ctx.Context{ AuthCtx: auth.Context{ PermissionOverrides: auth.PermissionOverrides{}, @@ -216,14 +228,14 @@ func TestPermissionsCheckAtLeastOne(t *testing.T) { { Name: "Big Boy", Description: "The big boy.", - Permissions: auth.Permissions().All(), + Permissions: model.Permissions{auth.Permissions().AuthManageSelf, auth.Permissions().GraphDBRead}, }, }, }, Session: model.UserSession{}, }, }). - OnHandler(permissionsCheckAtLeastOneHandler(handlerReturn200, auth.Permissions().AuthManageSelf)). + OnHandler(permissionsCheckAtLeastOneHandler(mockDB, handlerReturn200, auth.Permissions().GraphDBWrite)). Require(). - ResponseStatusCode(http.StatusOK) + ResponseStatusCode(http.StatusForbidden) } diff --git a/cmd/api/src/api/middleware/logging.go b/cmd/api/src/api/middleware/logging.go index 605bb852cf..a78524d175 100644 --- a/cmd/api/src/api/middleware/logging.go +++ b/cmd/api/src/api/middleware/logging.go @@ -29,6 +29,7 @@ import ( "github.com/specterops/bloodhound/src/auth" "github.com/specterops/bloodhound/src/config" "github.com/specterops/bloodhound/src/ctx" + "github.com/specterops/bloodhound/src/database" ) // PanicHandler is a middleware func that sets up a defer-recovery trap to capture any unhandled panics that bubble @@ -114,7 +115,7 @@ func setSignedRequestFields(request *http.Request, logEvent log.Event) { // LoggingMiddleware is a middleware func that outputs a log for each request-response lifecycle. It includes timestamped // information organized into fields suitable for searching or parsing. -func LoggingMiddleware(cfg config.Configuration, idResolver auth.IdentityResolver) func(http.Handler) http.Handler { +func LoggingMiddleware(cfg config.Configuration, idResolver auth.IdentityResolver, db *database.BloodhoundDB) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { var ( diff --git a/cmd/api/src/api/middleware/middleware.go b/cmd/api/src/api/middleware/middleware.go index 71f1fdbdf4..353f6fbc9a 100644 --- a/cmd/api/src/api/middleware/middleware.go +++ b/cmd/api/src/api/middleware/middleware.go @@ -129,8 +129,8 @@ func ContextMiddleware(next http.Handler) http.Handler { // Create a new context with the timeout requestCtx, cancel := context.WithTimeout(request.Context(), requestedWaitDuration.Value) defer cancel() - // Insert the bh context + requestCtx = ctx.Set(requestCtx, &ctx.Context{ StartTime: startTime, Timeout: requestedWaitDuration, @@ -139,6 +139,7 @@ func ContextMiddleware(next http.Handler) http.Handler { Scheme: getScheme(request), Host: request.Host, }, + RequestIP: parseUserIP(request), }) // Route the request with the embedded context @@ -147,6 +148,16 @@ func ContextMiddleware(next http.Handler) http.Handler { }) } +func parseUserIP(r *http.Request) string { + if result := r.Header.Get("X-Forwarded-For"); result == "" { + log.Warnf("No data found in X-Forwarded-For header") + return r.RemoteAddr + } else { + result += "," + r.RemoteAddr + return result + } +} + func ParseHeaderValues(values string) map[string]string { parsed := map[string]string{} diff --git a/cmd/api/src/api/middleware/middleware_internal_test.go b/cmd/api/src/api/middleware/middleware_internal_test.go index 3b79dd5383..e1c26fa135 100644 --- a/cmd/api/src/api/middleware/middleware_internal_test.go +++ b/cmd/api/src/api/middleware/middleware_internal_test.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "net/http" "net/url" + "strings" "testing" "time" @@ -79,6 +80,27 @@ func TestRequestWaitDuration(t *testing.T) { require.True(t, requestedWaitDuration.UserSet) } +func TestParseUserIP_XForwardedFor_RemoteAddr(t *testing.T) { + req, err := http.NewRequest("GET", "/teapot", nil) + require.Nil(t, err) + + ip1 := "192.168.1.1:8080" + ip2 := "192.168.1.2" + ip3 := "192.168.1.3" + + req.Header.Set("X-Forwarded-For", strings.Join([]string{ip1, ip2, ip3}, ",")) + req.RemoteAddr = "0.0.0.0:3000" + + require.Equal(t, parseUserIP(req), strings.Join([]string{ip1, ip2, ip3, req.RemoteAddr}, ",")) +} + +func TestParseUserIP_RemoteAddrOnly(t *testing.T) { + req, err := http.NewRequest("GET", "/teapot", nil) + require.Nil(t, err) + req.RemoteAddr = "0.0.0.0:3000" + require.Equal(t, parseUserIP(req), req.RemoteAddr) +} + func TestParsePreferHeaderWait(t *testing.T) { _, err := parsePreferHeaderWait("wait=1.5", 30*time.Second) require.NotNil(t, err) diff --git a/cmd/api/src/api/registration/registration.go b/cmd/api/src/api/registration/registration.go index 4a5e9e3611..8a18a45a68 100644 --- a/cmd/api/src/api/registration/registration.go +++ b/cmd/api/src/api/registration/registration.go @@ -33,14 +33,14 @@ import ( "github.com/specterops/bloodhound/src/queries" ) -func RegisterFossGlobalMiddleware(routerInst *router.Router, cfg config.Configuration, identityResolver auth.IdentityResolver, authenticator api.Authenticator) { +func RegisterFossGlobalMiddleware(routerInst *router.Router, cfg config.Configuration, db *database.BloodhoundDB, identityResolver auth.IdentityResolver, authenticator api.Authenticator) { // Set up the middleware stack routerInst.UsePrerouting(middleware.ContextMiddleware) routerInst.UsePrerouting(middleware.CORSMiddleware()) // Set up logging. This must be done after ContextMiddleware is initialized so the context can be accessed in the log logic if cfg.EnableAPILogging { - routerInst.UsePrerouting(middleware.LoggingMiddleware(cfg, identityResolver)) + routerInst.UsePrerouting(middleware.LoggingMiddleware(cfg, identityResolver, db)) } routerInst.UsePostrouting( diff --git a/cmd/api/src/api/registration/v2.go b/cmd/api/src/api/registration/v2.go index 07e7a41312..64b7820909 100644 --- a/cmd/api/src/api/registration/v2.go +++ b/cmd/api/src/api/registration/v2.go @@ -40,7 +40,7 @@ func samlWriteAPIErrorResponse(request *http.Request, response http.ResponseWrit func registerV2Auth(cfg config.Configuration, db database.Database, permissions auth.PermissionSet, routerInst *router.Router, authenticator api.Authenticator) { var ( loginResource = authapi.NewLoginResource(cfg, authenticator, db) - managementResource = authapi.NewManagementResource(cfg, db, auth.NewAuthorizer()) + managementResource = authapi.NewManagementResource(cfg, db, auth.NewAuthorizer(db)) samlResource = saml.NewSAMLRootResource(cfg, db, samlWriteAPIErrorResponse) ) diff --git a/cmd/api/src/api/v2/agi.go b/cmd/api/src/api/v2/agi.go index ac824efcdb..66d6be9ca1 100644 --- a/cmd/api/src/api/v2/agi.go +++ b/cmd/api/src/api/v2/agi.go @@ -175,20 +175,19 @@ func (s Resources) UpdateAssetGroup(response http.ResponseWriter, request *http. pathVars = mux.Vars(request) rawAssetGroupID = pathVars[api.URIPathVariableAssetGroupID] updateAssetGroupRequest UpdateAssetGroupRequest + assetGroup model.AssetGroup ) if assetGroupID, err := strconv.Atoi(rawAssetGroupID); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseDetailsIDMalformed, request), response) } else if err := api.ReadJSONRequestPayloadLimited(&updateAssetGroupRequest, request); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response) - } else if err := s.DB.AppendAuditLog(*ctx.FromRequest(request), "UpdateAssetGroup", updateAssetGroupRequest); err != nil { - api.HandleDatabaseError(request, response, err) - } else if assetGroup, err := s.DB.GetAssetGroup(int32(assetGroupID)); err != nil { + } else if assetGroup, err = s.DB.GetAssetGroup(int32(assetGroupID)); err != nil { api.HandleDatabaseError(request, response, err) } else { assetGroup.Name = updateAssetGroupRequest.Name - if err := s.DB.UpdateAssetGroup(assetGroup); err != nil { + if err := s.DB.UpdateAssetGroup(request.Context(), assetGroup); err != nil { api.HandleDatabaseError(request, response, err) } else { api.WriteBasicResponse(request.Context(), assetGroup, http.StatusOK, response) @@ -201,9 +200,7 @@ func (s Resources) CreateAssetGroup(response http.ResponseWriter, request *http. if err := api.ReadJSONRequestPayloadLimited(&createRequest, request); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response) - } else if err := s.DB.AppendAuditLog(*ctx.FromRequest(request), "CreateAssetGroup", createRequest); err != nil { - api.HandleDatabaseError(request, response, err) - } else if newAssetGroup, err := s.DB.CreateAssetGroup(createRequest.Name, createRequest.Tag, false); err != nil { + } else if newAssetGroup, err := s.DB.CreateAssetGroup(request.Context(), createRequest.Name, createRequest.Tag, false); err != nil { api.HandleDatabaseError(request, response, err) } else { assetGroupURL := *ctx.Get(request.Context()).Host @@ -218,17 +215,16 @@ func (s Resources) DeleteAssetGroup(response http.ResponseWriter, request *http. var ( pathVars = mux.Vars(request) rawAssetGroupID = pathVars[api.URIPathVariableAssetGroupID] + assetGroup model.AssetGroup ) if assetGroupID, err := strconv.Atoi(rawAssetGroupID); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseDetailsIDMalformed, request), response) - } else if assetGroup, err := s.DB.GetAssetGroup(int32(assetGroupID)); err != nil { + } else if assetGroup, err = s.DB.GetAssetGroup(int32(assetGroupID)); err != nil { api.HandleDatabaseError(request, response, err) } else if assetGroup.SystemGroup { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusConflict, "Cannot delete a system defined asset group.", request), response) - } else if err := s.DB.AppendAuditLog(*ctx.FromRequest(request), "DeleteAssetGroup", assetGroup); err != nil { - api.HandleDatabaseError(request, response, err) - } else if err := s.DB.DeleteAssetGroup(assetGroup); err != nil { + } else if err := s.DB.DeleteAssetGroup(request.Context(), assetGroup); err != nil { api.HandleDatabaseError(request, response, err) } else { response.WriteHeader(http.StatusOK) @@ -274,6 +270,7 @@ func (s Resources) UpdateAssetGroupSelectors(response http.ResponseWriter, reque func (s Resources) DeleteAssetGroupSelector(response http.ResponseWriter, request *http.Request) { var ( + assetGroupSelector model.AssetGroupSelector pathVars = mux.Vars(request) rawAssetGroupID = pathVars[api.URIPathVariableAssetGroupID] rawAssetGroupSelectorID = pathVars[api.URIPathVariableAssetGroupSelectorID] @@ -285,13 +282,11 @@ func (s Resources) DeleteAssetGroupSelector(response http.ResponseWriter, reques api.HandleDatabaseError(request, response, err) } else if assetGroupSelectorID, err := strconv.Atoi(rawAssetGroupSelectorID); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseDetailsIDMalformed, request), response) - } else if assetGroupSelector, err := s.DB.GetAssetGroupSelector(int32(assetGroupSelectorID)); err != nil { + } else if assetGroupSelector, err = s.DB.GetAssetGroupSelector(int32(assetGroupSelectorID)); err != nil { api.HandleDatabaseError(request, response, err) } else if assetGroupSelector.SystemSelector { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusConflict, "Cannot delete a system defined asset group selector.", request), response) - } else if err := s.DB.AppendAuditLog(*ctx.FromRequest(request), "DeleteAssetGroupSelector", assetGroupSelector); err != nil { - api.HandleDatabaseError(request, response, err) - } else if err := s.DB.DeleteAssetGroupSelector(assetGroupSelector); err != nil { + } else if err := s.DB.DeleteAssetGroupSelector(request.Context(), assetGroupSelector); err != nil { api.HandleDatabaseError(request, response, err) } else { response.WriteHeader(http.StatusOK) diff --git a/cmd/api/src/api/v2/agi_test.go b/cmd/api/src/api/v2/agi_test.go index eb9e244119..d3841cc6b5 100644 --- a/cmd/api/src/api/v2/agi_test.go +++ b/cmd/api/src/api/v2/agi_test.go @@ -20,15 +20,16 @@ import ( "context" "encoding/json" "fmt" - "github.com/specterops/bloodhound/headers" - "github.com/specterops/bloodhound/mediatypes" - "github.com/specterops/bloodhound/src/auth" - "github.com/specterops/bloodhound/src/test/must" "net/http" "net/http/httptest" "net/url" "testing" + "github.com/specterops/bloodhound/headers" + "github.com/specterops/bloodhound/mediatypes" + "github.com/specterops/bloodhound/src/auth" + "github.com/specterops/bloodhound/src/test/must" + "github.com/gorilla/mux" "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/errors" @@ -348,20 +349,7 @@ func TestResources_UpdateAssetGroup(t *testing.T) { Require(). ResponseStatusCode(http.StatusBadRequest) - // Audit Log fails - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("exploded")) - - requestTemplate. - WithURLPathVars(map[string]string{ - "asset_group_id": "1234", - }). - WithBody(v2.UpdateAssetGroupRequest{}). - OnHandlerFunc(resources.UpdateAssetGroup). - Require(). - ResponseStatusCode(http.StatusInternalServerError) - // GetAssetGroup DB fails - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, fmt.Errorf("exploded")) requestTemplate. @@ -374,9 +362,8 @@ func TestResources_UpdateAssetGroup(t *testing.T) { ResponseStatusCode(http.StatusInternalServerError) // UpdateAssetGroup DB fails - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil) - mockDB.EXPECT().UpdateAssetGroup(model.AssetGroup{}).Return(fmt.Errorf("exploded")) + mockDB.EXPECT().UpdateAssetGroup(gomock.Any(), model.AssetGroup{}).Return(fmt.Errorf("exploded")) requestTemplate. WithURLPathVars(map[string]string{ @@ -388,9 +375,8 @@ func TestResources_UpdateAssetGroup(t *testing.T) { ResponseStatusCode(http.StatusInternalServerError) // Success - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil) - mockDB.EXPECT().UpdateAssetGroup(model.AssetGroup{}).Return(nil) + mockDB.EXPECT().UpdateAssetGroup(gomock.Any(), model.AssetGroup{}).Return(nil) requestTemplate. WithURLPathVars(map[string]string{ @@ -420,18 +406,8 @@ func TestResources_CreateAssetGroup(t *testing.T) { Require(). ResponseStatusCode(http.StatusBadRequest) - // Audit Log fails - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("exploded")) - - requestTemplate. - WithBody(v2.CreateAssetGroupRequest{}). - OnHandlerFunc(resources.CreateAssetGroup). - Require(). - ResponseStatusCode(http.StatusInternalServerError) - // Create DB Query fails - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().CreateAssetGroup(gomock.Any(), gomock.Any(), gomock.Any()).Return(model.AssetGroup{}, fmt.Errorf("exploded")) + mockDB.EXPECT().CreateAssetGroup(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(model.AssetGroup{}, fmt.Errorf("exploded")) requestTemplate. WithBody(v2.CreateAssetGroupRequest{}). @@ -440,8 +416,7 @@ func TestResources_CreateAssetGroup(t *testing.T) { ResponseStatusCode(http.StatusInternalServerError) // Success - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().CreateAssetGroup(gomock.Any(), gomock.Any(), gomock.Any()).Return(model.AssetGroup{}, nil) + mockDB.EXPECT().CreateAssetGroup(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(model.AssetGroup{}, nil) requestTemplate. WithContext(&ctx.Context{ @@ -730,18 +705,6 @@ func TestResources_DeleteAssetGroup(t *testing.T) { // GetAssetGroup DB fails mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, fmt.Errorf("exploded")) - requestTemplate. - WithURLPathVars(map[string]string{ - "asset_group_id": "1234", - }). - OnHandlerFunc(resources.DeleteAssetGroup). - Require(). - ResponseStatusCode(http.StatusInternalServerError) - - // Audit Log DB fails - mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("exploded")) - requestTemplate. WithURLPathVars(map[string]string{ "asset_group_id": "1234", @@ -752,8 +715,7 @@ func TestResources_DeleteAssetGroup(t *testing.T) { // DeleteAssetGroup DB fails mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().DeleteAssetGroup(model.AssetGroup{}).Return(fmt.Errorf("exploded")) + mockDB.EXPECT().DeleteAssetGroup(gomock.Any(), model.AssetGroup{}).Return(fmt.Errorf("exploded")) requestTemplate. WithURLPathVars(map[string]string{ @@ -765,8 +727,7 @@ func TestResources_DeleteAssetGroup(t *testing.T) { // Success mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().DeleteAssetGroup(model.AssetGroup{}).Return(nil) + mockDB.EXPECT().DeleteAssetGroup(gomock.Any(), model.AssetGroup{}).Return(nil) requestTemplate. WithURLPathVars(map[string]string{ @@ -847,25 +808,10 @@ func TestResources_DeleteAssetGroupSelector(t *testing.T) { Require(). ResponseStatusCode(http.StatusConflict) - // Audit Log DB fails - mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil) - mockDB.EXPECT().GetAssetGroupSelector(int32(1234)).Return(model.AssetGroupSelector{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("exploded")) - - requestTemplate. - WithURLPathVars(map[string]string{ - "asset_group_id": "1234", - "asset_group_selector_id": "1234", - }). - OnHandlerFunc(resources.DeleteAssetGroupSelector). - Require(). - ResponseStatusCode(http.StatusInternalServerError) - // DeleteAssetGroupSelector DB fails mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil) mockDB.EXPECT().GetAssetGroupSelector(int32(1234)).Return(model.AssetGroupSelector{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().DeleteAssetGroupSelector(model.AssetGroupSelector{}).Return(fmt.Errorf("exploded")) + mockDB.EXPECT().DeleteAssetGroupSelector(gomock.Any(), model.AssetGroupSelector{}).Return(fmt.Errorf("exploded")) requestTemplate. WithURLPathVars(map[string]string{ @@ -879,8 +825,7 @@ func TestResources_DeleteAssetGroupSelector(t *testing.T) { // Success mockDB.EXPECT().GetAssetGroup(int32(1234)).Return(model.AssetGroup{}, nil) mockDB.EXPECT().GetAssetGroupSelector(int32(1234)).Return(model.AssetGroupSelector{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().DeleteAssetGroupSelector(model.AssetGroupSelector{}).Return(nil) + mockDB.EXPECT().DeleteAssetGroupSelector(gomock.Any(), model.AssetGroupSelector{}).Return(nil) requestTemplate. WithURLPathVars(map[string]string{ diff --git a/cmd/api/src/api/v2/apitest/test.go b/cmd/api/src/api/v2/apitest/test.go index a8cc063f16..7b4af1937c 100644 --- a/cmd/api/src/api/v2/apitest/test.go +++ b/cmd/api/src/api/v2/apitest/test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package apitest @@ -36,7 +36,7 @@ func NewAuthManagementResource(mockCtrl *gomock.Controller) (auth.ManagementReso cfg.Crypto.Argon2.NumThreads = 1 mockDB := mocks.NewMockDatabase(mockCtrl) - resources := auth.NewManagementResource(cfg, mockDB, authPkg.NewAuthorizer()) + resources := auth.NewManagementResource(cfg, mockDB, authPkg.NewAuthorizer(mockDB)) return resources, mockDB } diff --git a/cmd/api/src/api/v2/audit_integration_test.go b/cmd/api/src/api/v2/audit_integration_test.go index e27d0f7101..43b5988aad 100644 --- a/cmd/api/src/api/v2/audit_integration_test.go +++ b/cmd/api/src/api/v2/audit_integration_test.go @@ -57,11 +57,26 @@ func Test_ListAuditLogs(t *testing.T) { // Expect one audit log entry from the deletion auditLogs := testCtx.ListAuditLogs(deletionTimestamp, time.Now(), 0, 1000) - require.Equal(t, 1, len(auditLogs), "Expected only 1 audit log entry but saw %d", len(auditLogs)) + require.Equal(t, 2, len(auditLogs), "Expected exactly 2 audit log entries but saw %d", len(auditLogs)) + + // Make sure these two actions are from the same request + require.Equal(t, auditLogs[0].RequestID, auditLogs[1].RequestID) + + // Makes sure these two actions are from the same two phase commit + require.Equal(t, auditLogs[0].CommitID, auditLogs[1].CommitID) + + // Audit logs are in LIFO order + require.Equal(t, auditLogs[0].Status, "success") + require.Equal(t, auditLogs[1].Status, "intent") testCtx.AssetAuditLog(auditLogs[0], "DeleteAssetGroup", map[string]any{ "asset_group_name": newAssetGroup.Name, "asset_group_tag": newAssetGroup.Tag, }) + + testCtx.AssetAuditLog(auditLogs[1], "DeleteAssetGroup", map[string]any{ + "asset_group_name": newAssetGroup.Name, + "asset_group_tag": newAssetGroup.Tag, + }) }) } diff --git a/cmd/api/src/api/v2/auth/auth.go b/cmd/api/src/api/v2/auth/auth.go index b0e3ca5c04..607378926b 100644 --- a/cmd/api/src/api/v2/auth/auth.go +++ b/cmd/api/src/api/v2/auth/auth.go @@ -17,6 +17,7 @@ package auth import ( + "context" "fmt" "io" "net/http" @@ -118,6 +119,8 @@ func (s ManagementResource) GetSAMLProvider(response http.ResponseWriter, reques } func (s ManagementResource) CreateSAMLProviderMultipart(response http.ResponseWriter, request *http.Request) { + var samlIdentityProvider model.SAMLProvider + if err := request.ParseMultipartForm(api.DefaultAPIPayloadReadLimitBytes); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response) } else if providerNames, hasProviderName := request.MultipartForm.Value["name"]; !hasProviderName { @@ -142,17 +145,13 @@ func (s ManagementResource) CreateSAMLProviderMultipart(response http.ResponseWr } else if ssoURL, err := bhsaml.GetIDPSingleSignOnServiceURL(ssoDescriptor, saml.HTTPPostBinding); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "metadata does not have a SSO service that supports HTTP POST binding", request), response) } else { - samlIdentityProvider := model.SAMLProvider{ - Name: providerNames[0], - DisplayName: providerNames[0], - MetadataXML: metadataXML, - IssuerURI: metadata.EntityID, - SingleSignOnURI: ssoURL, - } + samlIdentityProvider.Name = providerNames[0] + samlIdentityProvider.DisplayName = providerNames[0] + samlIdentityProvider.MetadataXML = metadataXML + samlIdentityProvider.IssuerURI = metadata.EntityID + samlIdentityProvider.SingleSignOnURI = ssoURL - if err := s.db.AppendAuditLog(*ctx.FromRequest(request), "CreateSAMLIdentityProvider", samlIdentityProvider); err != nil { - api.HandleDatabaseError(request, response, err) - } else if newSAMLProvider, err := s.db.CreateSAMLIdentityProvider(samlIdentityProvider); err != nil { + if newSAMLProvider, err := s.db.CreateSAMLIdentityProvider(request.Context(), samlIdentityProvider); err != nil { api.HandleDatabaseError(request, response, err) } else { api.WriteBasicResponse(request.Context(), newSAMLProvider, http.StatusOK, response) @@ -166,9 +165,7 @@ func (s ManagementResource) disassociateUsersFromSAMLProvider(request *http.Requ user.SAMLProvider = nil user.SAMLProviderID = null.NewInt32(0, false) - if err := s.db.AppendAuditLog(*ctx.FromRequest(request), "RemoveSAMLProvider", user); err != nil { - return api.FormatDatabaseError(err) - } else if err := s.db.UpdateUser(user); err != nil { + if err := s.db.UpdateUser(request.Context(), user); err != nil { return api.FormatDatabaseError(err) } } @@ -178,23 +175,22 @@ func (s ManagementResource) disassociateUsersFromSAMLProvider(request *http.Requ func (s ManagementResource) DeleteSAMLProvider(response http.ResponseWriter, request *http.Request) { var ( - rawProviderID = mux.Vars(request)[api.URIPathVariableSAMLProviderID] - requestContext = ctx.FromRequest(request) + identityProvider model.SAMLProvider + rawProviderID = mux.Vars(request)[api.URIPathVariableSAMLProviderID] + requestContext = ctx.FromRequest(request) ) if providerID, err := strconv.ParseInt(rawProviderID, 10, 32); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusNotFound, api.ErrorResponseDetailsResourceNotFound, request), response) - } else if identityProvider, err := s.db.GetSAMLProvider(int32(providerID)); err != nil { + } else if identityProvider, err = s.db.GetSAMLProvider(int32(providerID)); err != nil { api.HandleDatabaseError(request, response, err) } else if user, isUser := auth.GetUserFromAuthCtx(requestContext.AuthCtx); isUser && int64(user.SAMLProviderID.Int32) == providerID { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusConflict, "user may not delete their own SAML auth provider", request), response) } else if providerUsers, err := s.db.GetSAMLProviderUsers(identityProvider.ID); err != nil { api.HandleDatabaseError(request, response, err) - } else if err := s.db.AppendAuditLog(*requestContext, "DeleteSAMLProvider", identityProvider); err != nil { - api.HandleDatabaseError(request, response, err) } else if err := s.disassociateUsersFromSAMLProvider(request, providerUsers); err != nil { api.HandleDatabaseError(request, response, err) - } else if err := s.db.DeleteSAMLProvider(identityProvider); err != nil { + } else if err := s.db.DeleteSAMLProvider(request.Context(), identityProvider); err != nil { api.HandleDatabaseError(request, response, err) } else { api.WriteBasicResponse(request.Context(), v2.DeleteSAMLProviderResponse{ @@ -423,7 +419,10 @@ func (s ManagementResource) ListUsers(response http.ResponseWriter, request *htt } func (s ManagementResource) CreateUser(response http.ResponseWriter, request *http.Request) { - var createUserRequest v2.CreateUserRequest + var ( + createUserRequest v2.CreateUserRequest + userTemplate model.User + ) if err := api.ReadJSONRequestPayloadLimited(&createUserRequest, request); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response) @@ -432,15 +431,13 @@ func (s ManagementResource) CreateUser(response http.ResponseWriter, request *ht } else if roles, err := s.db.GetRoles(createUserRequest.Roles); err != nil { api.HandleDatabaseError(request, response, err) } else { - userTemplate := model.User{ - Roles: roles, - FirstName: null.StringFrom(createUserRequest.FirstName), - LastName: null.StringFrom(createUserRequest.LastName), - EmailAddress: null.StringFrom(createUserRequest.EmailAddress), - PrincipalName: createUserRequest.Principal, - // EULA Acceptance does not pertain to Bloodhound Community Edition; this flag is used for Bloodhound Enterprise users. - EULAAccepted: true, - } + userTemplate.Roles = roles + userTemplate.FirstName = null.StringFrom(createUserRequest.FirstName) + userTemplate.LastName = null.StringFrom(createUserRequest.LastName) + userTemplate.EmailAddress = null.StringFrom(createUserRequest.EmailAddress) + userTemplate.PrincipalName = createUserRequest.Principal + // EULA Acceptance does not pertain to Bloodhound Community Edition; this flag is used for Bloodhound Enterprise users. + userTemplate.EULAAccepted = true if createUserRequest.Secret != "" { if errs := validation.Validate(createUserRequest.SetUserSecretRequest); errs != nil { @@ -478,33 +475,26 @@ func (s ManagementResource) CreateUser(response http.ResponseWriter, request *ht } } - if err := s.db.AppendAuditLog(*ctx.FromRequest(request), "CreateUser", userTemplate); err != nil { - api.HandleDatabaseError(request, response, err) - } else if newUser, err := s.db.CreateUser(userTemplate); err != nil { + if newUser, err := s.db.CreateUser(request.Context(), userTemplate); err != nil { api.HandleDatabaseError(request, response, err) } else { api.WriteBasicResponse(request.Context(), newUser, http.StatusOK, response) } + } } func (s ManagementResource) updateUser(response http.ResponseWriter, request *http.Request, user model.User) { - if err := s.db.AppendAuditLog(*ctx.FromRequest(request), "UpdateUser", user); err != nil { - api.HandleDatabaseError(request, response, err) - } else if err := s.db.UpdateUser(user); err != nil { + if err := s.db.UpdateUser(request.Context(), user); err != nil { api.HandleDatabaseError(request, response, err) } else { response.WriteHeader(http.StatusOK) } } -func (s ManagementResource) ensureUserHasNoAuthSecret(context ctx.Context, user model.User) error { +func (s ManagementResource) ensureUserHasNoAuthSecret(ctx context.Context, user model.User) error { if user.AuthSecret != nil { - if err := s.db.AppendAuditLog(context, "DeleteUserAuthSecret", user); err != nil { - return api.FormatDatabaseError(err) - } - - if err := s.db.DeleteAuthSecret(*user.AuthSecret); err != nil { + if err := s.db.DeleteAuthSecret(ctx, *user.AuthSecret); err != nil { return api.FormatDatabaseError(err) } else { return nil @@ -558,12 +548,10 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht // We're setting a SAML provider. If the user has an associated secret the secret will be removed. if samlProviderID, err := serde.ParseInt32(updateUserRequest.SAMLProviderID); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, fmt.Sprintf("SAML Provider ID must be a number: %v", err.Error()), request), response) - } else if err := s.ensureUserHasNoAuthSecret(context, user); err != nil { + } else if err := s.ensureUserHasNoAuthSecret(request.Context(), user); err != nil { api.HandleDatabaseError(request, response, err) } else if provider, err := s.db.GetSAMLProvider(samlProviderID); err != nil { api.HandleDatabaseError(request, response, err) - } else if err := s.db.AppendAuditLog(*ctx.FromRequest(request), "SetUserSAMLProvider", user); err != nil { - api.HandleDatabaseError(request, response, err) } else { // Ensure that the AuthSecret reference is nil and that the SAML provider is set user.AuthSecret = nil @@ -604,36 +592,35 @@ func (s ManagementResource) GetSelf(response http.ResponseWriter, request *http. func (s ManagementResource) DeleteUser(response http.ResponseWriter, request *http.Request) { var ( + user model.User pathVars = mux.Vars(request) rawUserID = pathVars[api.URIPathVariableUserID] ) if userID, err := uuid.FromString(rawUserID); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseDetailsIDMalformed, request), response) - } else if user, err := s.db.GetUser(userID); err != nil { - api.HandleDatabaseError(request, response, err) - } else if err := s.db.AppendAuditLog(*ctx.FromRequest(request), "DeleteUser", user); err != nil { + } else if user, err = s.db.GetUser(userID); err != nil { api.HandleDatabaseError(request, response, err) - } else if err := s.db.DeleteUser(user); err != nil { + } else if err := s.db.DeleteUser(request.Context(), user); err != nil { api.HandleDatabaseError(request, response, err) } else { response.WriteHeader(http.StatusOK) } } -func (s ManagementResource) setUserSecret(user model.User, authSecret model.AuthSecret) error { +func (s ManagementResource) setUserSecret(ctx context.Context, user model.User, authSecret model.AuthSecret) error { if user.AuthSecret != nil { user.AuthSecret.Digest = authSecret.Digest user.AuthSecret.DigestMethod = authSecret.DigestMethod user.AuthSecret.ExpiresAt = authSecret.ExpiresAt.UTC() - if err := s.db.UpdateAuthSecret(*user.AuthSecret); err != nil { + if err := s.db.UpdateAuthSecret(ctx, *user.AuthSecret); err != nil { return api.FormatDatabaseError(err) } else { return nil } } else { - if _, err := s.db.CreateAuthSecret(authSecret); err != nil { + if _, err := s.db.CreateAuthSecret(ctx, authSecret); err != nil { return api.FormatDatabaseError(err) } else { return nil @@ -643,10 +630,10 @@ func (s ManagementResource) setUserSecret(user model.User, authSecret model.Auth func (s ManagementResource) PutUserAuthSecret(response http.ResponseWriter, request *http.Request) { var ( + authSecret model.AuthSecret setUserSecretRequest v2.SetUserSecretRequest pathVars = mux.Vars(request) rawUserID = pathVars[api.URIPathVariableUserID] - context = *ctx.FromRequest(request) ) if userID, err := uuid.FromString(rawUserID); err != nil { @@ -667,20 +654,16 @@ func (s ManagementResource) PutUserAuthSecret(response http.ResponseWriter, requ log.Errorf("Error while attempting to digest secret for user: %v", err) api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusInternalServerError, api.ErrorResponseDetailsInternalServerError, request), response) } else { - authSecret := model.AuthSecret{ - UserID: targetUser.ID, - Digest: secretDigest.String(), - DigestMethod: s.secretDigester.Method(), - ExpiresAt: time.Now().Add(passwordExpiration).UTC(), - } + authSecret.UserID = targetUser.ID + authSecret.Digest = secretDigest.String() + authSecret.DigestMethod = s.secretDigester.Method() + authSecret.ExpiresAt = time.Now().Add(passwordExpiration).UTC() if setUserSecretRequest.NeedsPasswordReset { authSecret.ExpiresAt = time.Time{} } - if err := s.db.AppendAuditLog(context, "PutUserAuthSecret", authSecret); err != nil { - api.HandleDatabaseError(request, response, err) - } else if err := s.setUserSecret(targetUser, authSecret); err != nil { + if err := s.setUserSecret(request.Context(), targetUser, authSecret); err != nil { api.HandleDatabaseError(request, response, err) } else { response.WriteHeader(http.StatusOK) @@ -689,7 +672,9 @@ func (s ManagementResource) PutUserAuthSecret(response http.ResponseWriter, requ } func (s ManagementResource) ExpireUserAuthSecret(response http.ResponseWriter, request *http.Request) { - rawUserID := mux.Vars(request)[api.URIPathVariableUserID] + var ( + rawUserID = mux.Vars(request)[api.URIPathVariableUserID] + ) if userID, err := uuid.FromString(rawUserID); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseDetailsIDMalformed, request), response) @@ -697,13 +682,11 @@ func (s ManagementResource) ExpireUserAuthSecret(response http.ResponseWriter, r api.HandleDatabaseError(request, response, err) } else if targetUser.SAMLProviderID.Valid { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusConflict, "user has SAML auth enabled", request), response) - } else if err := s.db.AppendAuditLog(*ctx.FromRequest(request), "InvalidateUserAuthSecret", model.AuditData{"user_id": targetUser.ID}); err != nil { - api.HandleDatabaseError(request, response, err) } else { authSecret := targetUser.AuthSecret authSecret.ExpiresAt = time.Time{} - if err := s.db.UpdateAuthSecret(*authSecret); err != nil { + if err := s.db.UpdateAuthSecret(request.Context(), *authSecret); err != nil { api.HandleDatabaseError(request, response, err) } else { // NOTE: This "should" be a 204 since we're not returning a payload but am returning a 200 to retain @@ -802,13 +785,11 @@ func (s ManagementResource) CreateAuthToken(response http.ResponseWriter, reques api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponsePayloadUnmarshalError, request), response) } else if user, err := s.db.GetUser(user.ID); err != nil { api.HandleDatabaseError(request, response, err) - } else if err := s.db.AppendAuditLog(*ctx.FromRequest(request), "CreateAuthToken", model.AuditData{"user_id": user.ID}); err != nil { - api.HandleDatabaseError(request, response, err) } else if err := verifyUserID(&createUserTokenRequest, user, bhCtx, s.authorizer); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusForbidden, err.Error(), request), response) } else if authToken, err := auth.NewUserAuthToken(createUserTokenRequest.UserID, createUserTokenRequest.TokenName, auth.HMAC_SHA2_256); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusInternalServerError, api.ErrorResponseDetailsInternalServerError, request), response) - } else if newAuthToken, err := s.db.CreateAuthToken(authToken); err != nil { + } else if newAuthToken, err := s.db.CreateAuthToken(request.Context(), authToken); err != nil { api.HandleDatabaseError(request, response, err) } else { api.WriteBasicResponse(request.Context(), newAuthToken, http.StatusOK, response) @@ -846,9 +827,7 @@ func (s ManagementResource) DeleteAuthToken(response http.ResponseWriter, reques } else if token.UserID.Valid && token.UserID.UUID != user.ID && !s.authorizer.AllowsPermission(bhCtx.AuthCtx, auth.Permissions().AuthManageUsers) { log.Errorf("Bad user ID: %s != %s", token.UserID.UUID.String(), user.ID.String()) api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusNotFound, api.ErrorResponseDetailsResourceNotFound, request), response) - } else if err := s.db.AppendAuditLog(*ctx.FromRequest(request), "DeleteAuthToken", model.AuditData{"user_id": user.ID.String(), "token_id": token.ID}); err != nil { - api.HandleDatabaseError(request, response, err) - } else if err := s.db.DeleteAuthToken(token); err != nil { + } else if err := s.db.DeleteAuthToken(request.Context(), token); err != nil { api.HandleDatabaseError(request, response, err) } else { response.WriteHeader(http.StatusOK) @@ -907,7 +886,7 @@ func (s ManagementResource) EnrollMFA(response http.ResponseWriter, request *htt } else { user.AuthSecret.TOTPSecret = totpSecret.Secret() - if err := s.db.UpdateAuthSecret(*user.AuthSecret); err != nil { + if err := s.db.UpdateAuthSecret(request.Context(), *user.AuthSecret); err != nil { api.HandleDatabaseError(request, response, err) } else if qrCode, err := auth.GenerateQRCodeBase64(*totpSecret); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusInternalServerError, api.ErrorResponseDetailsInternalServerError, request), response) @@ -962,7 +941,7 @@ func (s ManagementResource) DisenrollMFA(response http.ResponseWriter, request * user.AuthSecret.TOTPSecret = "" user.AuthSecret.TOTPActivated = false - if err := s.db.UpdateAuthSecret(*user.AuthSecret); err != nil { + if err := s.db.UpdateAuthSecret(request.Context(), *user.AuthSecret); err != nil { api.HandleDatabaseError(request, response, err) } else { responseBody := MFAStatusResponse{MFADeactivated} @@ -1013,7 +992,7 @@ func (s ManagementResource) ActivateMFA(response http.ResponseWriter, request *h } else { user.AuthSecret.TOTPActivated = true - if err := s.db.UpdateAuthSecret(*user.AuthSecret); err != nil { + if err := s.db.UpdateAuthSecret(request.Context(), *user.AuthSecret); err != nil { api.HandleDatabaseError(request, response, err) } else { responseBody := MFAStatusResponse{MFAActivated} diff --git a/cmd/api/src/api/v2/auth/auth_test.go b/cmd/api/src/api/v2/auth/auth_test.go index 4093b03956..1aa2ef64c4 100644 --- a/cmd/api/src/api/v2/auth/auth_test.go +++ b/cmd/api/src/api/v2/auth/auth_test.go @@ -81,8 +81,7 @@ func TestManagementResource_PutUserAuthSecret(t *testing.T) { Duration: appcfg.DefaultPasswordExpirationWindow, }), }, nil).Times(1) - mockDB.EXPECT().CreateAuthSecret(gomock.Any()).Return(model.AuthSecret{}, nil).Times(1) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1) + mockDB.EXPECT().CreateAuthSecret(gomock.Any(), gomock.Any()).Return(model.AuthSecret{}, nil).Times(1) // Happy path test.Request(t). @@ -141,9 +140,8 @@ func TestManagementResource_EnableUserSAML(t *testing.T) { mockDB.EXPECT().GetUser(badUserID).Return(model.User{AuthSecret: &model.AuthSecret{}}, nil) mockDB.EXPECT().GetUser(goodUserID).Return(model.User{}, nil) mockDB.EXPECT().GetSAMLProvider(samlProviderID).Return(model.SAMLProvider{}, nil).Times(2) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(5) - mockDB.EXPECT().UpdateUser(gomock.Any()).Return(nil).Times(2) - mockDB.EXPECT().DeleteAuthSecret(gomock.Any()).Return(nil) + mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil).Times(2) + mockDB.EXPECT().DeleteAuthSecret(gomock.Any(), gomock.Any()).Return(nil) // Happy path test.Request(t). @@ -206,12 +204,11 @@ func TestManagementResource_DeleteSAMLProvider(t *testing.T) { mockDB.EXPECT().GetSAMLProvider(goodSAMLProvider.ID).Return(goodSAMLProvider, nil) mockDB.EXPECT().GetSAMLProvider(samlProviderWithUsers.ID).Return(samlProviderWithUsers, nil) - mockDB.EXPECT().DeleteSAMLProvider(gomock.Eq(goodSAMLProvider)).Return(nil) - mockDB.EXPECT().DeleteSAMLProvider(gomock.Eq(samlProviderWithUsers)).Return(nil) + mockDB.EXPECT().DeleteSAMLProvider(gomock.Any(), gomock.Eq(goodSAMLProvider)).Return(nil) + mockDB.EXPECT().DeleteSAMLProvider(gomock.Any(), gomock.Eq(samlProviderWithUsers)).Return(nil) mockDB.EXPECT().GetSAMLProviderUsers(goodSAMLProvider.ID).Return(nil, nil) mockDB.EXPECT().GetSAMLProviderUsers(samlProviderWithUsers.ID).Return(model.Users{samlEnabledUser}, nil) - mockDB.EXPECT().UpdateUser(gomock.Eq(samlEnabledUser)).Return(nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(3) + mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Eq(samlEnabledUser)).Return(nil) // Happy path test.Request(t). @@ -247,7 +244,7 @@ func TestManagementResource_ListPermissions_SortingError(t *testing.T) { require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -280,7 +277,7 @@ func TestManagementResource_ListPermissions_InvalidFilterPredicate(t *testing.T) require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -313,7 +310,7 @@ func TestManagementResource_ListPermissions_PredicateMismatchWithColumn(t *testi require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -347,7 +344,7 @@ func TestManagementResource_ListPermissions_DBError(t *testing.T) { require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -439,7 +436,7 @@ func TestManagementResource_ListRoles_SortingError(t *testing.T) { require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -472,7 +469,7 @@ func TestManagementResource_ListRoles_InvalidColumn(t *testing.T) { require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -505,7 +502,7 @@ func TestManagementResource_ListRoles_InvalidFilterPredicate(t *testing.T) { require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -538,7 +535,7 @@ func TestManagementResource_ListRoles_PredicateMismatchWithColumn(t *testing.T) require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -572,7 +569,7 @@ func TestManagementResource_ListRoles_DBError(t *testing.T) { require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -761,8 +758,7 @@ func TestExpireUserAuthSecret_Success(t *testing.T) { resources, mockDB := apitest.NewAuthManagementResource(mockCtrl) mockDB.EXPECT().GetUser(userId).Return(model.User{AuthSecret: &model.AuthSecret{}}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().UpdateAuthSecret(gomock.Any()).Return(nil) + mockDB.EXPECT().UpdateAuthSecret(gomock.Any(), gomock.Any()).Return(nil) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "DELETE", fmt.Sprintf(endpoint, userId), nil); err != nil { @@ -788,7 +784,7 @@ func TestManagementResource_ListUsers_SortingError(t *testing.T) { require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -821,7 +817,7 @@ func TestManagementResource_ListUsers_InvalidColumn(t *testing.T) { require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -854,7 +850,7 @@ func TestManagementResource_ListUsers_InvalidFilterPredicate(t *testing.T) { require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -887,7 +883,7 @@ func TestManagementResource_ListUsers_PredicateMismatchWithColumn(t *testing.T) require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -921,7 +917,7 @@ func TestManagementResource_ListUsers_DBError(t *testing.T) { require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -1060,8 +1056,8 @@ func TestCreateUser_Failure(t *testing.T) { }, nil).AnyTimes() mockDB.EXPECT().GetRoles(badRole).Return(model.Roles{}, fmt.Errorf("db error")) mockDB.EXPECT().GetRoles(gomock.Not(badRole)).Return(model.Roles{}, nil).AnyTimes() - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - mockDB.EXPECT().CreateUser(badUser).Return(model.User{}, fmt.Errorf("db error")) + mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockDB.EXPECT().CreateUser(gomock.Any(), badUser).Return(model.User{}, fmt.Errorf("db error")) type Input struct { Body v2.CreateUserRequest @@ -1175,8 +1171,7 @@ func TestCreateUser_Success(t *testing.T) { }), }, nil) mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().CreateUser(gomock.Any()).Return(goodUser, nil).AnyTimes() + mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) input := v2.CreateUserRequest{ @@ -1229,8 +1224,7 @@ func TestCreateUser_ResetPassword(t *testing.T) { }), }, nil) mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().CreateUser(gomock.Any()).Return(goodUser, nil) + mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil) input := struct { Body v2.CreateUserRequest @@ -1303,8 +1297,7 @@ func TestManagementResource_UpdateUser_IDMalformed(t *testing.T) { }), }, nil) mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().CreateUser(gomock.Any()).Return(goodUser, nil).AnyTimes() + mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) input := v2.CreateUserRequest{ @@ -1367,8 +1360,7 @@ func TestManagementResource_UpdateUser_GetUserError(t *testing.T) { }), }, nil) mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().CreateUser(gomock.Any()).Return(goodUser, nil).AnyTimes() + mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() mockDB.EXPECT().GetUser(gomock.Any()).Return(model.User{}, fmt.Errorf("foo")) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) @@ -1432,8 +1424,7 @@ func TestManagementResource_UpdateUser_GetRolesError(t *testing.T) { }), }, nil) mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().CreateUser(gomock.Any()).Return(goodUser, nil).AnyTimes() + mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() mockDB.EXPECT().GetUser(gomock.Any()).Return(goodUser, nil) mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, fmt.Errorf("foo")) @@ -1491,8 +1482,7 @@ func TestManagementResource_UpdateUser_SelfDisable(t *testing.T) { }), }, nil) mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().CreateUser(gomock.Any()).Return(goodUser, nil).AnyTimes() + mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() mockDB.EXPECT().GetUser(gomock.Any()).Return(goodUser, nil) mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{model.Role{ Name: "admin", @@ -1573,8 +1563,7 @@ func TestManagementResource_UpdateUser_LookupActiveSessionsError(t *testing.T) { }), }, nil) mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().CreateUser(gomock.Any()).Return(goodUser, nil).AnyTimes() + mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() mockDB.EXPECT().GetUser(gomock.Any()).Return(goodUser, nil) mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{model.Role{ Name: "admin", @@ -1631,7 +1620,7 @@ func TestManagementResource_UpdateUser_LookupActiveSessionsError(t *testing.T) { require.Contains(t, response.Body.String(), api.ErrorResponseDetailsInternalServerError) } -func TestManagementResource_UpdateUser_AppendAuditLogError(t *testing.T) { +func TestManagementResource_UpdateUser_DBError(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() @@ -1655,8 +1644,7 @@ func TestManagementResource_UpdateUser_AppendAuditLogError(t *testing.T) { }), }, nil) mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().CreateUser(gomock.Any()).Return(goodUser, nil).AnyTimes() + mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() mockDB.EXPECT().GetUser(gomock.Any()).Return(goodUser, nil) mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{model.Role{ Name: "admin", @@ -1668,7 +1656,7 @@ func TestManagementResource_UpdateUser_AppendAuditLogError(t *testing.T) { }}, Serial: model.Serial{}, }}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("foo")) + mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(fmt.Errorf("foo")) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) input := v2.CreateUserRequest{ @@ -1683,8 +1671,10 @@ func TestManagementResource_UpdateUser_AppendAuditLogError(t *testing.T) { payload, err := json.Marshal(input) require.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(payload)) require.Nil(t, err) + req.Header.Set(headers.ContentType.String(), mediatypes.ApplicationJson.String()) router := mux.NewRouter() router.HandleFunc(endpoint, resources.CreateUser).Methods("POST") @@ -1712,88 +1702,147 @@ func TestManagementResource_UpdateUser_AppendAuditLogError(t *testing.T) { require.Equal(t, http.StatusInternalServerError, response.Code) } -func TestManagementResource_UpdateUser_DBError(t *testing.T) { +func TestManagementResource_DeleteUser_BadUserID(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() - endpoint := "/api/v2/auth/users" + endpoint := "/api/v2/bloodhound-users" + userID := "badUserID" - goodUserID, err := uuid.NewV4() + resources, _ := apitest.NewAuthManagementResource(mockCtrl) + + ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) + req, err := http.NewRequestWithContext(ctx, "DELETE", endpoint, nil) require.Nil(t, err) - goodUser := model.User{ - PrincipalName: "good user", - Unique: model.Unique{ - ID: goodUserID, - }, - } + req = mux.SetURLVars(req, map[string]string{api.URIPathVariableUserID: userID}) + req.Header.Set(headers.ContentType.String(), mediatypes.ApplicationJson.String()) + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(resources.DeleteUser) + handler.ServeHTTP(rr, req) + + require.Equal(t, rr.Code, http.StatusBadRequest) +} + +func TestManagementResource_DeleteUser_UserNotFound(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + endpoint := "/api/v2/bloodhound-users" + + userID, err := uuid.NewV4() + require.Nil(t, err) resources, mockDB := apitest.NewAuthManagementResource(mockCtrl) - mockDB.EXPECT().GetConfigurationParameter(appcfg.PasswordExpirationWindow).Return(appcfg.Parameter{ - Key: appcfg.PasswordExpirationWindow, - Value: must.NewJSONBObject(appcfg.PasswordExpiration{ - Duration: appcfg.DefaultPasswordExpirationWindow, - }), - }, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().CreateUser(gomock.Any()).Return(goodUser, nil).AnyTimes() - mockDB.EXPECT().GetUser(gomock.Any()).Return(goodUser, nil) - mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{model.Role{ - Name: "admin", - Description: "admin", - Permissions: model.Permissions{model.Permission{ - Authority: "admin", - Name: "admin", - Serial: model.Serial{}, - }}, - Serial: model.Serial{}, - }}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().UpdateUser(gomock.Any()).Return(fmt.Errorf("foo")) + mockDB.EXPECT().GetUser(userID).Return(model.User{}, database.ErrNotFound) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) - input := v2.CreateUserRequest{ - UpdateUserRequest: v2.UpdateUserRequest{ - Principal: "good user", - }, - SetUserSecretRequest: v2.SetUserSecretRequest{ - Secret: "abcDEF123456$$", - NeedsPasswordReset: true, - }, - } + req, err := http.NewRequestWithContext(ctx, "DELETE", endpoint, nil) + require.Nil(t, err) - payload, err := json.Marshal(input) + req = mux.SetURLVars(req, map[string]string{api.URIPathVariableUserID: userID.String()}) + req.Header.Set(headers.ContentType.String(), mediatypes.ApplicationJson.String()) + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(resources.DeleteUser) + handler.ServeHTTP(rr, req) + + require.Equal(t, rr.Code, http.StatusNotFound) +} + +func TestManagementResource_DeleteUser_GetUserError(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + endpoint := "/api/v2/bloodhound-users" + + userID, err := uuid.NewV4() require.Nil(t, err) - req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(payload)) + resources, mockDB := apitest.NewAuthManagementResource(mockCtrl) + mockDB.EXPECT().GetUser(userID).Return(model.User{}, fmt.Errorf("foo")) + + ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) + req, err := http.NewRequestWithContext(ctx, "DELETE", endpoint, nil) require.Nil(t, err) + req = mux.SetURLVars(req, map[string]string{api.URIPathVariableUserID: userID.String()}) req.Header.Set(headers.ContentType.String(), mediatypes.ApplicationJson.String()) - router := mux.NewRouter() - router.HandleFunc(endpoint, resources.CreateUser).Methods("POST") + rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) + handler := http.HandlerFunc(resources.DeleteUser) + handler.ServeHTTP(rr, req) - require.Equal(t, rr.Code, http.StatusOK) - require.Contains(t, rr.Body.String(), "good user") + require.Equal(t, rr.Code, http.StatusInternalServerError) +} + +func TestManagementResource_DeleteUser_DeleteUserError(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + endpoint := "/api/v2/bloodhound-users" userID, err := uuid.NewV4() require.Nil(t, err) - payload, err = json.Marshal(v2.UpdateUserRequest{}) + user := model.User{ + PrincipalName: "good user", + Unique: model.Unique{ + ID: userID, + }, + } + + resources, mockDB := apitest.NewAuthManagementResource(mockCtrl) + mockDB.EXPECT().GetUser(userID).Return(user, nil) + mockDB.EXPECT().DeleteUser(gomock.Any(), user).Return(fmt.Errorf("foo")) + + ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) + req, err := http.NewRequestWithContext(ctx, "DELETE", endpoint, nil) require.Nil(t, err) - endpoint = fmt.Sprintf("/api/v2/bloodhound-users/%v", userID) - req, err = http.NewRequestWithContext(ctx, "PATCH", endpoint, bytes.NewReader(payload)) + req = mux.SetURLVars(req, map[string]string{api.URIPathVariableUserID: userID.String()}) + req.Header.Set(headers.ContentType.String(), mediatypes.ApplicationJson.String()) + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(resources.DeleteUser) + handler.ServeHTTP(rr, req) + + require.Equal(t, rr.Code, http.StatusInternalServerError) +} + +func TestManagementResource_DeleteUser_Success(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + endpoint := "/api/v2/bloodhound-users" + + userID, err := uuid.NewV4() + require.Nil(t, err) + + user := model.User{ + PrincipalName: "good user", + Unique: model.Unique{ + ID: userID, + }, + } + + resources, mockDB := apitest.NewAuthManagementResource(mockCtrl) + mockDB.EXPECT().GetUser(userID).Return(user, nil) + mockDB.EXPECT().DeleteUser(gomock.Any(), user).Return(nil) + + ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) + req, err := http.NewRequestWithContext(ctx, "DELETE", endpoint, nil) require.Nil(t, err) req = mux.SetURLVars(req, map[string]string{api.URIPathVariableUserID: userID.String()}) req.Header.Set(headers.ContentType.String(), mediatypes.ApplicationJson.String()) - response := httptest.NewRecorder() - handler := http.HandlerFunc(resources.UpdateUser) - handler.ServeHTTP(response, req) - require.Equal(t, http.StatusInternalServerError, response.Code) + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(resources.DeleteUser) + handler.ServeHTTP(rr, req) + + require.Equal(t, rr.Code, http.StatusOK) } func TestManagementResource_UpdateUser_Success(t *testing.T) { @@ -1820,8 +1869,7 @@ func TestManagementResource_UpdateUser_Success(t *testing.T) { }), }, nil) mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().CreateUser(gomock.Any()).Return(goodUser, nil).AnyTimes() + mockDB.EXPECT().CreateUser(gomock.Any(), gomock.Any()).Return(goodUser, nil).AnyTimes() mockDB.EXPECT().GetUser(gomock.Any()).Return(goodUser, nil) mockDB.EXPECT().GetRoles(gomock.Any()).Return(model.Roles{model.Role{ Name: "admin", @@ -1834,8 +1882,7 @@ func TestManagementResource_UpdateUser_Success(t *testing.T) { Serial: model.Serial{}, }}, nil) mockDB.EXPECT().LookupActiveSessionsByUser(gomock.Any()).Return([]model.UserSession{}, nil) - mockDB.EXPECT().AppendAuditLog(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - mockDB.EXPECT().UpdateUser(gomock.Any()).Return(nil) + mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) input := v2.CreateUserRequest{ @@ -1946,7 +1993,7 @@ func TestManagementResource_ListAuthTokens_SortingError(t *testing.T) { require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) endpoint := "/api/v2/auth/tokens" if req, err := http.NewRequestWithContext(c, "GET", endpoint, nil); err != nil { @@ -1979,7 +2026,7 @@ func TestManagementResource_ListAuthTokens_InvalidColumn(t *testing.T) { require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -2012,7 +2059,7 @@ func TestManagementResource_ListAuthTokens_InvalidFilterPredicate(t *testing.T) require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -2045,7 +2092,7 @@ func TestManagementResource_ListAuthTokens_PredicateMismatchWithColumn(t *testin require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) if req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil); err != nil { @@ -2094,7 +2141,7 @@ func TestManagementResource_ListAuthTokens_DBError(t *testing.T) { require.Nilf(t, err, "Failed to create default configuration: %v", err) config.Crypto.Argon2.NumIterations = 1 config.Crypto.Argon2.NumThreads = 1 - resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer()) + resources := auth.NewManagementResource(config, mockDB, authz.NewAuthorizer(mockDB)) endpoint := "/api/v2/auth/tokens" if req, err := http.NewRequestWithContext(c, "GET", endpoint, nil); err != nil { @@ -2629,7 +2676,7 @@ func TestDisenrollMFA_Success(t *testing.T) { userId := test.NewUUIDv4(t) mockDB.EXPECT().GetUser(userId).Return(model.User{AuthSecret: defaultDigestAuthSecret(t, "password")}, nil) - mockDB.EXPECT().UpdateAuthSecret(gomock.Any()).Return(nil) + mockDB.EXPECT().UpdateAuthSecret(gomock.Any(), gomock.Any()).Return(nil) input := auth.MFAEnrollmentRequest{"password"} @@ -2669,7 +2716,7 @@ func TestDisenrollMFA_Admin_Success(t *testing.T) { nonAdminId := test.NewUUIDv4(t) mockDB.EXPECT().GetUser(nonAdminId).Return(model.User{AuthSecret: defaultDigestAuthSecret(t, "password")}, nil) - mockDB.EXPECT().UpdateAuthSecret(gomock.Any()).Return(nil) + mockDB.EXPECT().UpdateAuthSecret(gomock.Any(), gomock.Any()).Return(nil) adminContext := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) bhCtx := ctx.Get(adminContext) @@ -2950,7 +2997,7 @@ func TestActivateMFA_Success(t *testing.T) { endpoint := "/api/v2/auth/users/%s/mfa-activation" userId := test.NewUUIDv4(t) mockDB.EXPECT().GetUser(userId).Return(model.User{AuthSecret: defaultDigestAuthSecretWithTOTP(t, "password", totpSecret.Secret())}, nil) - mockDB.EXPECT().UpdateAuthSecret(gomock.Any()).Return(nil) + mockDB.EXPECT().UpdateAuthSecret(gomock.Any(), gomock.Any()).Return(nil) ctx := context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}) inputBody := auth.MFAActivationRequest{passcode} diff --git a/cmd/api/src/api/v2/auth/login.go b/cmd/api/src/api/v2/auth/login.go index a73318e4bb..1892ce161b 100644 --- a/cmd/api/src/api/v2/auth/login.go +++ b/cmd/api/src/api/v2/auth/login.go @@ -1,22 +1,23 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package auth import ( + "context" "fmt" "net/http" "strings" @@ -73,7 +74,7 @@ func (s LoginResource) Login(response http.ResponseWriter, request *http.Request if err := api.ReadJSONRequestPayloadLimited(&loginRequest, request); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response) - } else if err = s.patchEULAAcceptance(loginRequest.Username); err != nil { + } else if err = s.patchEULAAcceptance(request.Context(), loginRequest.Username); err != nil { api.HandleDatabaseError(request, response, err) } else { switch strings.ToLower(loginRequest.LoginMethod) { @@ -87,15 +88,16 @@ func (s LoginResource) Login(response http.ResponseWriter, request *http.Request } // EULA Acceptance does not pertain to Bloodhound Community Edition; this flag is used for Bloodhound Enterprise users. -func (s LoginResource) patchEULAAcceptance(username string) error { +func (s LoginResource) patchEULAAcceptance(ctx context.Context, username string) error { if user, err := s.db.LookupUser(username); err != nil { return err - } else { + } else if !user.EULAAccepted { user.EULAAccepted = true - if err = s.db.UpdateUser(user); err != nil { + if err = s.db.UpdateUser(ctx, user); err != nil { return err } } + return nil } diff --git a/cmd/api/src/api/v2/auth/login_internal_test.go b/cmd/api/src/api/v2/auth/login_internal_test.go index 61df022a6d..f468401d0a 100644 --- a/cmd/api/src/api/v2/auth/login_internal_test.go +++ b/cmd/api/src/api/v2/auth/login_internal_test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package auth @@ -89,7 +89,7 @@ func TestLoginFailure(t *testing.T) { mockAuthenticator.EXPECT().LoginWithSecret(gomock.Any(), req3).Return(api.LoginDetails{User: model.User{EULAAccepted: true}}, fmt.Errorf("db error")) mockAuthenticator.EXPECT().LoginWithSecret(gomock.Any(), req4).Return(api.LoginDetails{User: model.User{EULAAccepted: true}}, api.ErrUserDisabled) mockDB.EXPECT().LookupUser(gomock.Any()).Return(model.User{EULAAccepted: false}, nil).Times(5) - mockDB.EXPECT().UpdateUser(gomock.Any()).Return(nil).Times(5) + mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil).Times(5) resources := NewLoginResource(config.Configuration{}, mockAuthenticator, mockDB) @@ -211,7 +211,7 @@ func TestLoginSuccess(t *testing.T) { mockAuthenticator := api_mocks.NewMockAuthenticator(mockCtrl) mockAuthenticator.EXPECT().LoginWithSecret(gomock.Any(), input).Return(api.LoginDetails{User: model.User{AuthSecret: &model.AuthSecret{}, EULAAccepted: true}, SessionToken: "imasessiontoken"}, nil) mockDB.EXPECT().LookupUser(gomock.Any()).Return(model.User{EULAAccepted: false}, nil) - mockDB.EXPECT().UpdateUser(gomock.Any()).Return(nil) + mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil) resources := NewLoginResource(config.Configuration{}, mockAuthenticator, mockDB) diff --git a/cmd/api/src/api/v2/auth/login_test.go b/cmd/api/src/api/v2/auth/login_test.go index 66dd6505da..a017049a09 100644 --- a/cmd/api/src/api/v2/auth/login_test.go +++ b/cmd/api/src/api/v2/auth/login_test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package auth @@ -63,7 +63,7 @@ func TestLoginExpiry(t *testing.T) { mockAuthenticator.EXPECT().LoginWithSecret(gomock.Any(), req1).Return(api.LoginDetails{User: model.User{AuthSecret: &model.AuthSecret{ExpiresAt: time.Now().UTC().Add(time.Hour * 24)}, EULAAccepted: true}, SessionToken: "imasession"}, nil) mockAuthenticator.EXPECT().LoginWithSecret(gomock.Any(), req2).Return(api.LoginDetails{User: model.User{AuthSecret: &model.AuthSecret{ExpiresAt: time.Now().UTC().Add(time.Hour * 24 * -1)}, EULAAccepted: true}, SessionToken: "imasession"}, nil) mockDB.EXPECT().LookupUser(gomock.Any()).Return(model.User{EULAAccepted: false}, nil).Times(2) - mockDB.EXPECT().UpdateUser(gomock.Any()).Return(nil).Times(2) + mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil).Times(2) resources := NewLoginResource(config.Configuration{}, mockAuthenticator, mockDB) diff --git a/cmd/api/src/api/v2/integration/audit.go b/cmd/api/src/api/v2/integration/audit.go index b58149f266..90d8d6f570 100644 --- a/cmd/api/src/api/v2/integration/audit.go +++ b/cmd/api/src/api/v2/integration/audit.go @@ -20,6 +20,7 @@ import ( "time" "github.com/specterops/bloodhound/src/model" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -37,13 +38,13 @@ func (s *Context) ListAuditLogs(after, before time.Time, offset, limit int) mode } func (s *Context) AssetAuditLog(auditLog model.AuditLog, expectedAction string, expectedFields map[string]any) { - require.Equal(s.TestCtrl, auditLog.Action, expectedAction) + assert.Equal(s.TestCtrl, auditLog.Action, expectedAction) for expectedFieldName, expectedFieldValue := range expectedFields { actualFieldValue, hasField := auditLog.Fields[expectedFieldName] - require.True(s.TestCtrl, hasField) - require.Equal(s.TestCtrl, expectedFieldValue, actualFieldValue) + assert.True(s.TestCtrl, hasField) + assert.Equal(s.TestCtrl, expectedFieldValue, actualFieldValue) } } diff --git a/cmd/api/src/auth/model.go b/cmd/api/src/auth/model.go index 5a009dcff4..b851e42326 100644 --- a/cmd/api/src/auth/model.go +++ b/cmd/api/src/auth/model.go @@ -17,15 +17,18 @@ package auth import ( + "context" "crypto/rand" "encoding/base64" "fmt" + "net/http" "strconv" "time" "github.com/gofrs/uuid" "github.com/golang-jwt/jwt/v4" "github.com/specterops/bloodhound/errors" + "github.com/specterops/bloodhound/log" "github.com/specterops/bloodhound/src/database/types/null" "github.com/specterops/bloodhound/src/model" ) @@ -56,9 +59,10 @@ type PermissionOverrides struct { } type SimpleIdentity struct { - ID uuid.UUID - Name string - Key string + ID uuid.UUID + Name string + Email string + Key string } type IdentityResolver interface { @@ -76,41 +80,57 @@ func (s idResolver) GetIdentity(ctx Context) (SimpleIdentity, error) { return SimpleIdentity{}, errors.New("error retrieving user from auth context") } else { return SimpleIdentity{ - ID: user.ID, - Name: user.PrincipalName, - Key: "user_id", + ID: user.ID, + Name: user.PrincipalName, + Email: user.EmailAddress.String, + Key: "user_id", }, nil } } +type AuditLogger interface { + AppendAuditLog(ctx context.Context, entry model.AuditEntry) error +} + type Authorizer interface { + HasPermission(ctx Context, requiredPermission model.Permission, grantedPermissions model.Permissions) bool AllowsPermission(ctx Context, requiredPermission model.Permission) bool AllowsAllPermissions(ctx Context, requiredPermissions model.Permissions) bool AllowsAtLeastOnePermission(ctx Context, requiredPermissions model.Permissions) bool + AuditLogUnauthorizedAccess(request *http.Request) } -type authorizer struct{} +type authorizer struct { + auditLogger AuditLogger +} -func NewAuthorizer() Authorizer { - return authorizer{} +func NewAuthorizer(auditLogger AuditLogger) Authorizer { + return authorizer{auditLogger: auditLogger} } -func (s authorizer) AllowsPermission(ctx Context, requiredPermission model.Permission) bool { +func (s authorizer) HasPermission(ctx Context, requiredPermission model.Permission, grantedPermissions model.Permissions) bool { if ctx.PermissionOverrides.Enabled { return ctx.PermissionOverrides.Permissions.Has(requiredPermission) } + return grantedPermissions.Has(requiredPermission) +} + +func (s authorizer) AllowsPermission(ctx Context, requiredPermission model.Permission) bool { if user, isUser := GetUserFromAuthCtx(ctx); isUser { - return user.Roles.Permissions().Has(requiredPermission) + return s.HasPermission(ctx, requiredPermission, user.Roles.Permissions()) } return false } func (s authorizer) AllowsAllPermissions(ctx Context, requiredPermissions model.Permissions) bool { - for _, permission := range requiredPermissions { - if !s.AllowsPermission(ctx, permission) { - return false + if user, isUser := GetUserFromAuthCtx(ctx); isUser { + grantedPermissions := user.Roles.Permissions() + for _, permission := range requiredPermissions { + if !s.HasPermission(ctx, permission, grantedPermissions) { + return false + } } } @@ -118,15 +138,34 @@ func (s authorizer) AllowsAllPermissions(ctx Context, requiredPermissions model. } func (s authorizer) AllowsAtLeastOnePermission(ctx Context, requiredPermissions model.Permissions) bool { - for _, permission := range requiredPermissions { - if s.AllowsPermission(ctx, permission) { - return true + if user, isUser := GetUserFromAuthCtx(ctx); isUser { + grantedPermissions := user.Roles.Permissions() + for _, permission := range requiredPermissions { + if s.HasPermission(ctx, permission, grantedPermissions) { + return true + } } } return false } +func (s authorizer) AuditLogUnauthorizedAccess(request *http.Request) { + // Ignore read logs as they are less likely to occur from malicious access + if request.Method != "GET" { + if err := s.auditLogger.AppendAuditLog( + request.Context(), + model.AuditEntry{ + Action: "UnauthorizedAccessAttempt", + Model: model.AuditData{"endpoint": request.Method + " " + request.URL.Path}, + Status: model.AuditStatusFailure, + }, + ); err != nil { + log.Errorf("error creating audit log for unauthorized access: %s", err.Error()) + } + } +} + type Context struct { PermissionOverrides PermissionOverrides Owner any diff --git a/cmd/api/src/ctx/ctx.go b/cmd/api/src/ctx/ctx.go index c9b733988f..3231dc8515 100644 --- a/cmd/api/src/ctx/ctx.go +++ b/cmd/api/src/ctx/ctx.go @@ -44,6 +44,7 @@ type Context struct { RequestID string AuthCtx auth.Context Host *url.URL + RequestIP string } func (s *Context) ConstructGoContext() context.Context { diff --git a/cmd/api/src/database/agi.go b/cmd/api/src/database/agi.go index 4a72c1b030..8d25fe769c 100644 --- a/cmd/api/src/database/agi.go +++ b/cmd/api/src/database/agi.go @@ -17,6 +17,7 @@ package database import ( + "context" "time" "gorm.io/gorm" @@ -26,22 +27,49 @@ import ( "github.com/specterops/bloodhound/src/model" ) -func (s *BloodhoundDB) CreateAssetGroup(name, tag string, systemGroup bool) (model.AssetGroup, error) { - assetGroup := model.AssetGroup{ - Name: name, - Tag: tag, - SystemGroup: systemGroup, - } +func (s *BloodhoundDB) CreateAssetGroup(ctx context.Context, name, tag string, systemGroup bool) (model.AssetGroup, error) { + var ( + assetGroup = model.AssetGroup{ + Name: name, + Tag: tag, + SystemGroup: systemGroup, + } + + auditEntry = model.AuditEntry{ + Action: "CreateAssetGroup", + Model: &assetGroup, // Pointer is required to ensure success log contains updated fields after transaction + } + ) - return assetGroup, CheckError(s.db.Create(&assetGroup)) + return assetGroup, s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Create(&assetGroup)) + }) } -func (s *BloodhoundDB) UpdateAssetGroup(assetGroup model.AssetGroup) error { - return CheckError(s.db.Save(&assetGroup)) +func (s *BloodhoundDB) UpdateAssetGroup(ctx context.Context, assetGroup model.AssetGroup) error { + var ( + auditEntry = model.AuditEntry{ + Action: "UpdateAssetGroup", + Model: &assetGroup, // Pointer is required to ensure success log contains updated fields after transaction + } + ) + + return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Save(&assetGroup)) + }) } -func (s *BloodhoundDB) DeleteAssetGroup(assetGroup model.AssetGroup) error { - return CheckError(s.db.Delete(&assetGroup)) +func (s *BloodhoundDB) DeleteAssetGroup(ctx context.Context, assetGroup model.AssetGroup) error { + var ( + auditEntry = model.AuditEntry{ + Action: "DeleteAssetGroup", + Model: &assetGroup, // Pointer is required to ensure success log contains updated fields after transaction + } + ) + + return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Delete(&assetGroup)) + }) } func (s *BloodhoundDB) GetAssetGroup(id int32) (model.AssetGroup, error) { @@ -150,16 +178,30 @@ func (s *BloodhoundDB) GetAssetGroupSelector(id int32) (model.AssetGroupSelector return assetGroupSelector, CheckError(s.db.Find(&assetGroupSelector, id)) } -func (s *BloodhoundDB) UpdateAssetGroupSelector(selector model.AssetGroupSelector) error { - return CheckError(s.db.Save(&selector)) -} +func (s *BloodhoundDB) UpdateAssetGroupSelector(ctx context.Context, selector model.AssetGroupSelector) error { + var ( + auditEntry = model.AuditEntry{ + Action: "UpdateAssetGroupSelector", + Model: &selector, // Pointer is required to ensure success log contains updated fields after transaction + } + ) -func (s *BloodhoundDB) DeleteAssetGroupSelector(selector model.AssetGroupSelector) error { - return CheckError(s.db.Delete(&selector)) + return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Save(&selector)) + }) } -func (s *BloodhoundDB) RemoveAssetGroupSelector(selector model.AssetGroupSelector) error { - return CheckError(s.db.Where("asset_group_id=? AND name=?", selector.AssetGroupID, selector.Name).Delete(&model.AssetGroupSelector{})) +func (s *BloodhoundDB) DeleteAssetGroupSelector(ctx context.Context, selector model.AssetGroupSelector) error { + var ( + auditEntry = model.AuditEntry{ + Action: "DeleteAssetGroupSelector", + Model: &selector, // Pointer is required to ensure success log contains updated fields after transaction + } + ) + + return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Delete(&selector)) + }) } func (s *BloodhoundDB) CreateRawAssetGroupSelector(assetGroup model.AssetGroup, name, selector string) (model.AssetGroupSelector, error) { @@ -217,12 +259,6 @@ func (s *BloodhoundDB) UpdateAssetGroupSelectors(ctx ctx.Context, assetGroup mod }) } } - - if auditLog, err := newAuditLog(ctx, "UpdateAssetGroupSelectors", assetGroup.AuditData().MergeLeft(selectorSpec), s.idResolver); err != nil { - return err - } else if result := tx.Create(&auditLog); result.Error != nil { - return result.Error - } } return nil diff --git a/cmd/api/src/database/audit.go b/cmd/api/src/database/audit.go index 941e6a5695..3e3f5290d0 100644 --- a/cmd/api/src/database/audit.go +++ b/cmd/api/src/database/audit.go @@ -17,45 +17,53 @@ package database import ( + "context" + "database/sql" + "fmt" "time" + "github.com/gofrs/uuid" + "github.com/specterops/bloodhound/errors" "github.com/specterops/bloodhound/src/auth" "github.com/specterops/bloodhound/src/ctx" - "gorm.io/gorm" - - "github.com/specterops/bloodhound/errors" "github.com/specterops/bloodhound/src/database/types" "github.com/specterops/bloodhound/src/model" + "gorm.io/gorm" ) const ( ErrAuthContextInvalid = errors.Error("auth context is invalid") ) -func newAuditLog(ctx ctx.Context, action string, data model.Auditable, idResolver auth.IdentityResolver) (model.AuditLog, error) { +func newAuditLog(context context.Context, entry model.AuditEntry, idResolver auth.IdentityResolver) (model.AuditLog, error) { + bheCtx := ctx.Get(context) + auditLog := model.AuditLog{ - Action: action, - Fields: types.JSONUntypedObject(data.AuditData()), - RequestID: ctx.RequestID, - Status: "success", // TODO: parameterize this so we can pass the actual status instead of hard-coding + Action: entry.Action, + Fields: types.JSONUntypedObject(entry.Model.AuditData()), + RequestID: bheCtx.RequestID, + SourceIpAddress: bheCtx.RequestIP, + Status: string(entry.Status), + CommitID: entry.CommitID, } - authContext := ctx.AuthCtx + authContext := bheCtx.AuthCtx if !authContext.Authenticated() { return auditLog, ErrAuthContextInvalid - } else if identity, err := idResolver.GetIdentity(ctx.AuthCtx); err != nil { + } else if identity, err := idResolver.GetIdentity(bheCtx.AuthCtx); err != nil { return auditLog, ErrAuthContextInvalid } else { auditLog.ActorID = identity.ID.String() auditLog.ActorName = identity.Name + auditLog.ActorEmail = identity.Email } return auditLog, nil } -func (s *BloodhoundDB) AppendAuditLog(ctx ctx.Context, action string, data model.Auditable) error { - if auditLog, err := newAuditLog(ctx, action, data, s.idResolver); err != nil { - return err +func (s *BloodhoundDB) AppendAuditLog(ctx context.Context, entry model.AuditEntry) error { + if auditLog, err := newAuditLog(ctx, entry, s.idResolver); err != nil && err != ErrAuthContextInvalid { + return fmt.Errorf("audit log append: %w", err) } else { return CheckError(s.db.Create(&auditLog)) } @@ -94,3 +102,34 @@ func (s *BloodhoundDB) ListAuditLogs(before, after time.Time, offset, limit int, return auditLogs, int(count), CheckError(result) } + +func (s *BloodhoundDB) AuditableTransaction(ctx context.Context, auditEntry model.AuditEntry, f func(tx *gorm.DB) error, opts ...*sql.TxOptions) error { + var ( + commitID, err = uuid.NewV4() + ) + + if err != nil { + return fmt.Errorf("commitID could not be created: %w", err) + } + + auditEntry.CommitID = commitID + auditEntry.Status = model.AuditStatusIntent + + if err := s.AppendAuditLog(ctx, auditEntry); err != nil { + return fmt.Errorf("could not append intent to audit log: %w", err) + } + + err = s.db.Transaction(f, opts...) + + if err != nil { + auditEntry.Status = model.AuditStatusFailure + } else { + auditEntry.Status = model.AuditStatusSuccess + } + + if err := s.AppendAuditLog(ctx, auditEntry); err != nil { + return fmt.Errorf("could not append %s to audit log: %w", auditEntry.Status, err) + } + + return err +} diff --git a/cmd/api/src/database/audit_test.go b/cmd/api/src/database/audit_test.go index 683fe5c35c..c1b33cf641 100644 --- a/cmd/api/src/database/audit_test.go +++ b/cmd/api/src/database/audit_test.go @@ -20,12 +20,14 @@ package database_test import ( + "context" + "testing" + "time" + "github.com/specterops/bloodhound/src/auth" "github.com/specterops/bloodhound/src/ctx" "github.com/specterops/bloodhound/src/model" "github.com/specterops/bloodhound/src/test/integration" - "testing" - "time" ) func TestDatabase_ListAuditLogs(t *testing.T) { @@ -53,7 +55,7 @@ func TestDatabase_ListAuditLogs(t *testing.T) { }, } for i := 0; i < 7; i++ { - if err := dbInst.AppendAuditLog(mockCtx, "CreateUser", model.User{}); err != nil { + if err := dbInst.AppendAuditLog(ctx.Set(context.Background(), &mockCtx), model.AuditEntry{Model: &model.User{}, Action: "CreateUser", Status: model.AuditStatusSuccess}); err != nil { t.Fatalf("Error creating audit log: %v", err) } } diff --git a/cmd/api/src/database/auth.go b/cmd/api/src/database/auth.go index 2a7768174c..0073c53b4d 100644 --- a/cmd/api/src/database/auth.go +++ b/cmd/api/src/database/auth.go @@ -1,32 +1,33 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package database import ( + "context" "crypto/rand" "encoding/base64" "fmt" "strings" "time" + "github.com/gofrs/uuid" "github.com/specterops/bloodhound/src/auth" "github.com/specterops/bloodhound/src/database/types/null" "github.com/specterops/bloodhound/src/model" - "github.com/gofrs/uuid" "gorm.io/gorm" ) @@ -342,7 +343,7 @@ func (s *BloodhoundDB) HasInstallation() (bool, error) { // CreateUser creates a new user // INSERT INTO users (...) VALUES (...) -func (s *BloodhoundDB) CreateUser(user model.User) (model.User, error) { +func (s *BloodhoundDB) CreateUser(ctx context.Context, user model.User) (model.User, error) { updatedUser := user if newID, err := uuid.NewV4(); err != nil { @@ -354,20 +355,34 @@ func (s *BloodhoundDB) CreateUser(user model.User) (model.User, error) { updatedUser.ID = newID } - result := s.db.Create(&updatedUser) - return updatedUser, CheckError(result) + auditEntry := model.AuditEntry{ + Action: "CreateUser", + Model: &updatedUser, + } + return updatedUser, s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Create(&updatedUser)) + }) } // UpdateUser updates the roles associated with the user according to the input struct // UPDATE users SET roles = .... -func (s *BloodhoundDB) UpdateUser(user model.User) error { - // Update roles first - if err := s.db.Model(&user).Association("Roles").Replace(&user.Roles); err != nil { - return err - } +func (s *BloodhoundDB) UpdateUser(ctx context.Context, user model.User) error { + var ( + auditEntry = model.AuditEntry{ + Action: "UpdateUser", + Model: &user, // Pointer is required to ensure success log contains updated fields after transaction + } + ) - result := s.db.Save(&user) - return CheckError(result) + return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + // Update roles first + if err := tx.Model(&user).Association("Roles").Replace(&user.Roles); err != nil { + return err + } + + result := tx.Save(&user) + return CheckError(result) + }) } func (s *BloodhoundDB) GetAllUsers(order string, filter model.SQLFilter) (model.Users, error) { @@ -402,18 +417,20 @@ func (s *BloodhoundDB) GetUser(id uuid.UUID) (model.User, error) { // DeleteUser removes all roles for a given user, thereby revoking all permissions // UPDATE users SET roles = nil WHERE user_id = .... -func (s *BloodhoundDB) DeleteUser(user model.User) error { - err := s.db.Transaction(func(tx *gorm.DB) error { +func (s *BloodhoundDB) DeleteUser(ctx context.Context, user model.User) error { + auditEntry := model.AuditEntry{ + Action: "DeleteUser", + Model: &user, + } + + return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { // Clear associations first if err := tx.Model(&user).Association("Roles").Clear(); err != nil { return err } - result := tx.Delete(&user) - return CheckError(result) + return CheckError(tx.Delete(&user)) }) - - return err } // LookupUser retrieves the User row associated with the provided name. The name is matched against both the @@ -432,13 +449,15 @@ func (s *BloodhoundDB) LookupUser(name string) (model.User, error) { // CreateAuthToken creates a new AuthToken row using the provided struct // INSERT INTO auth_tokens (...) VALUES (....) -func (s *BloodhoundDB) CreateAuthToken(authToken model.AuthToken) (model.AuthToken, error) { - var ( - updatedAuthToken = authToken - result = s.db.Create(&updatedAuthToken) - ) +func (s *BloodhoundDB) CreateAuthToken(ctx context.Context, authToken model.AuthToken) (model.AuthToken, error) { + auditEntry := model.AuditEntry{ + Action: "CreateAuthToken", + Model: &authToken, + } - return updatedAuthToken, CheckError(result) + return authToken, s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Create(&authToken)) + }) } // UpdateAuthToken updates all fields in the AuthToken row as specified in the provided struct @@ -508,16 +527,28 @@ func (s *BloodhoundDB) GetUserToken(userId, tokenId uuid.UUID) (model.AuthToken, // DeleteAuthToken deletes the provided AuthToken row // DELETE FROM auth_tokens WHERE id = ... -func (s *BloodhoundDB) DeleteAuthToken(authToken model.AuthToken) error { - result := s.db.Where("id = ?", authToken.ID).Delete(&authToken) - return CheckError(result) +func (s *BloodhoundDB) DeleteAuthToken(ctx context.Context, authToken model.AuthToken) error { + auditEntry := model.AuditEntry{ + Action: "DeleteAuthToken", + Model: &authToken, + } + + return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Where("id = ?", authToken.ID).Delete(&authToken)) + }) } // CreateAuthSecret creates a new AuthSecret row // INSERT INTO auth_secrets (...) VALUES (....) -func (s *BloodhoundDB) CreateAuthSecret(authSecret model.AuthSecret) (model.AuthSecret, error) { - result := s.db.Create(&authSecret) - return authSecret, CheckError(result) +func (s *BloodhoundDB) CreateAuthSecret(ctx context.Context, authSecret model.AuthSecret) (model.AuthSecret, error) { + auditEntry := model.AuditEntry{ + Action: "CreateAuthSecret", + Model: &authSecret, + } + + return authSecret, s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Create(&authSecret)) + }) } // GetAuthSecret retrieves the AuthSecret row associated with the provided ID @@ -534,34 +565,60 @@ func (s *BloodhoundDB) GetAuthSecret(id int32) (model.AuthSecret, error) { // UpdateAuthSecret updates the auth secret with the input struct specified // UPDATE auth_secrets SET digest = .., hmac_method = ..., expires_at = ... // WHERE user_id = .... -func (s *BloodhoundDB) UpdateAuthSecret(authSecret model.AuthSecret) error { - result := s.db.Save(&authSecret) - return CheckError(result) +func (s *BloodhoundDB) UpdateAuthSecret(ctx context.Context, authSecret model.AuthSecret) error { + auditEntry := model.AuditEntry{ + Action: "UpdateAuthSecret", + Model: &authSecret, + } + + return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Save(&authSecret)) + }) } // DeleteAuthSecret deletes the auth secret row corresponding to the struct specified // DELETE FROM auth_secrets WHERE user_id = ... -func (s *BloodhoundDB) DeleteAuthSecret(authSecret model.AuthSecret) error { - result := s.db.Delete(&authSecret) - return CheckError(result) +func (s *BloodhoundDB) DeleteAuthSecret(ctx context.Context, authSecret model.AuthSecret) error { + auditEntry := model.AuditEntry{ + Action: "DeleteAuthSecret", + Model: &authSecret, + } + + return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Delete(&authSecret)) + }) } // CreateSAMLProvider creates a new saml_providers row using the data in the input struct // INSERT INTO saml_identity_providers (...) VALUES (...) -func (s *BloodhoundDB) CreateSAMLIdentityProvider(samlProvider model.SAMLProvider) (model.SAMLProvider, error) { +func (s *BloodhoundDB) CreateSAMLIdentityProvider(ctx context.Context, samlProvider model.SAMLProvider) (model.SAMLProvider, error) { var ( - updatedSAMLProvider = samlProvider - result = s.db.Create(&updatedSAMLProvider) + auditEntry = model.AuditEntry{ + Action: "CreateSAMLIdentityProvider", + Model: &samlProvider, // Pointer is required to ensure success log contains updated fields after transaction + } ) - return updatedSAMLProvider, CheckError(result) + err := s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Create(&samlProvider)) + }) + + return samlProvider, err } // CreateSAMLProvider updates a saml_providers row using the data in the input struct // UPDATE saml_identity_providers SET (...) VALUES (...) WHERE id = ... -func (s *BloodhoundDB) UpdateSAMLIdentityProvider(provider model.SAMLProvider) error { - result := s.db.Save(&provider) - return CheckError(result) +func (s *BloodhoundDB) UpdateSAMLIdentityProvider(ctx context.Context, provider model.SAMLProvider) error { + var ( + auditEntry = model.AuditEntry{ + Action: "UpdateSAMLIdentityProvider", + Model: &provider, // Pointer is required to ensure success log contains updated fields after transaction + } + ) + + return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Save(&provider)) + }) } // LookupSAMLProviderByName returns a SAML provider corresponding to the name provided @@ -597,8 +654,17 @@ func (s *BloodhoundDB) GetSAMLProvider(id int32) (model.SAMLProvider, error) { return samlProvider, CheckError(result) } -func (s *BloodhoundDB) DeleteSAMLProvider(provider model.SAMLProvider) error { - return CheckError(s.db.Delete(&provider)) +func (s *BloodhoundDB) DeleteSAMLProvider(ctx context.Context, provider model.SAMLProvider) error { + var ( + auditEntry = model.AuditEntry{ + Action: "DeleteSAMLProvider", + Model: &provider, // Pointer is required to ensure success log contains updated fields after transaction + } + ) + + return s.AuditableTransaction(ctx, auditEntry, func(tx *gorm.DB) error { + return CheckError(tx.Delete(&provider)) + }) } // GetSAMLProviderUsers returns all users that are bound to the SAML provider ID provided diff --git a/cmd/api/src/database/auth_test.go b/cmd/api/src/database/auth_test.go index 93a3d45902..2544c1d37e 100644 --- a/cmd/api/src/database/auth_test.go +++ b/cmd/api/src/database/auth_test.go @@ -20,6 +20,7 @@ package database_test import ( + "context" "testing" "time" @@ -64,7 +65,7 @@ func initAndCreateUser(t *testing.T) (database.Database, model.User) { } ) - if newUser, err := dbInst.CreateUser(user); err != nil { + if newUser, err := dbInst.CreateUser(context.Background(), user); err != nil { t.Fatalf("Error creating user: %v", err) } else { return dbInst, newUser @@ -197,7 +198,7 @@ func TestDatabase_CreateGetUser(t *testing.T) { ) for _, user := range users { - if _, err := dbInst.CreateUser(user); err != nil { + if _, err := dbInst.CreateUser(context.Background(), user); err != nil { t.Fatalf("Error creating user: %v", err) } else if newUser, err := dbInst.LookupUser(user.PrincipalName); err != nil { t.Fatalf("Failed looking up user by principal %s: %v", user.PrincipalName, err) @@ -219,7 +220,7 @@ func TestDatabase_CreateGetUser(t *testing.T) { newUser.Roles = newUser.Roles.RemoveByName(roleToDelete) - if err := dbInst.UpdateUser(newUser); err != nil { + if err := dbInst.UpdateUser(context.Background(), newUser); err != nil { t.Fatalf("Failed to update user: %v", err) } @@ -240,6 +241,7 @@ func TestDatabase_CreateGetUser(t *testing.T) { func TestDatabase_CreateGetDeleteAuthToken(t *testing.T) { var ( + ctx = context.Background() dbInst, user = initAndCreateUser(t) expectedName = "test" token = model.AuthToken{ @@ -250,7 +252,7 @@ func TestDatabase_CreateGetDeleteAuthToken(t *testing.T) { } ) - if newToken, err := dbInst.CreateAuthToken(token); err != nil { + if newToken, err := dbInst.CreateAuthToken(ctx, token); err != nil { t.Fatalf("Failed to create auth token: %v", err) } else if updatedUser, err := dbInst.GetUser(user.ID); err != nil { t.Fatalf("Failed to fetch updated user: %v", err) @@ -260,7 +262,7 @@ func TestDatabase_CreateGetDeleteAuthToken(t *testing.T) { t.Fatalf("Expected auth token to have valid name") } else if newToken.Name.String != expectedName { t.Fatalf("Expected auth token to have name %s but saw %v", expectedName, newToken.Name.String) - } else if err := dbInst.DeleteAuthToken(newToken); err != nil { + } else if err := dbInst.DeleteAuthToken(ctx, newToken); err != nil { t.Fatalf("Failed to delete auth token: %v", err) } @@ -275,6 +277,7 @@ func TestDatabase_CreateGetDeleteAuthSecret(t *testing.T) { const updatedDigest = "updated" var ( + ctx = context.Background() dbInst, user = initAndCreateUser(t) secret = model.AuthSecret{ UserID: user.ID, @@ -284,7 +287,7 @@ func TestDatabase_CreateGetDeleteAuthSecret(t *testing.T) { } ) - if newSecret, err := dbInst.CreateAuthSecret(secret); err != nil { + if newSecret, err := dbInst.CreateAuthSecret(ctx, secret); err != nil { t.Fatalf("Failed to create auth secret: %v", err) } else if updatedUser, err := dbInst.GetUser(user.ID); err != nil { t.Fatalf("Failed to fetch updated user: %v", err) @@ -293,7 +296,7 @@ func TestDatabase_CreateGetDeleteAuthSecret(t *testing.T) { } else { newSecret.Digest = updatedDigest - if err := dbInst.UpdateAuthSecret(newSecret); err != nil { + if err := dbInst.UpdateAuthSecret(ctx, newSecret); err != nil { t.Fatalf("Failed to update auth secret %d: %v", newSecret.ID, err) } else if updatedSecret, err := dbInst.GetAuthSecret(newSecret.ID); err != nil { t.Fatalf("Failed to fetch updated auth secret: %v", err) @@ -301,7 +304,7 @@ func TestDatabase_CreateGetDeleteAuthSecret(t *testing.T) { t.Fatalf("Expected updated auth secret digest to be %s but saw %s", updatedDigest, updatedSecret.Digest) } - if err := dbInst.DeleteAuthSecret(newSecret); err != nil { + if err := dbInst.DeleteAuthSecret(ctx, newSecret); err != nil { t.Fatalf("Failed to delete auth token: %v", err) } } @@ -323,12 +326,12 @@ func TestDatabase_CreateSAMLProvider(t *testing.T) { SingleSignOnURI: "https://idp.example.com/sso", } - if newSAMLProvider, err := dbInst.CreateSAMLIdentityProvider(samlProvider); err != nil { + if newSAMLProvider, err := dbInst.CreateSAMLIdentityProvider(context.Background(), samlProvider); err != nil { t.Fatalf("Failed to create SAML provider: %v", err) } else { user.SAMLProviderID = null.Int32From(newSAMLProvider.ID) - if err := dbInst.UpdateUser(user); err != nil { + if err := dbInst.UpdateUser(context.Background(), user); err != nil { t.Fatalf("Failed to update user: %v", err) } else if updatedUser, err := dbInst.GetUser(user.ID); err != nil { t.Fatalf("Failed to fetch updated user: %v", err) diff --git a/cmd/api/src/database/db.go b/cmd/api/src/database/db.go index 21bbe8de89..a5b4643fcf 100644 --- a/cmd/api/src/database/db.go +++ b/cmd/api/src/database/db.go @@ -19,6 +19,7 @@ package database //go:generate go run go.uber.org/mock/mockgen -copyright_file=../../../../LICENSE.header -destination=./mocks/db.go -package=mocks . Database import ( + "context" "fmt" "time" @@ -60,9 +61,9 @@ type Database interface { DeleteIngestTask(ingestTask model.IngestTask) error GetIngestTasksForJob(jobID int64) (model.IngestTasks, error) GetUnfinishedIngestIDs() ([]int64, error) - CreateAssetGroup(name, tag string, systemGroup bool) (model.AssetGroup, error) - UpdateAssetGroup(assetGroup model.AssetGroup) error - DeleteAssetGroup(assetGroup model.AssetGroup) error + CreateAssetGroup(ctx context.Context, name, tag string, systemGroup bool) (model.AssetGroup, error) + UpdateAssetGroup(ctx context.Context, assetGroup model.AssetGroup) error + DeleteAssetGroup(ctx context.Context, assetGroup model.AssetGroup) error GetAssetGroup(id int32) (model.AssetGroup, error) GetAllAssetGroups(order string, filter model.SQLFilter) (model.AssetGroups, error) SweepAssetGroupCollections() @@ -71,9 +72,8 @@ type Database interface { GetTimeRangedAssetGroupCollections(assetGroupID int32, from int64, to int64, order string) (model.AssetGroupCollections, error) GetAllAssetGroupCollections() (model.AssetGroupCollections, error) GetAssetGroupSelector(id int32) (model.AssetGroupSelector, error) - UpdateAssetGroupSelector(selector model.AssetGroupSelector) error - DeleteAssetGroupSelector(selector model.AssetGroupSelector) error - RemoveAssetGroupSelector(selector model.AssetGroupSelector) error + UpdateAssetGroupSelector(ctx context.Context, selector model.AssetGroupSelector) error + DeleteAssetGroupSelector(ctx context.Context, selector model.AssetGroupSelector) error CreateRawAssetGroupSelector(assetGroup model.AssetGroup, name, selector string) (model.AssetGroupSelector, error) CreateAssetGroupSelector(assetGroup model.AssetGroup, spec model.AssetGroupSelectorSpec, systemSelector bool) (model.AssetGroupSelector, error) UpdateAssetGroupSelectors(ctx ctx.Context, assetGroup model.AssetGroup, selectorSpecs []model.AssetGroupSelectorSpec, systemSelector bool) (model.UpdatedAssetGroupSelectors, error) @@ -82,7 +82,7 @@ type Database interface { RawFirst(value any) error Wipe() error Migrate() error - AppendAuditLog(ctx ctx.Context, action string, data model.Auditable) error + AppendAuditLog(ctx context.Context, entry model.AuditEntry) error ListAuditLogs(before, after time.Time, offset, limit int, order string, filter model.SQLFilter) (model.AuditLogs, int, error) CreateRole(role model.Role) (model.Role, error) UpdateRole(role model.Role) error @@ -99,30 +99,30 @@ type Database interface { CreateInstallation() (model.Installation, error) GetInstallation() (model.Installation, error) HasInstallation() (bool, error) - CreateUser(user model.User) (model.User, error) - UpdateUser(user model.User) error + CreateUser(ctx context.Context, user model.User) (model.User, error) + UpdateUser(ctx context.Context, user model.User) error GetAllUsers(order string, filter model.SQLFilter) (model.Users, error) GetUser(id uuid.UUID) (model.User, error) - DeleteUser(user model.User) error + DeleteUser(ctx context.Context, user model.User) error LookupUser(principalName string) (model.User, error) - CreateAuthToken(authToken model.AuthToken) (model.AuthToken, error) + CreateAuthToken(ctx context.Context, authToken model.AuthToken) (model.AuthToken, error) UpdateAuthToken(authToken model.AuthToken) error GetAllAuthTokens(order string, filter model.SQLFilter) (model.AuthTokens, error) GetAuthToken(id uuid.UUID) (model.AuthToken, error) ListUserTokens(userID uuid.UUID, order string, filter model.SQLFilter) (model.AuthTokens, error) GetUserToken(userId, tokenId uuid.UUID) (model.AuthToken, error) - DeleteAuthToken(authToken model.AuthToken) error - CreateAuthSecret(authSecret model.AuthSecret) (model.AuthSecret, error) + DeleteAuthToken(ctx context.Context, authToken model.AuthToken) error + CreateAuthSecret(ctx context.Context, authSecret model.AuthSecret) (model.AuthSecret, error) GetAuthSecret(id int32) (model.AuthSecret, error) - UpdateAuthSecret(authSecret model.AuthSecret) error - DeleteAuthSecret(authSecret model.AuthSecret) error - CreateSAMLIdentityProvider(samlProvider model.SAMLProvider) (model.SAMLProvider, error) - UpdateSAMLIdentityProvider(samlProvider model.SAMLProvider) error + UpdateAuthSecret(ctx context.Context, authSecret model.AuthSecret) error + DeleteAuthSecret(ctx context.Context, authSecret model.AuthSecret) error + CreateSAMLIdentityProvider(ctx context.Context, samlProvider model.SAMLProvider) (model.SAMLProvider, error) + UpdateSAMLIdentityProvider(ctx context.Context, samlProvider model.SAMLProvider) error LookupSAMLProviderByName(name string) (model.SAMLProvider, error) GetAllSAMLProviders() (model.SAMLProviders, error) GetSAMLProvider(id int32) (model.SAMLProvider, error) GetSAMLProviderUsers(id int32) (model.Users, error) - DeleteSAMLProvider(samlProvider model.SAMLProvider) error + DeleteSAMLProvider(ctx context.Context, samlProvider model.SAMLProvider) error CreateUserSession(userSession model.UserSession) (model.UserSession, error) LookupActiveSessionsByUser(user model.User) ([]model.UserSession, error) EndUserSession(userSession model.UserSession) diff --git a/cmd/api/src/database/migration/migrations/v5.5.0.sql b/cmd/api/src/database/migration/migrations/v5.5.0.sql index b3dbbf5b5e..a54fbda460 100644 --- a/cmd/api/src/database/migration/migrations/v5.5.0.sql +++ b/cmd/api/src/database/migration/migrations/v5.5.0.sql @@ -15,7 +15,7 @@ -- SPDX-License-Identifier: Apache-2.0 -- Add new columns for audit_logs -ALTER TABLE audit_logs +ALTER TABLE audit_logs ADD COLUMN IF NOT EXISTS actor_email VARCHAR(330) DEFAULT NULL, ADD COLUMN IF NOT EXISTS source VARCHAR(40) DEFAULT NULL, ADD COLUMN IF NOT EXISTS status VARCHAR(15) CHECK (status IN ('success', 'failure')) DEFAULT 'success'; diff --git a/cmd/api/src/database/migration/migrations/v5.6.0.sql b/cmd/api/src/database/migration/migrations/v5.6.0.sql new file mode 100644 index 0000000000..cad1cf7dec --- /dev/null +++ b/cmd/api/src/database/migration/migrations/v5.6.0.sql @@ -0,0 +1,31 @@ +-- Copyright 2024 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 + +ALTER TABLE IF EXISTS audit_logs + RENAME COLUMN source TO source_ip_address; + +ALTER TABLE IF EXISTS audit_logs + DROP CONSTRAINT IF EXISTS audit_logs_status_check, + ADD CONSTRAINT status_check + CHECK (status IN ('intent', 'success', 'failure')), + ALTER COLUMN status SET DEFAULT 'intent', + ALTER COLUMN source_ip_address TYPE TEXT, + ADD COLUMN IF NOT EXISTS commit_id TEXT; + +-- Add indices for scalability +CREATE INDEX IF NOT EXISTS idx_audit_logs_actor_email ON audit_logs USING btree (actor_email); +CREATE INDEX IF NOT EXISTS idx_audit_logs_source_ip_address ON audit_logs USING btree (source_ip_address); +CREATE INDEX IF NOT EXISTS idx_audit_logs_status ON audit_logs USING btree (status); diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index 6e58a3d86e..c8320a944b 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -21,6 +21,7 @@ package mocks import ( + context "context" reflect "reflect" time "time" @@ -55,17 +56,17 @@ func (m *MockDatabase) EXPECT() *MockDatabaseMockRecorder { } // AppendAuditLog mocks base method. -func (m *MockDatabase) AppendAuditLog(arg0 ctx.Context, arg1 string, arg2 model.Auditable) error { +func (m *MockDatabase) AppendAuditLog(arg0 context.Context, arg1 model.AuditEntry) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppendAuditLog", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "AppendAuditLog", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // AppendAuditLog indicates an expected call of AppendAuditLog. -func (mr *MockDatabaseMockRecorder) AppendAuditLog(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) AppendAuditLog(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuditLog", reflect.TypeOf((*MockDatabase)(nil).AppendAuditLog), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuditLog", reflect.TypeOf((*MockDatabase)(nil).AppendAuditLog), arg0, arg1) } // Close mocks base method. @@ -111,18 +112,18 @@ func (mr *MockDatabaseMockRecorder) CreateADDataQualityStats(arg0 interface{}) * } // CreateAssetGroup mocks base method. -func (m *MockDatabase) CreateAssetGroup(arg0, arg1 string, arg2 bool) (model.AssetGroup, error) { +func (m *MockDatabase) CreateAssetGroup(arg0 context.Context, arg1, arg2 string, arg3 bool) (model.AssetGroup, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateAssetGroup", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "CreateAssetGroup", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(model.AssetGroup) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateAssetGroup indicates an expected call of CreateAssetGroup. -func (mr *MockDatabaseMockRecorder) CreateAssetGroup(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) CreateAssetGroup(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAssetGroup", reflect.TypeOf((*MockDatabase)(nil).CreateAssetGroup), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAssetGroup", reflect.TypeOf((*MockDatabase)(nil).CreateAssetGroup), arg0, arg1, arg2, arg3) } // CreateAssetGroupCollection mocks base method. @@ -155,33 +156,33 @@ func (mr *MockDatabaseMockRecorder) CreateAssetGroupSelector(arg0, arg1, arg2 in } // CreateAuthSecret mocks base method. -func (m *MockDatabase) CreateAuthSecret(arg0 model.AuthSecret) (model.AuthSecret, error) { +func (m *MockDatabase) CreateAuthSecret(arg0 context.Context, arg1 model.AuthSecret) (model.AuthSecret, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateAuthSecret", arg0) + ret := m.ctrl.Call(m, "CreateAuthSecret", arg0, arg1) ret0, _ := ret[0].(model.AuthSecret) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateAuthSecret indicates an expected call of CreateAuthSecret. -func (mr *MockDatabaseMockRecorder) CreateAuthSecret(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) CreateAuthSecret(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthSecret", reflect.TypeOf((*MockDatabase)(nil).CreateAuthSecret), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthSecret", reflect.TypeOf((*MockDatabase)(nil).CreateAuthSecret), arg0, arg1) } // CreateAuthToken mocks base method. -func (m *MockDatabase) CreateAuthToken(arg0 model.AuthToken) (model.AuthToken, error) { +func (m *MockDatabase) CreateAuthToken(arg0 context.Context, arg1 model.AuthToken) (model.AuthToken, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateAuthToken", arg0) + ret := m.ctrl.Call(m, "CreateAuthToken", arg0, arg1) ret0, _ := ret[0].(model.AuthToken) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateAuthToken indicates an expected call of CreateAuthToken. -func (mr *MockDatabaseMockRecorder) CreateAuthToken(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) CreateAuthToken(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthToken", reflect.TypeOf((*MockDatabase)(nil).CreateAuthToken), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthToken", reflect.TypeOf((*MockDatabase)(nil).CreateAuthToken), arg0, arg1) } // CreateAzureDataQualityAggregation mocks base method. @@ -305,18 +306,18 @@ func (mr *MockDatabaseMockRecorder) CreateRole(arg0 interface{}) *gomock.Call { } // CreateSAMLIdentityProvider mocks base method. -func (m *MockDatabase) CreateSAMLIdentityProvider(arg0 model.SAMLProvider) (model.SAMLProvider, error) { +func (m *MockDatabase) CreateSAMLIdentityProvider(arg0 context.Context, arg1 model.SAMLProvider) (model.SAMLProvider, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSAMLIdentityProvider", arg0) + ret := m.ctrl.Call(m, "CreateSAMLIdentityProvider", arg0, arg1) ret0, _ := ret[0].(model.SAMLProvider) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateSAMLIdentityProvider indicates an expected call of CreateSAMLIdentityProvider. -func (mr *MockDatabaseMockRecorder) CreateSAMLIdentityProvider(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) CreateSAMLIdentityProvider(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSAMLIdentityProvider", reflect.TypeOf((*MockDatabase)(nil).CreateSAMLIdentityProvider), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSAMLIdentityProvider", reflect.TypeOf((*MockDatabase)(nil).CreateSAMLIdentityProvider), arg0, arg1) } // CreateSavedQuery mocks base method. @@ -335,18 +336,18 @@ func (mr *MockDatabaseMockRecorder) CreateSavedQuery(arg0, arg1, arg2 interface{ } // CreateUser mocks base method. -func (m *MockDatabase) CreateUser(arg0 model.User) (model.User, error) { +func (m *MockDatabase) CreateUser(arg0 context.Context, arg1 model.User) (model.User, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateUser", arg0) + ret := m.ctrl.Call(m, "CreateUser", arg0, arg1) ret0, _ := ret[0].(model.User) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateUser indicates an expected call of CreateUser. -func (mr *MockDatabaseMockRecorder) CreateUser(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) CreateUser(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUser", reflect.TypeOf((*MockDatabase)(nil).CreateUser), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateUser", reflect.TypeOf((*MockDatabase)(nil).CreateUser), arg0, arg1) } // CreateUserSession mocks base method. @@ -365,59 +366,59 @@ func (mr *MockDatabaseMockRecorder) CreateUserSession(arg0 interface{}) *gomock. } // DeleteAssetGroup mocks base method. -func (m *MockDatabase) DeleteAssetGroup(arg0 model.AssetGroup) error { +func (m *MockDatabase) DeleteAssetGroup(arg0 context.Context, arg1 model.AssetGroup) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteAssetGroup", arg0) + ret := m.ctrl.Call(m, "DeleteAssetGroup", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // DeleteAssetGroup indicates an expected call of DeleteAssetGroup. -func (mr *MockDatabaseMockRecorder) DeleteAssetGroup(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) DeleteAssetGroup(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAssetGroup", reflect.TypeOf((*MockDatabase)(nil).DeleteAssetGroup), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAssetGroup", reflect.TypeOf((*MockDatabase)(nil).DeleteAssetGroup), arg0, arg1) } // DeleteAssetGroupSelector mocks base method. -func (m *MockDatabase) DeleteAssetGroupSelector(arg0 model.AssetGroupSelector) error { +func (m *MockDatabase) DeleteAssetGroupSelector(arg0 context.Context, arg1 model.AssetGroupSelector) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteAssetGroupSelector", arg0) + ret := m.ctrl.Call(m, "DeleteAssetGroupSelector", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // DeleteAssetGroupSelector indicates an expected call of DeleteAssetGroupSelector. -func (mr *MockDatabaseMockRecorder) DeleteAssetGroupSelector(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) DeleteAssetGroupSelector(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAssetGroupSelector", reflect.TypeOf((*MockDatabase)(nil).DeleteAssetGroupSelector), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAssetGroupSelector", reflect.TypeOf((*MockDatabase)(nil).DeleteAssetGroupSelector), arg0, arg1) } // DeleteAuthSecret mocks base method. -func (m *MockDatabase) DeleteAuthSecret(arg0 model.AuthSecret) error { +func (m *MockDatabase) DeleteAuthSecret(arg0 context.Context, arg1 model.AuthSecret) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteAuthSecret", arg0) + ret := m.ctrl.Call(m, "DeleteAuthSecret", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // DeleteAuthSecret indicates an expected call of DeleteAuthSecret. -func (mr *MockDatabaseMockRecorder) DeleteAuthSecret(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) DeleteAuthSecret(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthSecret", reflect.TypeOf((*MockDatabase)(nil).DeleteAuthSecret), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthSecret", reflect.TypeOf((*MockDatabase)(nil).DeleteAuthSecret), arg0, arg1) } // DeleteAuthToken mocks base method. -func (m *MockDatabase) DeleteAuthToken(arg0 model.AuthToken) error { +func (m *MockDatabase) DeleteAuthToken(arg0 context.Context, arg1 model.AuthToken) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteAuthToken", arg0) + ret := m.ctrl.Call(m, "DeleteAuthToken", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // DeleteAuthToken indicates an expected call of DeleteAuthToken. -func (mr *MockDatabaseMockRecorder) DeleteAuthToken(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) DeleteAuthToken(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthToken", reflect.TypeOf((*MockDatabase)(nil).DeleteAuthToken), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthToken", reflect.TypeOf((*MockDatabase)(nil).DeleteAuthToken), arg0, arg1) } // DeleteIngestTask mocks base method. @@ -435,17 +436,17 @@ func (mr *MockDatabaseMockRecorder) DeleteIngestTask(arg0 interface{}) *gomock.C } // DeleteSAMLProvider mocks base method. -func (m *MockDatabase) DeleteSAMLProvider(arg0 model.SAMLProvider) error { +func (m *MockDatabase) DeleteSAMLProvider(arg0 context.Context, arg1 model.SAMLProvider) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteSAMLProvider", arg0) + ret := m.ctrl.Call(m, "DeleteSAMLProvider", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // DeleteSAMLProvider indicates an expected call of DeleteSAMLProvider. -func (mr *MockDatabaseMockRecorder) DeleteSAMLProvider(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) DeleteSAMLProvider(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSAMLProvider", reflect.TypeOf((*MockDatabase)(nil).DeleteSAMLProvider), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSAMLProvider", reflect.TypeOf((*MockDatabase)(nil).DeleteSAMLProvider), arg0, arg1) } // DeleteSavedQuery mocks base method. @@ -463,17 +464,17 @@ func (mr *MockDatabaseMockRecorder) DeleteSavedQuery(arg0 interface{}) *gomock.C } // DeleteUser mocks base method. -func (m *MockDatabase) DeleteUser(arg0 model.User) error { +func (m *MockDatabase) DeleteUser(arg0 context.Context, arg1 model.User) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteUser", arg0) + ret := m.ctrl.Call(m, "DeleteUser", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // DeleteUser indicates an expected call of DeleteUser. -func (mr *MockDatabaseMockRecorder) DeleteUser(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) DeleteUser(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUser", reflect.TypeOf((*MockDatabase)(nil).DeleteUser), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUser", reflect.TypeOf((*MockDatabase)(nil).DeleteUser), arg0, arg1) } // EndUserSession mocks base method. @@ -1289,20 +1290,6 @@ func (mr *MockDatabaseMockRecorder) RawFirst(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RawFirst", reflect.TypeOf((*MockDatabase)(nil).RawFirst), arg0) } -// RemoveAssetGroupSelector mocks base method. -func (m *MockDatabase) RemoveAssetGroupSelector(arg0 model.AssetGroupSelector) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoveAssetGroupSelector", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// RemoveAssetGroupSelector indicates an expected call of RemoveAssetGroupSelector. -func (mr *MockDatabaseMockRecorder) RemoveAssetGroupSelector(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveAssetGroupSelector", reflect.TypeOf((*MockDatabase)(nil).RemoveAssetGroupSelector), arg0) -} - // SavedQueryBelongsToUser mocks base method. func (m *MockDatabase) SavedQueryBelongsToUser(arg0 uuid.UUID, arg1 int) (bool, error) { m.ctrl.T.Helper() @@ -1371,31 +1358,31 @@ func (mr *MockDatabaseMockRecorder) SweepSessions() *gomock.Call { } // UpdateAssetGroup mocks base method. -func (m *MockDatabase) UpdateAssetGroup(arg0 model.AssetGroup) error { +func (m *MockDatabase) UpdateAssetGroup(arg0 context.Context, arg1 model.AssetGroup) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateAssetGroup", arg0) + ret := m.ctrl.Call(m, "UpdateAssetGroup", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // UpdateAssetGroup indicates an expected call of UpdateAssetGroup. -func (mr *MockDatabaseMockRecorder) UpdateAssetGroup(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) UpdateAssetGroup(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAssetGroup", reflect.TypeOf((*MockDatabase)(nil).UpdateAssetGroup), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAssetGroup", reflect.TypeOf((*MockDatabase)(nil).UpdateAssetGroup), arg0, arg1) } // UpdateAssetGroupSelector mocks base method. -func (m *MockDatabase) UpdateAssetGroupSelector(arg0 model.AssetGroupSelector) error { +func (m *MockDatabase) UpdateAssetGroupSelector(arg0 context.Context, arg1 model.AssetGroupSelector) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateAssetGroupSelector", arg0) + ret := m.ctrl.Call(m, "UpdateAssetGroupSelector", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // UpdateAssetGroupSelector indicates an expected call of UpdateAssetGroupSelector. -func (mr *MockDatabaseMockRecorder) UpdateAssetGroupSelector(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) UpdateAssetGroupSelector(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAssetGroupSelector", reflect.TypeOf((*MockDatabase)(nil).UpdateAssetGroupSelector), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAssetGroupSelector", reflect.TypeOf((*MockDatabase)(nil).UpdateAssetGroupSelector), arg0, arg1) } // UpdateAssetGroupSelectors mocks base method. @@ -1414,17 +1401,17 @@ func (mr *MockDatabaseMockRecorder) UpdateAssetGroupSelectors(arg0, arg1, arg2, } // UpdateAuthSecret mocks base method. -func (m *MockDatabase) UpdateAuthSecret(arg0 model.AuthSecret) error { +func (m *MockDatabase) UpdateAuthSecret(arg0 context.Context, arg1 model.AuthSecret) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateAuthSecret", arg0) + ret := m.ctrl.Call(m, "UpdateAuthSecret", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // UpdateAuthSecret indicates an expected call of UpdateAuthSecret. -func (mr *MockDatabaseMockRecorder) UpdateAuthSecret(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) UpdateAuthSecret(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAuthSecret", reflect.TypeOf((*MockDatabase)(nil).UpdateAuthSecret), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAuthSecret", reflect.TypeOf((*MockDatabase)(nil).UpdateAuthSecret), arg0, arg1) } // UpdateAuthToken mocks base method. @@ -1470,31 +1457,31 @@ func (mr *MockDatabaseMockRecorder) UpdateRole(arg0 interface{}) *gomock.Call { } // UpdateSAMLIdentityProvider mocks base method. -func (m *MockDatabase) UpdateSAMLIdentityProvider(arg0 model.SAMLProvider) error { +func (m *MockDatabase) UpdateSAMLIdentityProvider(arg0 context.Context, arg1 model.SAMLProvider) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateSAMLIdentityProvider", arg0) + ret := m.ctrl.Call(m, "UpdateSAMLIdentityProvider", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // UpdateSAMLIdentityProvider indicates an expected call of UpdateSAMLIdentityProvider. -func (mr *MockDatabaseMockRecorder) UpdateSAMLIdentityProvider(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) UpdateSAMLIdentityProvider(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSAMLIdentityProvider", reflect.TypeOf((*MockDatabase)(nil).UpdateSAMLIdentityProvider), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSAMLIdentityProvider", reflect.TypeOf((*MockDatabase)(nil).UpdateSAMLIdentityProvider), arg0, arg1) } // UpdateUser mocks base method. -func (m *MockDatabase) UpdateUser(arg0 model.User) error { +func (m *MockDatabase) UpdateUser(arg0 context.Context, arg1 model.User) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateUser", arg0) + ret := m.ctrl.Call(m, "UpdateUser", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // UpdateUser indicates an expected call of UpdateUser. -func (mr *MockDatabaseMockRecorder) UpdateUser(arg0 interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) UpdateUser(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockDatabase)(nil).UpdateUser), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockDatabase)(nil).UpdateUser), arg0, arg1) } // Wipe mocks base method. diff --git a/cmd/api/src/docs/json/definitions/models.json b/cmd/api/src/docs/json/definitions/models.json index 12dd046379..cbaaf813cf 100644 --- a/cmd/api/src/docs/json/definitions/models.json +++ b/cmd/api/src/docs/json/definitions/models.json @@ -128,9 +128,13 @@ "source": { "type": "string" }, + "commit_id": { + "type": "string" + }, "status": { "string": "enum", "enum": [ + "intent", "success", "failure" ] diff --git a/cmd/api/src/docs/json/paths/v2/auth.json b/cmd/api/src/docs/json/paths/v2/auth.json index d8e09e52be..61c8f58a05 100644 --- a/cmd/api/src/docs/json/paths/v2/auth.json +++ b/cmd/api/src/docs/json/paths/v2/auth.json @@ -280,46 +280,6 @@ "$ref": "#/components/responses/defaultError" } } - }, - "post": { - "description": "Create a new SAML provider authentication endpoint.", - "tags": [ - "Auth", - "Community", - "Enterprise" - ], - "summary": "Create SAML Provider", - "parameters": [ - { - "$ref": "#/definitions/parameters.PreferHeader" - } - ], - "requestBody": { - "description": "The request body for creating a SAML Provider", - "required": true, - "content": { - "application/json": { - "schema": { - "$ref": "#/definitions/v2.CreateSAMLAuthProviderRequest" - } - } - } - }, - "responses": { - "200": { - "description": "OK", - "content": { - "application/json": { - "schema": { - "$ref": "#/definitions/v2.SAMLProviderResponse" - } - } - } - }, - "Error": { - "$ref": "#/components/responses/defaultError" - } - } } }, "api/v2/saml/sso": { diff --git a/cmd/api/src/model/audit.go b/cmd/api/src/model/audit.go index bee470c112..1163f885f0 100644 --- a/cmd/api/src/model/audit.go +++ b/cmd/api/src/model/audit.go @@ -20,20 +20,31 @@ import ( "fmt" "time" + "github.com/gofrs/uuid" "github.com/specterops/bloodhound/src/database/types" ) +type AuditEntryStatus string + +const ( + AuditStatusSuccess AuditEntryStatus = "success" + AuditStatusFailure AuditEntryStatus = "failure" + AuditStatusIntent AuditEntryStatus = "intent" +) + +// TODO embed Basic into this struct instead of declaring the ID and CreatedAt fields. This will require a migration type AuditLog struct { - ID int64 `json:"id" gorm:"primaryKey"` - CreatedAt time.Time `json:"created_at" gorm:"index"` - ActorID string `json:"actor_id" gorm:"index"` - ActorName string `json:"actor_name"` - ActorEmail string `json:"actor_email"` - Action string `json:"action" gorm:"index"` - Fields types.JSONUntypedObject `json:"fields"` - RequestID string `json:"request_id"` - Source string `json:"source"` - Status string `json:"status"` + ID int64 `json:"id" gorm:"primaryKey"` + CreatedAt time.Time `json:"created_at" gorm:"index"` + ActorID string `json:"actor_id" gorm:"index"` + ActorName string `json:"actor_name"` + ActorEmail string `json:"actor_email"` + Action string `json:"action" gorm:"index"` + Fields types.JSONUntypedObject `json:"fields"` + RequestID string `json:"request_id"` + SourceIpAddress string `json:"source_ip_address"` + Status string `json:"status"` + CommitID uuid.UUID `json:"commit_id" gorm:"type:text"` } func (s AuditLog) String() string { @@ -51,7 +62,7 @@ func (s AuditLogs) IsSortable(column string) bool { "action", "request_id", "created_at", - "source", + "source_ip_address", "status": return true default: @@ -61,15 +72,15 @@ func (s AuditLogs) IsSortable(column string) bool { func (s AuditLogs) ValidFilters() map[string][]FilterOperator { return map[string][]FilterOperator{ - "id": {Equals, GreaterThan, GreaterThanOrEquals, LessThan, LessThanOrEquals, NotEquals}, - "actor_id": {Equals, NotEquals}, - "actor_name": {Equals, NotEquals}, - "actor_email": {Equals, NotEquals}, - "action": {Equals, NotEquals}, - "request_id": {Equals, NotEquals}, - "created_at": {Equals, GreaterThan, GreaterThanOrEquals, LessThan, LessThanOrEquals, NotEquals}, - "source": {Equals, NotEquals}, - "status": {Equals, NotEquals}, + "id": {Equals, GreaterThan, GreaterThanOrEquals, LessThan, LessThanOrEquals, NotEquals}, + "actor_id": {Equals, NotEquals}, + "actor_name": {Equals, NotEquals}, + "actor_email": {Equals, NotEquals}, + "action": {Equals, NotEquals}, + "request_id": {Equals, NotEquals}, + "created_at": {Equals, GreaterThan, GreaterThanOrEquals, LessThan, LessThanOrEquals, NotEquals}, + "source_ip_address": {Equals, NotEquals}, + "status": {Equals, NotEquals}, } } @@ -80,7 +91,7 @@ func (s AuditLogs) IsString(column string) bool { "actor_email", "action", "request_id", - "source", + "source_ip_address", "status": return true default: @@ -136,3 +147,11 @@ func (s AuditData) MergeLeft(rightSide Auditable) AuditData { type Auditable interface { AuditData() AuditData } + +type AuditEntry struct { + CommitID uuid.UUID + Action string + Model Auditable + Status AuditEntryStatus + ErrorMsg string +} diff --git a/cmd/api/src/model/auth.go b/cmd/api/src/model/auth.go index df01dc5c97..5fff3a40bd 100644 --- a/cmd/api/src/model/auth.go +++ b/cmd/api/src/model/auth.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package model @@ -21,9 +21,9 @@ import ( "net/url" "time" + "github.com/gofrs/uuid" "github.com/specterops/bloodhound/src/database/types/null" "github.com/specterops/bloodhound/src/serde" - "github.com/gofrs/uuid" ) const PermissionURIScheme = "permission" @@ -155,6 +155,16 @@ type AuthToken struct { Unique } +func (s AuthToken) AuditData() AuditData { + return AuditData{ + "id": s.ID, + "user_id": s.UserID, + "client_id": s.ClientID, + "name": s.Name, + "last_access": s.LastAccess, + } +} + func (s AuthToken) StripKey() AuthToken { return AuthToken{ UserID: s.UserID, @@ -460,18 +470,18 @@ type User struct { Unique } -func (s User) AuditData() AuditData { - data := AuditData{ - "id": s.ID, - "principal_name": s.PrincipalName, - "roles": s.Roles.IDs(), - } - - if s.SAMLProviderID.Valid { - data["saml_provider_id"] = s.SAMLProviderID +func (s *User) AuditData() AuditData { + return AuditData{ + "id": s.ID, + "principal_name": s.PrincipalName, + "first_name": s.FirstName, + "last_name": s.LastName, + "email_address": s.EmailAddress, + "roles": s.Roles.IDs(), + "saml_provider_id": s.SAMLProviderID.ValueOrZero(), + "is_disabled": s.IsDisabled, + "eula_accepted": s.EULAAccepted, } - - return data } func (s *User) RemoveRole(role Role) { diff --git a/cmd/api/src/services/entrypoint.go b/cmd/api/src/services/entrypoint.go index 223a51d5a9..57e7b22445 100644 --- a/cmd/api/src/services/entrypoint.go +++ b/cmd/api/src/services/entrypoint.go @@ -19,11 +19,12 @@ package services import ( "context" "fmt" + "time" + schema "github.com/specterops/bloodhound/graphschema" "github.com/specterops/bloodhound/log" "github.com/specterops/bloodhound/src/bootstrap" "github.com/specterops/bloodhound/src/queries" - "time" "github.com/specterops/bloodhound/cache" "github.com/specterops/bloodhound/dawgs/graph" @@ -87,12 +88,12 @@ func Entrypoint(ctx context.Context, cfg config.Configuration, connections boots var ( graphQuery = queries.NewGraphQuery(connections.Graph, graphQueryCache, cfg) datapipeDaemon = datapipe.NewDaemon(ctx, cfg, connections, graphQueryCache, time.Duration(cfg.DatapipeInterval)*time.Second) - routerInst = router.NewRouter(cfg, auth.NewAuthorizer(), bootstrap.ContentSecurityPolicy) + routerInst = router.NewRouter(cfg, auth.NewAuthorizer(connections.RDMS), bootstrap.ContentSecurityPolicy) ctxInitializer = database.NewContextInitializer(connections.RDMS) authenticator = api.NewAuthenticator(cfg, connections.RDMS, ctxInitializer) ) - registration.RegisterFossGlobalMiddleware(&routerInst, cfg, auth.NewIdentityResolver(), authenticator) + registration.RegisterFossGlobalMiddleware(&routerInst, cfg, connections.RDMS, auth.NewIdentityResolver(), authenticator) registration.RegisterFossRoutes(&routerInst, cfg, connections.RDMS, connections.Graph, graphQuery, apiCache, collectorManifests, authenticator, datapipeDaemon) // Set neo4j batch and flush sizes diff --git a/local-harnesses/build.config.json.template b/local-harnesses/build.config.json.template index 3556eaec5b..4ba7a719ce 100644 --- a/local-harnesses/build.config.json.template +++ b/local-harnesses/build.config.json.template @@ -7,6 +7,8 @@ "collectors_base_path": "/bhapi/collectors", "log_level": "INFO", "log_path": "bhapi.log", + "enable_startup_wait_period": false, + "datapipe_interval": 1, "features": { "enable_auth": true }, diff --git a/local-harnesses/integration.config.json.template b/local-harnesses/integration.config.json.template index 35b1713d28..9293aed87c 100644 --- a/local-harnesses/integration.config.json.template +++ b/local-harnesses/integration.config.json.template @@ -7,6 +7,8 @@ "collectors_base_path": "/tmp/collectors", "log_level": "ERROR", "log_path": "bhapi.log", + "enable_startup_wait_period": false, + "datapipe_interval": 1, "features": { "enable_auth": true },