Skip to content

Commit

Permalink
BED-5060: Prevent user from changing their own role/auth (#984)
Browse files Browse the repository at this point in the history
* BED-5060: Add check for if user is attempting to edit their own role/SSO

* BED-5060: Disable editing roles and auth method for oneself in UI

* BED-5060: Finished implementation of roles/sso guarding
Fix bug where user's password could get deleted
Fix bug with GetAuthSecret not returning errnotfound

* BED-5060: run prepare-for-codereview

* BED-5060: Fix API tests

* refactor: ensure audit secret audit logs are maintained

* BED-5060: add tests for new errors on the API
Adjust update user form to hide fields

* BED-5060: Fix whitespace

* BED-5060: Remove extraneous admin password

---------

Co-authored-by: Mistah J <[email protected]>
  • Loading branch information
wes-mil and mistahj67 authored Dec 6, 2024
1 parent da6297a commit 5cb0295
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 191 deletions.
2 changes: 2 additions & 0 deletions cmd/api/src/api/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ const (
ErrorResponsePayloadUnmarshalError = "error unmarshalling JSON payload"
ErrorResponseRequestTimeout = "request timed out"
ErrorResponseUserSelfDisable = "user attempted to disable themselves"
ErrorResponseUserSelfRoleChange = "user attempted to change own role"
ErrorResponseUserSelfSSOProviderChange = "user attempted to change own SSO Provider"
ErrorResponseAGTagWhiteSpace = "asset group tags must not contain whitespace"
ErrorResponseAGNameTagEmpty = "asset group name or tag must not be empty"
ErrorResponseAGDuplicateName = "asset group name must be unique"
Expand Down
33 changes: 15 additions & 18 deletions cmd/api/src/api/v2/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,16 +364,6 @@ func (s ManagementResource) CreateUser(response http.ResponseWriter, request *ht
}
}

func (s ManagementResource) ensureUserHasNoAuthSecret(ctx context.Context, user model.User) error {
if user.AuthSecret != nil {
if err := s.db.DeleteAuthSecret(ctx, *user.AuthSecret); err != nil {
return api.FormatDatabaseError(err)
}
}

return nil
}

func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *http.Request) {
var (
updateUserRequest v2.UpdateUserRequest
Expand All @@ -400,8 +390,10 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
user.PrincipalName = updateUserRequest.Principal
user.IsDisabled = updateUserRequest.IsDisabled

loggedInUser, _ := auth.GetUserFromAuthCtx(authCtx.AuthCtx)

if user.IsDisabled {
if loggedInUser, _ := auth.GetUserFromAuthCtx(authCtx.AuthCtx); user.ID == loggedInUser.ID {
if user.ID == loggedInUser.ID {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseUserSelfDisable, request), response)
return
} else if userSessions, err := s.db.LookupActiveSessionsByUser(request.Context(), user); err != nil {
Expand All @@ -419,9 +411,6 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
if samlProviderID, err := serde.ParseInt32(updateUserRequest.SAMLProviderID); err != nil {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, fmt.Sprintf("SAML Provider ID must be a number: %v", err.Error()), request), response)
return
} else if err := s.ensureUserHasNoAuthSecret(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else if provider, err := s.db.GetSAMLProvider(request.Context(), samlProviderID); err != nil {
api.HandleDatabaseError(request, response, err)
return
Expand All @@ -431,10 +420,7 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
user.SSOProviderID = provider.SSOProviderID
}
} else if updateUserRequest.SSOProviderID.Valid {
if err := s.ensureUserHasNoAuthSecret(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else if _, err := s.db.GetSSOProviderById(request.Context(), updateUserRequest.SSOProviderID.Int32); err != nil {
if _, err := s.db.GetSSOProviderById(request.Context(), updateUserRequest.SSOProviderID.Int32); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else {
Expand All @@ -447,6 +433,17 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
user.SSOProviderID = null.NewInt32(0, false)
}

// Prevent a user from modifying their own roles/permissions
if user.ID == loggedInUser.ID {
if !slices.Equal(roles.IDs(), loggedInUser.Roles.IDs()) {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseUserSelfRoleChange, request), response)
return
} else if !user.SSOProviderID.Equal(loggedInUser.SSOProviderID) {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseUserSelfSSOProviderChange, request), response)
return
}
}

