Skip to content

Commit

Permalink
Bed 3859 - Fix inaccurate count strategy in paginated SQL queries (#317)
Browse files Browse the repository at this point in the history
* fix: counts when filtering AuditLogs

* fix: counts when filtering SavedQueries
  • Loading branch information
mistahj67 authored Jan 17, 2024
1 parent 95df8a4 commit a2fca90
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 11 deletions.
18 changes: 14 additions & 4 deletions cmd/api/src/database/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,24 @@ func (s *BloodhoundDB) ListAuditLogs(before, after time.Time, offset, limit int,
// This code went through a partial refactor when adding support for new fields.
// See the comments here for more information: https://github.com/SpecterOps/BloodHound/pull/297#issuecomment-1887640827

if filter.SQLString != "" {
result = s.db.Model(&auditLogs).Where(filter.SQLString, filter.Params).Count(&count)
} else {
result = s.db.Model(&auditLogs).Count(&count)
}

if result.Error != nil {
return nil, 0, CheckError(result)
}

if order != "" && filter.SQLString == "" {
result = cursor.Order(order).Find(&auditLogs).Count(&count)
result = cursor.Order(order).Find(&auditLogs)
} else if order != "" && filter.SQLString != "" {
result = cursor.Where(filter.SQLString, filter.Params).Order(order).Find(&auditLogs).Count(&count)
result = cursor.Where(filter.SQLString, filter.Params).Order(order).Find(&auditLogs)
} else if order == "" && filter.SQLString != "" {
result = cursor.Where(filter.SQLString, filter.Params).Find(&auditLogs).Count(&count)
result = cursor.Where(filter.SQLString, filter.Params).Find(&auditLogs)
} else {
result = cursor.Find(&auditLogs).Count(&count)
result = cursor.Find(&auditLogs)
}

return auditLogs, int(count), CheckError(result)
Expand Down
73 changes: 73 additions & 0 deletions cmd/api/src/database/audit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright 2023 Specter Ops, Inc.
//
// Licensed under the Apache License, Version 2.0
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0

//go:build integration
// +build integration

package database_test

import (
"github.com/specterops/bloodhound/src/auth"
"github.com/specterops/bloodhound/src/ctx"
"github.com/specterops/bloodhound/src/model"
"github.com/specterops/bloodhound/src/test/integration"
"testing"
"time"
)

func TestDatabase_ListAuditLogs(t *testing.T) {
var (
dbInst = integration.OpenDatabase(t)

auditLogIdFilter = model.QueryParameterFilter{
Name: "id",
Operator: model.GreaterThan,
Value: "4",
IsStringData: false,
}
auditLogIdFilterMap = model.QueryParameterFilterMap{auditLogIdFilter.Name: model.QueryParameterFilters{auditLogIdFilter}}
)

if err := integration.Prepare(dbInst); err != nil {
t.Fatalf("Failed preparing DB: %v", err)
}

mockCtx := ctx.Context{
RequestID: "requestID",
AuthCtx: auth.Context{
Owner: model.User{},
Session: model.UserSession{},
},
}
for i := 0; i < 7; i++ {
if err := dbInst.AppendAuditLog(mockCtx, "CreateUser", model.User{}); err != nil {
t.Fatalf("Error creating audit log: %v", err)
}
}

if _, count, err := dbInst.ListAuditLogs(time.Now(), time.Now(), 0, 10, "", model.SQLFilter{}); err != nil {
t.Fatalf("Failed to list all audit logs: %v", err)
} else if count != 7 {
t.Fatalf("Expected 7 audit logs to be returned")
} else if filter, err := auditLogIdFilterMap.BuildSQLFilter(); err != nil {
t.Fatalf("Failed to generate SQL Filter: %v", err)
// Limit is set to 1 to verify that count is total filtered count, not response size
} else if _, count, err = dbInst.ListAuditLogs(time.Now(), time.Now(), 0, 1, "", filter); err != nil {
t.Fatalf("Failed to list filtered events: %v", err)
} else if count != 3 {
t.Fatalf("Expected 3 audit logs to be returned")
}
}
17 changes: 10 additions & 7 deletions cmd/api/src/database/saved_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,33 @@ package database
import (
"github.com/gofrs/uuid"
"github.com/specterops/bloodhound/src/model"
"gorm.io/gorm"
)

func (s *BloodhoundDB) ListSavedQueries(userID uuid.UUID, order string, filter model.SQLFilter, skip, limit int) (model.SavedQueries, int, error) {
var (
queries model.SavedQueries
result *gorm.DB
count int64
cursor = s.Scope(Paginate(skip, limit)).Where("user_id = ?", userID)
)

cursor := s.Scope(Paginate(skip, limit)).Where("user_id = ?", userID)

if filter.SQLString != "" {
cursor = cursor.Where(filter.SQLString, filter.Params)
result = s.db.Model(&queries).Where("user_id = ?", userID).Where(filter.SQLString, filter.Params).Count(&count)
} else {
result = s.db.Model(&queries).Where("user_id = ?", userID).Count(&count)
}

if order != "" {
cursor = cursor.Order(order)
}

result := s.db.Where("user_id = ?", userID).Find(&queries).Count(&count)
if result.Error != nil {
return queries, 0, result.Error
}

if order != "" {
cursor = cursor.Order(order)
}
result = cursor.Find(&queries)

return queries, int(count), CheckError(result)
}

Expand Down
69 changes: 69 additions & 0 deletions cmd/api/src/database/saved_queries_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright 2023 Specter Ops, Inc.
//
// Licensed under the Apache License, Version 2.0
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0

//go:build integration
// +build integration

package database_test

import (
"fmt"
"github.com/gofrs/uuid"
"github.com/specterops/bloodhound/src/model"
"github.com/specterops/bloodhound/src/test/integration"
"github.com/stretchr/testify/require"
"testing"
)

func TestSavedQueries_ListSavedQueries(t *testing.T) {
var (
dbInst = integration.OpenDatabase(t)

savedQueriesFilter = model.QueryParameterFilter{
Name: "id",
Operator: model.GreaterThan,
Value: "4",
IsStringData: false,
}
savedQueriesFilterMap = model.QueryParameterFilterMap{savedQueriesFilter.Name: model.QueryParameterFilters{savedQueriesFilter}}
)

if err := integration.Prepare(dbInst); err != nil {
t.Fatalf("Failed preparing DB: %v", err)
}

userUUID, err := uuid.NewV4()
require.Nil(t, err)

for i := 0; i < 7; i++ {
if _, err := dbInst.CreateSavedQuery(userUUID, fmt.Sprintf("saved_query_%d", i), ""); err != nil {
t.Fatalf("Error creating audit log: %v", err)
}
}

if _, count, err := dbInst.ListSavedQueries(userUUID, "", model.SQLFilter{}, 0, 10); err != nil {
t.Fatalf("Failed to list all saved queries: %v", err)
} else if count != 7 {
t.Fatalf("Expected 7 saved queries to be returned")
} else if filter, err := savedQueriesFilterMap.BuildSQLFilter(); err != nil {
t.Fatalf("Failed to generate SQL Filter: %v", err)
// Limit is set to 1 to verify that count is total filtered count, not response size
} else if _, count, err = dbInst.ListSavedQueries(userUUID, "", filter, 0, 1); err != nil {
t.Fatalf("Failed to list filtered saved queries: %v", err)
} else if count != 3 {
t.Fatalf("Expected 3 saved queries to be returned")
}
}

0 comments on commit a2fca90

Please sign in to comment.