if err := s.db.UpdateUser(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
Expand Down
66 changes: 65 additions & 1 deletion cmd/api/src/api/v2/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,17 @@ func TestManagementResource_PutUserAuthSecret(t *testing.T) {

func TestManagementResource_EnableUserSAML(t *testing.T) {
var (
adminUser = model.User{Unique: model.Unique{ID: must.NewUUIDv4()}}
goodRoles = []int32{0}
goodUserID = must.NewUUIDv4()
badUserID = must.NewUUIDv4()
mockCtrl = gomock.NewController(t)
resources, mockDB = apitest.NewAuthManagementResource(mockCtrl)
)

bhCtx := ctx.Get(context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}))
bhCtx.AuthCtx.Owner = adminUser

defer mockCtrl.Finish()

t.Run("Successfully update user with deprecated saml provider", func(t *testing.T) {
Expand All @@ -181,6 +185,7 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": goodUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Expand All @@ -197,9 +202,9 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
mockDB.EXPECT().GetUser(gomock.Any(), badUserID).Return(model.User{AuthSecret: &model.AuthSecret{}}, nil)
mockDB.EXPECT().GetSAMLProvider(gomock.Any(), samlProviderID).Return(model.SAMLProvider{}, nil)
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)
mockDB.EXPECT().DeleteAuthSecret(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": badUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Expand All @@ -218,6 +223,7 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": goodUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Expand Down Expand Up @@ -1511,6 +1517,64 @@ func TestManagementResource_UpdateUser_SelfDisable(t *testing.T) {
require.Contains(t, response.Body.String(), api.ErrorResponseUserSelfDisable)
}

func TestManagementResource_UpdateUser_UserSelfModify(t *testing.T) {
var (
adminRole = model.Role{
Serial: model.Serial{
ID: 1,
},
}
goodRoles = []int32{1}
badRole = model.Role{
Serial: model.Serial{
ID: 2,
},
}
badRoles = []int32{2}
adminUser = model.User{AuthSecret: defaultDigestAuthSecret(t, "currentPassword"), Unique: model.Unique{ID: must.NewUUIDv4()}, Roles: model.Roles{adminRole}}
mockCtrl = gomock.NewController(t)
resources, mockDB = apitest.NewAuthManagementResource(mockCtrl)
)

bhCtx := ctx.Get(context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}))
bhCtx.AuthCtx.Owner = adminUser

defer mockCtrl.Finish()

t.Run("Prevent users from changing their own SSO provider", func(t *testing.T) {
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{adminRole}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), adminUser.ID).Return(adminUser, nil)
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), ssoProviderID).Return(model.SSOProvider{}, nil)
test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": adminUser.ID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Roles: goodRoles,
SSOProviderID: null.Int32From(123),
}).
OnHandlerFunc(resources.UpdateUser).
Require().
ResponseStatusCode(http.StatusBadRequest)
})

t.Run("Prevent users from changing their own roles", func(t *testing.T) {
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{badRole}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), adminUser.ID).Return(adminUser, nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": adminUser.ID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Roles: badRoles,
}).
OnHandlerFunc(resources.UpdateUser).
Require().
ResponseStatusCode(http.StatusBadRequest)
})
}

func TestManagementResource_UpdateUser_LookupActiveSessionsError(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
Expand Down
15 changes: 14 additions & 1 deletion cmd/api/src/database/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,19 @@ func (s *BloodhoundDB) UpdateUser(ctx context.Context, user model.User) error {
return err
}

// AuthSecret must be manually retrieved and deleted
if user.AuthSecret == nil {
var authSecret model.AuthSecret
if err := tx.Raw("SELECT * FROM auth_secrets WHERE user_id = ?", user.ID).First(&authSecret).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
} else if authSecret.ID > 0 {
bhdb := NewBloodhoundDB(tx, s.idResolver)
if err := bhdb.DeleteAuthSecret(ctx, authSecret); err != nil {
return err
}
}
}

result := tx.WithContext(ctx).Save(&user)
return CheckError(result)
})
Expand Down Expand Up @@ -431,7 +444,7 @@ func (s *BloodhoundDB) CreateAuthSecret(ctx context.Context, authSecret model.Au
func (s *BloodhoundDB) GetAuthSecret(ctx context.Context, id int32) (model.AuthSecret, error) {
var (
authSecret model.AuthSecret
result = s.db.WithContext(ctx).Find(&authSecret, id)
result = s.db.WithContext(ctx).First(&authSecret, id)
)

return authSecret, CheckError(result)
Expand Down
65 changes: 65 additions & 0 deletions cmd/api/src/database/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,71 @@ func TestDatabase_CreateGetDeleteUser(t *testing.T) {
}
}

func TestDatabase_UpdateUserAuth(t *testing.T) {
var (
ctx = context.Background()
dbInst, user = initAndCreateUser(t)
secret = model.AuthSecret{
UserID: user.ID,
Digest: "digest",
DigestMethod: "fake",
ExpiresAt: time.Now().Add(1 * time.Hour),
}
samlProvider = model.SAMLProvider{
Name: "provider",
DisplayName: "provider name",
IssuerURI: "https://idp.example.com/idp.xml",
SingleSignOnURI: "https://idp.example.com/sso",
}
)

if newSecret, err := dbInst.CreateAuthSecret(ctx, secret); err != nil {
t.Fatalf("Failed to create auth secret: %v", err)
} else if err = test.VerifyAuditLogs(dbInst, model.AuditLogActionCreateAuthSecret, "secret_user_id", newSecret.UserID.String()); err != nil {
t.Fatalf("Failed to validate CreateAuthSecret audit logs:\n%v", err)
} else {
if newSAMLProvider, err := dbInst.CreateSAMLIdentityProvider(ctx, samlProvider); err != nil {
t.Fatalf("Failed to create SAML provider: %v", err)
} else if err = test.VerifyAuditLogs(dbInst, model.AuditLogActionCreateSAMLIdentityProvider, "saml_name", newSAMLProvider.Name); err != nil {
t.Fatalf("Failed to validate CreateSAMLIdentityProvider audit logs:\n%v", err)
} else {
user, err = dbInst.GetUser(ctx, user.ID)
if err != nil {
t.Fatalf("Failed looking up user by principal %s: %v", user.PrincipalName, err)
}

user.FirstName = null.StringFrom("friendly man")

if err := dbInst.UpdateUser(ctx, user); err != nil {
t.Fatalf("Failed to update user: %v", err)
} else if err = test.VerifyAuditLogs(dbInst, model.AuditLogActionUpdateUser, "principal_name", user.PrincipalName); err != nil {
t.Fatalf("Failed to validate UpdateUser audit logs:\n%v", err)
} else if updatedUser, err := dbInst.GetUser(ctx, user.ID); err != nil {
t.Fatalf("Failed looking up user by principal %s: %v", user.PrincipalName, err)
} else if updatedUser.AuthSecret == nil {
t.Fatalf("Failed to find authsecret for user %s", user.PrincipalName)
} else if _, err := dbInst.GetAuthSecret(ctx, updatedUser.AuthSecret.ID); err != nil {
t.Fatalf("Failed to get authsecret by id %d", updatedUser.AuthSecret.ID)
}

user.AuthSecret = nil
user.SSOProviderID = newSAMLProvider.SSOProviderID

if err := dbInst.UpdateUser(ctx, user); err != nil {
t.Fatalf("Failed to update user: %v", err)
} else if err = test.VerifyAuditLogs(dbInst, model.AuditLogActionUpdateUser, "principal_name", user.PrincipalName); err != nil {
t.Fatalf("Failed to validate UpdateUser audit logs:\n%v", err)
} else if updatedUser, err := dbInst.GetUser(ctx, user.ID); err != nil {
t.Fatalf("Failed looking up user by principal %s: %v", user.PrincipalName, err)
} else if updatedUser.AuthSecret != nil {
t.Fatalf("Found authsecret for user %s but expected it to be removed", user.PrincipalName)
} else if _, err := dbInst.GetAuthSecret(ctx, newSecret.ID); err == nil {
t.Fatalf("Found authsecret for id %d but expected it to be removed", newSecret.ID)
}
}
}
}

func TestDatabase_CreateGetDeleteAuthToken(t *testing.T) {
var (
ctx = context.Background()
Expand Down
16 changes: 5 additions & 11 deletions cmd/api/src/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ require (
github.com/go-chi/chi/v5 v5.0.8
github.com/gobeam/stringy v0.0.6
github.com/gofrs/uuid v4.4.0+incompatible
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/golang-jwt/jwt/v4 v4.5.1
github.com/golang/mock v1.6.0
github.com/gorilla/handlers v1.5.1
github.com/gorilla/mux v1.8.0
Expand All @@ -44,12 +44,11 @@ require (
github.com/unrolled/secure v1.13.0
go.uber.org/mock v0.2.0
golang.org/x/oauth2 v0.23.0
gorm.io/driver/postgres v1.3.8
gorm.io/gorm v1.23.8
gorm.io/driver/postgres v1.5.10
gorm.io/gorm v1.25.12
)

require (
github.com/Masterminds/semver/v3 v3.2.1 // indirect
github.com/beevik/etree v1.2.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/boombuler/barcode v1.0.1 // indirect
Expand All @@ -61,18 +60,12 @@ require (
github.com/go-pkgz/expirable-cache v1.0.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgconn v1.14.3 // indirect
github.com/jackc/pgio v1.0.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgproto3/v2 v2.3.3 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgtype v1.14.4 // indirect
github.com/jackc/pgx/v4 v4.18.3 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/jonboulle/clockwork v0.4.0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect
github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
Expand All @@ -82,6 +75,7 @@ require (
github.com/prometheus/procfs v0.11.0 // indirect
github.com/rivo/uniseg v0.4.4 // indirect
golang.org/x/crypto v0.27.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/time v0.3.0 // indirect
Expand Down
Loading

0 comments on commit 5cb0295

Please sign in to comment.